bitrl & cuberl Documentation
Simulation engine for reinforcement learning agents
Loading...
Searching...
No Matches
torch_policy.h
Go to the documentation of this file.
1#ifndef TORCH_POLICY_H
2#define TORCH_POLICY_H
3
4/*
5#include "kernel/base/config.h"
6
7#ifdef USE_PYTORCH
8
9#include "cubic_engine/base/cubic_engine_types.h"
10#include "cubic_engine/rl/actions/action_space.h"
11#include "cubic_engine/rl/utils/torch_observation_normalizer.h"
12#include "cubic_engine/rl/networks/torch_nn.h"
13#include "cubic_engine/ml/neural_networks/torch_output_layers.h"
14#include "torch/torch.h"
15
16#include <memory>
17#include <vector>
18
19namespace cengine {
20namespace rl {
21namespace policies {
22
23
24
28class TorchPolicyImpl: public torch::nn::Module
29{
30
31public:
32
39 TorchPolicyImpl(actions::ActionSpace action_space,
40 std::shared_ptr<nets::TorchNNBase> base,
41 bool normalize_observations = false);
42
50 std::vector<torch::Tensor> act(torch::Tensor inputs,
51 torch::Tensor rnn_hxs,
52 torch::Tensor masks) const;
53
62 std::vector<torch::Tensor> evaluate_actions(torch::Tensor inputs,
63 torch::Tensor rnn_hxs,
64 torch::Tensor masks,
65 torch::Tensor actions) const;
66
74 torch::Tensor get_probs(torch::Tensor inputs,
75 torch::Tensor rnn_hxs,
76 torch::Tensor masks) const;
77
85 torch::Tensor get_values(torch::Tensor inputs,
86 torch::Tensor rnn_hxs,
87 torch::Tensor masks) const;
88
93 void update_observation_normalizer(torch::Tensor observations);
94
99 //bool is_recurrent() const { return base->is_recurrent(); }
100
105 uint_t get_hidden_size() const{return base_->get_hidden_size();}
106
111 bool using_observation_normalizer() const{return !observation_normalizer_.is_empty();}
112
113private:
114
118 actions::ActionSpace action_space_;
119
123 std::shared_ptr<nets::TorchNNBase> base_;
124
128 utils::TorchObservationNormalizer observation_normalizer_;
129
133 std::shared_ptr<cengine::ml::nets::LinearOutputLayer> output_layer_;
134
135 std::vector<torch::Tensor> forward_gru(torch::Tensor x,
136 torch::Tensor hxs,
137 torch::Tensor masks);
138
139};
140
141TORCH_MODULE(TorchPolicy);
142}
143
144}
145
146}
147#endif
148*/
149#endif // TORCH_POLICY_H