| #pragma once |
| |
| #include <ATen/Tensor.h> |
| #include <torch/csrc/Export.h> |
| #include <vector> |
| |
| // A hook that's called on gradients |
| |
| namespace torch { |
| namespace autograd { |
| |
| using Variable = at::Tensor; |
| using variable_list = std::vector<Variable>; |
| |
| struct TORCH_API FunctionPreHook { |
| virtual ~FunctionPreHook() = default; |
| virtual variable_list operator()(const variable_list& grads) = 0; |
| }; |
| |
| struct TORCH_API FunctionPostHook { |
| virtual ~FunctionPostHook() = default; |
| virtual variable_list operator()( |
| const variable_list& outputs /* grad_inputs */, |
| const variable_list& inputs /* grad_outputs */) = 0; |
| }; |
| |
| } // namespace autograd |
| } // namespace torch |