bitrl & cuberl Documentation
Simulation engine for reinforcement learning agents
Loading...
Searching...
No Matches
normal_dist.h
Go to the documentation of this file.
1#ifndef NORMAL_DIST_H
2#define NORMAL_DIST_H
3
5#include "bitrl/bitrl_types.h"
6
7#include <cmath>
8#include <numbers>
9#include <random>
10#include <type_traits>
11#include <vector>
12
13namespace bitrl
14{
15namespace utils::maths::stats
16{
17
23template <typename RealType = real_t> class NormalDist
24{
25
26public:
27 static_assert(std::is_floating_point<RealType>::value, "Not a floating point type");
28
33
37 NormalDist();
38
42 explicit NormalDist(result_type mu, result_type std = 1.0);
43
48
52 result_type sample() const;
53
57 result_type sample(uint_t seed) const;
58
62 std::vector<result_type> sample_many(uint_t size) const;
63
67 std::vector<result_type> sample_many(uint_t size, uint_t seed) const;
68
72 result_type mean() const { return dist_.mean(); }
73
77 result_type std() const { return dist_.stddev(); }
78
79private:
85 mutable std::normal_distribution<RealType> dist_;
86};
87
88template <typename RealType>
92
93template <typename RealType> NormalDist<RealType>::NormalDist() : NormalDist<RealType>(0.0, 1.0) {}
94
95template <typename RealType> RealType NormalDist<RealType>::sample() const
96{
97
98 std::random_device rd{};
99 std::mt19937 gen{rd()};
100 return dist_(gen);
101}
102
103template <typename RealType> RealType NormalDist<RealType>::sample(uint_t seed) const
104{
105
106 std::mt19937 gen{seed};
107 return dist_(gen);
108}
109
110template <typename RealType>
111std::vector<RealType> NormalDist<RealType>::sample_many(uint_t size) const
112{
113
114 std::vector<RealType> samples(size);
115 std::random_device rd{};
116 std::mt19937 gen{rd()};
117
118 for (uint_t i = 0; i < size; ++i)
119 {
120 samples[i] = dist_(gen);
121 }
122
123 return samples;
124}
125
126template <typename RealType>
127std::vector<RealType> NormalDist<RealType>::sample_many(uint_t size, uint_t seed) const
128{
129
130 std::vector<RealType> samples(size);
131 std::mt19937 gen(seed);
132
133 for (uint_t i = 0; i < size; ++i)
134 {
135 samples[i] = dist_(gen);
136 }
137
138 return samples;
139}
140
141template <typename RealType> RealType NormalDist<RealType>::pdf(RealType x) const
142{
143
144 auto mu = dist_.mean();
145 auto std = dist_.stddev();
147 auto factor = 1.0 / (std * std::sqrt(2.0 * pi));
148 auto exp = std::exp(-0.5 * std::pow((x - mu) / std, 2.0));
149 return factor * exp;
150}
151
152}
153} // namespace rlenvscpp
154
155#endif
Wrapper to std::normal_distribution to facilitate sampling multiple values, sampling with a given see...
Definition normal_dist.h:24
result_type mean() const
The mean value of the distribution.
Definition normal_dist.h:72
result_type std() const
The STD of the distribution.
Definition normal_dist.h:77
std::vector< result_type > sample_many(uint_t size) const
sample from the distribution
Definition normal_dist.h:111
NormalDist()
Constructor.
Definition normal_dist.h:93
result_type pdf(result_type x) const
compute the value of the PDF at the given point
Definition normal_dist.h:141
RealType result_type
Definition normal_dist.h:32
result_type sample() const
Sample from the distribution.
Definition normal_dist.h:95
const real_t PI
The PI constant.
Definition bitrl_consts.h:49
OutT resolve(const std::string &name, const std::map< std::string, std::any > &input)
Definition std_map_utils.h:25
Definition bitrl_consts.h:14
std::size_t uint_t
uint_t
Definition bitrl_types.h:43