| # Owner(s): ["module: dynamo"] |
| |
| import torch |
| import torch._dynamo.test_case |
| import torch._dynamo.testing |
| from torch._dynamo.utils import disable_cache_limit |
| |
| |
| # NB: do NOT include this test class in test_dynamic_shapes.py |
| |
| |
| class ConfigTests(torch._dynamo.test_case.TestCase): |
| @disable_cache_limit() |
| def test_no_automatic_dynamic(self): |
| def fn(a, b): |
| return a - b * 10 |
| |
| torch._dynamo.reset() |
| cnt_static = torch._dynamo.testing.CompileCounter() |
| with torch._dynamo.config.patch( |
| automatic_dynamic_shapes=False, assume_static_by_default=True |
| ): |
| opt_fn = torch._dynamo.optimize(cnt_static)(fn) |
| for i in range(2, 12): |
| opt_fn(torch.randn(i), torch.randn(i)) |
| self.assertEqual(cnt_static.frame_count, 10) |
| |
| @disable_cache_limit() |
| def test_automatic_dynamic(self): |
| def fn(a, b): |
| return a - b * 10 |
| |
| torch._dynamo.reset() |
| cnt_dynamic = torch._dynamo.testing.CompileCounter() |
| with torch._dynamo.config.patch( |
| automatic_dynamic_shapes=True, assume_static_by_default=True |
| ): |
| opt_fn = torch._dynamo.optimize(cnt_dynamic)(fn) |
| # NB: must not do 0, 1 as they specialized |
| for i in range(2, 12): |
| opt_fn(torch.randn(i), torch.randn(i)) |
| # two graphs now rather than 10 |
| self.assertEqual(cnt_dynamic.frame_count, 2) |
| |
| @disable_cache_limit() |
| def test_no_assume_static_by_default(self): |
| def fn(a, b): |
| return a - b * 10 |
| |
| torch._dynamo.reset() |
| cnt_dynamic = torch._dynamo.testing.CompileCounter() |
| with torch._dynamo.config.patch( |
| automatic_dynamic_shapes=True, assume_static_by_default=False |
| ): |
| opt_fn = torch._dynamo.optimize(cnt_dynamic)(fn) |
| # NB: must not do 0, 1 as they specialized |
| for i in range(2, 12): |
| opt_fn(torch.randn(i), torch.randn(i)) |
| # one graph now, as we didn't wait for recompile |
| self.assertEqual(cnt_dynamic.frame_count, 1) |
| |
| def test_config_compile_ignored(self): |
| # Remove from this list if no longer relevant |
| dynamo_guarded_config_ignorelist = { |
| "log_file_name", |
| "verbose", |
| "verify_correctness", # will not affect model, will raise RuntimeError |
| # (no silent change to compilation behaviour) |
| "cache_size_limit", |
| "accumulated_cache_size_limit", |
| "replay_record_enabled", |
| "cprofile", # only wraps _compile, not graph |
| "repro_after", |
| "repro_level", |
| "repro_forward_only", |
| "repro_tolerance", |
| "same_two_models_use_fp64", |
| "error_on_recompile", # safe because: will throw error |
| "report_guard_failures", |
| "base_dir", # used for minifying / logging |
| "DEBUG_DIR_VAR_NAME", |
| "debug_dir_root", |
| } |
| for k in dynamo_guarded_config_ignorelist: |
| assert k in torch._dynamo.config._compile_ignored_keys, k |
| |
| def test_config_hash(self): |
| config = torch._dynamo.config |
| starting_hash = config.get_hash() |
| |
| with config.patch({"verbose": not config.verbose}): |
| new_hash = config.get_hash() |
| assert "verbose" in config._compile_ignored_keys |
| assert new_hash == starting_hash |
| |
| new_hash = config.get_hash() |
| assert new_hash == starting_hash |
| |
| with config.patch({"dead_code_elimination": not config.dead_code_elimination}): |
| changed_hash = config.get_hash() |
| assert "dead_code_elimination" not in config._compile_ignored_keys |
| assert changed_hash != starting_hash |
| |
| # Test nested patch |
| with config.patch({"verbose": not config.verbose}): |
| inner_changed_hash = config.get_hash() |
| assert inner_changed_hash == changed_hash |
| assert inner_changed_hash != starting_hash |
| |
| newest_hash = config.get_hash() |
| assert changed_hash != newest_hash |
| assert newest_hash == starting_hash |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| |
| run_tests() |