# Owner(s): ["oncall: jit"] | |
import torch | |
from torch.testing._internal.jit_utils import JitTestCase | |
class TestFuserCommon(JitTestCase): | |
def test_autodiff_fallback(self): | |
for rq in [True, False]: | |
@torch.jit.script | |
def fn(x): | |
return torch.max(x**2.0, x**3.0) | |
x = torch.randn(5, requires_grad=not rq) | |
# cause optimization to be created | |
for i in range(5): | |
fn(x) | |
# test fallback when optimization is not applicable | |
y = fn(torch.randn(5, requires_grad=rq)) | |
self.assertEqual(y.requires_grad, rq) |