1#ifndef TORCH_BERNOULLI_DIST_H
2#define TORCH_BERNOULLI_DIST_H
9#include "cuberl/base/cubeai_config.h"
16#include <torch/torch.h>
25class TorchBernoulliDist final :
public TorchDistributionBase
33 TorchBernoulliDist() =
default;
42 TorchBernoulliDist(torch_tensor_t probs,
bool do_build_from_logits=
false);
47 virtual ~TorchBernoulliDist() =
default;
53 virtual torch_tensor_t entropy()
const override;
60 virtual torch_tensor_t log_prob(torch_tensor_t value)
override;
65 virtual torch_tensor_t sample(c10::ArrayRef<int64_t> sample_shape = {})
override;
71 void build_from_logits(torch_tensor_t logits);
76 void build_from_probabilities(torch_tensor_t probs);
81 torch_tensor_t get_logits()
const {
return logits_; }
87 torch_tensor_t
get_probs()
const {
return probs_; }
90 torch_tensor_t probs_;
91 torch_tensor_t logits_;
92 torch_tensor_t param_;
Various utilities used when working with RL problems.
Definition cuberl_types.h:16
std::vector< real_t > get_probs(uint_t n)
Definition rl_example_3.cpp:82