| # Owner(s): ["module: inductor"] |
| import contextlib |
| |
| import sympy |
| |
| import torch |
| |
| import torch._inductor.config as inductor_config |
| from torch._inductor.codegen import triton_utils |
| from torch._inductor.codegen.common import SizeArg |
| from torch._inductor.graph import GraphLowering |
| from torch._inductor.virtualized import V |
| |
| from torch.testing._internal.common_utils import TestCase as TorchTestCase |
| from torch.testing._internal.inductor_utils import HAS_CPU, HAS_CUDA |
| |
| |
| class TestCodegenTriton(TorchTestCase): |
| def setUp(self): |
| super().setUp() |
| |
| class DummyModule(torch.nn.Module): |
| def forward(self, x): |
| return x * 2 |
| |
| self._gm = torch.fx.symbolic_trace(DummyModule()) |
| self._graph = GraphLowering(self._gm) |
| |
| self._stack = contextlib.ExitStack() |
| self._stack.enter_context(V.set_graph_handler(self._graph)) |
| |
| def tearDown(self): |
| self._stack.close() |
| super().tearDown() |
| |
| @inductor_config.patch("triton.divisible_by_16", True) |
| def test_config_of_sizearg(self): |
| two = sympy.Integer(2) |
| eight = sympy.Integer(8) |
| sixteen = sympy.Integer(16) |
| s0 = sympy.Symbol("s0", positive=True, integer=True) |
| s1 = sympy.Symbol("s1", positive=True, integer=True) |
| |
| self.assertEqual( |
| (2,), |
| triton_utils.config_of( |
| [ |
| SizeArg("A", two), # no |
| SizeArg("B", eight), # no |
| SizeArg("C", sixteen), # yes |
| SizeArg("D", s0), # no |
| SizeArg("E", s1), # no |
| ] |
| ).divisible_by_16, |
| ) |
| |
| self.assertEqual( |
| (0, 2, 4, 5, 6), |
| triton_utils.config_of( |
| [ |
| SizeArg("A", two * eight), # 0: yes |
| SizeArg("B", eight * s0), # 1: no |
| SizeArg("C", two * eight * s0), # 2: yes |
| SizeArg("D", s0 * s1), # 3: no |
| SizeArg("E", sixteen * s0), # 4: yes |
| SizeArg("F", sixteen * eight * s0 * s1), # 5: yes |
| SizeArg("G", two * eight * s0 * s1), # 6: yes |
| ] |
| ).divisible_by_16, |
| ) |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| |
| if HAS_CPU or HAS_CUDA: |
| run_tests("sympy") |