| # Owner(s): ["module: dynamo"] |
| import contextlib |
| |
| import torch._dynamo.test_case |
| import torch._dynamo.testing |
| from torch._dynamo.testing import CompileCounter |
| from torch.backends.cuda import SDPAParams |
| |
| |
| @contextlib.contextmanager |
| def allow_in_graph_sdpa_params(): |
| global SDPAParams |
| try: |
| old = SDPAParams |
| SDPAParams = torch._dynamo.allow_in_graph(SDPAParams) |
| yield |
| finally: |
| SDPAParams = old |
| |
| |
| class TestSDPA(torch._dynamo.test_case.TestCase): |
| def assert_ref_equals_params(self, actual, expected): |
| self.assertIs(actual.query, expected.query) |
| self.assertIs(actual.key, expected.key) |
| self.assertIs(actual.value, expected.value) |
| self.assertIs(actual.attn_mask, expected.attn_mask) |
| |
| def test_returns_SDPAParams(self): |
| with allow_in_graph_sdpa_params(): |
| counter = CompileCounter() |
| |
| @torch.compile(fullgraph=True, backend=counter) |
| def fn(q, k, v, m): |
| return SDPAParams(q, k, v, m, 0.1, True, False) |
| |
| q = torch.randn(10) |
| k = torch.randn(10) |
| v = torch.randn(10) |
| m = torch.randn(10) |
| o = fn(q, k, v, m) |
| self.assertTrue(isinstance(o, SDPAParams)) |
| self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True, False)) |
| self.assertEqual(counter.frame_count, 1) |
| |
| def test_graph_break_SDPAParams(self): |
| with allow_in_graph_sdpa_params(): |
| counter = CompileCounter() |
| |
| @torch.compile(backend=counter) |
| def fn(q, k, v, m): |
| z = SDPAParams(q, k, v, m, 0.1, True, False) |
| torch._dynamo.graph_break() |
| return z, q + 1 |
| |
| q = torch.randn(10) |
| k = torch.randn(10) |
| v = torch.randn(10) |
| m = torch.randn(10) |
| o, _ = fn(q, k, v, m) |
| self.assertTrue(isinstance(o, SDPAParams)) |
| self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True, False)) |
| self.assertEqual(counter.frame_count, 2) |
| |
| def test_input_SDPAParams(self): |
| with allow_in_graph_sdpa_params(): |
| counter = CompileCounter() |
| |
| @torch.compile(backend=counter) |
| def fn(sdpap, q): |
| torch._dynamo.graph_break() |
| return sdpap, sdpap.query + q |
| |
| q = torch.randn(10) |
| k = torch.randn(10) |
| v = torch.randn(10) |
| m = torch.randn(10) |
| s = SDPAParams(q, k, v, m, 0.1, True, False) |
| o, _ = fn(s, q) |
| self.assertIs(o, s) |
| self.assertEqual(counter.frame_count, 1) |
| |
| def test_intermediate_attr_access_SDPAParams(self): |
| with allow_in_graph_sdpa_params(): |
| counter = CompileCounter() |
| |
| @torch.compile(fullgraph=True, backend=counter) |
| def fn(q, k, v, m): |
| q += 1 |
| z = SDPAParams(q, k, v, m, 0.1, True, False) |
| a = z.query |
| return a + 1, z, q |
| |
| q = torch.randn(10) |
| k = torch.randn(10) |
| v = torch.randn(10) |
| m = torch.randn(10) |
| _, o, _ = fn(q, k, v, m) |
| expected = SDPAParams(q, k, v, m, 0.1, True, False) |
| self.assert_ref_equals_params(o, expected) |
| self.assertEqual(counter.frame_count, 1) |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| |
| run_tests() |