| import operator |
| from typing import Any, Callable, Dict, Tuple, Optional |
| |
| import torch |
| import torch.fx |
| import torch.fx as fx |
| from torch.fx import Transformer, Proxy |
| from torch.fx.node import Argument, Target, Node, map_aggregate |
| from torch.fx.operator_schemas import ( |
| normalize_module, |
| normalize_function, |
| create_type_hint, |
| ) |
| |
| from .schema_type_annotation import AnnotateTypesWithSchema |
| |
| |
| class NormalizeArgs(Transformer): |
| """ |
| Normalize arguments to Python targets. This means that |
| `args/kwargs` will be matched up to the module/functional's |
| signature and rewritten to exclusively kwargs in positional order |
| if `normalize_to_only_use_kwargs` is true. Also populates default |
| values. Does not support positional-only parameters or varargs |
| parameters (*args, **kwargs). |
| |
| If the nodes have 'type' metadata, it will use it to disambiguate |
| overloads. Otherwise, it will throw an error. |
| |
| Example usage: |
| m = torchvision.models.resnet18() |
| traced = torch.fx.symbolic_trace(m) |
| traced = NormalizeArgs(traced).transform() |
| """ |
| |
| def __init__( |
| self, module: torch.fx.GraphModule, normalize_to_only_use_kwargs: bool = True |
| ): |
| super().__init__(module) |
| self.node_map: Dict[Proxy, Node] = {} |
| self.normalize_to_only_use_kwargs = normalize_to_only_use_kwargs |
| |
| def run_node(self, n: Node) -> Any: |
| args, kwargs = self.fetch_args_kwargs_from_env(n) |
| |
| def get_type(arg): |
| if isinstance(arg, fx.Node): |
| return n.meta["type"] if "type" in n.meta else None |
| return type(arg) |
| |
| arg_types = map_aggregate(n.args, get_type) |
| assert isinstance(arg_types, tuple) |
| arg_types = tuple([create_type_hint(i) for i in arg_types]) |
| kwarg_types = {k: get_type(v) for k, v in kwargs.items()} |
| if n.op == "call_function": |
| out = self.call_function(n.target, args, kwargs, arg_types, kwarg_types) |
| else: |
| out = super().run_node(n) |
| if n.op != "output": |
| self.node_map[out] = n |
| out.node.meta = n.meta |
| out.node.type = n.type |
| return out |
| |
| def call_function( |
| self, |
| target: Target, |
| args: Tuple[Argument, ...], |
| kwargs: Dict[str, Any], |
| arg_types: Optional[Tuple[Any, ...]] = None, |
| kwarg_types: Optional[Dict[str, Any]] = None, |
| ): |
| assert callable(target) |
| new_args_and_kwargs = normalize_function( |
| target, |
| args, # type: ignore[arg-type] |
| kwargs, |
| arg_types, # type: ignore[arg-type] |
| kwarg_types, |
| self.normalize_to_only_use_kwargs, |
| ) |
| if new_args_and_kwargs: |
| new_args, new_kwargs = new_args_and_kwargs |
| return self.tracer.create_proxy( |
| "call_function", target, new_args, new_kwargs |
| ) |
| else: |
| return super().call_function(target, args, kwargs) |
| |
| def call_module( |
| self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] |
| ): |
| assert isinstance(target, str) |
| new_args_and_kwargs = normalize_module( |
| self.module, |
| target, |
| args, # type: ignore[arg-type] |
| kwargs, |
| self.normalize_to_only_use_kwargs, |
| ) |
| if new_args_and_kwargs: |
| new_args, new_kwargs = new_args_and_kwargs |
| return super().call_module(target, new_args, new_kwargs) |
| else: |
| return super().call_module(target, args, kwargs) |
| |
| |
| class NormalizeOperators(AnnotateTypesWithSchema): |
| """ |
| Normalize callsites that are different ways of "spelling" the same |
| invocation into a single, canonical call. Currently supports: |
| |
| 1. Normalize operators (e.g. operator.add) to the `torch` ops they |
| ultimately invoke (e.g. torch.add) when it is possible to statically |
| reason that |
| |
| Example usage: |
| |
| m = torchvision.models.resnet18() |
| |
| traced = torch.fx.symbolic_trace(m) |
| |
| traced = NormalizeOperators(traced).transform() |
| """ |
| |
| binary_magic_method_remap: Dict[ |
| Callable[[Any, Any], Any], Callable[[Any, Any], Any] |
| ] = { |
| torch.add: operator.add, |
| torch.mul: operator.mul, |
| torch.sub: operator.sub, |
| torch.div: operator.truediv, |
| torch.floor_divide: operator.floordiv, |
| torch.remainder: operator.mod, |
| torch.eq: operator.eq, |
| torch.ne: operator.ne, |
| torch.lt: operator.lt, |
| torch.le: operator.le, |
| torch.gt: operator.gt, |
| torch.ge: operator.ge, |
| } |
| |
| def call_function( |
| self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any] |
| ): |
| # Normalize operators according to the magic methods implemented on tensors here: |
| # https://github.com/pytorch/pytorch/blob/28c5d90b679c6b38bf4183ec99f16d933c2f1bcd/tools/autograd/templates/python_variable_methods.cpp#L1137 # noqa: B950 |
| |
| assert callable(target) |
| |
| if target in self.binary_magic_method_remap: |
| if len(args) != 2: |
| return super().call_function(target, args, kwargs) |
| lhs, rhs = args |
| |
| return super().call_function( |
| target=self.binary_magic_method_remap[target], |
| args=(lhs, rhs), |
| kwargs={}, |
| ) |
| |
| return super().call_function(target, args, kwargs) |