| # Owner(s): ["module: dynamo"] |
| import torch |
| |
| import torch._dynamo |
| import torch._dynamo.test_case |
| |
| |
| class PreDispatchTests(torch._dynamo.test_case.TestCase): |
| def test_no_grad_simple(self): |
| def f(a): |
| b = a.sin() |
| with torch.no_grad(): |
| c = b.cos() |
| return b * c.sin() |
| |
| f_compiled = torch.compile(f, backend="pre_dispatch_eager") |
| |
| a_ref = torch.randn(4, requires_grad=True) |
| a_test = a_ref.clone().detach().requires_grad_(True) |
| |
| out_ref = f(a_ref) |
| out_test = f_compiled(a_test) |
| self.assertEqual(out_ref, out_test) |
| |
| out_ref.sum().backward() |
| out_test.sum().backward() |
| self.assertEqual(a_ref.grad, a_test.grad) |
| |
| def test_enable_grad_and_no_grad(self): |
| def f(a): |
| b = a * 2 |
| with torch.no_grad(): |
| c = b * 3 |
| with torch.enable_grad(): |
| d = c * 4 |
| e = d * 5 |
| return b + c + d + e |
| |
| f_compiled = torch.compile(f, backend="pre_dispatch_eager") |
| |
| a_ref = torch.randn(4, requires_grad=True) |
| a_test = a_ref.clone().detach().requires_grad_(True) |
| |
| out_ref = f(a_ref) |
| out_test = f_compiled(a_test) |
| self.assertEqual(out_ref, out_test) |
| |
| out_ref.sum().backward() |
| out_test.sum().backward() |
| self.assertEqual(a_ref.grad, a_test.grad) |
| |
| def test_autocast_simple(self): |
| def f(a): |
| b = a * 2 |
| with torch.amp.autocast(device_type="cpu"): |
| c = torch.matmul(b, b) |
| return b + c |
| |
| f_compiled = torch.compile(f, backend="pre_dispatch_eager") |
| |
| a_ref = torch.randn(4, device="cpu", requires_grad=True) |
| a_test = a_ref.clone().detach().requires_grad_(True) |
| |
| out_ref = f(a_ref) |
| out_test = f_compiled(a_test) |
| self.assertEqual(out_ref, out_test) |
| |
| out_ref.sum().backward() |
| out_test.sum().backward() |
| self.assertEqual(a_ref.grad, a_test.grad) |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| |
| run_tests() |