bitrl & cuberl Documentation
Simulation engine for reinforcement learning agents
Loading...
Searching...
No Matches
actor_critic_solver_base.h
Go to the documentation of this file.
1#ifndef ACTOR_CRITIC_SOLVER_BASE_H
2#define ACTOR_CRITIC_SOLVER_BASE_H
3
10
11#include "cuberl/base/cubeai_config.h"
12
13#ifdef USE_PYTORCH
14
22#include "cuberl/data_structs/experience_buffer.h"
23
24
25#include <torch/torch.h>
26
27
28#ifdef CUBERL_DEBUG
29#include <cassert>
30#include <boost/log/trivial.hpp>
31#endif
32
33#include <string>
34#include <chrono>
35#include <map>
36#include <any>
37#include <memory>
38#include <tuple>
39#include <string>
40#include <exception>
41#include <iostream>
42
43namespace cuberl{
44namespace rl{
45namespace algos {
46namespace pg {
47
61template<typename EnvType, typename PolicyType,
62 typename CriticType, typename MonitorType,
63 typename ConfigType>
64class ACSolverBase: public RLSolverBase<EnvType>
65{
66public:
67
71 typedef EnvType env_type;
72
76 typedef PolicyType policy_type;
77
81 typedef CriticType critic_type;
82
83 typedef typename env_type::state_type state_type;
84 typedef typename env_type::action_type action_type;
85
89 typedef MonitorType monitor_type;
90
92 typedef typename monitor_type::experience_buffer_type experience_buffer_type;
93 typedef typename monitor_type::experience_tuple_type experience_tuple_type;
94
98 typedef ConfigType config_type;
99
103 virtual ~ACSolverBase()=default;
104
109 virtual void actions_after_training_ends(env_type&){}
110
114 virtual void actions_before_episode_begins(env_type&,
115 uint_t /*episode_idx*/){}
116
120 virtual void actions_after_episode_ends(env_type&,
121 uint_t /*episode_idx*/,
122 const EpisodeInfo&){}
123
128 virtual void actions_before_training_begins(env_type&);
129
133 void set_train_mode()noexcept;
134
138 void set_evaluation_mode()noexcept;
139
143 monitor_type& get_monitor(){return monitor_;}
144
145
146protected:
147
153 ACSolverBase(const config_type& config,
154 policy_type& policy, critic_type& critic,
155 std::unique_ptr<torch::optim::Optimizer>& policy_optimizer,
156 std::unique_ptr<torch::optim::Optimizer>& critic_optimizer);
157
161 config_type config_;
162
166 policy_type& policy_;
167
171 critic_type& critic_;
172
176 monitor_type monitor_;
177
181 std::unique_ptr<torch::optim::Optimizer> policy_optimizer_;
182
186 std::unique_ptr<torch::optim::Optimizer> critic_optimizer_;
187
188
193 uint_t
194 create_episode_batch_(env_type& env, uint_t /*episode_idx*/, experience_buffer_type& buffer);
195
196};
197
198template<typename EnvType, typename PolicyType,
199 typename CriticType, typename MonitorType,
200 typename ConfigType>
201ACSolverBase<EnvType, PolicyType, CriticType,
202 MonitorType, ConfigType>::ACSolverBase(const config_type& config,
203 policy_type& policy, critic_type& critic,
204 std::unique_ptr<torch::optim::Optimizer>& policy_optimizer,
205 std::unique_ptr<torch::optim::Optimizer>& critic_optimizer)
206 :
207 RLSolverBase<EnvType>(),
208 config_(config),
209 policy_(policy),
210 critic_(critic),
211 monitor_(),
212 policy_optimizer_(std::move(policy_optimizer)),
213 critic_optimizer_(std::move(critic_optimizer))
214{}
215
216template<typename EnvType, typename PolicyType,
217 typename CriticType, typename MonitorType,
218 typename ConfigType>
219void
220ACSolverBase<EnvType, PolicyType, CriticType,
221 MonitorType, ConfigType>::set_train_mode()noexcept{
222 policy_ -> train();
223 critic_ -> train();
224
225}
226
227template<typename EnvType, typename PolicyType,
228 typename CriticType, typename MonitorType,
229 typename ConfigType>
230void
231ACSolverBase<EnvType, PolicyType, CriticType,
232 MonitorType, ConfigType>::set_evaluation_mode()noexcept{
233 policy_ -> eval();
234 critic_ -> eval();
235
236}
237
238template<typename EnvType, typename PolicyType,
239 typename CriticType, typename MonitorType,
240 typename ConfigType>
241void
242ACSolverBase<EnvType, PolicyType, CriticType,
243 MonitorType, ConfigType>::actions_before_training_begins(env_type& /*env*/){
244
245 monitor_.reset();
246 monitor_.policy_loss_values.reserve(config_.n_episodes);
247 monitor_.critic_loss_values.reserve(config_.n_episodes);
248 monitor_.rewards.reserve(config_.n_episodes);
249 monitor_.episode_duration.reserve(config_.n_episodes);
250 set_train_mode();
251}
252
253template<typename EnvType, typename PolicyType,
254 typename CriticType, typename MonitorType,
255 typename ConfigType>
256uint_t
257ACSolverBase<EnvType, PolicyType, CriticType,
258 MonitorType, ConfigType>::create_episode_batch_(env_type& env, uint_t episode_idx, experience_buffer_type& buffer)
259{
260
261#ifdef CUBERL_DEBUG
262BOOST_LOG_TRIVIAL(info)<<"Collecting batch for episode: "<<episode_idx;
263#endif
264
267
268 typedef typename MonitorType::experience_tuple_type experience_tuple_type;
269
270 // reset the environment
271 // for every episode reset the environment
272 auto old_timestep = env.reset();
273
274 // loop over the iterations
275 uint_t itrs = 0;
276 for(; itrs < config_.max_itrs_per_episode; ++itrs){
277
278 auto [action, log_prob] = policy_ -> act(old_timestep.observation());
279 auto values = critic_ -> evaluate(old_timestep.observation());
280
281 // step into the environment
282 auto next_time_step = env.step(action);
283 auto next_state = next_time_step.observation();
284 auto reward = next_time_step.reward();
285
286 experience_tuple_type exp = {old_timestep.observation(),
287 action,
288 reward,
289 next_time_step.done(),
290 log_prob,
291 values};
292
293 // put the observation into the buffer
294 buffer.append(exp);
295
296 // if the step is done then break
297 if (next_time_step.done()){
298 break;
299 }
300
301 old_timestep = next_time_step;
302
303 }
304
305#ifdef CUBERL_DEBUG
306BOOST_LOG_TRIVIAL(info)<<"Done... ";
307#endif
308
309
310 return itrs + 1;
311}
312
313}
314}
315}
316}
317
318#endif // USE_PYTORCH
319#endif
std::size_t uint_t
uint_t
Definition bitrl_types.h:43
PolicyType
Definition policy_type.h:8
Various utilities used when working with RL problems.
Definition cuberl_types.h:16
std::pair< uint_t, uint_t > state_type
Definition example_15.cpp:28
dict action
Definition play.py:41
reward
Definition play.py:44
env
Definition play.py:30
dict policy
Definition play.py:26
list values
Definition plot_losses.py:13
bitrl::envs::gymnasium::CliffWorld env_type
Definition rl_example_10.cpp:32