bitrl & cuberl Documentation
Simulation engine for reinforcement learning agents
Loading...
Searching...
No Matches
rl_mixins.h
Go to the documentation of this file.
1#ifndef RL_MIXINS_H
2#define RL_MIXINS_H
7#include "cuberl/base/cubeai_config.h"
10
11
12#ifdef CUBERL_DEBUG
13#include <cassert>
14#endif
15
16#include <map>
17#include <tuple>
18#include <random>
19
20
21namespace cuberl::rl{
22
23namespace {
24
25template<typename StateTp>
26const DynVec<real_t>&
27get_table_values_(const std::map<StateTp,DynVec<real_t>>& table, const StateTp& state ){
28
29 auto itr = table.find(state);
30#ifdef CUBEAI_DEBUG
31 if(itr == table.end()){
32 assert(false && "Invalid state given");
33 }
34#endif
35
36 return itr->second;
37
38}
39
40template<typename StateTp>
41DynVec<real_t>&
42get_table_values_(std::map<StateTp,DynVec<real_t>>& table, const StateTp& state ){
43
44 auto itr = table.find(state);
45#ifdef CUBERL_DEBUG
46 if(itr == table.end()){
47 assert(false && "Invalid state given");
48 }
49#endif
50
51 return itr->second;
52
53}
54
55}
56
64uint_t max_action(const DynMat<real_t>& qtable, uint_t state, uint_t n_actions);
65
66
71{
79 EpsilonDecayOptionType decay_op;
80
86 real_t decay_eps(uint_t episode_index);
87
91 template<typename VectorType>
92 uint_t choose_action_index(const VectorType& values)const;
93};
94
95template<typename VectorType>
98
99 std::mt19937 gen(this->with_decay_epsilon_option_mixin::seed);
100
101 // generate a number in [0, 1]
102 std::uniform_real_distribution<> real_dist_(0.0, 1.0);
103
104 if(real_dist_(gen) > this->with_decay_epsilon_option_mixin::eps){
105 // select greedy action with probability 1 - epsilon
106 return arg_max(values);
107 }
108
109 std::uniform_int_distribution<> distrib_(0, this->with_decay_epsilon_option_mixin::n_actions - 1);
110 return distrib_(gen);
111
112}
113
118{
122
127
134 void initialize(state_type n_states, action_type n_actions, real_t init_value);
135};
136
137template<typename TableTp>
139
143template<>
145{
150
155
160
167 void initialize(const std::vector<index_type>& indices, action_type n_actions, real_t init_value);
168
172 template<int index>
173 value_type get(const state_type& state, const action_type action)const;
174
175 template<int index>
176 void set(const state_type& state, const action_type action, const value_type value);
177};
178
179template<>
181with_double_q_table_mixin< DynMat<real_t>>::get<1>(const state_type& state, const action_type action)const{
182 return q_table_1(state, action);
183}
184
185template<>
188 const action_type action)const{
189 return q_table_2(state, action);
190}
191
192template<>
193void
195 const action_type action,
196 const value_type value){
197 q_table_1(state, action) = value;
198}
199
200template<>
201void
203 const action_type action,
204 const value_type value){
205 q_table_2(state, action) = value;
206}
207
208
209
210
211template<typename KeyTp>
212struct with_double_q_table_mixin<std::map<KeyTp, DynVec<real_t>>>
213{
214
215 typedef KeyTp index_type;
216 typedef KeyTp state_type;
219
223 std::map<KeyTp, DynVec<real_t>> q_table_1;
224
228 std::map<KeyTp, DynVec<real_t>> q_table_2;
229
236 void initialize(const std::vector<index_type>& indices, action_type n_actions, real_t init_value);
237
241 template<int index>
242 value_type get(const state_type& state, const action_type action)const;
243
247 template<int index>
248 void set(const state_type& state, const action_type action, const value_type value);
249
250};
251
252template<typename KeyTp>
253void
254with_double_q_table_mixin<std::map<KeyTp, DynVec<real_t>>>::initialize( const std::vector<index_type>& indices,
255 action_type n_actions,
256 real_t init_value){
257
258
259 DynVec<real_t> init_vals(n_actions, init_value);
260
261 for(uint_t i=0; i< indices.size(); ++i){
262
263 q_table_1[indices[i]] = init_vals;
264 q_table_2[indices[i]] = init_vals;
265 }
266}
267
268template<typename KeyTp>
269template<int index>
272
273 static_assert (index == 1 || index == 2, "Invalid index for template parameter");
274 if(index == 1){
275 return get_table_values_(q_table_1, state)[action];
276 }
277
278 return get_table_values_(q_table_2, state)[action];
279
280}
281
282template<typename KeyTp>
283template<int index>
284void
286 const action_type action,
287 const value_type value){
288
289 static_assert (index == 1 || index == 2, "Invalid index for template parameter");
290
291 if(index == 1){
292 auto& vals1 = get_table_values_(q_table_1, state);
293 vals1[action] = value;
294 }
295
296 auto& vals2 = get_table_values_(q_table_2, state);
297 vals2[action] = value;
298}
299
300
302{
303
307 template<typename TableTp, typename StateTp>
308 static uint_t max_action(const TableTp& q1_table, const TableTp& q2_table,
309 const StateTp& state, uint_t n_actions);
310
314 template<typename TableTp, typename StateTp>
315 static uint_t max_action(const TableTp& q1_table, const StateTp& state, uint_t n_actions);
316
317};
318
319
320template<typename TableTp, typename StateTp>
321uint_t
322with_double_q_table_max_action_mixin::max_action(const TableTp& q1_table, const TableTp& q2_table,
323 const StateTp& state, uint_t /*n_actions*/){
324
325 const auto& vals1 = get_table_values_(q1_table, state);
326 const auto& vals2 = get_table_values_(q2_table, state);
327 auto sum = vals1 + vals2;
328 return 1; //blaze::argmax(sum);
329
330}
331
332template<typename TableTp, typename StateTp>
333uint_t
334with_double_q_table_max_action_mixin::max_action(const TableTp& q_table, const StateTp& state, uint_t /*n_actions*/){
335
336 const auto& vals = get_table_values_(q_table, state);
337 return 1; //blaze::argmax(vals);
338
339}
340
341
342
343}
344
345#endif // RL_MIXINS_H
double real_t
real_t
Definition bitrl_types.h:23
Eigen::RowVectorX< T > DynVec
Dynamically sized row vector.
Definition bitrl_types.h:74
std::size_t uint_t
uint_t
Definition bitrl_types.h:43
Eigen::MatrixX< T > DynMat
Dynamically sized matrix to use around the library.
Definition bitrl_types.h:49
Definition dummy_agent.h:8
uint_t max_action(const DynMat< real_t > &qtable, uint_t state, uint_t n_actions)
max_action
The with_decay_epsilon_option_mixin struct.
Definition rl_mixins.h:71
uint_t n_actions
Definition rl_mixins.h:77
EpsilonDecayOptionType decay_op
Definition rl_mixins.h:79
real_t decay_eps(uint_t episode_index)
decay_eps
uint_t seed
Definition rl_mixins.h:78
real_t max_eps
Definition rl_mixins.h:75
real_t min_eps
Definition rl_mixins.h:74
real_t eps_init
Definition rl_mixins.h:72
real_t eps
Definition rl_mixins.h:73
uint_t choose_action_index(const VectorType &values) const
Definition rl_mixins.h:97
real_t epsilon_decay
Definition rl_mixins.h:76
static uint_t max_action(const TableTp &q1_table, const TableTp &q2_table, const StateTp &state, uint_t n_actions)
Returns the max action by averaging the state values from the two tables.
Definition rl_mixins.h:322
DynMat< value_type > q_table_2
q_table_2
Definition rl_mixins.h:159
DynMat< value_type > q_table_1
q_table_1
Definition rl_mixins.h:154
void initialize(const std::vector< index_type > &indices, action_type n_actions, real_t init_value)
initialize
void set(const state_type &state, const action_type action, const value_type value)
value_type get(const state_type &state, const action_type action) const
value_type get(const state_type &state, const action_type action) const
std::map< KeyTp, DynVec< real_t > > q_table_1
q_table_1
Definition rl_mixins.h:223
std::map< KeyTp, DynVec< real_t > > q_table_2
q_table_2
Definition rl_mixins.h:228
Definition rl_mixins.h:138
The WithQTableMixin struct.
Definition rl_mixins.h:118
void initialize(state_type n_states, action_type n_actions, real_t init_value)
initialize
real_t value_type
Definition rl_mixins.h:121
DynMat< value_type > q_table
q_table
Definition rl_mixins.h:126
uint_t action_type
Definition rl_mixins.h:120
uint_t state_type
Definition rl_mixins.h:119