blob: b9503915051fda0640e71a0428801129ffeeccc8 [file] [log] [blame]
# Module for defining "primitive" operations executable by the nvFuser. This
# list exists to decouple main set of primitives from the ones that provide a
# lowering of the op to nvFuser’s Python interface. Mostly torch.ops.nvprims is
# a subset of the primitives in torch.ops.prims, but some additional primitives
# can be added in the future for the corresponding higher-level torch/aten
# functions.
from typing import Any, Dict, Optional
import torch
from torch._prims_common import (
DimsSequenceType,
ELEMENTWISE_TYPE_PROMOTION_KIND,
getnvFuserDtype,
make_contiguous_strides_for,
ShapeType,
TensorLikeType,
)
from torch._prims_common.wrappers import (
backwards_not_supported,
elementwise_type_promotion_wrapper,
)
nvprim_namespace = "nvprims"
nvprim = torch.library.Library(nvprim_namespace, "DEF")
nvprim_impl = torch.library.Library(
nvprim_namespace, "IMPL", "CompositeExplicitAutograd"
)
nvprim_implicit_impl = torch.library.Library(
nvprim_namespace, "IMPL", "CompositeImplicitAutograd"
)
nvprim_autograd_impl = torch.library.Library(nvprim_namespace, "IMPL", "Autograd")
nvprim_meta_impl = torch.library.Library(nvprim_namespace, "IMPL", "Meta")
nvprim_names = [
"abs",
"acos",
"asin",
"atan",
"atanh",
"cos",
"cosh",
"bitwise_not",
"ceil",
"erf",
"erfc",
"exp",
"expm1",
"floor",
"imag",
"isfinite",
"lgamma",
"log",
"log1p",
"log2",
"log10",
"real",
"reciprocal",
"neg",
"round",
"rsqrt",
"sign",
"sin",
"sinh",
"sqrt",
"tan",
"tanh",
"transpose",
"trunc",
"add",
"atan2",
"bitwise_and",
"bitwise_or",
"bitwise_xor",
"div",
"eq",
"fmod",
"ge",
"gt",
"le",
"lt",
"mul",
"ne",
"pow",
"remainder",
"sub",
"squeeze",
"view_of",
"broadcast_in_dim",
"where",
"convert_element_type",
"sum",
"var",
"amax",
"amin",
]
_nvfuser_impls: Dict[str, Any] = {}
_nvfuser_unary_ops = {
"abs",
"acos",
"asin",
"atan",
"atanh",
"cos",
"cosh",
"bitwise_not",
"ceil",
"erf",
"erfc",
"exp",
"expm1",
"floor",
"imag",
"isfinite",
"lgamma",
"log",
"log1p",
"log2",
"log10",
"reciprocal",
"neg",
"real",
"round",
"rsqrt",
"sign",
"sin",
"sinh",
"sqrt",
"tan",
"tanh",
"trunc",
}
def _assert_nvfuser_op_exists(fname: str):
try:
from torch._C._nvfuser import FusionDefinition as fd # type: ignore[import]
assert getattr(fd.Operators, fname)
except ImportError:
# Not all PyTorch builds have nvfuser
pass
for fname in _nvfuser_unary_ops:
exec(
f"""
# Ensure that the nvfuser implementation exists
_assert_nvfuser_op_exists("{fname}")
def _{fname}_nvfuser(fd, a):
return fd.ops.{fname}(a) # type: ignore[attr-defined]
_nvfuser_impls["{fname}"] = _{fname}_nvfuser
"""
)
_nvfuser_binary_ops = {
"add",
"atan2",
"bitwise_and",
"bitwise_or",
"bitwise_xor",
"div",
"eq",
"fmod",
"ge",
"gt",
"le",
"lt",
"mul",
"ne",
"pow",
"remainder",
"sub",
}
for fname in _nvfuser_binary_ops:
exec(
f"""
# Ensure that the nvfuser implementation exists
_assert_nvfuser_op_exists("{fname}")
def _{fname}_nvfuser(fd, a, b):
return fd.ops.{fname}(a, b) # type: ignore[attr-defined]
_nvfuser_impls["{fname}"] = _{fname}_nvfuser
"""
)
_nvfuser_ternary_ops = {
"where",
}
for fname in _nvfuser_ternary_ops:
exec(
f"""
# Ensure that the nvfuser implementation exists
_assert_nvfuser_op_exists("{fname}")
def _{fname}_nvfuser(fd, a, b, c):
return fd.ops.{fname}(a, b, c) # type: ignore[attr-defined]
_nvfuser_impls["{fname}"] = _{fname}_nvfuser
"""
)
def _native_batch_norm_nvfuser(
fd, input, weight, bias, running_mean, running_var, training, momentum, eps
):
if weight is None:
weight = fd.define_null_tensor()
if bias is None:
bias = fd.define_null_tensor()
if running_mean is None:
running_mean = fd.define_null_tensor()
if running_var is None:
running_var = fd.define_null_tensor()
return fd.ops.batch_norm(
input,
weight,
bias,
running_mean,
running_var,
training,
momentum,
eps,
)
def _broadcast_in_dim_nvfuser(
fd: Any,
a: TensorLikeType,
shape: ShapeType,
broadcast_dimensions: ShapeType,
):
return fd.ops.broadcast_in_dim(a, shape, broadcast_dimensions) # type: ignore[attr-defined]
def _convert_element_type_nvfuser(fd: Any, a: TensorLikeType, dtype: torch.dtype):
nvfuser_dtype = getnvFuserDtype(dtype)
return fd.ops.cast(a, nvfuser_dtype) # type: ignore[attr-defined]
def _transpose_nvfuser(fd, a, permutation):
return fd.ops.permute(a, permutation) # type: ignore[attr-defined]
def _squeeze_nvfuser(fd, a, a_shape, dimensions):
for idx in reversed(sorted(dimensions)):
a = fd.ops.squeeze(a, a_shape, idx)
a_shape = a_shape[:idx] + a_shape[idx + 1 :]
return a
def _view_of_nvfuser(fd, a):
return fd.ops.set(a)
def _sum_nvfuser(
fd: Any,
a: TensorLikeType,
dims: DimsSequenceType,
):
keep_dims = False
output_dtype = torch._C._nvfuser.DataType.Null
return fd.ops.sum(a, dims, keep_dims, output_dtype)
def _var_nvfuser(
fd: Any,
a: TensorLikeType,
dims: DimsSequenceType,
*,
correction: int,
):
keep_dims = False
return fd.ops.var(a, dims, correction, keep_dims)
def _var_mean_nvfuser(
fd: Any,
a: TensorLikeType,
dims: DimsSequenceType,
unbiased: Optional[bool] = None,
keepdim: bool = False,
*,
correction: int,
):
# Unbiased arg shouldn't be set when this function is called
assert unbiased is None
# Ignore keepdim arg, because currently it's automatically converted into nvfuser's symbolic scalar
# keepdim is handled by the reference implementation
keepdim = False
return fd.ops.var_mean(a, dims, correction, keepdim)
def _rand_like_nvfuser(fd: Any, a: TensorLikeType):
return fd.ops.rand_like(a)
def _amax_nvfuser(
fd: Any,
a: TensorLikeType,
dims: DimsSequenceType,
):
keep_dims = False
return fd.ops.max(a, dims, keep_dims)
def _amin_nvfuser(
fd: Any,
a: TensorLikeType,
dims: DimsSequenceType,
):
keep_dims = False
return fd.ops.min(a, dims, keep_dims)
_nvfuser_impls["native_batch_norm"] = _native_batch_norm_nvfuser
_nvfuser_impls["broadcast_in_dim"] = _broadcast_in_dim_nvfuser
_nvfuser_impls["convert_element_type"] = _convert_element_type_nvfuser
_nvfuser_impls["transpose"] = _transpose_nvfuser
_nvfuser_impls["squeeze"] = _squeeze_nvfuser
_nvfuser_impls["view_of"] = _view_of_nvfuser
_nvfuser_impls["rand_like"] = _rand_like_nvfuser
_nvfuser_impls["sum"] = _sum_nvfuser
_nvfuser_impls["var"] = _var_nvfuser
_nvfuser_impls["var_mean"] = _var_mean_nvfuser
_nvfuser_impls["amax"] = _amax_nvfuser
_nvfuser_impls["amin"] = _amin_nvfuser
def register_native_batch_norm():
"""This function is used to register the native_batch_norm function in torch.ops.nvprims module."""
name = "native_batch_norm"
nvprim.define(
f"{name}(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, "
+ "bool training, float momentum, float eps)"
+ " -> (Tensor, Tensor, Tensor)"
)
def _prim_impl(
input, weight, bias, running_mean, running_var, training, momentum, eps
):
return torch.native_batch_norm(
input, weight, bias, running_mean, running_var, training, momentum, eps
)
nvprim_impl.impl(name, _prim_impl)
nvprim_autograd_impl.impl(
name, backwards_not_supported(torch.ops.nvprims.native_batch_norm.default)
)
prim_packet = torch.ops.nvprims.native_batch_norm
prim = prim_packet.default
for p in (prim_packet, prim):
p.__doc__ = "Computes batch normalization."
p.impl_nvfuser = _nvfuser_impls["native_batch_norm"]
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
def register_rand_like():
name = "rand_like"
nvprim.define(
"rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, "
+ "Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"
)
def _meta_rand_like(
self,
*,
dtype=None,
layout=None,
device=None,
pin_memory=None,
memory_format=None,
):
strides = make_contiguous_strides_for(self.shape)
return torch._prims.TensorMeta(
self,
shape=self.shape,
strides=strides,
dtype=dtype,
device=device,
)
def _prim_impl(
self,
*,
dtype=None,
layout=None,
device=None,
pin_memory=None,
memory_format=None,
):
return torch.rand_like(
self,
dtype=dtype,
layout=layout,
device=device,
pin_memory=pin_memory,
memory_format=memory_format,
)
nvprim_impl.impl(name, _prim_impl)
nvprim_meta_impl.impl(name, _meta_rand_like)
prim_packet = getattr(torch.ops.nvprims, name)
prim = prim_packet.default
nvprim_autograd_impl.impl(name, backwards_not_supported(prim))
for p in (prim_packet, prim):
p.__doc__ = "Computes rand_like"
p.impl_nvfuser = _nvfuser_impls["rand_like"]
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
def register_var_mean():
"""This function is used to register the var_mean function in torch.ops.nvprims module."""
name = "var_mean.main"
# This overload must be default for correct dispatching of var_mean(Tensor, bool)
nvprim.define("var_mean(Tensor inp, bool unbiased) -> (Tensor, Tensor)")
# This signature tries to combine several overloads of the torch.var_mean function into one overload.
nvprim.define(
f"{name}(Tensor inp, int[1]? dim=None, bool? unbiased=None, bool keepdim=False, *, int? correction=None)"
+ " -> (Tensor, Tensor)"
)
# This function is used for device="meta" Tensors.
def _meta_var_mean(inp, dim=None, unbiased=None, keepdim=False, *, correction=None):
if torch._prims_common.is_complex_dtype(inp.dtype):
output_dtype = torch._prims_common.corresponding_real_dtype(inp.dtype)
else:
output_dtype = inp.dtype
var = torch._prims._reduction_meta(inp, dim, output_dtype=output_dtype)
mean = torch._prims._reduction_meta(inp, dim, output_dtype=inp.dtype)
if keepdim:
output_shape = [
inp.shape[i] if i not in dim else 1 for i in range(inp.ndim)
]
broadcast_dims = [i for i in range(inp.ndim) if i not in dim]
var = torch.ops.nvprims.broadcast_in_dim(var, output_shape, broadcast_dims)
mean = torch.ops.nvprims.broadcast_in_dim(
mean, output_shape, broadcast_dims
)
return (var, mean)
# This function is used under _AutoDispatchBelowAutograd context
def _prim_impl(inp, dim=None, unbiased=None, keepdim=False, *, correction=None):
correction = torch._prims_common.set_correction(unbiased, correction)
return torch.var_mean(inp, dim, correction=correction, keepdim=keepdim)
nvprim_impl.impl(name, _prim_impl)
nvprim_meta_impl.impl(name, _meta_var_mean)
prim_packet = torch.ops.nvprims.var_mean
prim = prim_packet.main
def _unbiased_overload_impl(inp, unbiased):
return prim(inp, dim=None, unbiased=unbiased)
nvprim_implicit_impl.impl("var_mean", _unbiased_overload_impl)
@elementwise_type_promotion_wrapper(
type_promoting_args=("a",),
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
)
def _var_mean_ref(a, dim=None, unbiased=None, keepdim=False, *, correction=None):
correction = torch._prims_common.set_correction(unbiased, correction)
# reduces over all dimensions if dim=() is passed
if dim == () or dim == []:
dim = None
dim = torch._prims_common.reduction_dims(a.shape, dim)
# For complex tensors eager computes the variance as the sum of variances of
# the real and imaginary parts
# TODO: Creating a complex tensor from real and imaginary parts is not supported
if torch._prims_common.is_complex_dtype(a.dtype):
raise NotImplementedError("Complex tensors are not supported")
var_mean = prim(a, dim, correction=correction)
if keepdim:
output_shape = [a.shape[i] if i not in dim else 1 for i in range(a.ndim)]
broadcast_dims = [i for i in range(a.ndim) if i not in dim]
var, mean = var_mean
var = torch.ops.nvprims.broadcast_in_dim(var, output_shape, broadcast_dims)
mean = torch.ops.nvprims.broadcast_in_dim(
mean, output_shape, broadcast_dims
)
var_mean = (var, mean)
return var_mean
def _var_mean_autograd(
a, dim=None, unbiased=None, keepdim=False, *, correction=None
):
# This wrapper is needed to convert prims calls inside
# elementwise_type_promotion_wrapper to nvprims calls
from torch._prims.context import NvfuserPrimsMode
with NvfuserPrimsMode():
return backwards_not_supported(_var_mean_ref)(
a, dim, unbiased, keepdim, correction=correction
)
nvprim_autograd_impl.impl(name, _var_mean_autograd)
for p in (prim_packet, prim):
p.__doc__ = "Computes the variance and mean of x over the list of dimensions specified in the dim argument"
p.impl_nvfuser = _nvfuser_impls["var_mean"]
p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined]
def register_nvprims():
"""Registers all nvFuser primitives in the torch.ops.nvprims module."""
register_var_mean()
register_native_batch_norm()
register_rand_like()
for name in nvprim_names:
main_prim = getattr(torch.ops.prims, name)
nvprim.define(main_prim.schema)
nvprim_impl.impl(name, main_prim.prim_impl)
nvprim_meta_impl.impl(name, main_prim.prim_meta_impl)
prim_packet = getattr(torch.ops.nvprims, name)
prim = prim_packet.default
nvprim_autograd_impl.impl(name, backwards_not_supported(prim))
for p in (prim_packet, prim):
p.__doc__ = main_prim.__doc__
p.impl_nvfuser = _nvfuser_impls[name]
p.return_type = main_prim.return_type # type: ignore[attr-defined]