bitrl & cuberl Documentation
Simulation engine for reinforcement learning agents
Loading...
Searching...
No Matches
pytorch_loss_wrapper.h
Go to the documentation of this file.
1
#ifndef PYTORCH_LOSS_WRAPPER_H
2
#define PYTORCH_LOSS_WRAPPER_H
3
4
#include "cuberl/base/cubeai_config.h"
5
6
#ifdef USE_PYTORCH
7
8
#include "
cuberl/base/cuberl_types.h
"
9
#include "
cuberl/utils/loss_type.h
"
10
#include <torch/torch.h>
11
12
13
14
15
namespace
cuberl
{
16
namespace
utils{
17
namespace
pytorch {
18
19
using namespace
cubeai::utils
;
20
24
class
PyTorchLossWrapper
25
{
26
public
:
27
32
PyTorchLossWrapper(LossType type);
33
40
torch_tensor_t calculate(torch_tensor_t input, torch_tensor_t target)
const
;
41
42
private
:
43
47
LossType type_;
48
49
};
50
51
}
52
}
53
}
54
55
#endif
56
#endif
// PYTORCH_LOSS_WRAPPER_H
cuberl_types.h
loss_type.h
cubeai::utils
Definition
loss_type.h:8
cuberl
Various utilities used when working with RL problems.
Definition
cuberl_types.h:16
libs
cuberl
include
cuberl
utils
pytorch_loss_wrapper.h
Generated by
1.9.8