| # Owner(s): ["module: dynamo"] |
| import sys |
| import unittest |
| from unittest.mock import MagicMock, patch |
| |
| import torch |
| import torch._dynamo |
| import torch._dynamo.backends |
| import torch._dynamo.test_case |
| from torch._dynamo.backends.debugging import ExplainWithBackend |
| from torch._dynamo.backends.onnxrt import has_onnxruntime |
| from torch._dynamo.backends.tvm import has_tvm |
| from torch._dynamo.testing import same |
| from torch.fx._lazy_graph_module import _force_skip_lazy_graph_module |
| from torch.testing._internal.inductor_utils import HAS_CUDA |
| |
| |
| requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") |
| |
| |
| class Seq(torch.nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.layers = torch.nn.Sequential( |
| torch.nn.Linear(10, 10), |
| torch.nn.ReLU(), |
| torch.nn.Linear(10, 10), |
| torch.nn.Sigmoid(), |
| ) |
| |
| def forward(self, x): |
| return self.layers(x) |
| |
| |
| class Conv_Bn_Relu(torch.nn.Module): |
| def __init__(self, in_channels, out_channels, **kwargs): |
| super().__init__() |
| self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=False, **kwargs) |
| self.bn = torch.nn.BatchNorm2d(out_channels, eps=0.001) |
| self.relu = torch.nn.ReLU() |
| |
| def forward(self, x): |
| return self.relu(self.bn(self.conv(x))) |
| |
| |
| class TestOptimizations(torch._dynamo.test_case.TestCase): |
| def test_example_inputs(self): |
| def fn(a, bc, d): |
| b, c = bc |
| return a / d - b / c |
| |
| def compiler_fn(graph, example_inputs): |
| nonlocal r1 |
| r1 = graph(*example_inputs)[0] |
| return graph.forward |
| |
| a = torch.empty(2).fill_(1) |
| b = torch.empty(2).fill_(2) |
| c = torch.empty(2).fill_(3) |
| d = 4 |
| r1 = None |
| r2 = fn(a, (b, c), d) |
| opt_fn = torch._dynamo.optimize_assert(compiler_fn)(fn) |
| r3 = opt_fn(a, (b, c), d) |
| |
| self.assertIsNotNone(r1) |
| self.assertEqual(r1.size(), r2.size()) |
| self.assertEqual(r1.stride(), r2.stride()) |
| self.assertEqual(r1.dtype, r2.dtype) |
| |
| self.assertEqual(r1.size(), r3.size()) |
| self.assertEqual(r1.stride(), r3.stride()) |
| self.assertEqual(r1.dtype, r3.dtype) |
| |
| def test_example_inputs_runtime_use(self): |
| def fn(a, bc, d): |
| b, c = bc |
| return a / d - b / c |
| |
| def compiler_fn(graph, example_inputs): |
| def fwd(*args): |
| nonlocal r1 |
| r = graph.forward(*args) |
| r1 = r[0] |
| return r |
| |
| return fwd |
| |
| a = torch.empty(2).fill_(1) |
| b = torch.empty(2).fill_(2) |
| c = torch.empty(2).fill_(3) |
| d = 4 |
| r1 = None |
| r2 = fn(a, (b, c), d) |
| opt_fn = torch._dynamo.optimize_assert(compiler_fn)(fn) |
| r3 = opt_fn(a, (b, c), d) |
| |
| self.assertIsNotNone(r1) |
| self.assertTrue(same(r1, r2)) |
| self.assertTrue(same(r1, r3)) |
| |
| def _check_backend_works(self, backend, options=None): |
| model = Seq().eval() |
| input = torch.randn(2, 10) |
| r1 = model(input) |
| r2 = torch.compile(model, backend=backend, options=options)(input) |
| self.assertTrue(same(r1, r2.float(), tol=0.01)) |
| |
| def test_eager(self): |
| self._check_backend_works("eager") |
| |
| def test_eager_noexcept(self): |
| self._check_backend_works("eager_noexcept") |
| |
| @_force_skip_lazy_graph_module() |
| def test_torchscript(self): |
| self._check_backend_works("ts") |
| |
| def test_aot_eager(self): |
| self._check_backend_works("aot_eager") |
| |
| def test_aot_eager_decomp_partition(self): |
| self._check_backend_works("aot_eager_decomp_partition") |
| |
| @_force_skip_lazy_graph_module() |
| def test_aot_ts(self): |
| self._check_backend_works("aot_ts") |
| |
| @requires_cuda |
| def test_aot_cudagraphs(self): |
| self._check_backend_works("cudagraphs") |
| |
| @unittest.skipIf(not has_onnxruntime(), "requires onnxruntime") |
| def test_onnxrt(self): |
| self._check_backend_works("onnxrt") |
| |
| @unittest.skipIf(not has_tvm(), "requires tvm") |
| def test_tvm(self): |
| self._check_backend_works("tvm") |
| self._check_backend_works("tvm", options={"scheduler": None}) |
| self._check_backend_works("tvm", options={"opt_level": 0}) |
| |
| def test_list_backends(self): |
| self.assertIn("inductor", torch._dynamo.list_backends()) |
| self.assertIn("inductor", torch._dynamo.list_backends(exclude_tags=None)) |
| self.assertNotIn("eager", torch._dynamo.list_backends()) |
| self.assertNotIn("eager", torch._dynamo.list_backends(exclude_tags=["debug"])) |
| self.assertIn("eager", torch._dynamo.list_backends(exclude_tags=[])) |
| |
| |
| class NormalizeIRTests(torch._dynamo.test_case.TestCase): |
| def test_inplace_normalize(self): |
| def fn(a, b): |
| x = torch.cos(a) |
| x += b |
| return torch.sin(x) |
| |
| a = torch.randn(10) |
| b = torch.randn(10).to(torch.float64) |
| |
| ref = fn(a, b) |
| |
| optimized_fn = torch._dynamo.optimize("aot_eager")(fn) |
| res = optimized_fn(a, b) |
| self.assertTrue(same(ref, res)) |
| |
| |
| class MPSNotSupportedTest(torch._dynamo.test_case.TestCase): |
| @unittest.skipIf(not torch.backends.mps.is_available(), "requires mps") |
| def test_mps_not_supported(self): |
| model = Seq().to("mps") |
| example_input = torch.randn(1, 10).to("mps") |
| self.assertRaises( |
| RuntimeError, |
| lambda: torch.compile(model, backend="inductor")(example_input), |
| ) |
| |
| |
| class TestExplainWithBackend(torch._dynamo.test_case.TestCase): |
| def test_explain_with_backend(self): |
| def fn3(x): |
| x = torch.sin(x) |
| torch._dynamo.graph_break() |
| x = torch.sin(x) |
| return x |
| |
| def fn2(x): |
| x = torch.cos(x) |
| x = fn3(x) |
| x = torch.cos(x) |
| return x |
| |
| def fn1(x): |
| x = torch.tan(x) |
| x = fn2(x) |
| x = torch.tan(x) |
| return x |
| |
| def fn(x): |
| x = torch.sigmoid(x) |
| x = fn1(x) |
| x = torch.sigmoid(x) |
| return x |
| |
| # Wrap TorchInductor with explain backend |
| eb = ExplainWithBackend("inductor") |
| optimized_fn = torch.compile(fn, backend=eb) |
| input_tensor = torch.randn(5) |
| result = optimized_fn(input_tensor) |
| |
| # Check that fn still produces the same output when wrapped by ExplainWithBackend |
| self.assertTrue(torch.allclose(result, fn(input_tensor))) |
| |
| # Verify ExplainOutput object contents, output might change but make sure these fields are present |
| explain_output = eb.output() |
| explain_str = str(explain_output) |
| self.assertIn("Graph Count", explain_str) |
| self.assertIn("Graph Break Count", explain_str) |
| self.assertIn("Op Count", explain_str) |
| self.assertIn("Break Reasons", explain_str) |
| |
| # Verify that for the given functions above, we report the correct number of graphs, graph breaks, and ops |
| self.assertEqual(8, explain_output.graph_count) |
| self.assertEqual(7, explain_output.graph_break_count) |
| self.assertEqual(8, explain_output.op_count) |
| |
| |
| class TestCustomBackendAPI(torch._dynamo.test_case.TestCase): |
| """Test APIs documented by https://pytorch.org/docs/main/torch.compiler_custom_backends.html""" |
| |
| def test_register_backend_api(self): |
| from torch._dynamo import register_backend |
| |
| backend_run = False |
| |
| @register_backend |
| def my_custom_backend(gm, example_inputs): |
| nonlocal backend_run |
| backend_run = True |
| return gm.forward |
| |
| def f(x): |
| return torch.relu(x) |
| |
| opt_f = torch.compile(f, backend="my_custom_backend") |
| opt_f(torch.randn(3, 3)) |
| self.assertTrue(backend_run) |
| |
| def test_aot_autograd_api(self): |
| from functorch.compile import make_boxed_func |
| from torch._dynamo.backends.common import aot_autograd |
| |
| backend_run = False |
| |
| def my_compiler(gm, example_inputs): |
| nonlocal backend_run |
| backend_run = True |
| return make_boxed_func(gm.forward) |
| |
| my_backend = aot_autograd(fw_compiler=my_compiler) |
| |
| def f(x): |
| return torch.relu(x) |
| |
| opt_f = torch.compile(f, backend=my_backend) |
| opt_f(torch.randn(3, 3)) |
| self.assertTrue(backend_run) |
| |
| def test_lookup_backend(self): |
| from torch._dynamo import list_backends, lookup_backend |
| |
| backends = list_backends() |
| backend_run = False |
| |
| def my_compiler(gm, example_inputs): |
| nonlocal backend_run |
| backend_run = True |
| try: |
| trt_compiled = lookup_backend("tensorrt")(gm, example_inputs) |
| if trt_compiled is not None: |
| return trt_compiled |
| except Exception: |
| pass |
| # first backend failed, try something else... |
| try: |
| inductor_compiled = lookup_backend("inductor")(gm, example_inputs) |
| if inductor_compiled is not None: |
| return inductor_compiled |
| except Exception: |
| pass |
| return gm.forward |
| |
| def f(x): |
| return torch.relu(x) |
| |
| opt_f = torch.compile(f, backend=my_compiler) |
| opt_f(torch.randn(3, 3)) |
| self.assertTrue(backend_run) |
| |
| def test_lookup_custom_backend(self): |
| from torch._dynamo import list_backends |
| |
| backends_group = "torch_dynamo_backends" |
| name = "mycustombackend" |
| |
| mock_3_9 = MagicMock() |
| mock_3_9.load.return_value = lambda: "mocked 3.9" |
| mock_3_9.name = name |
| |
| mock_3_10 = MagicMock() |
| mock_3_10.load.return_value = lambda: "mocked 3.10" |
| |
| def mock_eps(group=None): |
| if sys.version_info < (3, 10): |
| return {backends_group: [mock_3_9]} |
| else: |
| assert group == backends_group, group |
| mock_group = MagicMock() |
| mock_group.names = [name] |
| mock_group[name] = mock_3_10 |
| # mock_group[name].load.return_value = lambda: "mocked 3.10" |
| return mock_group |
| |
| with patch("importlib.metadata.entry_points", mock_eps): |
| from torch._dynamo.backends import registry |
| |
| registry._lazy_import.cache_clear() |
| registry._discover_entrypoint_backends.cache_clear() |
| |
| backends = list_backends() |
| assert name in backends, (name, backends) |
| |
| def test_backend_recompilation(self): |
| def fn(x): |
| return x + x |
| |
| input = torch.tensor(2.0) |
| |
| opt_fn = torch.compile( |
| fn, backend="inductor", options={"_raise_error_for_testing": False} |
| ) |
| opt_fn(input) |
| with self.assertRaises(torch._dynamo.exc.BackendCompilerFailed): |
| opt_fn = torch.compile( |
| fn, backend="inductor", options={"_raise_error_for_testing": True} |
| ) |
| opt_fn(input) |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| |
| run_tests() |