1#ifndef TORCH_DISTRIBUTION_H
2#define TORCH_DISTRIBUTION_H
4#include "cuberl/base/cubeai_config.h"
9#include "torch/torch.h"
21class TorchDistributionBase
29 virtual ~TorchDistributionBase() =
default;
34 virtual torch_tensor_t entropy()
const = 0;
41 virtual torch_tensor_t log_prob(torch_tensor_t value) = 0;
46 virtual torch_tensor_t sample(c10::ArrayRef<int64_t> sample_shape) = 0;
53 TorchDistributionBase() =
default;
58 std::vector<int64_t> batch_shape_;
63 std::vector<int64_t> event_shape_;
70 std::vector<int64_t> extended_shape(c10::ArrayRef<int64_t> sample_shape);
Various utilities used when working with RL problems.
Definition cuberl_types.h:16