| # Owner(s): ["module: cuda graphs"] |
| |
| import functools |
| import unittest |
| |
| import torch |
| import torch._dynamo |
| import torch._dynamo.config |
| import torch._dynamo.test_case |
| import torch._dynamo.testing |
| from torch._dynamo.testing import same |
| from torch.testing._internal.common_utils import TEST_CUDA_GRAPH |
| |
| |
| def composed(*decs): |
| def deco(f): |
| for dec in reversed(decs): |
| f = dec(f) |
| return f |
| |
| return deco |
| |
| |
| def assert_aot_autograd_counter(ok=True): |
| def deco(f): |
| @functools.wraps(f) |
| def wrap(self, *args, **kwargs): |
| torch._dynamo.utils.counters.clear() |
| r = f(self, *args, **kwargs) |
| c_ok = torch._dynamo.utils.counters["aot_autograd"]["ok"] |
| c_not_ok = torch._dynamo.utils.counters["aot_autograd"]["not_ok"] |
| if ok: |
| self.assertGreater(c_ok, 0) |
| self.assertEqual(c_not_ok, 0) |
| else: |
| self.assertEqual(c_ok, 0) |
| self.assertGreater(c_not_ok, 0) |
| return r |
| |
| return wrap |
| |
| return deco |
| |
| |
| def patch_all(ok=True): |
| return composed( |
| torch._dynamo.config.patch( |
| verify_correctness=True, automatic_dynamic_shapes=True |
| ), |
| assert_aot_autograd_counter(ok), |
| ) |
| |
| |
| N_ITERS = 5 |
| |
| |
| @unittest.skipIf(not torch.cuda.is_available(), "these tests require cuda") |
| class TestAotCudagraphs(torch._dynamo.test_case.TestCase): |
| @patch_all() |
| def test_basic(self): |
| def model(x, y): |
| return (x + y) * y |
| |
| @torch._dynamo.optimize("cudagraphs") |
| def fn(x, y): |
| for i in range(N_ITERS): |
| loss = model(x, y).sum() |
| loss.backward() |
| |
| x = torch.randn(3, device="cuda", requires_grad=True) |
| y = torch.randn(3, device="cuda") |
| fn(x, y) |
| |
| @patch_all() |
| def test_dtoh(self): |
| def model(x, y): |
| a = x + y |
| b = a.cpu() * 3 |
| return b |
| |
| @torch._dynamo.optimize("cudagraphs") |
| def fn(x, y): |
| for i in range(N_ITERS): |
| loss = model(x, y).sum() |
| loss.backward() |
| |
| x = torch.randn(3, device="cuda", requires_grad=True) |
| y = torch.randn(3, device="cuda") |
| fn(x, y) |
| |
| @patch_all() |
| def test_htod(self): |
| def model(x, y): |
| a = x + y |
| return a * 3 |
| |
| @torch._dynamo.optimize("cudagraphs") |
| def fn(x, y): |
| for i in range(N_ITERS): |
| loss = model(x, y).sum() |
| loss.backward() |
| |
| x = torch.randn(3, device="cuda", requires_grad=True) |
| y = torch.randn((), device="cpu") |
| fn(x, y) |
| |
| def test_mutate_input(self): |
| def model(x, y): |
| y.add_(3) |
| return x * y |
| |
| @torch._dynamo.optimize("cudagraphs") |
| def fn(x, y): |
| for i in range(N_ITERS): |
| with self.subTest(i): |
| y_orig = y.clone() |
| loss = model(x, y).sum() |
| self.assertTrue(same(y, y_orig + 3)) |
| loss.backward() |
| |
| x = torch.randn(3, device="cuda", requires_grad=True) |
| y = torch.randn(3, device="cuda") |
| fn(x, y) |
| |
| @patch_all() |
| def test_mutate_constant(self): |
| def model(x, y): |
| c = torch.tensor(1) |
| c.add_(2) |
| return x * y * 0 + c |
| |
| @torch._dynamo.optimize("cudagraphs") |
| def fn(x, y): |
| for i in range(N_ITERS): |
| with self.subTest(i): |
| loss = model(x, y).sum() |
| self.assertTrue(same(loss, torch.tensor(3.0, device="cuda"))) |
| loss.backward() |
| |
| x = torch.randn(1, device="cuda", requires_grad=True) |
| y = torch.randn(1, device="cuda") |
| fn(x, y) |
| |
| @patch_all() |
| def test_factory(self): |
| def model(y): |
| x = torch.zeros(3, device="cuda:0") |
| x.add_(3) |
| return x * y |
| |
| @torch._dynamo.optimize("cudagraphs") |
| def fn(y): |
| for i in range(N_ITERS): |
| with self.subTest(i): |
| loss = model(y).sum() |
| loss.backward() |
| |
| y = torch.randn(3, device="cuda:0", requires_grad=True) |
| fn(y) |
| |
| @patch_all() |
| def test_mutated_metadata(self): |
| # more tortured example at |
| # https://github.com/pytorch/pytorch/issues/81385 |
| def model(x): |
| x = x.clone() |
| x.resize_(20) |
| x.fill_(2) |
| return x |
| |
| @torch._dynamo.optimize("cudagraphs") |
| def fn(x): |
| for i in range(N_ITERS): |
| with self.subTest(i): |
| rx = model(x) |
| self.assertTrue(same(rx, torch.full((20,), 2.0, device="cuda:0"))) |
| |
| x = torch.empty(0, device="cuda:0") |
| fn(x) |
| |
| @patch_all() |
| def test_dead_fill(self): |
| def model(x): |
| x = x.clone() |
| y = x[0:0] |
| x.fill_(2) |
| y.fill_(3) |
| return x, y |
| |
| @torch._dynamo.optimize("cudagraphs") |
| def fn(x): |
| for i in range(N_ITERS): |
| with self.subTest(i): |
| rx, ry = model(x) |
| self.assertTrue(same(rx, torch.full((20,), 2.0, device="cuda:0"))) |
| self.assertTrue(same(ry, torch.empty(0, device="cuda:0"))) |
| |
| x = torch.empty(20, device="cuda:0") |
| fn(x) |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| |
| if not TEST_CUDA_GRAPH: |
| if __name__ == "__main__": |
| import sys |
| |
| sys.exit(0) |
| raise unittest.SkipTest("cuda graph test is skipped") |
| |
| run_tests() |