11#include "cuberl/base/cubeai_config.h"
22#include "cuberl/data_structs/experience_buffer.h"
25#include <torch/torch.h>
37namespace rl::algos::pg
53 template<
typename EnvType,
typename PolicyType,
typename CriticType>
54 class A2CSolver final:
public RLSolverBase<EnvType>
71 typedef CriticType critic_type;
73 typedef typename env_type::state_type
state_type;
74 typedef typename env_type::action_type action_type;
76 typedef typename A2CMonitor<action_type,
77 state_type>::experience_buffer_type experience_buffer_type;
84 A2CSolver(
const A2CConfig& config,
85 policy_type& policy, critic_type& critic,
86 std::unique_ptr<torch::optim::Optimizer>& policy_optimizer,
87 std::unique_ptr<torch::optim::Optimizer>& critic_optimizer);
93 virtual void actions_before_training_begins(env_type&);
99 virtual void actions_after_training_ends(env_type&)
override final{}
104 virtual void actions_before_episode_begins(env_type&,
105 uint_t )
override final{}
110 virtual void actions_after_episode_ends(env_type&,
112 const EpisodeInfo&)
override final{}
117 virtual EpisodeInfo on_training_episode(env_type&, uint_t );
122 void set_train_mode()noexcept;
127 void set_evaluation_mode()noexcept;
132 A2CMonitor<action_type, state_type>& get_monitor(){
return monitor_;}
144 policy_type& policy_;
149 critic_type& critic_;
154 A2CMonitor<action_type, state_type> monitor_;
159 std::unique_ptr<torch::optim::Optimizer> policy_optimizer_;
164 std::unique_ptr<torch::optim::Optimizer> critic_optimizer_;
170 uint_t create_episode_batch_(env_type&,
172 experience_buffer_type& buffer);
174 std::tuple<real_t, real_t>
175 train_with_batch_(experience_buffer_type& buffer);
179 template<
typename EnvType,
typename PolicyType,
typename CriticType>
180 A2CSolver<EnvType, PolicyType, CriticType>::A2CSolver(
const A2CConfig& config,
181 policy_type& policy, critic_type& critic,
182 std::unique_ptr<torch::optim::Optimizer>& policy_optimizer,
183 std::unique_ptr<torch::optim::Optimizer>& critic_optimizer)
189 policy_optimizer_(std::move(policy_optimizer)),
190 critic_optimizer_(std::move(critic_optimizer))
193 template<
typename EnvType,
typename PolicyType,
typename CriticType>
195 A2CSolver<EnvType, PolicyType, CriticType>::set_train_mode()noexcept{
201 template<
typename EnvType,
typename PolicyType,
typename CriticType>
203 A2CSolver<EnvType, PolicyType, CriticType>::set_evaluation_mode()noexcept{
209 template<
typename EnvType,
typename PolicyType,
typename CriticType>
211 A2CSolver<EnvType, PolicyType, CriticType>::actions_before_training_begins(env_type& ){
214 monitor_.policy_loss_values.reserve(config_.n_episodes);
215 monitor_.critic_loss_values.reserve(config_.n_episodes);
216 monitor_.rewards.reserve(config_.n_episodes);
217 monitor_.episode_duration.reserve(config_.n_episodes);
221 template<
typename EnvType,
typename PolicyType,
typename CriticType>
223 A2CSolver<EnvType, PolicyType, CriticType>::on_training_episode(env_type& env, uint_t episode_idx){
225 auto start = std::chrono::steady_clock::now();
228 experience_buffer_type buffer(config_.max_itrs_per_episode);
231 auto eps_itrs = create_episode_batch_(env, episode_idx, buffer);
235 auto [episode_reward, total_episode_loss] = train_with_batch_(buffer);
237 auto end = std::chrono::steady_clock::now();
238 std::chrono::duration<real_t> elapsed_seconds = end - start;
240 monitor_.episode_duration.push_back(eps_itrs);
243 info.episode_index = episode_idx;
244 info.episode_reward = episode_reward;
245 info.episode_iterations = eps_itrs;
246 info.total_time = elapsed_seconds;
250 template<
typename EnvType,
typename PolicyType,
typename CriticType>
252 A2CSolver<EnvType, PolicyType, CriticType>::create_episode_batch_(env_type& env,
254 experience_buffer_type& buffer){
257 typedef typename A2CMonitor<action_type,
258 state_type>::experience_tuple_type experience_tuple_type;
262 auto old_timestep =
env.reset();
266 for(; itrs < config_.max_itrs_per_episode; ++itrs){
268 auto [
action, log_prob] = policy_ -> act(old_timestep.observation());
269 auto values = critic_ -> evaluate(old_timestep.observation());
272 auto next_time_step =
env.step(action);
274 auto next_state = next_time_step.observation();
275 auto reward = next_time_step.reward();
277 experience_tuple_type exp = {old_timestep.observation(),
280 next_time_step.done(),
287 if (next_time_step.done()){
291 old_timestep = next_time_step;
299 template<
typename EnvType,
typename PolicyType,
typename CriticType>
300 std::tuple<real_t, real_t>
301 A2CSolver<EnvType, PolicyType, CriticType>::train_with_batch_(experience_buffer_type& buffer){
306 using namespace cuberl::utils::pytorch;
308 typedef typename A2CMonitor<action_type,
309 state_type>::experience_tuple_type experience_tuple_type;
310 typedef std::vector<experience_tuple_type> batch_type;
313 auto batch = buffer.template get<batch_type>();
314 auto rewards_batch = monitor_.template get<real_t, 2>(batch);
315 auto values_batch = monitor_.template get<torch_tensor_t, 5>(batch);
316 auto logprobs_batch = monitor_.template get<torch_tensor_t, 4>(batch);
323 auto torch_rewards_batch = TorchAdaptor::to_torch(discounted_returns,
327 auto torch_values_batch = TorchAdaptor::stack(values_batch,
331 auto torch_logprobs_batch = TorchAdaptor::stack(logprobs_batch,
332 config_.device_type);
335 auto advantage = torch_rewards_batch - torch_values_batch;
338 auto actor_loss = -(torch_logprobs_batch * advantage.detach()).
mean();
339 auto critic_loss = advantage.pow(2).mean();
341 if(config_.clip_policy_grad){
344 torch::nn::utils::clip_grad_norm_(policy_->parameters(),
345 config_.max_grad_norm_policy);
350 if(config_.clip_critic_grad){
351 torch::nn::utils::clip_grad_norm_(critic_->parameters(),
352 config_.max_grad_norm_critic);
357 policy_optimizer_->zero_grad();
358 critic_optimizer_ -> zero_grad();
360 actor_loss.backward();
361 critic_loss.backward();
363 policy_optimizer_ -> step();
364 critic_optimizer_ -> step();
367 auto total_episode_policy_loss = actor_loss.item().template to<real_t>();
368 auto total_episode_critic_loss = critic_loss.item().template to<real_t>();
374 monitor_.policy_loss_values.push_back(total_episode_policy_loss);
375 monitor_.critic_loss_values.push_back(total_episode_critic_loss);
376 monitor_.rewards.push_back(R);
378 return std::make_tuple(R, total_episode_policy_loss + total_episode_critic_loss);
std::size_t uint_t
uint_t
Definition bitrl_types.h:43
PolicyType
Definition policy_type.h:8
real_t mean(IteratorType begin, IteratorType end, bool parallel=true)
mean Compute the mean value of the values in the provided iterator range
Definition vector_math.h:126
std::iterator_traits< IteratorType >::value_type sum(IteratorType begin, IteratorType end, bool parallel=true)
Definition vector_math.h:98
std::vector< T > calculate_step_discounted_return(const std::vector< T > &rewards, T gamma)
Given an array of rewards, for each entry calculate the following: $$G = \sum_{k=t+1}^T \gamma^{k-t-1...
Definition utils.h:161
Various utilities used when working with RL problems.
Definition cuberl_types.h:16
std::pair< uint_t, uint_t > state_type
Definition example_15.cpp:28
int R
Definition extended_kalman_filter.py:54
dict action
Definition play.py:41
reward
Definition play.py:44
info
Definition play.py:44
dict policy
Definition play.py:26
list values
Definition plot_losses.py:13
bitrl::envs::gymnasium::CliffWorld env_type
Definition rl_example_10.cpp:32