#include <torch/extension.h> | |
struct Doubler { | |
Doubler(int A, int B) { | |
tensor_ = | |
torch::ones({A, B}, torch::dtype(torch::kFloat64).requires_grad(true)); | |
} | |
torch::Tensor forward() { | |
return tensor_ * 2; | |
} | |
torch::Tensor get() const { | |
return tensor_; | |
} | |
private: | |
torch::Tensor tensor_; | |
}; |