bitrl & cuberl Documentation
Simulation engine for reinforcement learning agents
Loading...
Searching...
No Matches
iterative_policy_evaluation.h
Go to the documentation of this file.
1#ifndef ITERATIVE_POLICY_EVALUATION_H
2#define ITERATIVE_POLICY_EVALUATION_H
3
6
10
11#include <chrono>
12#include <cmath>
13
14namespace cuberl{
15namespace rl::algos::dp
16{
17
18
25
29 template<typename EnvType, typename PolicyType>
31 {
32 public:
33
38
42 typedef PolicyType policy_type;
43
48 policy_type& policy);
49
54 virtual void actions_before_training_begins(env_type& env)override;
55
60 virtual void actions_after_training_ends(env_type& env)override;
61
65 virtual void actions_before_episode_begins(env_type&, uint_t /*episode_idx*/)override{}
66
70 virtual void actions_after_episode_ends(env_type&, uint_t /*episode_idx*/, const EpisodeInfo& /*einfo*/)override{}
71
75 virtual EpisodeInfo on_training_episode(env_type& env, uint_t episode_idx) override;
76
80 void save(const std::string& filename)const;
81
87
93
98 void update_policy(const policy_type& other){policy_.update(other);}
99
100 protected:
101
102
108
113
114 };
115
116 template<typename EnvType, typename PolicyType>
118 policy_type& policy)
119 :
120 DPSolverBase<EnvType>(),
121 config_(config),
122 v_(),
123 policy_(policy)
124 {}
125
126 template<typename EnvType, typename PolicyType>
127 void
129
130 v_.resize(env.n_states());
131 std::for_each(v_.begin(), v_.end(),
132 [](auto& item){item = 0.0;});
133 }
134
135 template<typename EnvType, typename PolicyType>
136 void
138
139 if(config_.save_path != bitrl::consts::INVALID_STR){
140 save(config_.save_path);
141 }
142 }
143
144 template<typename EnvType, typename PolicyType>
147
148 auto start = std::chrono::steady_clock::now();
149 auto episode_rewards = 0.0;
150 auto delta = 0.0;
151
152
153 bitrl::utils::IterationCounter itr_counter(env.n_states());
154 uint_t s = 0;
155 while(itr_counter.continue_iterations()){
156 // every time we query itr_counter we increase the
157 // counter so we miss the zero state
158 auto old_v = v_[s];
159 auto new_v = 0.0;
160
161 auto state_actions_probs = policy_(s);
162
163 for(const auto& action_prob : state_actions_probs){
164
165 auto aidx = action_prob.first;
166 auto action_p = action_prob.second;
167
168 // get transition dynamics from the environment
169 auto transition_dyn = env.p(s, aidx);
170
171 for(auto& dyn: transition_dyn){
172 auto prob = std::get<0>(dyn);
173 auto next_state = std::get<1>(dyn);
174 auto reward = std::get<2>(dyn);
175 new_v += action_p * prob * (reward + config_.gamma * v_[next_state]);
176 episode_rewards += reward;
177 }
178 }
179
180 delta = std::max(delta, std::fabs(old_v - new_v));
181 v_[s] = new_v;
182 s += 1;
183 }
184
185 auto end = std::chrono::steady_clock::now();
186 std::chrono::duration<real_t> elapsed_seconds = end-start;
187
188 EpisodeInfo info;
189 info.episode_index = episode_idx;
190 info.episode_reward = episode_rewards;
191 info.episode_iterations = itr_counter.current_iteration_index();
192 info.total_time = elapsed_seconds;
193
194 if( delta < config_.tolerance){
195 info.stop_training = true;
196 }
197
198 return info;
199 }
200
201 template<typename EnvType, typename PolicyType>
202 void
204
205 bitrl::utils::io::CSVWriter file_writer(filename, ',');
206 file_writer.open();
207 file_writer.write_column_names({"state_index", "value_function"});
208
209 for(uint_t s=0; s < static_cast<uint_t>(v_.size()); ++s){
210 auto row = std::make_tuple(s, v_[s]);
211 file_writer.write_row(row);
212 }
213 }
214
215
216}
217}
218
219#endif // ITERATIVE_POLICY_EVALUATION_H
The IterationCounter class.
Definition iteration_counter.h:15
uint_t current_iteration_index() const noexcept
current_iteration_index
Definition iteration_counter.h:33
bool continue_iterations() noexcept
continue_iterations
Definition iteration_counter.h:58
The CSVWriter class. Handles writing into CSV file format.
Definition csv_file_writer.h:22
void write_column_names(const std::vector< std::string > &col_names, bool write_header=true)
Write the column names.
Definition csv_file_writer.cpp:16
void write_row(const std::vector< T > &vals)
Write a row of the file.
Definition csv_file_writer.h:89
virtual void open() override
Open the file for writing.
Definition file_writer_base.cpp:21
The DPSolverBase class.
Definition dp_algo_base.h:21
RLSolverBase< EnvType >::env_type env_type
The environment type the solver is using.
Definition dp_algo_base.h:27
The IterativePolicyEval class.
Definition iterative_policy_evaluation.h:31
virtual void actions_before_episode_begins(env_type &, uint_t) override
actions_before_training_episode
Definition iterative_policy_evaluation.h:65
IterativePolicyEvalConfig config_
Definition iterative_policy_evaluation.h:103
virtual void actions_before_training_begins(env_type &env) override
actions_before_training_begins. Execute any actions the algorithm needs before starting the iteration...
Definition iterative_policy_evaluation.h:128
void save(const std::string &filename) const
Definition iterative_policy_evaluation.h:203
PolicyType policy_type
policy_type
Definition iterative_policy_evaluation.h:42
virtual void actions_after_training_ends(env_type &env) override
actions_after_training_ends. Actions to execute after the training iterations have finisehd
Definition iterative_policy_evaluation.h:137
policy_type & policy_
policy_
Definition iterative_policy_evaluation.h:112
virtual void actions_after_episode_ends(env_type &, uint_t, const EpisodeInfo &) override
actions_after_training_episode
Definition iterative_policy_evaluation.h:70
DynVec< real_t > v_
v_
Definition iterative_policy_evaluation.h:107
DynVec< real_t > get_value_function() const
value_function
Definition iterative_policy_evaluation.h:86
DPSolverBase< EnvType >::env_type env_type
env_type
Definition iterative_policy_evaluation.h:37
void update_policy(const policy_type &other)
update_policy
Definition iterative_policy_evaluation.h:98
IterativePolicyEvalutationSolver(IterativePolicyEvalConfig config, policy_type &policy)
IterativePolicyEval.
Definition iterative_policy_evaluation.h:117
virtual EpisodeInfo on_training_episode(env_type &env, uint_t episode_idx) override
on_episode Do one on_episode of the algorithm
Definition iterative_policy_evaluation.h:146
policy_type get_policy() const
get_policy
Definition iterative_policy_evaluation.h:92
const std::string INVALID_STR
Invalid string.
Definition bitrl_consts.h:26
double real_t
real_t
Definition bitrl_types.h:23
Eigen::RowVectorX< T > DynVec
Dynamically sized row vector.
Definition bitrl_types.h:74
std::size_t uint_t
uint_t
Definition bitrl_types.h:43
Various utilities used when working with RL problems.
Definition cuberl_types.h:16
The EpisodeInfo struct.
Definition episode_info.h:19
Definition iterative_policy_evaluation.h:20
std::string save_path
Definition iterative_policy_evaluation.h:23
real_t gamma
Definition iterative_policy_evaluation.h:21
real_t tolerance
Definition iterative_policy_evaluation.h:22