1#ifndef TORCH_TENSOR_UTILS_H
2#define TORCH_TENSOR_UTILS_H
4#include "cuberl/base/cubeai_config.h"
8#include "cuberl/base/cubeai_types.h"
9#include "cuberl/io/torch_state_dictionary_reader.h"
10#include <torch/torch.h>
14namespace torch_utils {
21template<
typename ContainerType>
22torch_tensor_t create_mask(
const ContainerType& container){
24 return torch::tensor(container.data(), torch::dtype(torch::kBool));
30template<
typename TorchNetType>
31void copy_parameters_to( TorchNetType& to,
const std::string params_path){
34 torch::autograd::GradMode::set_enabled(
false);
39 auto new_params = TorchStateDictionaryReader(params_path);
40 auto params = to->named_parameters(
true );
41 auto buffers = to->named_buffers(
true );
60 torch::autograd::GradMode::set_enabled(
true);
Definition mc_tree_search_solver.h:22