blob: 4a3a3eca384497b093021033b728290291790b39 [file] [log] [blame] [edit]
# Owner(s): ["module: onnx"]
import torch
# Autograd funtion that is a replica of the autograd funtion in
# test_utility_funs.py (test_autograd_module_name)
class CustomFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return input.clamp(min=0)
@staticmethod
def backward(ctx, grad_output):
(input,) = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input < 0] = 0
return grad_input