1#ifndef POLICY_ITERATION_H
2#define POLICY_ITERATION_H
15namespace rl::algos::dp
33 template<
typename EnvType,
typename PolicyType>
87 void save(
const std::string& filename)
const;
111 template<
typename EnvType,
typename PolicyType>
120 policy_impr_(action_space_size, config.gamma,
DynVec<real_t>(), policy)
124 template<
typename EnvType,
typename PolicyType>
129 policy_impr_.actions_before_training_begins(env);
132 template<
typename EnvType,
typename PolicyType>
135 v_ = policy_eval_.get_value_function();
138 save(config_.save_path);
142 template<
typename EnvType,
typename PolicyType>
146 auto start = std::chrono::steady_clock::now();
149 auto episode_rewards = 0.0;
152 auto old_policy = policy_eval_.get_policy();
154 for(
uint_t itr=0; itr < config_.n_policy_eval_steps; ++itr ){
156 policy_eval_.on_training_episode(env, itr);
161 policy_impr_.set_value_function( policy_eval_.get_value_function());
164 auto policy_imp_info = policy_impr_.on_training_episode(env, episode_idx);
167 const auto& new_policy = policy_impr_.policy();
170 if(old_policy == new_policy){
171 info.stop_training =
true;
174 policy_eval_.update_policy(new_policy);
176 auto end = std::chrono::steady_clock::now();
177 std::chrono::duration<real_t> elapsed_seconds = end-start;
180 info.episode_index = episode_idx;
181 info.episode_reward = episode_rewards;
182 info.episode_iterations = config_.n_policy_eval_steps + policy_imp_info.episode_iterations;
183 info.total_time = elapsed_seconds;
187 template<
typename EnvType,
typename PolicyType>
196 auto vec_size =
static_cast<uint_t>(v_.size());
197 for(
uint_t s=0; s < vec_size; ++s){
198 auto row = std::make_tuple(s, v_[s]);
The CSVWriter class. Handles writing into CSV file format.
Definition csv_file_writer.h:22
void write_column_names(const std::vector< std::string > &col_names, bool write_header=true)
Write the column names.
Definition csv_file_writer.cpp:16
void write_row(const std::vector< T > &vals)
Write a row of the file.
Definition csv_file_writer.h:89
virtual void open() override
Open the file for writing.
Definition file_writer_base.cpp:21
virtual void actions_before_training_begins(env_type &)=0
actions_before_training_begins. Execute any actions the algorithm needs before starting the iteration...
The DPSolverBase class.
Definition dp_algo_base.h:21
RLSolverBase< EnvType >::env_type env_type
The environment type the solver is using.
Definition dp_algo_base.h:27
The IterativePolicyEval class.
Definition iterative_policy_evaluation.h:31
The PolicyImprovement class. PolicyImprovement is not a real algorithm in the sense that it looks for...
Definition policy_improvement.h:23
The policy iteration class.
Definition policy_iteration.h:35
virtual void actions_before_training_begins(env_type &env) override
actions_before_training_begins. Execute any actions the algorithm needs before starting the iteration...
Definition policy_iteration.h:126
void save(const std::string &filename) const
save
Definition policy_iteration.h:189
virtual void actions_after_training_ends(env_type &) override
actions_after_training_ends. Actions to execute after the training iterations have finisehd
Definition policy_iteration.h:134
virtual void actions_after_episode_ends(env_type &, uint_t, const EpisodeInfo &) override
actions_after_training_episode
Definition policy_iteration.h:75
virtual void actions_before_episode_begins(env_type &, uint_t) override
actions_before_training_episode
Definition policy_iteration.h:70
PolicyIterationSolver(PolicyIterationConfig config, uint_t action_space_size, policy_type &policy)
PolicyIteration.
Definition policy_iteration.h:112
virtual EpisodeInfo on_training_episode(env_type &env, uint_t episode_idx) override
on_episode Do one on_episode of the algorithm
Definition policy_iteration.h:144
PolicyType policy_type
policy_type
Definition policy_iteration.h:46
DPSolverBase< EnvType >::env_type env_type
env_t
Definition policy_iteration.h:41
const std::string INVALID_STR
Invalid string.
Definition bitrl_consts.h:26
double real_t
real_t
Definition bitrl_types.h:23
Eigen::RowVectorX< T > DynVec
Dynamically sized row vector.
Definition bitrl_types.h:74
std::size_t uint_t
uint_t
Definition bitrl_types.h:43
Various utilities used when working with RL problems.
Definition cuberl_types.h:16
The EpisodeInfo struct.
Definition episode_info.h:19
The PolicyIterationConfig struct.
Definition policy_iteration.h:23
real_t gamma
Definition policy_iteration.h:25
std::string save_path
Definition policy_iteration.h:27
real_t tolerance
Definition policy_iteration.h:26
uint_t n_policy_eval_steps
Definition policy_iteration.h:24