bitrl & cuberl Documentation
Simulation engine for reinforcement learning agents
Loading...
Searching...
No Matches
torch_distribution.h
Go to the documentation of this file.
1#ifndef TORCH_DISTRIBUTION_H
2#define TORCH_DISTRIBUTION_H
3
4#include "cuberl/base/cubeai_config.h"
5
6#ifdef USE_PYTORCH
7
9#include "torch/torch.h"
10
11#include <vector>
12
13namespace cuberl {
14namespace maths {
15namespace stats {
16
21class TorchDistributionBase
22{
23
24public:
25
29 virtual ~TorchDistributionBase() = default;
30
34 virtual torch_tensor_t entropy() const = 0;
35
41 virtual torch_tensor_t log_prob(torch_tensor_t value) = 0;
42
46 virtual torch_tensor_t sample(c10::ArrayRef<int64_t> sample_shape) = 0;
47
48protected:
49
53 TorchDistributionBase() = default;
54
58 std::vector<int64_t> batch_shape_;
59
63 std::vector<int64_t> event_shape_;
64
70 std::vector<int64_t> extended_shape(c10::ArrayRef<int64_t> sample_shape);
71
72};
73
74}
75}
76}
77#endif
78#endif // TORCH_DISTRIBUTION_H
Various utilities used when working with RL problems.
Definition cuberl_types.h:16