1#ifndef MAX_TABULAR_POLICY_H
2#define MAX_TABULAR_POLICY_H
4#include "cuberl/base/cubeai_config.h"
10#include <torch/torch.h>
24struct MaxTabularPolicyBuilder;
43 template<
typename MatType>
51 template<
typename VecTp>
83 void reset()noexcept{state_action_map_.clear();}
93 void save(
const std::string& filename)
const;
100 std::vector<uint_t> state_action_map_;
108 return torch::argmax(vec).item<
uint_t>();
113template<
typename VecTp>
117 return std::distance(vec.begin(),
118 std::max_element(vec.begin(),
127 template<
typename EnvType>
137template<
typename EnvType>
144 static_assert(std::is_integral_v<typename EnvType::state_type>,
145 "state type must be integral");
146 static_assert(std::is_integral_v<typename EnvType::action_type>,
147 "action type must be integral");
150 typedef typename EnvType::action_type action_type;
151 policy.state_action_map_.clear();
152 policy.state_action_map_.resize(env.n_states());
154 for(
uint_t s=0; s<env.n_states(); ++s){
159 action_type action = policy.get_action(state_vals);
160 policy.state_action_map_[s] = action;
class MaxTabularPolicy
Definition max_tabular_policy.h:30
void reset() noexcept
Reset the policy.
Definition max_tabular_policy.h:83
uint_t action_type
Definition max_tabular_policy.h:38
void on_episode(uint_t) noexcept
any actions the policy should perform on the given episode index
Definition max_tabular_policy.h:78
MaxTabularPolicy()=default
Constructor.
void save(const std::string &filename) const
Save the state -> action map in a CSV file;.
uint_t output_type
The output type of operator()
Definition max_tabular_policy.h:36
static output_type get_action(const MatType &q_map, uint_t state_idx)
get_action. Given a
uint_t state_type
Definition max_tabular_policy.h:37
static output_type get_action(const VecTp &q_map)
get_action. Given a vector always returns the position of the maximum occuring element....
action_type on_state(state_type s) const
Get the action from the given state.
Definition max_tabular_policy.h:88
double real_t
real_t
Definition bitrl_types.h:23
Eigen::RowVectorX< T > DynVec
Dynamically sized row vector.
Definition bitrl_types.h:74
std::size_t uint_t
uint_t
Definition bitrl_types.h:43
Eigen::MatrixX< T > DynMat
Dynamically sized matrix to use around the library.
Definition bitrl_types.h:49
auto state_actions_from_v(const WorldTp &env, const DynVec< real_t > &v, real_t gamma, uint_t state) -> DynVec< real_t >
Given the state index returns the list of actions under the provided value functions.
Definition utils.h:23
Various utilities used when working with RL problems.
Definition cuberl_types.h:16
Definition max_tabular_policy.h:125
void build_from_state_action_function(const DynMat< real_t > &q, MaxTabularPolicy &policy)
void build_from_state_function(const EnvType &env, const DynVec< real_t > &v, real_t gamma, MaxTabularPolicy &policy)
Definition max_tabular_policy.h:139