| import dataclasses |
| import itertools |
| import re |
| |
| from dataclasses import dataclass |
| from enum import auto, Enum |
| from typing import Callable, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union |
| |
| from torchgen.utils import assert_never, NamespaceHelper |
| |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # |
| # |
| # DATA MODEL |
| # |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # |
| # |
| # Some general principles for our data model. |
| # |
| # - Stop using C++ data types as the internal data representation |
| # format. Instead, the internal data structures are centered |
| # around JIT schema representation. This avoid a big problem |
| # with the old codegen where we read in all the types from |
| # native_functions.yaml and then immediately had to retranslate |
| # them into C++ types. |
| # |
| # - More semantic data representation. Instead of representing |
| # everything as dicts and strings, we define dataclasses for |
| # every interesting entity the code generation has to deal with. |
| # These dataclasses have strong semantic invariants: for example, |
| # we generally require them to roundtrip losslessly into the |
| # form they were parsed from. These structures are immutable |
| # and you're expected to populate information once during |
| # construction. |
| |
| # Represent a source location; used for better error reporting |
| @dataclass(frozen=True) |
| class Location: |
| file: str |
| line: int |
| |
| def __str__(self) -> str: |
| return "{}:{}".format(self.file, self.line) |
| |
| |
| # Valid values of the 'variants' field in native_functions.yaml |
| Variant = Enum("Variant", ("function", "method")) |
| |
| # Default kernel namespace |
| DEFAULT_KERNEL_NAMESPACE = "at::native" |
| |
| # NOTE: Keep the list in sync with `DispatchKey` in c10/core/DispatchKey.h |
| BACKEND_COMPONENTS = "CPU CUDA HIP XLA MPS IPU XPU HPU VE Lazy Meta PrivateUse1 PrivateUse2 PrivateUse3".split() |
| FUNCTIONALITY_KEYS = ["", "Quantized", "Sparse", "NestedTensor", "Autograd"] |
| |
| # This doesn't have to be in sync with the header, it only needs to contain |
| # entries that we actually use in the codegen |
| class DispatchKey(Enum): |
| Undefined = 0 |
| CatchAll = Undefined |
| |
| FPGA = auto() |
| ORT = auto() |
| Vulkan = auto() |
| Metal = auto() |
| MKLDNN = auto() |
| OpenGL = auto() |
| OpenCL = auto() |
| IDEEP = auto() |
| CustomRNGKeyId = auto() |
| MkldnnCPU = auto() |
| Sparse = auto() |
| SparseCsrCPU = auto() |
| SparseCsrCUDA = auto() |
| |
| ZeroTensor = auto() |
| BackendSelect = auto() |
| Named = auto() |
| AutogradOther = auto() |
| AutogradFunctionality = auto() |
| AutogradNestedTensor = auto() |
| Tracer = auto() |
| Autocast = auto() |
| Batched = auto() |
| VmapMode = auto() |
| TESTING_ONLY_GenericWrapper = auto() |
| TESTING_ONLY_GenericMode = auto() |
| |
| Autograd = auto() |
| CompositeImplicitAutograd = auto() |
| CompositeExplicitAutograd = auto() |
| CompositeExplicitAutogradNonFunctional = auto() |
| |
| # BEGIN autogenerated |
| CPU = auto() |
| CUDA = auto() |
| HIP = auto() |
| XLA = auto() |
| MPS = auto() |
| IPU = auto() |
| XPU = auto() |
| HPU = auto() |
| VE = auto() |
| Lazy = auto() |
| Meta = auto() |
| PrivateUse1 = auto() |
| PrivateUse2 = auto() |
| PrivateUse3 = auto() |
| QuantizedCPU = auto() |
| QuantizedCUDA = auto() |
| QuantizedHIP = auto() |
| QuantizedXLA = auto() |
| QuantizedMPS = auto() |
| QuantizedIPU = auto() |
| QuantizedXPU = auto() |
| QuantizedHPU = auto() |
| QuantizedVE = auto() |
| QuantizedLazy = auto() |
| QuantizedMeta = auto() |
| QuantizedPrivateUse1 = auto() |
| QuantizedPrivateUse2 = auto() |
| QuantizedPrivateUse3 = auto() |
| SparseCPU = auto() |
| SparseCUDA = auto() |
| SparseHIP = auto() |
| SparseXLA = auto() |
| SparseMPS = auto() |
| SparseIPU = auto() |
| SparseXPU = auto() |
| SparseHPU = auto() |
| SparseVE = auto() |
| SparseLazy = auto() |
| SparseMeta = auto() |
| SparsePrivateUse1 = auto() |
| SparsePrivateUse2 = auto() |
| SparsePrivateUse3 = auto() |
| NestedTensorCPU = auto() |
| NestedTensorCUDA = auto() |
| NestedTensorHIP = auto() |
| NestedTensorXLA = auto() |
| NestedTensorMPS = auto() |
| NestedTensorIPU = auto() |
| NestedTensorXPU = auto() |
| NestedTensorHPU = auto() |
| NestedTensorVE = auto() |
| NestedTensorLazy = auto() |
| NestedTensorMeta = auto() |
| NestedTensorPrivateUse1 = auto() |
| NestedTensorPrivateUse2 = auto() |
| NestedTensorPrivateUse3 = auto() |
| AutogradCPU = auto() |
| AutogradCUDA = auto() |
| AutogradHIP = auto() |
| AutogradXLA = auto() |
| AutogradMPS = auto() |
| AutogradIPU = auto() |
| AutogradXPU = auto() |
| AutogradHPU = auto() |
| AutogradVE = auto() |
| AutogradLazy = auto() |
| AutogradMeta = auto() |
| AutogradPrivateUse1 = auto() |
| AutogradPrivateUse2 = auto() |
| AutogradPrivateUse3 = auto() |
| # END autogenerated |
| |
| def __str__(self) -> str: |
| return self.name |
| |
| def lower(self) -> str: |
| return str(self).lower() |
| |
| @staticmethod |
| def parse(value: str) -> "DispatchKey": |
| for k, v in DispatchKey.__members__.items(): |
| if k == value: |
| return v |
| raise AssertionError(f"unknown dispatch key {value}") |
| |
| |
| def codegen_per_backend_entries() -> str: |
| r = [] |
| for fk in FUNCTIONALITY_KEYS: |
| for bc in BACKEND_COMPONENTS: |
| r.append(f" {fk}{bc} = auto()") |
| return "\n".join(r) |
| |
| |
| for fk in FUNCTIONALITY_KEYS: |
| for bc in BACKEND_COMPONENTS: |
| if not hasattr(DispatchKey, fk + bc): |
| r = codegen_per_backend_entries() |
| print(r) |
| raise RuntimeError( |
| f"Missing {fk}{bc} from DispatchKey enum. Here is the autogenerated list we expect to have:\n\n{r}" |
| ) |
| |
| |
| STRUCTURED_DISPATCH_KEYS = {DispatchKey.MPS, DispatchKey.CUDA, DispatchKey.CPU} |
| UFUNC_DISPATCH_KEYS = {DispatchKey.CUDA, DispatchKey.CPU} |
| |
| # Set of supported dispatch keys |
| dispatch_keys = [ |
| DispatchKey.CPU, |
| DispatchKey.SparseCPU, |
| DispatchKey.SparseCsrCPU, |
| DispatchKey.MkldnnCPU, |
| DispatchKey.CUDA, |
| DispatchKey.MPS, |
| DispatchKey.SparseCUDA, |
| DispatchKey.SparseCsrCUDA, |
| DispatchKey.QuantizedCPU, |
| DispatchKey.QuantizedCUDA, |
| DispatchKey.CompositeImplicitAutograd, |
| DispatchKey.CompositeExplicitAutograd, |
| DispatchKey.CompositeExplicitAutogradNonFunctional, |
| DispatchKey.NestedTensorCPU, |
| DispatchKey.NestedTensorCUDA, |
| # Meta is a magic key: it is automatically generated for structured |
| # kernels |
| DispatchKey.Meta, |
| DispatchKey.SparseMeta, |
| DispatchKey.QuantizedMeta, |
| DispatchKey.NestedTensorMeta, |
| DispatchKey.ZeroTensor, |
| ] |
| |
| # Dispatch keys that "support all backends". These codegen slightly differently |
| # then backend specific keys. |
| def is_generic_dispatch_key(dk: DispatchKey) -> bool: |
| return dk in { |
| DispatchKey.CompositeExplicitAutograd, |
| DispatchKey.CompositeExplicitAutogradNonFunctional, |
| DispatchKey.CompositeImplicitAutograd, |
| } |
| |
| |
| # CUDA specific dispatch keys |
| def is_cuda_dispatch_key(dk: DispatchKey) -> bool: |
| return dk in { |
| DispatchKey.CUDA, |
| DispatchKey.QuantizedCUDA, |
| DispatchKey.SparseCUDA, |
| DispatchKey.SparseCsrCUDA, |
| DispatchKey.NestedTensorCUDA, |
| DispatchKey.AutogradCUDA, |
| } |
| |
| |
| # Structured kernel generation is only supported for certain key types; |
| # otherwise use old-style |
| def is_structured_dispatch_key(dk: DispatchKey) -> bool: |
| return dk in STRUCTURED_DISPATCH_KEYS |
| |
| |
| def is_ufunc_dispatch_key(dk: DispatchKey) -> bool: |
| # For now, ufunc dispatch keys coincide with structured keys |
| return dk in UFUNC_DISPATCH_KEYS |
| |
| |
| # This is oddly named ScalarType and not DType for symmetry with C++ |
| class ScalarType(Enum): |
| Byte = auto() |
| Char = auto() |
| Short = auto() |
| Int = auto() |
| Long = auto() |
| Half = auto() |
| Float = auto() |
| Double = auto() |
| ComplexHalf = auto() |
| ComplexFloat = auto() |
| ComplexDouble = auto() |
| Bool = auto() |
| BFloat16 = auto() |
| |
| def __str__(self) -> str: |
| return self.name |
| |
| @staticmethod |
| def maybe_parse(value: str) -> Optional["ScalarType"]: |
| for k, v in ScalarType.__members__.items(): |
| if k == value: |
| return v |
| return None |
| |
| @staticmethod |
| def parse(value: str) -> "ScalarType": |
| mb_r = ScalarType.maybe_parse(value) |
| assert mb_r is not None, f"unknown dtype {value}" |
| return mb_r |
| |
| @staticmethod |
| def parse_set(values: str) -> Set["ScalarType"]: |
| dtypes: Set[ScalarType] = set() |
| for value in values.split(", "): |
| if value in DTYPE_CLASSES: |
| dtypes.update(DTYPE_CLASSES[value]) |
| else: |
| dtypes.add(ScalarType.parse(value)) |
| return dtypes |
| |
| |
| DTYPE_CLASSES: Dict[str, Set[ScalarType]] = {} |
| # NB: Integral doesn't include boolean |
| DTYPE_CLASSES["Integral"] = { |
| ScalarType.Byte, |
| ScalarType.Char, |
| ScalarType.Int, |
| ScalarType.Long, |
| ScalarType.Short, |
| } |
| # NB: Floating doesn't include low precision types |
| DTYPE_CLASSES["Floating"] = {ScalarType.Float, ScalarType.Double} |
| DTYPE_CLASSES["Complex"] = {ScalarType.ComplexFloat, ScalarType.ComplexDouble} |
| DTYPE_CLASSES["All"] = DTYPE_CLASSES["Integral"] | DTYPE_CLASSES["Floating"] |
| DTYPE_CLASSES["AllAndComplex"] = DTYPE_CLASSES["All"] | DTYPE_CLASSES["Complex"] |
| DTYPE_CLASSES["FloatingAndComplex"] = ( |
| DTYPE_CLASSES["Floating"] | DTYPE_CLASSES["Complex"] |
| ) |
| |
| |
| # Represents the valid entries for ufunc_inner_loop in native_functions.yaml. |
| # NB: if you add a new UfuncKey, you will teach torchgen.dest.ufunc how |
| # to process it. Most logic will ignore keys they don't understand, so your |
| # new key will get silently ignored until you hook in logic to deal with it. |
| class UfuncKey(Enum): |
| # These are low level keys that represent exactly one particular |
| # instantiation of the kernel produced by codegen |
| CUDAFunctor = auto() |
| CUDAFunctorOnOther = auto() |
| CUDAFunctorOnSelf = auto() |
| |
| CPUScalar = auto() |
| CPUVector = auto() |
| |
| # These are the ones users will usually specify, and |
| # implicitly "fill in" the low level keys |
| ScalarOnly = auto() # CUDA*, CPUScalar |
| Generic = auto() # CUDA*, CPU* |
| |
| def __str__(self) -> str: |
| return self.name |
| |
| @staticmethod |
| def parse(value: str) -> "UfuncKey": |
| for k, v in UfuncKey.__members__.items(): |
| if k == value: |
| return v |
| raise AssertionError(f"unknown ufunc key {value}") |
| |
| |
| class DeviceCheckType(Enum): |
| NoCheck = 0 |
| ExactSame = 1 |
| |
| |
| ViewSchemaKind = Enum( |
| "ViewSchemaKind", ("aliasing", "aliasing_inplace", "non_aliasing") |
| ) |
| |
| # The basic input to the code generation is native_functions.yaml. |
| # The name "native", BTW, comes from the distinction between native |
| # functions and legacy TH functions. The legacy TH functions are gone, |
| # but the "native" descriptor has stuck. |
| # |
| # NativeFunction models a single entry in native_functions.yaml. Its |
| # fields roughly correspond to what you would see in the YAML itself, |
| # but after canonicalization and parsing has occurred. |
| # |
| # You can see some of the overall design patterns for how we setup |
| # dataclasses in this class, but we will defer a complete discussion |
| # of this at FunctionSchema. |
| @dataclass(frozen=True) |
| class NativeFunction: |
| # The namespace for this operator. For example, if we have "at::add" |
| # then the namespace would be "at". This enables ops to be registered |
| # through the same DSL with a custom namespace. If not specified, the |
| # default namespace would be "at". |
| namespace: str |
| |
| # The function schema of the operator in question. This schema |
| # has been parsed; see FunctionSchema for more about its structure. |
| # (This type is quoted as we are forward referencing a type |
| # defined later in the file. I opted for this ordering of the |
| # classes for expository clarity.) |
| func: "FunctionSchema" |
| |
| # Whether or not to generate mutable tensor arguments like regular |
| # ones |
| use_const_ref_for_mutable_tensors: bool |
| |
| # Whether or not to omit automatic generation of a DeviceGuard |
| device_guard: bool |
| |
| # How to emit automatic generation of device check |
| device_check: DeviceCheckType |
| |
| # What python module to put the function in |
| python_module: Optional[str] |
| |
| # TODO: figure out what this does |
| category_override: Optional[str] |
| |
| # If no variants are specified in native_functions.yaml, this is |
| # assumed to be {'function'}. |
| variants: Set[Variant] |
| |
| # Whether or not we should skip generating registrations for |
| # this kernel. This is a bit of a double-edged sword, as manual |
| # registrations don't participate in codegen-based selective build! |
| manual_kernel_registration: bool |
| |
| # Whether or not to skip generating TensorMethod/Functions bindings |
| # for this kernel. Technically, this doesn't actually skip generating |
| # the binding; instead, the binding gets generated to __dispatch_{funcname} |
| # so you can make use of the normal binding if you need it. |
| manual_cpp_binding: bool |
| |
| # The location in the YAML file were this native function entry was |
| # defined. This is for conveniently reporting error messages! |
| loc: "Location" |
| |
| # A list of operators that are expected to be auto-generated for this NativeFunction. |
| # Note: This list isn't actually directly used by the codegen to generate anything. |
| # Instead, the codegen figures out what operators to generate purely based off of |
| # function schema, and uses the autogen declarations to error check. |
| # We expect every NativeFunction that gets auto-generated be explicitly called out |
| # in native_functions.yaml |
| autogen: List["OperatorName"] |
| |
| # If non-empty, this kernel is subject to ufunc codegen. |
| # Sorted by ufunc_key |
| ufunc_inner_loop: Dict[UfuncKey, "UfuncInnerLoop"] |
| |
| # Whether or not this out functions is a "structured kernel". Structured |
| # kernels are defined a little differently from normal kernels; in |
| # particular, their shape checking logic is defined separately from |
| # the kernel. Only out functions can be structured; other functions |
| # delegate to the out function using the structured_delegate keyword. |
| # Every structured kernel must have at least an out and a functional |
| # variant. |
| structured: bool |
| |
| # Whether or not this non-out function is a structured kernel, defined |
| # in terms of the out kernel referenced by the string here. |
| structured_delegate: Optional["OperatorName"] |
| |
| # Only valid for structured kernels. Specifies alternative of what |
| # to inherit from when defining the meta class for the structured |
| # operator. This will usually be TensorIteratorBase. This also |
| # changes the semantics of set_output to call the parent class. |
| structured_inherits: Optional[str] |
| |
| # Structured kernels can declare elements as "precomputed". These elements |
| # are returned by the meta function in one struct and passed to the impl |
| # function in lieu of certain kernel arguments that these precomputed |
| # elements supersede. Information about the names and types of these |
| # precomputed elements and how they correspond to kernel arguments is stored |
| # in this member, if applicable. |
| precomputed: Optional["Precompute"] |
| |
| # Argument names whose default should be excluded from the C++ interface. |
| # Intended for resolving overload ambiguities between signatures. |
| cpp_no_default_args: Set[str] |
| |
| # Note [Abstract ATen methods] |
| # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| # An abstract ATen method is one whose dispatch differs between |
| # types. These are implemented in derived types (with a |
| # standard (throwing) definition in Type). A concrete ATen |
| # method is one which has the same dispatch for all types; |
| # we just implement it in the base Type. This is exposed |
| # in Declarations.yaml via a field named 'abstract'. |
| is_abstract: bool |
| |
| # Whether or not the NativeFunction contains a backend-agnostic kernel |
| has_composite_implicit_autograd_kernel: bool |
| has_composite_explicit_autograd_kernel: bool |
| has_composite_explicit_autograd_non_functional_kernel: bool |
| |
| # Tags are used to describe semantic information about (groups of) operators, |
| # That aren't easily inferrable directly from the operator's schema. |
| tags: Set[str] |
| |
| # NB: The benefit of defining a dataclass is that we automatically get |
| # a constructor defined for all the fields we specify. No need |
| # to explicitly write it out. |
| |
| # We parse both the NativeFunction + backend-specific information about it, which it stored in a corresponding BackendIndex. |
| @staticmethod |
| def from_yaml( |
| ei: Dict[str, object], |
| loc: "Location", |
| valid_tags: Set[str], |
| ignore_keys: Optional[Set[DispatchKey]] = None, |
| ) -> Tuple[ |
| "NativeFunction", Dict[DispatchKey, Dict["OperatorName", "BackendMetadata"]] |
| ]: |
| """ |
| Parse a NativeFunction from a dictionary as directly parsed |
| from native_functions.yaml |
| """ |
| e = ei.copy() |
| |
| funcs = e.pop("func") |
| assert isinstance(funcs, str), f"not a str: {funcs}" |
| # only support one level of namespace. E.g., aten::add |
| namespace_helper = NamespaceHelper.from_namespaced_entity( |
| namespaced_entity=funcs, max_level=1 |
| ) |
| namespace = namespace_helper.get_cpp_namespace(default="aten") |
| func = FunctionSchema.parse(namespace_helper.entity_name) |
| |
| cpp_no_default_args_list = e.pop("cpp_no_default_args", []) |
| assert isinstance(cpp_no_default_args_list, list) |
| cpp_no_default_args = set(cpp_no_default_args_list) |
| |
| use_const_ref_for_mutable_tensors = e.pop( |
| "use_const_ref_for_mutable_tensors", False |
| ) |
| assert isinstance(use_const_ref_for_mutable_tensors, bool) |
| |
| variants_s = e.pop("variants", "function") |
| assert isinstance(variants_s, str) |
| variants: Set[Variant] = set() |
| for v in variants_s.split(", "): |
| if v == "function": |
| variants.add(Variant.function) |
| elif v == "method": |
| variants.add(Variant.method) |
| else: |
| raise AssertionError(f"illegal variant {v}") |
| |
| manual_kernel_registration = e.pop("manual_kernel_registration", False) |
| assert isinstance( |
| manual_kernel_registration, bool |
| ), f"not a bool: {manual_kernel_registration}" |
| |
| manual_cpp_binding = e.pop("manual_cpp_binding", False) |
| assert isinstance(manual_cpp_binding, bool), f"not a bool: {manual_cpp_binding}" |
| |
| device_guard = e.pop("device_guard", True) |
| assert isinstance(device_guard, bool), f"not a bool: {device_guard}" |
| |
| device_check_s = e.pop("device_check", None) |
| assert device_check_s is None or isinstance( |
| device_check_s, str |
| ), f"not a str: {device_check_s}" |
| device_check: DeviceCheckType |
| if device_check_s is None: |
| device_check = DeviceCheckType.ExactSame |
| else: |
| device_check = DeviceCheckType[device_check_s] |
| |
| structured = e.pop("structured", False) |
| assert isinstance(structured, bool), f"not a bool: {structured}" |
| |
| structured_delegate_s = e.pop("structured_delegate", None) |
| assert structured_delegate_s is None or isinstance( |
| structured_delegate_s, str |
| ), f"not a str: {structured_delegate_s}" |
| assert structured_delegate_s is None or "::" not in structured_delegate_s, ( |
| "namespace is not supported in structured delegate," |
| " using the same namespace as the native function" |
| ) |
| structured_delegate: Optional[OperatorName] = None |
| if structured_delegate_s is not None: |
| structured_delegate = OperatorName.parse(structured_delegate_s) |
| |
| structured_inherits = e.pop("structured_inherits", None) |
| assert structured_inherits is None or isinstance( |
| structured_inherits, str |
| ), f"not a str: {structured_inherits}" |
| assert structured_inherits is None or "::" not in structured_inherits, ( |
| "namespace is not supported in structured inherits," |
| " using the same namespace as the native function" |
| ) |
| |
| python_module = e.pop("python_module", None) |
| assert python_module is None or isinstance( |
| python_module, str |
| ), f"not a str: {python_module}" |
| assert ( |
| python_module is None or Variant.method not in variants |
| ), "functions in modules cannot be methods" |
| |
| category_override = e.pop("category_override", None) |
| assert category_override is None or isinstance( |
| category_override, str |
| ), f"not a str: {category_override}" |
| |
| precomputed_dict = e.pop("precomputed", None) |
| assert precomputed_dict is None or structured is True |
| precomputed = Precompute.parse(precomputed_dict) if precomputed_dict else None |
| |
| tags_s = e.pop("tags", "") |
| assert isinstance(tags_s, str) |
| tags: Set[str] = set() |
| if len(tags_s) > 0: |
| assert len(valid_tags) > 0 |
| for t in tags_s.split(", "): |
| # TODO: verify that the tag is valid and has an entry in tags.yaml |
| if t in valid_tags: |
| tags.add(t) |
| else: |
| raise AssertionError(f"illegal tag {t}") |
| assert isinstance(tags, set) |
| |
| from torchgen.api import cpp |
| |
| raw_dispatch = e.pop("dispatch", None) |
| assert raw_dispatch is None or isinstance(raw_dispatch, dict), e |
| dispatch: Dict[DispatchKey, BackendMetadata] = {} |
| if raw_dispatch is not None: |
| assert not manual_kernel_registration, ( |
| "cannot specify both manual_kernel_registration and dispatch; with " |
| "manual registration, dispatch has no effect!" |
| ) |
| redundant_composite_implicit_autograd = False |
| for ks, v in raw_dispatch.items(): |
| if ks == "__line__": |
| continue # not worth tracking line numbers for dispatch entries |
| assert isinstance(ks, str), e |
| for k in ks.split(","): |
| dispatch_key = DispatchKey.parse(k.strip()) |
| if ignore_keys and dispatch_key in ignore_keys: |
| continue |
| assert dispatch_key in dispatch_keys, ( |
| f"Dispatch key {dispatch_key} of kernel {v} " |
| "is not a supported dispatch key." |
| ) |
| # We only allow at most 2 levels of namespace for kernels. |
| # We will append "native" to a custom kernel namespace. |
| namespace_helper = NamespaceHelper.from_namespaced_entity( |
| v, max_level=2 |
| ) |
| kernel_namespace = namespace_helper.get_cpp_namespace(default="at") |
| # Why is 'structured' included? External backends (e.g. |
| # XLA) opt into which ops are structured independently |
| # of which in-tree ops are structured |
| dispatch[dispatch_key] = BackendMetadata( |
| kernel=namespace_helper.entity_name, |
| structured=structured |
| and is_structured_dispatch_key(dispatch_key), |
| cpp_namespace=(kernel_namespace + "::native"), |
| ) |
| if ( |
| dispatch_key is DispatchKey.CompositeImplicitAutograd |
| and v == cpp.name(func) |
| ): |
| redundant_composite_implicit_autograd = True |
| |
| assert not (len(dispatch) == 1 and redundant_composite_implicit_autograd), ( |
| "unnecessary dispatch table for this function; just delete the dispatch " |
| "key entirely" |
| ) |
| # if a function is a structured delegate, deleting the dispatch |
| # table is NOT semantics preserving |
| assert structured_delegate or dispatch.keys() != { |
| DispatchKey.CompositeImplicitAutograd |
| }, ( |
| f"unexpected name for singleton CompositeImplicitAutograd dispatch entry: expected {cpp.name(func)} " |
| f"but got {dispatch[DispatchKey.CompositeImplicitAutograd]}. Rename your implementation to the expected " |
| "name, then delete the dispatch table" |
| ) |
| elif not structured and structured_delegate is None: |
| name = str(func.name.name) |
| assert not ( |
| name.startswith("new_") |
| or name.endswith("_like") |
| # TODO: maybe it's better to test the return |
| or ( |
| func.arguments.tensor_options |
| and not func.arguments.has_tensor_arg() |
| ) |
| ), ( |
| f"expected {name} to have a CompositeExplicitAutograd " |
| "dispatch entry, but there was no dispatch table. Factory functions " |
| "should not have implicit dispatch as they should not be decomposed " |
| "for __torch_dispatch__" |
| ) |
| dispatch[DispatchKey.CompositeImplicitAutograd] = BackendMetadata( |
| cpp.name(func), structured=False, cpp_namespace=DEFAULT_KERNEL_NAMESPACE |
| ) |
| |
| composites_in_dispatch = [ |
| d |
| for d in dispatch |
| if d == DispatchKey.CompositeExplicitAutograd |
| or d == DispatchKey.CompositeExplicitAutogradNonFunctional |
| or d == DispatchKey.CompositeImplicitAutograd |
| ] |
| |
| assert len(composites_in_dispatch) <= 1, ( |
| "cannot specify more than one of CompositeExplicitAutograd, CompositeExplicitAutogradNonFunctional, " |
| "or CompositeImplicitAutograd on a single kernel; each " |
| "strictly subsumes the other. If you wanted to provide an explicit autograd " |
| "implementation, specify CompositeExplicitAutograd; otherwise specify CompositeImplicitAutograd only" |
| ) |
| |
| autogen_str = e.pop("autogen", "") |
| assert isinstance(autogen_str, str) |
| autogen = ( |
| [] |
| if autogen_str == "" |
| else [OperatorName.parse(x) for x in autogen_str.split(", ")] |
| ) |
| |
| raw_ufunc_inner_loop = e.pop("ufunc_inner_loop", {}) |
| ufunc_inner_loop = {} |
| if isinstance(raw_ufunc_inner_loop, str): |
| ufunc_inner_loop[UfuncKey.Generic] = UfuncInnerLoop.parse( |
| raw_ufunc_inner_loop, UfuncKey.Generic |
| ) |
| elif isinstance(raw_ufunc_inner_loop, dict): |
| for k, vo in raw_ufunc_inner_loop.items(): |
| if k == "__line__": |
| continue |
| assert isinstance(k, str), f"ufunc_inner_loop key is not a str: {k}" |
| assert isinstance(vo, str), f"ufunc_inner_loop value is not a str: {v}" |
| ufunc_key = UfuncKey.parse(k) |
| ufunc_inner_loop[ufunc_key] = UfuncInnerLoop.parse(vo, ufunc_key) |
| else: |
| raise AssertionError( |
| f"ufunc_inner_loop not str or dict: {raw_ufunc_inner_loop}" |
| ) |
| # Program the BackendIndex for the implicit dispatch entry from ufunc |
| if ufunc_inner_loop: |
| assert structured, "ufunc must be structured" |
| |
| # Delay import ufunc here to avoid circular import issue |
| # See: https://github.com/pytorch/pytorch/issues/81294 |
| import torchgen.api.ufunc as ufunc |
| |
| for dispatch_key in UFUNC_DISPATCH_KEYS: |
| assert ( |
| dispatch_key not in dispatch |
| ), f"ufunc should not have explicit dispatch entry for {dispatch_key}" |
| dispatch[dispatch_key] = BackendMetadata( |
| kernel=ufunc.schema_kernel_name(func, dispatch_key), |
| structured=True, |
| cpp_namespace=DEFAULT_KERNEL_NAMESPACE, |
| ) |
| |
| if structured_delegate: |
| # Structured functions MUST have a dispatch table |
| is_abstract = True |
| else: |
| is_abstract = dispatch.keys() != {DispatchKey.CompositeImplicitAutograd} |
| |
| has_composite_implicit_autograd_kernel = ( |
| DispatchKey.CompositeImplicitAutograd in dispatch.keys() |
| ) |
| has_composite_explicit_autograd_kernel = ( |
| DispatchKey.CompositeExplicitAutograd in dispatch.keys() |
| ) |
| has_composite_explicit_autograd_non_functional_kernel = ( |
| DispatchKey.CompositeExplicitAutogradNonFunctional in dispatch.keys() |
| ) |
| |
| # We aren't going to store dispatch metadata inline in NativeFunctions; |
| # instead it is separately indexed by backend (so other backends can |
| # add more dispatch entries after the fact). Reindex the individual |
| # metadata by OperatorName! |
| backend_metadata = {k: {func.name: v} for k, v in dispatch.items()} |
| |
| # don't care if it exists or not; make it easier to use this function |
| # with other yaml parsers that aren't setting __line__ in the dict |
| e.pop("__line__", None) |
| assert not e, f"leftover entries: {e}" |
| |
| # Asserts that we can't do in post_init, because they rely on backend-specific info |
| if structured_delegate is not None: |
| for key in STRUCTURED_DISPATCH_KEYS: |
| assert key not in dispatch, ( |
| f"if structured_delegate, then must not have {key} in dispatch dictionary " |
| "(it is delegated!)" |
| ) |
| |
| return ( |
| NativeFunction( |
| func=func, |
| use_const_ref_for_mutable_tensors=use_const_ref_for_mutable_tensors, |
| variants=variants, |
| structured=structured, |
| structured_delegate=structured_delegate, |
| structured_inherits=structured_inherits, |
| precomputed=precomputed, |
| autogen=autogen, |
| ufunc_inner_loop=ufunc_inner_loop, |
| manual_kernel_registration=manual_kernel_registration, |
| manual_cpp_binding=manual_cpp_binding, |
| python_module=python_module, |
| category_override=category_override, |
| device_guard=device_guard, |
| device_check=device_check, |
| loc=loc, |
| cpp_no_default_args=cpp_no_default_args, |
| is_abstract=is_abstract, |
| has_composite_implicit_autograd_kernel=has_composite_implicit_autograd_kernel, |
| has_composite_explicit_autograd_kernel=has_composite_explicit_autograd_kernel, |
| has_composite_explicit_autograd_non_functional_kernel=has_composite_explicit_autograd_non_functional_kernel, |
| tags=tags, |
| namespace=namespace, |
| ), |
| backend_metadata, |
| ) |
| |
| def symints_to_ints(self) -> "NativeFunction": |
| return dataclasses.replace(self, func=self.func.symints_to_ints()) |
| |
| def validate_unstructured(self) -> None: |
| # TODO: probably better to accumulate these errors and report them all |
| # at once |
| assert not self.structured, ( |
| "This function is structured, but there was " |
| "no valid functional variant of it." |
| ) |
| assert self.structured_delegate, ( |
| "This function delegates to another structured out function, " |
| "but no valid function was found (the delegate may not exist, or it has the wrong type)" |
| ) |
| |
| # __post_init__ functions in dataclasses can be used to do extra |
| # validation after construction. |
| # |
| # Notice that we don't do any type validation here. In fact, we |
| # rely exclusively on mypy to check if you've done types correctly! |
| # Validation is for nontrivial invariants that cannot be (conveniently) |
| # encoded in the type system. |
| def __post_init__(self) -> None: |
| if self.func.arguments.out: |
| assert self.variants == {Variant.function}, ( |
| "Native functions with out arguments MUST " |
| "be declared with only function variant; e.g., variants: function; " |
| "otherwise you will tickle a Python argument binding bug " |
| "(which usually manifests itself as the result variable being undefined.)" |
| ) |
| if self.structured: |
| assert self.func.kind() == SchemaKind.out, ( |
| "Put structured field on the out= " |
| "variant of a function; did you mean structured_delegate?" |
| ) |
| assert ( |
| self.device_guard |
| ), "device_guard: False is not respected by structured kernels" |
| if self.structured_delegate: |
| assert self.func.kind() != SchemaKind.out, ( |
| "structured_delegate field not allowed " |
| "on out= functions; did you mean structured?" |
| ) |
| assert ( |
| self.device_guard |
| ), "device_guard: False is not respected by structured kernels" |
| # Technically, with the asserts above, this assert is impossible to |
| # happen |
| assert not ( |
| self.structured and self.structured_delegate |
| ), "Cannot have both structured and structured_delegate on function" |
| defaulted_arguments = { |
| a.name for a in self.func.schema_order_arguments() if a.default is not None |
| } |
| invalid_args = set.difference(self.cpp_no_default_args, defaulted_arguments) |
| assert len(invalid_args) == 0, f"Invalid cpp_no_default_args: {invalid_args}" |
| if self.structured_inherits is not None: |
| assert ( |
| self.structured |
| ), "structured_inherits must also imply structured: True" |
| if str(self.func.name).startswith("_foreach"): |
| assert self.device_check == DeviceCheckType.NoCheck, ( |
| "foreach kernels fall back to slow path when tensor are on different devices, " |
| "device_check not allowed to be enabled" |
| ) |
| |
| @property |
| def has_composite_kernel(self) -> bool: |
| return ( |
| self.has_composite_implicit_autograd_kernel |
| or self.has_composite_explicit_autograd_kernel |
| or self.has_composite_explicit_autograd_non_functional_kernel |
| ) |
| |
| @property |
| def is_view_op(self) -> bool: |
| rets = self.func.returns |
| is_non_mutating_view = len(rets) > 0 and any( |
| r.annotation is not None and not r.annotation.is_write for r in rets |
| ) |
| is_inplace_view = "inplace_view" in self.tags |
| is_wildcard_view = any( |
| inp.annotation is not None and inp.annotation.alias_set_after != "" |
| for inp in self.func.schema_order_arguments() |
| ) |
| return is_non_mutating_view or is_inplace_view or is_wildcard_view |
| |
| @property |
| def view_schema_kind(self) -> ViewSchemaKind: |
| if self.is_view_op and self.func.name.name.inplace: |
| assert "inplace_view" in self.tags |
| return ViewSchemaKind.aliasing_inplace |
| if self.is_view_op: |
| return ViewSchemaKind.aliasing |
| else: |
| return ViewSchemaKind.non_aliasing |
| |
| @property |
| def root_name(self) -> str: |
| return self.func.name.name.base |
| |
| |
| SchemaKind = Enum("SchemaKind", ("functional", "inplace", "out", "mutable", "scratch")) |
| |
| # A structured kernel is guaranteed to have a functional and out variant, and |
| # optionally an inplace variant. |
| # |
| # NB: we create NativeFunctionsGroup *even if* the function is not |
| # actually annotated structured. Test the structured boolean to see if it |
| # actually is structured or not. |
| @dataclass(frozen=True) |
| class NativeFunctionsGroup: |
| functional: NativeFunction |
| inplace: Optional[NativeFunction] |
| mutable: Optional[NativeFunction] |
| out: NativeFunction |
| |
| @property |
| def structured(self) -> bool: |
| # Whether or not the operator has a meta() function. This information is backend-agnostic. |
| return self.out.structured |
| |
| def __post_init__(self) -> None: |
| test_sig: FunctionSchema = self.functional.func.signature() |
| for f in self.functions(): |
| if test_sig != f.func.signature(): |
| raise AssertionError( |
| "NativeFunctionsGroup constructed from two NativeFunctions " |
| f"that don't have matching signatures: {test_sig} != {f.func.signature()}" |
| ) |
| assert self.functional.func.kind() == SchemaKind.functional |
| assert self.out.func.kind() == SchemaKind.out |
| assert self.functional.namespace == self.out.namespace |
| if self.inplace is not None: |
| assert self.inplace.func.kind() == SchemaKind.inplace |
| assert self.inplace.namespace == self.functional.namespace |
| |
| if self.mutable is not None: |
| assert self.mutable.func.kind() == SchemaKind.mutable |
| assert self.mutable.namespace == self.functional.namespace |
| # See Note [Overload Ambiguity With Functional Variants] |
| assert self.functional.func.name.name.functional_overload |
| |
| if self.structured: |
| # For now, structured composite kernels are not supported (need some |
| # design work to figure out how to make the composite case work) |
| assert not self.out.has_composite_implicit_autograd_kernel |
| |
| assert self.functional.structured_delegate == self.out.func.name, ( |
| f"{self.functional.func.name} delegates to {self.functional.structured_delegate} " |
| f"but its actual delegate is {self.out.func.name}" |
| ) |
| if self.inplace is not None: |
| assert self.inplace.structured_delegate == self.out.func.name |
| |
| generated_fns = [ |
| str(f.func.name) for f in self.functions() if "generated" in f.tags |
| ] |
| generated_fns_str = ", ".join(str(x) for x in generated_fns) |
| expected_generated_fns = f.autogen |
| expected_generated_fns_str = ", ".join(str(x) for x in expected_generated_fns) |
| if len(expected_generated_fns) == 0 and len(generated_fns) > 0: |
| raise RuntimeError( |
| f"The codegen expects to be able to generate '{generated_fns_str}'." |
| " In order to generate them however, we expect them to be called out explicitly in the yaml." |
| f" Please add an 'autogen: {generated_fns_str}' line to the entry for {str(f.func.name)}" |
| ) |
| if expected_generated_fns_str != generated_fns_str: |
| raise RuntimeError( |
| f"The codegen expects to be able to generate '{generated_fns_str}'." |
| f" To do so, it expects a line: 'autogen: {generated_fns_str}'." |
| f" Instead, it found 'autogen: {expected_generated_fns_str}'" |
| ) |
| |
| def signature(self) -> "FunctionSchema": |
| return self.out.func.signature() |
| |
| def functions(self) -> Iterator[NativeFunction]: |
| yield self.functional |
| yield self.out |
| if self.inplace is not None: |
| yield self.inplace |
| if self.mutable is not None: |
| yield self.mutable |
| |
| @property |
| def root_name(self) -> str: |
| return self.functional.root_name |
| |
| @staticmethod |
| def from_dict( |
| d: Dict[SchemaKind, NativeFunction] |
| ) -> Optional["NativeFunctionsGroup"]: |
| assert d |
| if len(d) == 1: |
| return None |
| d = dict(d) # non-destructive updates please |
| functional = d.pop(SchemaKind.functional, None) |
| inplace = d.pop(SchemaKind.inplace, None) |
| mutable = d.pop(SchemaKind.mutable, None) |
| out = d.pop(SchemaKind.out, None) |
| assert not d |
| assert functional is not None |
| # There are a few operators which only have functional/inplace variants; |
| # these don't count as structured for our purposes here |
| if out is None: |
| return None |
| # assuming all variants have the same namespace |
| return NativeFunctionsGroup( |
| functional=functional, |
| inplace=inplace, |
| mutable=mutable, |
| out=out, |
| ) |
| |
| |
| @dataclass(frozen=True) |
| class BackendMetadata: |
| # The name of the backend kernel, for a given operator |
| # for in-tree backends. These names come directly from the 'dispatch" field |
| # in native_functions.yaml. The dispatch entry is optional; in that |
| # case, that is equivalent to having written: |
| # |
| # dispatch: |
| # CompositeImplicitAutograd: $operator_name |
| kernel: str |
| # Whether or not the operator has a structured kernel implemented, for this particular backend. |
| # For in-tree backends, they all have the same value for structured- this is listed |
| # in native_functions.yaml. |
| # However, external backends like XLA can indendently toggle which ops are structured. |
| structured: bool |
| |
| # The namespace for kernels, default value: DEFAULT_KERNEL_NAMESPACE |
| cpp_namespace: str |
| |
| |
| @dataclass(frozen=True) |
| class UfuncInnerLoop: |
| name: str |
| supported_dtypes: Set[ScalarType] |
| # key is stored here because it affects the semantics of name, |
| # so its helpful to have them together for further processing |
| ufunc_key: UfuncKey |
| |
| @staticmethod |
| def parse(value: str, ufunc_key: UfuncKey) -> "UfuncInnerLoop": |
| name, supported_dtypes_str = value.split(" ", 1) |
| assert supported_dtypes_str[0] == "(" |
| assert supported_dtypes_str[-1] == ")" |
| supported_dtypes = set() |
| for k in supported_dtypes_str[1:-1].split(", "): |
| supported_dtypes |= ScalarType.parse_set(k) |
| return UfuncInnerLoop( |
| name=name, supported_dtypes=supported_dtypes, ufunc_key=ufunc_key |
| ) |
| |
| |
| # BackendIndex represents a backend. |
| # The BackendIndex encodes per-operator information that is potentially different |
| # for each backend. The most obvious example is the name of the kernel |
| # (the 'dispatch' entry in native_functions.yaml). |
| # However, there can be other examples of different backends having different information. |
| # External backends can choose to opt their kernels to be structured independently from in-tree backends, |
| # which means that this information isn't inherentely tied to a NativeFunction- it's different per backend. |
| @dataclass(frozen=True) |
| class BackendIndex: |
| dispatch_key: DispatchKey |
| # Mainly important for structured kernels, this determines which variant in the operator group is used to implement the others. |
| # All in-tree ops use out kernels, while XLA uses functional kernels. |
| use_out_as_primary: bool |
| # Whether the backend requires a device guard, and device checks. |
| # For in-tree backends, this is currently just CUDA/HIP |
| # For out-of-tree backends, this is currently just Intel XPU |
| device_guard: bool |
| # Whether the backend is in-tree (CPU/CUDA) or out-of-tree (XLA) |
| external: bool |
| # Other backend-specific information that is on a per-operator basis |
| index: Dict["OperatorName", BackendMetadata] |
| |
| @staticmethod |
| def grow_index( |
| parent_index: Dict[DispatchKey, Dict["OperatorName", BackendMetadata]], |
| child_index: Dict[DispatchKey, Dict["OperatorName", BackendMetadata]], |
| ) -> None: |
| for k, v in child_index.items(): |
| for op_name, metadata in v.items(): |
| assert ( |
| op_name not in parent_index[k] |
| ), f"duplicate operator {op_name} for dispatch key {k}" |
| parent_index[k][op_name] = metadata |
| |
| def primary(self, g: NativeFunctionsGroup) -> NativeFunction: |
| if self.use_out_as_primary: |
| return g.out |
| else: |
| return g.functional |
| |
| def has_kernel(self, g: Union[NativeFunction, NativeFunctionsGroup]) -> bool: |
| m = self.get_kernel(g) |
| return m is not None |
| |
| def get_kernel( |
| self, g: Union[NativeFunction, NativeFunctionsGroup] |
| ) -> Optional[BackendMetadata]: |
| if isinstance(g, NativeFunction): |
| f = g |
| elif isinstance(g, NativeFunctionsGroup): |
| f = self.primary(g) |
| else: |
| assert_never(f) |
| if f.func.name not in self.index: |
| return None |
| return self.index[f.func.name] |
| |
| def native_function_class_name(self) -> Optional[str]: |
| if self.external: |
| return f"{str(self.dispatch_key)}NativeFunctions" |
| else: |
| # TODO: This discrepancy isn't required; we could also generated |
| # a class for in-tree kernels. It'll just require carefully |
| # updating every kernel definition + callsite of every in-tree aten kernel. |
| return None |
| |
| |
| # The function schema is undoubtedly the most important data structure |
| # in all of the codegen, as it defines the type signature for operators, |
| # and most of the code generation we do is type directed (e.g., look at |
| # the types, decide what to do. Think about how we code generate |
| # C++ function stubs!) |
| # |
| # We will also see in this class the general structure for how we model |
| # data in this code generation. A few notable properties to point out |
| # ahead of time: |
| # |
| # - These dataclasses are a *lossless* representation of the strings |
| # they are parsed from. In fact, we assert that given the |
| # information stored in the dataclass, we can exactly reconstruct |
| # the string we parsed from (and assert this inside the parse |
| # definition). There are a few reasons for this: |
| # |
| # - If you find that it is difficult to reconstruct the string |
| # given a dataclass, that is a clue that you are data |
| # representation is wrong. |
| # |
| # - It helps ensure that all relevant information is present |
| # in the dataclass, so that downstream users aren't tempted |
| # to reparse the original string to get some information |
| # that was omitted. |
| # |
| # - It forces you to represent the data in-memory in the same way |
| # it is recorded textually, which makes the dataclasses easier |
| # to understand for someone who is familiar with the |
| # textual format. (As a tradeoff, it means you have to model |
| # the syntax, even when it is inconvenient. But maybe that means |
| # the syntax is bad!) If you don't understand the internal |
| # representation, go look at the printing code to see how |
| # it maps onto the surface syntax! |
| # |
| # - It makes it easy to test the parsing code, as parsing code |
| # that is inconsistent with the string code will fail early |
| # and loudly. (As a tradeoff, it makes the parsing code a bit |
| # brittle (in particular, with trivial whitespace changes you |
| # are likely to trigger an assert error). |
| # |
| # In general, try to make the __str__ code as simple as possible |
| # (even at the cost of more complex parsing logic.) Additionally, |
| # try to minimize redundancy in data representation. (Precomputed |
| # fields are OK though: they are defined as a simple function on |
| # the canonical representation in question.) |
| # |
| # - These dataclasses are all frozen; once constructed their |
| # values never change. This makes it easy to tell where any |
| # given data came from: just look to the constructor. As a |
| # tradeoff, you can't easily "decorate" a schema with extra |
| # information from a post-facto analysis. We impose this |
| # restriction to make these structures more understandable. |
| # |
| @dataclass(frozen=True) |
| class FunctionSchema: |
| # The name of the operator this function schema describes. |
| name: "OperatorName" |
| |
| arguments: "Arguments" |
| |
| # TODO: Need to handle collisions with argument names at some point |
| returns: Tuple["Return", ...] |
| |
| def schema_order_arguments(self) -> Iterator["Argument"]: |
| return itertools.chain( |
| self.arguments.flat_positional, |
| self.arguments.flat_kwarg_only, |
| self.arguments.out, |
| ) |
| |
| decl_re = re.compile(r"(?P<name>[^\(]+)\((?P<args>.*)\) -> (?P<returns>.*)") |
| |
| def symints_to_ints(self) -> "FunctionSchema": |
| return dataclasses.replace(self, arguments=self.arguments.symints_to_ints()) |
| |
| @staticmethod |
| def parse(func: str) -> "FunctionSchema": |
| # We should probably get a proper parser here |
| decls = FunctionSchema.decl_re.findall(func) |
| assert len(decls) == 1, f"Invalid function schema: {func}" |
| ops, args, return_decl = decls[0] |
| name = OperatorName.parse(ops) |
| arguments = Arguments.parse(args) |
| returns = parse_returns(return_decl) |
| r = FunctionSchema(name=name, arguments=arguments, returns=returns) |
| assert str(r) == func, f"{str(r)} != {func}" |
| return r |
| |
| def returns_are_aliased(self) -> bool: |
| # We assert earlier that schemas can't have a mix of aliased and non-aliased returns |
| return any( |
| r |
| for r in self.returns |
| if r.annotation is not None and r.annotation.is_write |
| ) |
| |
| def __post_init__(self) -> None: |
| for arg, ret in zip(self.arguments.out, self.returns): |
| assert arg.annotation == ret.annotation, ( |
| "Out arguments must have matching return Tensor; furthermore, " |
| "the ith-argument needs to correspond to the ith return" |
| ) |
| # We also enforce that if you have any mutable, positional args, then they are not returned. |
| # This makes it easier to group these functions properly with their functional/out= counterparts. |
| for a in self.arguments.post_self_positional_mutable: |
| assert not any( |
| a.annotation == r.annotation for r in self.returns |
| ), f"If you have a schema with mutable positional args, we expect them to not be returned. schema: {str(self)}" |
| # Invariant: we expect out arguments to appear as keyword arguments in the schema. |
| # This means that all mutable returns should be aliased to a keyword argument |
| # (except for "self", which we explicitly don't treat as an out argument because of its use in methods) |
| # See Note [is_out_fn] |
| out_and_self = list(self.arguments.out) + [ |
| arg for arg in self.arguments.flat_positional if arg.name == "self" |
| ] |
| mutable_returns = [ |
| ret |
| for ret in self.returns |
| if ret.annotation is not None and ret.annotation.is_write |
| ] |
| immutable_returns = [ |
| ret |
| for ret in self.returns |
| if ret.annotation is None or not ret.annotation.is_write |
| ] |
| # Some assertions: We don't want any functions with a return type of "-> (Tensor(a!), Tensor)", |
| # because: |
| # (1) It's more annoying to handle properly |
| # (2) It's unnecessary - you can't method-chain on the first (mutated) output because it's part of a tuple. |
| # Instead, we expect the (a!) argument to not be returned. |
| assert ( |
| len(mutable_returns) == 0 or len(immutable_returns) == 0 |
| ), f"NativeFunctions must have either only mutable returns, or only immutable returns. Found: {str(self)}" |
| for ret in mutable_returns: |
| assert any([ret.annotation == arg.annotation for arg in out_and_self]), ( |
| 'All mutable returns must be aliased either to a keyword argument, or to "self". ' |
| "Did you forget to mark an out argument as keyword-only?" |
| ) |
| if self.arguments.out: |
| # out= ops that return their mutable inputs are only really useful for method chaining. |
| # And method chaining is only really useful if the thing you're returning is a plain Tensor. |
| # So ideally, we'd enforce that out= ops with a single plain mutable tensor should return the tensor, |
| # and all other types of out= op schemas should return void. |
| # There are a bunch of existing out= ops that return tuples of tensors though, so we're stuck with allowing that. |
| if any(a.type != BaseType(BaseTy.Tensor) for a in self.arguments.out): |
| assert ( |
| len(self.returns) == 0 |
| ), "out= ops that accept tensor lists as out arguments " |
| "are expected to have no return type (since you can't do method chaining on them)" |
| else: |
| # mutable keyward arguments whose name has _scratch_ prefix are |
| # scratch tensors for memory planning and should not be returned |
| assert len( |
| [ |
| arg |
| for arg in self.arguments.out |
| if not arg.name.startswith("_scratch_") |
| ] |
| ) == len( |
| self.returns |
| ), "Must return as many arguments as there are out arguments, or no return at all" |
| |
| if self.name.name.inplace: |
| self_a = self.arguments.self_arg |
| assert ( |
| self_a |
| and self_a.argument.annotation |
| and self_a.argument.annotation.is_write |
| ) |
| if self_a.argument.type == BaseType(BaseTy.Tensor): |
| # All inplace ops with an ordinary `Tensor self` argument should return self, |
| # to allow for method chaining. |
| assert ( |
| len(self.returns) == 1 |
| and self.returns[0].annotation == self_a.argument.annotation |
| ) |
| else: |
| # You can't method chain on non-tensor self arguments though (like a List[Tensor]) |
| # so in all other cases we expect the return type to be none. |
| assert len(self.returns) == 0 |
| |
| if self.arguments.tensor_options is not None: |
| assert self.kind() == SchemaKind.functional, ( |
| "Found an operator that is not functional, but has tensor options arguments." |
| "This is not allowed- tensor options arguments are only allowed for factory functions." |
| f"schema: {str(self)}" |
| ) |
| if self.is_functional_fn(): |
| assert self.kind() == SchemaKind.functional, ( |
| "Found an operator that is not functional, but its overload contains the string 'functional'." |
| "This is a special keyword in the codegen, please use a different overload name." |
| f"schema: {str(self)}" |
| ) |
| |
| def is_functional_fn(self) -> bool: |
| return "functional" in self.name.overload_name |
| |
| def is_symint_fn(self) -> bool: |
| # TODO: make this more robust |
| return "SymInt" in self.name.overload_name |
| |
| def is_out_fn(self) -> bool: |
| # Note [is_out_fn] |
| # |
| # out functions are the variants which take an explicit out= argument |
| # to populate into. We need to know if a schema corresponds to an |
| # out function for several reasons: |
| # |
| # - They codegen differently in C++ API |
| # - codegen to at::add_out rather than at::add |
| # - out argument is moved to front of C++ argument list |
| # |
| # out functions are DEFINED to be any function with a keyword-only |
| # argument that is mutable. In principle, this could lead to a |
| # false positive if you define a function that mutates a |
| # kwarg only argument, but this isn't the "true" output of this |
| # function. A more robust definition that would work in this |
| # case would also look at: |
| # |
| # - The output types. Out functions take in the arguments |
| # they mutate and then return them again; this is sort |
| # of "definitionally" what makes something an out function. |
| # Historically, we DO check this for consistency. |
| # - Correspondence with pure variant. An out function |
| # should have a signature equivalent to its pure variant, |
| # but just with extra kwargs for the output elements. This |
| # is difficult to actually check for and historically |
| # we only do this check in tools/ |
| return bool(self.arguments.out) |
| |
| def kind(self) -> SchemaKind: |
| """ |
| What kind of schema is this? A functional schema is one |
| that returns a newly allocated output; an inplace schema |
| modifies the self argument inplace; an out schema writes |
| the result into an explicitly provided out argument. |
| """ |
| is_out = bool(self.arguments.out) |
| is_scratch = bool( |
| [arg for arg in self.arguments.out if arg.name.startswith("_scratch_")] |
| ) |
| is_inplace = self.name.name.inplace |
| is_mutable = any( |
| a.annotation is not None and a.annotation.is_write |
| for a in self.arguments.post_self_positional |
| ) |
| assert not (is_out and is_inplace) |
| # out= and inplace schemas can also have post_self_positional mutable args, |
| # but we give precedence to out= and inplace when deciding the schema kind. |
| # Tradeoff: we probably don't want to have to teach codegen that looks at inplace ops |
| # to also worry about mutable post_self_positional arguments, |
| # but it seems like a much bigger lift to classify them has having a new schema kind. |
| # The number of ops that fit in this strange category is small enough that |
| # we can probably manually write code for them instead of forcing the codegen to handle them. |
| if is_inplace: |
| return SchemaKind.inplace |
| elif is_scratch: |
| assert ( |
| is_out |
| ), "invariant: all scratch operators are expected to be out= operators too" |
| return SchemaKind.scratch |
| elif is_out: |
| assert ( |
| not is_scratch |
| ), "We should not categorize a scratch op as an out variant. Check if the order of if statements are expected!" |
| return SchemaKind.out |
| elif is_mutable: |
| return SchemaKind.mutable |
| else: |
| return SchemaKind.functional |
| |
| # For every return: |
| # - If the return aliases an input, we return the input name |
| # - Otherwise, we return None. |
| # If return names were enforced to be consistent with aliasing information, then we wouldn't need this. |
| def aliased_return_names(self) -> List[Optional[str]]: |
| outs: List[Optional[str]] = [] |
| for r in self.returns: |
| aliased_args = [ |
| a |
| for a in self.arguments.flat_all |
| if a.annotation is not None and a.annotation == r.annotation |
| ] |
| if len(aliased_args) == 0: |
| outs.append(None) |
| elif len(aliased_args) == 1: |
| outs.append(aliased_args[0].name) |
| else: |
| aliased_names = ", ".join(a.name for a in aliased_args) |
| raise AssertionError( |
| f"Found a return ({r.name})that aliases multiple inputs ({aliased_names})" |
| ) |
| return outs |
| |
| def signature( |
| self, |
| *, |
| strip_default: bool = False, |
| strip_view_copy_name: bool = False, |
| keep_return_names: bool = False, |
| ) -> "FunctionSchema": |
| """ |
| Certain schemas are 'related', in that they are simply |
| inplace/out/functional versions of the same function. This method |
| factors these schemas into the "core" functional signature which |
| is equal across all versions. |
| |
| Here is what normalization happens to the schema to convert |
| it to a signature: |
| - The overload name is stripped (name is retained, since |
| it expresses semantic content about what the function does) |
| - Inplace is set False |
| - Out arguments are stripped |
| - Mutable post_self_positional args are converted to returns |
| - Mutability annotations are stripped (this is sound |
| because you cannot overload on mutability annotation) |
| - Return names are stripped since they are not overloadable and |
| some variants have return names but some not |
| - TensorOptions are dropped |
| because out= variants of factory functions don't include them |
| (and we want to be able to pair up factory functions with their out variants) |
| |
| Finally, we want to be able to pair up related "view" and their |
| corresponding "view_copy" operators. We do this by optionally |
| stripping the trailing "_copy" from the base name. |
| |
| Example of a mutable op before and after: |
| |
| f.func (Mutable operator): |
| _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask) # noqa: B950 |
| |
| f.func (Corresponding functional operator): |
| _fused_moving_avg_obs_fq_helper.functional(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask, Tensor running_min_out, Tensor running_max_out, Tensor scale_out, Tensor zero_point_out) # noqa: B950 |
| |
| f.func.signature() output: |
| _fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) # noqa: B950 |
| """ |
| |
| def strip_ret_annotation(r: Return) -> Return: |
| return Return( |
| name=r.name if keep_return_names else None, |
| type=r.type, |
| annotation=None, |
| ) |
| |
| base_name = self.name.name.base |
| if strip_view_copy_name and base_name.endswith("_copy"): |
| base_name = base_name.replace("_copy", "") |
| |
| # find mutable inputs that are not originally returned, and convert them to returns |
| returns_from_mutable_inputs = tuple( |
| # When we're grouping functions we strip the return names, |
| # but when we're generating the actual functional variants then we follow |
| # a convention for what to name the returns |
| Return( |
| name=f"{a.name}_out" if keep_return_names else None, |
| type=a.type, |
| annotation=None, |
| ) |
| for a in itertools.chain( |
| # Order is important here (otherwise e.g. inplace with mutable args |
| # and out= with mutable args won't have the same signature) |
| [self.arguments.self_arg.argument] |
| if self.arguments.self_arg is not None |
| else [], |
| self.arguments.out, |
| self.arguments.post_self_positional, |
| ) |
| if a.annotation is not None |
| and a.annotation.is_write |
| and not any(a.annotation == r.annotation for r in self.returns) |
| ) |
| original_returns = tuple(map(strip_ret_annotation, self.returns)) |
| # Ordering is important here. We expect the "mutable input" returns to come last. |
| returns = original_returns + returns_from_mutable_inputs |
| |
| args_sig = self.arguments.signature(strip_default=strip_default) |
| # See Note [arange.start_step schema] |
| if str(self.name) == "arange.start_step": |
| args_sig = Arguments.parse( |
| str(args_sig).replace("Scalar step", "Scalar step=1") |
| ) |
| # See Note [bernoulli.p schema] |
| if str(self.name) == "bernoulli.p": |
| args_sig = Arguments.parse(str(args_sig).replace("float p", "float p=0.5")) |
| |
| return FunctionSchema( |
| name=OperatorName( |
| name=BaseOperatorName( |
| base=base_name, |
| inplace=False, |
| dunder_method=self.name.name.dunder_method, |
| ), |
| overload_name="", # stripped |
| ), |
| arguments=args_sig, |
| returns=returns, |
| ) |
| |
| def view_signature(self) -> "FunctionSchema": |
| return self.signature(strip_view_copy_name=True) |
| |
| def with_name(self, name: "OperatorName") -> "FunctionSchema": |
| return FunctionSchema( |
| name=name, |
| arguments=self.arguments, |
| returns=self.returns, |
| ) |
| |
| @property |
| def modifies_arguments(self) -> bool: |
| return self.kind() in [SchemaKind.inplace, SchemaKind.out, SchemaKind.mutable] |
| |
| def __str__(self) -> str: |
| all_arguments_str = str(self.arguments) |
| if len(self.returns) == 1: |
| returns = str(self.returns[0]) # omit parentheses |
| else: |
| returns = "(" + ", ".join(map(str, self.returns)) + ")" |
| return f"{self.name}({all_arguments_str}) -> {returns}" |
| |
| |
| # Here is the rest of the data model, described more briefly. |
| |
| # Simplified version for what actually shows up in built-ins. |
| # Look at alias_info.h for expanded syntax. If you need the structure, |
| # you also need to make this structure recursive so it can be lined |
| # up with the type components too. For primitives this isn't really |
| # necessary |
| @dataclass(frozen=True) |
| class Annotation: |
| # Typically only has one element. Not actually a set so |
| # we can conveniently assume it is canonically ordered |
| alias_set: Tuple[str, ...] |
| is_write: bool |
| alias_set_after: str |
| |
| @staticmethod |
| def parse(ann: str) -> "Annotation": |
| # Only handling afterSet == Wildcard for now |
| becomes_wildcard_index = ann.find(" -> *") |
| if becomes_wildcard_index != -1: |
| after_set = "*" |
| # TODO: im not good enough with regexes to ignore -> * |
| m = re.match( |
| r"^([a-z])(!?)(!?)$", |
| ann[:becomes_wildcard_index] |
| + ann[becomes_wildcard_index + len(" -> *") :], |
| ) |
| else: |
| after_set = "" |
| m = re.match(r"^([a-z])(!?)(!?)$", ann) |
| assert m is not None, f"unrecognized alias annotation {ann}" |
| alias_set = (m.group(1),) |
| is_write = m.group(2) == "!" |
| r = Annotation( |
| alias_set=alias_set, is_write=is_write, alias_set_after=after_set |
| ) |
| assert str(r) == ann, f"{r} != {ann}" |
| return r |
| |
| def __str__(self) -> str: |
| alias_set = "|".join(self.alias_set) |
| if self.alias_set_after: |
| alias_set = f'{alias_set}{" -> "}{self.alias_set_after}' |
| is_write = "!" if self.is_write else "" |
| return f"{alias_set}{is_write}" |
| |
| |
| # The base class for the type system. This is also loosely modeled |
| # off of jit_type.h, but we've simplified the hierarchy to focus |
| # in on the aspects of the type system that matter for code generation |
| # (for example, there's no SingleElementType subclass anymore). |
| # You never actually construct a Type; usually it's going to be one |
| # of the subclasses. If Python had ADTs this would be one! |
| @dataclass(frozen=True) |
| class Type: |
| @staticmethod |
| def parse(t: str) -> "Type": |
| r = Type._parse(t) |
| assert str(r) == t, f"{r} != {t}" |
| return r |
| |
| @staticmethod |
| def _parse(t: str) -> "Type": |
| m = re.match(r"^(.+)\?$", t) |
| if m is not None: |
| return OptionalType(Type.parse(m.group(1))) |
| m = re.match(r"^(.+)\[([0-9]+)?\]$", t) |
| if m is not None: |
| size = int(m.group(2)) if m.group(2) is not None else None |
| return ListType(elem=Type.parse(m.group(1)), size=size) |
| try: |
| return BaseType(BaseTy[t]) |
| except KeyError: |
| raise RuntimeError(f"unrecognized type {t}") |
| |
| def __str__(self) -> str: |
| raise NotImplementedError |
| |
| # WARNING: These concepts are not very well-defined. For example, |
| # is "int?" nullable? How about "int?[]". They are defined |
| # so we can conveniently generate legacy Declarations.yaml but |
| # really we should probably just remove these at some point |
| |
| def is_tensor_like(self) -> bool: |
| raise NotImplementedError |
| |
| def is_nullable(self) -> bool: |
| raise NotImplementedError |
| |
| def is_list_like(self) -> Optional["ListType"]: |
| raise NotImplementedError |
| |
| def symint_to_int(self) -> "Type": |
| raise NotImplementedError |
| |
| |
| # Base types are simple, atomic types with no further structure |
| BaseTy = Enum( |
| "BaseTy", |
| ( |
| "Generator", |
| "ScalarType", |
| "Tensor", |
| "int", |
| "Dimname", |
| "DimVector", |
| "float", |
| "str", |
| "bool", |
| "Layout", |
| "Device", |
| "Scalar", |
| "MemoryFormat", |
| "QScheme", |
| "Storage", |
| "Stream", |
| "SymInt", |
| "ConstQuantizerPtr", # TODO: rename |
| ), |
| ) |
| |
| |
| @dataclass(frozen=True) |
| class BaseType(Type): |
| name: BaseTy |
| |
| def __str__(self) -> str: |
| return f"{self.name.name}" |
| |
| def is_tensor_like(self) -> bool: |
| return self.name == BaseTy.Tensor |
| |
| def is_nullable(self) -> bool: |
| return False |
| |
| def symint_to_int(self) -> "BaseType": |
| if self.name == BaseTy.SymInt: |
| return BaseType(BaseTy.int) |
| return self |
| |
| def is_list_like(self) -> Optional["ListType"]: |
| return None |
| |
| |
| # Optional types may be specified, or may also be validly given None |
| @dataclass(frozen=True) |
| class OptionalType(Type): |
| elem: Type |
| |
| def __str__(self) -> str: |
| return f"{self.elem}?" |
| |
| def is_tensor_like(self) -> bool: |
| return self.elem.is_tensor_like() |
| |
| def is_nullable(self) -> bool: |
| return True |
| |
| def symint_to_int(self) -> "Type": |
| return dataclasses.replace(self, elem=self.elem.symint_to_int()) |
| |
| def is_list_like(self) -> Optional["ListType"]: |
| return self.elem.is_list_like() |
| |
| |
| # List types specify that we may have multiples of an element. We |
| # also support explicit sizes on list types, but these have |
| # some nontrivial semantics! (However, for C++ API purposes, explicit |
| # sizes are mostly erased from the type system.) |
| # |
| # DANGER WILL ROBINSON: C++ elaboration depends on elem type; e.g., |
| # int[] elaborates differently than bool[3]! |
| @dataclass(frozen=True) |
| class ListType(Type): |
| elem: Type |
| size: Optional[int] |
| |
| def __str__(self) -> str: |
| size = f"{self.size}" if self.size else "" |
| return f"{self.elem}[{size}]" |
| |
| def is_tensor_like(self) -> bool: |
| return self.elem.is_tensor_like() |
| |
| def is_nullable(self) -> bool: |
| return self.elem.is_nullable() |
| |
| def symint_to_int(self) -> "ListType": |
| return ListType(self.elem.symint_to_int(), self.size) |
| |
| def is_list_like(self) -> Optional["ListType"]: |
| return self |
| |
| |
| @dataclass(frozen=True) |
| class Argument: |
| # NB: I didn't put kwarg_only as a boolean field here, unlike |
| # c10::Argument, so that printing works correctly |
| |
| name: str |
| type: Type |
| default: Optional[str] |
| |
| # The semantics of the annotation field are a little strange. |
| # |
| # Alias annotations parametrize Tensors (since Tensors are the only things |
| # that can alias.) This motivates why I write Tensor(a!)? (and not, for |
| # example, Tensor?(a!)), because the (a!) describes aliasing on the tensor, |
| # which may be optional (i.e., the alias annotation should bind first to |
| # Tensor, before the optional postfix annotation). |
| # |
| # However, despite being a property of Tensor, we (and c10::Argument) |
| # store the annotation at the top level of the Argument, rather than |
| # inside the embedded Tensor type. In the C++ version of this |
| # class, we then go through great lengths to mimic the type |
| # structure in the annotation structure so we can correlate |
| # annotations with types. |
| # |
| # Now, it turns out, in all applications in code generation, the |
| # structure of annotated types is very simple. So we just hard |
| # code it here. But if we ever do get anything more complex, this |
| # model will have to change! |
| annotation: Optional[Annotation] |
| |
| @staticmethod |
| def parse(arg: str) -> "Argument": |
| name: str |
| default: Optional[str] |
| type_and_annot, name_and_default = arg.rsplit(" ", 1) |
| if "=" in name_and_default: |
| name, default = name_and_default.split("=") |
| else: |
| name = name_and_default |
| default = None |
| # TODO: deduplicate annotation matching with Return |
| match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot) |
| annotation: Optional[Annotation] |
| if match: |
| # If you update this, make sure the __str__ still works too |
| assert match.group(2) in [ |
| "", |
| "?", |
| "[]", |
| ], "unrecognized alias analysis form with Tensor" |
| type_s = "Tensor" + match.group(2) |
| annotation = Annotation.parse(match.group(1)) |
| else: |
| type_s = type_and_annot |
| annotation = None |
| type = Type.parse(type_s) |
| r = Argument( |
| name=name, |
| type=type, |
| default=default, |
| annotation=annotation, |
| ) |
| assert str(r) == arg, f"{str(r)} != {arg}" |
| return r |
| |
| @property |
| def is_write(self) -> bool: |
| return self.annotation is not None and self.annotation.is_write |
| |
| def symint_to_int(self) -> "Argument": |
| return dataclasses.replace(self, type=self.type.symint_to_int()) |
| |
| def __str__(self) -> str: |
| type = f"{self.type}" |
| if self.annotation: |
| assert type in ["Tensor", "Tensor?", "Tensor[]"] |
| type = type.replace("Tensor", f"Tensor({self.annotation})") |
| if self.name is None: |
| return type |
| else: |
| mb_default = "" |
| if self.default: |
| mb_default = f"={self.default}" |
| return f"{type} {self.name}{mb_default}" |
| |
| |
| @dataclass(frozen=True) |
| class Return: |
| name: Optional[str] |
| type: Type |
| annotation: Optional[Annotation] |
| |
| @staticmethod |
| def parse(arg: str) -> "Return": |
| name: Optional[str] |
| if " " in arg: |
| type_and_annot, name = arg.rsplit(" ", 1) |
| else: |
| type_and_annot = arg |
| name = None |
| match = re.match(r"Tensor\((.+)\)(.*)", type_and_annot) |
| annotation: Optional[Annotation] |
| if match: |
| # If you update this, make sure the __str__ still works too |
| assert match.group(2) in [ |
| "", |
| "?", |
| "[]", |
| ], "unrecognized alias analysis form with Tensor" |
| type_s = "Tensor" + match.group(2) |
| annotation = Annotation.parse(match.group(1)) |
| else: |
| type_s = type_and_annot |
| annotation = None |
| type = Type.parse(type_s) |
| r = Return( |
| name=name, |
| type=type, |
| annotation=annotation, |
| ) |
| assert str(r) == arg, f"{str(r)} != {arg}" |
| return r |
| |
| @property |
| def is_write(self) -> bool: |
| return self.annotation is not None and self.annotation.is_write |
| |
| def __str__(self) -> str: |
| type = f"{self.type}" |
| if self.annotation: |
| assert type in ["Tensor", "Tensor?", "Tensor[]"] |
| type = type.replace("Tensor", f"Tensor({self.annotation})") |
| if self.name is None: |
| return type |
| else: |
| return f"{type} {self.name}" |
| |
| |
| # Represents the self argument for functions that may be methods |
| @dataclass(frozen=True) |
| class SelfArgument: |
| argument: Argument |
| |
| |
| # Bundle of arguments that represent a TensorOptions. This is mostly |
| # relevant for the public C++ API but we bake it into the core data |
| # model because other APIs often have to interact with it |
| @dataclass(frozen=True) |
| class TensorOptionsArguments: |
| dtype: Argument |
| layout: Argument |
| device: Argument |
| pin_memory: Argument |
| |
| def all(self) -> Sequence[Argument]: |
| return [self.dtype, self.layout, self.device, self.pin_memory] |
| |
| |
| @dataclass(frozen=True) |
| class Arguments: |
| # pre_self_positional is usually empty, but is notably non-empty |
| # for where.self, where the condition argument comes before the |
| # self argument |
| pre_self_positional: Tuple[Argument, ...] |
| self_arg: Optional[SelfArgument] |
| post_self_positional: Tuple[Argument, ...] |
| |
| pre_tensor_options_kwarg_only: Tuple[Argument, ...] |
| tensor_options: Optional[TensorOptionsArguments] |
| # post_tensor_options is typically memory format, which should be |
| # part of tensor options but isn't right now, and is usually |
| # placed after the tensor options arguments |
| post_tensor_options_kwarg_only: Tuple[Argument, ...] |
| |
| # Unlike in the previous codegen, we have factored out 'out' arguments |
| # in the canonical representation, removing them from kwarg |
| # arguments. This choice is justified by numerous downstream |
| # transformations which treat out arguments specially; additionally, |
| # you can see that canonicity is not violated! |
| out: Tuple[Argument, ...] # these are also kwarg-only |
| |
| @property |
| def flat_non_out(self) -> Sequence[Argument]: |
| ret: List[Argument] = [] |
| ret.extend(self.flat_positional) |
| ret.extend(self.flat_kwarg_only) |
| return ret |
| |
| @property |
| def flat_positional(self) -> Sequence[Argument]: |
| ret: List[Argument] = [] |
| ret.extend(self.pre_self_positional) |
| if self.self_arg is not None: |
| ret.append(self.self_arg.argument) |
| ret.extend(self.post_self_positional) |
| return ret |
| |
| @property |
| def post_self_positional_mutable(self) -> Sequence[Argument]: |
| return [a for a in self.post_self_positional if a.is_write] |
| |
| # NB: doesn't contain out arguments |
| @property |
| def flat_kwarg_only(self) -> Sequence[Argument]: |
| ret: List[Argument] = [] |
| ret.extend(self.pre_tensor_options_kwarg_only) |
| if self.tensor_options is not None: |
| ret.extend(self.tensor_options.all()) |
| ret.extend(self.post_tensor_options_kwarg_only) |
| return ret |
| |
| @property |
| def flat_all(self) -> Sequence[Argument]: |
| ret: List[Argument] = [] |
| ret.extend(self.flat_positional) |
| ret.extend(self.flat_kwarg_only) |
| ret.extend(self.out) |
| return ret |
| |
| @property |
| def non_out( |
| self, |
| ) -> Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]: |
| ret: List[Union[Argument, SelfArgument, TensorOptionsArguments]] = [] |
| ret.extend(self.positional) |
| ret.extend(self.kwarg_only) |
| return ret |
| |
| @property |
| def positional(self) -> Sequence[Union[Argument, SelfArgument]]: |
| ret: List[Union[Argument, SelfArgument]] = [] |
| ret.extend(self.pre_self_positional) |
| if self.self_arg is not None: |
| ret.append(self.self_arg) |
| ret.extend(self.post_self_positional) |
| return ret |
| |
| @property |
| def kwarg_only(self) -> Sequence[Union[Argument, TensorOptionsArguments]]: |
| ret: List[Union[Argument, TensorOptionsArguments]] = [] |
| ret.extend(self.pre_tensor_options_kwarg_only) |
| if self.tensor_options is not None: |
| ret.append(self.tensor_options) |
| ret.extend(self.post_tensor_options_kwarg_only) |
| return ret |
| |
| @property |
| def all(self) -> Sequence[Union[Argument, SelfArgument, TensorOptionsArguments]]: |
| ret: List[Union[Argument, SelfArgument, TensorOptionsArguments]] = [] |
| ret.extend(self.positional) |
| ret.extend(self.kwarg_only) |
| ret.extend(self.out) |
| return ret |
| |
| def mutable_arg_names(self) -> List[str]: |
| return [ |
| a.name |
| for a in self.flat_all |
| if a.annotation is not None and a.annotation.is_write |
| ] |
| |
| def symints_to_ints(self) -> "Arguments": |
| arguments = self |
| |
| if arguments.self_arg: |
| arguments = dataclasses.replace( |
| arguments, |
| pre_self_positional=[ |
| x.symint_to_int() for x in arguments.pre_self_positional |
| ], |
| ) |
| |
| if self.tensor_options: |
| arguments = dataclasses.replace( |
| arguments, |
| post_tensor_options_kwarg_only=[ |
| x.symint_to_int() for x in arguments.post_tensor_options_kwarg_only |
| ], |
| ) |
| |
| arguments = dataclasses.replace( |
| arguments, |
| post_self_positional=[ |
| x.symint_to_int() for x in arguments.post_self_positional |
| ], |
| pre_tensor_options_kwarg_only=[ |
| x.symint_to_int() for x in arguments.pre_tensor_options_kwarg_only |
| ], |
| ) |
| |
| return arguments |
| |
| def has_tensor_arg(self) -> bool: |
| return any(a.type.is_tensor_like() for a in self.flat_non_out) |
| |
| def signature(self, *, strip_default: bool = False) -> "Arguments": |
| # dataclasses.replace could be used here, but it is less |
| # type safe so for now I've opted to type everything out |
| def strip_arg_annotation(a: Argument) -> Argument: |
| return Argument( |
| name=a.name, |
| type=a.type, |
| default=a.default if not strip_default else None, |
| annotation=None, |
| ) |
| |
| return Arguments( |
| pre_self_positional=tuple( |
| map(strip_arg_annotation, self.pre_self_positional) |
| ), |
| self_arg=SelfArgument(strip_arg_annotation(self.self_arg.argument)) |
| if self.self_arg is not None |
| else None, |
| post_self_positional=tuple( |
| map(strip_arg_annotation, self.post_self_positional) |
| ), |
| # Since TensorOptions are droped, the post_tensor_options_kwargs are |
| # converted to pre_tensor_options_kwargs |
| pre_tensor_options_kwarg_only=tuple( |
| map(strip_arg_annotation, self.pre_tensor_options_kwarg_only) |
| ) |
| + tuple(map(strip_arg_annotation, self.post_tensor_options_kwarg_only)), |
| # TensorOptions are dropped in signature, |
| # so we can pair factory functions with their out= variants. |
| tensor_options=None, |
| post_tensor_options_kwarg_only=tuple(), |
| # out arguments are dropped in signature |
| out=(), |
| ) |
| |
| def remove_self_annotation(self) -> "Arguments": |
| assert self.self_arg is not None |
| return dataclasses.replace( |
| self, |
| self_arg=SelfArgument( |
| dataclasses.replace(self.self_arg.argument, annotation=None) |
| ), |
| ) |
| |
| def with_out_args(self, outs: List[Argument]) -> "Arguments": |
| assert len(self.out) == 0 |
| return dataclasses.replace( |
| self, |
| out=tuple(outs), |
| ) |
| |
| @staticmethod |
| def _preparse(args: str) -> Tuple[List[Argument], List[Argument], List[Argument]]: |
| positional: List[Argument] = [] |
| kwarg_only: List[Argument] = [] |
| out: List[Argument] = [] |
| arguments_acc = positional |
| |
| # TODO: Use a real parser here; this will get bamboozled |
| # by signatures that contain things like std::array<bool, 2> (note the space) |
| for arg in args.split(", "): |
| if not arg: |
| continue |
| if arg == "*": |
| assert ( |
| arguments_acc is positional |
| ), "invalid syntax: kwarg-only specifier * can only occur once" |
| arguments_acc = kwarg_only |
| continue |
| parg = Argument.parse(arg) |
| # Currently, we rely directly on the invariant that there are NO |
| # kwarg-only mutating arguments. If you want to relax this, |
| # we will need a more semantic way of matching that takes |
| # into account return arguments. In that case, you will have |
| # to manage out computation a level up, in FunctionSchema. See Note |
| # [is_out_fn] |
| if parg.annotation is not None and parg.annotation.is_write: |
| if arguments_acc is positional: |
| pass # do nothing |
| elif arguments_acc is kwarg_only: |
| arguments_acc = out |
| else: |
| assert arguments_acc is not out |
| arguments_acc.append(parg) |
| |
| return positional, kwarg_only, out |
| |
| @staticmethod |
| def parse(args: str) -> "Arguments": |
| """ |
| Input: 'int x, int y, int z' |
| """ |
| |
| # We do this in two phases. First we parse into three |
| # main categories: positional, kwarg_only, out. |
| # Then, we reparse positional and kwarg_only to separate |
| # out the self argument and tensor options arguments. |
| |
| positional, kwarg_only, out = Arguments._preparse(args) |
| |
| # Split self argument |
| self_ix = None |
| for i, a in enumerate(positional): |
| if a.name == "self": |
| self_ix = i |
| break |
| pre_self_positional: List[Argument] |
| self_arg: Optional[SelfArgument] |
| post_self_positional: List[Argument] |
| if self_ix is not None: |
| pre_self_positional = positional[:self_ix] |
| self_arg = SelfArgument(positional[self_ix]) |
| post_self_positional = positional[self_ix + 1 :] |
| else: |
| pre_self_positional = [] |
| self_arg = None |
| post_self_positional = positional |
| |
| # Group tensor options arguments |
| pre_tensor_options_kwarg_only: List[Argument] = [] |
| tensor_options: Optional[TensorOptionsArguments] = None |
| post_tensor_options_kwarg_only: List[Argument] = [] |
| kwarg_only_acc = pre_tensor_options_kwarg_only |
| |
| def pred(name: str, ty: Type) -> Callable[[Argument], bool]: |
| return lambda a: a.name == name and a.type in [ty, OptionalType(ty)] |
| |
| predicates = [ # order matters |
| pred("dtype", Type.parse("ScalarType")), |
| pred("layout", Type.parse("Layout")), |
| pred("device", Type.parse("Device")), |
| pred("pin_memory", Type.parse("bool")), |
| ] |
| |
| i = 0 |
| while i < len(kwarg_only): |
| # If there is enough space... |
| if i <= len(kwarg_only) - len(predicates): |
| # And the next len(predicates) arguments look like TensorOptions arguments |
| if all( |
| p(a) |
| for p, a in zip(predicates, kwarg_only[i : i + len(predicates)]) |
| ): |
| assert kwarg_only_acc is pre_tensor_options_kwarg_only |
| # Group them together as one argument |
| tensor_options = TensorOptionsArguments( |
| dtype=kwarg_only[i], |
| layout=kwarg_only[i + 1], |
| device=kwarg_only[i + 2], |
| pin_memory=kwarg_only[i + 3], |
| ) |
| i += len(predicates) |
| kwarg_only_acc = post_tensor_options_kwarg_only |
| continue |
| kwarg_only_acc.append(kwarg_only[i]) |
| i += 1 |
| |
| return Arguments( |
| pre_self_positional=tuple(pre_self_positional), |
| self_arg=self_arg, |
| post_self_positional=tuple(post_self_positional), |
| pre_tensor_options_kwarg_only=tuple(pre_tensor_options_kwarg_only), |
| tensor_options=tensor_options, |
| post_tensor_options_kwarg_only=tuple(post_tensor_options_kwarg_only), |
| out=tuple(out), |
| ) |
| |
| def __str__(self) -> str: |
| all_arguments: List[str] = [] |
| all_arguments.extend(map(str, self.flat_positional)) |
| if self.flat_kwarg_only or self.out: |
| all_arguments.append("*") |
| all_arguments.extend(map(str, self.flat_kwarg_only)) |
| all_arguments.extend(map(str, self.out)) |
| return ", ".join(all_arguments) |
| |
| def __post_init__(self) -> None: |
| # TODO: These invariants are weirdly asymmetric? |
| # TODO: Fancier types? |
| if self.self_arg is None: |
| assert not self.pre_self_positional |
| if self.tensor_options is None: |
| assert not self.post_tensor_options_kwarg_only |
| |
| # We don't allow any of the following to have argument annotations, |
| # to keep things simple. |
| mutable_pre_self_positionals = [ |
| a |
| for a in self.pre_self_positional |
| if a.annotation is not None and a.annotation.is_write |
| ] |
| assert ( |
| len(mutable_pre_self_positionals) == 0 |
| ), "mutable pre_self_positional arguments are not currently supported in the schema" |
| |
| |
| # Names that validly are __iXXX__ indicating inplace operations. |
| # Taken from https://www.python.org/dev/peps/pep-0203/#new-methods |
| # NB: PyTorch hasn't actually implemented all of these |
| AUGMENTED_ASSIGNMENT_NAMES = [ |
| "add", |
| "sub", |
| "mul", |
| "div", |
| "mod", |
| "pow", |
| "lshift", |
| "rshift", |
| "and", |
| "xor", |
| "or", |
| ] |
| |
| # A BaseOperatorName is what we think of the operator name, without |
| # the overload name. Unusually, we don't represent this as just a |
| # string; instead, we directly represent a few important semantic |
| # bits of information we derive from the string: namely whether |
| # or not it's inplace (add_) and whether or not it's a double-underscore |
| # method (__add__) |
| @dataclass(frozen=True) |
| class BaseOperatorName: |
| base: str |
| inplace: bool |
| dunder_method: bool |
| # Note [Overload Ambiguity With Functional Variants] |
| # A handful of operators have both a "mutable" and a "functional" variant. |
| # (native_batch_norm is a good example, although this isn't the case today). |
| # For those operators, the mutable and functional variant take in the same set of |
| # arguments, but have different alias annotations. |
| # this makes it ambiguous when you try to resolve an OverloadPacket into an overload, |
| # given a set of input arguments. |
| # |
| # So instead of making the "functional" variant in this case a real overload, e.g: |
| # native_batch_norm (mutable variant) |
| # native_batch_norm.functional (functional variant) |
| # we make it a new base operator, |
| # native_batch_norm_functional (functional variant) |
| # |
| # In an ideal world, we would probably invert this so the operators were: |
| # native_batch_norm.mutable (mutable variant) |
| # native_batch_norm (functional variant) |
| # |
| # Doing that is BC-breaking though, so we're stuck with the above modeling. |
| functional_overload: bool = False |
| |
| @staticmethod |
| def parse(op: str) -> "BaseOperatorName": |
| assert op != "" |
| assert not op.endswith("_out"), ( |
| "_out suffix is reserved and not permitted for operator names; " |
| "did you mean to specify an out overload name instead?" |
| ) |
| m = re.match(r"^__([^_]+)__$", op) |
| if m is not None: |
| dunder_method = True |
| base = m.group(1) |
| if any(base == f"i{n}" for n in AUGMENTED_ASSIGNMENT_NAMES): |
| inplace = True |
| base = base[1:] |
| else: |
| inplace = False |
| # temporary, this is not intrinsically true but |
| # has been historically true for dunder methods |
| # we support (but, if we ever got, say, __int__, this would |
| # be wrong!) |
| assert base[0] != "i" |
| else: |
| dunder_method = False |
| base = op |
| if base[-1] == "_": |
| inplace = True |
| base = base[:-1] |
| else: |
| inplace = False |
| |
| # See Note [Overload Ambiguity With Functional Variants] |
| functional_suffix = "_functional" |
| if base.endswith(functional_suffix): |
| functional_overload = True |
| base = base[: -len(functional_suffix)] |
| # This seems complicated and unnecessary, so banning dunder methods |
| # for now on ops that have a functional + mutable variant (like native_batch_norm). |
| assert not dunder_method and not inplace |
| else: |
| functional_overload = False |
| |
| r = BaseOperatorName( |
| base=base, |
| inplace=inplace, |
| dunder_method=dunder_method, |
| functional_overload=functional_overload, |
| ) |
| assert str(r) == op, f"{str(r)} != {op}" |
| return r |
| |
| def __str__(self) -> str: |
| if self.dunder_method: |
| i = "i" if self.inplace else "" |
| return f"__{i}{self.base}__" |
| else: |
| i = ( |
| "_" |
| if self.inplace |
| else "_functional" |
| if self.functional_overload |
| else "" |
| ) |
| return f"{self.base}{i}" |
| |
| |
| # Operator name is the base operator name along with the (typically not |
| # user visible) overload string. |
| @dataclass(frozen=True) |
| class OperatorName: |
| name: BaseOperatorName |
| overload_name: str |
| |
| @staticmethod |
| def parse(op_name: str) -> "OperatorName": |
| if "." in op_name: |
| name, overload_name = op_name.split(".", 1) |
| else: |
| name = op_name |
| overload_name = "" |
| r = OperatorName(name=BaseOperatorName.parse(name), overload_name=overload_name) |
| assert str(r) == op_name, f"{str(r)} != {op_name}" |
| return r |
| |
| def __str__(self) -> str: |
| if self.overload_name: |
| return f"{self.name}.{self.overload_name}" |
| else: |
| return f"{self.name}" |
| |
| # NB: This must be synchronized with the naming scheme in |
| # aten/src/ATen/templates/Operators.h |
| # Given a function schema "aten::op.overload(...)", |
| # If there is no overload name, this returns f"{op}" |
| # If there is an overload name, this returns f"{op}_{overload}" |
| def unambiguous_name(self) -> str: |
| if self.overload_name: |
| return f"{self.name}_{self.overload_name}" |
| else: |
| return f"{self.name}" |
| |
| def remove_inplace(self) -> "OperatorName": |
| return OperatorName( |
| name=BaseOperatorName( |
| base=self.name.base, |
| inplace=False, |
| dunder_method=self.name.dunder_method, |
| ), |
| overload_name=self.overload_name, |
| ) |
| |
| def with_overload(self, overload: str) -> "OperatorName": |
| return OperatorName( |
| name=BaseOperatorName( |
| base=self.name.base, |
| inplace=False, |
| dunder_method=self.name.dunder_method, |
| ), |
| overload_name=overload, |
| ) |
| |
| |
| def gets_generated_out_inplace_wrapper( |
| f: NativeFunction, g: NativeFunctionsGroup, b: BackendIndex |
| ) -> bool: |
| return ( |
| f.func.kind() is not SchemaKind.functional |
| and not b.has_kernel(f) |
| and b.has_kernel(g.functional) |
| ) |
| |
| |
| # NativeFunction objects that are views (f.is_view_op returns True) |
| # are added into a `NativeFunctionsViewGroup`, which we can use to |
| # easily access the generated (optional) view_copy NativeFunction. |
| # It's convenient to group them together, so we pair them up in NativeFunctionsViewGroup. |
| # See Note [Codegen'd {view}_copy Operators] |
| # |
| # One property of this representation is that in order for a view-like op to be part of |
| # a NativeFunctionsViewGroup, the "aliasing" version of that view op must exist. |
| # There's one case where that doesn't happen: we have a non-aliasing `narrow_copy.out` op, |
| # but don't have corresponding aliasing `narrow.out` op. |
| # This means that `narrow_copy.out` won't appear as a NativeFunctionsViewGroup. |
| @dataclass(frozen=True) |
| class NativeFunctionsViewGroup: |
| view: NativeFunction |
| # Note: the {view}_copy operator is optional because we currently don't generate copy variants |
| # for all view ops. Notably, we don't generate them for CompositeImplicitAutograd views |
| # (we already get them "for free" through decomposition) |
| view_copy: Optional[NativeFunction] |
| # view_inplace ops are also optional, but every view_inplace op should have out-of-place variant. |
| view_inplace: Optional[NativeFunction] |
| |
| def __post_init__(self) -> None: |
| assert self.view.is_view_op |
| if self.view_copy is None: |
| assert not gets_generated_view_copy(self.view), ( |
| f"{str(self.view.func.name)} appears to be a new operator that aliases its inputs." |
| " The codegen expects you to add a corresponding operator to native_functions.yaml:" |
| f" {get_view_copy_name(self.view)!s}." |
| " See Note [view_copy NativeFunctions] for details." |
| ) |
| else: |
| assert self.view_copy.func.name.name.base.endswith("_copy") |
| assert self.view.func.signature() == self.view_copy.func.signature( |
| strip_view_copy_name=True |
| ) |
| assert "view_copy" in self.view_copy.tags, ( |
| f"{str(self.view_copy.func.name), str(self.view.tags)} appears to be a view_copy operator. The codegen expects" |
| " view_copy operators to be annotated with the 'view_copy' tag in native_functions.yaml." |
| " See Note [view_copy NativeFunction] for details." |
| ) |
| if self.view_inplace is not None: |
| assert self.view.func.signature() == self.view_inplace.func.signature() |
| |
| if self.view.has_composite_implicit_autograd_kernel: |
| if self.view_inplace is not None: |
| assert self.view_inplace.has_composite_implicit_autograd_kernel, ( |
| f"{str(self.view.func.name)} and {str(self.view_inplace.func.name)} must either" |
| " both have CompositeImplicitAutograd kernels, or both not have composite kernels." |
| ) |
| |
| def functions(self, *, include_copy: bool = True) -> Iterator[NativeFunction]: |
| yield self.view |
| if self.view_inplace is not None: |
| yield self.view_inplace |
| if self.view_copy is not None and include_copy: |
| yield self.view_copy |
| |
| @property |
| def root_name(self) -> str: |
| return self.view.root_name |
| |
| @property |
| def composite(self) -> bool: |
| # We currently assert that the "group" is consistent. |
| # If the view op is composite, then its view_inplace op is too. |
| return self.view.has_composite_implicit_autograd_kernel |
| |
| |
| def gets_generated_view_copy(f: NativeFunction) -> bool: |
| # Only aliasing (view) operators get a copy variant. |
| if not f.is_view_op: |
| return False |
| # We don't need to bother generating copy variants for CompositeImplicitAutograd ops, |
| # because we can let them decompose into base view ops. |
| if f.has_composite_implicit_autograd_kernel: |
| return False |
| # We also don't need to generate copy variants for inplace views. |
| if "inplace_view" in f.tags: |
| return False |
| return True |
| |
| |
| # Given a NativeFunction that corresponds to a view op, |
| # returns the OperatorName of the corresponding "copy" variant of the op. |
| def get_view_copy_name(f: NativeFunction) -> "OperatorName": |
| # Right now, when asking for a view op's corresponding "view_copy" name |
| # we assert for sanity that the op is allowed to have a generated view_copy variant. |
| # (We can do this because "gets_generated_view_copy()" tell us which ops get a generated view_copy op). |
| # However, narrow_copy() already exists as an op directly in native_functions.yaml. |
| # I'm hardcoding narrow_copy here for now to maintain the assert, |
| # But we could also just get rid of the assert. |
| list_of_ops_with_explicit_view_copy_operators = ["narrow"] |
| if str(f.func.name) not in list_of_ops_with_explicit_view_copy_operators: |
| assert gets_generated_view_copy(f) |
| |
| base_name = f"{f.func.name.name.base}_copy" |
| view_copy_name = OperatorName( |
| name=BaseOperatorName( |
| base=base_name, inplace=False, dunder_method=f.func.name.name.dunder_method |
| ), |
| overload_name=f.func.name.overload_name, |
| ) |
| return view_copy_name |
| |
| |
| # Helper functions for parsing argument lists (both inputs and returns) |
| |
| |
| def parse_returns(return_decl: str) -> Tuple[Return, ...]: |
| """ |
| Input: '()' |
| Output: [] |
| """ |
| if return_decl == "()": |
| return () |
| if return_decl[0] == "(" and return_decl[-1] == ")": |
| return_decl = return_decl[1:-1] |
| return tuple(Return.parse(arg) for arg in return_decl.split(", ")) |
| |
| |
| # A Precompute instance consists of a map from kernel argument name |
| # to the list of Argument instances that should replace that |
| # kernel argument in the impl function. |
| @dataclass(frozen=True) |
| class Precompute: |
| # A map from kernel argument name -> a list of precomputed |
| # elements that replaces/supersedes it. |
| replace: Dict[str, List[Argument]] |
| # List of precomputed args added without replacement |
| add: List[Argument] |
| |
| @staticmethod |
| def parse(src: object) -> "Precompute": |
| assert isinstance(src, list) |
| |
| # src is a list of strings of the format: |
| # {kernel param name} -> {replacement decl}[, {replacement decl}, ...] |
| # [{add decl}[, {add decl}, ...]] |
| # The last line is optional and contains the precomputed parameters that are |
| # added without replacement. |
| # The other lines are parsed to get the names of which precomputed elements |
| # should replace which kernel arguments. |
| add_args = [] |
| if " -> " not in src[-1]: |
| add_list = src[-1].split(",") |
| add_args = [Argument.parse(name.strip()) for name in add_list] |
| src = src[:-1] |
| |
| replace = {} |
| for raw_replace_item in src: |
| assert isinstance(raw_replace_item, str) |
| assert " -> " in raw_replace_item, ( |
| "precomputed parameters without replacement" |
| " are allowed only in the last line" |
| ) |
| |
| arg, with_list_raw = raw_replace_item.split(" -> ") |
| with_list = with_list_raw.split(",") |
| with_list_args = [Argument.parse(name.strip()) for name in with_list] |
| replace[arg] = with_list_args |
| |
| r = Precompute(replace=replace, add=add_args) |
| assert r.to_list() == src, "r.to_list() != src" |
| return r |
| |
| def __post_init__(self) -> None: |
| # the template parameters are upper so if these are the |
| # same then it is ambiguous |
| for a in self.add: |
| assert a.name.upper() != a.name |
| for args in self.replace.values(): |
| for a in args: |
| assert a.name.upper() != a.name |
| |
| def to_list(self) -> List[str]: |
| replace_list = [] |
| for kernel_param, replacement_params in self.replace.items(): |
| replacements = ", ".join(str(param) for param in replacement_params) |
| replace_list.append(f"{kernel_param} -> {replacements}") |
| |
| return replace_list |