1#ifndef TORCH_STATE_ADAPTOR_H
2#define TORCH_STATE_ADAPTOR_H
4#include "cuberl/base/cubeai_config.h"
10#include <torch/torch.h>
24 typedef torch_tensor_t value_type;
26 static value_type to_torch(
const std::vector<real_t>& data,
27 DeviceType dtype=DeviceType::CPU,
bool requires_grad=
false);
29 static value_type to_torch(
const std::vector<uint_t>& data,
30 DeviceType dtype=DeviceType::CPU,
bool requires_grad=
false);
32 static value_type to_torch(
const std::vector<float_t>& data,
33 DeviceType dtype=DeviceType::CPU,
bool requires_grad=
false);
34 static value_type to_torch(
const std::vector<int_t>& data,
35 DeviceType dtype=DeviceType::CPU,
bool requires_grad=
false);
36 static value_type to_torch(
const std::vector<lint_t>& data,
37 DeviceType dtype=DeviceType::CPU,
bool requires_grad=
false);
38 static value_type to_torch(
const std::vector<bool>& data,
39 DeviceType dtype=DeviceType::CPU,
bool requires_grad=
false);
42 static value_type stack(
const std::vector<T>& values,
43 DeviceType type=DeviceType::CPU,
44 bool requires_grad=
false);
49 static value_type cat(
const std::vector<real_t>& values,
50 DeviceType type=DeviceType::CPU,
51 bool requires_grad=
false);
54 static value_type stack_as_float(
const std::vector<std::vector<T>>& values,
55 DeviceType type=DeviceType::CPU);
58 static std::vector<T> to_vector(torch_tensor_t tensor);
60 torch_tensor_t operator()(real_t value)
const;
61 torch_tensor_t operator()(
const std::vector<real_t>& data)
const;
62 torch_tensor_t operator()(
const std::vector<float_t>& data)
const;
63 torch_tensor_t operator()(
const std::vector<int>& data)
const;
71TorchAdaptor::value_type
72TorchAdaptor::stack_as_float(
const std::vector<std::vector<T>>& values,
76 std::vector<std::vector<float_t>> values_(
values.size());
79 for(uint_t i=0; i<
values.size(); ++i){
80 values_[i].resize(values[i].size());
81 for(uint_t j=0; j<
values[i].size(); ++j){
82 values_[i][j] =
static_cast<float>(
values[i][j]);
86 return TorchAdaptor::stack(values_, dtype);
Various utilities used when working with RL problems.
Definition cuberl_types.h:16
list values
Definition plot_losses.py:13