| import builtins |
| import collections |
| import copy |
| import functools |
| import inspect |
| import itertools |
| import math |
| import operator |
| import types |
| import warnings |
| |
| from typing import cast, Dict, Optional, Set |
| |
| try: |
| import numpy as np |
| except ModuleNotFoundError: |
| np = None |
| |
| |
| import torch |
| import torch._functorch.deprecated as deprecated_func |
| from torch.fx._symbolic_trace import is_fx_tracing |
| |
| from . import config |
| from .external_utils import is_compiling |
| from .utils import is_safe_constant, NP_SUPPORTED_MODULES |
| |
| """ |
| A note on allowed functions: |
| |
| Dynamo consults this file to determine if a particular function/module |
| is allowed to appear as a node in its fx output. |
| |
| If a function is disallowed, it may either be traced-through, or skipped. |
| |
| Trace-through means dynamo will continue to trace the interior code for |
| the function/module rather than stopping at its boundary and recording it |
| as a node in the fx graph. Whether tracing through or allowing, the functionality |
| of the function/module is part of the dynamo graph. Caveat: if tracing through, |
| any interior operation could trigger its own graph-break. |
| |
| Skips are determined by (torch/_dynamo/skipfiles.py) - see "a note on |
| skipfiles" there. |
| """ |
| |
| |
| def make_function_id_set(lazy_initializer): |
| """ |
| Track a set of `id()`s of objects which are either allowed or not |
| allowed to go into the generated FX graph. Use to test for torch.*, |
| numpy.*, builtins.*, etc. |
| |
| Support user modification to permit customization of what can be |
| added to the graph and what will cause a graph break. |
| """ |
| |
| class FunctionIdSet: |
| function_ids: Optional[Set[int]] = None |
| function_names: Optional[Dict[int, str]] = None |
| |
| def __call__(self): |
| if self.function_ids is None: |
| value = lazy_initializer() |
| if isinstance(value, dict): |
| self.function_ids = set(value.keys()) |
| self.function_names = value |
| else: |
| assert isinstance(value, set) |
| self.function_ids = value |
| return self.function_ids |
| |
| def get_name(self, idx: int, default: str): |
| self() # lazy init |
| return self.function_names.get(idx, default) |
| |
| def add(self, idx: int): |
| self() # lazy init |
| self.function_ids.add(idx) |
| |
| def remove(self, idx: int): |
| if idx in self(): |
| self.function_ids.remove(idx) |
| |
| def __contains__(self, idx: int): |
| return idx in self() |
| |
| return FunctionIdSet() |
| |
| |
| @make_function_id_set |
| def _disallowed_function_ids(): |
| remove = [ |
| True, |
| False, |
| None, |
| collections.OrderedDict, |
| copy.copy, |
| copy.deepcopy, |
| inspect.signature, |
| math.__package__, |
| torch.__builtins__, |
| torch.autocast_decrement_nesting, |
| torch.autocast_increment_nesting, |
| torch.autograd.grad, |
| torch.clear_autocast_cache, |
| torch.cuda.current_device, |
| torch.cuda.set_device, |
| torch.distributions.constraints.is_dependent, |
| torch.distributions.normal.Normal, |
| torch.inference_mode, |
| torch.set_anomaly_enabled, |
| torch.set_autocast_cache_enabled, |
| torch.set_autocast_cpu_dtype, |
| torch.set_autocast_cpu_enabled, |
| torch.set_autocast_enabled, |
| torch.set_autocast_gpu_dtype, |
| warnings.warn, |
| torch._C._dynamo.eval_frame.unsupported, |
| torch.Tensor.__init__, |
| ] |
| if torch.distributed.is_available(): |
| from torch.distributed import _functional_collectives |
| |
| config.skipfiles_inline_module_allowlist.add(_functional_collectives) |
| |
| # extract all dtypes from torch |
| dtypes = [ |
| obj for obj in torch.__dict__.values() if isinstance(obj, type(torch.float32)) |
| ] |
| remove += dtypes |
| storage = [ |
| obj |
| for obj in torch.__dict__.values() |
| if isinstance(obj, type(torch.FloatStorage)) |
| ] |
| remove += storage |
| |
| # Distributed APIs don't work well with torch.compile. |
| if torch.distributed.is_available(): |
| remove.extend( |
| torch.distributed.distributed_c10d.dynamo_unsupported_distributed_c10d_ops |
| ) |
| |
| return {id(x) for x in remove} |
| |
| |
| @make_function_id_set |
| def _allowed_function_ids(): |
| """ |
| Walk torch.* and get the ids of all the stuff in it |
| """ |
| warnings.filterwarnings("ignore", category=UserWarning, module="torch.distributed") |
| torch_object_ids = dict() |
| |
| def _is_allowed_module_prefix(obj): |
| allowed_modules = ("torch", "math") |
| # torch.nn.modules.rnn is disallowed because these modules internally |
| # flatten their parameters. This flattening process will call |
| # Tensor.set_ with a Storage, and Storages cannot be traced with |
| # AOTAutograd; so we need to graph-break. To ensure this, we inline |
| # these functions, rather than keep them opaque-ly in the graph. |
| disallowed_modules = ( |
| "torch.optim.", |
| "torch.utils._foreach_utils", # omit the period so we match all the functions in this module |
| "torch.nn.modules.rnn.", |
| "torch._dynamo.", |
| "torch._C._dynamo.", |
| "torch._inductor.", |
| "torch._C.inductor.", |
| "torch.fx.", |
| "torch.distributed.fsdp.", |
| "torch.distributed._tensor.", |
| ) |
| allowed_modules_dot = tuple([x + "." for x in allowed_modules]) |
| module = inspect.getmodule(obj) |
| if module is None: |
| return False |
| |
| mod_name = module.__name__ |
| |
| if any(mod_name.startswith(m) for m in disallowed_modules): |
| return False |
| |
| return mod_name in allowed_modules or mod_name.startswith(allowed_modules_dot) |
| |
| def _find_torch_objects(module): |
| if any( |
| module.__name__.startswith(mod_name) |
| for mod_name in config.allowed_functions_module_string_ignorelist |
| ): |
| return |
| torch_object_ids[id(module)] = module.__name__ |
| for name, obj in list(module.__dict__.items()): |
| if id(obj) not in torch_object_ids: |
| # Dynamo allows all builtins into the graph and does not attempt |
| # to introspect into them. We don't want to allow instances of |
| # HigherOrderOperator into the graph all the time (Dynamo needs |
| # to introspect the body functions of these HigherOrderOperator |
| # first, decide they are safe, and then allow them into the graph). |
| # So we exclude HigherOrderOperator from being a builtin. |
| import torch._ops |
| |
| if isinstance(obj, torch._ops.HigherOrderOperator): |
| continue |
| |
| # We want to trace through `grad` and `vmap` |
| if obj in ( |
| torch.func.grad, |
| deprecated_func.grad, |
| torch.func.vmap, |
| deprecated_func.vmap, |
| ): |
| continue |
| |
| if isinstance(obj, types.ModuleType): |
| if obj.__name__.startswith("torch.") and _is_allowed_module_prefix( |
| obj |
| ): |
| torch_object_ids[id(obj)] = f"{module.__name__}.{name}" |
| _find_torch_objects(obj) |
| elif _is_allowed_module_prefix(obj): |
| torch_object_ids[id(obj)] = f"{module.__name__}.{name}" |
| elif inspect.getmodule(obj) is None and not is_safe_constant(obj): |
| torch_object_ids[id(obj)] = f"{module.__name__}.{name}" |
| |
| _find_torch_objects(torch) |
| _find_torch_objects(math) |
| |
| # torch.Tensor.{fn} |
| for name in dir(torch.Tensor): |
| method = getattr(torch.Tensor, name) |
| if isinstance( |
| method, (types.MethodDescriptorType, types.WrapperDescriptorType) |
| ): |
| torch_object_ids[id(method)] = f"torch.Tensor.{name}" |
| |
| for idx in _disallowed_function_ids(): |
| if idx in torch_object_ids: |
| del torch_object_ids[idx] |
| |
| for extra in (is_fx_tracing, is_compiling): |
| torch_object_ids[id(extra)] = f"{extra.__module__}.{extra.__name__}" |
| |
| return torch_object_ids |
| |
| |
| @make_function_id_set |
| def _allowed_user_defined_function_ids(): |
| rv = {} |
| return rv |
| |
| |
| @make_function_id_set |
| def _builtin_function_ids(): |
| rv = { |
| id(v): f"builtins.{k}" |
| for k, v in builtins.__dict__.items() |
| if not k.startswith("_") and callable(v) |
| } |
| rv.update( |
| { |
| id(v): f"operator.{k}" |
| for k, v in operator.__dict__.items() |
| if not k.startswith("_") and callable(v) |
| } |
| ) |
| rv.update( |
| {id(v): f"functools.{v.__name__}" for v in (itertools.chain, itertools.islice)} |
| ) |
| rv.update({id(cast): "typing.cast"}) |
| rv[id(functools.reduce)] = "functools.reduce" |
| return rv |
| |
| |
| @make_function_id_set |
| def _numpy_function_ids(): |
| rv = dict() |
| for mod in NP_SUPPORTED_MODULES: |
| rv.update( |
| { |
| id(v): f"{mod.__name__}.{k}" |
| for k, v in mod.__dict__.items() |
| if callable(v) |
| and (getattr(v, "__module__", None) or mod.__name__) == mod.__name__ |
| } |
| ) |
| return rv |
| |
| |
| @make_function_id_set |
| def _builtin_constant_ids(): |
| """ |
| Collects constant builtins by eliminating callable items. |
| """ |
| rv = { |
| id(v): f"builtins.{k}" |
| for k, v in builtins.__dict__.items() |
| if not k.startswith("_") and not callable(v) |
| } |
| return rv |
| |
| |
| def is_allowed(obj): |
| """Is this safe to trace like torch.add ?""" |
| # torch.ops is populated lazily so we don't necessarily have them in |
| # _allowed_function_ids. Figure it out by testing the type instead |
| # in those cases |
| if id(obj) in _disallowed_function_ids: |
| return False |
| |
| return id(obj) in _allowed_function_ids or isinstance( |
| obj, |
| (torch._ops.OpOverloadPacket, torch._ops.OpOverload, torch._ops._OpNamespace), |
| ) |
| |
| |
| def is_user_defined_allowed(obj): |
| return id(obj) in _allowed_user_defined_function_ids |
| |
| |
| def torch_get_name(obj, default): |
| """Convert a torch.* function to a string""" |
| return _allowed_function_ids.get_name(id(obj), default) |
| |
| |
| def is_builtin_callable(obj): |
| return id(obj) in _builtin_function_ids |
| |
| |
| def is_builtin_constant(obj): |
| return id(obj) in _builtin_constant_ids |
| |
| |
| def is_numpy(obj): |
| if np is None: |
| return False |
| return isinstance(obj, np.ndarray) or id(obj) in _numpy_function_ids |