1#ifndef ITERATIVE_POLICY_EVALUATION_H
2#define ITERATIVE_POLICY_EVALUATION_H
15namespace rl::algos::dp
29 template<
typename EnvType,
typename PolicyType>
80 void save(
const std::string& filename)
const;
116 template<
typename EnvType,
typename PolicyType>
126 template<
typename EnvType,
typename PolicyType>
130 v_.resize(env.n_states());
131 std::for_each(v_.begin(), v_.end(),
132 [](
auto& item){item = 0.0;});
135 template<
typename EnvType,
typename PolicyType>
140 save(config_.save_path);
144 template<
typename EnvType,
typename PolicyType>
148 auto start = std::chrono::steady_clock::now();
149 auto episode_rewards = 0.0;
161 auto state_actions_probs = policy_(s);
163 for(
const auto& action_prob : state_actions_probs){
165 auto aidx = action_prob.first;
166 auto action_p = action_prob.second;
169 auto transition_dyn = env.p(s, aidx);
171 for(
auto& dyn: transition_dyn){
172 auto prob = std::get<0>(dyn);
173 auto next_state = std::get<1>(dyn);
174 auto reward = std::get<2>(dyn);
175 new_v += action_p * prob * (reward + config_.gamma * v_[next_state]);
176 episode_rewards += reward;
180 delta = std::max(delta, std::fabs(old_v - new_v));
185 auto end = std::chrono::steady_clock::now();
186 std::chrono::duration<real_t> elapsed_seconds = end-start;
189 info.episode_index = episode_idx;
190 info.episode_reward = episode_rewards;
192 info.total_time = elapsed_seconds;
194 if( delta < config_.tolerance){
195 info.stop_training =
true;
201 template<
typename EnvType,
typename PolicyType>
209 for(
uint_t s=0; s < static_cast<uint_t>(v_.size()); ++s){
210 auto row = std::make_tuple(s, v_[s]);
The IterationCounter class.
Definition iteration_counter.h:15
uint_t current_iteration_index() const noexcept
current_iteration_index
Definition iteration_counter.h:33
bool continue_iterations() noexcept
continue_iterations
Definition iteration_counter.h:58
The CSVWriter class. Handles writing into CSV file format.
Definition csv_file_writer.h:22
void write_column_names(const std::vector< std::string > &col_names, bool write_header=true)
Write the column names.
Definition csv_file_writer.cpp:16
void write_row(const std::vector< T > &vals)
Write a row of the file.
Definition csv_file_writer.h:89
virtual void open() override
Open the file for writing.
Definition file_writer_base.cpp:21
The DPSolverBase class.
Definition dp_algo_base.h:21
RLSolverBase< EnvType >::env_type env_type
The environment type the solver is using.
Definition dp_algo_base.h:27
The IterativePolicyEval class.
Definition iterative_policy_evaluation.h:31
virtual void actions_before_episode_begins(env_type &, uint_t) override
actions_before_training_episode
Definition iterative_policy_evaluation.h:65
IterativePolicyEvalConfig config_
Definition iterative_policy_evaluation.h:103
virtual void actions_before_training_begins(env_type &env) override
actions_before_training_begins. Execute any actions the algorithm needs before starting the iteration...
Definition iterative_policy_evaluation.h:128
void save(const std::string &filename) const
Definition iterative_policy_evaluation.h:203
PolicyType policy_type
policy_type
Definition iterative_policy_evaluation.h:42
virtual void actions_after_training_ends(env_type &env) override
actions_after_training_ends. Actions to execute after the training iterations have finisehd
Definition iterative_policy_evaluation.h:137
policy_type & policy_
policy_
Definition iterative_policy_evaluation.h:112
virtual void actions_after_episode_ends(env_type &, uint_t, const EpisodeInfo &) override
actions_after_training_episode
Definition iterative_policy_evaluation.h:70
DynVec< real_t > v_
v_
Definition iterative_policy_evaluation.h:107
DynVec< real_t > get_value_function() const
value_function
Definition iterative_policy_evaluation.h:86
DPSolverBase< EnvType >::env_type env_type
env_type
Definition iterative_policy_evaluation.h:37
void update_policy(const policy_type &other)
update_policy
Definition iterative_policy_evaluation.h:98
IterativePolicyEvalutationSolver(IterativePolicyEvalConfig config, policy_type &policy)
IterativePolicyEval.
Definition iterative_policy_evaluation.h:117
virtual EpisodeInfo on_training_episode(env_type &env, uint_t episode_idx) override
on_episode Do one on_episode of the algorithm
Definition iterative_policy_evaluation.h:146
policy_type get_policy() const
get_policy
Definition iterative_policy_evaluation.h:92
const std::string INVALID_STR
Invalid string.
Definition bitrl_consts.h:26
double real_t
real_t
Definition bitrl_types.h:23
Eigen::RowVectorX< T > DynVec
Dynamically sized row vector.
Definition bitrl_types.h:74
std::size_t uint_t
uint_t
Definition bitrl_types.h:43
Various utilities used when working with RL problems.
Definition cuberl_types.h:16
The EpisodeInfo struct.
Definition episode_info.h:19
Definition iterative_policy_evaluation.h:20
std::string save_path
Definition iterative_policy_evaluation.h:23
real_t gamma
Definition iterative_policy_evaluation.h:21
real_t tolerance
Definition iterative_policy_evaluation.h:22