| from typing import Any, Optional, Tuple, Union |
| |
| from torchgen.model import ( |
| Annotation, |
| Argument, |
| Arguments, |
| BaseOperatorName, |
| BaseTy, |
| BaseType, |
| CustomClassType, |
| FunctionSchema, |
| ListType, |
| OperatorName, |
| Return, |
| ) |
| |
| |
| # Note: These aren't actually used in torchgen, they're some utilities for generating a schema |
| # from real arguments. For example, this is used to generate HigherOrderOperators' schema since |
| # their schemas can vary for different instances of the same HOP. |
| |
| |
| class TypeGen: |
| convert_to_base_ty = { |
| int: BaseTy.int, |
| float: BaseTy.float, |
| str: BaseTy.str, |
| bool: BaseTy.bool, |
| } |
| |
| @staticmethod |
| def from_example(obj: Any) -> Union[BaseType, ListType, CustomClassType]: |
| import torch |
| |
| if isinstance(obj, torch.fx.GraphModule): |
| return BaseType(BaseTy.GraphModule) |
| elif isinstance(obj, torch.Tensor): |
| return BaseType(BaseTy.Tensor) |
| elif isinstance(obj, torch.SymInt): |
| return BaseType(BaseTy.SymInt) |
| elif isinstance(obj, torch.SymBool): |
| return BaseType(BaseTy.SymBool) |
| elif isinstance(obj, torch.ScriptObject): |
| return CustomClassType(obj._type().name()) # type: ignore[attr-defined] |
| elif isinstance(obj, (list, tuple)): |
| assert len(obj) > 0 |
| all_base_tys = [TypeGen.from_example(x) for x in obj] |
| if len(set(all_base_tys)) > 1: |
| raise RuntimeError( |
| f"Cannot generate schema for a seqeunce of args of heterogeneous types: {all_base_tys}. " |
| "Consider unpacking the argument and give proper names to them if possible " |
| "instead of using *args." |
| ) |
| return ListType(all_base_tys[0], len(obj)) |
| tp = type(obj) |
| if tp not in TypeGen.convert_to_base_ty: |
| raise RuntimeError(f"unsupported type {tp}") |
| return BaseType(TypeGen.convert_to_base_ty[tp]) |
| |
| |
| class ReturnGen: |
| @staticmethod |
| def from_example( |
| name: Optional[str], obj: Any, annotation: Optional[Annotation] |
| ) -> Return: |
| return Return(name, TypeGen.from_example(obj), annotation) |
| |
| |
| class ArgumentGen: |
| @staticmethod |
| def from_example( |
| name: str, obj: Any, default: Optional[str], annotation: Optional[Annotation] |
| ) -> Argument: |
| return Argument( |
| name, TypeGen.from_example(obj), default=default, annotation=annotation |
| ) |
| |
| |
| class FunctionSchemaGen: |
| @staticmethod |
| def from_example( |
| op_name: str, |
| example_inputs: Tuple[Tuple[str, Any], ...], |
| example_outputs: Tuple[Any, ...], |
| ) -> FunctionSchema: |
| args = [] |
| for name, inp in example_inputs: |
| args.append(ArgumentGen.from_example(name, inp, None, None)) |
| # ignore the annotations and other attributes for now, we could add more when needed. |
| arguments = Arguments( |
| tuple(), None, tuple(args), tuple(), None, tuple(), tuple() |
| ) |
| returns = tuple( |
| ReturnGen.from_example(None, out, None) for out in example_outputs |
| ) |
| op_name = OperatorName(BaseOperatorName(op_name, False, False, False), "") |
| return FunctionSchema(op_name, arguments, returns) |