1#ifndef RL_SERIAL_AGENT_TRAINER_H
2#define RL_SERIAL_AGENT_TRAINER_H
10#include <boost/noncopyable.hpp>
11#include <boost/log/trivial.hpp>
38template<
typename EnvType,
typename AgentType>
129template<
typename EnvType,
typename AgentType>
133 output_msg_frequency_(config.output_msg_frequency),
134 itr_ctrl_(config.n_episodes, config.tolerance),
136 total_reward_per_episode_(),
137 n_itrs_per_episode_()
140template<
typename EnvType,
typename AgentType>
144 agent_.actions_before_training_begins(env);
145 total_reward_per_episode_.clear();
146 n_itrs_per_episode_.clear();
148 total_reward_per_episode_.reserve(itr_ctrl_.get_max_iterations());
149 n_itrs_per_episode_.reserve(itr_ctrl_.get_max_iterations());
152template<
typename EnvType,
typename AgentType>
156 agent_.actions_before_episode_begins(env, episode_idx);
159template<
typename EnvType,
typename AgentType>
163 agent_.actions_after_episode_ends(env, episode_idx, einfo);
166template<
typename EnvType,
typename AgentType>
169 agent_.actions_after_training_ends(env);
172template<
typename EnvType,
typename AgentType>
176 BOOST_LOG_TRIVIAL(info)<<
" Start training on environment...";
179 auto start = std::chrono::steady_clock::now();
181 this->actions_before_training_begins(env);
183 uint_t episode_counter = 0;
184 bool stop_training =
false;
185 while(itr_ctrl_.continue_iterations()){
187 this->actions_before_episode_begins(env, episode_counter);
188 auto episode_info = agent_.on_training_episode(env, episode_counter);
191 episode_counter % output_msg_frequency_ == 0){
193 BOOST_LOG_TRIVIAL(info)<<episode_info;
196 total_reward_per_episode_.push_back(episode_info.episode_reward);
197 n_itrs_per_episode_.push_back(episode_info.episode_iterations);
198 this->actions_after_episode_ends(env, episode_counter, episode_info);
200 if(episode_info.stop_training){
201 BOOST_LOG_TRIVIAL(info)<<
" Stopping training at index="<<episode_counter;
205 stop_training =
true;
208 episode_counter += 1;
211 this->actions_after_training_ends(env);
212 auto end = std::chrono::steady_clock::now();
213 std::chrono::duration<real_t> elapsed_seconds = end-start;
215 BOOST_LOG_TRIVIAL(info)<<
" Done... ";
217 auto state = itr_ctrl_.get_state();
218 state.total_time = elapsed_seconds;
219 state.converged = stop_training;
Controller for iterative algorithms.
Definition iterative_algorithm_controller.h:17
Definition rl_serial_agent_trainer.h:40
virtual bitrl::utils::IterativeAlgorithmResult train(env_type &env)
train Iterate to train the agent on the given environment
Definition rl_serial_agent_trainer.h:174
EnvType env_type
Definition rl_serial_agent_trainer.h:43
const std::vector< uint_t > & n_itrs_per_episode() const noexcept
n_itrs_per_episode
Definition rl_serial_agent_trainer.h:95
const std::vector< real_t > & episodes_total_rewards() const noexcept
episodes_total_rewards
Definition rl_serial_agent_trainer.h:88
std::vector< real_t > total_reward_per_episode_
total_reward_per_episode_
Definition rl_serial_agent_trainer.h:119
agent_type & agent_
agent_
Definition rl_serial_agent_trainer.h:114
uint_t output_msg_frequency_
Definition rl_serial_agent_trainer.h:103
virtual void actions_after_training_ends(env_type &)
actions_after_training_ends. Execute any actions the algorithm needs after the iterations are finishe...
Definition rl_serial_agent_trainer.h:168
virtual void actions_before_training_begins(env_type &)
actions_before_training_begins. Execute any actions the algorithm needs before starting the episode
Definition rl_serial_agent_trainer.h:142
RLSerialAgentTrainer(const RLSerialTrainerConfig &config, agent_type &agent)
RLSerialAgentTrainer.
Definition rl_serial_agent_trainer.h:130
std::vector< uint_t > n_itrs_per_episode_
n_itrs_per_episode_ Holds the number of iterations performed per training episode
Definition rl_serial_agent_trainer.h:125
AgentType agent_type
Definition rl_serial_agent_trainer.h:44
virtual void actions_before_episode_begins(env_type &, uint_t)
actions_before_episode_begins. Execute any actions the algorithm needs before starting the episode
Definition rl_serial_agent_trainer.h:154
bitrl::utils::IterativeAlgorithmController itr_ctrl_
itr_ctrl_ Handles the iteration over the episodes
Definition rl_serial_agent_trainer.h:109
virtual void actions_after_episode_ends(env_type &, uint_t, const EpisodeInfo &einfo)
actions_after_episode_ends. Execute any actions the algorithm needs after ending the episode
Definition rl_serial_agent_trainer.h:161
const uint_t INVALID_ID
Invalid id.
Definition bitrl_consts.h:21
const real_t TOLERANCE
Tolerance used around the library.
Definition bitrl_consts.h:31
double real_t
real_t
Definition bitrl_types.h:23
std::size_t uint_t
uint_t
Definition bitrl_types.h:43
Various utilities used when working with RL problems.
Definition cuberl_types.h:16
The IterativeAlgorithmResult struct. Helper struct to assemble the result of an iterative algorithm.
Definition iterative_algorithm_result.h:19
The EpisodeInfo struct.
Definition episode_info.h:19
The RLSerialTrainerConfig struct. Configuration struct for the serial RL agent trainer.
Definition rl_serial_agent_trainer.h:27
real_t tolerance
Definition rl_serial_agent_trainer.h:30
uint_t output_msg_frequency
Definition rl_serial_agent_trainer.h:28
uint_t n_episodes
Definition rl_serial_agent_trainer.h:29