bitrl & cuberl Documentation
Simulation engine for reinforcement learning agents
Loading...
Searching...
No Matches
torch_bernoulli_dist.h
Go to the documentation of this file.
1#ifndef TORCH_BERNOULLI_DIST_H
2#define TORCH_BERNOULLI_DIST_H
3
8
9#include "cuberl/base/cubeai_config.h"
10
11#ifdef USE_PYTORCH
12
15
16#include <torch/torch.h>
17
18namespace cuberl {
19namespace maths {
20namespace stats {
21
25class TorchBernoulliDist final : public TorchDistributionBase
26{
27public:
28
33 TorchBernoulliDist() = default;
34
42 TorchBernoulliDist(torch_tensor_t probs, bool do_build_from_logits=false);
43
47 virtual ~TorchBernoulliDist() = default;
48
53 virtual torch_tensor_t entropy() const override;
54
60 virtual torch_tensor_t log_prob(torch_tensor_t value) override;
61
65 virtual torch_tensor_t sample(c10::ArrayRef<int64_t> sample_shape = {})override;
66
67
71 void build_from_logits(torch_tensor_t logits);
72
76 void build_from_probabilities(torch_tensor_t probs);
77
81 torch_tensor_t get_logits()const { return logits_; }
82
87 torch_tensor_t get_probs()const { return probs_; }
88
89private:
90 torch_tensor_t probs_;
91 torch_tensor_t logits_;
92 torch_tensor_t param_;
93 int num_events_;
94
95
96};
97
98}
99}
100}
101
102#endif // TORCH_BERNOULLI_DIST_H
103#endif
Various utilities used when working with RL problems.
Definition cuberl_types.h:16
std::vector< real_t > get_probs(uint_t n)
Definition rl_example_3.cpp:82