bitrl & cuberl Documentation
Simulation engine for reinforcement learning agents
Loading...
Searching...
No Matches
rl_serial_agent_trainer.h
Go to the documentation of this file.
1#ifndef RL_SERIAL_AGENT_TRAINER_H
2#define RL_SERIAL_AGENT_TRAINER_H
3
6
9
10#include <boost/noncopyable.hpp>
11#include <boost/log/trivial.hpp>
12#include <vector>
13#include <chrono>
14//#include <iostream>
15
16namespace cuberl {
17namespace rl {
18
19// forward declare
20struct EpisodeInfo;
21
32
33
38template<typename EnvType, typename AgentType>
39class RLSerialAgentTrainer: private boost::noncopyable
40{
41public:
42
43 typedef EnvType env_type;
44 typedef AgentType agent_type;
45
52
58
64
70
75 virtual void actions_after_episode_ends(env_type&, uint_t /*episode_idx*/,
76 const EpisodeInfo& einfo);
77
83
88 const std::vector<real_t>& episodes_total_rewards()const noexcept
90
95 const std::vector<uint_t>& n_itrs_per_episode()const noexcept
96 {return n_itrs_per_episode_;}
97
98protected:
99
104
110
115
119 std::vector<real_t> total_reward_per_episode_;
120
125 std::vector<uint_t> n_itrs_per_episode_;
126
127};
128
129template<typename EnvType, typename AgentType>
131 agent_type& agent)
132 :
133 output_msg_frequency_(config.output_msg_frequency),
134 itr_ctrl_(config.n_episodes, config.tolerance),
135 agent_(agent),
136 total_reward_per_episode_(),
137 n_itrs_per_episode_()
138{}
139
140template<typename EnvType, typename AgentType>
141void
143
144 agent_.actions_before_training_begins(env);
145 total_reward_per_episode_.clear();
146 n_itrs_per_episode_.clear();
147
148 total_reward_per_episode_.reserve(itr_ctrl_.get_max_iterations());
149 n_itrs_per_episode_.reserve(itr_ctrl_.get_max_iterations());
150}
151
152template<typename EnvType, typename AgentType>
153void
155 uint_t episode_idx){
156 agent_.actions_before_episode_begins(env, episode_idx);
157}
158
159template<typename EnvType, typename AgentType>
160void
162 const EpisodeInfo& einfo){
163 agent_.actions_after_episode_ends(env, episode_idx, einfo);
164}
165
166template<typename EnvType, typename AgentType>
167void
169 agent_.actions_after_training_ends(env);
170}
171
172template<typename EnvType, typename AgentType>
175
176 BOOST_LOG_TRIVIAL(info)<<" Start training on environment..."; //<<env.name;
177
178 // start timing the training
179 auto start = std::chrono::steady_clock::now();
180
181 this->actions_before_training_begins(env);
182
183 uint_t episode_counter = 0;
184 bool stop_training = false;
185 while(itr_ctrl_.continue_iterations()){
186
187 this->actions_before_episode_begins(env, episode_counter);
188 auto episode_info = agent_.on_training_episode(env, episode_counter);
189
190 if(output_msg_frequency_ != bitrl::consts::INVALID_ID &&
191 episode_counter % output_msg_frequency_ == 0){
192
193 BOOST_LOG_TRIVIAL(info)<<episode_info;
194 }
195
196 total_reward_per_episode_.push_back(episode_info.episode_reward);
197 n_itrs_per_episode_.push_back(episode_info.episode_iterations);
198 this->actions_after_episode_ends(env, episode_counter, episode_info);
199
200 if(episode_info.stop_training){
201 BOOST_LOG_TRIVIAL(info)<<" Stopping training at index="<<episode_counter;
202
203 // assume that if we were told to stop
204 // that we have converge
205 stop_training = true;
206 break;
207 }
208 episode_counter += 1;
209 }
210
211 this->actions_after_training_ends(env);
212 auto end = std::chrono::steady_clock::now();
213 std::chrono::duration<real_t> elapsed_seconds = end-start;
214
215 BOOST_LOG_TRIVIAL(info)<<" Done... ";
216
217 auto state = itr_ctrl_.get_state();
218 state.total_time = elapsed_seconds;
219 state.converged = stop_training;
220 return state;
221}
222
223
224}
225}
226
227#endif // RL_SERIAL_AGENT_TRAINER_H
Controller for iterative algorithms.
Definition iterative_algorithm_controller.h:17
Definition rl_serial_agent_trainer.h:40
virtual bitrl::utils::IterativeAlgorithmResult train(env_type &env)
train Iterate to train the agent on the given environment
Definition rl_serial_agent_trainer.h:174
EnvType env_type
Definition rl_serial_agent_trainer.h:43
const std::vector< uint_t > & n_itrs_per_episode() const noexcept
n_itrs_per_episode
Definition rl_serial_agent_trainer.h:95
const std::vector< real_t > & episodes_total_rewards() const noexcept
episodes_total_rewards
Definition rl_serial_agent_trainer.h:88
std::vector< real_t > total_reward_per_episode_
total_reward_per_episode_
Definition rl_serial_agent_trainer.h:119
agent_type & agent_
agent_
Definition rl_serial_agent_trainer.h:114
uint_t output_msg_frequency_
Definition rl_serial_agent_trainer.h:103
virtual void actions_after_training_ends(env_type &)
actions_after_training_ends. Execute any actions the algorithm needs after the iterations are finishe...
Definition rl_serial_agent_trainer.h:168
virtual void actions_before_training_begins(env_type &)
actions_before_training_begins. Execute any actions the algorithm needs before starting the episode
Definition rl_serial_agent_trainer.h:142
RLSerialAgentTrainer(const RLSerialTrainerConfig &config, agent_type &agent)
RLSerialAgentTrainer.
Definition rl_serial_agent_trainer.h:130
std::vector< uint_t > n_itrs_per_episode_
n_itrs_per_episode_ Holds the number of iterations performed per training episode
Definition rl_serial_agent_trainer.h:125
AgentType agent_type
Definition rl_serial_agent_trainer.h:44
virtual void actions_before_episode_begins(env_type &, uint_t)
actions_before_episode_begins. Execute any actions the algorithm needs before starting the episode
Definition rl_serial_agent_trainer.h:154
bitrl::utils::IterativeAlgorithmController itr_ctrl_
itr_ctrl_ Handles the iteration over the episodes
Definition rl_serial_agent_trainer.h:109
virtual void actions_after_episode_ends(env_type &, uint_t, const EpisodeInfo &einfo)
actions_after_episode_ends. Execute any actions the algorithm needs after ending the episode
Definition rl_serial_agent_trainer.h:161
const uint_t INVALID_ID
Invalid id.
Definition bitrl_consts.h:21
const real_t TOLERANCE
Tolerance used around the library.
Definition bitrl_consts.h:31
double real_t
real_t
Definition bitrl_types.h:23
std::size_t uint_t
uint_t
Definition bitrl_types.h:43
Various utilities used when working with RL problems.
Definition cuberl_types.h:16
The IterativeAlgorithmResult struct. Helper struct to assemble the result of an iterative algorithm.
Definition iterative_algorithm_result.h:19
The EpisodeInfo struct.
Definition episode_info.h:19
The RLSerialTrainerConfig struct. Configuration struct for the serial RL agent trainer.
Definition rl_serial_agent_trainer.h:27
real_t tolerance
Definition rl_serial_agent_trainer.h:30
uint_t output_msg_frequency
Definition rl_serial_agent_trainer.h:28
uint_t n_episodes
Definition rl_serial_agent_trainer.h:29