blob: c82f64316fdc05741846874db70c72e72df334ce [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import dataclasses
import importlib
import inspect
import math
import types
import unittest
import warnings
from typing import Any, Dict, Set
import torch
import torch._dynamo.config as config
import torch._dynamo.test_case
import torch._functorch.deprecated as deprecated_func
from torch._dynamo.trace_rules import (
LEGACY_MOD_INLINELIST,
load_object,
manual_torch_name_rule_map,
MOD_INLINELIST,
torch_c_binding_in_graph_functions,
torch_non_c_binding_in_graph_functions,
)
from torch._dynamo.utils import hashable, is_safe_constant, istype
from torch._dynamo.variables import TorchInGraphFunctionVariable, UserFunctionVariable
try:
from .utils import create_dummy_module_and_function
except ImportError:
from utils import create_dummy_module_and_function
ignored_c_binding_in_graph_function_names = {
# Ignored because they have manual rules defined at `trace_rules.manual_torch_name_rule_map`.
"torch._nested_tensor_from_mask",
"torch._nested_from_padded",
"torch.sparse_compressed_tensor",
"torch.sparse_bsc_tensor",
"torch.sparse_bsr_tensor",
"torch.sparse_coo_tensor",
"torch.sparse_csc_tensor",
"torch.sparse_csr_tensor",
"torch.cuda._get_device_properties",
# Ignored and go through rules defined at `trace_rules.check`.
"torch._functionalize_are_all_mutations_under_no_grad_or_inference_mode",
"torch._cslt_sparse_mm_search",
"torch._C._abort",
"torch._C._mps_is_on_macos_or_newer",
"torch._C._swap_tensor_impl",
"torch._C._unsafe_reset_storage",
"torch._dynamo.eval_frame.reset_code",
"torch._C.autocast_decrement_nesting",
"torch._C.autocast_increment_nesting",
"torch._C.clear_autocast_cache",
"torch._C.set_anomaly_enabled",
"torch._C.set_autocast_cache_enabled",
"torch._C.set_autocast_cpu_dtype",
"torch._C.set_autocast_cpu_enabled",
"torch._C.set_autocast_enabled",
"torch._C.set_autocast_gpu_dtype",
"torch._C.set_autocast_ipu_dtype",
"torch._C.set_autocast_ipu_enabled",
"torch._C.set_autocast_xla_dtype",
"torch._C.set_autocast_xla_enabled",
"torch.resize_as_",
"torch.resize_as_sparse_",
"torch._C._data_address",
"torch._C._is_cow_tensor",
"torch._lazy_clone",
"torch._test_parallel_materialize",
"torch._C._storage_address",
"torch._C._pickle_save",
"torch._validate_sparse_compressed_tensor_args",
"torch._validate_sparse_csr_tensor_args",
"torch._validate_sparse_bsr_tensor_args",
"torch._validate_sparse_csc_tensor_args",
"torch._validate_sparse_coo_tensor_args",
"torch._validate_sparse_bsc_tensor_args",
"torch._validate_compressed_sparse_indices",
}
if torch._C._llvm_enabled():
ignored_c_binding_in_graph_function_names |= {
"torch._C._te.set_llvm_aot_workflow",
"torch._C._te.set_llvm_target_cpu",
"torch._C._te.set_llvm_target_attrs",
"torch._C._te.set_llvm_target_triple",
}
# Helper function to dump the torch name rule map generated based on
# the heuristic defined in gen_allowed_objs_and_ids.
def dump_allowed_torch_name_rule_map() -> None:
m = gen_allowed_objs_and_ids(record=True, c_binding_only=False).name_rule_map
for k, v in m.items():
print(f'"{k}": {v.__name__},')
@dataclasses.dataclass
class AllowedObjects:
"""
Track the objects, object id - name pairs, and name - dynamo wrapping rule pairs
from the heuristic defined in `gen_allowed_objs_and_ids`.
"""
object_ids: Dict[int, str]
c_binding_in_graph_functions: Set[Any]
non_c_binding_in_graph_functions: Set[Any]
name_rule_map: Dict[str, Any]
def gen_allowed_objs_and_ids(record=False, c_binding_only=True) -> AllowedObjects:
"""
Walk torch.* and get the ids of all the stuff in it
"""
warnings.filterwarnings("ignore", category=UserWarning, module="torch.distributed")
torch_object_ids = dict()
c_binding_in_graph_functions = set()
non_c_binding_in_graph_functions = set()
torch_name_rule_map = dict()
# In some platforms, these functions were loaded as classes instead of functions.
# To mitigate these weired cases, we need this special check.
def is_special_functions(obj):
return hashable(obj) and obj in {
torch._C._cuda_isCurrentStreamCapturing,
torch._C._graph_pool_handle,
}
# Add obj to c_binding_in_graph_functions set or non_c_binding_in_graph_functions set
# if it's a torch function or method.
# This is used to generate the in graph function list based on heuristic.
def heuristic_record_if_in_graph_function(obj, module, name):
try:
if hasattr(obj, "__wrapped__"):
obj = obj.__wrapped__
except Exception:
pass
if isinstance(
obj,
(
types.FunctionType,
types.BuiltinFunctionType,
types.MethodDescriptorType,
types.WrapperDescriptorType,
),
) or is_special_functions(obj):
torch_name_rule_map[
f"{module.__name__}.{name}"
] = TorchInGraphFunctionVariable
if c_binding_only:
if not hasattr(obj, "__code__"):
c_binding_in_graph_functions.add(obj)
else:
if hasattr(obj, "__code__"):
non_c_binding_in_graph_functions.add(obj)
else:
c_binding_in_graph_functions.add(obj)
def _is_allowed_module_prefix(obj):
allowed_modules = ("torch", "math")
# torch.nn.modules.rnn is disallowed because these modules internally
# flatten their parameters. This flattening process will call
# Tensor.set_ with a Storage, and Storages cannot be traced with
# AOTAutograd; so we need to graph-break. To ensure this, we inline
# these functions, rather than keep them opaque-ly in the graph.
disallowed_modules = [
"torch.optim.",
"torch.nn.modules.rnn.",
"torch._dynamo.",
"torch._C._dynamo.",
"torch._inductor.",
"torch._C.inductor.",
"torch.fx.",
"torch._C._autograd",
"torch._C._cudart",
"torch._C._distributed_autograd",
"torch._C._distributed_c10d",
"torch._C._distributed_rpc",
"torch._C._functorch",
"torch._C._monitor",
"torch._C._nvtx",
"torch._C._lazy",
"torch._C._profiler",
"torch.__config__",
"torch._custom_op",
"torch._decomp",
"torch._dispatch",
"torch._export",
"torch._functorch.make_functional",
"torch._functorch.compile_utils",
"torch._functorch.partitioners",
"torch._functorch.aot_autograd",
"torch._functorch.compilers",
"torch._functorch.fx_minifier",
"torch.autograd.profiler_util",
"torch.autograd.profiler",
"torch._jit_internal",
"torch._library",
"torch._lobpcg",
"torch._logging",
"torch._meta_registrations",
"torch._namedtensor_internals",
"torch._numpy",
"torch._sources",
"torch._subclasses",
"torch._tensor",
"torch._tensor_str",
"torch._utils",
"torch._utils_internal",
"torch._vmap_internals",
"torch.compiler",
"torch.distributed",
"torch.export",
"torch.hub",
"torch.jit",
"torch.library",
"torch.masked.maskedtensor",
"torch.nn.init",
"torch.nn.modules.module",
"torch.nn.parallel",
"torch.nn.utils",
"torch.multiprocessing",
"torch.onnx",
"torch.overrides",
"torch.package",
"torch.profiler",
"torch.serialization",
"torch.storage",
"torch.utils",
"torch.distributed.",
]
allowed_modules_dot = tuple([x + "." for x in allowed_modules])
module = inspect.getmodule(obj)
if module is None:
return False
mod_name = module.__name__
if any(mod_name.startswith(m) for m in disallowed_modules):
return False
return mod_name in allowed_modules or mod_name.startswith(allowed_modules_dot)
def _find_torch_objects(module):
if any(
module.__name__.startswith(mod_name)
for mod_name in config.allowed_functions_module_string_ignorelist
):
return
torch_object_ids[id(module)] = module.__name__
for name, obj in list(module.__dict__.items()):
if id(obj) not in torch_object_ids:
# Dynamo allows all builtins into the graph and does not attempt
# to introspect into them. We don't want to allow instances of
# HigherOrderOperator into the graph all the time (Dynamo needs
# to introspect the body functions of these HigherOrderOperator
# first, decide they are safe, and then allow them into the graph).
# So we exclude HigherOrderOperator from being a builtin.
import torch._ops
if isinstance(obj, torch._ops.HigherOrderOperator):
continue
# We want to trace through `grad` and `vmap`
if obj in (
torch.func.grad,
deprecated_func.grad,
torch.func.vmap,
deprecated_func.vmap,
torch.nn.functional.triplet_margin_with_distance_loss,
torch.cond,
):
continue
if isinstance(obj, types.ModuleType):
if obj.__name__.startswith("torch.") and _is_allowed_module_prefix(
obj
):
torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
_find_torch_objects(obj)
elif _is_allowed_module_prefix(obj):
if record:
heuristic_record_if_in_graph_function(obj, module, name)
torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
elif inspect.getmodule(obj) is None and not is_safe_constant(obj):
if record:
heuristic_record_if_in_graph_function(obj, module, name)
torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
_find_torch_objects(torch)
_find_torch_objects(math)
return AllowedObjects(
torch_object_ids,
c_binding_in_graph_functions,
non_c_binding_in_graph_functions,
torch_name_rule_map,
)
class TraceRuleTests(torch._dynamo.test_case.TestCase):
def _check_set_equality(self, generated, used, rule_map, ignored_set):
x = generated - used
y = used - generated
msg1 = (
f"New torch objects: {x} "
f"were not added to `trace_rules.{rule_map}` or `test_trace_rules.{ignored_set}`. "
"Refer the instruction in `torch/_dynamo/trace_rules.py` for more details."
)
msg2 = (
f"Existing torch objects: {y} were removed. "
f"Please remove them from `trace_rules.{rule_map}` or `test_trace_rules.{ignored_set}`. "
"Refer the instruction in `torch/_dynamo/trace_rules.py` for more details."
)
self.assertTrue(len(x) == 0, msg1)
self.assertTrue(len(y) == 0, msg2)
# We are using python function and module string names for these inlinelist,
# this unit test is to make sure the functions/modules can be correctly imported
# or loaded in case there is typo in the strings.
def test_skipfiles_inlinelist(self):
for m in LEGACY_MOD_INLINELIST.union(MOD_INLINELIST):
self.assertTrue(
isinstance(importlib.import_module(m), types.ModuleType),
f"{m} from trace_rules.MOD_INLINELIST/LEGACY_MOD_INLINELIST is not a python module, please check and correct it.",
)
@unittest.skip(
"This test keeps getting broken and our disable infra is not handling well. see #120627"
)
def test_torch_name_rule_map_updated(self):
# Generate the allowed objects based on heuristic defined in `allowed_functions.py`,
objs = gen_allowed_objs_and_ids(record=True, c_binding_only=True)
# Test C binding in graph functions are updated in torch_name_rule_map.
generated = objs.c_binding_in_graph_functions
used = set()
for x in (
set(torch_c_binding_in_graph_functions.keys())
| ignored_c_binding_in_graph_function_names
):
obj = load_object(x)
if obj is not None:
used.add(obj)
self._check_set_equality(
generated,
used,
"torch_c_binding_in_graph_functions",
"ignored_c_binding_in_graph_function_names",
)
# For non C binding in graph functions, we only test if they can be loaded successfully.
for f in torch_non_c_binding_in_graph_functions:
self.assertTrue(
isinstance(
load_object(f),
(
types.FunctionType,
types.BuiltinFunctionType,
types.MethodDescriptorType,
types.WrapperDescriptorType,
),
)
)
def test_force_inline_torch_function(self):
# `torch._dynamo.utils.istype` is skipped by default
def fn(x):
if istype(x, torch.Tensor):
return x + 1
else:
return x - 1
_manual_torch_name_rule_map = manual_torch_name_rule_map.copy()
# Force inline `torch._dynamo.utils.istype` by setting trace rule.
_manual_torch_name_rule_map["torch._dynamo.utils.istype"] = UserFunctionVariable
_torch_name_rule_map = [
_manual_torch_name_rule_map,
torch_c_binding_in_graph_functions,
torch_non_c_binding_in_graph_functions,
]
self.assertTrue(
"torch._dynamo" not in torch._dynamo.trace_rules.LEGACY_MOD_INLINELIST
)
self.assertTrue("torch._dynamo" not in torch._dynamo.trace_rules.MOD_INLINELIST)
with unittest.mock.patch(
"torch._dynamo.trace_rules.torch_name_rule_map",
_torch_name_rule_map,
), unittest.mock.patch(
"torch._dynamo.trace_rules.get_torch_obj_rule_map",
torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__, # bypass functools.lru_cache
):
x = torch.rand(3)
opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
def test_force_inline_custom_function(self):
mod, func = create_dummy_module_and_function()
def fn(x):
return func(x)
_manual_torch_name_rule_map = manual_torch_name_rule_map.copy()
# Force inline `mod.func` by setting trace rule.
_manual_torch_name_rule_map[
f"{mod.__name__}.{func.__name__}"
] = UserFunctionVariable
_torch_name_rule_map = [
_manual_torch_name_rule_map,
torch_c_binding_in_graph_functions,
torch_non_c_binding_in_graph_functions,
]
with unittest.mock.patch(
"torch._dynamo.trace_rules.torch_name_rule_map",
_torch_name_rule_map,
), unittest.mock.patch(
"torch._dynamo.trace_rules.get_torch_obj_rule_map",
torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__,
):
# First adding the module to SKIP_DIRS so that it will be skipped by default.
torch._dynamo.trace_rules.add(mod.__name__)
x = torch.rand(3)
opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
ref = fn(x)
res = opt_fn(x)
self.assertEqual(ref, res)
class TestModuleSurviveSkipFiles(torch._dynamo.test_case.TestCase):
@unittest.skipIf(
not torch.distributed.is_available(),
"need to import MLP module from distributed",
)
def test_module_survive_skip_files(self):
from torch.testing._internal.common_fsdp import MLP
model = MLP(3)
inp = torch.randn((2, 3))
frame_count_before = torch._dynamo.convert_frame.FRAME_COUNTER
model.compile(backend="eager")
model(inp)
frame_count_after = torch._dynamo.convert_frame.FRAME_COUNTER
self.assertTrue(
frame_count_after > frame_count_before, "MLP did not survive skip files"
)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()