| # Owner(s): ["module: dynamo"] |
| import dataclasses |
| import importlib |
| import inspect |
| import math |
| import types |
| import unittest |
| import warnings |
| from typing import Any, Dict, Set |
| |
| import torch |
| import torch._dynamo.config as config |
| import torch._dynamo.test_case |
| import torch._functorch.deprecated as deprecated_func |
| from torch._dynamo.trace_rules import ( |
| LEGACY_MOD_INLINELIST, |
| load_object, |
| manual_torch_name_rule_map, |
| MOD_INLINELIST, |
| torch_c_binding_in_graph_functions, |
| torch_non_c_binding_in_graph_functions, |
| ) |
| from torch._dynamo.utils import hashable, is_safe_constant, istype |
| from torch._dynamo.variables import TorchInGraphFunctionVariable, UserFunctionVariable |
| |
| try: |
| from .utils import create_dummy_module_and_function |
| except ImportError: |
| from utils import create_dummy_module_and_function |
| |
| |
| ignored_c_binding_in_graph_function_names = { |
| # Ignored because they have manual rules defined at `trace_rules.manual_torch_name_rule_map`. |
| "torch._nested_tensor_from_mask", |
| "torch._nested_from_padded", |
| "torch.sparse_compressed_tensor", |
| "torch.sparse_bsc_tensor", |
| "torch.sparse_bsr_tensor", |
| "torch.sparse_coo_tensor", |
| "torch.sparse_csc_tensor", |
| "torch.sparse_csr_tensor", |
| "torch.cuda._get_device_properties", |
| # Ignored and go through rules defined at `trace_rules.check`. |
| "torch._functionalize_are_all_mutations_under_no_grad_or_inference_mode", |
| "torch._cslt_sparse_mm_search", |
| "torch._C._abort", |
| "torch._C._mps_is_on_macos_or_newer", |
| "torch._C._swap_tensor_impl", |
| "torch._C._unsafe_reset_storage", |
| "torch._dynamo.eval_frame.reset_code", |
| "torch._C.autocast_decrement_nesting", |
| "torch._C.autocast_increment_nesting", |
| "torch._C.clear_autocast_cache", |
| "torch._C.set_anomaly_enabled", |
| "torch._C.set_autocast_cache_enabled", |
| "torch._C.set_autocast_cpu_dtype", |
| "torch._C.set_autocast_cpu_enabled", |
| "torch._C.set_autocast_enabled", |
| "torch._C.set_autocast_gpu_dtype", |
| "torch._C.set_autocast_ipu_dtype", |
| "torch._C.set_autocast_ipu_enabled", |
| "torch._C.set_autocast_xla_dtype", |
| "torch._C.set_autocast_xla_enabled", |
| "torch.resize_as_", |
| "torch.resize_as_sparse_", |
| "torch._C._data_address", |
| "torch._C._is_cow_tensor", |
| "torch._lazy_clone", |
| "torch._test_parallel_materialize", |
| "torch._C._storage_address", |
| "torch._C._pickle_save", |
| "torch._validate_sparse_compressed_tensor_args", |
| "torch._validate_sparse_csr_tensor_args", |
| "torch._validate_sparse_bsr_tensor_args", |
| "torch._validate_sparse_csc_tensor_args", |
| "torch._validate_sparse_coo_tensor_args", |
| "torch._validate_sparse_bsc_tensor_args", |
| "torch._validate_compressed_sparse_indices", |
| } |
| if torch._C._llvm_enabled(): |
| ignored_c_binding_in_graph_function_names |= { |
| "torch._C._te.set_llvm_aot_workflow", |
| "torch._C._te.set_llvm_target_cpu", |
| "torch._C._te.set_llvm_target_attrs", |
| "torch._C._te.set_llvm_target_triple", |
| } |
| |
| |
| # Helper function to dump the torch name rule map generated based on |
| # the heuristic defined in gen_allowed_objs_and_ids. |
| def dump_allowed_torch_name_rule_map() -> None: |
| m = gen_allowed_objs_and_ids(record=True, c_binding_only=False).name_rule_map |
| for k, v in m.items(): |
| print(f'"{k}": {v.__name__},') |
| |
| |
| @dataclasses.dataclass |
| class AllowedObjects: |
| """ |
| Track the objects, object id - name pairs, and name - dynamo wrapping rule pairs |
| from the heuristic defined in `gen_allowed_objs_and_ids`. |
| """ |
| |
| object_ids: Dict[int, str] |
| c_binding_in_graph_functions: Set[Any] |
| non_c_binding_in_graph_functions: Set[Any] |
| name_rule_map: Dict[str, Any] |
| |
| |
| def gen_allowed_objs_and_ids(record=False, c_binding_only=True) -> AllowedObjects: |
| """ |
| Walk torch.* and get the ids of all the stuff in it |
| """ |
| |
| warnings.filterwarnings("ignore", category=UserWarning, module="torch.distributed") |
| torch_object_ids = dict() |
| c_binding_in_graph_functions = set() |
| non_c_binding_in_graph_functions = set() |
| torch_name_rule_map = dict() |
| |
| # In some platforms, these functions were loaded as classes instead of functions. |
| # To mitigate these weired cases, we need this special check. |
| def is_special_functions(obj): |
| return hashable(obj) and obj in { |
| torch._C._cuda_isCurrentStreamCapturing, |
| torch._C._graph_pool_handle, |
| } |
| |
| # Add obj to c_binding_in_graph_functions set or non_c_binding_in_graph_functions set |
| # if it's a torch function or method. |
| # This is used to generate the in graph function list based on heuristic. |
| def heuristic_record_if_in_graph_function(obj, module, name): |
| try: |
| if hasattr(obj, "__wrapped__"): |
| obj = obj.__wrapped__ |
| except Exception: |
| pass |
| if isinstance( |
| obj, |
| ( |
| types.FunctionType, |
| types.BuiltinFunctionType, |
| types.MethodDescriptorType, |
| types.WrapperDescriptorType, |
| ), |
| ) or is_special_functions(obj): |
| torch_name_rule_map[ |
| f"{module.__name__}.{name}" |
| ] = TorchInGraphFunctionVariable |
| if c_binding_only: |
| if not hasattr(obj, "__code__"): |
| c_binding_in_graph_functions.add(obj) |
| else: |
| if hasattr(obj, "__code__"): |
| non_c_binding_in_graph_functions.add(obj) |
| else: |
| c_binding_in_graph_functions.add(obj) |
| |
| def _is_allowed_module_prefix(obj): |
| allowed_modules = ("torch", "math") |
| # torch.nn.modules.rnn is disallowed because these modules internally |
| # flatten their parameters. This flattening process will call |
| # Tensor.set_ with a Storage, and Storages cannot be traced with |
| # AOTAutograd; so we need to graph-break. To ensure this, we inline |
| # these functions, rather than keep them opaque-ly in the graph. |
| disallowed_modules = [ |
| "torch.optim.", |
| "torch.nn.modules.rnn.", |
| "torch._dynamo.", |
| "torch._C._dynamo.", |
| "torch._inductor.", |
| "torch._C.inductor.", |
| "torch.fx.", |
| "torch._C._autograd", |
| "torch._C._cudart", |
| "torch._C._distributed_autograd", |
| "torch._C._distributed_c10d", |
| "torch._C._distributed_rpc", |
| "torch._C._functorch", |
| "torch._C._monitor", |
| "torch._C._nvtx", |
| "torch._C._lazy", |
| "torch._C._profiler", |
| "torch.__config__", |
| "torch._custom_op", |
| "torch._decomp", |
| "torch._dispatch", |
| "torch._export", |
| "torch._functorch.make_functional", |
| "torch._functorch.compile_utils", |
| "torch._functorch.partitioners", |
| "torch._functorch.aot_autograd", |
| "torch._functorch.compilers", |
| "torch._functorch.fx_minifier", |
| "torch.autograd.profiler_util", |
| "torch.autograd.profiler", |
| "torch._jit_internal", |
| "torch._library", |
| "torch._lobpcg", |
| "torch._logging", |
| "torch._meta_registrations", |
| "torch._namedtensor_internals", |
| "torch._numpy", |
| "torch._sources", |
| "torch._subclasses", |
| "torch._tensor", |
| "torch._tensor_str", |
| "torch._utils", |
| "torch._utils_internal", |
| "torch._vmap_internals", |
| "torch.compiler", |
| "torch.distributed", |
| "torch.export", |
| "torch.hub", |
| "torch.jit", |
| "torch.library", |
| "torch.masked.maskedtensor", |
| "torch.nn.init", |
| "torch.nn.modules.module", |
| "torch.nn.parallel", |
| "torch.nn.utils", |
| "torch.multiprocessing", |
| "torch.onnx", |
| "torch.overrides", |
| "torch.package", |
| "torch.profiler", |
| "torch.serialization", |
| "torch.storage", |
| "torch.utils", |
| "torch.distributed.", |
| ] |
| |
| allowed_modules_dot = tuple([x + "." for x in allowed_modules]) |
| module = inspect.getmodule(obj) |
| if module is None: |
| return False |
| |
| mod_name = module.__name__ |
| |
| if any(mod_name.startswith(m) for m in disallowed_modules): |
| return False |
| |
| return mod_name in allowed_modules or mod_name.startswith(allowed_modules_dot) |
| |
| def _find_torch_objects(module): |
| if any( |
| module.__name__.startswith(mod_name) |
| for mod_name in config.allowed_functions_module_string_ignorelist |
| ): |
| return |
| torch_object_ids[id(module)] = module.__name__ |
| for name, obj in list(module.__dict__.items()): |
| if id(obj) not in torch_object_ids: |
| # Dynamo allows all builtins into the graph and does not attempt |
| # to introspect into them. We don't want to allow instances of |
| # HigherOrderOperator into the graph all the time (Dynamo needs |
| # to introspect the body functions of these HigherOrderOperator |
| # first, decide they are safe, and then allow them into the graph). |
| # So we exclude HigherOrderOperator from being a builtin. |
| import torch._ops |
| |
| if isinstance(obj, torch._ops.HigherOrderOperator): |
| continue |
| |
| # We want to trace through `grad` and `vmap` |
| if obj in ( |
| torch.func.grad, |
| deprecated_func.grad, |
| torch.func.vmap, |
| deprecated_func.vmap, |
| torch.nn.functional.triplet_margin_with_distance_loss, |
| torch.cond, |
| ): |
| continue |
| |
| if isinstance(obj, types.ModuleType): |
| if obj.__name__.startswith("torch.") and _is_allowed_module_prefix( |
| obj |
| ): |
| torch_object_ids[id(obj)] = f"{module.__name__}.{name}" |
| _find_torch_objects(obj) |
| elif _is_allowed_module_prefix(obj): |
| if record: |
| heuristic_record_if_in_graph_function(obj, module, name) |
| torch_object_ids[id(obj)] = f"{module.__name__}.{name}" |
| elif inspect.getmodule(obj) is None and not is_safe_constant(obj): |
| if record: |
| heuristic_record_if_in_graph_function(obj, module, name) |
| torch_object_ids[id(obj)] = f"{module.__name__}.{name}" |
| |
| _find_torch_objects(torch) |
| _find_torch_objects(math) |
| |
| return AllowedObjects( |
| torch_object_ids, |
| c_binding_in_graph_functions, |
| non_c_binding_in_graph_functions, |
| torch_name_rule_map, |
| ) |
| |
| |
| class TraceRuleTests(torch._dynamo.test_case.TestCase): |
| def _check_set_equality(self, generated, used, rule_map, ignored_set): |
| x = generated - used |
| y = used - generated |
| msg1 = ( |
| f"New torch objects: {x} " |
| f"were not added to `trace_rules.{rule_map}` or `test_trace_rules.{ignored_set}`. " |
| "Refer the instruction in `torch/_dynamo/trace_rules.py` for more details." |
| ) |
| msg2 = ( |
| f"Existing torch objects: {y} were removed. " |
| f"Please remove them from `trace_rules.{rule_map}` or `test_trace_rules.{ignored_set}`. " |
| "Refer the instruction in `torch/_dynamo/trace_rules.py` for more details." |
| ) |
| self.assertTrue(len(x) == 0, msg1) |
| self.assertTrue(len(y) == 0, msg2) |
| |
| # We are using python function and module string names for these inlinelist, |
| # this unit test is to make sure the functions/modules can be correctly imported |
| # or loaded in case there is typo in the strings. |
| def test_skipfiles_inlinelist(self): |
| for m in LEGACY_MOD_INLINELIST.union(MOD_INLINELIST): |
| self.assertTrue( |
| isinstance(importlib.import_module(m), types.ModuleType), |
| f"{m} from trace_rules.MOD_INLINELIST/LEGACY_MOD_INLINELIST is not a python module, please check and correct it.", |
| ) |
| |
| @unittest.skip( |
| "This test keeps getting broken and our disable infra is not handling well. see #120627" |
| ) |
| def test_torch_name_rule_map_updated(self): |
| # Generate the allowed objects based on heuristic defined in `allowed_functions.py`, |
| objs = gen_allowed_objs_and_ids(record=True, c_binding_only=True) |
| # Test C binding in graph functions are updated in torch_name_rule_map. |
| generated = objs.c_binding_in_graph_functions |
| used = set() |
| for x in ( |
| set(torch_c_binding_in_graph_functions.keys()) |
| | ignored_c_binding_in_graph_function_names |
| ): |
| obj = load_object(x) |
| if obj is not None: |
| used.add(obj) |
| self._check_set_equality( |
| generated, |
| used, |
| "torch_c_binding_in_graph_functions", |
| "ignored_c_binding_in_graph_function_names", |
| ) |
| # For non C binding in graph functions, we only test if they can be loaded successfully. |
| for f in torch_non_c_binding_in_graph_functions: |
| self.assertTrue( |
| isinstance( |
| load_object(f), |
| ( |
| types.FunctionType, |
| types.BuiltinFunctionType, |
| types.MethodDescriptorType, |
| types.WrapperDescriptorType, |
| ), |
| ) |
| ) |
| |
| def test_force_inline_torch_function(self): |
| # `torch._dynamo.utils.istype` is skipped by default |
| def fn(x): |
| if istype(x, torch.Tensor): |
| return x + 1 |
| else: |
| return x - 1 |
| |
| _manual_torch_name_rule_map = manual_torch_name_rule_map.copy() |
| # Force inline `torch._dynamo.utils.istype` by setting trace rule. |
| _manual_torch_name_rule_map["torch._dynamo.utils.istype"] = UserFunctionVariable |
| |
| _torch_name_rule_map = [ |
| _manual_torch_name_rule_map, |
| torch_c_binding_in_graph_functions, |
| torch_non_c_binding_in_graph_functions, |
| ] |
| |
| self.assertTrue( |
| "torch._dynamo" not in torch._dynamo.trace_rules.LEGACY_MOD_INLINELIST |
| ) |
| self.assertTrue("torch._dynamo" not in torch._dynamo.trace_rules.MOD_INLINELIST) |
| |
| with unittest.mock.patch( |
| "torch._dynamo.trace_rules.torch_name_rule_map", |
| _torch_name_rule_map, |
| ), unittest.mock.patch( |
| "torch._dynamo.trace_rules.get_torch_obj_rule_map", |
| torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__, # bypass functools.lru_cache |
| ): |
| x = torch.rand(3) |
| opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) |
| ref = fn(x) |
| res = opt_fn(x) |
| self.assertEqual(ref, res) |
| |
| def test_force_inline_custom_function(self): |
| mod, func = create_dummy_module_and_function() |
| |
| def fn(x): |
| return func(x) |
| |
| _manual_torch_name_rule_map = manual_torch_name_rule_map.copy() |
| # Force inline `mod.func` by setting trace rule. |
| _manual_torch_name_rule_map[ |
| f"{mod.__name__}.{func.__name__}" |
| ] = UserFunctionVariable |
| |
| _torch_name_rule_map = [ |
| _manual_torch_name_rule_map, |
| torch_c_binding_in_graph_functions, |
| torch_non_c_binding_in_graph_functions, |
| ] |
| |
| with unittest.mock.patch( |
| "torch._dynamo.trace_rules.torch_name_rule_map", |
| _torch_name_rule_map, |
| ), unittest.mock.patch( |
| "torch._dynamo.trace_rules.get_torch_obj_rule_map", |
| torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__, |
| ): |
| # First adding the module to SKIP_DIRS so that it will be skipped by default. |
| torch._dynamo.trace_rules.add(mod.__name__) |
| x = torch.rand(3) |
| opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) |
| ref = fn(x) |
| res = opt_fn(x) |
| self.assertEqual(ref, res) |
| |
| |
| class TestModuleSurviveSkipFiles(torch._dynamo.test_case.TestCase): |
| @unittest.skipIf( |
| not torch.distributed.is_available(), |
| "need to import MLP module from distributed", |
| ) |
| def test_module_survive_skip_files(self): |
| from torch.testing._internal.common_fsdp import MLP |
| |
| model = MLP(3) |
| inp = torch.randn((2, 3)) |
| frame_count_before = torch._dynamo.convert_frame.FRAME_COUNTER |
| model.compile(backend="eager") |
| model(inp) |
| frame_count_after = torch._dynamo.convert_frame.FRAME_COUNTER |
| self.assertTrue( |
| frame_count_after > frame_count_before, "MLP did not survive skip files" |
| ) |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| |
| run_tests() |