bitrl & cuberl Documentation
Simulation engine for reinforcement learning agents
Loading...
Searching...
No Matches
ppo.h
Go to the documentation of this file.
1#ifndef PPO_H
2#define PPO_H
3
10
11#include "cuberl/base/cubeai_config.h"
12
13#ifdef USE_PYTORCH
14
22
23#include <torch/torch.h>
24
25#ifdef CUBERL_DEBUG
26#include <cassert>
27#include <boost/log/trivial.hpp>
28#endif
29
30#include <chrono>
31
32namespace cuberl{
33namespace rl::algos::pg
34{
35
36 template<typename EnvType, typename PolicyType, typename CriticType>
37 class PPOSolver final: public ACSolverBase<EnvType, PolicyType,
38 CriticType,
39 A2CMonitor<typename EnvType::action_type,
40 typename EnvType::state_type>,
41 PPOConfig>
42 {
43 public:
44
48 typedef EnvType env_type;
49
53 typedef PolicyType policy_type;
54
58 typedef CriticType critic_type;
59
63 typedef typename env_type::state_type state_type;
64
68 typedef typename env_type::action_type action_type;
69
70 typedef typename A2CMonitor<action_type,
71 state_type>::experience_buffer_type experience_buffer_type;
72
76 PPOSolver(const PPOConfig& config,
77 policy_type& policy, critic_type& critic,
78 std::unique_ptr<torch::optim::Optimizer>& policy_optimizer,
79 std::unique_ptr<torch::optim::Optimizer>& critic_optimizer);
80
84 virtual EpisodeInfo on_training_episode(env_type&, uint_t /*episode_idx*/) override final;
85
86 private:
87
88 std::tuple<real_t, real_t>
89 train_with_batch_(experience_buffer_type& buffer);
90
91 };
92
93 template<typename EnvType, typename PolicyType, typename CriticType>
94 PPOSolver<EnvType, PolicyType, CriticType>::PPOSolver(const PPOConfig& config,
95 policy_type& policy, critic_type& critic,
96 std::unique_ptr<torch::optim::Optimizer>& policy_optimizer,
97 std::unique_ptr<torch::optim::Optimizer>& critic_optimizer)
98 :
99 ACSolverBase<EnvType, PolicyType, CriticType,
100 A2CMonitor<typename EnvType::action_type, typename EnvType::state_type>,
101 PPOConfig>(config, policy, critic, policy_optimizer, critic_optimizer)
102 {}
103
104 template<typename EnvType, typename PolicyType, typename CriticType>
105 EpisodeInfo
106 PPOSolver<EnvType, PolicyType, CriticType>::on_training_episode(env_type& env, uint_t episode_idx)
107 {
108 auto start = std::chrono::steady_clock::now();
109
110 // the buffer to use
111 experience_buffer_type buffer(this -> config_.max_itrs_per_episode);
112
113 // collect the buffer
114 auto eps_itrs = this -> create_episode_batch_(env, episode_idx, buffer);
115
116 // train the networks with from the
117 // collected buffer
118 auto [episode_reward, total_episode_loss] = train_with_batch_(buffer);
119
120 auto end = std::chrono::steady_clock::now();
121 std::chrono::duration<real_t> elapsed_seconds = end - start;
122
123 this -> monitor_.episode_duration.push_back(eps_itrs);
124
125 EpisodeInfo info;
126 info.episode_index = episode_idx;
127 info.episode_reward = episode_reward;
128 info.episode_iterations = eps_itrs;
129 info.total_time = elapsed_seconds;
130 return info;
131
132 }
133
134template<typename EnvType, typename PolicyType, typename CriticType>
135std::tuple<real_t, real_t>
136PPOSolver<EnvType, PolicyType, CriticType>::train_with_batch_(experience_buffer_type& buffer){
137
138#ifdef CUBERL_DEBUG
139BOOST_LOG_TRIVIAL(info)<<"PPO: Training with batch...: ";
140#endif
141
142
143 // because of the way we treat the values
144 // we loose the requires_grad so we need to set it
145 using namespace cuberl::utils::pytorch;
146 using namespace cuberl::rl::algos;
147 typedef typename A2CMonitor<action_type, state_type>::experience_tuple_type experience_tuple_type;
148 typedef std::vector<experience_tuple_type> batch_type;
149
150 // the batch for this episode
151 auto batch = buffer.template get<batch_type>();
152 auto states = this -> monitor_.template get<state_type, 0>(batch);
153 std::vector<std::vector<float_t>> states_f(states.size());
154
155 for (uint_t i=0; i < states.size(); ++i)
156 {
157 states_f[i] = std::vector<float_t>(states[i].begin(), states[i].end());
158 }
159
160
161 auto actions = this -> monitor_.template get<action_type, 1>(batch);
162 std::vector<std::vector<float_t>> actions_f(actions.size());
163
164 for (uint_t i=0; i < actions.size(); ++i)
165 {
166 actions_f[i] = std::vector<float_t>(actions[i].begin(), actions[i].end());
167 }
168
169
170 auto rewards_batch = this -> monitor_.template get<float_t, 2>(batch);
171 auto values_batch = this -> monitor_.template get<torch_tensor_t, 5>(batch);
172 auto logprobs_batch = this -> monitor_.template get<torch_tensor_t, 4>(batch);
173
174 // compute the discounted rewards for this batch
175 auto discounted_returns = calculate_step_discounted_return(rewards_batch, static_cast<float_t>(this->config_.gamma));
176 auto torch_states_batch = TorchAdaptor::stack(states_f, this -> config_.device_type, true);
177 auto torch_actions_batch = TorchAdaptor::stack(actions_f,this -> config_.device_type, true);
178 auto torch_rewards_batch = TorchAdaptor::to_torch(discounted_returns, this -> config_.device_type, false).detach();
179 auto torch_values_batch = TorchAdaptor::stack(values_batch, this -> config_.device_type);
180 auto old_torch_logprobs_batch = TorchAdaptor::stack(logprobs_batch, this -> config_.device_type).detach();
181
182 // form the advantage
183 auto advantages = (torch_rewards_batch - torch_values_batch).detach();
184
185
186 std::vector<real_t> loss_vals(this -> config_.max_passes_over_batch, 0.0);
187 for (uint_t p=0; p < this -> config_.max_passes_over_batch; ++p)
188 {
189 auto [new_log_probs, entropy, _] = this -> policy_ -> evaluate(torch_states_batch, torch_actions_batch);
190
191 auto ratio = (new_log_probs - old_torch_logprobs_batch).exp();
192 auto surr1 = ratio * advantages;
193 auto surr2 = torch::clamp(ratio, 1 - this -> config_.clip_epsilon, 1 + this -> config_.clip_epsilon) * advantages;
194 auto actor_loss = -torch::min(surr1, surr2).mean();
195
196 auto value_estimates = this -> critic_ -> forward(torch_states_batch).squeeze();
197 auto critic_loss = torch::nn::functional::mse_loss(value_estimates, torch_rewards_batch);
198
199 auto total_loss = actor_loss + 0.5 * critic_loss - 0.01 * entropy;
200
201 this -> policy_optimizer_ -> zero_grad();
202 this -> critic_optimizer_ -> zero_grad();
203 total_loss.backward();
204 loss_vals[p] = total_loss.item().template to<real_t>();
205 this -> policy_optimizer_ -> step();
206 this -> critic_optimizer_ -> step();
207
208 }
209
210 // compute the undiscounted reward as the reward for this episode
211 auto R = cuberl::maths::sum(rewards_batch);
212 auto avg_loss = cuberl::maths::mean(loss_vals);
213 this -> monitor_.policy_loss_values.push_back(avg_loss);
214 this -> monitor_.rewards.push_back(R);
215#ifdef CUBERL_DEBUG
216BOOST_LOG_TRIVIAL(info)<<"PPO: Done...: ";
217#endif
218 return std::make_tuple(R, avg_loss);
219
220}
221
222
223}
224}// cuberl
225#endif
226#endif
std::size_t uint_t
uint_t
Definition bitrl_types.h:43
float float_t
float
Definition bitrl_types.h:28
PolicyType
Definition policy_type.h:8
real_t mean(IteratorType begin, IteratorType end, bool parallel=true)
mean Compute the mean value of the values in the provided iterator range
Definition vector_math.h:126
std::iterator_traits< IteratorType >::value_type sum(IteratorType begin, IteratorType end, bool parallel=true)
Definition vector_math.h:98
Definition dp_algo_base.h:14
std::vector< T > calculate_step_discounted_return(const std::vector< T > &rewards, T gamma)
Given an array of rewards, for each entry calculate the following: $$G = \sum_{k=t+1}^T \gamma^{k-t-1...
Definition utils.h:161
Various utilities used when working with RL problems.
Definition cuberl_types.h:16
std::pair< uint_t, uint_t > state_type
Definition example_15.cpp:28
int R
Definition extended_kalman_filter.py:54
info
Definition play.py:44
_
Definition play.py:34
dict policy
Definition play.py:26
bitrl::envs::gymnasium::CliffWorld env_type
Definition rl_example_10.cpp:32