8#include "cuberl/base/cubeai_config.h"
15#include <torch/torch.h>
19 namespace maths::stats
24 class TorchNormalDist final :
public TorchDistributionBase
33 explicit TorchNormalDist(torch_tensor_t mu, torch_tensor_t sigma);
38 torch_tensor_t sample()
const;
44 ~TorchNormalDist()
override =
default;
51 torch_tensor_t log_prob(torch_tensor_t value)
override;
56 torch_tensor_t entropy()
const override;
61 torch_tensor_t
mean()
const {
return mean_;}
66 torch_tensor_t std()
const {
return sd_;}
76 virtual torch_tensor_t sample(c10::ArrayRef<int64_t> ) {
return torch_tensor_t(); }
real_t mean(IteratorType begin, IteratorType end, bool parallel=true)
mean Compute the mean value of the values in the provided iterator range
Definition vector_math.h:126
Various utilities used when working with RL problems.
Definition cuberl_types.h:16