bitrl & cuberl Documentation
Simulation engine for reinforcement learning agents
Loading...
Searching...
No Matches
torch_tensor_utils.h
Go to the documentation of this file.
1#ifndef TORCH_TENSOR_UTILS_H
2#define TORCH_TENSOR_UTILS_H
3
4#include "cuberl/base/cubeai_config.h"
5
6#ifdef USE_PYTORCH
7
8#include "cuberl/base/cubeai_types.h"
9#include "cuberl/io/torch_state_dictionary_reader.h"
10#include <torch/torch.h>
11
12
13namespace cubeai{
14namespace torch_utils {
15
16
21template<typename ContainerType>
22torch_tensor_t create_mask(const ContainerType& container){
23
24 return torch::tensor(container.data(), torch::dtype(torch::kBool));
25}
26
30template<typename TorchNetType>
31void copy_parameters_to(/*TorchNetType& from,*/ TorchNetType& to, const std::string params_path){
32
33 // make parameters copying possible
34 torch::autograd::GradMode::set_enabled(false);
35
36 //
37 //from.save(params_path);
38
39 auto new_params = TorchStateDictionaryReader(params_path); // implement this
40 auto params = to->named_parameters(true /*recurse*/);
41 auto buffers = to->named_buffers(true /*recurse*/);
42
43 /*for (auto& val : new_params) {
44
45 auto name = val.key();
46 auto* t = params.find(name);
47
48 if (t != nullptr) {
49 t->copy_(val.value());
50 }
51 else {
52
53 t = buffers.find(name);
54 if (t != nullptr) {
55 t->copy_(val.value());
56 }
57 }
58 }*/
59
60 torch::autograd::GradMode::set_enabled(true);
61}
62
63
64}
65}
66
67#endif
68
69#endif // TORCH_TENSOR_UTILS_H
Definition mc_tree_search_solver.h:22