1#ifndef EXPECTED_SARSA_H
2#define EXPECTED_SARSA_H
4#include "cubeai/base/cubeai_config.h"
5#include "cubeai/base/cubeai_types.h"
6#include "cubeai/rl/algorithms/td/td_algo_base.h"
7#include "cubeai/rl/worlds/envs_concepts.h"
17namespace rl::algos::td
25 template<envs::discrete_world_concept EnvTp,
typename ActionSelector>
56 const ActionSelector& selector);
62 env_type& env,
const ActionSelector& selector);
80 uint_t current_score_counter_;
91 template <envs::discrete_world_concept EnvTp, typename ActionSelector>
94 env_type& env,
uint_t max_num_iterations_per_episode, const ActionSelector& selector)
96 TDAlgoBase<EnvTp>(n_episodes, tolerance, gamma, eta, plot_f, max_num_iterations_per_episode, env),
97 action_selector_(selector),
98 current_score_counter_(0)
101 template <envs::discrete_world_concept EnvTp,
typename ActionSelector>
105 action_selector_(selector),
106 current_score_counter_(0)
109 template <envs::discrete_world_concept EnvTp,
typename ActionSelector>
115 auto state = this->env_ref_().reset().observation();
118 auto action = action_selector_(this->q_table(), state);
121 for(; itr < this->n_iterations_per_episode(); ++itr){
124 auto action = action_selector_(this->q_table(), state);
125 if(this->is_verbose()){
126 std::cout<<
"Episode iteration="<<itr<<
" of="<<this->n_iterations_per_episode()<<std::endl;
127 std::cout<<
"State="<<state<<std::endl;
128 std::cout<<
"Action="<<action<<std::endl;
132 auto step_type_result = this->env_ref_().step(action);
134 auto next_state = step_type_result.observation();
135 auto reward = step_type_result.reward();
136 auto done = step_type_result.done();
142 auto next_action = action_selector_(this->q_table(), state);
143 update_q_table_(action, state, next_state, next_action, reward);
145 action = next_action;
149 update_q_table_(action, state, CubeAIConsts::invalid_size_type(),
150 CubeAIConsts::invalid_size_type(), reward);
152 this->tmp_scores()[current_score_counter_++] = score;
154 if(current_score_counter_ >= this->render_env_frequency_){
155 current_score_counter_ = 0;
158 if(this->is_verbose()){
159 std::cout<<
"============================================="<<std::endl;
160 std::cout<<
"Break out from episode="<<this->current_episode_idx()<<std::endl;
161 std::cout<<
"============================================="<<std::endl;
171 action_selector_.adjust_on_episode(this->current_episode_idx());
172 if(current_score_counter_ >= this->render_env_frequency_){
173 current_score_counter_ = 0;
176 std::cout<<
"Finished on_episode="<<this->current_episode_idx()<<std::endl;
180 template<envs::discrete_world_concept EnvTp,
typename ActionSelector>
183 const state_type& cstate,
184 const state_type& next_state,
185 const action_type& next_action,
real_t reward){
188 assert(action < this->env_ref_().n_actions() &&
"Inavlid action idx");
189 assert(cstate < this->env_ref_().n_states() &&
"Inavlid state idx");
191 if(next_state != CubeAIConsts::invalid_size_type())
192 assert(next_state < this->env_ref_().n_states() &&
"Inavlid next_state idx");
194 if(next_action != CubeAIConsts::invalid_size_type())
195 assert(next_action < this->env_ref_().n_actions() &&
"Inavlid next_action idx");
198 const auto eps = action_selector_.eps_value();
199 auto q_current = this->q_table()[cstate][action];
200 auto policy_s =
DynVec<real_t>(this->env_ref_().n_actions(), 1.0);
201 policy_s *= eps / this->env_ref_().n_actions();
203 auto state_action_values = this->q_table()[next_state];
204 auto argmax = blaze::argmax(state_action_values);
205 policy_s[argmax] = 1 - eps + (eps / this->env_ref_().n_actions());
207 auto q_next = state_action_values * policy_s;
208 auto td_target = reward + this->gamma() * q_next;
The ExpectedSARSA class. Simple implementation of the expected SARSA algorithm.
Definition expected_sarsa.h:27
TDAlgoBase< EnvTp >::action_type action_type
action_t
Definition expected_sarsa.h:38
TDAlgoBase< EnvTp >::env_type env_type
env_t
Definition expected_sarsa.h:33
ActionSelector action_selector_type
action_selector_t
Definition expected_sarsa.h:48
ExpectedSARSA(uint_t n_episodes, real_t tolerance, real_t gamma, real_t eta, uint_t plot_f, env_type &env, uint_t max_num_iterations_per_episode, const ActionSelector &selector)
Constructor.
Definition expected_sarsa.h:92
virtual void on_episode() override final
on_episode. Performs the iterations for one training episode
Definition expected_sarsa.h:111
TDAlgoBase< EnvTp >::state_type state_type
state_t
Definition expected_sarsa.h:43
The TDAlgoBase class. Base class for deriving TD algorithms.
Definition td_algo_base.h:19
env_type::action_type action_type
action_t
Definition td_algo_base.h:30
env_type::state_type state_type
state_t
Definition td_algo_base.h:35
EnvType env_type
env_t
Definition td_algo_base.h:25
double real_t
real_t
Definition bitrl_types.h:23
Eigen::RowVectorX< T > DynVec
Dynamically sized row vector.
Definition bitrl_types.h:74
std::size_t uint_t
uint_t
Definition bitrl_types.h:43
Various utilities used when working with RL problems.
Definition cuberl_types.h:16