bitrl & cuberl Documentation
Simulation engine for reinforcement learning agents
Loading...
Searching...
No Matches
expected_sarsa.h
Go to the documentation of this file.
1#ifndef EXPECTED_SARSA_H
2#define EXPECTED_SARSA_H
3
4#include "cubeai/base/cubeai_config.h"
5#include "cubeai/base/cubeai_types.h"
6#include "cubeai/rl/algorithms/td/td_algo_base.h"
7#include "cubeai/rl/worlds/envs_concepts.h"
9
10
11
12#ifdef CUBERL_DEBUG
13#include <cassert>
14#endif
15
16namespace cuberl {
17namespace rl::algos::td
18{
19
20
25 template<envs::discrete_world_concept EnvTp, typename ActionSelector>
26 class ExpectedSARSA: public TDAlgoBase<EnvTp>
27 {
28 public:
29
34
39
44
48 typedef ActionSelector action_selector_type;
49
53 ExpectedSARSA(uint_t n_episodes, real_t tolerance,
54 real_t gamma, real_t eta, uint_t plot_f,
55 env_type& env, uint_t max_num_iterations_per_episode,
56 const ActionSelector& selector);
57
61 ExpectedSARSA(TDAlgoConfig config,
62 env_type& env, const ActionSelector& selector);
63
68 virtual void on_episode()override final;
69
70 private:
71
75 action_selector_type action_selector_;
76
80 uint_t current_score_counter_;
81
86 void update_q_table_(const action_type& action, const state_type& cstate,
87 const state_type& next_state, const action_type& next_action, real_t reward);
88
89 };
90
91 template <envs::discrete_world_concept EnvTp, typename ActionSelector>
92 ExpectedSARSA<EnvTp, ActionSelector>::ExpectedSARSA(uint_t n_episodes, real_t tolerance, real_t gamma,
93 real_t eta, uint_t plot_f,
94 env_type& env, uint_t max_num_iterations_per_episode, const ActionSelector& selector)
95 :
96 TDAlgoBase<EnvTp>(n_episodes, tolerance, gamma, eta, plot_f, max_num_iterations_per_episode, env),
97 action_selector_(selector),
98 current_score_counter_(0)
99 {}
100
101 template <envs::discrete_world_concept EnvTp, typename ActionSelector>
102 ExpectedSARSA<EnvTp, ActionSelector>::ExpectedSARSA(TDAlgoConfig config, env_type& env, const ActionSelector& selector)
103 :
104 TDAlgoBase<EnvTp>(config, env),
105 action_selector_(selector),
106 current_score_counter_(0)
107 {}
108
109 template <envs::discrete_world_concept EnvTp, typename ActionSelector>
110 void
112
113 // total score for the episode
114 auto score = 0.0;
115 auto state = this->env_ref_().reset().observation();
116
117 // select an action
118 auto action = action_selector_(this->q_table(), state);
119
120 uint_t itr=0;
121 for(; itr < this->n_iterations_per_episode(); ++itr){
122
123 // select an action
124 auto action = action_selector_(this->q_table(), state);
125 if(this->is_verbose()){
126 std::cout<<"Episode iteration="<<itr<<" of="<<this->n_iterations_per_episode()<<std::endl;
127 std::cout<<"State="<<state<<std::endl;
128 std::cout<<"Action="<<action<<std::endl;
129 }
130
131 // Take a on_episode
132 auto step_type_result = this->env_ref_().step(action);
133
134 auto next_state = step_type_result.observation();
135 auto reward = step_type_result.reward();
136 auto done = step_type_result.done();
137
138 // accumulate score
139 score += reward;
140
141 if(!done){
142 auto next_action = action_selector_(this->q_table(), state);
143 update_q_table_(action, state, next_state, next_action, reward);
144 state = next_state;
145 action = next_action;
146 }
147 else{
148
149 update_q_table_(action, state, CubeAIConsts::invalid_size_type(),
150 CubeAIConsts::invalid_size_type(), reward);
151
152 this->tmp_scores()[current_score_counter_++] = score;
153
154 if(current_score_counter_ >= this->render_env_frequency_){
155 current_score_counter_ = 0;
156 }
157
158 if(this->is_verbose()){
159 std::cout<<"============================================="<<std::endl;
160 std::cout<<"Break out from episode="<<this->current_episode_idx()<<std::endl;
161 std::cout<<"============================================="<<std::endl;
162 }
163
164 break;
165 }
166 }
167
168 // make any adjustments to the way
169 // actions are selected given the experience collected
170 // in the episode
171 action_selector_.adjust_on_episode(this->current_episode_idx());
172 if(current_score_counter_ >= this->render_env_frequency_){
173 current_score_counter_ = 0;
174 }
175
176 std::cout<<"Finished on_episode="<<this->current_episode_idx()<<std::endl;
177
178 }
179
180 template<envs::discrete_world_concept EnvTp, typename ActionSelector>
181 void
183 const state_type& cstate,
184 const state_type& next_state,
185 const action_type& next_action, real_t reward){
186
187#ifdef CUBERL_DEBUG
188 assert(action < this->env_ref_().n_actions() && "Inavlid action idx");
189 assert(cstate < this->env_ref_().n_states() && "Inavlid state idx");
190
191 if(next_state != CubeAIConsts::invalid_size_type())
192 assert(next_state < this->env_ref_().n_states() && "Inavlid next_state idx");
193
194 if(next_action != CubeAIConsts::invalid_size_type())
195 assert(next_action < this->env_ref_().n_actions() && "Inavlid next_action idx");
196#endif
197
198 const auto eps = action_selector_.eps_value();
199 auto q_current = this->q_table()[cstate][action];
200 auto policy_s = DynVec<real_t>(this->env_ref_().n_actions(), 1.0);
201 policy_s *= eps / this->env_ref_().n_actions();
202
203 auto state_action_values = this->q_table()[next_state];
204 auto argmax = blaze::argmax(state_action_values);
205 policy_s[argmax] = 1 - eps + (eps / this->env_ref_().n_actions());
206
207 auto q_next = state_action_values * policy_s;
208 auto td_target = reward + this->gamma() * q_next;
209 //this->q_table()[cstate][action] = q_current + (this->eta() * (td_target - q_current));
210 }
211
212}
213}
214
215#endif // EXPECTED_SARSA_H
The ExpectedSARSA class. Simple implementation of the expected SARSA algorithm.
Definition expected_sarsa.h:27
TDAlgoBase< EnvTp >::action_type action_type
action_t
Definition expected_sarsa.h:38
TDAlgoBase< EnvTp >::env_type env_type
env_t
Definition expected_sarsa.h:33
ActionSelector action_selector_type
action_selector_t
Definition expected_sarsa.h:48
ExpectedSARSA(uint_t n_episodes, real_t tolerance, real_t gamma, real_t eta, uint_t plot_f, env_type &env, uint_t max_num_iterations_per_episode, const ActionSelector &selector)
Constructor.
Definition expected_sarsa.h:92
virtual void on_episode() override final
on_episode. Performs the iterations for one training episode
Definition expected_sarsa.h:111
TDAlgoBase< EnvTp >::state_type state_type
state_t
Definition expected_sarsa.h:43
The TDAlgoBase class. Base class for deriving TD algorithms.
Definition td_algo_base.h:19
env_type::action_type action_type
action_t
Definition td_algo_base.h:30
env_type::state_type state_type
state_t
Definition td_algo_base.h:35
EnvType env_type
env_t
Definition td_algo_base.h:25
double real_t
real_t
Definition bitrl_types.h:23
Eigen::RowVectorX< T > DynVec
Dynamically sized row vector.
Definition bitrl_types.h:74
std::size_t uint_t
uint_t
Definition bitrl_types.h:43
Various utilities used when working with RL problems.
Definition cuberl_types.h:16