| from typing import List, Optional, Sequence, Union |
| |
| from torchgen import local |
| from torchgen.api import cpp |
| |
| from torchgen.api.types import ( |
| ArgName, |
| BaseCType, |
| Binding, |
| boolT, |
| ConstRefCType, |
| CType, |
| deviceT, |
| layoutT, |
| ListCType, |
| MutRefCType, |
| NamedCType, |
| OptionalCType, |
| scalarT, |
| scalarTypeT, |
| tensorT, |
| ) |
| from torchgen.model import ( |
| Argument, |
| FunctionSchema, |
| Return, |
| SelfArgument, |
| TensorOptionsArguments, |
| Type, |
| ) |
| from torchgen.utils import assert_never |
| |
| # This file describes the translation of JIT schema to the native functions API. |
| # This looks a lot like the C++ API (which makes historical sense, because the |
| # idea was you wrote native functions to implement functions in the C++ API), |
| # but over time we have evolved the C++ API without actually changing our |
| # native:: kernels. The intention is to make native API and dispatcher API |
| # line up as closely as possible, since this results in the least overhead |
| # (no translation is needed from dispatcher API to native API). |
| # |
| # NB: this is symint aware, you will get the non-SymInt variant for some |
| # dispatch entries and SymInt for others. |
| |
| |
| def name(func: FunctionSchema) -> str: |
| name = str(func.name.name) |
| # TODO: delete this! |
| if func.is_out_fn(): |
| name += "_out" |
| if func.name.overload_name: |
| name += f"_{func.name.overload_name}" |
| return name |
| |
| |
| def argumenttype_type( |
| t: Type, *, mutable: bool, binds: ArgName, symint: bool |
| ) -> NamedCType: |
| if str(t) == "Tensor?": |
| tensor_type: OptionalCType = OptionalCType(BaseCType(tensorT)) |
| if mutable and not local.use_const_ref_for_mutable_tensors(): |
| return NamedCType(binds, MutRefCType(tensor_type)) |
| else: |
| return NamedCType(binds, ConstRefCType(tensor_type)) |
| elif str(t) == "Tensor?[]": |
| return NamedCType( |
| binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))) |
| ) |
| elif str(t) == "Scalar": |
| return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) |
| elif str(t) == "Scalar?": |
| return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT)))) |
| return cpp.argumenttype_type(t, mutable=mutable, binds=binds, symint=symint) |
| |
| |
| def returns_type(rs: Sequence[Return], *, symint: bool) -> CType: |
| return cpp.returns_type(rs, symint=symint) |
| |
| |
| def argument_type(a: Argument, *, binds: ArgName, symint: bool) -> NamedCType: |
| return argumenttype_type(a.type, mutable=a.is_write, binds=binds, symint=symint) |
| |
| |
| def argument( |
| a: Union[Argument, SelfArgument, TensorOptionsArguments], |
| *, |
| is_out: bool, |
| symint: bool, |
| ) -> List[Binding]: |
| # Ideally, we NEVER default native functions. However, there are a number |
| # of functions that call native:: directly and rely on the defaulting |
| # existing. So for BC, we generate defaults for non-out variants (but not |
| # for out variants, where it is impossible to generate an appropriate |
| # default) |
| should_default = not is_out |
| if isinstance(a, Argument): |
| default: Optional[str] = None |
| if should_default and a.default is not None: |
| default = cpp.default_expr(a.default, a.type, symint=symint) |
| return [ |
| Binding( |
| nctype=argument_type(a, binds=a.name, symint=symint), |
| name=a.name, |
| default=default, |
| argument=a, |
| ) |
| ] |
| elif isinstance(a, SelfArgument): |
| # Erase SelfArgument from the distinction |
| return argument(a.argument, is_out=is_out, symint=symint) |
| elif isinstance(a, TensorOptionsArguments): |
| default = None |
| if should_default: |
| default = "{}" |
| # TODO: Not sure why the arguments assigned here are for |
| # TensorOptionsArguments and not the constituent pieces. It seems |
| # to matter |
| return [ |
| Binding( |
| nctype=NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))), |
| name="dtype", |
| default=default, |
| argument=a, |
| ), |
| Binding( |
| nctype=NamedCType("layout", OptionalCType(BaseCType(layoutT))), |
| name="layout", |
| default=default, |
| argument=a, |
| ), |
| Binding( |
| nctype=NamedCType("device", OptionalCType(BaseCType(deviceT))), |
| name="device", |
| default=default, |
| argument=a, |
| ), |
| Binding( |
| nctype=NamedCType("pin_memory", OptionalCType(BaseCType(boolT))), |
| name="pin_memory", |
| default=default, |
| argument=a, |
| ), |
| ] |
| else: |
| assert_never(a) |
| |
| |
| def arguments(func: FunctionSchema, *, symint: bool) -> List[Binding]: |
| args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = [] |
| args.extend(func.arguments.non_out) |
| args.extend(func.arguments.out) |
| return [ |
| r for arg in args for r in argument(arg, symint=symint, is_out=func.is_out_fn()) |
| ] |