blob: d5f15f189f044a47949008dbcb1179d2b5c64ed0 [file] [log] [blame]
import contextlib
import dataclasses
import functools
import inspect
import typing
import weakref
from torchgen.model import FunctionSchema, OperatorName, SchemaKind
import torch
import torch._C as _C
import torch.library as library
import torch.utils._pytree as pytree
"""
There are various APIs for defining custom-operator-like things in PyTorch:
- [user-facing] autograd.Function (Python)
- [user-facing] custom_op (Python)
- [for power users] torch.library (Python)
- [for power users] TORCH_LIBRARY (C++)
This file contains the implementation for a Simple Custom Operator API (CustomOp).
Using CustomOp, you are able to define a custom operator and implement interactions
between the CustomOp and various PyTorch subsystems, including all the subsystems
that are necessary for a custom operator to work with torch.compile (i.e.,
autograd, FakeTensor, functionalization).
CustomOp is positioned as being safer and easier to use than
torch.library/TORCH_LIBRARY, which require deep understanding of PyTorch internals.
In additional, it supports torch.compile better than and is in general more
comprehensive than autograd.Function, which only supports implementing gradient
computation and vmap rules.
"""
__all__ = ["custom_op", "CustomOp", "get_ctx", "AbstractImplCtx"]
SUPPORTED_DEVICE_TYPE_TO_KEY = {
"cpu": "CPU",
"cuda": "CUDA",
}
# We will not let users register CustomOps with anything that could look like
# PyTorch internals to avoid confusion.
RESERVED_NS = {
"prim",
"prims",
"aten",
"at",
"torch",
"pytorch",
}
def custom_op(
qualname: str, manual_schema: typing.Optional[str] = None
) -> typing.Callable:
r"""Creates a new CustomOp object.
In PyTorch, defining an op (short for "operator") is a two step-process:
- we need to define (create) the op
- we need to implement behavior for how the operator interacts with
various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc.
This entrypoint defines the CustomOp object (the first step);
you must then perform the second step by calling various methods on
the CustomOp object.
This API is used as a decorator (see examples).
Arguments:
qualname (str): Should be a string that looks like
"namespace::operator_name". Operators in PyTorch need a namespace to
avoid name collisions; a given operator may only be created once.
If you are writing a Python library, we recommend the namespace to
be the name of your top-level module. The operator_name must be
the same as the name of the function you pass to custom_op
(see examples).
manual_schema (Optional[str]): Each PyTorch operator needs a schema that
tells PyTorch the types of the inputs/outputs. If None (default),
we will infer the schema from the type annotations on the function
(see examples). Otherwise, if you don't want to use type annotations,
you may provide us the schema string.
Example::
>>> import numpy as np
>>> from torch import Tensor
>>>
>>> # Step 1: define the CustomOp.
>>> # We need to provide the decorator a "prototype function"
>>> # (a function with Python ellipses as the body).
>>> @custom_op("mylibrary::numpy_sin")
>>> def numpy_sin(x: Tensor) -> Tensor:
>>> ...
>>>
>>> # numpy_sin is now an instance of class CustomOp
>>> print(type(numpy_sin))
>>>
>>> # Step 2: Register an implementation for various PyTorch subsystems
>>>
>>> # Register an implementation for CPU tensors
>>> @numpy_sin.impl('cpu'):
>>> def numpy_sin_impl_cpu(x):
>>> return torch.from_numpy(np.sin(x.numpy()))
>>>
>>> # Register an implementation for CUDA tensors
>>> @numpy_sin.impl('cuda'):
>>> def numpy_sin_impl_cuda(x):
>>> return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device)
>>>
>>> x = torch.randn(3)
>>> numpy_sin(x) # calls numpy_sin_impl_cpu
>>>
>>> x_cuda = x.cuda()
>>> numpy_sin(x) # calls numpy_sin_impl_cuda
"""
def inner(func):
if not inspect.isfunction(func):
raise ValueError(
f"custom_op(...)(func): Expected `func` to be a Python "
f"function, got: {type(func)}"
)
ns, name = parse_namespace(qualname)
validate_namespace(ns)
if func.__name__ != name:
raise ValueError(
f"custom_op(qualname='{qualname}', ...)(func): expected `func` "
f"to have name '{name}' but got '{func.__name__}'. "
f"Please either change the name of `func` or the qualname that "
f"is passed to `custom_op`"
)
schema = infer_schema(func) if manual_schema is None else manual_schema
schema_str = f"{name}{schema}"
function_schema = FunctionSchema.parse(schema_str)
validate_schema(function_schema)
if manual_schema is not None:
validate_function_matches_schema(function_schema, func)
lib = library.Library(ns, "FRAGMENT")
lib.define(schema_str)
ophandle = find_ophandle_or_throw(ns, function_schema.name)
result = CustomOp(lib, ns, function_schema.name, ophandle, _private_access=True)
result.__name__ = func.__name__
result.__module__ = func.__module__
result.__doc__ = func.__doc__
# NYI: autograd not supported
# In the near future we will either directly use the
# autograd_not_implemented kernels or make those the default fallback
# for the Autograd and ADInplaceOrView keys. Both of those are a bit tricky.
library.impl(lib, result._opname, "Autograd")(
get_autograd_not_implemented_kernel(weakref.proxy(result))
)
return result
return inner
# Global dictionary holding references to all CustomOp objects
# Yes, it keeps all CustomOps alive (see NOTE [CustomOp lifetime])
# Used to query the CustomOp associated with a specific C++ dispatcher operator.
# An example usage is FakeTensor: FakeTensor checks if a specific operator
# has an implementation registered via the CustomOp API.
# Indexed by qualname (e.g. aten::foo)
global_registry: typing.Dict[str, "CustomOp"] = {}
class CustomOp:
r"""Class for custom operators in PyTorch.
Use the CustomOp API to create user-defined custom operators that behave
just like regular PyTorch operators (e.g. torch.sin, torch.mm) when it
comes to various PyTorch subsystems (like torch.compile).
To construct a `CustomOp`, use `custom_op`.
"""
def __init__(self, lib, cpp_ns, operator_name, ophandle, *, _private_access=False):
super(CustomOp, self).__init__()
if not _private_access:
raise RuntimeError(
"The CustomOp constructor is private and we do not guarantee "
"BC for it. Please use custom_op(...) to create a CustomOp object"
)
name = f"{cpp_ns}::{str(operator_name.name)}"
self._cpp_ns = cpp_ns
self._lib: library.Library = lib
self._ophandle: _C._DispatchOperatorHandle = ophandle
# Has the name of the op, e.g. "foo". We cache here for convenience.
self._opname: str = str(operator_name)
# this is _opname but with namespace. e.g. "custom::foo"
self._qualname: str = name
self.__name__ = None # mypy requires this
self._abstract_impl: typing.Optional[FuncAndLocation] = None
global_registry[self._qualname] = self
def _destroy(self):
# NOTE: [CustomOp lifetime]
# A CustomOp, once created, lives forever. The mechanism is that the
# global registry holds a reference to it. However, to make testing
# easier, we want to be able to destroy CustomOp objects.
# CustomOp._destroy does the job, though it leaves the CustomOp
# in a garbage state.
del self._lib
opnamespace = getattr(torch.ops, self._cpp_ns)
if hasattr(opnamespace, self._opname):
delattr(opnamespace, self._opname)
del global_registry[self._qualname]
def __repr__(self):
return f'<CustomOp(op="{self._qualname}")>'
def __call__(self, *args, **kwargs):
# Bypass torch.ops.* and directly do OperatorHandle::callBoxed.
# Using torch.ops.* is a bit of a pain (it can be slow and it has lifetime
# issues from caching operators that make testing CustomOp difficult).
result = _C._dispatch_call_boxed(self._ophandle, *args, **kwargs)
return result
def impl(
self, device_types: typing.Union[str, typing.Iterable[str]]
) -> typing.Callable:
r"""Register an implementation for a device type for this CustomOp object.
If the CustomOp is passed multiple Tensor inputs with different device
types, it will dispatch to the registered implementation for the highest
priority device type among those present.
The supported device types, in order of priority, are {'cuda', 'cpu'}.
This API is used as a decorator (see examples).
Arguments:
device_types (str or Iterable[str]): the device type(s) to register the function for.
Examples::
>>> import numpy as np
>>> from torch import Tensor
>>>
>>> @custom_op("mylibrary::numpy_sin")
>>> def numpy_sin(x: Tensor) -> Tensor:
>>> ...
>>>
>>> # Register an implementation for CPU Tensors
>>> @numpy_sin.impl('cpu'):
>>> def numpy_sin_impl_cpu(x):
>>> return torch.from_numpy(np.sin(x.numpy()))
>>>
>>> # Register an implementation for CUDA Tensors
>>> @numpy_sin.impl('cuda'):
>>> def numpy_sin_impl_cuda(x):
>>> return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device)
>>>
>>> x = torch.randn(3)
>>> numpy_sin(x) # calls numpy_sin_impl_cpu
>>>
>>> x_cuda = x.cuda()
>>> numpy_sin(x) # calls numpy_sin_impl_cuda
"""
if isinstance(device_types, str):
device_types = [device_types]
for device_type in device_types:
validate_device_type(device_type)
def inner(f):
for device_type in set(device_types):
dispatch_key = SUPPORTED_DEVICE_TYPE_TO_KEY[device_type]
library.impl(self._lib, self._opname, dispatch_key)(f)
return f
return inner
def impl_factory(self) -> typing.Callable:
r"""Register an implementation for a factory function."""
def inner(f):
library.impl(self._lib, self._opname, "BackendSelect")(f)
return f
return inner
def impl_abstract(self) -> typing.Callable:
r"""Register an abstract implementation for this operator.
An "abstract implementation" specifies the behavior of this operator on
Tensors that carry no data. Given some input Tensors with certain properties
(sizes/strides/storage_offset/device), it specifies what the properties of
the output Tensors are.
The abstract implementation has the same signature as the operator.
It is run for both FakeTensors and meta tensors. To write an abstract
implementation, assume that all Tensor inputs to the operator are
regular CPU/CUDA/Meta tensors, but they do not have storage, and
you are trying to return regular CPU/CUDA/Meta tensor(s) as output.
The abstract implementation must consist of only PyTorch operations
(and may not directly access the storage or data of any input or
intermediate Tensors).
This API is used as a decorator (see examples).
Examples::
>>> import numpy as np
>>> from torch import Tensor
>>>
>>> # Example 1: an operator without data-dependent output shape
>>> @custom_op('mylibrary::custom_linear')
>>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor):
>>> ...
>>>
>>> @custom_linear.impl_abstract():
>>> def custom_linear_abstract(x, weight):
>>> assert x.dim() == 2
>>> assert weight.dim() == 2
>>> assert bias.dim() == 1
>>> assert x.shape[1] == weight.shape[1]
>>> assert weight.shape[0] == bias.shape[0]
>>> assert x.device == weight.device
>>>
>>> return (x @ weight.t()) + bias
>>>
>>> # Example 2: an operator with data-dependent output shape
>>> @custom_op('mylibrary::custom_nonzero')
>>> def custom_nonzero(x: Tensor) -> Tensor:
>>> ...
>>>
>>> @custom_nonzero.impl_abstract():
>>> def custom_nonzero_abstract(x):
>>> # Number of nonzero-elements is data-dependent.
>>> # Since we cannot peek at the data in an abstract impl,
>>> # we use the ctx object to construct a new symint that
>>> # represents the data-dependent size.
>>> ctx = torch._custom_op.get_ctx()
>>> nnz = ctx.create_unbacked_symint()
>>> shape = [x.dim(), nnz]
>>> result = x.new_empty(shape, dtype=torch.long)
>>> return result
>>>
>>> @numpy_nonzero.impl(['cpu', 'cuda'])
>>> def custom_nonzero_impl(x):
>>> x_np = to_numpy(x)
>>> res = np.stack(np.nonzero(x_np), axis=1)
>>> # unbacked symbolic ints in PyTorch must be >= 2, so we
>>> # constrain the range to at least 2
>>> if res.shape[0] <= 1:
>>> raise RuntimeError("not supported")
>>> return torch.tensor(res, device=x.device)
"""
def inner(f):
frame = inspect.stack()[1]
if self._abstract_impl is not None:
raise RuntimeError(
f"Attempting to register an abstract impl for operator {self._qualname} "
f"that already has an abstract impl registered from Python at "
f"{self._abstract_impl.location}. This is not supported."
)
new_location = f"{frame.filename}:{frame.lineno}"
# FakeTensor will look at _abstract_impl
self._abstract_impl = FuncAndLocation(f, new_location)
qualname = self._qualname
# Handle DispatchKey.Meta registration
@functools.wraps(f)
def f_with_ctx(*args, **kwargs):
def error_on_ctx():
raise RuntimeError(
f"Attempted to call get_ctx() for the meta implementation "
f"for {qualname}."
f"You have presumably called get_ctx() because the operator "
f"has a data-dependent output shape; if so, there is no "
f"such meta implementation and this error is the correct "
f"behavior. Otherwise, please remove the call to get_ctx() "
f"in the implementation registered with impl_abstract "
f"at {new_location}"
)
with set_ctx_getter(error_on_ctx):
return f(*args, **kwargs)
self._lib.impl(self._opname, f_with_ctx, "Meta")
return f
return inner
@dataclasses.dataclass
class FuncAndLocation:
func: typing.Callable
location: str
def find_ophandle_or_throw(cpp_ns: str, operator_name: OperatorName):
overload_name = (
"" if operator_name.overload_name is None else operator_name.overload_name
)
return _C._dispatch_find_schema_or_throw(
f"{cpp_ns}::{str(operator_name.name)}", overload_name
)
def validate_namespace(ns: str) -> None:
if "." in ns:
raise ValueError(
f'custom_op(..., ns="{ns}"): expected ns to not contain any . (and be a '
f"valid variable name)"
)
if ns in RESERVED_NS:
raise ValueError(
f"custom_op(..., ns='{ns}'): '{ns}' is a reserved namespace, "
f"please choose something else. "
)
def validate_schema(schema: FunctionSchema) -> None:
# Coming in the future. Requires us to have correct logic for
# the ADInplaceOrView key
if schema.kind() != SchemaKind.functional:
raise ValueError(
f"custom_op does not support non-functional function schema. Got: {schema}"
)
rets = schema.returns
is_non_mutating_view = len(rets) > 0 and any(
r.annotation is not None and not r.annotation.is_write for r in rets
)
if is_non_mutating_view:
raise ValueError(f"custom_op does not support view functions. Got: {schema}")
# Just seems weird so banning for now
if not schema.returns:
raise ValueError(
f"custom_op does not support function schema with no outputs. Got: {schema}"
)
# For simplicity: don't allow self arguments
if schema.arguments.self_arg is not None:
raise ValueError(
f"custom_op does not support arguments named 'self'. Please "
f"rename your argument. Got: {schema}"
)
def parse_namespace(namespaced_entity: str) -> typing.Tuple[str, str]:
names = namespaced_entity.split("::", 1)
if len(names) != 2:
raise ValueError(f"Expected there to be a namespace in {namespaced_entity}.")
return names[0], names[1]
def validate_device_type(device_type: str) -> None:
if device_type not in SUPPORTED_DEVICE_TYPE_TO_KEY:
raise ValueError(
f"CustomOp.impl(device_types=[{device_type}, ...]): we only support device_type "
f"in {SUPPORTED_DEVICE_TYPE_TO_KEY.keys()}."
)
def get_autograd_not_implemented_kernel(custom_op) -> typing.Callable:
def autograd_not_implemented(*args, **kwargs) -> None:
if pytree.tree_any(
lambda x: isinstance(x, torch.Tensor) and x.requires_grad, (args, kwargs)
):
raise RuntimeError("Autograd has not been implemented for operator")
guard = _C._AutoDispatchBelowAutograd()
try:
return custom_op(*args, **kwargs)
finally:
del guard
return autograd_not_implemented
def supported_param(param: inspect.Parameter) -> bool:
return param.kind in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
)
def validate_function_matches_schema(
schema: FunctionSchema, func: typing.Callable
) -> None:
sig = inspect.signature(func)
if not all(supported_param(p) for _, p in sig.parameters.items()):
raise ValueError(
f"custom_op(..., manual_schema)(func): positional-only args, "
f"varargs, and kwargs are not supported. Please rewrite `func` "
f"to not have them. Got `func` with signature: {sig}"
)
if (
any(
p.annotation is not inspect.Parameter.empty
for _, p in sig.parameters.items()
)
or sig.return_annotation is not inspect.Signature.empty
):
raise ValueError(
f"custom_op(..., manual_schema)(func): When passing in a manual "
f"schema, we expect `func` to have no type annotations to avoid "
f"ambiguity. Got `func` with signature: {sig}"
)
positional = [
(name, param)
for name, param in sig.parameters.items()
if param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
]
kwargonly = [
(name, param)
for name, param in sig.parameters.items()
if param.kind == inspect.Parameter.KEYWORD_ONLY
]
def error():
raise ValueError(
f"custom_op(..., manual_schema)(func): When passing in a manual "
f"schema, we expect `func`'s signature to match `manual_schema` "
f"(aside from type annotations). "
f"func's signature: {sig}, manual_schema: {schema}"
)
def error_default_args():
raise ValueError(
f"custom_op(..., manual_schema)(func): "
f"neither func nor manual_schema should have default "
f"arguments. Got "
f"func's signature: {sig}, manual_schema: {schema}"
)
def compare(sig_args, schema_args):
if len(sig_args) != len(schema_args):
error()
for (name, param), arg in zip(sig_args, schema_args):
if name != arg.name:
error()
if param.default is not inspect.Parameter.empty or arg.default is not None:
error_default_args()
compare(positional, schema.arguments.flat_positional)
compare(kwargonly, schema.arguments.flat_kwarg_only)
def get_none():
return None
global_ctx_getter: typing.Callable = get_none
# NOTE [ctx inside the fake implementation]
# If a user has an operator with data-dependent output shape, then when writing
# a fake implementation they must query the current ctx and use methods on the
# ctx to construct a new unbacked symint.
#
# This is done via us setting the global_ctx_getter function every time a fake
# implementation is invoked.
def get_ctx() -> "AbstractImplCtx":
"""get_ctx() returns the current AbstractImplCtx object.
Calling ``get_ctx()`` is only valid inside of an abstract implementation.
"""
return global_ctx_getter()
@contextlib.contextmanager
def set_ctx_getter(ctx_getter):
global global_ctx_getter
prev = global_ctx_getter
try:
global_ctx_getter = ctx_getter
yield
finally:
global_ctx_getter = prev
class AbstractImplCtx:
"""
Context object for writing abstract implementations for custom operators.
"""
def __init__(self, _shape_env, _op):
self._shape_env = _shape_env
self._op = _op
def create_unbacked_symint(self, *, min=2, max=None) -> torch.SymInt:
"""Constructs a new symint (symbolic int) representing a data-dependent value.
This is useful for writing the abstract implementation (which is necessary
for torch.compile) for a CustomOp where an output Tensor has a size
that depends on the data of the input Tensors.
Args:
min (int): A statically known inclusive lower bound for this symint.
min must be at least 2 due to implementation details of
torch.compile. Default: 2.
max (Optional[int]): A statically known inclusive upper bound for this
symint. Default: None
.. warning:
It is important that the ``min`` and ``max`` (if not None) values are set
correctly, otherwise, there will be undefined behavior under
torch.compile. The default value of ``min`` is 2 due to torch.compile
specializing on 0/1 sizes.
You must also verify that your implementation on concrete Tensors
(e.g. CPU/CUDA) only returns Tensors where the size that corresponds
to the symint also has respects these constraint.
The easiest way to do this is to add an assertion in the CPU/CUDA/etc
implementation that the size follows these bounds.
Example::
>>> # an operator with data-dependent output shape
>>> @custom_op("mylibrary::custom_nonzero")
>>> def custom_nonzero(x: Tensor) -> Tensor:
>>> ...
>>>
>>> @custom_nonzero.impl_abstract():
>>> def custom_nonzero_abstract(x):
>>> # Number of nonzero-elements is data-dependent
>>> ctx = torch._custom_op.get_ctx()
>>> nnz = ctx.create_unbacked_symint()
>>> shape = [x.dim(), nnz]
>>> result = x.new_empty(shape, dtype=torch.long)
>>> return result
>>>
>>> @numpy_nonzero.impl(['cpu', 'cuda'])
>>> def custom_nonzero_impl(x):
>>> x_np = to_numpy(x)
>>> res = np.stack(np.nonzero(x_np), axis=1)
>>> # the size associated with ctx.create_unbacked_symint()
>>> # must be constrained in the same way, so we add an assertion here.
>>> if res.shape[0] < 2 or res.shape[0] > x.numel():
>>> raise RuntimeError("not supported")
>>> return torch.tensor(res, device=x.device)
"""
if (
self._shape_env is None
or not self._shape_env.allow_dynamic_output_shape_ops
):
raise torch._subclasses.fake_tensor.DynamicOutputShapeException(self._op)
if isinstance(min, torch.SymInt) or isinstance(max, torch.SymInt):
raise ValueError(
f"ctx.create_unbacked_symint(min={min}, max={max}): expected "
f"min and max to be statically known ints but got SymInt. "
f"This is not supported."
)
if min < 2:
raise ValueError(
f"ctx.create_unbacked_symint(min={min}, ...): expected min to be "
f"greater than or equal to 2. PyTorch only supports new "
f"data-dependent sizes of >= 2"
)
result = self._shape_env.create_unbacked_symint()
torch.fx.experimental.symbolic_shapes.constrain_range(result, min=2, max=max)
return result
def infer_schema(prototype_function: typing.Callable) -> str:
sig = inspect.signature(prototype_function)
def error_fn(what):
raise ValueError(
f"custom_op(...)(func): {what} " f"Got func with signature {sig})"
)
params = [
parse_param(name, param, error_fn) for name, param in sig.parameters.items()
]
ret = parse_return(sig.return_annotation, error_fn)
return f"({', '.join(params)}) -> {ret}"
def parse_param(name, param, error_fn):
if not supported_param(param):
error_fn("We do not support positional-only args, varargs, or varkwargs.")
if param.annotation is inspect.Parameter.empty:
error_fn(f"Parameter {name} must have a type annotation.")
if param.annotation not in SUPPORTED_PARAM_TYPES.keys():
error_fn(
f"Parameter {name} has unsupported type {param.annotation}. "
f"The valid types are: {SUPPORTED_PARAM_TYPES.keys()}."
)
if param.default is not inspect.Parameter.empty:
error_fn(
f"Parameter {name} has a default value; this is not supported. "
f"If you want to use default values then create a function with "
f"default values that calls the CustomOp"
)
return f"{SUPPORTED_PARAM_TYPES[param.annotation]} {name}"
def derived_types(
base_type, cpp_type, list_base, optional_base_list, optional_list_base
):
result = [
(base_type, cpp_type),
(typing.Optional[base_type], f"{cpp_type}?"),
]
if list_base:
result.append((typing.Tuple[base_type, ...], f"{cpp_type}[]"))
if optional_base_list:
result.append((typing.Tuple[typing.Optional[base_type], ...], f"{cpp_type}?[]"))
if optional_list_base:
result.append((typing.Optional[typing.Tuple[base_type, ...]], f"{cpp_type}[]?"))
return result
def get_supported_param_types():
data = [
# (python type, schema type, type[] variant, type?[] variant, type[]? variant
(torch.Tensor, "Tensor", True, True, False),
(int, "SymInt", True, False, True),
(float, "float", True, False, True),
(bool, "bool", True, False, True),
(str, "str", False, False, False),
(torch.types.Number, "Scalar", True, False, False),
(torch.dtype, "ScalarType", False, False, False),
(torch.device, "Device", False, False, False),
]
result = []
for line in data:
result.extend(derived_types(*line))
return dict(result)
def parse_return(annotation, error_fn):
if annotation is torch.Tensor:
return "Tensor"
origin = typing.get_origin(annotation)
if origin is not tuple:
error_fn(
"Expected output of func to be type annotated as either Tensor "
"or a Tuple of known size of one or more tensors"
)
args = typing.get_args(annotation)
for arg in args:
if arg is not torch.Tensor:
error_fn(
"Expected output of func to be type annotated as either Tensor "
"or a Tuple of known size of one or more tensors"
)
return "(" + ", ".join(["Tensor"] * len(args)) + ")"
SUPPORTED_PARAM_TYPES = get_supported_param_types()