4#include "cubeai/base/cubeai_types.h"
15template<
typename ActionTp,
typename StateTp>
99 real_t
ucb(real_t temperature )
const;
105 std::shared_ptr<MCTSNodeBase>
max_ucb_child(real_t temperature)
const;
165template<
typename ActionTp,
typename StateTp>
170 explored_children_(0),
177template<
typename ActionTp,
typename StateTp>
182 assert(child !=
nullptr &&
"Cannot add null children");
185 children_.push_back(child);
189template<
typename ActionTp,
typename StateTp>
193 if(children_.empty()){
197 std::random_shuffle(children_.begin(), children_.end());
200template<
typename ActionTp,
typename StateTp>
208template<
typename ActionTp,
typename StateTp>
209std::shared_ptr<MCTSNodeBase<ActionTp, StateTp>>
212 if(children_.empty()){
213 return std::shared_ptr<MCTSNodeBase>();
216 auto current_ucb = children_[0]->ucb(temperature);
217 auto current_idx = 0;
219 for(
auto node : children_){
221 auto node_ucb = node->ucb(temperature);
223 if(node_ucb > current_ucb){
225 current_idx = dummy_idx;
226 current_ucb = node_ucb;
232 return children_[current_idx];
235template<
typename ActionTp,
typename StateTp>
238 return win_pct() + temperature * std::sqrt(std::log(parent_ -> total_visits()) / total_visits_);
Base class for Nodes in a MC tree search.
Definition mcts_node.h:17
std::vector< std::shared_ptr< MCTSNodeBase > > children_
children_
Definition mcts_node.h:161
void update_visits() noexcept
update_visits
Definition mcts_node.h:81
MCTSNodeBase(std::shared_ptr< MCTSNodeBase > parent, action_type action)
MCTreeNodeBase.
Definition mcts_node.h:166
virtual ~MCTSNodeBase()=default
uint_t get_action() const noexcept
get_action
Definition mcts_node.h:123
action_type action_
action_
Definition mcts_node.h:151
bool has_children() const noexcept
has_children
Definition mcts_node.h:52
void add_child(std::shared_ptr< MCTSNodeBase< ActionTp, StateTp > > child)
add_child
Definition mcts_node.h:179
ActionTp action_type
Definition mcts_node.h:20
StateTp state_type
Definition mcts_node.h:21
real_t win_pct() const
win_pct
Definition mcts_node.h:111
void update_explored_children() noexcept
Definition mcts_node.h:86
uint_t explored_children_
explored_children_
Definition mcts_node.h:146
uint_t total_visits() const noexcept
total_visits
Definition mcts_node.h:117
real_t ucb(real_t temperature) const
ucb
Definition mcts_node.h:237
std::shared_ptr< MCTSNodeBase > get_child(uint_t cidx)
get_child
Definition mcts_node.h:59
std::shared_ptr< MCTSNodeBase > parent()
parent
Definition mcts_node.h:129
real_t total_score_
total_score_
Definition mcts_node.h:136
uint_t n_children() const noexcept
n_children
Definition mcts_node.h:65
void shuffle_children() noexcept
shuffle_children
Definition mcts_node.h:191
std::shared_ptr< MCTSNodeBase > max_ucb_child(real_t temperature) const
max_ucb_child
Definition mcts_node.h:210
uint_t n_explored_children() const noexcept
explored_children
Definition mcts_node.h:76
uint_t total_visits_
total_visits_
Definition mcts_node.h:141
std::shared_ptr< MCTSNodeBase > parent_
parent_
Definition mcts_node.h:156
void update_total_score(real_t score) noexcept
update_total_score
Definition mcts_node.h:92
Definition mc_tree_search_solver.h:22