blob: 12151dbd9a4eb8ae15e0a9483ca125ff0894ec0d [file] [log] [blame]
import torch
from torch import Tensor
from torch._decomp import register_decomposition
from enum import Enum
from typing import Tuple, Optional, List, Callable
import torch.nn.functional as F
import functools
from torch.utils._pytree import tree_map, tree_flatten
import torch._prims.utils as utils
from torch._prims.wrappers import out_wrapper_multi
# None of these functions are publicly accessible; get at them
# from torch._decomps
__all__: List[str] = []
aten = torch.ops.aten
class Reduction(Enum):
NONE = 0
MEAN = 1
SUM = 2
# This wraps a decomposition and performs various type promotion logic within it, depending on the strategy provided
# We're currently re-using ELEMENTWISE_TYPE_PROMOTION_KIND, although some of the usages are on non-elementwise ops
# Will need to validate the non-elementwise uses
def type_casts(f: Callable, type_promotion: utils.ELEMENTWISE_TYPE_PROMOTION_KIND):
@functools.wraps(f)
def inner(*args, **kwargs):
flat_args = [x for x in tree_flatten((args, kwargs))[0] if isinstance(x, Tensor)]
computation_dtype, result_dtype = utils.elementwise_dtypes(*flat_args,
type_promotion_kind=type_promotion)
# TODO: pretty sure this is not quite right
def increase_prec(x):
if isinstance(x, Tensor):
return x.to(computation_dtype)
else:
return x
def decrease_prec(x):
if isinstance(x, Tensor):
return x.to(result_dtype)
else:
return x
r = f(*tree_map(increase_prec, args), **tree_map(increase_prec, kwargs))
return tree_map(decrease_prec, r)
return inner
pw_cast_for_opmath = functools.partial(type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT)
reduction_complex_to_real = functools.partial(type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT)
pw_cast_for_int_to_real = functools.partial(type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT)
# This expands x until x.dim() == dim. Might be useful as an operator
def _unsqueeze_to_dim(x: Tensor, dim: int):
for _ in range(dim - x.dim()):
x = x.unsqueeze(-1)
return x
@register_decomposition(aten.tanh_backward)
@pw_cast_for_opmath
def tanh_backward(out_grad: Tensor, y: Tensor):
return out_grad * (1 - y * y).conj_physical()
@register_decomposition(aten.sigmoid_backward)
@pw_cast_for_opmath
def sigmoid_backward(out_grad: Tensor, y: Tensor):
return out_grad * (y * (1 - y)).conj_physical()
@register_decomposition(aten.softplus_backward)
@pw_cast_for_opmath
def softplus_backward(out_grad: Tensor, x: Tensor, beta: float, threshold: float):
z = (x * beta).exp()
return torch.where((x * beta) > threshold, out_grad, out_grad * z / (z + 1.0))
@register_decomposition(aten.elu)
@pw_cast_for_opmath
def elu(
self: Tensor, alpha: float = 1, scale: float = 1, input_scale: float = 1
) -> Tensor:
negcoef = alpha * scale
poscoef = scale
negiptcoef = input_scale
return torch.where(
self > 0, self * poscoef, (torch.exp(self * negiptcoef) - 1) * negcoef
)
@register_decomposition(aten.elu_backward)
@pw_cast_for_opmath
def elu_backward(
grad_output: Tensor,
alpha: float,
scale: float,
input_scale: float,
is_result: bool,
self_or_result: Tensor,
):
negcoef = alpha * scale
poscoef = scale
negiptcoef = input_scale
if is_result:
return torch.where(
self_or_result <= 0,
grad_output * negiptcoef * (self_or_result + negcoef),
self_or_result * poscoef,
)
else:
return torch.where(
self_or_result <= 0,
grad_output * negiptcoef * negcoef * torch.exp(self_or_result * negiptcoef),
grad_output * poscoef,
)
@register_decomposition(aten.hardsigmoid)
@pw_cast_for_opmath
def hardsigmoid(self: Tensor) -> Tensor:
return torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6
@register_decomposition(aten.hardsigmoid_backward)
@pw_cast_for_opmath
def hardsigmoid_backward(grad_output: Tensor, self: Tensor):
return torch.where(
(self > -3.0) & (self < 3.0),
grad_output * (1.0 / 6.0),
grad_output.new_zeros(()),
)
@register_decomposition(aten.hardtanh_backward)
@pw_cast_for_opmath
def hardtanh_backward(
grad_output: Tensor, self: Tensor, min_val: float, max_val: float
):
return torch.where(
(self <= min_val) | (self >= max_val), grad_output.new_zeros(()), grad_output
)
@register_decomposition(aten.hardshrink_backward)
@pw_cast_for_opmath
def hardshrink_backward(grad_out: Tensor, self: Tensor, lambd: float):
return torch.where(
(self >= -lambd) & (self <= lambd), grad_out.new_zeros(()), grad_out
)
@register_decomposition(aten.hardswish)
@pw_cast_for_opmath
def hardswish(self: Tensor) -> Tensor:
return self * torch.clamp(torch.clamp(self + 3, min=0), max=6) / 6
@register_decomposition(aten.hardswish_backward)
@pw_cast_for_opmath
def hardswish_backward(grad_output: Tensor, self: Tensor) -> Tensor:
return torch.where(
self < -3,
grad_output.new_zeros(()),
torch.where(self <= 3, grad_output * ((self / 3) + 0.5), grad_output),
)
@register_decomposition(aten.threshold_backward)
@pw_cast_for_opmath
def threshold_backward(grad_output: Tensor, self: Tensor, threshold: float):
return torch.where(self <= threshold, grad_output.new_zeros(()), grad_output)
@register_decomposition(aten.leaky_relu_backward)
@pw_cast_for_opmath
def leaky_relu_backward(
grad_output: Tensor, self: Tensor, negative_slope: float, self_is_result: bool
):
return torch.where(self > 0, grad_output, grad_output * negative_slope)
@register_decomposition(aten.gelu_backward)
@pw_cast_for_opmath
def gelu_backward(grad: Tensor, self: Tensor, approximate: str = "none"):
M_SQRT2 = 1.41421356237309504880
M_SQRT1_2 = 0.70710678118654752440
M_2_SQRTPI = 1.12837916709551257390
if approximate == 'tanh':
kBeta = M_SQRT2 * M_2_SQRTPI * 0.5
kKappa = 0.044715
x_sq = self * self
x_cube = x_sq * self
inner = kBeta * (self + kKappa * x_cube)
tanh_inner = torch.tanh(inner)
left = 0.5 * self
right = 1 + tanh_inner
left_derivative = 0.5 * right
tanh_derivative = 1 - tanh_inner * tanh_inner
inner_derivative = kBeta * (1 + 3 * kKappa * x_sq)
right_derivative = left * tanh_derivative * inner_derivative
return grad * (left_derivative + right_derivative)
else:
kAlpha = M_SQRT1_2
kBeta = M_2_SQRTPI * M_SQRT1_2 * 0.5
cdf = 0.5 * (1 + torch.erf(self * kAlpha))
pdf = kBeta * torch.exp(self * self * -0.5)
return grad * (cdf + self * pdf)
@register_decomposition(aten.mish_backward)
@pw_cast_for_opmath
def mish_backward(grad_output: Tensor, input: Tensor):
input_tanh_softplus = torch.tanh(F.softplus(input))
input_sigmoid = torch.sigmoid(input)
out = input * input_sigmoid * (1 - input_tanh_softplus * input_tanh_softplus)
return grad_output * (input_tanh_softplus + out)
@register_decomposition(aten.silu)
@pw_cast_for_opmath
def silu(self: Tensor) -> Tensor:
return self * torch.sigmoid(self)
@register_decomposition(aten.silu_backward)
@pw_cast_for_opmath
def silu_backward(grad_output: Tensor, self: Tensor) -> Tensor:
sigmoid = 1 / (1 + torch.exp(-self))
return grad_output * sigmoid * (1 + self * (1 - sigmoid))
@register_decomposition(aten.softshrink_backward)
def softshrink_backward(grad_output: Tensor, self: Tensor, lambd: float) -> Tensor:
return torch.where(
(self >= -lambd) & (self <= lambd), grad_output.new_zeros(()), grad_output
)
@register_decomposition(aten.prelu_backward)
@pw_cast_for_opmath
def prelu_backward(
grad_output: Tensor, self: Tensor, weight: Tensor
) -> Tuple[Tensor, Tensor]:
# Logic is more complicated than I would like. Basically, weight can either
# be a scalar or a vector of size [C], and in the forward pass it's
# broadcast against [N, C, ...]. So now, we need to do the corresponding
# reduction, which is harder than we'd like...
cur_weight = weight
for _ in range(2, grad_output.dim()):
cur_weight = cur_weight.unsqueeze(-1)
input_grad = torch.where(self > 0, grad_output, cur_weight * grad_output)
weight_grad_collector = torch.where(
self > 0, grad_output.new_zeros(()), self * grad_output
)
out = weight_grad_collector.sum_to_size(cur_weight.shape)
while out.dim() > weight.dim():
out = out.squeeze(-1)
return (input_grad, out)
@register_decomposition(aten.rrelu_with_noise_backward)
@pw_cast_for_opmath
def rrelu_with_noise_backward(
grad_output: Tensor,
self: Tensor,
noise: Tensor,
lower: float,
upper: float,
training: bool,
self_is_result: bool,
) -> Tensor:
if training and upper - lower > 1e-6:
return grad_output.mul(noise)
else:
negative_slope = (lower + upper) / 2
return aten.leaky_relu_backward(grad_output, self, negative_slope, self_is_result)
@register_decomposition(aten.log_sigmoid_backward)
@pw_cast_for_opmath
def log_sigmoid_backward(grad_output: Tensor, self: Tensor, buffer: Tensor) -> Tensor:
in_negative = self < 0
max_deriv = torch.where(in_negative, 1, 0)
sign = torch.where(in_negative, 1, -1)
z = torch.exp(-torch.abs(self))
return grad_output * (max_deriv - sign * (z / (1 + z)))
# CPU has a special formula that uses buffer, but disabled for convenience sake
# return (max_deriv - sign * (buffer / (1 + buffer))) * grad_output
def apply_loss_reduction(loss: Tensor, reduction: int):
if reduction == Reduction.MEAN.value:
return torch.mean(loss)
elif reduction == Reduction.SUM.value:
return torch.sum(loss)
else:
return loss
def to_real_dtype(dtype: torch.dtype):
if dtype == torch.complex32:
return torch.float16
elif dtype == torch.complex64:
return torch.float32
elif dtype == torch.complex128:
return torch.float64
# TODO: None of these loss castings are quite correct, see
# https://github.com/pytorch/pytorch/issues/76870. Also, the ATen kernels
# perform the pointwise portion in opmath, but don't maintain it between the
# pointwise portion and the reduction
@register_decomposition(aten.l1_loss)
def l1_loss(
self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value
) -> Tensor:
loss = (self - target).abs()
# PyTorch semantics result in the output of l1_loss having the corresponding
# real dtype to self. This may not happen without explicit casting if say
# self: complex64 and target: float64, which results in loss: float64
float_type = to_real_dtype(self.dtype)
return apply_loss_reduction(loss, reduction).to(float_type)
@register_decomposition(aten.l1_loss_backward)
@pw_cast_for_opmath
def l1_loss_backward(
grad_output: Tensor,
self: Tensor,
target: Tensor,
reduction: int = Reduction.MEAN.value,
):
sign = torch.sign(self - target)
norm = sign / self.numel() if reduction == Reduction.MEAN.value else sign
return grad_output * norm
@register_decomposition(aten.mse_loss)
@pw_cast_for_opmath
def mse_loss(
self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value
) -> Tensor:
loss = (self - target) ** 2
return apply_loss_reduction(loss, reduction)
@register_decomposition(aten.mse_loss_backward)
@pw_cast_for_opmath
def mse_loss_backward(
grad_output: Tensor, input: Tensor, target: Tensor, reduction: int
):
norm = 2.0 / input.numel() if reduction == Reduction.MEAN.value else 2.0
return norm * (input - target) * grad_output
@register_decomposition(aten.huber_loss)
@pw_cast_for_opmath
def huber_loss(
self: Tensor,
target: Tensor,
reduction: int = Reduction.MEAN.value,
delta: float = 1.0,
) -> Tensor:
assert delta > 0, "huber_loss does not support non-positive values for delta."
z = (self - target).abs()
loss = torch.where(z < delta, 0.5 * z * z, delta * (z - 0.5 * delta))
return apply_loss_reduction(loss, reduction)
@register_decomposition(aten.huber_loss_backward)
@pw_cast_for_opmath
def huber_loss_backward(
grad_output: Tensor, self: Tensor, target: Tensor, reduction: int, delta: float
):
norm = 1.0 / self.numel() if reduction == Reduction.MEAN.value else 1.0
x = self - target
return torch.where(
x < -delta,
-norm * grad_output * delta,
torch.where(x > delta, norm * grad_output * delta, norm * x * grad_output),
)
def _nll_loss_backward(
grad_output: Tensor,
self: Tensor,
target: Tensor,
weight: Optional[Tensor],
reduction: int,
ignore_index: int,
total_weight: Tensor,
) -> Tensor:
channel_dim = 0 if self.dim() < 2 else 1
if reduction == Reduction.MEAN.value:
grad_output = grad_output / total_weight
target = target.unsqueeze(channel_dim)
grad_input = torch.zeros_like(self)
grad_input = torch.scatter(grad_input, channel_dim, target, -1.0)
if grad_input.dim() > grad_output.dim() > 0:
grad_output = grad_output.unsqueeze(channel_dim)
if weight is not None:
new_shape = [1 for _ in range(self.dim())]
new_shape[channel_dim] = weight.shape[0]
weight = weight.reshape(new_shape)
grad_output = grad_output * weight
has_ignore_index = ignore_index >= 0
if has_ignore_index:
ignore_index_mask = target != ignore_index
grad_output = grad_output * ignore_index_mask
return grad_input * grad_output
@register_decomposition(aten.glu_backward)
@pw_cast_for_opmath
def glu_backward(grad_output: Tensor, self: Tensor, dim: int) -> Tensor:
assert self.dim() > 0, "glu does not support 0-dimensional tensors"
wrap_dim = utils.canonicalize_dim(self.dim(), dim)
nIn = self.size(wrap_dim)
assert nIn % 2 == 0, f"Halving dimension must be even, but dimension {wrap_dim} is size {nIn}"
inputSize = nIn // 2
firstHalf = self.narrow(wrap_dim, 0, inputSize)
secondHalf = self.narrow(wrap_dim, inputSize, inputSize)
gradInputFirstHalf = torch.sigmoid(secondHalf)
gradInputSecondHalf = (1.0 - gradInputFirstHalf) * gradInputFirstHalf * firstHalf * grad_output
gradInputFirstHalf = gradInputFirstHalf * grad_output
return torch.cat([gradInputFirstHalf, gradInputSecondHalf], dim=wrap_dim)
@register_decomposition(aten.nll_loss_backward)
def nll_loss_backward(
grad_output: Tensor,
self: Tensor,
target: Tensor,
weight: Optional[Tensor],
reduction: int,
ignore_index: int,
total_weight: Tensor,
) -> Tensor:
assert 0 <= self.dim() <= 2, "input tensor should be 1D or 2D"
assert (
target.dim() <= 1
), "0D or 1D target tensor expected, multi-target not supported"
no_batch_dim = self.dim() == 1 and target.dim() == 0
assert no_batch_dim or (
self.shape[0] == target.shape[0]
), f"size mismatch (got input: {self.shape}, target: {target.shape})"
assert total_weight.numel() == 1, (
"expected total_weight to be a single element tensor, got: ",
f"{total_weight.shape} ({total_weight.numel()} elements)",
)
assert (
weight is None or weight.numel() == self.shape[-1]
), "weight tensor should be defined either for all or no classes"
if reduction == Reduction.NONE.value and self.dim() == 2:
assert grad_output.dim() == 1 and grad_output.shape[0] == self.shape[0], (
f"Expected a tensor of dimension 1 and tensor.size[0] == {self.shape[0]} but "
f"got: dimension {grad_output.dim()} and tensor.size[0] == {grad_output.shape[0]}"
)
else:
assert (
grad_output.dim() <= 1 and grad_output.numel() == 1
), f"Expected a single element grad_output tensor, but got: {grad_output.shape}"
return _nll_loss_backward(grad_output, self, target, weight, reduction, ignore_index, total_weight)
@register_decomposition(aten.nll_loss2d_backward)
def nll_loss2d_backward(
grad_output: Tensor,
self: Tensor,
target: Tensor,
weight: Optional[Tensor],
reduction: int,
ignore_index: int,
total_weight: Tensor,
) -> Tensor:
assert (
self.dim() == 4
), f"only batches of spatial inputs supported (4D tensors), but got input of dimension: {self.dim()}"
assert (
target.dim() == 3
), f"only batches of spatial targets supported (3D tensors) but got targets of dimension: {target.dim()}"
assert(
self.shape[0] == target.shape[0] and self.shape[2] == target.shape[1] and self.shape[3] == target.shape[2]
), f"size mismatch (got input: {self.shape}, target: {target.shape}"
assert (
total_weight.numel() == 1
), (
"expected total_weight to be a single element tensor, "
f"got: {total_weight.shape} ( {total_weight.numel()}, elements)"
)
return _nll_loss_backward(grad_output, self, target, weight, reduction, ignore_index, total_weight)
@register_decomposition(aten.binary_cross_entropy)
@pw_cast_for_opmath
def binary_cross_entropy(
self: Tensor,
target: Tensor,
weight: Optional[Tensor] = None,
reduction: int = Reduction.MEAN.value,
) -> Tensor:
# We cannot currently model this without introducing data-dependent control flow
# TORCH_CHECK(
# (input_val >= 0) && (input_val <= 1),
# "all elements of input should be between 0 and 1"
# )
loss = (target - 1) * torch.maximum(
torch.log(1 - self), self.new_full((), -100)
) - target * torch.maximum(torch.log(self), self.new_full((), -100))
if weight is not None:
loss = loss * weight
return apply_loss_reduction(loss, reduction)
@register_decomposition(aten.binary_cross_entropy_backward)
@pw_cast_for_opmath
def binary_cross_entropy_backward(
grad_output: Tensor,
self: Tensor,
target: Tensor,
weight: Optional[Tensor] = None,
reduction: int = Reduction.MEAN.value,
) -> Tensor:
EPSILON = 1e-12
result = grad_output * (self - target) / torch.clamp(self * (1 - self), min=EPSILON)
if weight is not None:
result = result * weight
if reduction == Reduction.MEAN.value:
result = result / self.numel()
return result
@register_decomposition(aten._euclidean_dist)
def _euclidean_dist(x1: Tensor, x2: Tensor) -> Tensor:
x1_norm = x1.pow(2).sum(-1, True)
x1_pad = torch.ones_like(x1_norm, memory_format=torch.contiguous_format)
x2_norm = x2.pow(2).sum(-1, True)
x2_pad = torch.ones_like(x2_norm, memory_format=torch.contiguous_format)
x1_ = torch.cat([x1.mul(-2), x1_norm, x1_pad], -1)
x2_ = torch.cat([x2, x2_pad, x2_norm], -1)
result = x1_.matmul(x2_.mT)
return result.clamp_min(0).sqrt()
@register_decomposition(aten.slice_backward)
def slice_backward(
grad_output: Tensor,
input_sizes: List[int],
dim: int,
start: int,
end: int,
step: int,
):
grad_input = grad_output.new_zeros(input_sizes)
return torch.slice_scatter(grad_input, grad_output, dim, start, end, step)
@register_decomposition(aten.select_backward)
def select_backward(grad_output: Tensor, input_sizes: List[int], dim: int, index: int):
grad_input = grad_output.new_zeros(input_sizes)
return torch.select_scatter(grad_input, grad_output, dim, index)
@register_decomposition(aten.diagonal_backward)
def diagonal_backward(
grad_output: Tensor, input_sizes: List[int], offset: int, dim1: int, dim2: int
):
grad_input = grad_output.new_zeros(input_sizes)
return torch.diagonal_scatter(grad_input, grad_output, offset, dim1, dim2)
@register_decomposition(aten._softmax_backward_data)
@pw_cast_for_opmath
def _softmax_backward_data(
grad_output: Tensor, output: Tensor, dim: int, input_dtype: int
):
new_grad = grad_output * output
return new_grad - output * torch.sum(new_grad, dim=dim, keepdim=True)
@register_decomposition(aten._log_softmax_backward_data)
@pw_cast_for_opmath
def _log_softmax_backward_data(
grad_output: Tensor, output: Tensor, dim: int, input_dtype: int
):
grad_input = grad_output - torch.exp(output) * torch.sum(
grad_output, dim=dim, keepdim=True
)
return grad_input
# TODO: the type annotations on arguments are not quite right
@register_decomposition(aten.im2col_backward)
def im2col_backward(
grad_output: Tensor,
input_size: List[int],
kernel_size: List[int],
dilation: List[int],
padding: List[int],
stride: List[int],
) -> Tensor:
return F.fold(grad_output, input_size, kernel_size, dilation, padding, stride) # type: ignore[arg-type]
@register_decomposition(aten.col2im_backward)
def col2im_backward(
grad_output: Tensor,
kernel_size: List[int],
dilation: List[int],
padding: List[int],
stride: List[int],
) -> Tensor:
return F.unfold(grad_output, kernel_size, dilation, padding, stride) # type: ignore[arg-type]
@register_decomposition(aten.masked_fill.Scalar)
def masked_fill_Scalar(self: Tensor, mask: Tensor, value: float) -> Tensor:
return torch.where(mask, utils.dtype_to_type(self.dtype)(value), self)
@register_decomposition(aten.masked_fill.Tensor)
def masked_fill_Tensor(self: Tensor, mask: Tensor, value: Tensor) -> Tensor:
return torch.where(mask, value, self)
@register_decomposition(aten.native_dropout_backward)
@pw_cast_for_opmath
def native_dropout_backward(grad_output: Tensor, mask: Tensor, scale: float):
return grad_output * (mask.type_as(grad_output) * scale)
@register_decomposition(aten.logit_backward.default)
@pw_cast_for_opmath
def logit_backward(
grad_output: Tensor, self: Tensor, eps: Optional[float] = None
) -> Tensor:
if eps is not None:
lo = eps
hi = 1.0 - lo
return torch.where(
torch.logical_and(self >= lo, self <= hi),
grad_output / (self * (1.0 - self)),
self.new_zeros(()),
)
else:
return torch.where(
torch.logical_and(self >= 0.0, self <= 1.0),
grad_output / (self * (1.0 - self)),
self.new_full((), float("nan")),
)
@register_decomposition(aten.native_dropout)
def native_dropout(input: Tensor, p: float, train: Optional[bool]):
if train:
bool_mask = torch.rand_like(input) > p
res = bool_mask * input * float(1.0 / (1.0 - p))
return (res, bool_mask)
else:
return (input, torch.ones_like(input, dtype=torch.bool))
# TODO: Correct the type promotion semantics
@register_decomposition(aten._softmax)
@pw_cast_for_opmath
def _softmax(x: Tensor, dim: int, half_to_float: bool):
x_max = torch.max(x, dim, keepdim=True)[0]
unnormalized = torch.exp(x - x_max)
return unnormalized / torch.sum(unnormalized, dim, keepdim=True)
# TODO: Correct the type promotion semantics
@register_decomposition(aten._log_softmax)
@pw_cast_for_opmath
def _log_softmax(x: Tensor, dim: int, half_to_float: bool):
x_max = torch.max(x, dim, keepdim=True)[0]
shifted = x - x_max
shifted_logsumexp = torch.log(torch.sum(torch.exp(shifted), dim, keepdim=True))
return shifted - shifted_logsumexp
@register_decomposition(aten.addcdiv)
@pw_cast_for_opmath
def addcdiv(self: Tensor, tensor1: Tensor, tensor2: Tensor, value: float = 1):
return self + value * (tensor1 / tensor2)
# Remove special case when https://github.com/pytorch/pytorch/pull/72949 is landed.
@register_decomposition(aten.addcmul)
@pw_cast_for_opmath
def addcmul(self: Tensor, tensor1: Tensor, tensor2: Tensor, value: float = 1):
if self.is_floating_point() or self.is_complex():
return self + value * tensor1 * tensor2
else:
return self + int(value) * tensor1 * tensor2
@register_decomposition(aten.rsub.Tensor)
def rsub_Tensor(self: Tensor, other: Tensor, alpha: float = 1) -> Tensor:
return torch.sub(other, self, alpha=alpha)
@register_decomposition(aten.rsub.Scalar)
def rsub_Scalar(self: Tensor, other: float, alpha: float = 1) -> Tensor:
return torch.sub(other, self, alpha=alpha)
@register_decomposition(aten.embedding)
def embedding(
weight: Tensor,
indices: Tensor,
padding_idx: int = -1,
scale_grad_by_freq: bool = False,
sparse: bool = False,
) -> Tensor:
assert weight.dim() == 2, "'weight' must be 2-D"
# TODO: Assert not ported over yet
# auto indices_arg = TensorArg(indices, "indices", 1);
# checkScalarTypes("embedding", indices_arg, {kLong, kInt});
if indices.dim() == 1:
return weight.index_select(0, indices)
size = list(indices.shape)
for d in weight.shape[1:]:
size.append(d)
return weight.index_select(0, indices.reshape(-1)).view(size)
# TODO: Correct the type promotion semantics
@register_decomposition(aten.embedding_dense_backward)
def embedding_dense_backward(
grad_output: Tensor,
indices: Tensor,
num_weights: int,
padding_idx: int,
scale_grad_by_freq: bool,
):
numel = indices.numel()
grad = grad_output.view(numel, grad_output.size(-1))
grad_weight = grad_output.new_zeros((num_weights, grad_output.shape[-1]))
indices_rank1 = indices.view(numel)
if scale_grad_by_freq:
counts = indices.new_zeros((num_weights,))
ones = indices.new_ones((numel,))
counts = counts.index_put([indices_rank1], ones, accumulate=True)
grad_weights_scale = counts[indices_rank1]
grad = grad / grad_weights_scale.unsqueeze(1)
skip_padding = (indices_rank1 != padding_idx).unsqueeze(1)
skip_padding = skip_padding.expand_as(grad)
zero_grad = torch.full_like(grad, 0)
return grad_weight.index_put(
[indices_rank1], torch.where(skip_padding, grad, zero_grad), accumulate=True
)
def prod(x: List[int]):
r = 1
for i in x:
r *= i
return r
@register_decomposition(aten.split_with_sizes)
def split_with_sizes(
self: Tensor, split_sizes: List[int], dim: int = 0
) -> List[Tensor]:
num_splits = len(split_sizes)
splits = []
start_idx = 0
for i in range(num_splits):
length = split_sizes[i]
splits.append(self.narrow(dim, start_idx, length))
start_idx += length
return splits
@register_decomposition(aten.split.Tensor)
def split(self: Tensor, split_size: int, dim: int = 0) -> List[Tensor]:
input_sizes = self.shape
dim_size = input_sizes[dim]
if split_size == 0:
assert dim_size == 0
return [self]
chunks = (dim_size + split_size - 1) // split_size
split_sizes = [split_size for i in range(chunks)]
split_sizes[chunks - 1] = split_size - (split_size * chunks - dim_size)
return torch.split(self, split_sizes, dim)
# TODO: this doesn't appear to have enough precision in bfloat16
@register_decomposition(aten.addmm)
@pw_cast_for_opmath
def addmm(self: Tensor, mat1: Tensor, mat2: Tensor, beta: int = 1, alpha: int = 1):
if not self.is_floating_point() and not self.is_complex():
beta = int(beta)
alpha = int(alpha)
out = alpha * torch.mm(mat1, mat2)
if beta == 0:
return out
return beta * self + out
# This computes the mean and variance along the specifized normalization dims,
# then normalizes along those dims. Finally, it returns the mean and variance of
# the normalized dims. Note that it intentionally leaves outputs upcasted.
# Example:
# input: [2, 3, 4, 5], norm_dims: [1, 3]
# mean: [2, 1, 4, 1]
def normalize(input, norm_dims, eps):
computation_dtype = utils.get_computation_dtype(input.dtype)
input_acc = input.to(dtype=computation_dtype)
biased_var = torch.var(input_acc, dim=norm_dims, unbiased=False, keepdim=True)
mean = torch.mean(input_acc, dim=norm_dims, keepdim=True)
rstd = torch.rsqrt(biased_var + eps)
out = ((input - mean) * rstd)
return out, mean, rstd
@register_decomposition(aten.native_layer_norm.default)
def native_layer_norm(
input: Tensor,
normalized_shape: List[int],
weight: Optional[Tensor],
bias: Optional[Tensor],
eps: float,
) -> Tuple[Tensor, Tensor, Tensor]:
computation_dtype = utils.get_computation_dtype(input.dtype)
axis = input.dim() - len(normalized_shape)
if prod(list(input.shape[:axis])) == 0:
mean = input.new_zeros((0,), dtype=computation_dtype)
rstd = input.new_zeros((0,), dtype=computation_dtype)
out = input
else:
reduction_dims = list(range(axis, input.dim()))
out, mean, rstd = normalize(input, reduction_dims, eps)
if weight is not None:
out = out * weight
if bias is not None:
out = out + bias
out = out.to(dtype=input.dtype)
if input.device.type == 'cpu':
mean = mean.to(dtype=input.dtype)
rstd = rstd.to(dtype=input.dtype)
return (out, mean, rstd)
@register_decomposition(aten.native_group_norm.default, disable_meta=True)
def native_group_norm(
input: Tensor,
weight: Optional[Tensor],
bias: Optional[Tensor],
N: int,
C: int,
HxW: int,
group: int,
eps: float,
) -> Tuple[Tensor, Tensor, Tensor]:
orig_shape = input.shape
input = input.view(N, group, C // group, HxW)
reduction_dims = [2, 3]
out, mean, rstd = normalize(input, reduction_dims, eps)
mean = _squeeze_multiple(mean, reduction_dims)
rstd = _squeeze_multiple(rstd, reduction_dims)
out = out.view(orig_shape)
if weight is not None:
weight = _unsqueeze_to_dim(weight, out.dim() - 1)
out = out * weight
if bias is not None:
bias = _unsqueeze_to_dim(bias, out.dim() - 1)
out = out + bias
out = out.to(dtype=input.dtype)
mean = mean.to(dtype=input.dtype)
rstd = rstd.to(dtype=input.dtype)
return (out, mean, rstd)
def _maybe_cast(x: Optional[Tensor], dtype) -> Optional[Tensor]:
if x is not None:
return x.to(dtype)
return x
# TODO: Take a closer look at the type promotion semantics
@register_decomposition(aten.native_layer_norm_backward)
def native_layer_norm_backward(
grad_out: Tensor,
input: Tensor,
normalized_shape: List[int],
mean: Tensor,
rstd: Tensor,
weight: Optional[Tensor],
bias: Optional[Tensor],
output_mask: List[bool],
) -> Tuple[Optional[Tensor], Optional[Tensor], Optional[Tensor]]:
input_shape = input.shape
input_ndim = input.dim()
computation_dtype = utils.get_computation_dtype(input.dtype)
grad_out_cast, input_cast, weight_cast, bias_cast = [
x.to(computation_dtype) if x is not None else x
for x in (grad_out, input, weight, bias)
]
assert grad_out_cast is not None
axis = input_ndim - len(normalized_shape)
inner_dims = input_shape[axis:]
outer_dims = input_shape[:axis]
inner_dim_indices: List[int] = []
outer_dim_indices: List[int] = []
for i in range(input_ndim):
if i >= axis:
inner_dim_indices.append(i)
else:
outer_dim_indices.append(i)
N = prod(inner_dims) # type: ignore[arg-type]
M = prod(outer_dims) # type: ignore[arg-type]
if M <= 0 or N <= 0:
return (
input.new_zeros(input_shape),
input.new_zeros(input_shape[axis:]),
input.new_zeros(input_shape[axis:]),
)
x_hat = (input_cast - mean) * rstd
if weight_cast is not None:
grad_x_hat = grad_out_cast * weight_cast
else:
grad_x_hat = grad_out_cast
a = grad_x_hat * N
b = torch.sum(grad_x_hat, inner_dim_indices, True)
c1 = torch.mul(grad_x_hat, x_hat)
c2 = torch.sum(c1, inner_dim_indices, True)
c3 = torch.mul(x_hat, c2)
inner = a - b - c3
d_input: Optional[Tensor] = None
d_weight: Optional[Tensor] = None
d_bias: Optional[Tensor] = None
if output_mask[0]:
d_input = (rstd / N) * inner
if output_mask[1] and weight_cast is not None:
if len(outer_dim_indices) > 0:
d_weight = torch.sum(
grad_out_cast * x_hat, outer_dim_indices, False
)
else:
d_weight = grad_out_cast * x_hat
if output_mask[2] and bias_cast is not None:
if len(outer_dim_indices) > 0:
d_bias = torch.sum(grad_out_cast, outer_dim_indices, False)
else:
d_bias = grad_out_cast
return _maybe_cast(d_input, input.dtype), _maybe_cast(d_weight, input.dtype), _maybe_cast(d_bias, input.dtype)
@register_decomposition(aten.native_batch_norm)
def native_batch_norm(
input: Tensor,
weight: Optional[Tensor],
bias: Optional[Tensor],
running_mean: Optional[Tensor],
running_var: Optional[Tensor],
training: bool,
momentum: float,
eps: float,
) -> Tuple[Tensor, Tensor, Tensor]:
reduction_dims = [0] + list(range(2, input.dim()))
computation_dtype = utils.get_computation_dtype(input.dtype)
if training:
output, mean, rstd = normalize(input, reduction_dims, eps)
save_mean = _squeeze_multiple(mean, reduction_dims)
save_rstd = _squeeze_multiple(rstd, reduction_dims)
if running_mean is not None:
running_mean.copy_(momentum * save_mean + (1 - momentum) * running_mean)
if running_var is not None:
n = input.numel() / input.shape[1]
# This doesn't strictly match eager's numerics, which accumulates var sum and then directly applies the correction
# But... that would require re-implementing var here, for negligible numerics gain on a tensor whose
# numerics probably don't matter.
unbiased_var = torch.var(input, reduction_dims, unbiased=False) * (n / (n - 1))
running_var.copy_(momentum * unbiased_var + (1 - momentum) * running_var)
else:
assert running_mean is not None and running_var is not None
running_mean = running_mean.to(dtype=computation_dtype)
running_var = running_var.to(dtype=computation_dtype)
mean = running_mean
invstd = 1 / (torch.sqrt(running_var + eps))
# Very annoying inconsistency where CPU and CUDA give different shapes
if input.device.type != "cpu":
save_mean = running_mean
save_rstd = invstd
else:
save_mean = input.new_zeros((0,))
save_rstd = input.new_zeros((0,))
mean = _unsqueeze_to_dim(mean, input.dim() - 1)
invstd = _unsqueeze_to_dim(invstd, input.dim() - 1)
output = ((input - mean) * invstd)
if weight is None:
weight = input.new_ones(())
if bias is None:
bias = input.new_zeros(())
weight = _unsqueeze_to_dim(weight, input.dim() - 1)
bias = _unsqueeze_to_dim(bias, input.dim() - 1)
output = output * weight + bias
if input.device.type == 'cpu':
save_mean = save_mean.to(dtype=input.dtype)
save_rstd = save_rstd.to(dtype=input.dtype)
return output.to(dtype=input.dtype), save_mean, save_rstd
@register_decomposition(aten.clamp_min)
def clamp_min(self: Tensor, min: float):
return torch.clamp(self, min=min)
@register_decomposition(aten.clamp_max)
def clamp_max(self: Tensor, max: float):
return torch.clamp(self, max=max)
@register_decomposition(aten._fused_dropout)
@pw_cast_for_opmath
def _fused_dropout_decomposition(input, p, generator=None):
mask = (torch.rand_like(input) < p).to(dtype=torch.uint8)
res = mask.type_as(input) * input * (1.0 / p)
return (res, mask)
@register_decomposition(aten.logical_not)
def logical_not(self: Tensor) -> Tensor:
return ~self.to(dtype=torch.bool)
@register_decomposition(aten.xlogy.Tensor)
@pw_cast_for_int_to_real
def xlogy(self: Tensor, other: Tensor) -> Tensor:
return aten.where(aten.isnan(self),
self,
aten.where(self == aten.new_zeros(self, ()),
aten.new_zeros(self, ()),
self * aten.log(other)))
@register_decomposition(aten.var.correction)
@reduction_complex_to_real
def var_correction(
x: Tensor,
dims: Optional[List[int]],
correction: Optional[int] = None,
keepdim: bool = False,
):
if dims is None:
dims = []
if x.is_complex():
# For complex, calculate variance of real and imaginary components
# separately then add to get overall variance.
real_in = x.real
var_real = torch.var(real_in, dims, correction=correction, keepdim=keepdim)
imag_in = x.imag
var_imag = torch.var(imag_in, dims, correction=correction, keepdim=keepdim)
return var_real + var_imag
if correction is None:
correction = 0
if len(dims) == 0:
n = prod(x.shape) # type: ignore[arg-type]
else:
n = 1
for dim in dims:
n *= x.shape[dim]
mean = torch.mean(x, dims, True)
sub = x - mean
sq = sub * sub
sum = torch.sum(sq, dims, keepdim)
if correction:
n = n - correction
return sum / n
@register_decomposition(aten.std.correction)
@reduction_complex_to_real
def std_decomposition(
x: Tensor, dims: List[int], correction: int = 0, keepdim: bool = False
):
return torch.sqrt(torch.var(x, dims, correction=correction, keepdim=keepdim))
# Questionable decompositions
# This is only valid if we're running the graph without autograd, such as if the backward pass has been traced.
# Note that this decomposition causes issues with in-place ops
@register_decomposition(aten.detach, disable_meta=True)
def detach_decomposition(x):
return x
@register_decomposition(aten.cudnn_batch_norm)
def cudnn_batch_norm(
input: Tensor,
weight: Tensor,
bias: Optional[Tensor],
running_mean: Optional[Tensor],
running_var: Optional[Tensor],
training: bool,
exponential_average_factor: float,
epsilon: float,
):
a, b, c = aten.native_batch_norm(
input,
weight,
bias,
running_mean,
running_var,
training,
exponential_average_factor,
epsilon,
)
# Cudnn return running mean and variance when training is True
if training:
return (a, b, c, input.new_zeros((0,), dtype=torch.uint8))
return (a, input.new_zeros((0,)), input.new_zeros((0,)), input.new_zeros((0,), dtype=torch.uint8))
@register_decomposition(aten.cudnn_batch_norm_backward)
def cudnn_batch_norm_backward(
input: Tensor,
grad_output: Tensor,
weight: Tensor,
running_mean: Optional[Tensor],
running_var: Optional[Tensor],
save_mean: Optional[Tensor],
save_var: Optional[Tensor],
epsilon: float,
reserveSpace: Tensor,
):
return aten.native_batch_norm_backward(
grad_output,
input,
weight,
running_mean,
running_var,
save_mean,
save_var,
True,
epsilon,
[True, True, True],
)
@register_decomposition(aten.transpose.int)
def transpose_int(self: Tensor, dim0: int, dim1: int) -> Tensor:
dim0, dim1 = utils.canonicalize_dims(self.dim(), (dim0, dim1)) # type: ignore[misc]
if self.dim() <= 1:
return self
if dim0 == dim1:
return self
perm = list(range(self.dim()))
perm[dim0], perm[dim1] = perm[dim1], perm[dim0]
return torch.permute(self, perm)
def _squeeze_multiple(self: Tensor, dims: List[int]) -> Tensor:
ndim = self.dim()
wrapped_dims = utils.canonicalize_dims(ndim, dims)
assert isinstance(wrapped_dims, tuple)
for idx in range(ndim - 1, -1, -1):
if idx in wrapped_dims:
self = self.squeeze(idx)
return self
@register_decomposition(aten.logsumexp.default)
@pw_cast_for_int_to_real
def logsumexp(self: Tensor, dim: List[int], keepdim: bool = False) -> Tensor:
if self.numel() == 0:
return torch.sum(torch.exp(self), dim, keepdim).log()
maxes = torch.amax(self, dim, keepdim=True)
maxes_squeezed = maxes if keepdim else _squeeze_multiple(maxes, dim)
maxes_squeezed = torch.masked_fill(maxes_squeezed, maxes_squeezed.abs() == float('inf'), 0)
result = torch.sum(torch.exp(self - maxes), dim, keepdim)
return result.log().add(maxes_squeezed)
@register_decomposition(aten.trace.default)
def trace(self: Tensor) -> Tensor:
return torch.sum(torch.diag(self))
# nb: Should use acc_t, not op_math
@register_decomposition(aten.log_sigmoid_forward)
@out_wrapper_multi('output', 'buffer')
@pw_cast_for_opmath
def log_sigmoid_forward(self: Tensor) -> Tuple[Tensor, Tensor]:
min = torch.minimum(self.new_zeros(()), self)
z = torch.exp(-torch.abs(self))
if self.is_cuda:
buffer = self.new_zeros((0,))
else:
buffer = z
return min - torch.log1p(z), buffer
# The implementation matches torch.ops.aten.norm
# torch.ops.aten.norm only supports numeric p, does not support Frobenius norm or nuclear norm
# For 2-norm and -2 matrix norm, it doesn't compute the singular values, it just compute the norm the same as when p > 2.
@register_decomposition([aten.norm.Scalar, aten.norm.ScalarOpt_dim])
@reduction_complex_to_real
def norm(self: Tensor, p: float = 2, dim: List[int] = None, keepdim: bool = False):
if dim is None:
dim = []
if p == 0:
return (self != 0).sum(dim, keepdim=keepdim)
elif p == float('inf'):
return self.abs().amax(dim, keepdim=keepdim)
elif p == -float('inf'):
return self.abs().amin(dim, keepdim=keepdim)
def fast_pow(x, ord):
if ord == 1.0:
return x
elif ord == 2.0:
return x.square()
elif ord == 0.5:
return x.sqrt()
else:
return x.pow(ord)
if not (p % 2.0 == 0.0 and utils.is_float_dtype(self.dtype)):
self = self.abs()
return fast_pow(fast_pow(self, p).sum(dim, keepdim=keepdim), 1.0 / p)