bitrl & cuberl Documentation
Simulation engine for reinforcement learning agents
Loading...
Searching...
No Matches
pytorch_loss_wrapper.h
Go to the documentation of this file.
1#ifndef PYTORCH_LOSS_WRAPPER_H
2#define PYTORCH_LOSS_WRAPPER_H
3
4#include "cuberl/base/cubeai_config.h"
5
6#ifdef USE_PYTORCH
7
10#include <torch/torch.h>
11
12
13
14
15namespace cuberl {
16namespace utils{
17namespace pytorch {
18
19 using namespace cubeai::utils;
20
24class PyTorchLossWrapper
25{
26public:
27
32 PyTorchLossWrapper(LossType type);
33
40 torch_tensor_t calculate(torch_tensor_t input, torch_tensor_t target)const;
41
42private:
43
47 LossType type_;
48
49};
50
51}
52}
53}
54
55#endif
56#endif // PYTORCH_LOSS_WRAPPER_H
Definition loss_type.h:8
Various utilities used when working with RL problems.
Definition cuberl_types.h:16