1#ifndef MC_TREE_SEARCH_BASE_H
2#define MC_TREE_SEARCH_BASE_H
4#include "cubeai/base/cubeai_types.h"
5#include "cubeai/rl/algorithms/rl_algorithm_base.h"
6#include "cubeai/rl/algorithms/rl_algo_config.h"
7#include "cubeai/utils/iteration_mixin.h"
8#include "cubeai/rl/episode_info.h"
9#include "cubeai/rl/algorithms/mc/mcts_node.h"
11#include "cubeai/base/cubeai_config.h"
45template<
typename EnvTp,
typename NodeTp=MCTSNodeBase<
typename EnvTp::action_type,
typename EnvTp::state_type>>
50 static_assert (std::is_default_constructible<typename EnvTp::action_type>::value,
51 "Action type is not default constructible");
53 static_assert (std::is_default_constructible<typename EnvTp::state_type>::value,
54 "State type is not default constructible");
119 void backprop(std::shared_ptr<node_type> node);
151template<
typename EnvTp,
typename NodeTp>
154 RLSolverBase(config.n_episodes, config.tolerance),
155 itr_mix_(config.n_itrs_per_episode, config.render_episode),
157 max_depth_tree_(config.max_tree_depth),
158 temperature_(config.temperature)
162template<
typename EnvTp,
typename NodeTp>
166 auto best_reward = std::numeric_limits<real_t>::min();
170 auto time_step = env.reset();
172 while(this->itr_mix_.continue_iterations()){
177 auto node = this->root_;
179 auto time_step = this->simulate_node(node, env_copy);
181 auto terminal = time_step.done();
186 this->expand_node(node, env_copy);
191 auto action = env_copy.sample_action();
192 auto new_time_step = env_copy.step(action);
193 sum_reward_ += new_time_step.reward();
194 actions_.push_back(action);
196 if(actions_.size() > this->max_depth_tree()){
203 if( best_reward < sum_reward_){
204 best_reward = sum_reward_;
205 best_actions_ = actions_;
209 this->backprop(node);
214 for(
auto action : best_actions_){
215 auto time_step = this->env.step(action);
217 sum_reward_ += time_step.reward();
218 if(time_step.done()){
224 best_rewards_.push_back(sum_reward_);
228template<
typename EnvTp,
typename NodeTp>
234 node -> update_visits();
235 node -> update_total_score(sum_reward_);
236 node = node ->parent();
241template<
typename Env,
typename NodeTp>
245 for(uint_t a=0; a < env.n_actions(); ++a ){
246 node -> add_child(node, a);
249 node->shuffle_children();
254template<
typename Env,
typename NodeTp>
259 while(node -> has_children()){
261 if(node -> n_explored_children() < node -> n_children() ){
263 auto child = node->get_child(node->n_explored_children());
264 node ->update_explored_children();
269 node = node -> max_ucb_child(this->temperature_);
273 time_step = env.step(node ->get_action());
274 sum_reward_ += time_step.reward();
275 actions_.push_back(node ->get_action());
MCTreeSearchBase.
Definition mc_tree_search_solver.h:47
virtual void actions_after_training_ends(env_type &) override final
actions_after_training_ends. Actions to execute after the training iterations have finisehd
Definition mc_tree_search_solver.h:87
virtual void actions_before_episode_begins(env_type &, uint_t)
actions_before_training_episode
Definition mc_tree_search_solver.h:92
env_type::time_step_type time_step_type
The time step type.
Definition mc_tree_search_solver.h:69
NodeTp node_type
node_type
Definition mc_tree_search_solver.h:64
virtual void actions_before_training_begins(env_type &) override final
actions_before_training_begins. Execute any actions the algorithm needs before starting the iteration...
Definition mc_tree_search_solver.h:81
void expand_node(std::shared_ptr< node_type > node, env_type &env)
expand_node
Definition mc_tree_search_solver.h:243
std::shared_ptr< node_type > root_
root_
Definition mc_tree_search_solver.h:137
virtual void actions_after_episode_ends(env_type &, uint_t, const EpisodeInfo &)
actions_after_training_episode
Definition mc_tree_search_solver.h:97
void backprop(std::shared_ptr< node_type > node)
backprop
Definition mc_tree_search_solver.h:230
MCTSSolver(MCTreeSearchConfig config)
MCTreeSearchBase.
Definition mc_tree_search_solver.h:152
EnvTp env_type
env_type
Definition mc_tree_search_solver.h:59
uint_t max_depth_tree_
max_depth_tree_
Definition mc_tree_search_solver.h:142
IterationMixin itr_mix_
itr_mix_
Definition mc_tree_search_solver.h:132
virtual EpisodeInfo on_training_episode(env_type &, uint_t) override final
on_episode Do one on_episode of the algorithm
Definition mc_tree_search_solver.h:164
real_t temperature_
temperature_
Definition mc_tree_search_solver.h:147
uint_t max_depth_tree() const noexcept
max_depth_tree
Definition mc_tree_search_solver.h:125
time_step_type simulate_node(std::shared_ptr< node_type > node, env_type &env)
simulate_node
Definition mc_tree_search_solver.h:256
Definition mc_tree_search_solver.h:22
The MCTreeSearchConfig struct.
Definition mc_tree_search_solver.h:32
uint_t max_tree_depth
Definition mc_tree_search_solver.h:33
real_t temperature
Definition mc_tree_search_solver.h:34