bitrl & cuberl Documentation
Simulation engine for reinforcement learning agents
Loading...
Searching...
No Matches
a2c.h
Go to the documentation of this file.
1#ifndef A2C_H
2#define A2C_H
3
10
11#include "cuberl/base/cubeai_config.h"
12
13#ifdef USE_PYTORCH
14
22#include "cuberl/data_structs/experience_buffer.h"
23
24
25#include <torch/torch.h>
26
27#ifdef CUBERL_DEBUG
28#include <cassert>
29#endif
30
31#include <chrono>
32#include <memory>
33#include <tuple>
34
35
36namespace cuberl{
37namespace rl::algos::pg
38{
39
53 template<typename EnvType, typename PolicyType, typename CriticType>
54 class A2CSolver final: public RLSolverBase<EnvType>
55 {
56 public:
57
61 typedef EnvType env_type;
62
66 typedef PolicyType policy_type;
67
71 typedef CriticType critic_type;
72
73 typedef typename env_type::state_type state_type;
74 typedef typename env_type::action_type action_type;
75
76 typedef typename A2CMonitor<action_type,
77 state_type>::experience_buffer_type experience_buffer_type;
78
84 A2CSolver(const A2CConfig& config,
85 policy_type& policy, critic_type& critic,
86 std::unique_ptr<torch::optim::Optimizer>& policy_optimizer,
87 std::unique_ptr<torch::optim::Optimizer>& critic_optimizer);
88
93 virtual void actions_before_training_begins(env_type&);
94
99 virtual void actions_after_training_ends(env_type&) override final{}
100
104 virtual void actions_before_episode_begins(env_type&,
105 uint_t /*episode_idx*/) override final{}
106
110 virtual void actions_after_episode_ends(env_type&,
111 uint_t /*episode_idx*/,
112 const EpisodeInfo&) override final{}
113
117 virtual EpisodeInfo on_training_episode(env_type&, uint_t /*episode_idx*/);
118
122 void set_train_mode()noexcept;
123
127 void set_evaluation_mode()noexcept;
128
132 A2CMonitor<action_type, state_type>& get_monitor(){return monitor_;}
133
134 private:
135
139 A2CConfig config_;
140
144 policy_type& policy_;
145
149 critic_type& critic_;
150
154 A2CMonitor<action_type, state_type> monitor_;
155
159 std::unique_ptr<torch::optim::Optimizer> policy_optimizer_;
160
164 std::unique_ptr<torch::optim::Optimizer> critic_optimizer_;
165
170 uint_t create_episode_batch_(env_type&,
171 uint_t /*episode_idx*/,
172 experience_buffer_type& buffer);
173
174 std::tuple<real_t, real_t>
175 train_with_batch_(experience_buffer_type& buffer);
176
177 };
178
179 template<typename EnvType, typename PolicyType, typename CriticType>
180 A2CSolver<EnvType, PolicyType, CriticType>::A2CSolver(const A2CConfig& config,
181 policy_type& policy, critic_type& critic,
182 std::unique_ptr<torch::optim::Optimizer>& policy_optimizer,
183 std::unique_ptr<torch::optim::Optimizer>& critic_optimizer)
184 :
185 config_(config),
186 policy_(policy),
187 critic_(critic),
188 monitor_(),
189 policy_optimizer_(std::move(policy_optimizer)),
190 critic_optimizer_(std::move(critic_optimizer))
191 {}
192
193 template<typename EnvType, typename PolicyType, typename CriticType>
194 void
195 A2CSolver<EnvType, PolicyType, CriticType>::set_train_mode()noexcept{
196 policy_ -> train();
197 critic_ -> train();
198
199 }
200
201 template<typename EnvType, typename PolicyType, typename CriticType>
202 void
203 A2CSolver<EnvType, PolicyType, CriticType>::set_evaluation_mode()noexcept{
204 policy_ -> eval();
205 critic_ -> eval();
206
207 }
208
209 template<typename EnvType, typename PolicyType, typename CriticType>
210 void
211 A2CSolver<EnvType, PolicyType, CriticType>::actions_before_training_begins(env_type& /*env*/){
212
213 monitor_.reset();
214 monitor_.policy_loss_values.reserve(config_.n_episodes);
215 monitor_.critic_loss_values.reserve(config_.n_episodes);
216 monitor_.rewards.reserve(config_.n_episodes);
217 monitor_.episode_duration.reserve(config_.n_episodes);
218 set_train_mode();
219 }
220
221 template<typename EnvType, typename PolicyType, typename CriticType>
222 EpisodeInfo
223 A2CSolver<EnvType, PolicyType, CriticType>::on_training_episode(env_type& env, uint_t episode_idx){
224
225 auto start = std::chrono::steady_clock::now();
226
227 // the buffer to use
228 experience_buffer_type buffer(config_.max_itrs_per_episode);
229
230 // collect the buffer
231 auto eps_itrs = create_episode_batch_(env, episode_idx, buffer);
232
233 // train the networks with from the
234 // collected buffer
235 auto [episode_reward, total_episode_loss] = train_with_batch_(buffer);
236
237 auto end = std::chrono::steady_clock::now();
238 std::chrono::duration<real_t> elapsed_seconds = end - start;
239
240 monitor_.episode_duration.push_back(eps_itrs);
241
242 EpisodeInfo info;
243 info.episode_index = episode_idx;
244 info.episode_reward = episode_reward;
245 info.episode_iterations = eps_itrs;
246 info.total_time = elapsed_seconds;
247 return info;
248 }
249
250 template<typename EnvType, typename PolicyType, typename CriticType>
251 uint_t
252 A2CSolver<EnvType, PolicyType, CriticType>::create_episode_batch_(env_type& env,
253 uint_t /*episode_idx*/,
254 experience_buffer_type& buffer){
257 typedef typename A2CMonitor<action_type,
258 state_type>::experience_tuple_type experience_tuple_type;
259
260 // reset the environment
261 // for every episode reset the environment
262 auto old_timestep = env.reset();
263
264 // loop over the iterations
265 uint_t itrs = 0;
266 for(; itrs < config_.max_itrs_per_episode; ++itrs){
267
268 auto [action, log_prob] = policy_ -> act(old_timestep.observation());
269 auto values = critic_ -> evaluate(old_timestep.observation());
270
271 // step into the environment
272 auto next_time_step = env.step(action);
273
274 auto next_state = next_time_step.observation();
275 auto reward = next_time_step.reward();
276
277 experience_tuple_type exp = {old_timestep.observation(),
278 action,
279 reward,
280 next_time_step.done(),
281 log_prob,
282 values};
283
284 // put the observation into the buffer
285 buffer.append(exp);
286
287 if (next_time_step.done()){
288 break;
289 }
290
291 old_timestep = next_time_step;
292
293 }
294
295 return itrs + 1;
296 }
297
298
299 template<typename EnvType,typename PolicyType, typename CriticType>
300 std::tuple<real_t, real_t>
301 A2CSolver<EnvType, PolicyType, CriticType>::train_with_batch_(experience_buffer_type& buffer){
302
303
304 // because of the way we treat the values
305 // we loose the requires_grad so we need to set it
306 using namespace cuberl::utils::pytorch;
307
308 typedef typename A2CMonitor<action_type,
309 state_type>::experience_tuple_type experience_tuple_type;
310 typedef std::vector<experience_tuple_type> batch_type;
311
312 // the batch for this episode
313 auto batch = buffer.template get<batch_type>();
314 auto rewards_batch = monitor_.template get<real_t, 2>(batch);
315 auto values_batch = monitor_.template get<torch_tensor_t, 5>(batch);
316 auto logprobs_batch = monitor_.template get<torch_tensor_t, 4>(batch);
317
318
319 // compute the discounted rewards for this batch
320 auto discounted_returns = cuberl::rl::algos::calculate_step_discounted_return(rewards_batch,
321 config_.gamma);
322
323 auto torch_rewards_batch = TorchAdaptor::to_torch(discounted_returns,
324 config_.device_type,
325 false);
326
327 auto torch_values_batch = TorchAdaptor::stack(values_batch,
328 config_.device_type
329 );
330
331 auto torch_logprobs_batch = TorchAdaptor::stack(logprobs_batch,
332 config_.device_type);
333
334 // form the advantage
335 auto advantage = torch_rewards_batch - torch_values_batch;
336
337 // take the mean because we collect batches
338 auto actor_loss = -(torch_logprobs_batch * advantage.detach()).mean();
339 auto critic_loss = advantage.pow(2).mean();
340
341 if(config_.clip_policy_grad){
342
343 // clip the grad if needed
344 torch::nn::utils::clip_grad_norm_(policy_->parameters(),
345 config_.max_grad_norm_policy);
346
347 }
348
349
350 if(config_.clip_critic_grad){
351 torch::nn::utils::clip_grad_norm_(critic_->parameters(),
352 config_.max_grad_norm_critic);
353
354 }
355
356 // Backward pass and optimize
357 policy_optimizer_->zero_grad();
358 critic_optimizer_ -> zero_grad();
359
360 actor_loss.backward();
361 critic_loss.backward();
362
363 policy_optimizer_ -> step();
364 critic_optimizer_ -> step();
365
366
367 auto total_episode_policy_loss = actor_loss.item().template to<real_t>();
368 auto total_episode_critic_loss = critic_loss.item().template to<real_t>();
369
370 // compute the undiscounted reward as the reward
371 // for this episode
372 auto R = cuberl::maths::sum(rewards_batch);
373
374 monitor_.policy_loss_values.push_back(total_episode_policy_loss);
375 monitor_.critic_loss_values.push_back(total_episode_critic_loss);
376 monitor_.rewards.push_back(R);
377
378 return std::make_tuple(R, total_episode_policy_loss + total_episode_critic_loss);
379
380 }
381
382
383
384}
385}
386#endif
387#endif // A2C_H
std::size_t uint_t
uint_t
Definition bitrl_types.h:43
PolicyType
Definition policy_type.h:8
real_t mean(IteratorType begin, IteratorType end, bool parallel=true)
mean Compute the mean value of the values in the provided iterator range
Definition vector_math.h:126
std::iterator_traits< IteratorType >::value_type sum(IteratorType begin, IteratorType end, bool parallel=true)
Definition vector_math.h:98
std::vector< T > calculate_step_discounted_return(const std::vector< T > &rewards, T gamma)
Given an array of rewards, for each entry calculate the following: $$G = \sum_{k=t+1}^T \gamma^{k-t-1...
Definition utils.h:161
Various utilities used when working with RL problems.
Definition cuberl_types.h:16
std::pair< uint_t, uint_t > state_type
Definition example_15.cpp:28
int R
Definition extended_kalman_filter.py:54
dict action
Definition play.py:41
reward
Definition play.py:44
info
Definition play.py:44
env
Definition play.py:30
dict policy
Definition play.py:26
list values
Definition plot_losses.py:13
bitrl::envs::gymnasium::CliffWorld env_type
Definition rl_example_10.cpp:32