bitrl & cuberl Documentation
Simulation engine for reinforcement learning agents
Loading...
Searching...
No Matches
mcts_node.h
Go to the documentation of this file.
1#ifndef MCTS_NODE_H
2#define MCTS_NODE_H
3
4#include "cubeai/base/cubeai_types.h"
5#include <memory>
6
7namespace cubeai{
8namespace rl{
9namespace algos{
10namespace mc{
11
15template<typename ActionTp, typename StateTp>
17{
18public:
19
20 typedef ActionTp action_type;
21 typedef StateTp state_type;
22
28 MCTSNodeBase(std::shared_ptr<MCTSNodeBase> parent, action_type action);
29
33 virtual ~MCTSNodeBase()=default;
34
39 void add_child(std::shared_ptr<MCTSNodeBase<ActionTp, StateTp>> child);
40
46 void add_child(std::shared_ptr<MCTSNodeBase<ActionTp, StateTp>> parent, action_type action);
47
52 bool has_children()const noexcept{return children_.empty() != true;}
53
59 std::shared_ptr<MCTSNodeBase> get_child(uint_t cidx){return children_[cidx];}
60
65 uint_t n_children()const noexcept{return children_.size();}
66
70 void shuffle_children()noexcept;
71
76 uint_t n_explored_children()const noexcept{return explored_children_;}
77
81 void update_visits()noexcept{total_visits_ += 1;}
82
87
92 void update_total_score(real_t score)noexcept {total_score_ += score;}
93
99 real_t ucb(real_t temperature )const;
100
105 std::shared_ptr<MCTSNodeBase> max_ucb_child(real_t temperature)const;
106
111 real_t win_pct()const{return total_score_ / total_visits_ ;}
112
117 uint_t total_visits()const noexcept{return total_visits_;}
118
123 uint_t get_action()const noexcept{return action_;}
124
129 std::shared_ptr<MCTSNodeBase> parent(){return parent_;}
130
131protected:
132
137
142
147
152
156 std::shared_ptr<MCTSNodeBase> parent_;
157
161 std::vector<std::shared_ptr<MCTSNodeBase>> children_;
162
163};
164
165template<typename ActionTp, typename StateTp>
166MCTSNodeBase<ActionTp, StateTp>::MCTSNodeBase(std::shared_ptr<MCTSNodeBase> parent, action_type action)
167 :
168 total_score_(0.0),
169 total_visits_(0),
170 explored_children_(0),
171 action_(action),
172 parent_(parent),
173 children_()
174{}
175
176
177template<typename ActionTp, typename StateTp>
178void
180
181#ifdef CUBEAI_DEBUG
182 assert(child != nullptr && "Cannot add null children");
183#endif
184
185 children_.push_back(child);
186
187}
188
189template<typename ActionTp, typename StateTp>
190void
192
193 if(children_.empty()){
194 return;
195 }
196
197 std::random_shuffle(children_.begin(), children_.end());
198}
199
200template<typename ActionTp, typename StateTp>
201void
203 action_type action){
204
205 add_child(std::make_shared<MCTSNodeBase<ActionTp, StateTp>>(parent, action));
206}
207
208template<typename ActionTp, typename StateTp>
209std::shared_ptr<MCTSNodeBase<ActionTp, StateTp>>
211
212 if(children_.empty()){
213 return std::shared_ptr<MCTSNodeBase>();
214 }
215
216 auto current_ucb = children_[0]->ucb(temperature);
217 auto current_idx = 0;
218 auto dummy_idx = 0;
219 for(auto node : children_){
220
221 auto node_ucb = node->ucb(temperature);
222
223 if(node_ucb > current_ucb){
224
225 current_idx = dummy_idx;
226 current_ucb = node_ucb;
227 }
228
229 dummy_idx += 1;
230 }
231
232 return children_[current_idx];
233}
234
235template<typename ActionTp, typename StateTp>
236real_t
237MCTSNodeBase<ActionTp, StateTp>::ucb(real_t temperature )const{
238 return win_pct() + temperature * std::sqrt(std::log(parent_ -> total_visits()) / total_visits_);
239}
240
241}
242}
243}
244}
245#endif
Base class for Nodes in a MC tree search.
Definition mcts_node.h:17
std::vector< std::shared_ptr< MCTSNodeBase > > children_
children_
Definition mcts_node.h:161
void update_visits() noexcept
update_visits
Definition mcts_node.h:81
MCTSNodeBase(std::shared_ptr< MCTSNodeBase > parent, action_type action)
MCTreeNodeBase.
Definition mcts_node.h:166
uint_t get_action() const noexcept
get_action
Definition mcts_node.h:123
action_type action_
action_
Definition mcts_node.h:151
bool has_children() const noexcept
has_children
Definition mcts_node.h:52
void add_child(std::shared_ptr< MCTSNodeBase< ActionTp, StateTp > > child)
add_child
Definition mcts_node.h:179
ActionTp action_type
Definition mcts_node.h:20
StateTp state_type
Definition mcts_node.h:21
real_t win_pct() const
win_pct
Definition mcts_node.h:111
void update_explored_children() noexcept
Definition mcts_node.h:86
uint_t explored_children_
explored_children_
Definition mcts_node.h:146
uint_t total_visits() const noexcept
total_visits
Definition mcts_node.h:117
real_t ucb(real_t temperature) const
ucb
Definition mcts_node.h:237
std::shared_ptr< MCTSNodeBase > get_child(uint_t cidx)
get_child
Definition mcts_node.h:59
std::shared_ptr< MCTSNodeBase > parent()
parent
Definition mcts_node.h:129
real_t total_score_
total_score_
Definition mcts_node.h:136
uint_t n_children() const noexcept
n_children
Definition mcts_node.h:65
void shuffle_children() noexcept
shuffle_children
Definition mcts_node.h:191
std::shared_ptr< MCTSNodeBase > max_ucb_child(real_t temperature) const
max_ucb_child
Definition mcts_node.h:210
uint_t n_explored_children() const noexcept
explored_children
Definition mcts_node.h:76
uint_t total_visits_
total_visits_
Definition mcts_node.h:141
std::shared_ptr< MCTSNodeBase > parent_
parent_
Definition mcts_node.h:156
void update_total_score(real_t score) noexcept
update_total_score
Definition mcts_node.h:92
Definition mc_tree_search_solver.h:22