bitrl & cuberl Documentation
Simulation engine for reinforcement learning agents
Loading...
Searching...
No Matches
torch_adaptor.h
Go to the documentation of this file.
1#ifndef TORCH_STATE_ADAPTOR_H
2#define TORCH_STATE_ADAPTOR_H
3
4#include "cuberl/base/cubeai_config.h"
5
6#ifdef USE_PYTORCH
7
9
10#include <torch/torch.h>
11#include <vector>
12
13namespace cuberl{
14namespace utils {
15namespace pytorch{
16
17
21struct TorchAdaptor{
22
23
24 typedef torch_tensor_t value_type;
25
26 static value_type to_torch(const std::vector<real_t>& data,
27 DeviceType dtype=DeviceType::CPU, bool requires_grad=false);
28
29 static value_type to_torch(const std::vector<uint_t>& data,
30 DeviceType dtype=DeviceType::CPU, bool requires_grad=false);
31
32 static value_type to_torch(const std::vector<float_t>& data,
33 DeviceType dtype=DeviceType::CPU, bool requires_grad=false);
34 static value_type to_torch(const std::vector<int_t>& data,
35 DeviceType dtype=DeviceType::CPU, bool requires_grad=false);
36 static value_type to_torch(const std::vector<lint_t>& data,
37 DeviceType dtype=DeviceType::CPU, bool requires_grad=false);
38 static value_type to_torch(const std::vector<bool>& data,
39 DeviceType dtype=DeviceType::CPU, bool requires_grad=false);
40
41 template<typename T>
42 static value_type stack(const std::vector<T>& values,
43 DeviceType type=DeviceType::CPU,
44 bool requires_grad=false);
45
46// static value_type stack(const std::vector<value_type>& values,
47// DeviceType type=DeviceType::CPU)const;
48
49 static value_type cat(const std::vector<real_t>& values,
50 DeviceType type=DeviceType::CPU,
51 bool requires_grad=false);
52
53 template<typename T>
54 static value_type stack_as_float(const std::vector<std::vector<T>>& values,
55 DeviceType type=DeviceType::CPU);
56
57 template<typename T>
58 static std::vector<T> to_vector(torch_tensor_t tensor);
59
60 torch_tensor_t operator()(real_t value)const;
61 torch_tensor_t operator()(const std::vector<real_t>& data)const;
62 torch_tensor_t operator()(const std::vector<float_t>& data)const;
63 torch_tensor_t operator()(const std::vector<int>& data)const;
64
65
66
67};
68
69
70template<typename T>
71TorchAdaptor::value_type
72TorchAdaptor::stack_as_float(const std::vector<std::vector<T>>& values,
73 DeviceType dtype){
74
75
76 std::vector<std::vector<float_t>> values_(values.size());
77
78 // make the input values floats
79 for(uint_t i=0; i<values.size(); ++i){
80 values_[i].resize(values[i].size());
81 for(uint_t j=0; j<values[i].size(); ++j){
82 values_[i][j] = static_cast<float>(values[i][j]);
83 }
84 }
85
86 return TorchAdaptor::stack(values_, dtype);
87
88}
89
90}
91}
92}
93#endif
94#endif // TORCH_STATE_ADAPTOR_H
Various utilities used when working with RL problems.
Definition cuberl_types.h:16
list values
Definition plot_losses.py:13