1#ifndef TORCH_CATEGORICAL_H
2#define TORCH_CATEGORICAL_H
8#include "cuberl/base/cubeai_config.h"
15#include <torch/torch.h>
22 class TorchCategorical final :
public TorchDistributionBase
32 TorchCategorical() =
default;
40 TorchCategorical(torch_tensor_t probs,
bool do_build_from_logits=
false);
45 ~TorchCategorical()
override =
default;
51 virtual torch_tensor_t entropy()
const override;
58 virtual torch_tensor_t log_prob(torch_tensor_t value)
override;
63 virtual torch_tensor_t sample(c10::ArrayRef<int64_t> sample_shape = {})
override;
68 void build_from_logits(torch_tensor_t logits);
73 void build_from_probabilities(torch_tensor_t probs);
78 torch_tensor_t get_logits()
const {
return logits_; }
84 torch_tensor_t
get_probs()
const {
return probs_; }
87 torch_tensor_t probs_;
88 torch_tensor_t logits_;
89 torch_tensor_t param_;
Various utilities used when working with RL problems.
Definition cuberl_types.h:16
std::vector< real_t > get_probs(uint_t n)
Definition rl_example_3.cpp:82