| import builtins |
| import copy |
| import functools |
| import inspect |
| import math |
| import os |
| import warnings |
| import collections |
| from itertools import chain |
| from types import CodeType, FunctionType, ModuleType |
| from typing import ( |
| Any, |
| Callable, |
| Dict, |
| List, |
| NamedTuple, |
| Optional, |
| Set, |
| Tuple, |
| Type, |
| Union, |
| ) |
| |
| import torch |
| import torch.utils._pytree as pytree |
| from torch._C import ScriptObject # type: ignore[attr-defined] |
| |
| from ._compatibility import compatibility |
| from .graph import _PyTreeCodeGen, _PyTreeInfo, Graph |
| from .graph_module import GraphModule |
| from .node import Argument, base_types, map_aggregate |
| from .proxy import ParameterProxy, Proxy, TracerBase, Scope, ScopeContextManager |
| |
| HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS |
| |
| # These need to run in global scope to handle nested calls correctly |
| _orig_module_call: Callable = torch.nn.Module.__call__ |
| _orig_module_getattr: Callable = torch.nn.Module.__getattr__ |
| |
| _proxyable_classes: Dict[Type, None] = {} |
| |
| _is_fx_tracing_flag = False |
| |
| |
| def is_fx_tracing(): |
| return _is_fx_tracing_flag |
| |
| @compatibility(is_backward_compatible=True) |
| class ProxyableClassMeta(type): |
| """ |
| ProxyableClassMeta allows you to make construction of a given Python class |
| symbolically traceable. For example:: |
| |
| import torch |
| import torch.fx |
| |
| class TensorPair(metaclass=torch.fx.ProxyableClassMeta): |
| def __init__(self, left, right): |
| self.left, self.right = left, right |
| |
| def add(self, other): |
| l = self.left + other.left |
| r = self.right + other.right |
| return TensorPair(l, r) |
| |
| def mul(self, other): |
| l = self.left * other.left |
| r = self.right * other.right |
| return TensorPair(l, r) |
| |
| def use_tensor_pair_ctor(x : TensorPair, y : torch.Tensor): |
| s = x.add(TensorPair(y, y)) |
| return s.mul(x) |
| |
| x = TensorPair(torch.randn(5, 3), torch.randn(5, 3)) |
| y = torch.randn(5, 3) |
| ref_out = use_tensor_pair_ctor(x, y) |
| |
| traced = torch.fx.symbolic_trace(use_tensor_pair_ctor) |
| print(traced.code) |
| ''' |
| def forward(self, x : __main___TensorPair, y : torch.Tensor): |
| tensor_pair = __main___TensorPair(y, y); y = None |
| add = x.add(tensor_pair); tensor_pair = None |
| mul = add.mul(x); add = x = None |
| return mul |
| ''' |
| |
| From this example, we can see that construction of a class (``TensorPair``) |
| defined with ``ProxyableClassMeta`` as metaclass can be recorded in symbolic |
| tracing. |
| """ |
| |
| def __init__(cls, name, bases, attrs): |
| _proxyable_classes.setdefault(cls) |
| super().__init__(name, bases, attrs) |
| |
| def __call__(cls, *args, **kwargs): |
| instance = cls.__new__(cls) # type: ignore[call-overload] |
| |
| if not is_fx_tracing(): |
| cls.__init__(instance, *args, **kwargs) # type: ignore[misc] |
| return instance |
| |
| found_proxies = [] |
| |
| def check_proxy(a): |
| if isinstance(a, Proxy): |
| found_proxies.append(a) |
| |
| map_aggregate(args, check_proxy) |
| map_aggregate(kwargs, check_proxy) |
| |
| if len(found_proxies) != 0: |
| tracer = found_proxies[0].tracer |
| return tracer.create_proxy("call_function", cls, args, kwargs) |
| else: |
| cls.__init__(instance, *args, **kwargs) # type: ignore[misc] |
| return instance |
| |
| |
| def _patch_function(fn: FunctionType, nargs: int) -> FunctionType: |
| co = fn.__code__ |
| co_flags = co.co_flags & ~HAS_VARSTUFF |
| co_args: tuple |
| if hasattr(co, "co_qualname"): |
| # Python-3.11+ code signature |
| co_args = ( |
| nargs, |
| 0, |
| 0, |
| co.co_nlocals, |
| co.co_stacksize, |
| co_flags, |
| co.co_code, |
| co.co_consts, |
| co.co_names, |
| co.co_varnames, |
| co.co_filename, |
| co.co_name, |
| co.co_qualname, # type: ignore[attr-defined] |
| co.co_firstlineno, |
| co.co_lnotab, |
| co.co_exceptiontable, # type: ignore[attr-defined] |
| co.co_freevars, |
| co.co_cellvars, |
| ) |
| elif hasattr(co, "co_posonlyargcount"): |
| co_args = ( |
| nargs, |
| 0, |
| 0, |
| co.co_nlocals, |
| co.co_stacksize, |
| co_flags, |
| co.co_code, |
| co.co_consts, |
| co.co_names, |
| co.co_varnames, |
| co.co_filename, |
| co.co_name, |
| co.co_firstlineno, |
| co.co_lnotab, |
| co.co_freevars, |
| co.co_cellvars, |
| ) |
| else: |
| co_args = ( |
| nargs, |
| 0, |
| co.co_nlocals, |
| co.co_stacksize, |
| co_flags, |
| co.co_code, |
| co.co_consts, |
| co.co_names, |
| co.co_varnames, |
| co.co_filename, |
| co.co_name, |
| co.co_firstlineno, |
| co.co_lnotab, |
| co.co_freevars, |
| co.co_cellvars, |
| ) |
| new_code = CodeType(*co_args) # type: ignore[arg-type] |
| return FunctionType( |
| new_code, fn.__globals__, fn.__name__, fn.__defaults__, fn.__closure__ |
| ) |
| |
| # we need to insert placeholder nodes for *args and **kwargs |
| # we can't call this function normally, otherwise it would try to unpack them |
| # instead, let's make python think that args and kwargs are normal variables |
| |
| |
| @compatibility(is_backward_compatible=False) |
| class PHBase: |
| """ |
| Object representing an input placeholder to `concrete_args` |
| """ |
| |
| def __repr__(self): |
| return "PH" |
| |
| |
| PH = PHBase() |
| |
| |
| @compatibility(is_backward_compatible=False) |
| class PHWithMeta(PHBase): |
| """ |
| Object representing an input placeholder to `concrete_args` |
| """ |
| def __init__(self, ph_key: Optional[str] = None): |
| super().__init__() |
| |
| # Provide a hey for user to identify placeholder node during analysis |
| self.ph_key = ph_key |
| |
| |
| def _transfer_attrs(fr, to): |
| for attr_name in dir(fr): |
| attr_val = getattr(fr, attr_name) |
| if ( |
| not callable(attr_val) |
| and not attr_name.startswith("__") |
| and not hasattr(to, attr_name) |
| ): |
| setattr(to, attr_name, attr_val) |
| |
| |
| @compatibility(is_backward_compatible=True) |
| class Tracer(TracerBase): |
| # Reference: https://github.com/pytorch/pytorch/issues/54354 |
| # The first line of this docstring overrides the one Sphinx generates for the |
| # documentation. We need it so that Sphinx doesn't leak `math`s path from the |
| # build environment (e.g. `<module 'math' from '/leaked/path'). |
| |
| """Tracer(autowrap_modules=(math,), autowrap_functions=()) |
| |
| ``Tracer`` is the class that implements the symbolic tracing functionality |
| of ``torch.fx.symbolic_trace``. A call to ``symbolic_trace(m)`` is equivalent |
| to ``Tracer().trace(m)``. |
| |
| Tracer can be subclassed to override various behaviors of the tracing |
| process. The different behaviors that can be overridden are described |
| in the docstrings of the methods on this class. |
| """ |
| |
| # Not checking BC on this API because the default value for `autowrap_modules` |
| # includes the local filepath to the `math` module, which would jitter |
| # across machines. |
| @compatibility(is_backward_compatible=True) |
| def __init__( |
| self, |
| autowrap_modules: Tuple[ModuleType] = (math,), |
| autowrap_functions: Tuple[Callable, ...] = (), |
| param_shapes_constant: bool = False, |
| ) -> None: |
| # This method's signature is overridden by the first line of this class' |
| # docstring. If this method's signature is modified, the signature that |
| # overrides it also should be modified accordingly. |
| |
| """ |
| Construct a Tracer object. |
| |
| Args: |
| |
| autowrap_modules (Tuple[ModuleType]): defaults to `(math, )`, |
| Python modules whose functions should be wrapped automatically |
| without needing to use fx.wrap(). Backward-compatibility for |
| this parameter is guaranteed. |
| |
| autowrap_functions (Tuple[Callable, ...]): defaults to `()`, |
| Python functions that should be wrapped automatically without |
| needing to use fx.wrap(). Backward compatibility for this |
| parameter is guaranteed. |
| |
| param_shapes_constant (bool): When this flag is set, calls to shape, |
| size and a few other shape like attributes of a module's parameter |
| will be evaluated directly, rather than returning a new Proxy value |
| for an attribute access. Backward compatibility for this parameter |
| is guaranteed. |
| """ |
| |
| super().__init__() |
| |
| # Functions we will eagerly wrap when we see them while tracing |
| # this captures both `math.sqrt()` and `from math import sqrt` automatically |
| self._autowrap_function_ids: Set[int] = { |
| id(value) |
| for name, value in chain(*[m.__dict__.items() for m in autowrap_modules]) |
| if not name.startswith("_") and callable(value) |
| } |
| self._autowrap_function_ids.update({id(f) for f in autowrap_functions}) |
| |
| # Python modules to apply autowrap to at the start, in addition to |
| # modules we see while tracing |
| self._autowrap_search: List[ModuleType] = list(autowrap_modules) |
| self.param_shapes_constant = param_shapes_constant |
| |
| self.submodule_paths: Optional[Dict[torch.nn.Module, str]] = None |
| self.root_module_name: str = "" |
| # Maps the containing module's name to the operator name |
| self.scope = Scope("", None) |
| # Records the module call stack |
| self.module_stack = collections.OrderedDict() |
| # Mapping of node name to module scope |
| self.node_name_to_scope: Dict[str, Tuple[str, type]] = {} |
| |
| @compatibility(is_backward_compatible=True) |
| def create_arg(self, a: Any) -> "Argument": |
| """ |
| A method to specify the behavior of tracing when preparing values to |
| be used as arguments to nodes in the ``Graph``. |
| |
| By default, the behavior includes: |
| |
| #. Iterate through collection types (e.g. tuple, list, dict) and recursively |
| call ``create_args`` on the elements. |
| #. Given a Proxy object, return a reference to the underlying IR ``Node`` |
| #. Given a non-Proxy Tensor object, emit IR for various cases: |
| |
| * For a Parameter, emit a ``get_attr`` node referring to that Parameter |
| * For a non-Parameter Tensor, store the Tensor away in a special |
| attribute referring to that attribute. |
| |
| This method can be overridden to support more types. |
| |
| Args: |
| |
| a (Any): The value to be emitted as an ``Argument`` in the ``Graph``. |
| |
| |
| Returns: |
| |
| The value ``a`` converted into the appropriate ``Argument`` |
| """ |
| # The base tracer is used to construct Graphs when there is no associated |
| # module hierarchy, so it can never create parameter references. |
| # The default tracer adds the ability to refer to parameters when |
| # tracing modules. |
| if isinstance(a, torch.nn.Parameter): |
| for n, p in self.root.named_parameters(): |
| if a is p: |
| return self.create_node("get_attr", n, (), {}) |
| raise NameError("parameter is not a member of this module") |
| elif isinstance(a, torch.Tensor): |
| for n_, p_ in self.root.named_buffers(): |
| if a is p_: |
| return self.create_node("get_attr", n_, (), {}) |
| elif isinstance(a, torch.nn.Module): |
| for n_, p_ in self.root.named_modules(): |
| if a is p_: |
| return self.create_node("get_attr", n_, (), {}) |
| # For NamedTuple instances that appear literally as args, we emit |
| # a node to construct the NamedTuple and use that Node as the argument. |
| if isinstance(a, tuple) and hasattr(a, "_fields"): |
| args = tuple(self.create_arg(elem) for elem in a) |
| return self.create_node("call_function", a.__class__, args, {}) |
| |
| # Tensors do not have a reliable string repr() from which they can be |
| # constructed (and we probably don't want to rely on that, either), so |
| # for any constant Tensor values we encounter, first search for if they |
| # are an attribute of some module in the module hierarchy. If so, emit |
| # a get_attr to retrieve that tensor. Otherwise, we'll store away the |
| # tensor value into a special attribute on the Module s.t. we can |
| # retrieve it with a get_attr. |
| if isinstance(a, (torch.Tensor, ScriptObject)): |
| qualname: Optional[str] = self.tensor_attrs.get(a) |
| |
| # Tensor was not found in the Module hierarchy, stow it away in a |
| # special attribute and set the qualname to refer to that |
| if not qualname: |
| i = 0 |
| while True: |
| qualname = f"_tensor_constant{i}" |
| if not hasattr(self.root, qualname): |
| break |
| i += 1 |
| self.tensor_attrs[a] = qualname |
| setattr(self.root, qualname, a) |
| |
| return self.create_node("get_attr", qualname, (), {}) |
| |
| if type(a) in _proxyable_classes: |
| # This is an instance of a proxyable class for which we did not |
| # witness its construction. Intern this as a constant attribute |
| |
| # TODO: binary search |
| i = 0 |
| while True: |
| qualname = f"_{a.__class__.__name__}_constant_{i}" |
| if not hasattr(self.root, qualname): |
| break |
| i += 1 |
| setattr(self.root, qualname, a) |
| |
| return self.create_node("get_attr", qualname, (), {}) |
| |
| return super().create_arg(a) |
| |
| @compatibility(is_backward_compatible=True) |
| def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: |
| """ |
| A method to specify whether a given ``nn.Module`` is a "leaf" module. |
| |
| Leaf modules are the atomic units that appear in |
| the IR, referenced by ``call_module`` calls. By default, |
| Modules in the PyTorch standard library namespace (torch.nn) |
| are leaf modules. All other modules are traced through and |
| their constituent ops are recorded, unless specified otherwise |
| via this parameter. |
| |
| Args: |
| |
| m (Module): The module being queried about |
| module_qualified_name (str): The path to root of this module. For example, |
| if you have a module hierarchy where submodule ``foo`` contains |
| submodule ``bar``, which contains submodule ``baz``, that module will |
| appear with the qualified name ``foo.bar.baz`` here. |
| """ |
| return ( |
| (m.__module__.startswith("torch.nn") or m.__module__.startswith("torch.ao.nn")) |
| and not isinstance(m, torch.nn.Sequential) |
| ) |
| |
| @compatibility(is_backward_compatible=True) |
| def path_of_module(self, mod: torch.nn.Module) -> str: |
| """ |
| Helper method to find the qualified name of ``mod`` in the Module hierarchy |
| of ``root``. For example, if ``root`` has a submodule named ``foo``, which has |
| a submodule named ``bar``, passing ``bar`` into this function will return |
| the string "foo.bar". |
| |
| Args: |
| |
| mod (str): The ``Module`` to retrieve the qualified name for. |
| """ |
| # Prefer the O(1) algorithm |
| if self.submodule_paths: |
| path = self.submodule_paths.get(mod) |
| if path is None: |
| raise NameError("module is not installed as a submodule") |
| assert isinstance(path, str) |
| return path |
| # O(N^2) fallback in the case that we didn't store the submodule |
| # paths. |
| else: |
| for n, p in self.root.named_modules(): |
| if mod is p: |
| return n |
| raise NameError("module is not installed as a submodule") |
| |
| @compatibility(is_backward_compatible=True) |
| def call_module( |
| self, |
| m: torch.nn.Module, |
| forward: Callable[..., Any], |
| args: Tuple[Any, ...], |
| kwargs: Dict[str, Any], |
| ) -> Any: |
| """ |
| Method that specifies the behavior of this ``Tracer`` when it encounters |
| a call to an ``nn.Module`` instance. |
| |
| By default, the behavior is to check if the called module is a leaf module |
| via ``is_leaf_module``. If it is, emit a ``call_module`` node referring to |
| ``m`` in the ``Graph``. Otherwise, call the ``Module`` normally, tracing through |
| the operations in its ``forward`` function. |
| |
| This method can be overridden to--for example--create nested traced |
| GraphModules, or any other behavior you would want while tracing across |
| ``Module`` boundaries. |
| |
| Args: |
| |
| m (Module): The module for which a call is being emitted |
| forward (Callable): The forward() method of the ``Module`` to be invoked |
| args (Tuple): args of the module callsite |
| kwargs (Dict): kwargs of the module callsite |
| |
| Return: |
| |
| The return value from the Module call. In the case that a ``call_module`` |
| node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever |
| value was returned from the ``Module`` invocation. |
| """ |
| module_qualified_name = self.path_of_module(m) |
| with ScopeContextManager(self.scope, Scope(module_qualified_name, type(m))) as _scope: |
| # module_stack is an ordered dict so writing then deleting the |
| # entry is equivalent to push/pop on a list |
| self.module_stack[_scope.module_path] = (module_qualified_name, _scope.module_type) |
| if not self.is_leaf_module(m, module_qualified_name): |
| ret_val = forward(*args, **kwargs) |
| else: |
| ret_val = self.create_proxy("call_module", module_qualified_name, args, kwargs) |
| key, _ = self.module_stack.popitem(last=True) |
| assert key == _scope.module_path, f" Unexpected key {key}" |
| |
| return ret_val |
| |
| @compatibility(is_backward_compatible=False) |
| def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: Dict[str, Any]): |
| """ |
| Method that specifies the behavior of this ``Tracer`` when we call getattr |
| on a call to an ``nn.Module`` instance. |
| |
| By default, the behavior is to return a proxy value for the attribute. It |
| also stores the proxy value in the ``parameter_proxy_cache``, so that future |
| calls will reuse the proxy rather than creating a new one. |
| |
| This method can be overridden to --for example-- not return proxies when |
| querying parameters. |
| |
| Args: |
| |
| attr (str): The name of the attribute being queried |
| attr_val (Any): The value of the attribute |
| parameter_proxy_cache (Dict[str, Any]): A cache of attr names to proxies |
| |
| Return: |
| |
| The return value from the getattr call. |
| """ |
| def maybe_get_proxy_for_attr( |
| attr_val, collection_to_search, parameter_proxy_cache |
| ): |
| for n, p in collection_to_search: |
| if attr_val is p: |
| if n not in parameter_proxy_cache: |
| kwargs = {} |
| if ( |
| "proxy_factory_fn" |
| in inspect.signature(self.create_proxy).parameters |
| ): |
| kwargs["proxy_factory_fn"] = ( |
| None |
| if not self.param_shapes_constant |
| else lambda node: ParameterProxy( |
| self, node, n, attr_val |
| ) |
| ) |
| val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type] |
| parameter_proxy_cache[n] = val_proxy |
| return parameter_proxy_cache[n] |
| return None |
| |
| if isinstance(attr_val, torch.nn.Parameter): |
| maybe_parameter_proxy = maybe_get_proxy_for_attr( |
| attr_val, self.root.named_parameters(), parameter_proxy_cache |
| ) |
| if maybe_parameter_proxy is not None: |
| return maybe_parameter_proxy |
| |
| if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor): |
| maybe_buffer_proxy = maybe_get_proxy_for_attr( |
| attr_val, self.root.named_buffers(), parameter_proxy_cache |
| ) |
| if maybe_buffer_proxy is not None: |
| return maybe_buffer_proxy |
| |
| return attr_val |
| |
| # This method will be refactored |
| @compatibility(is_backward_compatible=False) |
| def create_args_for_root(self, root_fn, is_module, concrete_args=None): |
| """ |
| Create ``placeholder`` nodes corresponding to the signature of the ``root`` |
| Module. This method introspects root's signature and emits those |
| nodes accordingly, also supporting ``*args`` and ``**kwargs``. |
| """ |
| # In some cases, a function or method has been decorated with a wrapper |
| # defined via ``functools.wraps``. In this case, the outer code object |
| # will likely not contain the actual parameters we care about, so unwrap |
| # the function to get to the innermost callable. |
| fn_for_analysis = inspect.unwrap(root_fn) |
| co = fn_for_analysis.__code__ |
| total_args = co.co_argcount + co.co_kwonlyargcount |
| orig_args = list(co.co_varnames) |
| names_iter = iter(co.co_varnames) |
| args: List[Any] = [] |
| skip_arg_idx = 0 |
| if is_module: |
| if total_args == 0: |
| raise RuntimeError( |
| "``self`` argument cannot be part of *args expansion!" |
| ) |
| skip_arg_idx = 1 |
| next(names_iter) # skip self |
| args.append(self.root) |
| |
| sig = inspect.signature(fn_for_analysis) |
| |
| def proxy_placeholder(name: str): |
| if concrete_args is not None and name in concrete_args: |
| cnt = 0 |
| |
| def replace_ph(x): |
| nonlocal cnt |
| cnt += 1 |
| param = sig.parameters[name] |
| default = ( |
| () |
| if param.default is inspect.Parameter.empty |
| else (param.default,) |
| ) |
| out = self.create_proxy( |
| "placeholder", f"{name}_{str(cnt)}", default, {} |
| ) |
| if isinstance(x, PHBase): |
| if x != PH: |
| # Transfer attrs in the case where you're using a placeholder other |
| # than the singleton PH (PH has no attributes to transfer). |
| # Proxies were created out of the placeholders. |
| # Transfer any metadata (put on the placeholders in the form of |
| # attributes set by the user) from the placeholder to the |
| # underlying nodes (the proxy is unwrapped by the user, but |
| # the metadata should hold). |
| _transfer_attrs(fr=x, to=out.node) |
| |
| return out |
| # Union[int, bool] == bool in Python <= 3.6 |
| if ( |
| type(x) == bool |
| or type(x) in base_types |
| and type(x) != torch.Tensor |
| ): |
| torch._assert( |
| out == x, |
| f"{name} has been specialized to have value {x} but got another value", |
| ) |
| elif type(x) == type(None): |
| args = ( |
| out, |
| f"{name} has been specialized to have value None but got another value", |
| ) |
| self.create_proxy("call_function", _assert_is_none, args, {}) |
| else: |
| warnings.warn( |
| f"Was not able to add assertion to guarantee correct input {name} to " |
| f"specialized function. It is up to the user to make sure that your inputs match the " |
| f"inputs you specialized the function with." |
| ) |
| |
| return x |
| |
| return pytree.tree_map(replace_ph, concrete_args[name]) |
| if name[0] == "*": |
| default = () |
| else: |
| param = sig.parameters[name] |
| default = () if param.default is inspect.Parameter.empty else (param.default,) # type: ignore[assignment] |
| return self.create_proxy( |
| "placeholder", |
| name, |
| default, |
| {}, |
| type_expr=fn_for_analysis.__annotations__.get(name, None) |
| ) |
| |
| # This covers the very specific case where we are passing in flat |
| # concrete_args as a tuple, but our traced fn takes (*args, **kwargs). |
| # In this case, just take the concrete_args and pass them through. |
| name_idx = 0 |
| if isinstance(concrete_args, tuple) and \ |
| len(concrete_args) > 0 and \ |
| (co.co_flags & HAS_VARSTUFF) and \ |
| total_args == 1: |
| for concrete_arg in concrete_args: |
| out = self.create_proxy("placeholder", f"input_{name_idx}", (), {}) |
| if isinstance(concrete_arg, PHBase): |
| if concrete_arg != PH: |
| # Transfer attrs in the case where you're using a placeholder other |
| # than the singleton PH (PH has no attributes to transfer). |
| # Proxies were created out of the placeholders. |
| # Transfer any metadata (put on the placeholders in the form of |
| # attributes set by the user) from the placeholder to the |
| # underlying nodes (the proxy is unwrapped by the user, but |
| # the metadata should hold). |
| _transfer_attrs(fr=concrete_arg, to=out.node) |
| args.append(out) |
| name_idx += 1 |
| return root_fn, args |
| |
| arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)] |
| if isinstance(concrete_args, tuple): |
| if len(arg_names) != len(concrete_args): |
| raise RuntimeError( |
| f"Tracing expected {len(arg_names)} arguments but got {len(concrete_args)} concrete arguments" |
| ) |
| concrete_args = dict(zip(arg_names, concrete_args)) |
| args.extend(proxy_placeholder(names) for names in arg_names) |
| |
| if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF: |
| # TODO: type annotations for *args and **kwargs |
| if co.co_flags & inspect.CO_VARARGS: |
| args.append(proxy_placeholder("*" + next(names_iter))) |
| if co.co_flags & inspect.CO_VARKEYWORDS: |
| args.append(proxy_placeholder("**" + next(names_iter))) |
| root_fn = _patch_function(root_fn, len(args)) |
| |
| flat_args, in_spec = pytree.tree_flatten(tuple(args)) |
| if any(not isinstance(i, pytree.LeafSpec) for i in in_spec.children_specs): |
| # In the case that we have pytree-flattened inputs in |
| # `concrete_args`, generate a flattening wrapper around the |
| # original root function and return that. |
| self.graph._codegen = _PyTreeCodeGen( |
| _PyTreeInfo(orig_args[:total_args], in_spec, None) |
| ) |
| |
| def flatten_fn(*args): |
| tree_args = pytree.tree_unflatten(list(args), in_spec) |
| tree_out = root_fn(*tree_args) |
| out_args, out_spec = pytree.tree_flatten(tree_out) |
| assert isinstance(self.graph._codegen, _PyTreeCodeGen) |
| self.graph._codegen.pytree_info = ( |
| self.graph._codegen.pytree_info._replace(out_spec=out_spec) |
| ) |
| return out_args |
| |
| return flatten_fn, flat_args |
| return root_fn, args |
| |
| @compatibility(is_backward_compatible=True) |
| def trace( |
| self, |
| root: Union[torch.nn.Module, Callable[..., Any]], |
| concrete_args: Optional[Dict[str, Any]] = None, |
| ) -> Graph: |
| """ |
| Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root`` |
| can either be an ``nn.Module`` instance or a Python callable. |
| |
| Note that after this call, ``self.root`` may be different from the ``root`` passed |
| in here. For example, when a free function is passed to ``trace()``, we will |
| create an ``nn.Module`` instance to use as the root and add embedded constants |
| to. |
| |
| |
| Args: |
| |
| root (Union[Module, Callable]): Either a ``Module`` or a function to be |
| traced through. Backwards-compatibility for this parameter is |
| guaranteed. |
| concrete_args (Optional[Dict[str, any]]): Concrete arguments that should |
| not be treated as Proxies. This parameter is experimental and |
| its backwards-compatibility is *NOT* guaranteed. |
| |
| Returns: |
| |
| A ``Graph`` representing the semantics of the passed-in ``root``. |
| """ |
| global _is_fx_tracing_flag |
| old_is_fx_tracing_flag = _is_fx_tracing_flag |
| _is_fx_tracing_flag = True |
| try: |
| if isinstance(root, torch.nn.Module): |
| self.root = root |
| |
| assert hasattr( |
| type(root), self.traced_func_name |
| ), f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}" |
| |
| fn = getattr(type(root), self.traced_func_name) |
| self.root_module_name = root._get_name() |
| self.submodule_paths = {mod: name for name, mod in root.named_modules()} |
| else: |
| self.root = torch.nn.Module() |
| fn = root |
| |
| tracer_cls: Optional[Type[Tracer]] = getattr(self, "__class__", None) |
| self.graph = Graph(tracer_cls=tracer_cls) |
| if hasattr(fn, '__code__'): |
| code = fn.__code__ |
| self.graph._co_fields = { |
| 'co_name': code.co_name, |
| 'co_filename': code.co_filename, |
| 'co_firstlineno': code.co_firstlineno, |
| } |
| |
| # When we encounter a Tensor value that's not a parameter, we look if it |
| # is some other attribute on the model. Construct a dict mapping Tensor |
| # values to the qualified name here for efficiency. This is used downstream |
| # in create_arg |
| self.tensor_attrs: Dict[Union[torch.Tensor, ScriptObject], str] = {} |
| |
| def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: List[str]): |
| for k, v in m.__dict__.items(): |
| if isinstance(v, (torch.Tensor, ScriptObject)): |
| self.tensor_attrs[v] = ".".join(prefix_atoms + [k]) |
| for k, v in m.named_children(): |
| collect_tensor_attrs(v, prefix_atoms + [k]) |
| |
| collect_tensor_attrs(self.root, []) |
| |
| assert isinstance(fn, FunctionType) |
| |
| fn_globals = fn.__globals__ # run before it gets patched |
| fn, args = self.create_args_for_root( |
| fn, isinstance(root, torch.nn.Module), concrete_args |
| ) |
| |
| parameter_proxy_cache: Dict[ |
| str, Proxy |
| ] = {} # Reduce number of get_attr calls |
| |
| # Method dispatch on parameters is not recorded unless it's directly used. |
| # Thus, we need to insert a proxy when __getattr__ requests a parameter. |
| @functools.wraps(_orig_module_getattr) |
| def module_getattr_wrapper(mod, attr): |
| attr_val = _orig_module_getattr(mod, attr) |
| return self.getattr(attr, attr_val, parameter_proxy_cache) |
| |
| @functools.wraps(_orig_module_call) |
| def module_call_wrapper(mod, *args, **kwargs): |
| def forward(*args, **kwargs): |
| return _orig_module_call(mod, *args, **kwargs) |
| |
| _autowrap_check( |
| patcher, |
| getattr(getattr(mod, "forward", mod), "__globals__", {}), |
| self._autowrap_function_ids, |
| ) |
| return self.call_module(mod, forward, args, kwargs) |
| |
| with _Patcher() as patcher: |
| # allow duplicate patches to support the case of nested calls |
| patcher.patch_method( |
| torch.nn.Module, |
| "__getattr__", |
| module_getattr_wrapper, |
| deduplicate=False, |
| ) |
| patcher.patch_method( |
| torch.nn.Module, "__call__", module_call_wrapper, deduplicate=False |
| ) |
| _patch_wrapped_functions(patcher) |
| _autowrap_check(patcher, fn_globals, self._autowrap_function_ids) |
| for module in self._autowrap_search: |
| _autowrap_check( |
| patcher, module.__dict__, self._autowrap_function_ids |
| ) |
| self.create_node( |
| "output", |
| "output", |
| (self.create_arg(fn(*args)),), |
| {}, |
| type_expr=fn.__annotations__.get("return", None), |
| ) |
| |
| self.submodule_paths = None |
| finally: |
| _is_fx_tracing_flag = old_is_fx_tracing_flag |
| return self.graph |
| |
| def __deepcopy__(self, memo): |
| # _autowrap_search contains modules, which cannot be deepcopied. |
| new_tracer = Tracer.__new__(Tracer) |
| |
| for k, v in self.__dict__.items(): |
| if k in {'_autowrap_search'}: |
| new_obj = copy.copy(v) |
| else: |
| new_obj = copy.deepcopy(v, memo) |
| |
| new_tracer.__dict__[k] = new_obj |
| |
| return new_tracer |
| |
| |
| # Dictionary of (id(globals dict), function name) => globals_dict to patch for |
| # the purposes of the wrap() API. |
| # We key by the globals dict id and function name to ensure we're wrapping a given |
| # function only once. |
| _wrapped_fns_to_patch: Dict[Tuple[int, str], dict] = {} |
| |
| # List of methods on classes to wrap (class type, function name) |
| # this currently only works for Tensor.* methods that aren't traced properly |
| _wrapped_methods_to_patch: List[Tuple[type, str]] = [] |
| |
| if os.environ.get("FX_PATCH_GETITEM") == "1": |
| # This change is needed to trace models like PositionalEmbedding from BERT: |
| # https://github.com/pytorch/benchmark/blob/master/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/position.py |
| # but causes issues in quantization documented here: |
| # https://github.com/pytorch/pytorch/issues/50710 |
| # once that is fixed we can make this the default behavior. |
| _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__")) |
| |
| |
| def _find_proxy(*objects_to_search): |
| """ |
| Recursively search a data structure for a Proxy() and return it, |
| return None if not found. |
| """ |
| proxy = None |
| |
| def find_proxy(x): |
| nonlocal proxy |
| if isinstance(x, Proxy): |
| proxy = x |
| |
| map_aggregate(objects_to_search, find_proxy) |
| return proxy |
| |
| |
| def _create_wrapped_func(orig_fn): |
| @functools.wraps(orig_fn) |
| def wrapped(*args, **kwargs): |
| """ |
| Given an closed-over ``orig_function`` to invoke, search the args and kwargs for |
| a Proxy object. If there is one, emit a ``call_function`` node to preserve the |
| call to this leaf function directly. Otherwise, just return the results of |
| this function call, as this function is not being traced. |
| """ |
| proxy = _find_proxy(args, kwargs) |
| if proxy is not None: |
| return_proxy = proxy.tracer.create_proxy( |
| "call_function", orig_fn, args, kwargs |
| ) |
| return_proxy.node.meta["is_wrapped"] = True |
| return return_proxy |
| return orig_fn(*args, **kwargs) |
| |
| return wrapped |
| |
| |
| def _create_wrapped_method(cls, name): |
| orig_fn = getattr(cls, name) |
| |
| @functools.wraps(orig_fn) |
| def wrapped(*args, **kwargs): |
| """ |
| Search the args and kwargs for a Proxy object. If there is one, |
| emit a ``call_method`` node to preserve the call to this method |
| directly. Otherwise, just return the results of this function |
| call, as this function is not being traced. |
| """ |
| proxy = _find_proxy(args, kwargs) |
| if proxy is not None: |
| return proxy.tracer.create_proxy("call_method", name, args, kwargs) |
| return orig_fn(*args, **kwargs) |
| |
| return wrapped |
| |
| |
| class _PatchedFn(NamedTuple): |
| frame_dict: Any |
| fn_name: str |
| orig_fn: Any |
| |
| def revert(self): |
| raise NotImplementedError() |
| |
| |
| class _PatchedFnSetItem(_PatchedFn): |
| def revert(self): |
| self.frame_dict[self.fn_name] = self.orig_fn |
| |
| |
| class _PatchedFnDel(_PatchedFn): |
| def revert(self): |
| del self.frame_dict[self.fn_name] |
| |
| |
| class _PatchedFnSetAttr(_PatchedFn): |
| def revert(self): |
| setattr(self.frame_dict, self.fn_name, self.orig_fn) |
| |
| |
| class _Patcher: |
| def __init__(self): |
| super().__init__() |
| self.patches_made: List[_PatchedFn] = [] |
| self.visited: Set[int] = set() |
| |
| def patch( |
| self, |
| frame_dict: Dict[str, Any], |
| name: str, |
| new_fn: Callable, |
| deduplicate: bool = True, |
| ): |
| """ |
| Replace frame_dict[name] with new_fn until we exit the context manager. |
| """ |
| new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined] |
| if name not in frame_dict and hasattr(builtins, name): |
| self.patches_made.append(_PatchedFnDel(frame_dict, name, None)) |
| elif getattr(frame_dict[name], "__fx_already_patched", False): |
| return # already patched, no need to do it again |
| else: |
| self.patches_made.append( |
| _PatchedFnSetItem(frame_dict, name, frame_dict[name]) |
| ) |
| frame_dict[name] = new_fn |
| |
| def patch_method( |
| self, cls: type, name: str, new_fn: Callable, deduplicate: bool = True |
| ): |
| """ |
| Replace object_or_dict.name with new_fn until we exit the context manager. |
| """ |
| new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined] |
| orig_fn = getattr(cls, name) |
| if getattr(orig_fn, "__fx_already_patched", False): |
| return # already patched, no need to do it again |
| self.patches_made.append(_PatchedFnSetAttr(cls, name, orig_fn)) |
| setattr(cls, name, new_fn) |
| |
| def visit_once(self, thing: Any): |
| """Return True on the first call to with thing, otherwise false""" |
| idx = id(thing) |
| if idx in self.visited: |
| return False |
| self.visited.add(idx) |
| return True |
| |
| def __enter__(self): |
| return self |
| |
| def __exit__(self, exc_type, exc_val, exc_tb): |
| """ |
| Undo all the changes made via self.patch() and self.patch_method() |
| """ |
| while self.patches_made: |
| # unpatch in reverse order to handle duplicates correctly |
| self.patches_made.pop().revert() |
| self.visited.clear() |
| |
| |
| def _patch_wrapped_functions(patcher: _Patcher): |
| """ |
| Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap |
| the listed global functions in the `_create_wrapped_func` wrapper. |
| """ |
| for (_, name), frame_dict in _wrapped_fns_to_patch.copy().items(): |
| if name not in frame_dict and hasattr(builtins, name): |
| orig_fn = getattr(builtins, name) |
| else: |
| orig_fn = frame_dict[name] |
| patcher.patch(frame_dict, name, _create_wrapped_func(orig_fn)) |
| |
| for cls, name in _wrapped_methods_to_patch: |
| patcher.patch_method(cls, name, _create_wrapped_method(cls, name)) |
| |
| |
| def _autowrap_check( |
| patcher: _Patcher, frame_dict: Dict[str, Any], function_ids: Set[int] |
| ): |
| """ |
| Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them. |
| This method searches a scope for them and patches them if found. |
| """ |
| if patcher.visit_once(frame_dict): |
| for name, value in frame_dict.items(): |
| if ( |
| not name.startswith("_") |
| and callable(value) |
| and id(value) in function_ids |
| ): |
| patcher.patch(frame_dict, name, _create_wrapped_func(value)) |
| |
| |
| @compatibility(is_backward_compatible=True) |
| def wrap(fn_or_name: Union[str, Callable]): |
| """ |
| This function can be called at module-level scope to register fn_or_name as a "leaf function". |
| A "leaf function" will be preserved as a CallFunction node in the FX trace instead of being |
| traced through:: |
| |
| # foo/bar/baz.py |
| def my_custom_function(x, y): |
| return x * x + y * y |
| |
| torch.fx.wrap('my_custom_function') |
| |
| def fn_to_be_traced(x, y): |
| # When symbolic tracing, the below call to my_custom_function will be inserted into |
| # the graph rather than tracing it. |
| return my_custom_function(x, y) |
| |
| This function can also equivalently be used as a decorator:: |
| |
| # foo/bar/baz.py |
| @torch.fx.wrap |
| def my_custom_function(x, y): |
| return x * x + y * y |
| |
| A wrapped function can be thought of a "leaf function", analogous to the concept of |
| "leaf modules", that is, they are functions that are left as calls in the FX trace |
| rather than traced through. |
| |
| Args: |
| |
| fn_or_name (Union[str, Callable]): The function or name of the global function to insert into the |
| graph when it's called |
| """ |
| if not callable(fn_or_name) and not isinstance(fn_or_name, str): |
| raise RuntimeError( |
| "Unsupported type for global function! Must be either a callable or " |
| "string name" |
| ) |
| |
| if callable(fn_or_name): |
| assert not isinstance(fn_or_name, str) # to make mypy happy |
| fn_name = fn_or_name.__name__ |
| else: |
| assert isinstance( |
| fn_or_name, str |
| ), "fn_or_name must be a global function or string name" |
| fn_name = fn_or_name |
| |
| currentframe = inspect.currentframe() |
| assert currentframe is not None |
| f = currentframe.f_back |
| assert f is not None |
| if f.f_code.co_name != "<module>": |
| raise NotImplementedError("wrap must be called at the top level of a module") |
| |
| # consider implementing Callable version of this via _autowrap_function_ids / _autowrap_search |
| # semantics would be slightly different, but would add support `from x import wrapped_function` |
| _wrapped_fns_to_patch[(id(f.f_globals), fn_name)] = f.f_globals |
| return fn_or_name |
| |
| |
| @compatibility(is_backward_compatible=True) |
| def symbolic_trace( |
| root: Union[torch.nn.Module, Callable[..., Any]], |
| concrete_args: Optional[Dict[str, Any]] = None, |
| ) -> GraphModule: |
| """ |
| Symbolic tracing API |
| |
| Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule`` |
| constructed by recording operations seen while tracing through ``root``. |
| |
| ``concrete_args`` allows you to partially specialize your function, whether it's to remove control flow or data structures. |
| |
| For example:: |
| |
| def f(a, b): |
| if b == True: |
| return a |
| else: |
| return a*2 |
| |
| FX can typically not trace through this due to the presence of control |
| flow. However, we can use `concrete_args` to specialize on the value of |
| `b` to trace through this:: |
| |
| f = fx.symbolic_trace(f, concrete_args={'b': False}) |
| assert f(3, False) == 6 |
| |
| Note that although you can still pass in different values of `b`, they will be ignored. |
| |
| We can also use `concrete_args` to eliminate data-structure handling from |
| our function. This will use pytrees to flatten your input. To avoid |
| overspecializing, pass in `fx.PH` for values that shouldn't be |
| specialized. For example:: |
| |
| def f(x): |
| out = 0 |
| for v in x.values(): |
| out += v |
| return out |
| f = fx.symbolic_trace(f, concrete_args={'x': {'a': fx.PH, 'b': fx.PH, 'c': fx.PH}}) |
| assert f({'a': 1, 'b': 2, 'c': 4}) == 7 |
| |
| |
| Args: |
| root (Union[torch.nn.Module, Callable]): Module or function to be traced and converted |
| into a Graph representation. |
| concrete_args (Optional[Dict[str, any]]): Inputs to be partially specialized |
| |
| Returns: |
| GraphModule: a Module created from the recorded operations from ``root``. |
| """ |
| tracer = Tracer() |
| graph = tracer.trace(root, concrete_args) |
| name = ( |
| root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ |
| ) |
| return GraphModule(tracer.root, graph, name) |
| |
| |
| @wrap |
| def _assert_is_none(value, msg): |
| assert value is None, msg |