28class TorchPolicyImpl: public torch::nn::Module
39 TorchPolicyImpl(actions::ActionSpace action_space,
40 std::shared_ptr<nets::TorchNNBase> base,
41 bool normalize_observations = false);
50 std::vector<torch::Tensor> act(torch::Tensor inputs,
51 torch::Tensor rnn_hxs,
52 torch::Tensor masks) const;
62 std::vector<torch::Tensor> evaluate_actions(torch::Tensor inputs,
63 torch::Tensor rnn_hxs,
65 torch::Tensor actions) const;
74 torch::Tensor get_probs(torch::Tensor inputs,
75 torch::Tensor rnn_hxs,
76 torch::Tensor masks) const;
85 torch::Tensor get_values(torch::Tensor inputs,
86 torch::Tensor rnn_hxs,
87 torch::Tensor masks) const;
93 void update_observation_normalizer(torch::Tensor observations);
99 //bool is_recurrent() const { return base->is_recurrent(); }
105 uint_t get_hidden_size() const{return base_->get_hidden_size();}
111 bool using_observation_normalizer() const{return !observation_normalizer_.is_empty();}
118 actions::ActionSpace action_space_;
123 std::shared_ptr<nets::TorchNNBase> base_;
128 utils::TorchObservationNormalizer observation_normalizer_;
133 std::shared_ptr<cengine::ml::nets::LinearOutputLayer> output_layer_;
135 std::vector<torch::Tensor> forward_gru(torch::Tensor x,
137 torch::Tensor masks);
141TORCH_MODULE(TorchPolicy);