bitrl & cuberl Documentation
Simulation engine for reinforcement learning agents
Loading...
Searching...
No Matches
torch_categorical.h
Go to the documentation of this file.
1#ifndef TORCH_CATEGORICAL_H
2#define TORCH_CATEGORICAL_H
8#include "cuberl/base/cubeai_config.h"
9
10#ifdef USE_PYTORCH
11
14
15#include <torch/torch.h>
16
17
18namespace cuberl {
19namespace maths::stats
20{
21
22 class TorchCategorical final : public TorchDistributionBase
23 {
24
25
26 public:
27
32 TorchCategorical() = default;
33
40 TorchCategorical(torch_tensor_t probs, bool do_build_from_logits=false);
41
45 ~TorchCategorical() override = default;
46
51 virtual torch_tensor_t entropy() const override;
52
58 virtual torch_tensor_t log_prob(torch_tensor_t value) override;
59
63 virtual torch_tensor_t sample(c10::ArrayRef<int64_t> sample_shape = {})override;
64
68 void build_from_logits(torch_tensor_t logits);
69
73 void build_from_probabilities(torch_tensor_t probs);
74
78 torch_tensor_t get_logits()const { return logits_; }
79
84 torch_tensor_t get_probs()const { return probs_; }
85
86 private:
87 torch_tensor_t probs_;
88 torch_tensor_t logits_;
89 torch_tensor_t param_;
90 int num_events_;
91
92 };
93}
94}
95#endif
96#endif
Various utilities used when working with RL problems.
Definition cuberl_types.h:16
std::vector< real_t > get_probs(uint_t n)
Definition rl_example_3.cpp:82