bitrl & cuberl Documentation
Simulation engine for reinforcement learning agents
Loading...
Searching...
No Matches
torch_normal.h
Go to the documentation of this file.
1//
2// Created by alex on 7/26/25.
3//
4
5#ifndef TORCH_NORMAL_H
6#define TORCH_NORMAL_H
7
8#include "cuberl/base/cubeai_config.h"
9
10#ifdef USE_PYTORCH
11
14
15#include <torch/torch.h>
16
17namespace cuberl
18{
19 namespace maths::stats
20 {
24 class TorchNormalDist final : public TorchDistributionBase
25 {
26
27 public:
28
33 explicit TorchNormalDist(torch_tensor_t mu, torch_tensor_t sigma);
34
38 torch_tensor_t sample()const;
39
40
44 ~TorchNormalDist() override = default;
45
51 torch_tensor_t log_prob(torch_tensor_t value) override;
52
56 torch_tensor_t entropy() const override;
57
61 torch_tensor_t mean()const {return mean_;}
62
66 torch_tensor_t std()const {return sd_;}
67 private:
68
69 torch_tensor_t mean_;
70 torch_tensor_t sd_;
71
76 virtual torch_tensor_t sample(c10::ArrayRef<int64_t> /*sample_shape*/) { return torch_tensor_t(); }
77 };
78 }
79}
80
81#endif
82
83#endif //TORCH_NORMAL_H
real_t mean(IteratorType begin, IteratorType end, bool parallel=true)
mean Compute the mean value of the values in the provided iterator range
Definition vector_math.h:126
Various utilities used when working with RL problems.
Definition cuberl_types.h:16