bitrl & cuberl Documentation
Simulation engine for reinforcement learning agents
Loading...
Searching...
No Matches
reinforce.h
Go to the documentation of this file.
1#ifndef REINFORCE_H
2#define REINFORCE_H
3
9
10#include "cuberl/base/cubeai_config.h"
11
12#ifdef USE_PYTORCH
13
22#include "cuberl/data_structs/experience_buffer.h"
24
25#include <boost/log/trivial.hpp>
26#include <torch/torch.h>
27
28#include <vector>
29
30#include <numeric>
31#include <iostream>
32#include <chrono>
33#include <memory>
34#include <tuple>
35#include <string>
36#include <iterator>
37
38namespace cuberl {
39namespace rl {
40namespace algos {
41namespace pg {
42
54template<typename EnvType, typename PolicyType>
55class ReinforceSolver final: public RLSolverBase<EnvType>
56{
57public:
58
59 typedef EnvType env_type;
60 typedef PolicyType policy_type;
61
62 typedef typename env_type::state_type state_type;
63 typedef typename env_type::action_type action_type;
64 typedef typename ReinforceMonitor<action_type,
65 state_type>::experience_buffer_type experience_buffer_type;
66
70 ReinforceSolver(ReinforceConfig opts,
71 policy_type& policy,
72 std::unique_ptr<torch::optim::Optimizer>& policy_optimizer);
73
78 virtual void actions_before_training_begins(env_type&);
79
84 virtual void actions_after_training_ends(env_type&){}
85
89 virtual void actions_before_episode_begins(env_type&, uint_t /*episode_idx*/){}
90
94 virtual void actions_after_episode_ends(env_type&, uint_t /*episode_idx*/,
95 const EpisodeInfo& /*einfo*/){}
96
100 virtual EpisodeInfo on_training_episode(env_type&, uint_t /*episode_idx*/);
101
105 ReinforceMonitor<action_type, state_type>& get_monitor(){return monitor_;}
106
107private:
108
112 ReinforceConfig config_;
113
117 policy_type policy_ptr_;
118
122 std::unique_ptr<torch::optim::Optimizer> policy_optimizer_;
123
127 ReinforceMonitor<action_type, state_type> monitor_;
128
134 uint_t create_episode_batch_(env_type& env, experience_buffer_type& buffer);
135
140 std::tuple<real_t, real_t> train_batch_(experience_buffer_type& buffer);
141
146 std::tuple<real_t, real_t> train_sequential_(experience_buffer_type& buffer);
147
151 std::tuple<real_t, real_t> train_without_baseline_(experience_buffer_type& buffer);
152
156 std::tuple<real_t, real_t> train_with_baseline_(experience_buffer_type& buffer);
157
158};
159
160template<typename EnvType, typename PolicyType>
161ReinforceSolver<EnvType, PolicyType>::ReinforceSolver(ReinforceConfig config,
162 policy_type& policy,
163 std::unique_ptr<torch::optim::Optimizer>& policy_optimizer)
164 :
165 RLSolverBase<EnvType>(),
166 config_(config),
167 policy_ptr_(policy),
168 policy_optimizer_(std::move(policy_optimizer)),
169 monitor_()
170
171{}
172
173template<typename EnvType, typename PolicyType>
174void
175ReinforceSolver<EnvType, PolicyType>::actions_before_training_begins(env_type& /*env*/){
176
177 monitor_.policy_loss_values.reserve(config_.n_episodes);
178 monitor_.rewards.reserve(config_.n_episodes);
179 monitor_.episode_duration.reserve(config_.n_episodes);
180
181 // set the policy to train mode
182 policy_ptr_ -> train();
183
184}
185
186template<typename EnvType, typename PolicyType>
187uint_t
188ReinforceSolver<EnvType,
190 >::create_episode_batch_(env_type& env, experience_buffer_type& buffer){
191
194
195 typedef typename ReinforceMonitor<action_type,
196 state_type>::experience_tuple_type experience_tuple_type;
197
198 // for every episode reset the environment
199 auto old_timestep = env.reset();
200
201 // iterate over the given number
202 // of iterations for the episode and create
203 // the trajectory. The trajectory may be less
204 // than config_.max_itrs_per_episode
205
206 uint_t itr = 0;
207 for(; itr < config_.max_itrs_per_episode; ++itr){
208
209 // from the policy get the action to do based
210 // on the seen state.
211 auto [action, log_prob] = policy_ptr_ -> act(old_timestep.observation());
212
213 // execute the selected action on the environment
214 auto new_timestep = env.step(action);
215 auto reward = new_timestep.reward();
216
217 experience_tuple_type exp = {old_timestep.observation(),
218 action,
219 reward,
220 new_timestep.done(),
221 log_prob};
222
223 // put the observation into the buffer
224 buffer.append(exp);
225
226 if (new_timestep.done()){
227 break;
228 }
229
230 old_timestep = new_timestep;
231 }
232
233 // because we start from zero
234 return itr + 1;
235}
236
237template<typename EnvType, typename PolicyType>
238EpisodeInfo
239ReinforceSolver<EnvType, PolicyType>::on_training_episode(env_type& env,
240 uint_t episode_idx){
241
242 // start the time for the episode
243 auto start = std::chrono::steady_clock::now();
244
245 // the buffer to use
246 experience_buffer_type buffer(config_.max_itrs_per_episode);
247
248 // Accummulate the data i.e. create the
249 // batch data we need to train the parameters
250 auto itrs = create_episode_batch_(env, buffer);
251
252 EpisodeInfo info;
253 if(config_.baseline_type == cuberl::rl::algos::pg::BaselineEnumType::NONE){
254
255 auto [episode_reward, total_episode_loss] = train_without_baseline_(buffer);
256 info.episode_reward = episode_reward;
257 }
258 else{
259
260 auto [episode_reward, total_episode_loss] = train_with_baseline_(buffer);
261 info.episode_reward = episode_reward;
262 }
263 monitor_.episode_duration.push_back(itrs);
264
265 auto end = std::chrono::steady_clock::now();
266 std::chrono::duration<real_t> elapsed_seconds = end - start;
267
268 // the info class to return for the episode
269 info.episode_index = episode_idx;
270 info.episode_iterations = itrs;
271 info.total_time = elapsed_seconds;
272 return info;
273
274}
275
276template<typename EnvType, typename PolicyType>
277std::tuple<real_t, real_t>
278ReinforceSolver<EnvType, PolicyType>::train_batch_(experience_buffer_type& buffer){
279
280 typedef typename ReinforceMonitor<action_type,
281 state_type>::experience_tuple_type experience_tuple_type;
282
283 typedef std::vector<experience_tuple_type> batch_type;
284
285 // the batch for this episode
286 auto batch = buffer.template get<batch_type>();
287
288 // create the batches
289 auto reward_batch = monitor_.template get<real_t, 2>(batch);
290 auto log_probs_batch = monitor_.template get<torch_tensor_t, 4>(batch);
291
292 // compute the discounted rewards for this batch
293 auto discounted_returns = cuberl::rl::algos::calculate_step_discounted_return(reward_batch,
294 config_.gamma);
295
296 if(config_.normalize_rewards){
297 discounted_returns = cuberl::maths::normalize_max(discounted_returns);
298 }
299
300 std::vector<torch_tensor_t> loss_vals = compute_loss_item(discounted_returns,
301 log_probs_batch);
302
303
304 auto loss = cuberl::utils::pytorch::TorchAdaptor::stack(loss_vals,
305 config_.device_type,
306 true).sum();
307 policy_optimizer_ -> zero_grad();
308 loss.backward();
309 policy_optimizer_ -> step();
310
311 auto total_episode_loss = loss.item().to<real_t>();
312
313 // compute the undiscounted reward as the reward
314 // for this episode
315 auto R = cuberl::maths::sum(reward_batch);
316 return std::make_tuple(R, total_episode_loss);
317
318
319}
320
321template<typename EnvType, typename PolicyType>
322std::tuple<real_t, real_t>
323ReinforceSolver<EnvType, PolicyType>::train_sequential_(experience_buffer_type& buffer){
324
325
326 typedef typename ReinforceMonitor<action_type,
327 state_type>::experience_tuple_type experience_tuple_type;
328
329 typedef std::vector<experience_tuple_type> batch_type;
330
331 // the batch for this episode
332 auto batch = buffer.template get<batch_type>();
333
334 // create the batches
335 auto reward_batch = monitor_.template get<real_t, 2>(batch);
336 auto log_probs_batch = monitor_.template get<torch_tensor_t, 4>(batch);
337
338
339 // compute the discounted rewards for this batch
340 auto discounted_returns = cuberl::rl::algos::calculate_step_discounted_return(reward_batch,
341 config_.gamma);
342
343 if(config_.normalize_rewards){
344 discounted_returns = cuberl::maths::normalize_max(discounted_returns);
345 }
346
347 std::vector<torch_tensor_t> loss_vals = compute_loss_item(discounted_returns,
348 log_probs_batch);
349
350 //auto device = config_.device_type != DeviceType::CPU ? torch::kCUDA : torch::kCPU;
351
352 auto total_episode_loss = 0.0;
353 for(uint_t l=0; l<loss_vals.size(); ++l){
354
355 auto loss = loss_vals[l];
356 policy_optimizer_ -> zero_grad();
357 loss.backward();
358 policy_optimizer_ -> step();
359
360 total_episode_loss += loss.item().to<real_t>();
361 }
362
363 auto R = cuberl::maths::sum(reward_batch);
364 return std::make_tuple(R, total_episode_loss / loss_vals.size());
365}
366
367template<typename EnvType, typename PolicyType>
368std::tuple<real_t, real_t>
369ReinforceSolver<EnvType, PolicyType>::train_without_baseline_(experience_buffer_type& buffer){
370
371 if(config_.train_type == cuberl::utils::TrainEnumType::BATCH){
372
373 auto [episode_reward, total_episode_loss] = train_batch_(buffer);
374 monitor_.policy_loss_values.push_back(total_episode_loss);
375 monitor_.rewards.push_back(episode_reward);
376 return std::make_tuple(episode_reward, total_episode_loss);
377 }
378 else{
379
380 auto [episode_reward, total_episode_loss] = train_sequential_(buffer);
381 monitor_.policy_loss_values.push_back(total_episode_loss);
382 monitor_.rewards.push_back(episode_reward);
383 return std::make_tuple(episode_reward, total_episode_loss);
384
385 }
386}
387
388template<typename EnvType, typename PolicyType>
389std::tuple<real_t, real_t>
390ReinforceSolver<EnvType, PolicyType>::train_with_baseline_(experience_buffer_type& buffer){
391
392
393 typedef typename ReinforceMonitor<action_type,
394 state_type>::experience_tuple_type experience_tuple_type;
395 typedef std::vector<experience_tuple_type> batch_type;
396
397 // the batch for this episode
398 auto batch = buffer.template get<batch_type>();
399 auto reward_batch = monitor_.template get<real_t, 2>(batch);
400
401 // compute the discounted rewards for this batch
402 auto discounted_returns = cuberl::rl::algos::calculate_step_discounted_return(reward_batch,
403 config_.gamma);
404 if(config_.baseline_type == BaselineEnumType::CONSTANT){
405 discounted_returns = compute_baseline_with_constant(discounted_returns,
406 config_.baseline_constant);
407 }
408 else if(config_.baseline_type == BaselineEnumType::MEAN){
409 discounted_returns = compute_baseline_with_mean(discounted_returns);
410 }
411 else{
412 discounted_returns = compute_baseline_with_standardization(discounted_returns,
413 config_.eps);
414 }
415
416 auto log_probs_batch = monitor_.template get<torch_tensor_t, 4>(batch);
417 std::vector<torch_tensor_t> loss_vals = compute_loss_item(discounted_returns,
418 log_probs_batch);
419
420
421 auto loss = cuberl::utils::pytorch::TorchAdaptor::stack(loss_vals,
422 config_.device_type,
423 true).sum();
424 policy_optimizer_ -> zero_grad();
425 loss.backward();
426 policy_optimizer_ -> step();
427
428 auto total_episode_loss = loss.item().to<real_t>();
429
430 // compute the undiscounted reward as the reward
431 // for this episode
432 auto R = cuberl::maths::sum(reward_batch);
433
434 monitor_.policy_loss_values.push_back(total_episode_loss);
435 monitor_.rewards.push_back(R);
436
437 return std::make_tuple(R, total_episode_loss);
438}
439
440}
441}
442}
443}
444#endif
445#endif // VANILLA_REINFORCE_H
double real_t
real_t
Definition bitrl_types.h:23
std::size_t uint_t
uint_t
Definition bitrl_types.h:43
PolicyType
Definition policy_type.h:8
std::iterator_traits< IteratorType >::value_type sum(IteratorType begin, IteratorType end, bool parallel=true)
Definition vector_math.h:98
std::vector< T > normalize_max(const std::vector< T > &vec)
Definition vector_math.h:564
std::vector< T > calculate_step_discounted_return(const std::vector< T > &rewards, T gamma)
Given an array of rewards, for each entry calculate the following: $$G = \sum_{k=t+1}^T \gamma^{k-t-1...
Definition utils.h:161
Various utilities used when working with RL problems.
Definition cuberl_types.h:16
std::pair< uint_t, uint_t > state_type
Definition example_15.cpp:28
int R
Definition extended_kalman_filter.py:54
dict action
Definition play.py:41
reward
Definition play.py:44
info
Definition play.py:44
env
Definition play.py:30
dict policy
Definition play.py:26
bitrl::envs::gymnasium::CliffWorld env_type
Definition rl_example_10.cpp:32