blob: e835932a4b52e64d07988b30511a9636478d9172 [file] [log] [blame]
import builtins
import collections
import copy
import functools
import inspect
import itertools
import math
import operator
import types
import warnings
from typing import cast, Dict, Optional, Set
try:
import numpy as np
except ModuleNotFoundError:
np = None
import torch
import torch._functorch.deprecated as deprecated_func
from torch.fx._symbolic_trace import is_fx_tracing
from . import config
from .external_utils import is_compiling
from .utils import is_safe_constant, NP_SUPPORTED_MODULES
"""
A note on allowed functions:
Dynamo consults this file to determine if a particular function/module
is allowed to appear as a node in its fx output.
If a function is disallowed, it may either be traced-through, or skipped.
Trace-through means dynamo will continue to trace the interior code for
the function/module rather than stopping at its boundary and recording it
as a node in the fx graph. Whether tracing through or allowing, the functionality
of the function/module is part of the dynamo graph. Caveat: if tracing through,
any interior operation could trigger its own graph-break.
Skips are determined by (torch/_dynamo/skipfiles.py) - see "a note on
skipfiles" there.
"""
def make_function_id_set(lazy_initializer):
"""
Track a set of `id()`s of objects which are either allowed or not
allowed to go into the generated FX graph. Use to test for torch.*,
numpy.*, builtins.*, etc.
Support user modification to permit customization of what can be
added to the graph and what will cause a graph break.
"""
class FunctionIdSet:
function_ids: Optional[Set[int]] = None
function_names: Optional[Dict[int, str]] = None
def __call__(self):
if self.function_ids is None:
value = lazy_initializer()
if isinstance(value, dict):
self.function_ids = set(value.keys())
self.function_names = value
else:
assert isinstance(value, set)
self.function_ids = value
return self.function_ids
def get_name(self, idx: int, default: str):
self() # lazy init
return self.function_names.get(idx, default)
def add(self, idx: int):
self() # lazy init
self.function_ids.add(idx)
def remove(self, idx: int):
if idx in self():
self.function_ids.remove(idx)
def __contains__(self, idx: int):
return idx in self()
return FunctionIdSet()
@make_function_id_set
def _disallowed_function_ids():
remove = [
True,
False,
None,
collections.OrderedDict,
copy.copy,
copy.deepcopy,
inspect.signature,
math.__package__,
torch.__builtins__,
torch.autocast_decrement_nesting,
torch.autocast_increment_nesting,
torch.autograd.grad,
torch.clear_autocast_cache,
torch.cuda.current_device,
torch.cuda.set_device,
torch.distributions.constraints.is_dependent,
torch.distributions.normal.Normal,
torch.inference_mode,
torch.set_anomaly_enabled,
torch.set_autocast_cache_enabled,
torch.set_autocast_cpu_dtype,
torch.set_autocast_cpu_enabled,
torch.set_autocast_enabled,
torch.set_autocast_gpu_dtype,
warnings.warn,
torch._C._dynamo.eval_frame.unsupported,
torch.Tensor.__init__,
]
if torch.distributed.is_available():
from torch.distributed import _functional_collectives
config.skipfiles_inline_module_allowlist.add(_functional_collectives)
# extract all dtypes from torch
dtypes = [
obj for obj in torch.__dict__.values() if isinstance(obj, type(torch.float32))
]
remove += dtypes
storage = [
obj
for obj in torch.__dict__.values()
if isinstance(obj, type(torch.FloatStorage))
]
remove += storage
# Distributed APIs don't work well with torch.compile.
if torch.distributed.is_available():
remove.extend(
torch.distributed.distributed_c10d.dynamo_unsupported_distributed_c10d_ops
)
return {id(x) for x in remove}
@make_function_id_set
def _allowed_function_ids():
"""
Walk torch.* and get the ids of all the stuff in it
"""
warnings.filterwarnings("ignore", category=UserWarning, module="torch.distributed")
torch_object_ids = dict()
def _is_allowed_module_prefix(obj):
allowed_modules = ("torch", "math")
# torch.nn.modules.rnn is disallowed because these modules internally
# flatten their parameters. This flattening process will call
# Tensor.set_ with a Storage, and Storages cannot be traced with
# AOTAutograd; so we need to graph-break. To ensure this, we inline
# these functions, rather than keep them opaque-ly in the graph.
disallowed_modules = (
"torch.optim.",
"torch.utils._foreach_utils", # omit the period so we match all the functions in this module
"torch.nn.modules.rnn.",
"torch._dynamo.",
"torch._C._dynamo.",
"torch._inductor.",
"torch._C.inductor.",
"torch.fx.",
"torch.distributed.fsdp.",
"torch.distributed._tensor.",
)
allowed_modules_dot = tuple([x + "." for x in allowed_modules])
module = inspect.getmodule(obj)
if module is None:
return False
mod_name = module.__name__
if any(mod_name.startswith(m) for m in disallowed_modules):
return False
return mod_name in allowed_modules or mod_name.startswith(allowed_modules_dot)
def _find_torch_objects(module):
if any(
module.__name__.startswith(mod_name)
for mod_name in config.allowed_functions_module_string_ignorelist
):
return
torch_object_ids[id(module)] = module.__name__
for name, obj in list(module.__dict__.items()):
if id(obj) not in torch_object_ids:
# Dynamo allows all builtins into the graph and does not attempt
# to introspect into them. We don't want to allow instances of
# HigherOrderOperator into the graph all the time (Dynamo needs
# to introspect the body functions of these HigherOrderOperator
# first, decide they are safe, and then allow them into the graph).
# So we exclude HigherOrderOperator from being a builtin.
import torch._ops
if isinstance(obj, torch._ops.HigherOrderOperator):
continue
# We want to trace through `grad` and `vmap`
if obj in (
torch.func.grad,
deprecated_func.grad,
torch.func.vmap,
deprecated_func.vmap,
):
continue
if isinstance(obj, types.ModuleType):
if obj.__name__.startswith("torch.") and _is_allowed_module_prefix(
obj
):
torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
_find_torch_objects(obj)
elif _is_allowed_module_prefix(obj):
torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
elif inspect.getmodule(obj) is None and not is_safe_constant(obj):
torch_object_ids[id(obj)] = f"{module.__name__}.{name}"
_find_torch_objects(torch)
_find_torch_objects(math)
# torch.Tensor.{fn}
for name in dir(torch.Tensor):
method = getattr(torch.Tensor, name)
if isinstance(
method, (types.MethodDescriptorType, types.WrapperDescriptorType)
):
torch_object_ids[id(method)] = f"torch.Tensor.{name}"
for idx in _disallowed_function_ids():
if idx in torch_object_ids:
del torch_object_ids[idx]
for extra in (is_fx_tracing, is_compiling):
torch_object_ids[id(extra)] = f"{extra.__module__}.{extra.__name__}"
return torch_object_ids
@make_function_id_set
def _allowed_user_defined_function_ids():
rv = {}
return rv
@make_function_id_set
def _builtin_function_ids():
rv = {
id(v): f"builtins.{k}"
for k, v in builtins.__dict__.items()
if not k.startswith("_") and callable(v)
}
rv.update(
{
id(v): f"operator.{k}"
for k, v in operator.__dict__.items()
if not k.startswith("_") and callable(v)
}
)
rv.update(
{id(v): f"functools.{v.__name__}" for v in (itertools.chain, itertools.islice)}
)
rv.update({id(cast): "typing.cast"})
rv[id(functools.reduce)] = "functools.reduce"
return rv
@make_function_id_set
def _numpy_function_ids():
rv = dict()
for mod in NP_SUPPORTED_MODULES:
rv.update(
{
id(v): f"{mod.__name__}.{k}"
for k, v in mod.__dict__.items()
if callable(v)
and (getattr(v, "__module__", None) or mod.__name__) == mod.__name__
}
)
return rv
@make_function_id_set
def _builtin_constant_ids():
"""
Collects constant builtins by eliminating callable items.
"""
rv = {
id(v): f"builtins.{k}"
for k, v in builtins.__dict__.items()
if not k.startswith("_") and not callable(v)
}
return rv
def is_allowed(obj):
"""Is this safe to trace like torch.add ?"""
# torch.ops is populated lazily so we don't necessarily have them in
# _allowed_function_ids. Figure it out by testing the type instead
# in those cases
if id(obj) in _disallowed_function_ids:
return False
return id(obj) in _allowed_function_ids or isinstance(
obj,
(torch._ops.OpOverloadPacket, torch._ops.OpOverload, torch._ops._OpNamespace),
)
def is_user_defined_allowed(obj):
return id(obj) in _allowed_user_defined_function_ids
def torch_get_name(obj, default):
"""Convert a torch.* function to a string"""
return _allowed_function_ids.get_name(id(obj), default)
def is_builtin_callable(obj):
return id(obj) in _builtin_function_ids
def is_builtin_constant(obj):
return id(obj) in _builtin_constant_ids
def is_numpy(obj):
if np is None:
return False
return isinstance(obj, np.ndarray) or id(obj) in _numpy_function_ids