11#include "cuberl/base/cubeai_config.h"
23#include <torch/torch.h>
27#include <boost/log/trivial.hpp>
33namespace rl::algos::pg
36 template<
typename EnvType,
typename PolicyType,
typename CriticType>
37 class PPOSolver final:
public ACSolverBase<EnvType, PolicyType,
39 A2CMonitor<typename EnvType::action_type,
40 typename EnvType::state_type>,
58 typedef CriticType critic_type;
63 typedef typename env_type::state_type
state_type;
68 typedef typename env_type::action_type action_type;
70 typedef typename A2CMonitor<action_type,
71 state_type>::experience_buffer_type experience_buffer_type;
76 PPOSolver(
const PPOConfig& config,
77 policy_type& policy, critic_type& critic,
78 std::unique_ptr<torch::optim::Optimizer>& policy_optimizer,
79 std::unique_ptr<torch::optim::Optimizer>& critic_optimizer);
84 virtual EpisodeInfo on_training_episode(env_type&, uint_t )
override final;
88 std::tuple<real_t, real_t>
89 train_with_batch_(experience_buffer_type& buffer);
93 template<
typename EnvType,
typename PolicyType,
typename CriticType>
94 PPOSolver<EnvType, PolicyType, CriticType>::PPOSolver(
const PPOConfig& config,
95 policy_type& policy, critic_type& critic,
96 std::unique_ptr<torch::optim::Optimizer>& policy_optimizer,
97 std::unique_ptr<torch::optim::Optimizer>& critic_optimizer)
100 A2CMonitor<typename EnvType::action_type, typename EnvType::
state_type>,
101 PPOConfig>(config,
policy, critic, policy_optimizer, critic_optimizer)
104 template<
typename EnvType,
typename PolicyType,
typename CriticType>
106 PPOSolver<EnvType, PolicyType, CriticType>::on_training_episode(env_type& env, uint_t episode_idx)
108 auto start = std::chrono::steady_clock::now();
111 experience_buffer_type buffer(
this -> config_.max_itrs_per_episode);
114 auto eps_itrs =
this -> create_episode_batch_(env, episode_idx, buffer);
118 auto [episode_reward, total_episode_loss] = train_with_batch_(buffer);
120 auto end = std::chrono::steady_clock::now();
121 std::chrono::duration<real_t> elapsed_seconds = end - start;
123 this -> monitor_.episode_duration.push_back(eps_itrs);
126 info.episode_index = episode_idx;
127 info.episode_reward = episode_reward;
128 info.episode_iterations = eps_itrs;
129 info.total_time = elapsed_seconds;
134template<
typename EnvType,
typename PolicyType,
typename CriticType>
135std::tuple<real_t, real_t>
136PPOSolver<EnvType, PolicyType, CriticType>::train_with_batch_(experience_buffer_type& buffer){
139BOOST_LOG_TRIVIAL(info)<<
"PPO: Training with batch...: ";
145 using namespace cuberl::utils::pytorch;
147 typedef typename A2CMonitor<action_type, state_type>::experience_tuple_type experience_tuple_type;
148 typedef std::vector<experience_tuple_type> batch_type;
151 auto batch = buffer.template get<batch_type>();
152 auto states =
this -> monitor_.template get<state_type, 0>(batch);
153 std::vector<std::vector<float_t>> states_f(states.size());
155 for (
uint_t i=0; i < states.size(); ++i)
157 states_f[i] = std::vector<float_t>(states[i].begin(), states[i].end());
161 auto actions =
this -> monitor_.template get<action_type, 1>(batch);
162 std::vector<std::vector<float_t>> actions_f(actions.size());
164 for (
uint_t i=0; i < actions.size(); ++i)
166 actions_f[i] = std::vector<float_t>(actions[i].begin(), actions[i].end());
170 auto rewards_batch =
this -> monitor_.template get<float_t, 2>(batch);
171 auto values_batch =
this -> monitor_.template get<torch_tensor_t, 5>(batch);
172 auto logprobs_batch =
this -> monitor_.template get<torch_tensor_t, 4>(batch);
176 auto torch_states_batch = TorchAdaptor::stack(states_f,
this -> config_.device_type,
true);
177 auto torch_actions_batch = TorchAdaptor::stack(actions_f,
this -> config_.device_type,
true);
178 auto torch_rewards_batch = TorchAdaptor::to_torch(discounted_returns,
this -> config_.device_type,
false).detach();
179 auto torch_values_batch = TorchAdaptor::stack(values_batch,
this -> config_.device_type);
180 auto old_torch_logprobs_batch = TorchAdaptor::stack(logprobs_batch,
this -> config_.device_type).detach();
183 auto advantages = (torch_rewards_batch - torch_values_batch).detach();
186 std::vector<real_t> loss_vals(
this -> config_.max_passes_over_batch, 0.0);
187 for (
uint_t p=0; p <
this -> config_.max_passes_over_batch; ++p)
189 auto [new_log_probs, entropy,
_] =
this -> policy_ -> evaluate(torch_states_batch, torch_actions_batch);
191 auto ratio = (new_log_probs - old_torch_logprobs_batch).exp();
192 auto surr1 = ratio * advantages;
193 auto surr2 = torch::clamp(ratio, 1 -
this -> config_.clip_epsilon, 1 + this -> config_.clip_epsilon) * advantages;
194 auto actor_loss = -torch::min(surr1, surr2).mean();
196 auto value_estimates =
this -> critic_ -> forward(torch_states_batch).squeeze();
197 auto critic_loss = torch::nn::functional::mse_loss(value_estimates, torch_rewards_batch);
199 auto total_loss = actor_loss + 0.5 * critic_loss - 0.01 * entropy;
201 this -> policy_optimizer_ -> zero_grad();
202 this -> critic_optimizer_ -> zero_grad();
203 total_loss.backward();
204 loss_vals[p] = total_loss.item().template to<real_t>();
205 this -> policy_optimizer_ -> step();
206 this -> critic_optimizer_ -> step();
213 this -> monitor_.policy_loss_values.push_back(avg_loss);
214 this -> monitor_.rewards.push_back(R);
216BOOST_LOG_TRIVIAL(info)<<
"PPO: Done...: ";
218 return std::make_tuple(R, avg_loss);
std::size_t uint_t
uint_t
Definition bitrl_types.h:43
float float_t
float
Definition bitrl_types.h:28
PolicyType
Definition policy_type.h:8
real_t mean(IteratorType begin, IteratorType end, bool parallel=true)
mean Compute the mean value of the values in the provided iterator range
Definition vector_math.h:126
std::iterator_traits< IteratorType >::value_type sum(IteratorType begin, IteratorType end, bool parallel=true)
Definition vector_math.h:98
Definition dp_algo_base.h:14
std::vector< T > calculate_step_discounted_return(const std::vector< T > &rewards, T gamma)
Given an array of rewards, for each entry calculate the following: $$G = \sum_{k=t+1}^T \gamma^{k-t-1...
Definition utils.h:161
Various utilities used when working with RL problems.
Definition cuberl_types.h:16
std::pair< uint_t, uint_t > state_type
Definition example_15.cpp:28
int R
Definition extended_kalman_filter.py:54
info
Definition play.py:44
dict policy
Definition play.py:26
bitrl::envs::gymnasium::CliffWorld env_type
Definition rl_example_10.cpp:32