1#ifndef ACTOR_CRITIC_SOLVER_BASE_H
2#define ACTOR_CRITIC_SOLVER_BASE_H
11#include "cuberl/base/cubeai_config.h"
22#include "cuberl/data_structs/experience_buffer.h"
25#include <torch/torch.h>
30#include <boost/log/trivial.hpp>
62 typename CriticType,
typename MonitorType,
64class ACSolverBase:
public RLSolverBase<EnvType>
81 typedef CriticType critic_type;
83 typedef typename env_type::state_type
state_type;
84 typedef typename env_type::action_type action_type;
89 typedef MonitorType monitor_type;
92 typedef typename monitor_type::experience_buffer_type experience_buffer_type;
93 typedef typename monitor_type::experience_tuple_type experience_tuple_type;
98 typedef ConfigType config_type;
103 virtual ~ACSolverBase()=
default;
109 virtual void actions_after_training_ends(env_type&){}
114 virtual void actions_before_episode_begins(env_type&,
120 virtual void actions_after_episode_ends(env_type&,
122 const EpisodeInfo&){}
128 virtual void actions_before_training_begins(env_type&);
133 void set_train_mode()noexcept;
138 void set_evaluation_mode()noexcept;
143 monitor_type& get_monitor(){
return monitor_;}
153 ACSolverBase(
const config_type& config,
154 policy_type& policy, critic_type& critic,
155 std::unique_ptr<torch::optim::Optimizer>& policy_optimizer,
156 std::unique_ptr<torch::optim::Optimizer>& critic_optimizer);
166 policy_type& policy_;
171 critic_type& critic_;
176 monitor_type monitor_;
181 std::unique_ptr<torch::optim::Optimizer> policy_optimizer_;
186 std::unique_ptr<torch::optim::Optimizer> critic_optimizer_;
194 create_episode_batch_(env_type& env, uint_t , experience_buffer_type& buffer);
198template<
typename EnvType,
typename PolicyType,
199 typename CriticType,
typename MonitorType,
202 MonitorType, ConfigType>::ACSolverBase(
const config_type& config,
203 policy_type& policy, critic_type& critic,
204 std::unique_ptr<torch::optim::Optimizer>& policy_optimizer,
205 std::unique_ptr<torch::optim::Optimizer>& critic_optimizer)
207 RLSolverBase<EnvType>(),
212 policy_optimizer_(std::move(policy_optimizer)),
213 critic_optimizer_(std::move(critic_optimizer))
216template<
typename EnvType,
typename PolicyType,
217 typename CriticType,
typename MonitorType,
221 MonitorType, ConfigType>::set_train_mode()noexcept{
227template<
typename EnvType,
typename PolicyType,
228 typename CriticType,
typename MonitorType,
232 MonitorType, ConfigType>::set_evaluation_mode()noexcept{
238template<
typename EnvType,
typename PolicyType,
239 typename CriticType,
typename MonitorType,
243 MonitorType, ConfigType>::actions_before_training_begins(env_type& ){
246 monitor_.policy_loss_values.reserve(config_.n_episodes);
247 monitor_.critic_loss_values.reserve(config_.n_episodes);
248 monitor_.rewards.reserve(config_.n_episodes);
249 monitor_.episode_duration.reserve(config_.n_episodes);
253template<
typename EnvType,
typename PolicyType,
254 typename CriticType,
typename MonitorType,
258 MonitorType, ConfigType>::create_episode_batch_(env_type& env, uint_t episode_idx, experience_buffer_type& buffer)
262BOOST_LOG_TRIVIAL(info)<<
"Collecting batch for episode: "<<episode_idx;
268 typedef typename MonitorType::experience_tuple_type experience_tuple_type;
272 auto old_timestep =
env.reset();
276 for(; itrs < config_.max_itrs_per_episode; ++itrs){
278 auto [
action, log_prob] = policy_ -> act(old_timestep.observation());
279 auto values = critic_ -> evaluate(old_timestep.observation());
282 auto next_time_step =
env.step(action);
283 auto next_state = next_time_step.observation();
284 auto reward = next_time_step.reward();
286 experience_tuple_type exp = {old_timestep.observation(),
289 next_time_step.done(),
297 if (next_time_step.done()){
301 old_timestep = next_time_step;
306BOOST_LOG_TRIVIAL(info)<<
"Done... ";
std::size_t uint_t
uint_t
Definition bitrl_types.h:43
PolicyType
Definition policy_type.h:8
Various utilities used when working with RL problems.
Definition cuberl_types.h:16
std::pair< uint_t, uint_t > state_type
Definition example_15.cpp:28
dict action
Definition play.py:41
reward
Definition play.py:44
dict policy
Definition play.py:26
list values
Definition plot_losses.py:13
bitrl::envs::gymnasium::CliffWorld env_type
Definition rl_example_10.cpp:32