10#include "cuberl/base/cubeai_config.h"
22#include "cuberl/data_structs/experience_buffer.h"
25#include <boost/log/trivial.hpp>
26#include <torch/torch.h>
54template<
typename EnvType,
typename PolicyType>
55class ReinforceSolver final:
public RLSolverBase<EnvType>
62 typedef typename env_type::state_type
state_type;
63 typedef typename env_type::action_type action_type;
64 typedef typename ReinforceMonitor<action_type,
65 state_type>::experience_buffer_type experience_buffer_type;
70 ReinforceSolver(ReinforceConfig opts,
72 std::unique_ptr<torch::optim::Optimizer>& policy_optimizer);
78 virtual void actions_before_training_begins(env_type&);
84 virtual void actions_after_training_ends(env_type&){}
89 virtual void actions_before_episode_begins(env_type&, uint_t ){}
94 virtual void actions_after_episode_ends(env_type&, uint_t ,
95 const EpisodeInfo& ){}
100 virtual EpisodeInfo on_training_episode(env_type&, uint_t );
105 ReinforceMonitor<action_type, state_type>& get_monitor(){
return monitor_;}
112 ReinforceConfig config_;
117 policy_type policy_ptr_;
122 std::unique_ptr<torch::optim::Optimizer> policy_optimizer_;
127 ReinforceMonitor<action_type, state_type> monitor_;
134 uint_t create_episode_batch_(env_type& env, experience_buffer_type& buffer);
140 std::tuple<real_t, real_t> train_batch_(experience_buffer_type& buffer);
146 std::tuple<real_t, real_t> train_sequential_(experience_buffer_type& buffer);
151 std::tuple<real_t, real_t> train_without_baseline_(experience_buffer_type& buffer);
156 std::tuple<real_t, real_t> train_with_baseline_(experience_buffer_type& buffer);
160template<
typename EnvType,
typename PolicyType>
161ReinforceSolver<EnvType, PolicyType>::ReinforceSolver(ReinforceConfig config,
163 std::unique_ptr<torch::optim::Optimizer>& policy_optimizer)
165 RLSolverBase<EnvType>(),
168 policy_optimizer_(std::move(policy_optimizer)),
173template<
typename EnvType,
typename PolicyType>
175ReinforceSolver<EnvType, PolicyType>::actions_before_training_begins(env_type& ){
177 monitor_.policy_loss_values.reserve(config_.n_episodes);
178 monitor_.rewards.reserve(config_.n_episodes);
179 monitor_.episode_duration.reserve(config_.n_episodes);
182 policy_ptr_ -> train();
186template<
typename EnvType,
typename PolicyType>
188ReinforceSolver<EnvType,
190 >::create_episode_batch_(env_type& env, experience_buffer_type& buffer){
195 typedef typename ReinforceMonitor<action_type,
196 state_type>::experience_tuple_type experience_tuple_type;
199 auto old_timestep =
env.reset();
207 for(; itr < config_.max_itrs_per_episode; ++itr){
211 auto [
action, log_prob] = policy_ptr_ -> act(old_timestep.observation());
214 auto new_timestep =
env.step(action);
215 auto reward = new_timestep.reward();
217 experience_tuple_type exp = {old_timestep.observation(),
226 if (new_timestep.done()){
230 old_timestep = new_timestep;
237template<
typename EnvType,
typename PolicyType>
239ReinforceSolver<EnvType, PolicyType>::on_training_episode(env_type& env,
243 auto start = std::chrono::steady_clock::now();
246 experience_buffer_type buffer(config_.max_itrs_per_episode);
250 auto itrs = create_episode_batch_(env, buffer);
255 auto [episode_reward, total_episode_loss] = train_without_baseline_(buffer);
256 info.episode_reward = episode_reward;
260 auto [episode_reward, total_episode_loss] = train_with_baseline_(buffer);
261 info.episode_reward = episode_reward;
263 monitor_.episode_duration.push_back(itrs);
265 auto end = std::chrono::steady_clock::now();
266 std::chrono::duration<real_t> elapsed_seconds = end - start;
269 info.episode_index = episode_idx;
270 info.episode_iterations = itrs;
271 info.total_time = elapsed_seconds;
276template<
typename EnvType,
typename PolicyType>
277std::tuple<real_t, real_t>
278ReinforceSolver<EnvType, PolicyType>::train_batch_(experience_buffer_type& buffer){
280 typedef typename ReinforceMonitor<action_type,
281 state_type>::experience_tuple_type experience_tuple_type;
283 typedef std::vector<experience_tuple_type> batch_type;
286 auto batch = buffer.template get<batch_type>();
289 auto reward_batch = monitor_.template get<real_t, 2>(batch);
290 auto log_probs_batch = monitor_.template get<torch_tensor_t, 4>(batch);
296 if(config_.normalize_rewards){
300 std::vector<torch_tensor_t> loss_vals = compute_loss_item(discounted_returns,
304 auto loss = cuberl::utils::pytorch::TorchAdaptor::stack(loss_vals,
307 policy_optimizer_ -> zero_grad();
309 policy_optimizer_ -> step();
311 auto total_episode_loss = loss.item().to<
real_t>();
316 return std::make_tuple(R, total_episode_loss);
321template<
typename EnvType,
typename PolicyType>
322std::tuple<real_t, real_t>
323ReinforceSolver<EnvType, PolicyType>::train_sequential_(experience_buffer_type& buffer){
326 typedef typename ReinforceMonitor<action_type,
327 state_type>::experience_tuple_type experience_tuple_type;
329 typedef std::vector<experience_tuple_type> batch_type;
332 auto batch = buffer.template get<batch_type>();
335 auto reward_batch = monitor_.template get<real_t, 2>(batch);
336 auto log_probs_batch = monitor_.template get<torch_tensor_t, 4>(batch);
343 if(config_.normalize_rewards){
347 std::vector<torch_tensor_t> loss_vals = compute_loss_item(discounted_returns,
352 auto total_episode_loss = 0.0;
353 for(uint_t l=0; l<loss_vals.size(); ++l){
355 auto loss = loss_vals[l];
356 policy_optimizer_ -> zero_grad();
358 policy_optimizer_ -> step();
360 total_episode_loss += loss.item().to<
real_t>();
364 return std::make_tuple(R, total_episode_loss / loss_vals.size());
367template<
typename EnvType,
typename PolicyType>
368std::tuple<real_t, real_t>
369ReinforceSolver<EnvType, PolicyType>::train_without_baseline_(experience_buffer_type& buffer){
373 auto [episode_reward, total_episode_loss] = train_batch_(buffer);
374 monitor_.policy_loss_values.push_back(total_episode_loss);
375 monitor_.rewards.push_back(episode_reward);
376 return std::make_tuple(episode_reward, total_episode_loss);
380 auto [episode_reward, total_episode_loss] = train_sequential_(buffer);
381 monitor_.policy_loss_values.push_back(total_episode_loss);
382 monitor_.rewards.push_back(episode_reward);
383 return std::make_tuple(episode_reward, total_episode_loss);
388template<
typename EnvType,
typename PolicyType>
389std::tuple<real_t, real_t>
390ReinforceSolver<EnvType, PolicyType>::train_with_baseline_(experience_buffer_type& buffer){
393 typedef typename ReinforceMonitor<action_type,
394 state_type>::experience_tuple_type experience_tuple_type;
395 typedef std::vector<experience_tuple_type> batch_type;
398 auto batch = buffer.template get<batch_type>();
399 auto reward_batch = monitor_.template get<real_t, 2>(batch);
404 if(config_.baseline_type == BaselineEnumType::CONSTANT){
405 discounted_returns = compute_baseline_with_constant(discounted_returns,
406 config_.baseline_constant);
408 else if(config_.baseline_type == BaselineEnumType::MEAN){
409 discounted_returns = compute_baseline_with_mean(discounted_returns);
412 discounted_returns = compute_baseline_with_standardization(discounted_returns,
416 auto log_probs_batch = monitor_.template get<torch_tensor_t, 4>(batch);
417 std::vector<torch_tensor_t> loss_vals = compute_loss_item(discounted_returns,
421 auto loss = cuberl::utils::pytorch::TorchAdaptor::stack(loss_vals,
424 policy_optimizer_ -> zero_grad();
426 policy_optimizer_ -> step();
428 auto total_episode_loss = loss.item().to<
real_t>();
434 monitor_.policy_loss_values.push_back(total_episode_loss);
435 monitor_.rewards.push_back(R);
437 return std::make_tuple(R, total_episode_loss);
double real_t
real_t
Definition bitrl_types.h:23
std::size_t uint_t
uint_t
Definition bitrl_types.h:43
PolicyType
Definition policy_type.h:8
std::iterator_traits< IteratorType >::value_type sum(IteratorType begin, IteratorType end, bool parallel=true)
Definition vector_math.h:98
std::vector< T > normalize_max(const std::vector< T > &vec)
Definition vector_math.h:564
std::vector< T > calculate_step_discounted_return(const std::vector< T > &rewards, T gamma)
Given an array of rewards, for each entry calculate the following: $$G = \sum_{k=t+1}^T \gamma^{k-t-1...
Definition utils.h:161
Various utilities used when working with RL problems.
Definition cuberl_types.h:16
std::pair< uint_t, uint_t > state_type
Definition example_15.cpp:28
int R
Definition extended_kalman_filter.py:54
dict action
Definition play.py:41
reward
Definition play.py:44
info
Definition play.py:44
dict policy
Definition play.py:26
bitrl::envs::gymnasium::CliffWorld env_type
Definition rl_example_10.cpp:32