bitrl & cuberl Documentation
Simulation engine for reinforcement learning agents
Loading...
Searching...
No Matches
max_tabular_policy.h
Go to the documentation of this file.
1#ifndef MAX_TABULAR_POLICY_H
2#define MAX_TABULAR_POLICY_H
3
4#include "cuberl/base/cubeai_config.h"
8
9#ifdef USE_PYTORCH
10#include <torch/torch.h>
11#endif
12
13#include <type_traits>
14#include <vector>
15#include <string>
16#include <iostream>
17
18namespace cuberl {
19namespace rl {
20namespace policies {
21
24struct MaxTabularPolicyBuilder;
25
30{
31public:
32
39
43 template<typename MatType>
44 static output_type get_action(const MatType& q_map, uint_t state_idx);
45
51 template<typename VecTp>
52 static output_type get_action(const VecTp& q_map);
53
54#ifdef USE_PYTORCH
60 static output_type get_action(const torch_tensor_t& vec);
61#endif
62
68
73
78 void on_episode(uint_t)noexcept{}
79
83 void reset()noexcept{state_action_map_.clear();}
84
88 action_type on_state(state_type s)const{return state_action_map_[s];}
89
93 void save(const std::string& filename)const;
94
95private:
96
100 std::vector<uint_t> state_action_map_;
101};
102
103
104#ifdef USE_PYTORCH
105inline
106uint_t
107MaxTabularPolicy::get_action(const torch_tensor_t& vec){
108 return torch::argmax(vec).item<uint_t>();
109}
110#endif
111
112
113template<typename VecTp>
116
117 return std::distance(vec.begin(),
118 std::max_element(vec.begin(),
119 vec.end()));
120
121}
122
123
125{
126
127 template<typename EnvType>
128 void build_from_state_function(const EnvType& env,
129 const DynVec<real_t>& v,
130 real_t gamma,
131 MaxTabularPolicy& policy);
132
134 MaxTabularPolicy& policy);
135};
136
137template<typename EnvType>
138void
140 const DynVec<real_t>& v,
141 real_t gamma,
142 MaxTabularPolicy& policy){
143
144 static_assert(std::is_integral_v<typename EnvType::state_type>,
145 "state type must be integral");
146 static_assert(std::is_integral_v<typename EnvType::action_type>,
147 "action type must be integral");
148
149
150 typedef typename EnvType::action_type action_type;
151 policy.state_action_map_.clear();
152 policy.state_action_map_.resize(env.n_states());
153
154 for(uint_t s=0; s<env.n_states(); ++s){
155
156 auto state_vals = cuberl::rl::algos::state_actions_from_v(env, v,
157 gamma, s);
158
159 action_type action = policy.get_action(state_vals);
160 policy.state_action_map_[s] = action;
161 }
162
163}
164
165
166
167
168
169
170}
171}
172}
173
174#endif // MAX_TABULAR_POLICY_H
class MaxTabularPolicy
Definition max_tabular_policy.h:30
void reset() noexcept
Reset the policy.
Definition max_tabular_policy.h:83
uint_t action_type
Definition max_tabular_policy.h:38
void on_episode(uint_t) noexcept
any actions the policy should perform on the given episode index
Definition max_tabular_policy.h:78
MaxTabularPolicy()=default
Constructor.
void save(const std::string &filename) const
Save the state -> action map in a CSV file;.
uint_t output_type
The output type of operator()
Definition max_tabular_policy.h:36
static output_type get_action(const MatType &q_map, uint_t state_idx)
get_action. Given a
uint_t state_type
Definition max_tabular_policy.h:37
static output_type get_action(const VecTp &q_map)
get_action. Given a vector always returns the position of the maximum occuring element....
action_type on_state(state_type s) const
Get the action from the given state.
Definition max_tabular_policy.h:88
double real_t
real_t
Definition bitrl_types.h:23
Eigen::RowVectorX< T > DynVec
Dynamically sized row vector.
Definition bitrl_types.h:74
std::size_t uint_t
uint_t
Definition bitrl_types.h:43
Eigen::MatrixX< T > DynMat
Dynamically sized matrix to use around the library.
Definition bitrl_types.h:49
auto state_actions_from_v(const WorldTp &env, const DynVec< real_t > &v, real_t gamma, uint_t state) -> DynVec< real_t >
Given the state index returns the list of actions under the provided value functions.
Definition utils.h:23
Various utilities used when working with RL problems.
Definition cuberl_types.h:16
Definition max_tabular_policy.h:125
void build_from_state_action_function(const DynMat< real_t > &q, MaxTabularPolicy &policy)
void build_from_state_function(const EnvType &env, const DynVec< real_t > &v, real_t gamma, MaxTabularPolicy &policy)
Definition max_tabular_policy.h:139