blob: 4edd4de8fd6e5c3ccb52f1770efdcd1edc7f751b [file] [log] [blame] [edit]
#include <torch/extension.h>
struct Doubler {
Doubler(int A, int B) {
tensor_ =
torch::ones({A, B}, torch::dtype(torch::kFloat64).requires_grad(true));
}
torch::Tensor forward() {
return tensor_ * 2;
}
torch::Tensor get() const {
return tensor_;
}
private:
torch::Tensor tensor_;
};