bitrl & cuberl Documentation
Simulation engine for reinforcement learning agents
Loading...
Searching...
No Matches
mc_tree_search_solver.h
Go to the documentation of this file.
1#ifndef MC_TREE_SEARCH_BASE_H
2#define MC_TREE_SEARCH_BASE_H
3
4#include "cubeai/base/cubeai_types.h"
5#include "cubeai/rl/algorithms/rl_algorithm_base.h"
6#include "cubeai/rl/algorithms/rl_algo_config.h"
7#include "cubeai/utils/iteration_mixin.h"
8#include "cubeai/rl/episode_info.h"
9#include "cubeai/rl/algorithms/mc/mcts_node.h"
10
11#include "cubeai/base/cubeai_config.h"
12
13#ifdef CUBEAI_DEBUG
14#include <cassert>
15#endif
16
17#include <vector>
18#include <memory>
19#include <cmath>
20#include <algorithm>
21
22namespace cubeai{
23namespace rl{
24namespace algos{
25namespace mc{
26
27
31struct MCTreeSearchConfig: public RLAlgoConfig
32{
33 uint_t max_tree_depth = 1000;
34 real_t temperature = 1.0;
35
36};
37
38
39
40
41
45template<typename EnvTp, typename NodeTp=MCTSNodeBase<typename EnvTp::action_type, typename EnvTp::state_type>>
46class MCTSSolver final: public RLSolverBase<EnvTp>
47{
48public:
49
50 static_assert (std::is_default_constructible<typename EnvTp::action_type>::value,
51 "Action type is not default constructible");
52
53 static_assert (std::is_default_constructible<typename EnvTp::state_type>::value,
54 "State type is not default constructible");
55
59 typedef EnvTp env_type;
60
64 typedef NodeTp node_type;
65
69 typedef typename env_type::time_step_type time_step_type;
70
76
81 virtual void actions_before_training_begins(env_type&)override final{}
82
87 virtual void actions_after_training_ends(env_type&) override final{}
88
92 virtual void actions_before_episode_begins(env_type&, uint_t /*episode_idx*/){}
93
97 virtual void actions_after_episode_ends(env_type&, uint_t /*episode_idx*/, const EpisodeInfo& /*einfo*/){}
98
102 virtual EpisodeInfo on_training_episode(env_type&, uint_t /*episode_idx*/) override final;
103
108 time_step_type simulate_node(std::shared_ptr<node_type> node, env_type& env);
109
114 void expand_node(std::shared_ptr<node_type> node, env_type& env);
115
119 void backprop(std::shared_ptr<node_type> node);
120
125 uint_t max_depth_tree()const noexcept{return max_depth_tree_;}
126
127protected:
128
132 IterationMixin itr_mix_;
133
137 std::shared_ptr<node_type> root_;
138
143
148
149};
150
151template<typename EnvTp, typename NodeTp>
153 :
154 RLSolverBase(config.n_episodes, config.tolerance),
155 itr_mix_(config.n_itrs_per_episode, config.render_episode),
156 root_(nullptr),
157 max_depth_tree_(config.max_tree_depth),
158 temperature_(config.temperature)
159{}
160
161
162template<typename EnvTp, typename NodeTp>
163EpisodeInfo
165
166 auto best_reward = std::numeric_limits<real_t>::min();
167
168 // on every episode we reset
169 // we may want to reset the copy here
170 auto time_step = env.reset();
171
172 while(this->itr_mix_.continue_iterations()){
173
174 sum_reward_ = 0.0;
175 actions_.clear();
176
177 auto node = this->root_;
178
179 auto time_step = this->simulate_node(node, env_copy);
180
181 auto terminal = time_step.done();
182
183 if(!terminal){
184 // expand the node if this is not
185 // terminal
186 this->expand_node(node, env_copy);
187 }
188
189 // creating exhaustive list of actions
190 while(!terminal){
191 auto action = env_copy.sample_action();
192 auto new_time_step = env_copy.step(action);
193 sum_reward_ += new_time_step.reward();
194 actions_.push_back(action);
195
196 if(actions_.size() > this->max_depth_tree()){
197 sum_reward_ -= 100;
198 break;
199 }
200 }
201
202 // do some book keeping retaining the best reward value and actions
203 if( best_reward < sum_reward_){
204 best_reward = sum_reward_;
205 best_actions_ = actions_;
206 }
207
208 // backpropagating in MCTS for assigning reward value to a node.
209 this->backprop(node);
210 }
211
212 // step in the best actions
213 sum_reward_ = 0.;
214 for(auto action : best_actions_){
215 auto time_step = this->env.step(action);
216
217 sum_reward_ += time_step.reward();
218 if(time_step.done()){
219 break;
220 }
221 }
222
223
224 best_rewards_.push_back(sum_reward_);
225}
226}
227
228template<typename EnvTp, typename NodeTp>
229void
230MCTSSolver<EnvTp, NodeTp>::backprop(std::shared_ptr<node_type> node){
231
232 while(node){
233
234 node -> update_visits();
235 node -> update_total_score(sum_reward_);
236 node = node ->parent();
237 }
238
239}
240
241template<typename Env, typename NodeTp>
242void
243MCTSSolver<Env>::expand_node(std::shared_ptr<node_type> node, env_type& env){
244
245 for(uint_t a=0; a < env.n_actions(); ++a ){
246 node -> add_child(node, a);
247 }
248
249 node->shuffle_children();
250
251}
252
253
254template<typename Env, typename NodeTp>
256MCTSSolver<Env, NodeTp>::simulate_node(std::shared_ptr<node_type> node, env_type& env){
257
259 while(node -> has_children()){
260
261 if(node -> n_explored_children() < node -> n_children() ){
262
263 auto child = node->get_child(node->n_explored_children());
264 node ->update_explored_children();
265 node = child;
266 }
267 else{
268
269 node = node -> max_ucb_child(this->temperature_);
270
271 }
272
273 time_step = env.step(node ->get_action());
274 sum_reward_ += time_step.reward();
275 actions_.push_back(node ->get_action());
276 }
277
278 return time_step;
279}
280
281
282}
283}
284}
285}
286
287#endif // MC_TREE_SEARCH_BASE_H
MCTreeSearchBase.
Definition mc_tree_search_solver.h:47
virtual void actions_after_training_ends(env_type &) override final
actions_after_training_ends. Actions to execute after the training iterations have finisehd
Definition mc_tree_search_solver.h:87
virtual void actions_before_episode_begins(env_type &, uint_t)
actions_before_training_episode
Definition mc_tree_search_solver.h:92
env_type::time_step_type time_step_type
The time step type.
Definition mc_tree_search_solver.h:69
NodeTp node_type
node_type
Definition mc_tree_search_solver.h:64
virtual void actions_before_training_begins(env_type &) override final
actions_before_training_begins. Execute any actions the algorithm needs before starting the iteration...
Definition mc_tree_search_solver.h:81
void expand_node(std::shared_ptr< node_type > node, env_type &env)
expand_node
Definition mc_tree_search_solver.h:243
std::shared_ptr< node_type > root_
root_
Definition mc_tree_search_solver.h:137
virtual void actions_after_episode_ends(env_type &, uint_t, const EpisodeInfo &)
actions_after_training_episode
Definition mc_tree_search_solver.h:97
void backprop(std::shared_ptr< node_type > node)
backprop
Definition mc_tree_search_solver.h:230
MCTSSolver(MCTreeSearchConfig config)
MCTreeSearchBase.
Definition mc_tree_search_solver.h:152
EnvTp env_type
env_type
Definition mc_tree_search_solver.h:59
uint_t max_depth_tree_
max_depth_tree_
Definition mc_tree_search_solver.h:142
IterationMixin itr_mix_
itr_mix_
Definition mc_tree_search_solver.h:132
virtual EpisodeInfo on_training_episode(env_type &, uint_t) override final
on_episode Do one on_episode of the algorithm
Definition mc_tree_search_solver.h:164
real_t temperature_
temperature_
Definition mc_tree_search_solver.h:147
uint_t max_depth_tree() const noexcept
max_depth_tree
Definition mc_tree_search_solver.h:125
time_step_type simulate_node(std::shared_ptr< node_type > node, env_type &env)
simulate_node
Definition mc_tree_search_solver.h:256
Definition mc_tree_search_solver.h:22
The MCTreeSearchConfig struct.
Definition mc_tree_search_solver.h:32
uint_t max_tree_depth
Definition mc_tree_search_solver.h:33
real_t temperature
Definition mc_tree_search_solver.h:34