bitrl & cuberl Documentation
Simulation engine for reinforcement learning agents
Loading...
Searching...
No Matches
value_iteration.h
Go to the documentation of this file.
1#ifndef VALUE_ITERATION_H
2#define VALUE_ITERATION_H
3
4#include "cuberl/base/cubeai_config.h" //KERNEL_PRINT_DBG_MSGS
6
13#include "bitrl/bitrl_consts.h"
14
15#include <memory>
16#include <cmath>
17#include <string>
18
19
20namespace cuberl{
21namespace rl::algos::dp
22{
23
33
37 template<typename EnvType>
38 class ValueIteration: public DPSolverBase<EnvType>
39 {
40 public:
41
46
51
56 virtual void actions_before_training_begins(env_type& env)override;
57
62 virtual void actions_after_training_ends(env_type& /*env*/)override;
63
67 virtual void actions_before_episode_begins(env_type&, uint_t /*episode_idx*/)override{}
68
72 virtual void actions_after_episode_ends(env_type&, uint_t /*episode_idx*/,
73 const EpisodeInfo& /*einfo*/)override{}
74
78 virtual EpisodeInfo on_training_episode(env_type& env, uint_t episode_idx) override;
79
83 void save(const std::string& filename)const;
84
85
90
91 private:
92
97
102
103 };
104
105 template<typename EnvType>
107 :
108 DPSolverBase<EnvType>(),
109 config_(config)
110 {}
111
112
113 template<typename EnvType>
114 void
118
119 template<typename EnvType>
122 uint_t episode_idx){
123
124 // start timing the training
125 auto start = std::chrono::steady_clock::now();
126
127 EpisodeInfo info;
128 auto delta = 0.0;
129 for(uint_t s=0; s< env.n_states(); ++s){
130
131 auto v = v_[s];
132 auto max_val = state_actions_from_v(env, v_, config_.gamma, s).maxCoeff();
133
134 v_[s] = max_val;
135 delta = std::max(delta, std::fabs(v_[s] - v));
136 }
137
138 // inform the outer loop that
139 // we converged
140 if(delta < config_.tolerance){
141 info.stop_training = true;
142 }
143
144 auto end = std::chrono::steady_clock::now();
145 std::chrono::duration<real_t> elapsed_seconds = end-start;
146
147 info.episode_index = episode_idx;
148 info.episode_iterations = env.n_states();
149 info.total_time = elapsed_seconds;
150
151 // this is artificial but helps
152 // to monitor convergence
153 info.episode_reward = delta;
154
155 return info;
156 }
157
158 template<typename EnvType>
159 void
161 if(config_.save_path != bitrl::consts::INVALID_STR){
162 save(config_.save_path);
163 }
164
165 }
166
167 template<typename EnvType>
168 void
169 ValueIteration<EnvType>::save(const std::string& filename)const{
170
171 bitrl::utils::io::CSVWriter file_writer(filename, ',');
172 file_writer.open();
173
174 file_writer.write_column_names({"state_index", "value_function"});
175
176 for(uint_t s=0; s < static_cast<uint_t>(v_.size()); ++s){
177 auto row = std::make_tuple(s, v_[s]);
178 file_writer.write_row(row);
179 }
180 }
181
182 template<typename EnvType>
185
188 builder.build_from_state_function(env, v_,
189 config_.gamma,policy);
190 return policy;
191
192 }
193
194}
195}
196
197#endif // VALUE_ITERATION_H
The CSVWriter class. Handles writing into CSV file format.
Definition csv_file_writer.h:22
void write_column_names(const std::vector< std::string > &col_names, bool write_header=true)
Write the column names.
Definition csv_file_writer.cpp:16
void write_row(const std::vector< T > &vals)
Write a row of the file.
Definition csv_file_writer.h:89
virtual void open() override
Open the file for writing.
Definition file_writer_base.cpp:21
The DPSolverBase class.
Definition dp_algo_base.h:21
RLSolverBase< EnvType >::env_type env_type
The environment type the solver is using.
Definition dp_algo_base.h:27
ValueIteration class.
Definition value_iteration.h:39
ValueIteration(const ValueIterationConfig config)
ValueIteration.
Definition value_iteration.h:106
void save(const std::string &filename) const
Definition value_iteration.h:169
virtual void actions_before_training_begins(env_type &env) override
actions_before_training_begins. Execute any actions the algorithm needs before starting the iteration...
Definition value_iteration.h:115
DPSolverBase< EnvType >::env_type env_type
env_t
Definition value_iteration.h:45
virtual void actions_after_episode_ends(env_type &, uint_t, const EpisodeInfo &) override
actions_after_training_episode
Definition value_iteration.h:72
virtual void actions_before_episode_begins(env_type &, uint_t) override
actions_before_training_episode
Definition value_iteration.h:67
virtual void actions_after_training_ends(env_type &) override
actions_after_training_ends. Actions to execute after the training iterations have finisehd
Definition value_iteration.h:160
virtual EpisodeInfo on_training_episode(env_type &env, uint_t episode_idx) override
on_episode Do one on_episode of the algorithm
Definition value_iteration.h:121
cuberl::rl::policies::MaxTabularPolicy build_policy(const env_type &env) const
Definition value_iteration.h:184
class MaxTabularPolicy
Definition max_tabular_policy.h:30
const real_t TOLERANCE
Tolerance used around the library.
Definition bitrl_consts.h:31
const std::string INVALID_STR
Invalid string.
Definition bitrl_consts.h:26
double real_t
real_t
Definition bitrl_types.h:23
Eigen::RowVectorX< T > DynVec
Dynamically sized row vector.
Definition bitrl_types.h:74
std::size_t uint_t
uint_t
Definition bitrl_types.h:43
auto state_actions_from_v(const WorldTp &env, const DynVec< real_t > &v, real_t gamma, uint_t state) -> DynVec< real_t >
Given the state index returns the list of actions under the provided value functions.
Definition utils.h:23
Various utilities used when working with RL problems.
Definition cuberl_types.h:16
The EpisodeInfo struct.
Definition episode_info.h:19
The ValueIterationConfig struct.
Definition value_iteration.h:28
std::string save_path
Definition value_iteration.h:31
real_t tolerance
Definition value_iteration.h:30
real_t gamma
Definition value_iteration.h:29
Definition max_tabular_policy.h:125
void build_from_state_function(const EnvType &env, const DynVec< real_t > &v, real_t gamma, MaxTabularPolicy &policy)
Definition max_tabular_policy.h:139