bitrl & cuberl Documentation
Simulation engine for reinforcement learning agents
Loading...
Searching...
No Matches
pytorch_optimizer_factory.h
Go to the documentation of this file.
1#ifndef PYTORCH_OPTIMIZER_FACTORY_H
2#define PYTORCH_OPTIMIZER_FACTORY_H
3
4#include "cuberl/base/cubeai_config.h"
5
6#ifdef USE_PYTORCH
7
8
10#include <torch/torch.h>
11#include <memory>
12#include <map>
13#include <any>
14#include <string>
15
16#ifdef CUBERL_DEBUG
17#include <cassert>
18#endif
19
20namespace cuberl {
21namespace maths{
22namespace optim {
23namespace pytorch {
24
25
29std::unique_ptr<torch::optim::OptimizerOptions>
30build_pytorch_optimizer_options(OptimzerType type, const std::map<std::string, std::any>& options);
31
37std::unique_ptr<torch::optim::Optimizer>
38build_pytorch_optimizer(OptimzerType type, torch::nn::Module& model, const torch::optim::OptimizerOptions& options);
39
45std::unique_ptr<torch::optim::Optimizer>
46build_pytorch_optimizer(OptimzerType type, torch::nn::Module& model, std::unique_ptr<torch::optim::OptimizerOptions>& options){
47 return build_pytorch_optimizer(type, model, *options.get());
48}
49
50}
51}
52}
53}
54#endif
55#endif // PYTORCH_OPTIMIZER_FACTORY_H
Various utilities used when working with RL problems.
Definition cuberl_types.h:16