bitrl & cuberl Documentation
Simulation engine for reinforcement learning agents
Loading...
Searching...
No Matches
q_learning.h
Go to the documentation of this file.
1#ifndef Q_LEARNING_H
2#define Q_LEARNING_H
3
10
12#include "bitrl/bitrl_consts.h"
13
14#ifdef CUBERL_DEBUG
15#include <cassert>
16#endif
17
18#include <chrono>
19
20namespace cuberl {
21namespace rl::algos::td
22{
23
38
39
46 template<envs::discrete_world_concept EnvTp, typename PolicyType>
47 class QLearningSolver final: public TDAlgoBase<EnvTp>
48 {
49
50 public:
51
56
61
66
70 typedef PolicyType policy_type;
71
75 QLearningSolver(const QLearningConfig config, const PolicyType& policy);
76
82
88
92 virtual void actions_before_episode_begins(env_type&, uint_t /*episode_idx*/){}
93
97 virtual void actions_after_episode_ends(env_type&, uint_t episode_idx,
98 const EpisodeInfo& /*einfo*/);
99
103 virtual EpisodeInfo on_training_episode(env_type&, uint_t episode_idx);
104
108 void save(const std::string& filename)const;
109
114
115 private:
116
120 QLearningConfig config_;
121
125 policy_type policy_;
126
130 DynMat<real_t> q_table_;
131
136 void update_q_table_(const action_type& action, const state_type& cstate,
137 const state_type& next_state, const action_type& next_action,
138 real_t reward);
139
140 };
141
142 template <envs::discrete_world_concept EnvTp, typename PolicyType>
144 const PolicyType& policy)
145 :
146 TDAlgoBase<EnvTp>(),
147 config_(config),
148 policy_(policy),
149 q_table_()
150 {}
151
152 template<envs::discrete_world_concept EnvTp, typename PolicyType>
153 void
155 q_table_ = DynMat<real_t>(env.n_states(), env.n_actions());
156
157 for(uint_t i=0; i < env.n_states(); ++i)
158 for(uint_t j=0; j < env.n_actions(); ++j)
159 q_table_(i, j) = 0.0;
160
161 }
162
163 template<envs::discrete_world_concept EnvTp, typename PolicyType>
164 void
166
167 if(config_.path != bitrl::consts::INVALID_STR){
168 save(config_.path);
169 }
170 }
171
172
173 template<envs::discrete_world_concept EnvTp, typename PolicyType>
176
177 auto start = std::chrono::steady_clock::now();
178 EpisodeInfo info;
179
180 // total score for the episode
181 auto episode_score = 0.0;
182 auto state = env.reset().observation();
183
184 uint_t itr=0;
185 for(; itr < config_.max_num_iterations_per_episode; ++itr){
186
187
188 // select an action
189 auto action = policy_(q_table_, state);
190
191 // Take a on_episode
192 auto step_type_result = env.step(action);
193
194 auto next_state = step_type_result.observation();
195 auto reward = step_type_result.reward();
196 auto done = step_type_result.done();
197
198 // accumulate score
199 episode_score += reward;
200
201 if(!done){
202 auto next_action = policy_(q_table_, state);
203 update_q_table_(action, state, next_state, next_action, reward);
204 state = next_state;
205 action = next_action;
206 }
207 else{
208
209 update_q_table_(action, state,
212 reward);
213
214
215 break;
216 }
217 }
218
219 auto end = std::chrono::steady_clock::now();
220 std::chrono::duration<real_t> elapsed_seconds = end-start;
221
222 info.episode_index = episode_idx;
223 info.episode_reward = config_.average_episode_reward ? episode_score / static_cast<real_t>(itr) : episode_score;
224 info.episode_iterations = itr;
225 info.total_time = elapsed_seconds;
226 return info;
227 }
228
229 template<envs::discrete_world_concept EnvTp, typename PolicyType>
230 void
232 const EpisodeInfo& /*einfo*/){
233 policy_.on_episode(episode_idx);
234 }
235
236 template<envs::discrete_world_concept EnvTp, typename PolicyType>
237 void
238 QLearningSolver<EnvTp, PolicyType>::save(const std::string& filename)const{
239
240 bitrl::utils::io::CSVWriter file_writer(filename, ',');
241 file_writer.open();
242
243 std::vector<std::string> col_names(1 + q_table_.cols());
244 col_names[0] = "state_index";
245
246 for(uint_t i = 0; i< static_cast<uint_t>(q_table_.cols()); ++i){
247 col_names[i + 1] = "action_" + std::to_string(i);
248 }
249
250 file_writer.write_column_names(col_names);
251
252 for(uint_t s=0; s < static_cast<uint_t>(q_table_.rows()); ++s){
253 auto actions = maths::get_row(q_table_, s);
254 auto row = std::make_tuple(s, actions);
255 file_writer.write_row(row);
256 }
257 }
258
259
260 template<envs::discrete_world_concept EnvTp, typename PolicyType>
270
271 template <envs::discrete_world_concept EnvTp, typename PolicyType>
272 void
273 QLearningSolver<EnvTp, PolicyType>::update_q_table_(const action_type& action, const state_type& cstate,
274 const state_type& next_state,
275 const action_type& /*next_action*/, real_t reward){
276
277 auto q_current = q_table_(cstate, action);
278 auto q_next = next_state != bitrl::consts::INVALID_ID ? cuberl::maths::get_row_max(q_table_, next_state) : 0.0;
279
280
281 auto td_target = reward + config_.gamma * q_next;
282 q_table_(cstate, action) = q_current + (config_.eta * (td_target - q_current));
283
284 }
285
286
287
288}
289}
290
291#endif // Q_LEARNING_H
The CSVWriter class. Handles writing into CSV file format.
Definition csv_file_writer.h:22
void write_column_names(const std::vector< std::string > &col_names, bool write_header=true)
Write the column names.
Definition csv_file_writer.cpp:16
void write_row(const std::vector< T > &vals)
Write a row of the file.
Definition csv_file_writer.h:89
virtual void open() override
Open the file for writing.
Definition file_writer_base.cpp:21
The QLearning class. Table based implementation of the Q-learning algorithm using epsilon-greedy poli...
Definition q_learning.h:48
TDAlgoBase< EnvTp >::env_type env_type
env_t
Definition q_learning.h:55
virtual EpisodeInfo on_training_episode(env_type &, uint_t episode_idx)
on_episode Do one on_episode of the algorithm
Definition q_learning.h:175
TDAlgoBase< EnvTp >::state_type state_type
state_t
Definition q_learning.h:65
QLearningSolver(const QLearningConfig config, const PolicyType &policy)
Constructor.
Definition q_learning.h:143
cuberl::rl::policies::MaxTabularPolicy build_policy() const
Build the policy after training.
Definition q_learning.h:262
virtual void actions_before_training_begins(env_type &)
actions_before_training_begins. Execute any actions the algorithm needs before starting the iteration...
Definition q_learning.h:154
TDAlgoBase< EnvTp >::action_type action_type
action_t
Definition q_learning.h:60
void save(const std::string &filename) const
Save the state-action function in a CSV format.
Definition q_learning.h:238
virtual void actions_after_episode_ends(env_type &, uint_t episode_idx, const EpisodeInfo &)
actions_after_training_episode
Definition q_learning.h:231
virtual void actions_before_episode_begins(env_type &, uint_t)
actions_before_training_episode
Definition q_learning.h:92
PolicyType policy_type
action_selector_t
Definition q_learning.h:70
virtual void actions_after_training_ends(env_type &)
actions_after_training_ends. Actions to execute after the training iterations have finisehd
Definition q_learning.h:165
The TDAlgoBase class. Base class for deriving TD algorithms.
Definition td_algo_base.h:19
env_type::action_type action_type
action_t
Definition td_algo_base.h:30
env_type::state_type state_type
state_t
Definition td_algo_base.h:35
EnvType env_type
env_t
Definition td_algo_base.h:25
class MaxTabularPolicy
Definition max_tabular_policy.h:30
const uint_t INVALID_ID
Invalid id.
Definition bitrl_consts.h:21
const std::string INVALID_STR
Invalid string.
Definition bitrl_consts.h:26
double real_t
real_t
Definition bitrl_types.h:23
std::size_t uint_t
uint_t
Definition bitrl_types.h:43
Eigen::MatrixX< T > DynMat
Dynamically sized matrix to use around the library.
Definition bitrl_types.h:49
T get_row_max(const DynMat< T > &matrix, uint_t row_idx)
Definition matrix_utilities.h:136
DynVec< T > get_row(const DynMat< T > &matrix, uint_t row_idx)
Extract the cidx-th column from the matrix.
Definition matrix_utilities.h:130
Various utilities used when working with RL problems.
Definition cuberl_types.h:16
The EpisodeInfo struct.
Definition episode_info.h:19
The QLearningConfig struct.
Definition q_learning.h:28
std::string path
Definition q_learning.h:35
uint_t max_num_iterations_per_episode
Definition q_learning.h:31
uint_t n_episodes
Definition q_learning.h:30
real_t gamma
Definition q_learning.h:33
real_t eta
Definition q_learning.h:34
real_t tolerance
Definition q_learning.h:32
bool average_episode_reward
Definition q_learning.h:29
Definition max_tabular_policy.h:125
void build_from_state_action_function(const DynMat< real_t > &q, MaxTabularPolicy &policy)