bitrl & cuberl Documentation
Simulation engine for reinforcement learning agents
Loading...
Searching...
No Matches
policy_iteration.h
Go to the documentation of this file.
1#ifndef POLICY_ITERATION_H
2#define POLICY_ITERATION_H
3
8
10#include "bitrl/bitrl_consts.h"
11
12#include <string>
13
14namespace cuberl{
15namespace rl::algos::dp
16{
17
18
29
33 template<typename EnvType, typename PolicyType>
34 class PolicyIterationSolver final: public DPSolverBase<EnvType>
35 {
36 public:
37
42
46 typedef PolicyType policy_type;
47
52 uint_t action_space_size,
53 policy_type& policy);
54
59 virtual void actions_before_training_begins(env_type& env)override;
60
65 virtual void actions_after_training_ends(env_type& /*env*/)override;
66
70 virtual void actions_before_episode_begins(env_type&, uint_t /*episode_idx*/)override{}
71
75 virtual void actions_after_episode_ends(env_type&, uint_t /*episode_idx*/,
76 const EpisodeInfo& /*einfo*/)override{}
77
81 virtual EpisodeInfo on_training_episode(env_type& env, uint_t episode_idx) override;
82
87 void save(const std::string& filename)const;
88
89 private:
90
92
97
102
103
108
109 };
110
111 template<typename EnvType, typename PolicyType>
113 uint_t action_space_size,
114 policy_type& policy)
115 :
116 DPSolverBase<EnvType>(),
117 config_(config),
118 v_(),
119 policy_eval_({config.gamma, config.tolerance}, policy),
120 policy_impr_(action_space_size, config.gamma, DynVec<real_t>(), policy)
121 {}
122
123
124 template<typename EnvType, typename PolicyType>
125 void
127
128 policy_eval_.actions_before_training_begins(env);
129 policy_impr_.actions_before_training_begins(env);
130 }
131
132 template<typename EnvType, typename PolicyType>
133 void
135 v_ = policy_eval_.get_value_function();
136
137 if(config_.save_path != bitrl::consts::INVALID_STR){
138 save(config_.save_path);
139 }
140 }
141
142 template<typename EnvType, typename PolicyType>
145
146 auto start = std::chrono::steady_clock::now();
147 EpisodeInfo info;
148
149 auto episode_rewards = 0.0;
150
151 // make a copy of the policy already obtained
152 auto old_policy = policy_eval_.get_policy();
153
154 for(uint_t itr=0; itr < config_.n_policy_eval_steps; ++itr ){
155 // evaluate the policy
156 policy_eval_.on_training_episode(env, itr);
157 }
158
159 // update the value function to
160 // improve for
161 policy_impr_.set_value_function( policy_eval_.get_value_function());
162
163 // improve the policy
164 auto policy_imp_info = policy_impr_.on_training_episode(env, episode_idx);
165
166 // get the improved policy
167 const auto& new_policy = policy_impr_.policy();
168
169 // policy converged
170 if(old_policy == new_policy){
171 info.stop_training = true;
172 }
173
174 policy_eval_.update_policy(new_policy);
175
176 auto end = std::chrono::steady_clock::now();
177 std::chrono::duration<real_t> elapsed_seconds = end-start;
178
179
180 info.episode_index = episode_idx;
181 info.episode_reward = episode_rewards;
182 info.episode_iterations = config_.n_policy_eval_steps + policy_imp_info.episode_iterations;
183 info.total_time = elapsed_seconds;
184 return info;
185 }
186
187 template<typename EnvType, typename PolicyType>
188 void
189 PolicyIterationSolver<EnvType, PolicyType>::save(const std::string& filename)const{
190
191 bitrl::utils::io::CSVWriter file_writer(filename, ',');
192 file_writer.open();
193
194 file_writer.write_column_names({"state_index", "value_function"});
195
196 auto vec_size = static_cast<uint_t>(v_.size());
197 for(uint_t s=0; s < vec_size; ++s){
198 auto row = std::make_tuple(s, v_[s]);
199 file_writer.write_row(row);
200 }
201 }
202
203}
204}
205
206
207#endif // POLICY_ITERATION_H
The CSVWriter class. Handles writing into CSV file format.
Definition csv_file_writer.h:22
void write_column_names(const std::vector< std::string > &col_names, bool write_header=true)
Write the column names.
Definition csv_file_writer.cpp:16
void write_row(const std::vector< T > &vals)
Write a row of the file.
Definition csv_file_writer.h:89
virtual void open() override
Open the file for writing.
Definition file_writer_base.cpp:21
virtual void actions_before_training_begins(env_type &)=0
actions_before_training_begins. Execute any actions the algorithm needs before starting the iteration...
The DPSolverBase class.
Definition dp_algo_base.h:21
RLSolverBase< EnvType >::env_type env_type
The environment type the solver is using.
Definition dp_algo_base.h:27
The IterativePolicyEval class.
Definition iterative_policy_evaluation.h:31
The PolicyImprovement class. PolicyImprovement is not a real algorithm in the sense that it looks for...
Definition policy_improvement.h:23
The policy iteration class.
Definition policy_iteration.h:35
virtual void actions_before_training_begins(env_type &env) override
actions_before_training_begins. Execute any actions the algorithm needs before starting the iteration...
Definition policy_iteration.h:126
void save(const std::string &filename) const
save
Definition policy_iteration.h:189
virtual void actions_after_training_ends(env_type &) override
actions_after_training_ends. Actions to execute after the training iterations have finisehd
Definition policy_iteration.h:134
virtual void actions_after_episode_ends(env_type &, uint_t, const EpisodeInfo &) override
actions_after_training_episode
Definition policy_iteration.h:75
virtual void actions_before_episode_begins(env_type &, uint_t) override
actions_before_training_episode
Definition policy_iteration.h:70
PolicyIterationSolver(PolicyIterationConfig config, uint_t action_space_size, policy_type &policy)
PolicyIteration.
Definition policy_iteration.h:112
virtual EpisodeInfo on_training_episode(env_type &env, uint_t episode_idx) override
on_episode Do one on_episode of the algorithm
Definition policy_iteration.h:144
PolicyType policy_type
policy_type
Definition policy_iteration.h:46
DPSolverBase< EnvType >::env_type env_type
env_t
Definition policy_iteration.h:41
const std::string INVALID_STR
Invalid string.
Definition bitrl_consts.h:26
double real_t
real_t
Definition bitrl_types.h:23
Eigen::RowVectorX< T > DynVec
Dynamically sized row vector.
Definition bitrl_types.h:74
std::size_t uint_t
uint_t
Definition bitrl_types.h:43
Various utilities used when working with RL problems.
Definition cuberl_types.h:16
The EpisodeInfo struct.
Definition episode_info.h:19
The PolicyIterationConfig struct.
Definition policy_iteration.h:23
real_t gamma
Definition policy_iteration.h:25
std::string save_path
Definition policy_iteration.h:27
real_t tolerance
Definition policy_iteration.h:26
uint_t n_policy_eval_steps
Definition policy_iteration.h:24