bitrl & cuberl Documentation
Simulation engine for reinforcement learning agents
Loading...
Searching...
No Matches
epsilon_greedy_policy.h
Go to the documentation of this file.
1#ifndef EPSILON_GREEDY_POLICY_H
2#define EPSILON_GREEDY_POLICY_H
3
4#include "cuberl/base/cubeai_config.h"
8
9#ifdef USE_PYTORCH
11#endif
12
13#include <random>
14#include <cmath>
15
16namespace cuberl {
17namespace rl {
18namespace policies {
19
20
25
30{
31public:
32
33
38
39 constexpr static real_t MIN_EPS = 0.01;
40 constexpr static real_t MAX_EPS = 1.0;
41 constexpr static real_t EPSILON_DECAY_FACTOR = 0.01;
42
47
51 explicit EpsilonGreedyPolicy(real_t eps, uint_t seed);
52
57 explicit EpsilonGreedyPolicy(real_t eps, uint_t seed,
58 EpsilonDecayOption decay_op,
59 real_t min_eps = MIN_EPS,
60 real_t max_eps=MAX_EPS,
61 real_t epsilon_decay = EPSILON_DECAY_FACTOR);
62
66 template<typename MapType>
67 output_type operator()(const MapType& q_map, uint_t state)const;
68
69
73 template<typename VecType>
74 output_type operator()(const VecType& vec)const;
75
79 template<typename MatType>
80 output_type get_action(const MatType& q_map, uint_t state_idx);
81
87 template<typename VecTp>
88 output_type get_action(const VecTp& q_map);
89
90#ifdef USE_PYTORCH
91 output_type operator()(const torch_tensor_t& vec, torch_tensor_value_type<real_t>)const;
92 output_type operator()(const torch_tensor_t& vec, torch_tensor_value_type<float_t>)const;
93 output_type operator()(const torch_tensor_t& vec, torch_tensor_value_type<int_t>)const;
94 output_type operator()(const torch_tensor_t& vec, torch_tensor_value_type<lint_t>)const;
95#endif
96
101 void on_episode(uint_t episode_idx)noexcept;
102
106 void reset()noexcept{eps_ = eps_init_;}
107
111 real_t eps_value()const noexcept{return eps_;}
112
118
122 EpsilonDecayOption decay_option()const noexcept{return decay_op_;}
123
124
125private:
126
127 real_t eps_init_;
128 real_t eps_;
129 real_t min_eps_;
130 real_t max_eps_;
131 real_t epsilon_decay_;
132 EpsilonDecayOption decay_op_;
133
137 mutable std::mt19937 generator_;
138
139 // how to select the action
140 RandomTabularPolicy random_policy_;
141 MaxTabularPolicy max_policy_;
142};
143
144inline
146 real_t min_eps, real_t max_eps, real_t epsilon_decay)
147:
148eps_init_(eps),
149eps_(eps),
150min_eps_(min_eps),
151max_eps_(max_eps),
152epsilon_decay_(epsilon_decay),
153decay_op_(decay_op),
154generator_(seed),
155random_policy_(seed),
156max_policy_()
157{}
158
159inline
161 :
162 eps_init_(eps),
163 eps_(eps),
164 min_eps_(eps),
165 max_eps_(eps),
166 epsilon_decay_(eps),
167 decay_op_(EpsilonDecayOption::NONE),
168 random_policy_(),
169 max_policy_()
170{}
171
172inline
174 :
176 eps, eps, eps)
177{}
178
179
180template<typename VecType>
182EpsilonGreedyPolicy::operator()(const VecType& vec)const{
183
184 // generate a number in [0, 1]
185 std::uniform_real_distribution<> real_dist_(0.0, 1.0);
186
187 if(real_dist_(generator_) > eps_){
188 // select greedy action with probability 1 - epsilon
189 return max_policy_.get_action(vec);
190 }
191
192 // else select a random action
193 return random_policy_(vec);
194}
195
196
197template<typename VecTp>
200 // generate a number in [0, 1]
201 std::uniform_real_distribution<> real_dist_(0.0, 1.0);
202
203 if(real_dist_(generator_) > eps_){
204 // select greedy action with probability 1 - epsilon
205 return max_policy_.get_action(vec);
206 }
207
208 // else select a random action
209 return random_policy_(vec);
210}
211
212}
213}
214}
215
216#endif // EPSILON_GREEDY_POLICY_H
The EpsilonGreedyPolicy class.
Definition epsilon_greedy_policy.h:30
EpsilonGreedyPolicy(real_t eps)
Constructor. Creates an epsilon-greedy tabular policy.
Definition epsilon_greedy_policy.h:160
void reset() noexcept
Reset the policy.
Definition epsilon_greedy_policy.h:106
static constexpr real_t MIN_EPS
Definition epsilon_greedy_policy.h:39
EpsilonDecayOption decay_option() const noexcept
Returns the decay option.
Definition epsilon_greedy_policy.h:122
real_t eps_value() const noexcept
Returns the value of the epsilon.
Definition epsilon_greedy_policy.h:111
static constexpr real_t MAX_EPS
Definition epsilon_greedy_policy.h:40
void on_episode(uint_t episode_idx) noexcept
any actions the policy should perform on the given episode index
static constexpr real_t EPSILON_DECAY_FACTOR
Definition epsilon_greedy_policy.h:41
output_type get_action(const MatType &q_map, uint_t state_idx)
get_action. Given a
uint_t output_type
The type returned when calling this->operator()
Definition epsilon_greedy_policy.h:37
output_type operator()(const MapType &q_map, uint_t state) const
operator() Select action for the given state
void set_eps_value(real_t eps)
Set the epsilon value.
class MaxTabularPolicy
Definition max_tabular_policy.h:30
static output_type get_action(const MatType &q_map, uint_t state_idx)
get_action. Given a
class RandomTabularPolicy
Definition random_tabular_policy.h:23
double real_t
real_t
Definition bitrl_types.h:23
std::size_t uint_t
uint_t
Definition bitrl_types.h:43
EpsilonDecayOption
The EpsilonDecayOption enum. Enumerate various decaying options.
Definition epsilon_greedy_policy.h:24
Various utilities used when working with RL problems.
Definition cuberl_types.h:16