bitrl & cuberl Documentation
Simulation engine for reinforcement learning agents
Loading...
Searching...
No Matches
sarsa.h
Go to the documentation of this file.
1#ifndef SARSA_H
2#define SARSA_H
3
9
11#include "bitrl/bitrl_consts.h"
12
13
14#ifdef CUBERL_DEBUG
15#include <cassert>
16#endif
17
18#include <chrono>
19#include <iostream>
20#include <string>
21
22
23namespace cuberl{
24namespace rl::algos::td
25{
26
27
40
44 template<envs::discrete_world_concept EnvType, typename PolicyType>
45 class SarsaSolver final: public TDAlgoBase<EnvType>
46 {
47 public:
48
53
58
63
67 typedef PolicyType policy_type;
68
72 SarsaSolver(SarsaConfig config, const PolicyType& selector);
73
79
85
89 virtual void actions_before_episode_begins(env_type&, uint_t /*episode_idx*/){}
90
94 virtual void actions_after_episode_ends(env_type&, uint_t /*episode_idx*/,
95 const EpisodeInfo& /*einfo*/){}
96
100 virtual EpisodeInfo on_training_episode(env_type&, uint_t episode_idx);
101
105 void save(const std::string& filename)const;
106
111
112 private:
113
117 SarsaConfig config_;
118
122 policy_type policy_;
123
127 DynMat<real_t> q_table_;
128
133 void update_q_table_(const action_type& action,
134 const state_type& cstate,
135 const state_type& next_state,
136 const action_type& next_action, real_t reward);
137 };
138
139
140
141 template<envs::discrete_world_concept EnvTp, typename PolicyType>
143 const PolicyType& selector)
144 :
145 TDAlgoBase<EnvTp>(),
146 config_(config),
147 policy_(selector)
148 {}
149
150 template<envs::discrete_world_concept EnvTp, typename PolicyType>
151 void
153 q_table_ = DynMat<real_t>(env.n_states(), env.n_actions());
154
155 for(uint_t i=0; i < env.n_states(); ++i)
156 for(uint_t j=0; j < env.n_actions(); ++j)
157 q_table_(i, j) = 0.0;
158
159 }
160
161 template<envs::discrete_world_concept EnvTp, typename PolicyType>
162 void
164
165 if(config_.path != bitrl::consts::INVALID_STR){
166 save(config_.path);
167 }
168 }
169
170 template<envs::discrete_world_concept EnvTp, typename PolicyType>
173 uint_t episode_idx){
174
175 auto start = std::chrono::steady_clock::now();
176 EpisodeInfo info;
177
178 // total score for the episode
179 auto episode_score = 0.0;
180 auto time_step = env.reset();
181 auto state = time_step.observation();
182
183 uint_t itr=0;
184 for(; itr < config_.max_num_iterations_per_episode; ++itr){
185
186 // select an action
187 auto action = policy_(q_table_, state);
188
189 // Take a on_episode
190 auto step_type_result = env.step(action);
191
192 auto next_state = step_type_result.observation();
193 auto reward = step_type_result.reward();
194 auto done = step_type_result.done();
195
196 // accumulate score
197 episode_score += reward;
198
199 if(!done){
200
201 // use the policy to select the next action
202 auto next_action = policy_(q_table_, state);
203 update_q_table_(action, state, next_state, next_action, reward);
204 state = next_state;
205 action = next_action;
206 }
207 else{
208
209 update_q_table_(action, state,
212 reward);
213
214 break;
215 }
216 }
217
218 auto end = std::chrono::steady_clock::now();
219 std::chrono::duration<real_t> elapsed_seconds = end-start;
220
221 info.episode_index = episode_idx;
222 info.episode_reward = episode_score;
223 info.episode_iterations = itr;
224 info.total_time = elapsed_seconds;
225 return info;
226 }
227
228 template<envs::discrete_world_concept EnvTp, typename PolicyType>
229 void
230 SarsaSolver<EnvTp, PolicyType>::save(const std::string& filename)const{
231
232 bitrl::utils::io::CSVWriter file_writer(filename, ',');
233 file_writer.open();
234
235 std::vector<std::string> col_names(1 + q_table_.cols());
236 col_names[0] = "state_index";
237
238 for(uint_t i = 0; i< static_cast<uint_t>(q_table_.cols()); ++i){
239 col_names[i + 1] = "action_" + std::to_string(i);
240 }
241
242 file_writer.write_column_names(col_names);
243 for(uint_t s=0; s < static_cast<uint_t>(q_table_.rows()); ++s){
244 auto actions = maths::get_row(q_table_, s);
245 auto row = std::make_tuple(s, actions);
246 file_writer.write_row(row);
247 }
248
249 }
250
251 template<envs::discrete_world_concept EnvTp, typename PolicyType>
252 void
253 SarsaSolver<EnvTp, PolicyType>::update_q_table_(const action_type& action,
254 const state_type& cstate,
255 const state_type& next_state,
256 const action_type& next_action, real_t reward){
257
258 auto q_current = q_table_(cstate, action);
259
260 // with the SARSA solver we query the
261 // q-function about its value at next state when taking next action
262 // in Q-learning we form a maximum instead
263 auto q_next = next_state != bitrl::consts::INVALID_ID ? q_table_(next_state, next_action) : 0.0;
264 auto td_target = reward + config_.gamma * q_next;
265 q_table_(cstate, action) = q_current + (config_.eta * (td_target - q_current));
266
267 }
268
269 template<envs::discrete_world_concept EnvTp, typename PolicyType>
279
280
281}
282}
283
284#endif // SARSA_H
The CSVWriter class. Handles writing into CSV file format.
Definition csv_file_writer.h:22
void write_column_names(const std::vector< std::string > &col_names, bool write_header=true)
Write the column names.
Definition csv_file_writer.cpp:16
void write_row(const std::vector< T > &vals)
Write a row of the file.
Definition csv_file_writer.h:89
virtual void open() override
Open the file for writing.
Definition file_writer_base.cpp:21
The Sarsa class.
Definition sarsa.h:46
virtual EpisodeInfo on_training_episode(env_type &, uint_t episode_idx)
on_episode Do one on_episode of the algorithm
Definition sarsa.h:172
SarsaSolver(SarsaConfig config, const PolicyType &selector)
ExpectedSarsaSolver.
Definition sarsa.h:142
void save(const std::string &filename) const
Build the policy after training.
Definition sarsa.h:230
virtual void actions_before_episode_begins(env_type &, uint_t)
actions_before_training_episode
Definition sarsa.h:89
PolicyType policy_type
action_selector_t
Definition sarsa.h:67
virtual void actions_after_training_ends(env_type &)
actions_after_training_ends. Actions to execute after the training iterations have finisehd
Definition sarsa.h:163
cuberl::rl::policies::MaxTabularPolicy build_policy() const
Build the policy after training.
Definition sarsa.h:271
TDAlgoBase< EnvType >::state_type state_type
state_t
Definition sarsa.h:62
TDAlgoBase< EnvType >::env_type env_type
env_t
Definition sarsa.h:52
TDAlgoBase< EnvType >::action_type action_type
action_t
Definition sarsa.h:57
virtual void actions_before_training_begins(env_type &)
actions_before_training_begins. Execute any actions the algorithm needs before starting the iteration...
Definition sarsa.h:152
virtual void actions_after_episode_ends(env_type &, uint_t, const EpisodeInfo &)
actions_after_training_episode
Definition sarsa.h:94
The TDAlgoBase class. Base class for deriving TD algorithms.
Definition td_algo_base.h:19
env_type::action_type action_type
action_t
Definition td_algo_base.h:30
env_type::state_type state_type
state_t
Definition td_algo_base.h:35
EnvType env_type
env_t
Definition td_algo_base.h:25
class MaxTabularPolicy
Definition max_tabular_policy.h:30
const uint_t INVALID_ID
Invalid id.
Definition bitrl_consts.h:21
const std::string INVALID_STR
Invalid string.
Definition bitrl_consts.h:26
double real_t
real_t
Definition bitrl_types.h:23
std::size_t uint_t
uint_t
Definition bitrl_types.h:43
Eigen::MatrixX< T > DynMat
Dynamically sized matrix to use around the library.
Definition bitrl_types.h:49
DynVec< T > get_row(const DynMat< T > &matrix, uint_t row_idx)
Extract the cidx-th column from the matrix.
Definition matrix_utilities.h:130
Various utilities used when working with RL problems.
Definition cuberl_types.h:16
The EpisodeInfo struct.
Definition episode_info.h:19
The SarsaConfig struct.
Definition sarsa.h:32
uint_t n_episodes
Definition sarsa.h:33
real_t tolerance
Definition sarsa.h:34
real_t gamma
Definition sarsa.h:35
std::string path
Definition sarsa.h:38
real_t eta
Definition sarsa.h:36
uint_t max_num_iterations_per_episode
Definition sarsa.h:37
Definition max_tabular_policy.h:125
void build_from_state_action_function(const DynMat< real_t > &q, MaxTabularPolicy &policy)