| import torch |
| import inspect |
| import numbers |
| import types |
| import typing |
| import enum |
| import warnings |
| from typing import Any, Callable, Dict, List, Optional, Tuple, NamedTuple, cast, TYPE_CHECKING |
| from torch._jit_internal import boolean_dispatched |
| from ._compatibility import compatibility |
| from torch._ops import OpOverloadPacket, OpOverload |
| |
| if TYPE_CHECKING: |
| from .node import Argument |
| |
| @compatibility(is_backward_compatible=False) |
| class ArgsKwargsPair(NamedTuple): |
| """ |
| Simple named tuple for wrapping args/kwargs pairs. |
| """ |
| args: Tuple[Any, ...] |
| kwargs: Dict[str, Any] |
| |
| _manual_overrides : Dict[Callable, List[inspect.Signature]] = {} |
| |
| def _nonzero_schemas(): |
| signatures = [] |
| |
| def nonzero(self): |
| pass |
| signatures.append(inspect.signature(nonzero)) |
| |
| def nonzero(self, *, as_tuple : bool): # type: ignore[no-redef] |
| pass |
| signatures.append(inspect.signature(nonzero)) |
| |
| return signatures |
| |
| _manual_overrides[torch.nonzero] = _nonzero_schemas() |
| |
| class _FakeGlobalNamespace: |
| def __getattr__(self, name): |
| if name == 'torch': |
| return torch |
| raise RuntimeError('Expected a torch namespace lookup') |
| |
| _type_eval_globals = {'Tensor' : torch.Tensor, 'Device' : torch.device, 'Layout' : torch.layout, |
| 'number' : numbers.Number, 'Future' : torch.jit.Future, |
| 'AnyEnumType' : enum.Enum, 'QScheme' : torch.qscheme, |
| '__torch__': _FakeGlobalNamespace(), 'NoneType': type(None), |
| 't': typing.TypeVar('t')} |
| for k in dir(typing): |
| _type_eval_globals[k] = getattr(typing, k) |
| |
| def _torchscript_type_to_python_type(ts_type : 'torch._C.JitType') -> Any: |
| """ |
| Convert a TorchScript type to a Python type (including subtypes) via |
| eval'ing the annotation_str. _type_eval_globals sets up expressions |
| like "List" and "Future" to map to actual types (typing.List and jit.Future) |
| """ |
| return eval(ts_type.annotation_str, _type_eval_globals) |
| |
| def _torchscript_schema_to_signature(ts_schema : torch._C.FunctionSchema) -> inspect.Signature: |
| parameters : List[inspect.Parameter] = [] |
| for arg in ts_schema.arguments: |
| arg_type = _torchscript_type_to_python_type(arg.type) |
| default = arg.default_value if arg.has_default_value() else inspect.Parameter.empty |
| # TODO: Figure out if this is safe. It seems like when generating the type signatures for |
| # PythonArgParser, we emit signatures with `input` instead of `self` as the first tensor |
| # argument name. Downstream, if someone converts that positional argument to a keyword |
| # argument, the name mismatch will break things, so here we're going to normalize the |
| # name to "input" |
| name = arg.name if arg.name != 'self' else 'input' |
| kind = inspect.Parameter.KEYWORD_ONLY if arg.kwarg_only else inspect.Parameter.POSITIONAL_OR_KEYWORD |
| parameters.append(inspect.Parameter(name=name, kind=kind, default=default, annotation=arg_type)) |
| return_types = [_torchscript_type_to_python_type(ret.type) for ret in ts_schema.returns] |
| if len(return_types) == 0: |
| return_type = None |
| elif len(return_types) == 1: |
| return_type = return_types[0] |
| else: |
| return_type = tuple(return_types) |
| |
| return inspect.Signature(parameters, return_annotation=return_type) |
| |
| @compatibility(is_backward_compatible=False) |
| def check_for_mutable_operation(target : Callable, args : Tuple['Argument', ...], kwargs : Dict[str, 'Argument']): |
| signatures, schemas = get_signature_for_torch_op(target, return_schemas=True) |
| |
| if signatures and schemas: |
| matched_schemas = [] |
| |
| # Iterate through all of the schema until we find one that matches |
| # If one matches, populate `new_args_and_kwargs` with the new args/kwargs |
| # values. If none matches, `new_args_and_kwargs` will be None |
| for candidate_signature, schema in zip(signatures, schemas): |
| try: |
| candidate_signature.bind(*args, **kwargs) |
| matched_schemas.append((candidate_signature, schema)) |
| except TypeError as e: |
| continue |
| |
| def throw_if_mutable(schema): |
| if schema.is_mutable: |
| raise RuntimeError(f'Tried to trace mutable operation {schema}. FX only supports functional ' |
| f'code, so operations that mutate operands in-place (e.g. via `out` arguments) ' |
| f'are not supported') |
| |
| if len(matched_schemas) == 0: |
| # Did not match any schema. Cannot check for mutation |
| pass |
| elif len(matched_schemas) == 1: |
| # Matched exactly one schema, unambiguous |
| _, schema_to_check = matched_schemas[0] |
| throw_if_mutable(schema_to_check) |
| pass |
| else: |
| # Ambiguous schema match. Since mutability checking is best effort, |
| # do nothing. |
| pass |
| |
| @compatibility(is_backward_compatible=False) |
| def get_signature_for_torch_op(op : Callable, return_schemas : bool = False): |
| """ |
| Given an operator on the `torch` namespace, return a list of `inspect.Signature` |
| objects corresponding to the overloads of that op.. May return `None` if a signature |
| could not be retrieved. |
| |
| Args: |
| op (Callable): An operator on the `torch` namespace to look up a signature for |
| |
| Returns: |
| Optional[List[inspect.Signature]]: A list of signatures for the overloads of this |
| operator, or None if the operator signatures could not be retrieved. If |
| return_schemas=True, returns a tuple containing the optional Python signatures |
| and the optional TorchScript Function signature |
| """ |
| if isinstance(op, OpOverload): |
| schemas = [op._schema] |
| elif isinstance(op, OpOverloadPacket): |
| schemas = [getattr(op, overload)._schema for overload in op.overloads()] |
| else: |
| override = _manual_overrides.get(op) |
| if override: |
| return (override, None) if return_schemas else None |
| |
| aten_fn = torch.jit._builtins._find_builtin(op) |
| |
| if aten_fn is None: |
| return (None, None) if return_schemas else None |
| schemas = torch._C._jit_get_schemas_for_operator(aten_fn) |
| |
| signatures = [_torchscript_schema_to_signature(schema) for schema in schemas] |
| return (signatures, schemas) if return_schemas else signatures |
| |
| @compatibility(is_backward_compatible=False) |
| def create_type_hint(x): |
| try: |
| if isinstance(x, list) or isinstance(x, tuple): |
| # todo(chilli): Figure out the right way for mypy to handle this |
| if isinstance(x, list): |
| def ret_type(x): |
| return List[x] # type: ignore[valid-type] |
| else: |
| def ret_type(x): |
| return Tuple[x, ...] |
| if len(x) == 0: |
| return ret_type(Any) |
| base_type = x[0] |
| for t in x: |
| if issubclass(t, base_type): |
| continue |
| elif issubclass(base_type, t): |
| base_type = t |
| else: |
| return ret_type(Any) |
| return ret_type(base_type) |
| except Exception as e: |
| # We tried to create a type hint for list but failed. |
| torch.warnings.warn(f"We were not able to successfully create type hint from the type {x}") |
| pass |
| return x |
| |
| @compatibility(is_backward_compatible=False) |
| def type_matches(signature_type : Any, argument_type : Any): |
| sig_origin_type = getattr(signature_type, '__origin__', signature_type) |
| |
| if signature_type is argument_type: |
| return True |
| |
| # Union types in signature. Given type needs to match one of the |
| # contained types in the Union |
| if sig_origin_type is typing.Union and signature_type != argument_type: |
| sig_contained = signature_type.__args__ |
| return any(type_matches(c, argument_type) for c in sig_contained) |
| |
| if signature_type is List[int] and argument_type is int: |
| # int can be promoted to List[int] |
| return True |
| |
| if getattr(signature_type, '__origin__', None) in {list, List}: |
| sig_el_type = signature_type.__args__[0] |
| if not inspect.isclass(sig_el_type): |
| warnings.warn( |
| f"Does not support nested parametric types, got {signature_type}. Please file a bug.") |
| return False |
| if getattr(argument_type, '__origin__', None) in {list, List}: |
| return issubclass(argument_type.__args__[0], sig_el_type) |
| |
| def is_homogeneous_tuple(t): |
| if not getattr(t, '__origin__', None) in {tuple, Tuple}: |
| return False |
| contained = t.__args__ |
| if t.__args__ == ((),): # Tuple[()].__args__ == ((),) for some reason |
| return True |
| return all((c is Ellipsis) or issubclass(c, sig_el_type) for c in contained) |
| |
| # Tuple[T] is accepted for List[T] parameters |
| return is_homogeneous_tuple(argument_type) |
| |
| # Dtype is an int in schemas |
| if signature_type is int and argument_type is torch.dtype: |
| return True |
| |
| if signature_type is numbers.Number and argument_type in {int, float}: |
| return True |
| if inspect.isclass(argument_type) and inspect.isclass(signature_type): |
| return issubclass(argument_type, signature_type) |
| |
| return False |
| |
| @compatibility(is_backward_compatible=False) |
| def normalize_function( |
| target: Callable, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None, arg_types : Optional[Tuple[Any]] = None, |
| kwarg_types : Optional[Dict[str, Any]] = None, |
| normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]: |
| """ |
| Returns normalized arguments to PyTorch functions. This means that |
| `args/kwargs` will be matched up to the functional's |
| signature and return exclusively kwargs in positional order if |
| `normalize_to_only_use_kwargs` is True. |
| Also populates default values. Does not support positional-only |
| parameters or varargs parameters (*args, **kwargs). Does not support modules. |
| |
| May require `arg_types` and `kwarg_types` in order to disambiguate overloads. |
| |
| Args: |
| target (Callable): Function that we are normalizing |
| args (Tuple[Any]): Tuple of args to the function |
| kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function |
| arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args |
| kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs |
| normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. |
| |
| Returns: |
| |
| Returns normalized_args_and_kwargs, or `None` if not successful. |
| """ |
| if kwargs is None: |
| kwargs = {} |
| new_args_and_kwargs = None |
| if not isinstance(target, types.BuiltinFunctionType) and not ( |
| isinstance(target, OpOverloadPacket) or isinstance(target, OpOverload) |
| ): |
| target_for_analysis = target |
| if target in boolean_dispatched: |
| # HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have |
| # a 2-way dispatch based on a boolean value. Here we check that the `true` and `false` |
| # branches of the dispatch have exactly the same signature. If they do, use the `true` |
| # branch signature for analysis. Otherwise, leave this un-normalized |
| assert not isinstance(target, str) |
| dispatched = boolean_dispatched[target] |
| if_true, if_false = dispatched['if_true'], dispatched['if_false'] |
| if inspect.signature(if_true).parameters != inspect.signature(if_false).parameters: |
| return None |
| target_for_analysis = if_true |
| |
| assert callable(target_for_analysis) |
| sig = inspect.signature(inspect.unwrap(target_for_analysis)) |
| new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs, normalize_to_only_use_kwargs) |
| else: |
| assert callable(target) |
| torch_op_schemas = get_signature_for_torch_op(target) |
| matched_schemas = [] |
| if torch_op_schemas: |
| # Iterate through all of the schema until we find one that matches |
| # If one matches, populate `new_args_and_kwargs` with the new args/kwargs |
| # values. If none matches, `new_args_and_kwargs` will be None |
| for candidate_signature in torch_op_schemas: |
| try: |
| candidate_signature.bind(*args, **kwargs) |
| matched_schemas.append(candidate_signature) |
| except TypeError as e: |
| continue |
| |
| if len(matched_schemas) == 0: |
| # Did not match any schema. Cannot normalize |
| pass |
| elif len(matched_schemas) == 1: |
| # Matched exactly one schema, unambiguous |
| new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(matched_schemas[0], args, kwargs, |
| normalize_to_only_use_kwargs) |
| else: |
| if arg_types is not None or kwarg_types is not None: |
| arg_types = arg_types if arg_types else cast(Tuple[Any], ()) |
| kwarg_types = kwarg_types if kwarg_types else {} |
| for candidate_signature in torch_op_schemas: |
| sig_matches = True |
| try: |
| bound_types = candidate_signature.bind(*arg_types, **kwarg_types) |
| for arg_name, arg_type in bound_types.arguments.items(): |
| param = candidate_signature.parameters[arg_name] |
| sig_matches = sig_matches and type_matches(param.annotation, arg_type) |
| except TypeError as e: |
| sig_matches = False |
| if sig_matches: |
| new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(candidate_signature, args, kwargs, |
| normalize_to_only_use_kwargs) |
| break |
| else: |
| # Matched more than one schema. In this situation, the caller must provide the types of |
| # the arguments of the overload they expect. |
| schema_printouts = '\n'.join(str(schema) for schema in matched_schemas) |
| raise RuntimeError(f'Tried to normalize arguments to {torch.typename(target)} but ' |
| f'the schema match was ambiguous! Please provide argument types to ' |
| f'the normalize_arguments() call. Available schemas:\n{schema_printouts}') |
| |
| return new_args_and_kwargs |
| |
| @compatibility(is_backward_compatible=False) |
| def normalize_module( |
| root: torch.nn.Module, target: str, args: Tuple[Any], kwargs : Optional[Dict[str, Any]] = None, |
| normalize_to_only_use_kwargs : bool = False) -> Optional[ArgsKwargsPair]: |
| """ |
| Returns normalized arguments to PyTorch modules. This means that |
| `args/kwargs` will be matched up to the functional's |
| signature and return exclusively kwargs in positional order if |
| `normalize_to_only_use_kwargs` is True. |
| Also populates default values. Does not support positional-only |
| parameters or varargs parameters (*args, **kwargs). |
| |
| Args: |
| root (nn.Module): root module upon which we query modules |
| target (Callable): Function that we are normalizing |
| args (Tuple[Any]): Tuple of args to the function |
| kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function |
| normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. |
| |
| Returns: |
| |
| Returns normalized_args_and_kwargs, or `None` if not successful. |
| """ |
| try: |
| submod = root.get_submodule(target) |
| except AttributeError: |
| raise RuntimeError(f"Tried to normalize node with target {target} but root did not " |
| f"have that target!") |
| if hasattr(submod.__class__, '__name__'): |
| classname = submod.__class__.__name__ |
| if getattr(torch.nn, classname, None) == submod.__class__: |
| sig = inspect.signature(inspect.unwrap(submod.forward)) |
| if kwargs is None: |
| kwargs = {} |
| new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs(sig, args, kwargs, |
| normalize_to_only_use_kwargs) |
| return new_args_and_kwargs |
| return None |
| |
| def _args_kwargs_to_normalized_args_kwargs(sig : inspect.Signature, args : Tuple[Any, ...], |
| kwargs : Dict[str, Any], |
| normalize_to_only_use_kwargs : bool) -> Optional[ArgsKwargsPair]: |
| """ |
| Given a call target, args, and kwargs, return the arguments normalized into |
| an ArgsKwargsPair, or None if the type signature is not supported by |
| this normalization. |
| |
| Args: |
| |
| target (inspect.Signature): Signature object for the target |
| args (Tuple): Arguments that appear at the callsite for `target` |
| kwargs (Dict): Keyword arguments that appear at the callsite for `target` |
| normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. |
| |
| Returns: |
| |
| Optional[ArgsKwargsPair]: Normalized args and kwargs for `target`, or `None` if |
| this target is not supported. |
| """ |
| |
| # Don't currently support positional-only |
| # or varargs (*args, **kwargs) signatures |
| supported_parameter_types = { |
| inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY} |
| if any(p.kind not in supported_parameter_types for p in sig.parameters.values()): |
| return None |
| |
| bound_args = sig.bind(*args, **kwargs) |
| bound_args.apply_defaults() |
| |
| new_kwargs : Dict[str, Any] = {} |
| new_args : List[Any] = [] |
| for i, param in enumerate(sig.parameters): |
| if not normalize_to_only_use_kwargs and i < len(args): |
| new_args.append(bound_args.arguments[param]) |
| else: |
| new_kwargs[param] = bound_args.arguments[param] |
| |
| return ArgsKwargsPair(tuple(new_args), new_kwargs) |