blob: fa6715659416401288227f397d7a0f9669277012 [file] [log] [blame]
import functools
import logging
import math
import numbers
import torch
import torch._decomp as decomp
from torch import Tensor
from torch._decomp import core_aten_decompositions, get_decompositions
from torch._decomp.decompositions import pw_cast_for_opmath
from torch.utils._mode_utils import no_dispatch
from . import config, utils
log = logging.getLogger(__name__)
aten = torch.ops.aten
inductor_decompositions = get_decompositions(
[
aten.arange,
aten.bitwise_and_,
aten.bitwise_or_,
aten.clamp_min_,
aten.flip,
aten.lcm,
aten.linalg_vector_norm,
aten.sin_,
aten.sqrt_,
aten.std,
aten.std_mean,
aten._to_copy,
aten.tril_indices,
aten.triu_indices,
aten.unsafe_split,
]
)
decompositions = {**core_aten_decompositions(), **inductor_decompositions}
def register_decomposition(ops):
for op in [ops] if callable(ops) else ops:
if op in decompositions:
log.warning(f"duplicate decomp: {ops}")
return decomp.register_decomposition(ops, decompositions)
@register_decomposition([aten.clamp])
@pw_cast_for_opmath
def clamp(x, min=None, max=None):
if min is not None:
x = x.clamp_min(min)
if max is not None:
x = x.clamp_max(max)
return x
# TorchInductor-only decomposition. It should not be taken to core.
# See https://github.com/pytorch/torchdynamo/pull/1120
@register_decomposition([aten.floor_divide.default])
def floordiv(a, b):
return aten.div.Tensor_mode(a, b, rounding_mode="floor")
def get_alignment_size(x):
if x.dtype == torch.float16 or x.dtype == torch.half or x.dtype == torch.bfloat16:
return 8
elif x.dtype == torch.float32 or x.dtype == torch.float:
return 4
else:
return 0
def check_device(a: Tensor, b: Tensor):
return a.is_cuda and b.is_cuda
def get_padded_length(x, alignment_size):
if alignment_size == 0 or x % alignment_size == 0:
return 0
return int((x // alignment_size + 1) * alignment_size) - x
def pad_dim(x, padded_length, dim):
if padded_length == 0:
return x
pad = x.new_zeros(*x.shape[:dim], padded_length, *x.shape[dim + 1 :])
return torch.cat([x, pad], dim=dim)
@register_decomposition([aten.addmm])
def addmm(input, mat1, mat2, *, beta=1, alpha=1):
if (
config.shape_padding
and check_device(mat1, mat2)
and should_pad_bench(mat1, mat2, torch.ops.aten.addmm, input=input)
):
m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1))
k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2))
if m_padded_length != 0 or k_padded_length != 0 or n_padded_length != 0:
return pad_addmm(
input, mat1, mat2, m_padded_length, k_padded_length, n_padded_length
)
return NotImplemented # go directly to lowering
def pad_addmm(input, mat1, mat2, m_padded_length, k_padded_length, n_padded_length):
# addmm decomp with padding will go through pad_addmm multiple times if multiple dimensions are needed to be padded
if k_padded_length != 0:
mat1 = pad_dim(mat1, k_padded_length, 1)
mat2 = pad_dim(mat2, k_padded_length, 0)
elif n_padded_length != 0:
mat2 = pad_dim(mat2, n_padded_length, 1)
elif m_padded_length != 0:
mat1 = pad_dim(mat1, m_padded_length, 0)
if input is not None and k_padded_length == 0:
if n_padded_length != 0:
if input.dim() == 2:
input = pad_dim(input, n_padded_length, 1)
elif input.dim() == 1:
input = pad_dim(input, n_padded_length, 0)
elif m_padded_length != 0 and input.dim() == 2:
input = pad_dim(input, m_padded_length, 0)
if k_padded_length != 0:
return torch.ops.aten.addmm(input, mat1, mat2)
elif n_padded_length != 0:
return torch.ops.aten.addmm(input, mat1, mat2)[:, :-n_padded_length]
else:
return torch.ops.aten.addmm(input, mat1, mat2)[:-m_padded_length, :]
def should_pad_bench(mat1, mat2, op, input=None):
assert utils.has_triton()
from triton.testing import do_bench
with no_dispatch():
if op is torch.ops.aten.mm or op is torch.ops.aten.addmm:
m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1))
k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2))
elif op is torch.ops.aten.bmm:
m_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
k_padded_length = get_padded_length(mat1.shape[2], get_alignment_size(mat1))
n_padded_length = get_padded_length(mat2.shape[2], get_alignment_size(mat2))
else:
return False
if m_padded_length == k_padded_length == n_padded_length == 0:
return False
mat1 = torch.randn_like(mat1)
mat2 = torch.randn_like(mat2)
warmup = 5
rep = 100
if op is torch.ops.aten.bmm or op is torch.ops.aten.mm:
ori_time = do_bench(
lambda: op(mat1, mat2), warmup=warmup, rep=rep, fast_flush=True
)[0]
else:
if input is not None:
input = torch.randn_like(input)
ori_time = do_bench(
lambda: op(input, mat1, mat2), warmup=warmup, rep=rep, fast_flush=True
)[0]
mat1_pad = torch.randn_like(mat1)
mat2_pad = torch.randn_like(mat2)
if op is torch.ops.aten.addmm:
input_pad = None
if input is not None and input.is_cuda:
input_pad = torch.randn_like(input)
pad_time = do_bench(
lambda: pad_addmm(
input_pad,
mat1_pad,
mat2_pad,
m_padded_length,
k_padded_length,
n_padded_length,
),
warmup=warmup,
rep=rep,
fast_flush=True,
)[0]
elif op is torch.ops.aten.mm:
pad_time = do_bench(
lambda: pad_mm(
mat1_pad,
mat2_pad,
m_padded_length,
k_padded_length,
n_padded_length,
),
warmup=warmup,
rep=rep,
fast_flush=True,
)[0]
else:
pad_time = do_bench(
lambda: pad_bmm(
mat1_pad,
mat2_pad,
m_padded_length,
k_padded_length,
n_padded_length,
),
warmup=warmup,
rep=rep,
fast_flush=True,
)[0]
# Shape padding introduces addtional memory ops. Based on microbenchmarks, 1.1x represents a reasonable
# tradeoff between performance improvement from shape padding and overhead from addtional memory ops
# TODO: Build a learned model which would be better than this heuristic
return ori_time > pad_time * 1.1
@register_decomposition([aten.mm])
def mm_decomp(mat1, mat2):
if (
config.shape_padding
and check_device(mat1, mat2)
and should_pad_bench(mat1, mat2, torch.ops.aten.mm)
):
m_padded_length = get_padded_length(mat1.shape[0], get_alignment_size(mat1))
k_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
n_padded_length = get_padded_length(mat2.shape[1], get_alignment_size(mat2))
if m_padded_length != 0 or k_padded_length != 0 or n_padded_length != 0:
return pad_mm(mat1, mat2, m_padded_length, k_padded_length, n_padded_length)
return NotImplemented # go directly to lowering
def pad_mm(mat1, mat2, m_padded_length, k_padded_length, n_padded_length):
# mm_decomp will go through pad_mm multiple times if multiple dimensions are needed to be padded
if k_padded_length != 0:
mat1 = pad_dim(mat1, k_padded_length, 1)
mat2 = pad_dim(mat2, k_padded_length, 0)
return torch.ops.aten.mm(mat1, mat2)
elif n_padded_length != 0:
mat2 = pad_dim(mat2, n_padded_length, 1)
return torch.ops.aten.mm(mat1, mat2)[:, :-n_padded_length]
else:
mat1 = pad_dim(mat1, m_padded_length, 0)
return torch.ops.aten.mm(mat1, mat2)[:-m_padded_length, :]
@register_decomposition([aten.bmm])
def bmm_decomp(mat1, mat2):
if (
config.shape_padding
and check_device(mat1, mat2)
and should_pad_bench(mat1, mat2, torch.ops.aten.bmm)
):
m_padded_length = get_padded_length(mat1.shape[1], get_alignment_size(mat1))
k_padded_length = get_padded_length(mat1.shape[2], get_alignment_size(mat1))
n_padded_length = get_padded_length(mat2.shape[2], get_alignment_size(mat2))
if k_padded_length != 0 or n_padded_length != 0 or m_padded_length != 0:
pad_bmm(mat1, mat2, m_padded_length, k_padded_length, n_padded_length)
return NotImplemented # go directly to lowering
def pad_bmm(mat1, mat2, m_padded_length, k_padded_length, n_padded_length):
# bmm_decomp will go through pad_bmm multiple times if multiple dimensions are needed to be padded
if k_padded_length != 0:
mat1 = pad_dim(mat1, k_padded_length, 2)
mat2 = pad_dim(mat2, k_padded_length, 1)
return torch.ops.aten.bmm(mat1, mat2)
elif n_padded_length != 0:
mat2 = pad_dim(mat2, n_padded_length, 2)
return torch.ops.aten.bmm(mat1, mat2)[:, :, :-n_padded_length].contiguous()
else:
mat1 = pad_dim(mat1, m_padded_length, 1)
return torch.ops.aten.bmm(mat1, mat2)[:, :-m_padded_length, :].contiguous()
@register_decomposition([aten.convolution_backward])
def convolution_backward(
grad_output,
input,
weight,
bias_sizes,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
output_mask,
):
if not output_mask[2] or grad_output.device.type != "cuda":
return NotImplemented
grad_bias = aten.sum(grad_output, [0] + list(range(2, grad_output.dim())))
grad_inp, grad_weight, _ = aten.convolution_backward(
grad_output,
input,
weight,
bias_sizes,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
[output_mask[0], output_mask[1], False],
)
return (grad_inp, grad_weight, grad_bias)
@register_decomposition([aten.log2])
def log2(x):
return torch.log(x) * (1.0 / math.log(2.0))
@register_decomposition([aten.round.decimals])
def round_dec(x, decimals=0):
ten_pow_decimals = 10.0**decimals
return aten.round(x * ten_pow_decimals) * (1.0 / ten_pow_decimals)
@register_decomposition([aten.all.default])
def all(input):
return torch.logical_not(torch.any(torch.logical_not(input)))
@register_decomposition([aten.all.dim])
def all_dim(input, dim, keeepdim=False):
return torch.logical_not(torch.any(torch.logical_not(input), dim, keeepdim))
# NB: this decomposition is not stride accurate, do not put it in the main
# library
@register_decomposition(aten.copy)
def copy(self, src, non_blocking=False):
intermediate = src.to(self, non_blocking)
if self.size() != intermediate.size():
return aten.expand_copy.default(intermediate, self.size())
else:
return intermediate
@register_decomposition([aten.baddbmm])
def baddbmm(self, batch1, batch2, beta=1, alpha=1):
result = torch.bmm(batch1, batch2)
if not isinstance(alpha, numbers.Number) or alpha != 1:
result = result * alpha
if not isinstance(beta, numbers.Number) or beta != 1:
self = self * beta
return self + result
@register_decomposition([aten.conj_physical])
def conj_physical(self):
assert not self.is_complex(), "TODO: implement this"
return self
@register_decomposition([aten.lift, aten.detach_])
def lift(self):
return self
@register_decomposition([aten.bernoulli.default])
def bernoulli(self, *, generator=None):
assert generator is None
return torch.rand_like(self, dtype=torch.float32) < self
"""
Some decomps result in differences from eager related to randomness.
We put these decomps in a separate table `extra_random_decomps` to allow
turning them on and off via `config.fallback_random`.
"""
extra_random_decomps = get_decompositions(
[
aten.native_dropout,
aten.cauchy,
aten.cauchy_,
aten.exponential,
aten.exponential_,
aten.geometric,
aten.geometric_,
aten.normal,
aten.normal_,
aten.normal_functional,
aten.log_normal,
aten.log_normal_,
aten.uniform_,
]
)
register_extra_random_decomp = functools.partial(
decomp.register_decomposition, registry=extra_random_decomps
)
@register_extra_random_decomp([aten.bernoulli_])
def bernoulli_(self, p=0.5):
return self.copy_(torch.rand_like(self, dtype=torch.float32) < p)
@register_extra_random_decomp([aten.bernoulli.p])
def bernoulli_p(self, p=0.5, *, generator=None):
assert generator is None
return torch.rand_like(self, dtype=torch.float32) < p
@functools.lru_cache(None)
def fast_random_decomps():
return {**decompositions, **extra_random_decomps}
def select_decomp_table():
"""decomps can change based on config"""
if config.fallback_random:
return decompositions
return fast_random_decomps()