1#ifndef DOUBLE_Q_LEARNING_H
2#define DOUBLE_Q_LEARNING_H
9#include "cubeai/base/cubeai_types.h"
10#include "cubeai/rl/algorithms/td/td_algo_base.h"
11#include "cubeai/rl/rl_mixins.h"
12#include "cubeai/rl/worlds/envs_concepts.h"
13#include "cubeai/rl/episode_info.h"
14#include "cubeai/maths/matrix_utilities.h"
23namespace rl::algos::td
44 template<envs::discrete_world_concept EnvTp,
typename ActionSelector>
98 const EpisodeInfo& ){ action_selector_.adjust_on_episode(episode_idx);}
108 void save(std::string filename)
const;
129 template <envs::discrete_world_concept EnvTp,
typename ActionSelector>
135 action_selector_(selector)
139 template<envs::discrete_world_concept EnvTp,
typename ActionSelector>
145 template<envs::discrete_world_concept EnvTp,
typename ActionSelector>
149 auto start = std::chrono::steady_clock::now();
153 auto episode_score = 0.0;
155 auto state = env.reset().observation();
158 for(; itr < config_.max_num_iterations_per_episode; ++itr){
165 auto step_type_result = env.step(action);
167 auto next_state = step_type_result.observation();
168 auto reward = step_type_result.reward();
169 auto done = step_type_result.done();
172 episode_score += reward;
175 update_q_table_(action, state, next_state, reward);
183 auto end = std::chrono::steady_clock::now();
184 std::chrono::duration<real_t> elapsed_seconds = end-start;
186 info.episode_index = episode_idx;
187 info.episode_reward = episode_score;
188 info.episode_iterations = itr;
189 info.total_time = elapsed_seconds;
194 template <envs::discrete_world_concept EnvTp,
typename ActionSelector>
197 const state_type& next_state,
real_t reward){
201 std::mt19937 gen(config_.seed);
204 std::uniform_real_distribution<> real_dist_(0.0, 1.0);
207 if(real_dist_(gen) <= 0.5){
215 next_state, this->env_ref_().n_actions());
222 auto target = reward + (config_.gamma * Qsa_next);
225 auto new_value = q_current + (config_.eta * (target - q_current));
236 next_state, this->env_ref_().n_actions());
243 auto target = reward + (config_.gamma * Qsa_next);
246 auto new_value = q_current + (config_.eta * (target - q_current));
251 template <envs::discrete_world_concept EnvTp,
typename ActionSelector>
255 rlenvscpp::utils::io::CSVWriter file_writer(filename,
',',
true);
257 col_names[0] =
"state_index";
260 col_names[i + 1] =
"action_" + std::to_string(i);
263 file_writer.write_column_names(col_names);
268 file_writer.write_row(std::make_tuple(s, actions));
271 file_writer.write_row(std::make_tuple(s, actions));
The class DoubleQLearning. Simple tabular implemtation of double q-learning algorithm.
Definition double_q_learning.h:48
DoubleQLearning(const DoubleQLearningConfig config, const ActionSelector &selector)
Constructor.
Definition double_q_learning.h:130
virtual void actions_after_training_ends(env_type &)
actions_after_training_ends. Actions to execute after the training iterations have finisehd
TDAlgoBase< EnvTp >::action_type action_type
action_t
Definition double_q_learning.h:60
virtual void actions_before_training_begins(env_type &)
actions_before_training_begins. Execute any actions the algorithm needs before starting the iteration...
Definition double_q_learning.h:141
ActionSelector action_selector_type
action_selector_t
Definition double_q_learning.h:70
TDAlgoBase< EnvTp >::env_type env_type
env_t
Definition double_q_learning.h:55
virtual void actions_before_episode_begins(env_type &, uint_t)
actions_before_training_episode
Definition double_q_learning.h:92
void save(std::string filename) const
Definition double_q_learning.h:253
virtual EpisodeInfo on_training_episode(env_type &, uint_t episode_idx)
on_episode Do one on_episode of the algorithm
Definition double_q_learning.h:147
TDAlgoBase< EnvTp >::state_type state_type
state_t
Definition double_q_learning.h:65
virtual void actions_after_episode_ends(env_type &, uint_t episode_idx, const EpisodeInfo &)
actions_after_training_episode
Definition double_q_learning.h:97
The TDAlgoBase class. Base class for deriving TD algorithms.
Definition td_algo_base.h:19
env_type::action_type action_type
action_t
Definition td_algo_base.h:30
env_type::state_type state_type
state_t
Definition td_algo_base.h:35
EnvType env_type
env_t
Definition td_algo_base.h:25
const std::string INVALID_STR
Invalid string.
Definition bitrl_consts.h:26
double real_t
real_t
Definition bitrl_types.h:23
std::size_t uint_t
uint_t
Definition bitrl_types.h:43
Eigen::MatrixX< T > DynMat
Dynamically sized matrix to use around the library.
Definition bitrl_types.h:49
DynVec< T > get_row(const DynMat< T > &matrix, uint_t row_idx)
Extract the cidx-th column from the matrix.
Definition matrix_utilities.h:130
Various utilities used when working with RL problems.
Definition cuberl_types.h:16
The EpisodeInfo struct.
Definition episode_info.h:19
Definition double_q_learning.h:27
uint_t max_num_iterations_per_episode
Definition double_q_learning.h:33
real_t gamma
Definition double_q_learning.h:31
real_t tolerance
Definition double_q_learning.h:30
uint_t seed
Definition double_q_learning.h:35
uint_t n_episodes
Definition double_q_learning.h:34
real_t eta
Definition double_q_learning.h:32
std::string path
Definition double_q_learning.h:29
Definition rl_mixins.h:302
static uint_t max_action(const TableTp &q1_table, const TableTp &q2_table, const StateTp &state, uint_t n_actions)
Returns the max action by averaging the state values from the two tables.
Definition rl_mixins.h:322
Definition rl_mixins.h:138