| from .graph_module import GraphModule |
| from .graph import Graph |
| from .node import Argument, Node, Target, map_arg, map_aggregate |
| from .proxy import Proxy |
| from ._symbolic_trace import Tracer |
| from ._compatibility import compatibility |
| from . import config |
| import torch.fx.traceback as fx_traceback |
| import torch |
| from typing import Any, Dict, Iterator, List, Optional, Tuple, Union |
| import inspect |
| from contextlib import contextmanager |
| from torch.hub import tqdm |
| |
| __all__ = ['Interpreter', 'Transformer'] |
| |
| @compatibility(is_backward_compatible=True) |
| class Interpreter: |
| """ |
| An Interpreter executes an FX graph Node-by-Node. This pattern |
| can be useful for many things, including writing code |
| transformations as well as analysis passes. |
| |
| Methods in the Interpreter class can be overridden to customize |
| the behavior of execution. The map of overrideable methods |
| in terms of call hierarchy:: |
| |
| run() |
| +-- run_node |
| +-- placeholder() |
| +-- get_attr() |
| +-- call_function() |
| +-- call_method() |
| +-- call_module() |
| +-- output() |
| |
| Example: |
| |
| Suppose we want to swap all instances of ``torch.neg`` with |
| ``torch.sigmoid`` and vice versa (including their ``Tensor`` |
| method equivalents). We could subclass Interpreter like so:: |
| |
| class NegSigmSwapInterpreter(Interpreter): |
| def call_function(self, target : Target, |
| args : Tuple, kwargs : Dict) -> Any: |
| if target == torch.sigmoid: |
| return torch.neg(*args, **kwargs) |
| return super().call_function(n) |
| |
| def call_method(self, target : Target, |
| args : Tuple, kwargs : Dict) -> Any: |
| if target == 'neg': |
| call_self, *args_tail = args |
| return call_self.sigmoid(*args_tail, **kwargs) |
| return super().call_method(n) |
| |
| def fn(x): |
| return torch.sigmoid(x).neg() |
| |
| gm = torch.fx.symbolic_trace(fn) |
| input = torch.randn(3, 4) |
| result = NegSigmSwapInterpreter(gm).run(input) |
| torch.testing.assert_close(result, torch.neg(input).sigmoid()) |
| |
| Args: |
| module (GraphModule): The module to be executed |
| garbage_collect_values (bool): Whether to delete values after their last |
| use within the Module's execution. This ensures optimal memory usage during |
| execution. This can be disabled to, for example, examine all of the intermediate |
| values in the execution by looking at the ``Interpreter.env`` attribute. |
| """ |
| @compatibility(is_backward_compatible=True) |
| def __init__(self, module : GraphModule, garbage_collect_values : bool = True): |
| assert isinstance(module, GraphModule) |
| self.module = module |
| self.submodules = dict(self.module.named_modules()) |
| self.env : Dict[Node, Any] = {} |
| self.name = "Interpreter" |
| self.garbage_collect_values = garbage_collect_values |
| self.extra_traceback = True |
| |
| if self.garbage_collect_values: |
| # Run through reverse nodes and record the first instance of a use |
| # of a given node. This represents the *last* use of the node in the |
| # execution order of the program, which we will use to free unused |
| # values |
| node_to_last_use : Dict[Node, Node] = {} |
| self.user_to_last_uses : Dict[Node, List[Node]] = {} |
| |
| def register_last_uses(n : Node, user : Node): |
| if n not in node_to_last_use: |
| node_to_last_use[n] = user |
| self.user_to_last_uses.setdefault(user, []).append(n) |
| |
| for node in reversed(self.module.graph.nodes): |
| map_arg(node.args, lambda n: register_last_uses(n, node)) |
| map_arg(node.kwargs, lambda n: register_last_uses(n, node)) |
| |
| @compatibility(is_backward_compatible=True) |
| def run(self, *args, initial_env : Optional[Dict[Node, Any]] = None, enable_io_processing : bool = True) -> Any: |
| """ |
| Run `module` via interpretation and return the result. |
| |
| Args: |
| *args: The arguments to the Module to run, in positional order |
| initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution. |
| This is a dict mapping `Node` to any value. This can be used, for example, to |
| pre-populate results for certain `Nodes` so as to do only partial evaluation within |
| the interpreter. |
| enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and |
| process_outputs function first before using them. |
| |
| Returns: |
| Any: The value returned from executing the Module |
| """ |
| self.env = initial_env if initial_env is not None else {} |
| |
| # Positional function args are consumed left-to-right by |
| # `placeholder` nodes. Use an iterator to keep track of |
| # position and extract those values. |
| if enable_io_processing: |
| args = self.module.graph.process_inputs(*args) |
| self.args_iter : Iterator[Any] = iter(args) |
| pbar = tqdm(total=len(self.module.graph.nodes), |
| desc=f"{self.name}: {str(list(self.module.graph.nodes)) if config.verbose_progress else ''}", |
| initial=0, position=0, leave=True, disable=config.disable_progress, delay=0) |
| |
| for node in self.module.graph.nodes: |
| pbar.update(1) |
| if node in self.env: |
| # Short circuit if we have this value. This could |
| # be used, for example, for partial evaluation |
| # where the caller has pre-populated `env` with |
| # values for a subset of the program. |
| continue |
| |
| try: |
| self.env[node] = self.run_node(node) |
| except Exception as e: |
| if self.extra_traceback: |
| msg = f"While executing {node.format_node()}" |
| msg = f'{e.args[0]}\n\n{msg}' if e.args else str(msg) |
| msg += f"\nOriginal traceback:\n{node.stack_trace}" |
| e.args = (msg,) + e.args[1:] |
| if isinstance(e, KeyError): |
| raise RuntimeError(*e.args) from e |
| raise |
| |
| if self.garbage_collect_values: |
| for to_delete in self.user_to_last_uses.get(node, []): |
| del self.env[to_delete] |
| |
| if node.op == 'output': |
| output_val = self.env[node] |
| return self.module.graph.process_outputs(output_val) if enable_io_processing else output_val |
| |
| @compatibility(is_backward_compatible=True) |
| def boxed_run(self, args_list): |
| """ |
| Run `module` via interpretation and return the result. This uses the "boxed" |
| calling convention, where you pass a list of arguments, which will be cleared |
| by the interpreter. This ensures that input tensors are promptly deallocated. |
| """ |
| args_iter = iter(args_list) |
| env = {} |
| for n in self.module.graph.nodes: |
| if n.op == "placeholder": |
| env[n] = next(args_iter) |
| args_list.clear() |
| return self.run(initial_env=env) |
| |
| @contextmanager |
| def _set_current_node(self, node): |
| with fx_traceback.set_current_meta(node): |
| yield |
| |
| @compatibility(is_backward_compatible=True) |
| def run_node(self, n : Node) -> Any: |
| """ |
| Run a specific node ``n`` and return the result. |
| Calls into placeholder, get_attr, call_function, |
| call_method, call_module, or output depending |
| on ``node.op`` |
| |
| Args: |
| n (Node): The Node to execute |
| |
| Returns: |
| Any: The result of executing ``n`` |
| """ |
| with self._set_current_node(n): |
| args, kwargs = self.fetch_args_kwargs_from_env(n) |
| assert isinstance(args, tuple) |
| assert isinstance(kwargs, dict) |
| return getattr(self, n.op)(n.target, args, kwargs) |
| |
| # Main Node running APIs |
| @compatibility(is_backward_compatible=True) |
| def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: |
| """ |
| Execute a ``placeholder`` node. Note that this is stateful: |
| ``Interpreter`` maintains an internal iterator over |
| arguments passed to ``run`` and this method returns |
| next() on that iterator. |
| |
| Args: |
| target (Target): The call target for this node. See |
| `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for |
| details on semantics |
| args (Tuple): Tuple of positional args for this invocation |
| kwargs (Dict): Dict of keyword arguments for this invocation |
| |
| Returns: |
| Any: The argument value that was retrieved. |
| """ |
| assert isinstance(target, str) |
| if target.startswith('*'): |
| # For a starred parameter e.g. `*args`, retrieve all |
| # remaining values from the args list. |
| return list(self.args_iter) |
| else: |
| try: |
| return next(self.args_iter) |
| except StopIteration as si: |
| if len(args) > 0: |
| return args[0] |
| else: |
| raise RuntimeError(f'Expected positional argument for parameter {target}, but one was not passed in!') from si |
| |
| @compatibility(is_backward_compatible=True) |
| def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: |
| """ |
| Execute a ``get_attr`` node. Will retrieve an attribute |
| value from the ``Module`` hierarchy of ``self.module``. |
| |
| Args: |
| target (Target): The call target for this node. See |
| `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for |
| details on semantics |
| args (Tuple): Tuple of positional args for this invocation |
| kwargs (Dict): Dict of keyword arguments for this invocation |
| |
| Return: |
| Any: The value of the attribute that was retrieved |
| """ |
| assert isinstance(target, str) |
| return self.fetch_attr(target) |
| |
| @compatibility(is_backward_compatible=True) |
| def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: |
| """ |
| Execute a ``call_function`` node and return the result. |
| |
| Args: |
| target (Target): The call target for this node. See |
| `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for |
| details on semantics |
| args (Tuple): Tuple of positional args for this invocation |
| kwargs (Dict): Dict of keyword arguments for this invocation |
| |
| Return |
| Any: The value returned by the function invocation |
| """ |
| assert not isinstance(target, str) |
| |
| # Execute the function and return the result |
| return target(*args, **kwargs) |
| |
| @compatibility(is_backward_compatible=True) |
| def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: |
| """ |
| Execute a ``call_method`` node and return the result. |
| |
| Args: |
| target (Target): The call target for this node. See |
| `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for |
| details on semantics |
| args (Tuple): Tuple of positional args for this invocation |
| kwargs (Dict): Dict of keyword arguments for this invocation |
| |
| Return |
| Any: The value returned by the method invocation |
| """ |
| # args[0] is the `self` object for this method call |
| self_obj, *args_tail = args |
| |
| # Execute the method and return the result |
| assert isinstance(target, str) |
| return getattr(self_obj, target)(*args_tail, **kwargs) |
| |
| @compatibility(is_backward_compatible=True) |
| def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: |
| """ |
| Execute a ``call_module`` node and return the result. |
| |
| Args: |
| target (Target): The call target for this node. See |
| `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for |
| details on semantics |
| args (Tuple): Tuple of positional args for this invocation |
| kwargs (Dict): Dict of keyword arguments for this invocation |
| |
| Return |
| Any: The value returned by the module invocation |
| """ |
| # Retrieve executed args and kwargs values from the environment |
| |
| # Execute the method and return the result |
| assert isinstance(target, str) |
| submod = self.fetch_attr(target) |
| |
| return submod(*args, **kwargs) |
| |
| @compatibility(is_backward_compatible=True) |
| def output(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: |
| """ |
| Execute an ``output`` node. This really just retrieves |
| the value referenced by the ``output`` node and returns it. |
| |
| Args: |
| target (Target): The call target for this node. See |
| `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for |
| details on semantics |
| args (Tuple): Tuple of positional args for this invocation |
| kwargs (Dict): Dict of keyword arguments for this invocation |
| |
| Return: |
| Any: The return value referenced by the output node |
| """ |
| return args[0] |
| |
| # Helper methods |
| @compatibility(is_backward_compatible=True) |
| def fetch_attr(self, target : str): |
| """ |
| Fetch an attribute from the ``Module`` hierarchy of ``self.module``. |
| |
| Args: |
| target (str): The fully-qualified name of the attribute to fetch |
| |
| Return: |
| Any: The value of the attribute. |
| """ |
| target_atoms = target.split('.') |
| attr_itr = self.module |
| for i, atom in enumerate(target_atoms): |
| if not hasattr(attr_itr, atom): |
| raise RuntimeError(f"Node referenced nonexistent target {'.'.join(target_atoms[:i])}") |
| attr_itr = getattr(attr_itr, atom) |
| return attr_itr |
| |
| @compatibility(is_backward_compatible=True) |
| def fetch_args_kwargs_from_env(self, n : Node) -> Tuple[Tuple, Dict]: |
| """ |
| Fetch the concrete values of ``args`` and ``kwargs`` of node ``n`` |
| from the current execution environment. |
| |
| Args: |
| n (Node): The node for which ``args`` and ``kwargs`` should be fetched. |
| |
| Return: |
| Tuple[Tuple, Dict]: ``args`` and ``kwargs`` with concrete values for ``n``. |
| """ |
| args = self.map_nodes_to_values(n.args, n) |
| assert isinstance(args, tuple) |
| kwargs = self.map_nodes_to_values(n.kwargs, n) |
| assert isinstance(kwargs, dict) |
| return args, kwargs |
| |
| @compatibility(is_backward_compatible=True) |
| def map_nodes_to_values(self, args : Argument, n : Node) -> Argument: |
| """ |
| Recursively descend through ``args`` and look up the concrete value |
| for each ``Node`` in the current execution environment. |
| |
| Args: |
| args (Argument): Data structure within which to look up concrete values |
| |
| n (Node): Node to which ``args`` belongs. This is only used for error reporting. |
| """ |
| def load_arg(n_arg : Node) -> Any: |
| if n_arg not in self.env: |
| raise RuntimeError(f'Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() ' |
| f'to diagnose such issues') |
| return self.env[n_arg] |
| return map_arg(args, load_arg) |
| |
| @compatibility(is_backward_compatible=True) |
| class Transformer(Interpreter): |
| """ |
| ``Transformer`` is a special type of interpreter that produces a |
| new ``Module``. It exposes a ``transform()`` method that returns |
| the transformed ``Module``. ``Transformer`` does not require |
| arguments to run, as ``Interpreter`` does. ``Transformer`` works |
| entirely symbolically. |
| |
| Example: |
| |
| Suppose we want to swap all instances of ``torch.neg`` with |
| ``torch.sigmoid`` and vice versa (including their ``Tensor`` |
| method equivalents). We could subclass ``Transformer`` like so:: |
| |
| class NegSigmSwapXformer(Transformer): |
| def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: |
| if target == torch.sigmoid: |
| return torch.neg(*args, **kwargs) |
| return super().call_function(n) |
| |
| def call_method(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: |
| if target == 'neg': |
| call_self, *args_tail = args |
| return call_self.sigmoid(*args_tail, **kwargs) |
| return super().call_method(n) |
| |
| def fn(x): |
| return torch.sigmoid(x).neg() |
| |
| gm = torch.fx.symbolic_trace(fn) |
| |
| transformed : torch.nn.Module = NegSigmSwapXformer(gm).transform() |
| input = torch.randn(3, 4) |
| torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid()) |
| |
| Args: |
| module (GraphModule): The ``Module`` to be transformed. |
| """ |
| |
| @compatibility(is_backward_compatible=True) |
| def __init__(self, module): |
| super().__init__(module) |
| self.new_graph = Graph() |
| self.new_graph.set_codegen(module.graph._codegen) |
| |
| class TransformerTracer(Tracer): |
| def __init__(self, graph: Graph): |
| super().__init__() |
| self.graph = graph |
| self.tensor_attrs: Dict[torch.Tensor, str] = {} # type: ignore[assignment] |
| |
| def is_leaf_module(self, _, __) -> bool: |
| return True |
| |
| self.tracer = TransformerTracer(self.new_graph) |
| self.tracer.root = module |
| |
| @compatibility(is_backward_compatible=True) |
| def placeholder(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy: |
| """ |
| Execute a ``placeholder`` node. In ``Transformer``, this is |
| overridden to insert a new ``placeholder`` into the output |
| graph. |
| |
| Args: |
| target (Target): The call target for this node. See |
| `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for |
| details on semantics |
| args (Tuple): Tuple of positional args for this invocation |
| kwargs (Dict): Dict of keyword arguments for this invocation |
| """ |
| assert isinstance(target, str) |
| default_value = next(iter(args)) if args else inspect.Signature.empty |
| return Proxy(self.new_graph.placeholder(target, default_value=default_value), self.tracer) |
| |
| @compatibility(is_backward_compatible=True) |
| def get_attr(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Proxy: |
| """ |
| Execute a ``get_attr`` node. In ``Transformer``, this is |
| overridden to insert a new ``get_attr`` node into the output |
| graph. |
| |
| Args: |
| target (Target): The call target for this node. See |
| `Node <https://pytorch.org/docs/master/fx.html#torch.fx.Node>`__ for |
| details on semantics |
| args (Tuple): Tuple of positional args for this invocation |
| kwargs (Dict): Dict of keyword arguments for this invocation |
| """ |
| assert isinstance(target, str) |
| return self.tracer.create_proxy("get_attr", target, args, kwargs) |
| |
| @compatibility(is_backward_compatible=True) |
| def call_module(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: |
| # Override so that the leaf module policy from `self.tracer` is respected. |
| assert isinstance(target, str) |
| submod = self.fetch_attr(target) |
| return self.tracer.call_module(submod, submod.forward, args, kwargs) |
| |
| @compatibility(is_backward_compatible=True) |
| def call_function(self, target : 'Target', args : Tuple[Argument, ...], kwargs : Dict[str, Any]) -> Any: |
| # Override so that functions that were wrapped are still wrapped. |
| return self.tracer.create_proxy('call_function', target, args, kwargs) |
| |
| @compatibility(is_backward_compatible=True) |
| def transform(self) -> GraphModule: |
| """ |
| Transform ``self.module`` and return the transformed |
| ``GraphModule``. |
| """ |
| with fx_traceback.preserve_node_meta(): |
| result = super().run(enable_io_processing=False) |
| if result is not None: |
| def strip_proxy(a : Union[Argument, Proxy]) -> Any: |
| return a.node if isinstance(a, Proxy) else a |
| self.new_graph.output(map_aggregate(result, strip_proxy)) |
| return GraphModule(self.module, self.new_graph) |