| # Owner(s): ["module: dynamo"] |
| import unittest |
| import weakref |
| |
| import torch |
| |
| import torch._dynamo |
| import torch._dynamo.config |
| import torch._dynamo.test_case |
| import torch._dynamo.testing |
| |
| import torch._logging |
| from torch.testing._internal.logging_utils import kwargs_to_settings, log_settings |
| |
| |
| class RecompileUxTests(torch._dynamo.test_case.TestCase): |
| # TODO(whc) dynamo actually recompiles one more time than the cache limit |
| cache_limit = 1 |
| |
| @classmethod |
| def setUpClass(cls): |
| super().setUpClass() |
| cls._exit_stack.enter_context( |
| torch._dynamo.config.patch("cache_size_limit", cls.cache_limit) |
| ) |
| |
| def test_drop_cache_on_skip(self): |
| def model(x, i): |
| return x + i |
| |
| attached = False |
| triggered = False |
| |
| def trigger(): |
| nonlocal triggered |
| triggered = True |
| |
| def compiler(gm, input): |
| nonlocal attached |
| f = gm.forward |
| assert not attached |
| # NB: making this a weakref.ref causes the cycle to no |
| # longer be promptly GC'ed |
| weakref.finalize(f, trigger) |
| attached = True |
| return f |
| |
| x = torch.randn(2) |
| for i in range(2): |
| opt_model = torch._dynamo.optimize(compiler)(model) |
| opt_model(x, i) |
| |
| self.assertTrue(triggered) |
| |
| def test_loop_torture(self): |
| def loop_torture(input, iters): |
| out = input |
| # randint itself causes one graph break |
| for _ in range(iters): |
| out += input |
| return out |
| |
| compile_counter = torch._dynamo.testing.CompileCounter() |
| for _ in range(10): |
| x = torch.randn(3) |
| iters = torch.randint(low=0, high=1000, size=()) |
| opt_loop_torture = torch._dynamo.optimize(compile_counter)(loop_torture) |
| opt_loop_torture(x, iters) |
| |
| # Currently, we recompile each time, |
| # We'd probably like to bail out quickly and warn |
| # TODO(whc) these checks fail on py37. Why? |
| # self.assertEqual(counters["frames"]["total"], 2 + self.cache_limit) |
| # self.assertEqual(counters["frames"]["ok"], 1 + self.cache_limit) |
| |
| # compile_counter only sees frames that were fed to the backend compiler, |
| # which is a subset of counters["frames"]["ok"] -- probably because |
| # counters["frames"]["ok"] includes frames not containing torch ops? |
| self.assertEqual(compile_counter.frame_count, self.cache_limit) |
| |
| @torch._dynamo.config.patch("automatic_dynamic_shapes", False) |
| def test_dynamic_input(self): |
| def model(input): |
| return input + input |
| |
| expected_recompiles = 2 |
| compile_counter = torch._dynamo.testing.CompileCounter() |
| with torch._dynamo.config.patch("cache_size_limit", expected_recompiles): |
| with self.assertLogs(logger="torch._dynamo", level="WARNING") as logs: |
| for _ in range(10): |
| bsz = torch.randint(low=0, high=1000, size=()) |
| x = torch.randn((bsz, 3, 4)) |
| opt_model = torch._dynamo.optimize(compile_counter)(model) |
| opt_model(x) |
| |
| self.assertEqual(compile_counter.frame_count, expected_recompiles) |
| self.assertEqual(len(logs.records), 1) |
| print(logs.records[0]) |
| self.assertTrue( |
| logs.records[0] |
| .getMessage() |
| .startswith("torch._dynamo hit config.cache_size_limit") |
| ) |
| |
| @unittest.skipIf(not torch.cuda.is_available(), "requires cuda") |
| def test_nvfuser_guards(self): |
| # we may want to model dynamo's guards sufficiently after nvfuser's ProfilingExecutor guards |
| # such that we ensure dynamo is in charge of all the recompilations at the top level, |
| # and we could thus simplify the underlying torchscript executor |
| def func(a, b, c): |
| return a + b * c |
| |
| a = torch.rand(3, 4, 5, device="cuda") |
| b = torch.rand(3, 4, 5, device="cuda") |
| b_v = torch.rand(3, 5, 4, device="cuda").view(3, 4, 5) |
| b_p = torch.rand(3, 5, 4, device="cuda").permute(0, 2, 1) |
| c = torch.rand(3, 4, 5, device="cuda") |
| compile_counter = torch._dynamo.testing.CompileCounter() |
| |
| with torch._dynamo.config.patch("cache_size_limit", 2): |
| opt_func = torch._dynamo.optimize(compile_counter)(func) |
| opt_func(a, b, c) # warmup |
| self.assertEqual(compile_counter.frame_count, 1) |
| |
| opt_func(a, b, c) # no guard fail or recompile |
| self.assertEqual(compile_counter.frame_count, 1) |
| |
| opt_func(a, b_v, c) # a view should not cause nvfuser recompile |
| self.assertEqual(compile_counter.frame_count, 1) |
| |
| opt_func(a, b_p, c) # a permutation should cause recompile |
| self.assertEqual(compile_counter.frame_count, 2) |
| |
| def assert_single_log_contains(self, logs, contains_str): |
| self.assertEqual(len(logs.records), 1) |
| self.assertTrue( |
| logs.records[0].getMessage().find(contains_str) > 0, |
| msg=f'Expected to find "{contains_str}" in log "{logs.records[0].getMessage()}"', |
| ) |
| |
| def test_verbose_tensor_check(self): |
| def func(a): |
| # Warning: choose a function here whose meta implementation lives |
| # entirely in C++. If you do a Python one, Dynamo will dive into |
| # torch._refs which is OK but it will muddy up the warnings |
| return torch.add(a, 4) |
| |
| def cache_fail_test(cached_input, missed_input, expected_failure): |
| # TODO(whc) maybe its hacky to have a 'test within a test' but this seemed convenient |
| torch._dynamo.reset() |
| torch._dynamo.utils.counters.clear() |
| opt_func = torch._dynamo.optimize("eager")(func) |
| # warmup |
| opt_func(cached_input) |
| |
| with self.assertLogs(logger="torch._dynamo", level="WARNING") as logs: |
| opt_func = torch._dynamo.optimize("eager")(func) |
| opt_func(missed_input) |
| self.assert_single_log_contains(logs, expected_failure) |
| |
| a = torch.rand(3, 4, 5) |
| cache_fail_test( |
| a, |
| a[0:2, :, :], |
| "tensor 'L['a']' size mismatch at index 0. expected 3, actual 2", |
| ) |
| cache_fail_test( |
| a, |
| a.clone().as_strided((3, 4, 5), stride=(1, 3, 12)), |
| "tensor 'L['a']' stride mismatch at index 0. expected 20, actual 1", |
| ) |
| cache_fail_test( |
| a, a[0, :, :], "tensor 'L['a']' rank mismatch. expected 3, actual 2" |
| ) |
| cache_fail_test(a, a.to("meta"), "tensor 'L['a']' dispatch key set mismatch.") |
| cache_fail_test( |
| a, |
| a.to(torch.float16), |
| "tensor 'L['a']' dtype mismatch. expected Float, actual Half", |
| ) |
| a_grad = a.clone() |
| a_grad.requires_grad = True |
| cache_fail_test( |
| a, |
| a_grad, |
| "tensor 'L['a']' requires_grad mismatch. expected requires_grad=0", |
| ) |
| |
| def test_mismatched_type(self): |
| a = torch.rand(3, 4, 5) |
| b = torch.rand(3, 4, 5) |
| |
| def func(a, b): |
| return a + b |
| |
| opt_func = torch._dynamo.optimize("eager")(func) |
| # warmup |
| opt_func(a, b) |
| |
| with self.assertLogs(logger="torch._dynamo", level="WARNING") as logs: |
| opt_func = torch._dynamo.optimize("eager")(func) |
| opt_func(a, 1) |
| self.assert_single_log_contains( |
| logs, |
| "expected type of 'L['b']' to be a tensor type, ' but found <class 'int'>", |
| ) |
| |
| @torch._dynamo.config.patch("cache_size_limit", 32) |
| def test_multiple_guard_fails(self): |
| failure_reasons = [] |
| |
| def guard_fail_fn(failure): |
| failure_reasons.append(failure[0]) |
| |
| def f(x): |
| return torch.relu(x) |
| |
| opt_f = torch._dynamo.optimize( |
| backend="eager", guard_fail_fn=guard_fail_fn, dynamic=False |
| )(f) |
| |
| for i in range(5): |
| failure_reasons.clear() |
| opt_f(torch.randn(8 + i)) |
| |
| failure_str = "\n".join(failure_reasons) |
| for line in """\ |
| tensor 'L['x']' size mismatch at index 0. expected 11, actual 12 |
| tensor 'L['x']' size mismatch at index 0. expected 10, actual 12 |
| tensor 'L['x']' size mismatch at index 0. expected 9, actual 12 |
| tensor 'L['x']' size mismatch at index 0. expected 8, actual 12""".split( |
| "\n" |
| ): |
| self.assertIn( |
| line, |
| failure_str, |
| ) |
| |
| @torch._dynamo.config.patch("cache_size_limit", 32) |
| def test_multiple_guard_fails_report_all(self): |
| with log_settings(kwargs_to_settings(recompiles_verbose=True)): |
| failure_reasons = [] |
| |
| def guard_fail_fn(failure): |
| failure_reasons.append(failure[0]) |
| |
| def f(x): |
| return torch.ones(len(x), x[-1]) |
| |
| opt_f = torch._dynamo.optimize( |
| backend="eager", guard_fail_fn=guard_fail_fn, dynamic=False |
| )(f) |
| |
| opt_f([4, 5, 6]) |
| |
| def filter_reasons(): |
| return "\n".join( |
| [ |
| line |
| for line in "\n".join(failure_reasons).splitlines() |
| if not line.startswith("___check_type_id") |
| ] |
| ) |
| |
| failure_reasons.clear() |
| opt_f([7, 8]) |
| |
| for line in """\ |
| len(L['x']) == 3""".split( |
| "\n" |
| ): |
| self.assertIn(line, filter_reasons()) |
| |
| failure_reasons.clear() |
| opt_f([9]) |
| |
| for line in """\ |
| len(L['x']) == 2 |
| len(L['x']) == 3""".split( |
| "\n" |
| ): |
| self.assertIn(line, filter_reasons()) |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| |
| run_tests() |