blob: f6355693ad2ed9f503b2b050c53222960b5af05c [file] [log] [blame]
import functools
import logging
import math
import numbers
from itertools import chain
import torch
import torch._decomp as decomp
import torch.ao.quantization.fx._decomposed
from torch import Tensor
from torch._decomp import core_aten_decompositions, get_decompositions
from torch._decomp.decompositions import pw_cast_for_opmath
from torch._decomp.decompositions_for_rng import extra_random_decomps
from torch.utils._mode_utils import no_dispatch
from . import config, utils
log = logging.getLogger(__name__)
aten = torch.ops.aten
prims = torch.ops.prims
quantized_decomposed = torch.ops.quantized_decomposed
inductor_decompositions = get_decompositions(
[
aten.arange,
aten.bitwise_and_,
aten.bitwise_or_,
aten.clamp_min_,
aten.flip,
aten.lcm,
aten.linalg_vector_norm,
aten.sin_,
aten.sqrt_,
aten.std,
aten.std_mean,
aten._to_copy,
aten.tril_indices,
aten.triu_indices,
aten.unsafe_split,
]
)
decompositions = {**core_aten_decompositions(), **inductor_decompositions}
def register_decomposition(ops):
for op in [ops] if callable(ops) else ops:
if op in decompositions:
log.warning("duplicate decomp: %s", ops)
return decomp.register_decomposition(ops, decompositions)
@register_decomposition(aten._unsafe_view.default)
def _unsafe_view(self, size):
# this makes pattern matching easier
return self.view(size)
# TODO: for now, inductor doesn't handle asserts
# because the condition is symbool -> tensor in the graph.
@register_decomposition([aten._assert_async.msg])
def assert_async_msg_decomp(tensor, msg):
return
@register_decomposition([aten.clamp])
@pw_cast_for_opmath
def clamp(x, min=None, max=None):
if min is not None:
x = x.clamp_min(min)
if max is not None:
x = x.clamp_max(max)
return x
# TorchInductor-only decomposition. It should not be taken to core.
# See https://github.com/pytorch/torchdynamo/pull/1120
@register_decomposition([aten.floor_divide.default])
def floordiv(a, b):
return aten.div.Tensor_mode(a, b, rounding_mode="floor")
# Not really sure how to put this into the main library. PrimTorch wants
# empty_permuted to go to the prim, and typically users don't really want
# to decompose to empty_strided (but inductor is OK with it, because we are
# cool with strides and everything goes to empty_strided)
@register_decomposition([aten.empty_permuted.default])
def empty_permuted(size, physical_layout, **kwargs):
perm = [0] * len(size)
for p, l in enumerate(physical_layout):
perm[l] = p
return torch.empty([size[l] for l in physical_layout], **kwargs).permute(perm)
def get_alignment_size(x):
if x.dtype == torch.float16 or x.dtype == torch.half or x.dtype == torch.bfloat16:
return 8
elif x.dtype == torch.float32 or x.dtype == torch.float:
return 4
else:
return 0
def check_device(a: Tensor, b: Tensor):
return a.is_cuda and b.is_cuda
def is_symbolic(a: Tensor, b: Tensor):
return any(
isinstance(x, torch.SymInt)
for x in chain(a.size(), a.stride(), b.size(), b.stride())
)
def get_padded_length(x, alignment_size):
if alignment_size == 0 or x % alignment_size == 0:
return 0
return int((x // alignment_size + 1) * alignment_size) - x
def pad_dim(x, padded_length, dim):
if padded_length == 0:
return x
pad = x.new_zeros(*x.shape[:dim], padded_length, *x.shape[dim + 1 :])
return torch.cat([x, pad], dim=dim)
@register_decomposition([aten.addmm])
def addmm(input, mat1, mat2, *, beta=1, alpha=1):
if (
config.shape_padding
and check_device(mat1, mat2)
and not is_symbolic(mat1, mat2)
and should_pad_bench(mat1, mat2, torch.ops.aten.addmm, input=input)
):
m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1))
k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2))
if m_padded_length != 0 or k_padded_length != 0 or n_padded_length != 0:
return pad_addmm(
input, mat1, mat2, m_padded_length, k_padded_length, n_padded_length
)
return NotImplemented # go directly to lowering
def pad_addmm(input, mat1, mat2, m_padded_length, k_padded_length, n_padded_length):
# addmm decomp with padding will go through pad_addmm multiple times if multiple dimensions are needed to be padded
if k_padded_length != 0:
mat1 = pad_dim(mat1, k_padded_length, 1)
mat2 = pad_dim(mat2, k_padded_length, 0)
elif n_padded_length != 0:
mat2 = pad_dim(mat2, n_padded_length, 1)
elif m_padded_length != 0:
mat1 = pad_dim(mat1, m_padded_length, 0)
if input is not None and k_padded_length == 0:
if n_padded_length != 0:
if input.dim() == 2:
input = pad_dim(input, n_padded_length, 1)
elif input.dim() == 1:
input = pad_dim(input, n_padded_length, 0)
elif m_padded_length != 0 and input.dim() == 2:
input = pad_dim(input, m_padded_length, 0)
if k_padded_length != 0:
return torch.ops.aten.addmm(input, mat1, mat2)
elif n_padded_length != 0:
return torch.ops.aten.addmm(input, mat1, mat2)[:, :-n_padded_length]
else:
return torch.ops.aten.addmm(input, mat1, mat2)[:-m_padded_length, :]
def should_pad_bench(mat1, mat2, op, input=None):
assert utils.has_triton()
import triton.testing
do_bench = functools.partial(
triton.testing.do_bench,
warmup=5,
rep=100,
fast_flush=True,
)
with no_dispatch():
if op is torch.ops.aten.mm or op is torch.ops.aten.addmm:
m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1))
k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2))
elif op is torch.ops.aten.bmm:
m_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
k_padded_length = get_padded_length(mat1.shape[2], get_alignment_size(mat1))
n_padded_length = get_padded_length(mat2.shape[2], get_alignment_size(mat2))
else:
return False
if m_padded_length == k_padded_length == n_padded_length == 0:
return False
mat1 = torch.randn_like(mat1)
mat2 = torch.randn_like(mat2)
warmup = 5
rep = 100
if op is torch.ops.aten.bmm or op is torch.ops.aten.mm:
ori_time = do_bench(
lambda: op(mat1, mat2),
)
else:
if input is not None:
input = torch.randn_like(input)
ori_time = do_bench(
lambda: op(input, mat1, mat2),
)
mat1_pad = torch.randn_like(mat1)
mat2_pad = torch.randn_like(mat2)
if op is torch.ops.aten.addmm:
input_pad = None
if input is not None and input.is_cuda:
input_pad = torch.randn_like(input)
pad_time = do_bench(
lambda: pad_addmm(
input_pad,
mat1_pad,
mat2_pad,
m_padded_length,
k_padded_length,
n_padded_length,
),
)
elif op is torch.ops.aten.mm:
pad_time = do_bench(
lambda: pad_mm(
mat1_pad,
mat2_pad,
m_padded_length,
k_padded_length,
n_padded_length,
),
)
else:
pad_time = do_bench(
lambda: pad_bmm(
mat1_pad,
mat2_pad,
m_padded_length,
k_padded_length,
n_padded_length,
),
)
# Shape padding introduces additional memory ops. Based on microbenchmarks, 1.1x represents a reasonable
# tradeoff between performance improvement from shape padding and overhead from additional memory ops
# TODO: Build a learned model which would be better than this heuristic
return ori_time > pad_time * 1.1
@register_decomposition([aten.mm])
def mm_decomp(mat1, mat2):
if (
config.shape_padding
and check_device(mat1, mat2)
and not is_symbolic(mat1, mat2)
and should_pad_bench(mat1, mat2, torch.ops.aten.mm)
):
m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1))
k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2))
if m_padded_length != 0 or k_padded_length != 0 or n_padded_length != 0:
return pad_mm(mat1, mat2, m_padded_length, k_padded_length, n_padded_length)
return NotImplemented # go directly to lowering
def pad_mm(mat1, mat2, m_padded_length, k_padded_length, n_padded_length):
# mm_decomp will go through pad_mm multiple times if multiple dimensions are needed to be padded
if k_padded_length != 0:
mat1 = pad_dim(mat1, k_padded_length, 1)
mat2 = pad_dim(mat2, k_padded_length, 0)
return torch.ops.aten.mm(mat1, mat2)
elif n_padded_length != 0:
mat2 = pad_dim(mat2, n_padded_length, 1)
return torch.ops.aten.mm(mat1, mat2)[:, :-n_padded_length]
else:
mat1 = pad_dim(mat1, m_padded_length, 0)
return torch.ops.aten.mm(mat1, mat2)[:-m_padded_length, :]
@register_decomposition([aten.bmm])
def bmm_decomp(mat1, mat2):
if (
config.shape_padding
and check_device(mat1, mat2)
and not is_symbolic(mat1, mat2)
and should_pad_bench(mat1, mat2, torch.ops.aten.bmm)
):
m_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
k_padded_length = get_padded_length(mat1.shape[2], get_alignment_size(mat1))
n_padded_length = get_padded_length(mat2.shape[2], get_alignment_size(mat2))
if k_padded_length != 0 or n_padded_length != 0 or m_padded_length != 0:
pad_bmm(mat1, mat2, m_padded_length, k_padded_length, n_padded_length)
return NotImplemented # go directly to lowering
def pad_bmm(mat1, mat2, m_padded_length, k_padded_length, n_padded_length):
# bmm_decomp will go through pad_bmm multiple times if multiple dimensions are needed to be padded
if k_padded_length != 0:
mat1 = pad_dim(mat1, k_padded_length, 2)
mat2 = pad_dim(mat2, k_padded_length, 1)
return torch.ops.aten.bmm(mat1, mat2)
elif n_padded_length != 0:
mat2 = pad_dim(mat2, n_padded_length, 2)
return torch.ops.aten.bmm(mat1, mat2)[:, :, :-n_padded_length].contiguous()
else:
mat1 = pad_dim(mat1, m_padded_length, 1)
return torch.ops.aten.bmm(mat1, mat2)[:, :-m_padded_length, :].contiguous()
@register_decomposition([aten.convolution_backward])
def convolution_backward(
grad_output,
input,
weight,
bias_sizes,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
output_mask,
):
if not output_mask[2] or grad_output.device.type != "cuda":
return NotImplemented
grad_bias = aten.sum(grad_output, [0] + list(range(2, grad_output.dim())))
grad_inp, grad_weight, _ = aten.convolution_backward(
grad_output,
input,
weight,
bias_sizes,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
[output_mask[0], output_mask[1], False],
)
return (grad_inp, grad_weight, grad_bias)
@register_decomposition([aten.log2])
def log2(x):
return torch.log(x) * (1.0 / math.log(2.0))
@register_decomposition([aten.round.decimals])
def round_dec(x, decimals=0):
ten_pow_decimals = 10.0**decimals
return aten.round(x * ten_pow_decimals) * (1.0 / ten_pow_decimals)
@register_decomposition([aten.all.default])
def all(input):
return torch.logical_not(torch.any(torch.logical_not(input)))
@register_decomposition([aten.all.dim])
def all_dim(input, dim, keepdim=False):
return torch.logical_not(torch.any(torch.logical_not(input), dim, keepdim))
# NB: this decomposition is not stride accurate, do not put it in the main
# library
@register_decomposition(aten.copy)
def copy(self, src, non_blocking=False):
intermediate = src.to(self, non_blocking)
if self.size() != intermediate.size():
return aten.expand_copy.default(intermediate, self.size())
else:
return intermediate
@register_decomposition([aten.baddbmm])
def baddbmm(self, batch1, batch2, beta=1, alpha=1):
result = torch.bmm(batch1, batch2)
if not isinstance(alpha, numbers.Number) or alpha != 1:
result = result * alpha
if beta == 0:
return result
if not isinstance(beta, numbers.Number) or beta != 1:
self = self * beta
return self + result
@register_decomposition([aten.cat.default])
def cat(tensors, dim=0):
if len(tensors) == 1:
return tensors[0].clone()
return NotImplemented
@register_decomposition([aten.conj_physical])
def conj_physical(self):
assert not self.is_complex(), "TODO: implement this"
return self
@register_decomposition([aten.lift, aten.detach_])
def lift(self):
return self
@register_decomposition([aten.bernoulli.default])
def bernoulli(self, *, generator=None):
assert generator is None
return torch.rand_like(self, dtype=torch.float32) < self
@register_decomposition([aten.fmin, prims.fmin])
def fmin(self, other):
return torch.where(torch.isnan(other) | (other > self), self, other)
@register_decomposition([aten.fmax, prims.fmax])
def fmax(self, other):
return torch.where(torch.isnan(other) | (other < self), self, other)
@register_decomposition([aten.narrow_copy])
def narrow_copy(self, dim, start, length):
return torch.narrow(self, dim, start, length).clone()
@register_decomposition([aten.expand_copy])
def expand_copy(self, size, *, implicit=False):
return aten.expand(self, size, implicit=implicit).clone()
@register_decomposition([aten.view_copy.default])
def view_copy_default(self, size):
return aten.view(self, size).clone()
@register_decomposition([aten.view_copy.dtype])
def view_copy_dtype(self, dtype):
return self.to(dtype).clone()
@register_decomposition([aten.native_dropout])
def native_dropout(input: Tensor, p: float, train: bool):
if not train or p == 0:
return (input, torch.ones_like(input, dtype=torch.bool))
if p == 1:
return (torch.zeros_like(input), torch.zeros_like(input, dtype=torch.bool))
# intentionally don't decompose to improve pattern matching
return NotImplemented
# The difference between quantize_per_tensor.default and quantize_per_tensor.tensor is
# scale and zero_point is scalar or scalar tensor
@register_decomposition(quantized_decomposed.quantize_per_tensor.default)
def quantize_per_tensor_default_decomp_impl(
input: torch.Tensor,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
inv_scale = 1.0 / scale
return torch.clamp(
torch.round(input * inv_scale) + zero_point, quant_min, quant_max
).to(dtype)
# The difference between dequantize_per_tensor.default and dequantize_per_tensor.tensor is
# scale and zero_point is scalar or scalar tensor
@register_decomposition(quantized_decomposed.dequantize_per_tensor.default)
def dequantize_per_tensor_default_decomp_impl(
input: torch.Tensor,
scale: float,
zero_point: int,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
return (input.to(torch.float32) - zero_point) * scale
@register_decomposition(quantized_decomposed.quantize_per_tensor.tensor)
def quantize_per_tensor_tensor_decomp_impl(
input: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
inv_scale = 1.0 / scale
return torch.clamp(
torch.round(input * inv_scale) + zero_point, quant_min, quant_max
).to(dtype)
@register_decomposition(quantized_decomposed.dequantize_per_tensor.tensor)
def dequantize_per_tensor_tensor_decomp_impl(
input: torch.Tensor,
scale: torch.Tensor,
zero_point: torch.Tensor,
quant_min: int,
quant_max: int,
dtype: torch.dtype,
) -> torch.Tensor:
return (input.to(torch.float32) - zero_point) * scale
@functools.lru_cache(None)
def fast_random_decomps():
return {**decompositions, **extra_random_decomps}
def select_decomp_table():
"""decomps can change based on config"""
if config.fallback_random:
return decompositions
return fast_random_decomps()