bitrl & cuberl Documentation
Simulation engine for reinforcement learning agents
Loading...
Searching...
No Matches
double_q_learning.h
Go to the documentation of this file.
1#ifndef DOUBLE_Q_LEARNING_H
2#define DOUBLE_Q_LEARNING_H
3
8
9#include "cubeai/base/cubeai_types.h"
10#include "cubeai/rl/algorithms/td/td_algo_base.h"
11#include "cubeai/rl/rl_mixins.h"
12#include "cubeai/rl/worlds/envs_concepts.h"
13#include "cubeai/rl/episode_info.h"
14#include "cubeai/maths/matrix_utilities.h"
15
17#include "bitrl/bitrl_consts.h"
18
19#include <chrono>
20#include <random>
21
22namespace cuberl{
23namespace rl::algos::td
24{
25
38
39
44 template<envs::discrete_world_concept EnvTp, typename ActionSelector>
45 class DoubleQLearning final: public TDAlgoBase<EnvTp>,
46 protected with_double_q_table_mixin<DynMat<real_t>>,
48 {
49 public:
50
51
56
61
66
70 typedef ActionSelector action_selector_type;
71
75 DoubleQLearning(const DoubleQLearningConfig config, const ActionSelector& selector);
76
82
88
92 virtual void actions_before_episode_begins(env_type&, uint_t /*episode_idx*/){}
93
97 virtual void actions_after_episode_ends(env_type&, uint_t episode_idx,
98 const EpisodeInfo& /*einfo*/){ action_selector_.adjust_on_episode(episode_idx);}
99
103 virtual EpisodeInfo on_training_episode(env_type&, uint_t episode_idx);
104
108 void save(std::string filename)const;
109
110 private:
111
112 DoubleQLearningConfig config_;
113
118 action_selector_type action_selector_;
119
124 void update_q_table_(const action_type& action, const state_type& cstate,
125 const state_type& next_state, real_t reward);
126
127 };
128
129 template <envs::discrete_world_concept EnvTp, typename ActionSelector>
131 :
132 TDAlgoBase<EnvTp>(),
134 config_(config),
135 action_selector_(selector)
136 {}
137
138
139 template<envs::discrete_world_concept EnvTp, typename ActionSelector>
140 void
144
145 template<envs::discrete_world_concept EnvTp, typename ActionSelector>
148
149 auto start = std::chrono::steady_clock::now();
150 EpisodeInfo info;
151
152 // total score for the episode
153 auto episode_score = 0.0;
154
155 auto state = env.reset().observation();
156
157 uint_t itr=0;
158 for(; itr < config_.max_num_iterations_per_episode; ++itr){
159
160 // select an action
161 auto action = action_selector_(this->with_double_q_table_mixin<DynMat<real_t>>::q_table_1,
162 this->with_double_q_table_mixin<DynMat<real_t>>::q_table_2, state);
163
164 // Take an action on the environment
165 auto step_type_result = env.step(action);
166
167 auto next_state = step_type_result.observation();
168 auto reward = step_type_result.reward();
169 auto done = step_type_result.done();
170
171 // accumulate score
172 episode_score += reward;
173
174 // update the table
175 update_q_table_(action, state, next_state, reward);
176 state = next_state;
177
178 if(done){
179 break;
180 }
181 }
182
183 auto end = std::chrono::steady_clock::now();
184 std::chrono::duration<real_t> elapsed_seconds = end-start;
185
186 info.episode_index = episode_idx;
187 info.episode_reward = episode_score;
188 info.episode_iterations = itr;
189 info.total_time = elapsed_seconds;
190 return info;
191
192 }
193
194 template <envs::discrete_world_concept EnvTp, typename ActionSelector>
195 void
196 DoubleQLearning<EnvTp, ActionSelector>::update_q_table_(const action_type& action, const state_type& cstate,
197 const state_type& next_state, real_t reward){
198
199 // flip a coin 50% of the time we update Q1
200 // whilst 50% of the time Q2
201 std::mt19937 gen(config_.seed); //rd());
202
203 // generate a number in [0, 1]
204 std::uniform_real_distribution<> real_dist_(0.0, 1.0);
205
206 // update Q1
207 if(real_dist_(gen) <= 0.5){
208
209 // the current qvalue
210 auto q_current = this->with_double_q_table_mixin<DynMat<real_t>>::template get<1>(cstate, action);
211 auto Qsa_next = 0.0;
212
213 //if(this->env_ref_().is_valid_state(next_state)){
215 next_state, this->env_ref_().n_actions());
216
217 // value of next state
218 Qsa_next = this->with_double_q_table_mixin<DynMat<real_t>>::template get<2>(next_state, max_act);
219 //}
220
221 // construct TD target
222 auto target = reward + (config_.gamma * Qsa_next);
223
224 // get updated value
225 auto new_value = q_current + (config_.eta * (target - q_current));
226 this->with_double_q_table_mixin<DynMat<real_t>>::template set<1>(cstate, action, new_value);
227 }
228 else{
229
230 // the current qvalue
231 auto q_current = this->with_double_q_table_mixin<DynMat<real_t>>::template get<2>(cstate, action);
232 auto Qsa_next = 0.0;
233
234
236 next_state, this->env_ref_().n_actions());
237
238 // value of next state
239 Qsa_next = this->with_double_q_table_mixin<DynMat<real_t>>::template get<1>(next_state, max_act);
240
241
242 // construct TD target
243 auto target = reward + (config_.gamma * Qsa_next);
244
245 // get updated value
246 auto new_value = q_current + (config_.eta * (target - q_current));
247 this->with_double_q_table_mixin<DynMat<real_t>>::template set<2>(cstate, action, new_value);
248 }
249 }
250
251 template <envs::discrete_world_concept EnvTp, typename ActionSelector>
252 void
254
255 rlenvscpp::utils::io::CSVWriter file_writer(filename, ',', true);
256 std::vector<std::string> col_names(1 + this->with_double_q_table_mixin<DynMat<real_t>>::q_table_1.columns());
257 col_names[0] = "state_index";
258
259 for(uint_t i = 0; i< this->with_double_q_table_mixin<DynMat<real_t>>::q_table_1.columns(); ++i){
260 col_names[i + 1] = "action_" + std::to_string(i);
261 }
262
263 file_writer.write_column_names(col_names);
264
265 for(uint_t s=0; s < this->with_double_q_table_mixin<DynMat<real_t>>::q_table_1.rows(); ++s){
266
267 auto actions = maths::get_row(this->with_double_q_table_mixin<DynMat<real_t>>::q_table_1, s);
268 file_writer.write_row(std::make_tuple(s, actions));
269
270 actions = maths::get_row(this->with_double_q_table_mixin<DynMat<real_t>>::q_table_2, s);
271 file_writer.write_row(std::make_tuple(s, actions));
272 }
273 }
274
275}
276}
277
278#endif // DOUBLE_Q_LEARNING_H
The class DoubleQLearning. Simple tabular implemtation of double q-learning algorithm.
Definition double_q_learning.h:48
DoubleQLearning(const DoubleQLearningConfig config, const ActionSelector &selector)
Constructor.
Definition double_q_learning.h:130
virtual void actions_after_training_ends(env_type &)
actions_after_training_ends. Actions to execute after the training iterations have finisehd
TDAlgoBase< EnvTp >::action_type action_type
action_t
Definition double_q_learning.h:60
virtual void actions_before_training_begins(env_type &)
actions_before_training_begins. Execute any actions the algorithm needs before starting the iteration...
Definition double_q_learning.h:141
ActionSelector action_selector_type
action_selector_t
Definition double_q_learning.h:70
TDAlgoBase< EnvTp >::env_type env_type
env_t
Definition double_q_learning.h:55
virtual void actions_before_episode_begins(env_type &, uint_t)
actions_before_training_episode
Definition double_q_learning.h:92
void save(std::string filename) const
Definition double_q_learning.h:253
virtual EpisodeInfo on_training_episode(env_type &, uint_t episode_idx)
on_episode Do one on_episode of the algorithm
Definition double_q_learning.h:147
TDAlgoBase< EnvTp >::state_type state_type
state_t
Definition double_q_learning.h:65
virtual void actions_after_episode_ends(env_type &, uint_t episode_idx, const EpisodeInfo &)
actions_after_training_episode
Definition double_q_learning.h:97
The TDAlgoBase class. Base class for deriving TD algorithms.
Definition td_algo_base.h:19
env_type::action_type action_type
action_t
Definition td_algo_base.h:30
env_type::state_type state_type
state_t
Definition td_algo_base.h:35
EnvType env_type
env_t
Definition td_algo_base.h:25
const std::string INVALID_STR
Invalid string.
Definition bitrl_consts.h:26
double real_t
real_t
Definition bitrl_types.h:23
std::size_t uint_t
uint_t
Definition bitrl_types.h:43
Eigen::MatrixX< T > DynMat
Dynamically sized matrix to use around the library.
Definition bitrl_types.h:49
DynVec< T > get_row(const DynMat< T > &matrix, uint_t row_idx)
Extract the cidx-th column from the matrix.
Definition matrix_utilities.h:130
Various utilities used when working with RL problems.
Definition cuberl_types.h:16
The EpisodeInfo struct.
Definition episode_info.h:19
Definition double_q_learning.h:27
uint_t max_num_iterations_per_episode
Definition double_q_learning.h:33
real_t gamma
Definition double_q_learning.h:31
real_t tolerance
Definition double_q_learning.h:30
uint_t seed
Definition double_q_learning.h:35
uint_t n_episodes
Definition double_q_learning.h:34
real_t eta
Definition double_q_learning.h:32
std::string path
Definition double_q_learning.h:29
static uint_t max_action(const TableTp &q1_table, const TableTp &q2_table, const StateTp &state, uint_t n_actions)
Returns the max action by averaging the state values from the two tables.
Definition rl_mixins.h:322
Definition rl_mixins.h:138