blob: 105045c8656de07f357e356e8de700d5e51a9190 [file] [log] [blame]
import collections
import contextlib
import dataclasses
import functools
import itertools
import logging
import re
import textwrap
import traceback
from contextlib import nullcontext
from enum import Enum
from functools import partial
from inspect import signature
from typing import Any, Callable, ClassVar, Dict, List, Optional, Set, Tuple, Union
from unittest.mock import patch
import sympy
from sympy import Expr, Integer, simplify
import torch._logging
import torch.fx
import torch.utils._pytree as pytree
from torch._dynamo.utils import identity
from torch._prims_common import (
compute_required_storage_length,
is_boolean_dtype,
is_float_dtype,
make_channels_last_strides_for,
make_contiguous_strides_for,
)
from torch.fx.experimental.symbolic_shapes import FloorDiv
from . import config, dependencies
from .codegen.common import index_prevent_reordering
from .cuda_properties import get_device_properties
from .dependencies import extract_read_writes, var_builder
from .utils import (
argsort,
cache_on_self,
convert_shape_to_inductor,
convert_shape_to_symint,
developer_warning,
pad_listlike,
sympy_dot,
sympy_product,
sympy_subs,
sympy_symbol,
)
from .virtualized import ops, V
log = logging.getLogger(__name__)
indent = functools.partial(textwrap.indent, prefix=" ")
aten = torch.ops.aten
""" [Note: Inductor IR]
Inductor's IR is produced by executing 'lowering' code (see lowering.py). Each
lowering is registered to a particular aten operator, and expects inputs that
correspond to the aten schema. However, in place of torch Tensor inputs, lowerings
expect Inductor TensorBox inputs.
TensorBox IR represents torch tensors. Tensors are sometimes single objects owning
storage, and sometimes views of another Tensor's storage. Mutating tensor operations
(such as add_()) affect the underlying storage and any associated views. Other operations
(such as .t_()) update metadata about the current view but don't modify the underlying storage.
To model this in Inductor, the IR distinguishes between TensorBox, View, StorageBox and Buffer.
TensorBox is the top level IR construct that any lowering should produce and maps to a torch.Tensor
output from an operation. But just as torch.Tensors take different forms, TensorBox IR can
reference View IR or directly reference StorageBox IRs.
Some Inductor lowerings produce new sets of 'Box'es, while others (such as .t() or other view ops)
may take an existing TensorBox and point it to a new underlying View IR.
Tensors that directly own storage are represented as a chain of:
TensorBox -> StorageBox -> Buffer
where Buffer is a simple (1D) allocation, and StorageBox introduces the concept of a Layout.
If you mutate the data of such a tensor, we swing the StorageBox pointer to point to a new buffer
(leaving the old buffer unmodified and functionalizing the operation).
Tensors backed by views add one more indirection to the IR.
TensorBox -> View -> StorageBox -> Buffer
In these cases, the underlying StorageBox/Buffer will be shared with the pre-view TensorBox.
"""
def validate_ir(node_or_nodes):
def _check_tensorbox(nodes):
# Could expand this to check deeper properties
# (e.g. TensorBox points to View or StorageBox)
if isinstance(nodes, (List, Tuple)):
for node in nodes:
_check_tensorbox(node)
else:
assert isinstance(
nodes,
(
DynamicScalar,
TensorBox,
RandSeedBuffer,
sympy.Symbol,
sympy.core.relational.Relational,
Expr,
),
), f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]"
# Be picky about the accepted data structure (don't use pytree here)
_check_tensorbox(node_or_nodes)
def inverse_reorder(order):
inv_order = dict(zip(order, range(len(order))))
def reindex(index):
assert len(index) == len(inv_order)
return [index[inv_order[i]] for i in range(len(index))]
return reindex
def same_reorder(order):
def reindex(index):
assert len(index) == len(order)
return [index[order[i]] for i in range(len(index))]
return reindex
def fuse_reindexing(reindex1, reindex2):
def reindex(index):
return reindex1(reindex2(index))
return reindex
def stride_order2fill_order(order):
"""
Convert stride order to fill order
For channel last format,
stride order = [3, 0, 2, 1] and fill order = [1, 3, 2, 0]
"""
lookup = {pos: idx for idx, pos in enumerate(order)}
fill_order = [lookup[i] for i in range(len(order))]
return fill_order
def get_stride_order(seq):
"""
Convert strides to stride order
"""
sorted_idx = argsort(seq)
out = [None for _ in range(len(seq))]
for i, elem in enumerate(sorted_idx):
out[elem] = i
return out
def ir_node_to_tensor(x, guard_shape=True):
if x is None:
return None
if not guard_shape:
shape_fn = V.graph.sizevars.size_hint
else:
shape_fn = identity
size = [shape_fn(s) for s in x.get_size()]
if is_storage_and_layout(x):
stride = [shape_fn(s) for s in x.get_layout().stride]
else:
stride = make_contiguous_strides_for(size)
dtype = x.get_dtype()
device = x.get_device()
size = convert_shape_to_symint(size)
stride = convert_shape_to_symint(stride)
t = torch.empty_strided(
size=size, stride=stride, dtype=dtype, device=device
).zero_()
return t
class ModularIndexing(sympy.Function):
"""
ModularIndexing(a, b, c) => (a // b) % c
"""
nargs = (3,)
is_integer = True
@classmethod
def eval(cls, base, divisor, modulus):
if base == 0 or modulus == 1:
return sympy.Integer(0)
if (
isinstance(base, sympy.Integer)
and isinstance(divisor, sympy.Integer)
and isinstance(modulus, sympy.Integer)
):
return (base // divisor) % modulus
if divisor != 1:
gcd = sympy.gcd(base, divisor)
if gcd != 1:
return ModularIndexing(
simplify(base / gcd), simplify(divisor / gcd), modulus
)
if isinstance(base, sympy.Add):
new_terms = []
all_positive = True
for term in base.args:
if sympy.gcd(term, modulus * divisor) != modulus * divisor:
if (isinstance(term, sympy.Integer) and term < 0) or (
isinstance(term, sympy.Mul)
and isinstance(term.args[0], sympy.Integer)
and term.args[0] < 0
):
# workaround for https://github.com/openai/triton/issues/619,
# if there are negative terms, // produces wrong result
# TODO if https://github.com/openai/triton/issues/619 is fixed
# this optimization would become valid
all_positive = False
break
else:
new_terms.append(term)
if len(new_terms) != len(base.args) and all_positive:
return ModularIndexing(sum(new_terms), divisor, modulus)
if isinstance(base, FloorDiv):
return ModularIndexing(base.args[0], base.args[1] * divisor, modulus)
class CleanDiv(FloorDiv):
"""
Div where we can assume no rounding.
This is to enable future optimizations.
"""
pass
class CeilDiv(sympy.Function):
"""
Div used in indexing that rounds up.
"""
is_integer = True
def __new__(cls, base, divisor):
if sympy.gcd(base, divisor) == divisor:
return CleanDiv(base, divisor)
else:
return FloorDiv(base + (divisor - 1), divisor)
def get_device_type(x):
if getattr(x, "get_device", None):
return get_device_type(x.get_device())
if isinstance(x, torch.device):
return x.type
return None
def is_triton(x):
return get_device_type(x) == "cuda"
def is_cpu(x):
return get_device_type(x) == "cpu"
@dataclasses.dataclass
class IRNode:
_current_origins: ClassVar[Set[Any]] = set()
@staticmethod
@contextlib.contextmanager
def current_origins(origins: Set[torch.fx.Node]):
old = IRNode._current_origins
IRNode._current_origins = old | origins
try:
yield
finally:
IRNode._current_origins = old
def __post_init__(self):
self.origins = set(self._current_origins)
self.traceback = traceback.format_stack() if config.debug_ir_traceback else None
def get_traceback(self):
return self.traceback
def common_repr(self):
origins = f"origins={getattr(self, 'origins', '')}"
if len(origins) > 64:
# this can get *very* long
origins = f"{origins[:61]}..."
return [origins]
def str_helper(self, lines):
lines = lines + self.common_repr()
lines = indent(",\n".join(map(str, lines)))
return f"{type(self).__name__}(\n{lines}\n)"
def is_user_of(self, name):
return any(name == dep.name for dep in self.get_reads())
def get_numel(self):
return sympy_product(self.get_size())
@dataclasses.dataclass
class Loops(IRNode):
device: torch.device
dtype: torch.dtype
inner_fn: Callable
ranges: List[Expr]
def __str__(self, names=("ranges",)):
return self.str_helper(
[
f"'{self.device.type}'",
str(self.dtype),
self.inner_fn_str(),
]
+ [f"{name}={getattr(self, name)}" for name in names]
+ [f"origin_node={self.origin_node!r}"]
)
def __post_init__(self):
super().__post_init__()
self.origin_node = None
__repr__ = __str__
def get_dtype(self):
return self.dtype
def get_device(self):
return self.device
def get_origin_node(self):
return self.origin_node
def get_size(self):
return self.ranges
def is_extern(self):
return False
@classmethod
def create(cls, *args, **kwargs):
origin_node = kwargs.pop("origin_node", None)
tb = kwargs.pop("traceback", None)
r = cls(*args, **kwargs)
r.origin_node = origin_node
r.traceback = (
tb or traceback.format_stack() if config.debug_ir_traceback else None
)
return TensorBox.create(r)
@staticmethod
def _index(ranges, prefix="i"):
return [
sympy.Integer(0) if s == 1 else sympy_symbol(f"{prefix}{n}")
for n, s in enumerate(ranges)
]
@cache_on_self
def inner_fn_str(self):
index = self._index(self.ranges)
return V.KernelFormatterHandler.ir_to_string(self.inner_fn, index)
def is_zero_elements(self):
return any(r == 0 for r in self.ranges)
@cache_on_self
def get_reads(self):
with patch.object(FlexibleLayout, "allow_indexing", True):
if self.get_reduction_type():
return extract_read_writes(
self.make_loader(),
self.get_size(),
self.get_reduction_size(),
).reads
else:
return extract_read_writes(
self.make_loader(),
self.get_size(),
).reads
class Pointwise(Loops):
def make_loader(self):
return self.inner_fn
def get_reduction_size(self):
return []
def get_reduction_type(self):
return None
def store_output(self, output_name, indexer, vars):
return ops.store(output_name, indexer(vars), self.inner_fn(vars))
def constant_to_device(self, device):
"""Move this to a given device. Requires that all reads are to constants."""
loader = self.make_loader()
loader = patch.object(ConstantBuffer, "override_device", device)(loader)
return Pointwise(device, self.dtype, loader, self.ranges)
@dataclasses.dataclass
class Scatter(Pointwise):
output_indexer: Callable[[List[Expr]], Expr]
scatter_mode: Optional[str] = None
def constant_to_device(self, device):
"""Move this to a given device. Requires that all reads are to constants."""
loader = self.make_loader()
loader = patch.object(ConstantBuffer, "override_device", device)(loader)
return Scatter(
device,
self.dtype,
loader,
self.ranges,
self.output_indexer,
self.scatter_mode,
)
def store_output(self, output_name, indexer, vars):
return ops.store(
output_name,
indexer(self.output_indexer(vars)),
self.inner_fn(vars),
mode=self.scatter_mode,
)
class ReductionHint(Enum):
INNER = 0
OUTER = 1
OUTER_TINY = 2
DEFAULT = 3
class TileHint(Enum):
SQUARE = 0
DEFAULT = 1
@dataclasses.dataclass
class Reduction(Loops):
reduction_ranges: List[Expr]
reduction_type: str
# self.dtype represents the dst dtype
src_dtype: torch.dtype
reduction_hint: ReductionHint
def __str__(self):
return Loops.__str__(
self, names=("ranges", "reduction_ranges", "reduction_type")
)
__repr__ = __str__
def get_reduction_size(self):
return self.reduction_ranges
def get_reduction_type(self):
return self.reduction_type
def store_reduction(self, output_name, indexer, vars, reduction_vars):
return ops.reduction(
output_name,
self.dtype,
self.src_dtype,
self.reduction_type,
indexer(vars),
self.inner_fn(vars, reduction_vars),
)
def index_length(self):
return len(self.ranges) + len(self.reduction_ranges)
@cache_on_self
def inner_fn_str(self):
index = self._index(self.ranges)
rindex = self._index(self.reduction_ranges, "r")
return V.KernelFormatterHandler.ir_to_string(
self.inner_fn,
index,
rindex,
)
def constant_to_device(self, device):
"""Move this to a given device. Requires that all reads are to constants."""
loader = self.make_loader()
loader = patch.object(ConstantBuffer, "override_device", device)(loader)
return Reduction(
device,
self.dtype,
loader,
self.ranges,
self.reduction_ranges,
self.reduction_type,
self.src_dtype,
ReductionHint.DEFAULT,
)
@staticmethod
def num_splits(
device,
dst_dtype,
src_dtype,
inner_fn,
ranges,
reduction_ranges,
reduction_type,
reduction_numel,
):
num_sm = get_device_properties(device).multi_processor_count
min_elements_per_thread = 32
max_elements_per_thread = 512
threads_per_sm = 2048
min_elements_per_device = min_elements_per_thread * num_sm * threads_per_sm
max_elements_per_device = max_elements_per_thread * num_sm * threads_per_sm
def inner_reduction_splits(reduction_numel_hint, numel_hint):
# do heuristics that's close to eager mode for split inner reduction
# we leak reduction autotune configs here, and will need to refactor to avoid this later
num_warps = 8
num_threads = 32 * num_warps
if numel_hint >= 2 * num_sm: # don't split if there are enough outputs
return 1
if reduction_numel_hint <= 8192:
return 1
if reduction_numel_hint * numel_hint <= min_elements_per_device:
split_size = min_elements_per_thread
elif reduction_numel_hint * numel_hint < max_elements_per_device:
target_blocks = num_sm * threads_per_sm // (2 * num_threads)
blocks_per_output = (target_blocks + numel_hint - 1) // numel_hint
tmp_split_size = (
reduction_numel_hint + num_threads * blocks_per_output - 1
) // (num_threads * blocks_per_output)
divisors = sympy.divisors(reduction_numel_hint)
closest = min(divisors, key=lambda x: abs(x - tmp_split_size))
if abs(closest - tmp_split_size) < 30:
# prefer even splits, but never smalle than min_elements_per_thread
split_size = max(closest, min_elements_per_thread)
else:
split_size = tmp_split_size
else:
divisors = sympy.divisors(reduction_numel_hint)
closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread))
if abs(closest - max_elements_per_thread) < 50:
# prefer even splits
split_size = closest
else:
split_size = max_elements_per_thread
return (reduction_numel_hint + split_size * num_threads - 1) // (
split_size * num_threads
)
def outer_reduction_splits(reduction_numel_hint, numel_hint):
# TODO the best heuristic currently has XBLOCK (corresponding to numel_hint) 128
# extend to even smaller number of outputs
num_warps = 8
num_threads = num_warps * 32
rvals_per_thread = 4 # comes from heuristics, refactor to not leak here
xvals_per_block = 128
xblocks = (numel_hint + xvals_per_block - 1) // xvals_per_block
if reduction_numel_hint * numel_hint < min_elements_per_device:
split_size = min_elements_per_thread
elif reduction_numel_hint * numel_hint < max_elements_per_device:
target_blocks = num_sm * threads_per_sm // (num_threads)
target_blocks = (target_blocks + xblocks - 1) // xblocks
tmp_split_size = (
reduction_numel_hint + rvals_per_thread * target_blocks - 1
) // (rvals_per_thread * target_blocks)
divisors = sympy.divisors(reduction_numel_hint)
closest = min(divisors, key=lambda x: abs(x - tmp_split_size))
if abs(tmp_split_size - closest) < 20:
split_size = max(closest, min_elements_per_thread)
else:
split_size = tmp_split_size
else:
divisors = sympy.divisors(reduction_numel_hint)
closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread))
if abs(closest - max_elements_per_thread) < 50:
# prefer even splits
split_size = closest
else:
split_size = max_elements_per_thread
return (reduction_numel_hint + rvals_per_thread * split_size - 1) // (
rvals_per_thread * split_size
)
reduction_numel_hint = V.graph.sizevars.size_hint(reduction_numel)
numel_hint = V.graph.sizevars.size_hint(sympy_product(ranges))
# easy cases
if numel_hint == 1:
return ReductionHint.INNER, inner_reduction_splits(
reduction_numel_hint, numel_hint
)
if (
reduction_numel_hint <= min_elements_per_thread
or numel_hint >= num_sm * 2 * 32
):
return ReductionHint.DEFAULT, 1
r = Reduction(
device,
dst_dtype,
inner_fn,
ranges,
reduction_ranges,
reduction_type,
src_dtype,
ReductionHint.DEFAULT,
)
def get_read_indices(r):
cb = ComputedBuffer(
name=None,
layout=FlexibleLayout(
device=r.get_device(),
dtype=r.get_dtype(),
size=r.get_size(),
),
data=r,
)
read_writes = cb.get_read_writes()
# try finding the full size producer
# TODO this will fail for something like ((1, N) * (N, 1)).sum()
# this would also possibly be wrong for producers with the different contiguity but we hope those cases are rare
range_vars = [
r
for r in read_writes.range_vars
if isinstance(r, sympy.Expr) and not isinstance(r, sympy.Number)
]
indices = []
changed = False
for md in sorted(read_writes.reads, key=lambda x: x.name):
if all(r in md.index.free_symbols for r in range_vars):
indices.append(md.index)
if md.name in V.graph.name_to_buffer:
buf = V.graph.name_to_buffer[md.name]
original_stride = buf.layout.stride
buf.decide_layout()
if buf.layout.stride != original_stride:
changed = True
return indices, changed
indices, changed = get_read_indices(r)
if changed:
indices, _ = get_read_indices(r)
if len(indices) == 0:
# TODO determine splits when all inputs are broadcast
return ReductionHint.DEFAULT, 1
(_, reduction_vars), ranges = dependencies.index_vars_squeeze(
r.get_size(), r.get_reduction_size()
)
num_outer = 0
num_inner = 0
for i in indices:
i = V.graph.sizevars.simplify_with_ranges(i, ranges)
strides = V.graph.sizevars.stride_hints(i, reduction_vars, ranges.keys())
outer = all(s > 1 for s in strides)
if outer:
num_outer += 1
else:
num_inner += 1
if num_inner > num_outer:
return ReductionHint.INNER, inner_reduction_splits(
reduction_numel_hint, numel_hint
)
else:
return ReductionHint.OUTER, outer_reduction_splits(
reduction_numel_hint, numel_hint
)
@staticmethod
def _unroll_reduction_fn(inner_fn, reduction_ranges, reduction_type):
"""Convert inner_fn from a reduction to an pointwise"""
reduction_ranges = [
V.graph.sizevars.guard_static_shape(x) for x in reduction_ranges
]
if reduction_type == "sum":
def combine_fn(a, b):
return ops.add(a, b)
elif reduction_type == "prod":
def combine_fn(a, b):
return ops.mul(a, b)
elif reduction_type == "xor_sum":
def combine_fn(a, b):
return ops.bitwise_xor(a, b)
elif reduction_type == "min":
def combine_fn(a, b):
return ops.minimum(a, b)
elif reduction_type == "max":
def combine_fn(a, b):
return ops.maximum(a, b)
elif reduction_type == "any":
def combine_fn(a, b):
return ops.logical_or(a, b)
elif reduction_type == "argmin":
def combine_fn(a, b):
a_value, a_index = a
b_value, b_index = b
mask = ops.lt(b_value, a_value)
a_isnan = ops.ne(a_value, a_value)
b_isnan = ops.ne(b_value, b_value)
mask = ops.logical_or(mask, ops.gt(b_isnan, a_isnan))
return (
ops.where(mask, b_value, a_value),
ops.where(mask, b_index, a_index),
)
elif reduction_type == "argmax":
def combine_fn(a, b):
a_value, a_index = a
b_value, b_index = b
mask = ops.gt(b_value, a_value)
a_isnan = ops.ne(a_value, a_value)
b_isnan = ops.ne(b_value, b_value)
mask = ops.logical_or(mask, ops.gt(b_isnan, a_isnan))
return (
ops.where(mask, b_value, a_value),
ops.where(mask, b_index, a_index),
)
else:
raise NotImplementedError(f"unknown reduction_type={reduction_type}")
def fn(index):
return functools.reduce(
combine_fn,
(
value_fn(index, rindex)
for rindex in itertools.product(
*[range(x) for x in reduction_ranges]
)
),
)
if reduction_type in ("argmin", "argmax"):
flatten_index = FixedLayout(
None,
None,
reduction_ranges,
FlexibleLayout.contiguous_strides(reduction_ranges),
).make_indexer()
def value_fn(index, rindex):
rindex = [sympy.expand(i) for i in rindex]
return (
inner_fn(index, rindex),
ops.index_expr(flatten_index(rindex), torch.int64),
)
return lambda index: fn(index)[1]
else:
value_fn = inner_fn
return fn
@classmethod
def create(
cls,
device: torch.device,
dst_dtype: torch.dtype,
src_dtype: torch.dtype,
inner_fn: Callable,
ranges: List[Expr],
reduction_ranges: List[Expr],
reduction_type: str,
reduction_hint: ReductionHint = ReductionHint.DEFAULT,
):
reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges))
if reduction_numel == 0:
# N.B. This is a hack to generate the literal of the given type
# Ideally, we should be fixing `def constant` in triton.py
# but it breaks due to hardcoded dtypes in other places
def py_cnst(val):
return (
bool(val)
if dst_dtype == torch.bool
else float(val)
if dst_dtype.is_floating_point
else int(val)
)
rtypes_to_inits = {
"sum": py_cnst(0),
"xor_sum": py_cnst(0),
"prod": py_cnst(1),
"any": py_cnst(0),
# "all" is desugared to `!any(!val)`
}
assert (
reduction_type in rtypes_to_inits.keys()
), f"{reduction_type} not supported for zero-dimension tensors!"
def const_fn(index):
return ops.constant(rtypes_to_inits[reduction_type], dst_dtype)
return Pointwise.create(
device=device,
dtype=src_dtype,
inner_fn=const_fn,
ranges=list(ranges),
)
if reduction_numel == 1:
# this reduction is actually a pointwise op
if reduction_type in ("argmin", "argmax"):
def fn(index):
return ops.constant(0, dst_dtype)
else:
def fn(index):
reduction_index = [sympy.Integer(0) for _ in reduction_ranges]
return inner_fn(index, reduction_index)
return Pointwise.create(device, dst_dtype, fn, ranges)
if (
isinstance(reduction_numel, sympy.Integer)
and V.graph.sizevars.size_hint(reduction_numel)
< config.unroll_reductions_threshold
and sympy_product(ranges) != 1
):
return Pointwise.create(
device,
dst_dtype,
cls._unroll_reduction_fn(inner_fn, reduction_ranges, reduction_type),
ranges,
)
split_reduction = (
is_triton(device)
and reduction_type
not in {
"argmax",
"argmin",
}
and config.split_reductions
)
if split_reduction:
# TODO(voz): dedup with sizevar util introduced in other PR
def _is_static(x):
return isinstance(x, (int, sympy.Integer))
split_reduction = (
all(_is_static(r) for r in ranges)
and all(_is_static(r) for r in reduction_ranges)
and _is_static(reduction_numel)
)
if split_reduction:
# triton doesn't support reduce to single element well, so break it up
hint, split = cls.num_splits(
device,
dst_dtype,
src_dtype,
inner_fn,
ranges,
reduction_ranges,
reduction_type,
reduction_numel,
)
# intermediate reduction in split can contain complex indexing,
# and num_splits will fail to correctly set the hint
# reuse the passed hint if available
if reduction_hint == ReductionHint.DEFAULT:
reduction_hint = hint
if split > 1:
# triton doesn't support reduce to single element well, so break it up
return cls.create_multilayer(
device,
dst_dtype,
src_dtype,
inner_fn,
ranges,
reduction_ranges,
reduction_type,
split,
reduction_hint,
)
return TensorBox.create(
Reduction(
device,
dst_dtype,
inner_fn,
ranges,
reduction_ranges,
reduction_type,
src_dtype,
reduction_hint,
)
)
@staticmethod
def default_value(reduction_type, dtype):
if reduction_type in {"max", "argmax"}:
if is_float_dtype(dtype):
return float("-inf")
elif is_boolean_dtype(dtype):
return 0
else:
return torch.iinfo(dtype).min
if reduction_type in {"min", "argmin"}:
if is_float_dtype(dtype):
return float("inf")
elif is_boolean_dtype(dtype):
return 1
else:
return torch.iinfo(dtype).max
return {
"sum": 0,
"prod": 1,
"xor_sum": 0,
"any": 0,
}[reduction_type]
@classmethod
def create_multilayer(
cls,
device: torch.device,
dst_dtype: torch.dtype,
src_dtype: torch.dtype,
inner_fn: Callable,
ranges: List[Expr],
reduction_ranges: List[Expr],
reduction_type: str,
split: int,
reduction_hint: ReductionHint,
):
"""
Break a large reduction up into multiple smaller reductions
recursively
"""
reduction_numel = sympy_product(reduction_ranges)
# TODO(jansel): convert this to dynamic shapes
# TODO(jansel): realize the reduction so we can do dynamic indexing
reduction_ranges = [
sympy.Integer(V.graph.sizevars.guard_static_shape(s))
for s in reduction_ranges
]
reduction_numel = sympy.Integer(
V.graph.sizevars.guard_static_shape(reduction_numel)
)
if V.graph.sizevars.size_hint(reduction_numel) % split == 0:
need_mask = False
else:
need_mask = True
split = sympy.Integer(split)
block_size = FloorDiv(reduction_numel + (split - 1), split)
reindex = View.dynamic_reshape_indexer(reduction_ranges, [reduction_numel])
def wrapper_fn(index, reduction_index):
(reduction_index,) = reduction_index
*new_index, reduction_block = index
indices = block_size * reduction_block + reduction_index
def body():
return inner_fn(new_index, reindex([indices]))
if need_mask:
mask = ops.lt(
ops.index_expr(indices, torch.int32),
ops.index_expr(reduction_numel, torch.int32),
)
return ops.masked(
mask, body, cls.default_value(reduction_type, dst_dtype)
)
else:
return body()
# triton will automatically compute reductions in fp32 if reducing over fp16/bf16
# within the kernel. keep the intermediate in fp32 so as to keep the whole reduction
# in fp32 and not reduce precision by breaking up the kernel into multiple layers
intermediate_dtype = (
dst_dtype
if dst_dtype not in (torch.float16, torch.bfloat16)
else torch.float
)
intermediate = Reduction.create(
device,
intermediate_dtype,
src_dtype,
wrapper_fn,
[*ranges, split],
[block_size],
reduction_type,
reduction_hint,
)
intermediate.realize()
intermediate_loader = intermediate.make_loader()
def intermediate_fn(index, reduction_index):
return intermediate_loader([*index, *reduction_index])
numel_hint = V.graph.sizevars.size_hint(sympy_product(ranges))
if split <= 512 and numel_hint <= 512 and reduction_hint == ReductionHint.OUTER:
reduction_hint = ReductionHint.OUTER_TINY
if (
split <= 1024
and numel_hint <= 256
and reduction_hint == ReductionHint.OUTER
):
reduction_hint = ReductionHint.OUTER_TINY
return TensorBox.create(
Reduction(
device,
dst_dtype,
intermediate_fn,
ranges,
[split],
reduction_type,
src_dtype,
reduction_hint,
)
)
def is_storage_and_layout(x):
try:
as_storage_and_layout(x, freeze=False)
return True
except NotImplementedError:
return False
def is_contiguous_storage_and_layout(x):
try:
buffer, layout = as_storage_and_layout(x, freeze=False)
return layout.is_contiguous()
except NotImplementedError:
return False
def as_storage_and_layout(x, freeze=True, want_contiguous=False, stride_order=None):
"""Try to simplify x into a StorageBox and a Layout"""
if isinstance(x, TensorBox):
return as_storage_and_layout(
x.data,
freeze=freeze,
want_contiguous=want_contiguous,
stride_order=stride_order,
)
if isinstance(x, StorageBox) and isinstance(x.data, Buffer):
if freeze:
if want_contiguous:
x.data.freeze_layout()
assert x.data.layout.is_contiguous()
elif stride_order is not None:
x.data.freeze_layout_with_stride_order(stride_order)
else:
x.data.decide_layout()
return x, x.data.layout
if isinstance(x, ReinterpretView):
# making the base of x contiguous or stride_ordered will not necessarily make
# the ReinterpretedView either, so dont pass along those arguments
buffer, _ = as_storage_and_layout(
x.data,
freeze=freeze,
)
return buffer, x.layout
raise NotImplementedError
as_contiguous_storage_and_layout = functools.partial(
as_storage_and_layout, want_contiguous=True
)
def is_stride_order_storage_and_layout(x, stride_order):
try:
buffer, layout = as_storage_and_layout(x, freeze=False)
return layout.is_stride_ordered(stride_order)
except NotImplementedError:
return False
@dataclasses.dataclass
class BaseView(IRNode):
data: IRNode
def get_dtype(self):
return self.data.get_dtype()
def get_device(self):
return self.data.get_device()
def get_origin_node(self):
return None
def get_name(self):
return self.data.get_name()
def mark_reuse(self, users):
return self.data.mark_reuse(users)
def has_exceeded_max_reads(self):
return self.data.has_exceeded_max_reads()
def realize(self):
return self.data.realize()
def realize_hint(self):
return self.data.realize_hint()
def get_storage_numel(self):
return self.data.get_storage_numel()
def is_extern(self):
return self.data.is_extern()
@cache_on_self
def get_reads(self):
with patch.object(FlexibleLayout, "allow_indexing", True):
return extract_read_writes(
self.make_loader(),
self.get_size(),
).reads
def unwrap_view(self):
x = self
while isinstance(x, BaseView):
x = x.data
return x
def constant_to_device(self, device):
"""Move this to a given device. Requires that all reads are to constants."""
loader = self.make_loader()
loader = patch.object(ConstantBuffer, "override_device", device)(loader)
return Pointwise(device, self.get_dtype(), loader, self.get_size())
@dataclasses.dataclass
class ExpandView(BaseView):
size: List[Expr]
@staticmethod
def _normalize_size(x, new_size):
"""Replace `-1` with correct sizes"""
new_size = list(map(sympy.expand, new_size))
old_size = x.get_size()
old_size = [None] * (len(new_size) - len(old_size)) + list(old_size)
assert len(new_size) == len(old_size)
for i in range(len(new_size)):
if new_size[i] == -1:
assert old_size[i] is not None
new_size[i] = old_size[i]
return new_size
@classmethod
def create(cls, x, new_size):
new_size = cls._normalize_size(x, new_size)
if is_storage_and_layout(x):
storage, old_layout = as_storage_and_layout(x)
skip = len(new_size) - len(old_layout.size)
assert skip >= 0
new_stride = [sympy.Integer(0)] * skip
for stride, size in zip(old_layout.stride, old_layout.size):
new_stride.append(stride if size != 1 else sympy.Integer(0))
new_layout = FixedLayout(
old_layout.device,
old_layout.dtype,
list(new_size),
new_stride,
old_layout.offset,
)
return ReinterpretView(storage, new_layout)
return ExpandView(x, new_size)
def get_size(self):
return self.size
def make_loader(self):
target = self.get_size()
actual = self.data.get_size()
skip = len(target) - len(actual)
inner = self.data.make_loader()
def load(index):
index = list(index[skip:])
assert len(index) == len(actual)
for i in range(len(actual)):
if actual[i] == 1:
# zero out broadcast dimension
index[i] = sympy.Integer(0)
return inner(index)
return load
@dataclasses.dataclass
class PermuteView(BaseView):
dims: List[Expr]
@classmethod
def create(cls, x, dims):
dims = cls._map_neg_dims(dims)
assert set(dims) == set(range(len(dims)))
if is_storage_and_layout(x):
storage, old_layout = as_storage_and_layout(x)
new_layout = FixedLayout(
old_layout.device,
old_layout.dtype,
[old_layout.size[i] for i in dims],
[old_layout.stride[i] for i in dims],
old_layout.offset,
)
return ReinterpretView(storage, new_layout)
return PermuteView(x, dims)
@classmethod
def _map_neg_dims(cls, dims):
return [dim if dim >= 0 else len(dims) + dim for dim in dims]
def get_size(self):
assert set(self._map_neg_dims(self.dims)) == set(range(len(self.dims)))
size = self.data.get_size()
return [size[i] for i in self.dims]
def make_loader(self):
inner = self.data.make_loader()
inv = {j: i for i, j in enumerate(self.dims)}
inv = [inv[i] for i in range(len(self.dims))]
assert set(inv) == set(range(len(self.dims)))
def load(index):
index = [index[i] for i in inv]
return inner(index)
return load
class SqueezeView(BaseView):
@classmethod
def create(cls, x, *, dim=None):
if is_storage_and_layout(x):
storage, old_layout = as_storage_and_layout(x)
new_size = []
new_stride = []
if dim is not None:
assert isinstance(dim, int), "expected integer dim argument"
assert 0 <= dim and dim < len(old_layout.size)
for i, (size, stride) in enumerate(zip(old_layout.size, old_layout.stride)):
if dim is None:
if size != 1:
new_size.append(size)
new_stride.append(stride)
else:
if i != dim:
new_size.append(size)
new_stride.append(stride)
else:
assert size == 1, "expected squeezed size to be 1"
new_layout = FixedLayout(
old_layout.device,
old_layout.dtype,
new_size,
new_stride,
old_layout.offset,
)
return ReinterpretView(storage, new_layout)
if dim is None:
# redirect to a generic view
return View.create(x, [s for s in x.get_size() if s != 1])
else:
assert x.get_size()[dim] == 1
return View.create(x, [s for i, s in enumerate(x.get_size()) if i != dim])
@staticmethod
def squeezer(size: Tuple[sympy.Expr, ...]):
new_size = [s for s in size if s != 1]
not_one = [i for i, s in enumerate(size) if s != 1]
length = len(size)
def reindex(index: List[sympy.Expr]) -> List[sympy.Expr]:
assert len(index) == len(not_one), f"{index} {not_one}"
new_index = [sympy.Integer(0)] * length
for idx, s in zip(not_one, index):
new_index[idx] = s
return tuple(new_index)
return new_size, reindex
def __init__(self, data):
raise AssertionError("use SqueezeView.create()")
@dataclasses.dataclass
class View(BaseView):
size: List[Expr]
reindex: Callable
def make_indexer(self):
base_indexer = self.data.make_indexer()
def indexer(idx):
return base_indexer(self.reindex(idx))
return indexer
@staticmethod
def handle_negative_index(idx, size):
idx = sympy.expand(idx)
size = sympy.expand(size)
sizevars = V.graph.sizevars
if sizevars.size_hint(idx) < 0:
sizevars.guard_lt(idx, 0)
idx = idx + size
return idx
def reindex_str(self):
index_old = [sympy_symbol(f"i{n}") for n in range(len(self.size))]
index_new = list(self.reindex(index_old))
return f"lambda {', '.join(map(str, index_old))}: {index_new}"
def __str__(self):
return self.str_helper(
[self.data, f"size={self.size}", f"reindex={self.reindex_str()}"]
)
__repr__ = __str__
@classmethod
def create(cls, x, new_size):
assert isinstance(new_size, (tuple, list))
old_size, new_size = cls.resolve_negative_size(x.get_size(), new_size)
# Skip pointless views
if V.graph.sizevars.statically_known_list_equals(old_size, new_size):
return x
if 0 in new_size and is_storage_and_layout(x):
storage, old_layout = as_storage_and_layout(x, freeze=False)
new_layout = FixedLayout(
old_layout.device,
old_layout.dtype,
new_size,
FlexibleLayout.contiguous_strides(new_size),
old_layout.offset,
)
return ReinterpretView(storage, new_layout)
# TODO: a new class for FixedTransferLayout that output layout is constrained by input layout
elif is_contiguous_storage_and_layout(x) and not isinstance(
x.data, ExternKernelAlloc
):
storage, old_layout = as_contiguous_storage_and_layout(x)
new_layout = FixedLayout(
old_layout.device,
old_layout.dtype,
new_size,
FlexibleLayout.contiguous_strides(new_size),
old_layout.offset,
)
return ReinterpretView(storage, new_layout)
reindex = cls.dynamic_reshape_indexer(old_size, new_size)
return cls(x, tuple(new_size), reindex)
@staticmethod
def resolve_negative_size(old_size, new_size):
new_size = [V.graph.sizevars.simplify(x) for x in new_size]
old_size = [V.graph.sizevars.simplify(x) for x in old_size]
new_size = list(new_size)
for i in range(len(new_size)):
if new_size[i] == -1:
new_size[i] = sympy.Integer(1)
new_size[i] = CleanDiv(sympy_product(old_size), sympy_product(new_size))
break
V.graph.sizevars.guard_equals(sympy_product(old_size), sympy_product(new_size))
return old_size, new_size
@classmethod
def dynamic_reshape_indexer(cls, old_size, new_size):
try:
reindex = cls._dynamic_reshape_indexer(old_size, new_size)
except (AssertionError, IndexError):
# optimistic algorithm failed, lets do a fallback
flat = [sympy_product(old_size)]
reindex1 = cls._dynamic_reshape_indexer(old_size, flat)
reindex2 = cls._dynamic_reshape_indexer(flat, new_size)
reindex = fuse_reindexing(reindex1, reindex2)
return reindex
@staticmethod
def _dynamic_reshape_indexer(old_size, new_size):
"""
Perform a reshape entirely by modifying indexing math
"""
size_hint = V.graph.sizevars.size_hint
vars = [sympy_symbol(f"view{i}") for i in range(len(new_size))]
stack_new = list(zip(vars, new_size))
stack_old = list(old_size)
view_expr = []
while stack_new and stack_old:
size_old = stack_old.pop()
var, size_new = stack_new.pop()
if size_old == 1:
view_expr.append(sympy.Integer(0))
stack_new.append((var, size_new)) # re-add
elif size_new == 1:
stack_old.append(size_old) # re-add
elif size_hint(size_new) == size_hint(size_old):
view_expr.append(var)
V.graph.sizevars.guard_equals(size_new, size_old)
elif size_hint(size_new) < size_hint(size_old):
while size_hint(size_new) < size_hint(size_old):
var2, size_new2 = stack_new.pop()
var = var2 * size_new + var
size_new = size_new * size_new2
view_expr.append(var)
V.graph.sizevars.guard_equals(size_new, size_old)
elif size_hint(size_new) > size_hint(size_old):
divisor = sympy.Integer(1)
modulus = size_old
view_expr.append(ModularIndexing(var, divisor, modulus))
divisor = divisor * modulus
while size_hint(size_new) > size_hint(size_old):
modulus = stack_old.pop()
view_expr.append(ModularIndexing(var, divisor, modulus))
divisor = divisor * modulus
size_old = size_old * modulus
V.graph.sizevars.guard_equals(size_new, size_old)
else:
raise AssertionError()
while stack_old:
size_old = stack_old.pop()
V.graph.sizevars.guard_equals(size_old, 1)
view_expr.append(sympy.Integer(0))
while stack_new:
var, size_new = stack_new.pop()
V.graph.sizevars.guard_equals(size_new, 1)
view_expr = list(reversed(view_expr))
assert len(view_expr) == len(old_size)
def reindex(index):
assert len(index) == len(vars), (len(index), len(vars))
replacements = dict(zip(vars, index))
return tuple(sympy_subs(x, replacements) for x in view_expr)
return reindex
def get_size(self):
return self.size
def make_loader(self):
def load(index):
return inner(self.reindex(index))
inner = self.data.make_loader()
return load
@dataclasses.dataclass
class ReinterpretView(BaseView):
"""Pretend our storage has a different layout"""
layout: "Layout"
def __post_init__(self):
super().__post_init__()
if isinstance(self.data, BaseView):
self.data = self.data.unwrap_view()
def __str__(self):
return self.str_helper(
[
self.data,
self.layout,
]
)
__repr__ = __str__
def get_name(self):
return self.data.get_name()
def get_device(self):
return self.layout.device
def get_origin_node(self):
return None
def get_dtype(self):
return self.layout.dtype
def get_size(self):
return list(self.layout.size)
def get_stride(self):
return list(self.layout.stride)
def make_loader(self):
def loader(index):
indexer = self.layout.make_indexer()
return ops.load(self.get_name(), indexer(index))
return loader
def make_indexer(self):
return self.layout.make_indexer()
def get_layout(self):
return self.layout
def freeze_layout(self):
pass
def codegen_reference(self):
size = V.graph.wrapper_code.codegen_shape_tuple(self.layout.size)
stride = V.graph.wrapper_code.codegen_shape_tuple(self.layout.stride)
offset = V.graph.wrapper_code.codegen_sizevar(self.layout.offset)
namespace = V.graph.wrapper_code.namespace
if offset != "0":
return (
f"{namespace}as_strided({self.get_name()}, {size}, {stride}, {offset})"
)
return f"{namespace}as_strided({self.get_name()}, {size}, {stride})"
class SliceView(View):
@classmethod
def create(cls, x, dim, start, end, step=1):
step = sympy.expand(step)
assert step > 0
try:
if start == 0 and end >= 2**63 - 1 and step == 1:
return x
except TypeError:
pass
sizevars = V.graph.sizevars
new_size = list(x.get_size())
start = cls.handle_negative_index(start, new_size[dim])
end = cls.handle_negative_index(end, new_size[dim])
end = sizevars.guard_min(end, new_size[dim])
start = sizevars.guard_min(sizevars.guard_min(start, new_size[dim]), end)
if start == 0 and sizevars.size_hint(end - new_size[dim]) == 0 and step == 1:
sizevars.guard_equals(end, new_size[dim])
return x
new_size[dim] = FloorDiv(end - start + (step - 1), step)
if is_storage_and_layout(x):
# Fast path
storage, old_layout = as_storage_and_layout(x)
new_stride = list(old_layout.stride)
new_stride[dim] = new_stride[dim] * step
new_layout = FixedLayout(
old_layout.device,
old_layout.dtype,
new_size,
new_stride,
old_layout.offset + old_layout.stride[dim] * start,
)
return ReinterpretView(storage, new_layout)
def reindex(index):
assert len(index) == len(new_size), f"wrong ndim {index} {new_size}"
index = list(index)
index[dim] = index[dim] * step + start
return index
# redirect to a generic view
return SliceView(x, size=new_size, reindex=reindex)
class BaseConstant(IRNode):
def get_size(self):
return ()
def get_dtype(self):
return self.dtype
def get_device(self):
return self.device
def get_origin_node(self):
return None
def mark_reuse(self, users):
pass
def has_exceeded_max_reads(self):
return False
def get_reads(self):
return ()
def is_extern(self):
return False
@dataclasses.dataclass
class Constant(BaseConstant):
value: Any
dtype: torch.dtype
device: torch.device
def make_loader(self):
def loader(index):
return ops.constant(self.value, self.dtype)
return loader
def realize(self):
pass
@dataclasses.dataclass
class IndexingConstant(BaseConstant):
index: Any
dtype: torch.dtype
device: torch.device
def make_loader(self):
def loader(index):
return ops.index_expr(self.index, self.dtype)
return loader
@dataclasses.dataclass
class Layout(IRNode):
def __init__(
self,
device: torch.device,
dtype: torch.dtype,
size: List[Expr],
stride: List[Expr],
offset: Expr = Integer(0),
):
assert stride is None or len(size) == len(
stride
), f"size={size}, stride={stride}"
self.device = device
self.dtype = dtype
assert all(isinstance(s, (Expr, int)) for s in size)
self.size = size
self._stride = stride
self.offset = offset
@property
def stride(self):
return self._stride
def __str__(self):
offset = ""
if self.offset != 0:
offset = f", offset={self.offset}"
return (
f"{type(self).__name__}('{self.device.type}', {self.dtype}, "
f"size={self.size}, stride={self.stride}{offset})"
)
__repr__ = __str__
def is_contiguous(self):
for left, right, size in zip(
self.stride, FlexibleLayout.contiguous_strides(self.size), self.size
):
if size != 1 and left != right:
return False
return True
def is_channels_last_contiguous(self):
ndim = len(self.size)
if ndim not in [4, 5]:
return False
for left, right, size in zip(
self.stride, make_channels_last_strides_for(self.size), self.size
):
if size != 1 and left != right:
return False
return True
def is_transposed(self):
for left, right, size in zip(
self.stride,
reversed(FlexibleLayout.contiguous_strides(self.size)),
self.size,
):
if size != 1 and left != right:
return False
return True
def is_stride_ordered(self, order):
assert len(self.stride) == len(order)
# reorder the stride given order
stride_ordered = [None] * len(order)
for i in range(len(order)):
stride_ordered[order[i]] = V.graph.sizevars.size_hint(self.stride[i])
# check if it is in ascending order
for i in range(len(order) - 1):
if stride_ordered[i] > stride_ordered[i + 1]:
return False
return True
def is_channels_last_stride_ordered(self):
# create channels_last order(NCHW, NCDHW, the C is the first order).
order = [0] + list(reversed(range(1, len(self.stride) - 1)))
order = [len(order)] + order
return self.is_stride_ordered(order)
def as_fixed(self):
return FixedLayout(
self.device,
self.dtype,
self.size,
self.stride,
self.offset,
)
def make_indexer(self):
assert (
FlexibleLayout.allow_indexing
), f"convert {type(self).__name__} to FixedLayout first"
return self.as_fixed().make_indexer()
def __eq__(self, other) -> bool:
return (
self.device == other.device
and self.dtype == other.dtype
and self.size == other.size
and self.stride == other.stride
and self.offset == other.offset
)
def storage_size(self) -> sympy.Expr:
return compute_required_storage_length(self.size, self.stride, self.offset)
class FixedLayout(Layout):
"""A Tensor layout we cannot change"""
def __init__(
self,
device: torch.device,
dtype: torch.dtype,
size: List[Expr],
stride: List[Expr] = None,
offset: Expr = Integer(0),
):
if stride is None:
stride = FlexibleLayout.contiguous_strides(size)
super().__init__(
device,
dtype,
size,
stride,
offset,
)
def make_indexer(self):
"""A closure containing math to read a given element"""
def indexer(index):
assert len(index) == len(self.stride) == len(self.size)
result = self.offset
for idx, stride, sz in zip(index, self.stride, self.size):
if sz != 1:
result = result + idx * stride
return result
return indexer
class FlexibleLayout(Layout):
"""A Tensor layout we are allowed to change"""
allow_indexing = False
@staticmethod
def contiguous_strides(sizes):
if len(sizes) == 0:
return []
reversed_strides = [sympy.Integer(1)]
for size in reversed(sizes[1:]):
reversed_strides.append(size * reversed_strides[-1])
return list(reversed(reversed_strides))
@staticmethod
def fill_ordered(sizes, order):
"""
Create a stride based on the order the dimensions should be filled in.
In this format, channels last would be:
[1, 3, 2, 0]
"""
assert set(range(len(sizes))) == set(order)
next_stride = sympy.Integer(1)
strides = [None] * len(order)
for i in order:
strides[i] = next_stride
next_stride = next_stride * sizes[i]
return strides
@staticmethod
def stride_ordered(sizes, order):
"""
Create a stride based on the sorted order of a permuted range.
In this format, channels last would be:
[3, 0, 2, 1]
"""
assert set(range(len(sizes))) == set(order)
fill_order = stride_order2fill_order(order)
return FlexibleLayout.fill_ordered(sizes, fill_order)
@staticmethod
def same_ordered(sizes, stride):
"""
Create a stride that has the same stride order as given stride
For example, if given stride is [1000, 1, 100, 10],
the fill order should be [1, 3, 2, 0]
"""
assert len(sizes) == len(stride)
stride = [V.graph.sizevars.size_hint(x) for x in stride]
fill_order = sorted(range(len(stride)), key=stride.__getitem__)
return FlexibleLayout.fill_ordered(sizes, fill_order)
def as_stride_order(self, order):
return FixedLayout(
self.device,
self.dtype,
self.size,
self.stride_ordered(self.size, order),
self.offset,
)
def as_fill_order(self, order):
return FixedLayout(
self.device,
self.dtype,
self.size,
self.fill_ordered(self.size, order),
self.offset,
)
def as_same_order(self, stride):
return FixedLayout(
self.device,
self.dtype,
self.size,
self.same_ordered(self.size, stride),
self.offset,
)
def __init__(self, device, dtype, size, stride_order=None):
if stride_order:
strides = FlexibleLayout.fill_ordered(size, stride_order)
else:
strides = FlexibleLayout.contiguous_strides(size)
super().__init__(device, dtype, size, strides)
class AliasedLayout(Layout):
"""Shares the same storage as another tensor"""
def __init__(self, view: "ReinterpretView"):
layout = view.get_layout()
super().__init__(
layout.device,
layout.dtype,
layout.size,
layout.stride,
)
self.view = view
def make_indexer(self):
return self.as_fixed().make_indexer()
def maybe_guard_aligned(self):
offset = self.view.get_layout().offset
if offset == 0:
return True
from .compile_fx import ALIGNMENT
return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT)
class MutationLayout(Layout):
def __init__(self, target: IRNode):
super().__init__(
target.get_device(),
target.get_dtype(),
target.get_size(),
None, # type: ignore[arg-type]
)
self.target = target
@Layout.stride.getter
def stride(self):
return self.real_layout().stride
def storage_size(self) -> sympy.Expr:
return self.real_layout().storage_size()
def real_layout(self):
def unwrap_views(target):
if isinstance(target, MutationLayout):
return unwrap_views(target.target)
if isinstance(target, BaseView):
return unwrap_views(target.unwrap_view())
if isinstance(target, MutableBox):
return unwrap_views(target.data)
return target
return unwrap_views(self.target).layout
@classmethod
def realize_into(cls, src, dst):
dst.realize()
V.graph.realize_users_of(dst.get_name())
if isinstance(src, TensorBox):
src = src.data
if not isinstance(src, StorageBox) or src.is_user_of(dst.get_name()):
need_copy = True
else:
src.realize()
need_copy = not isinstance(src.data.layout, FlexibleLayout)
if need_copy:
src = Pointwise.create(
device=src.get_device(),
dtype=src.get_dtype(),
inner_fn=src.make_loader(),
ranges=[
V.graph.sizevars.guard_equals(a, b)
for a, b in zip(src.get_size(), dst.get_size())
],
).data
src.realize()
assert isinstance(src.data.layout, FlexibleLayout)
src.data.layout = MutationLayout(dst)
return src.data
def as_fixed(self):
return self
def make_indexer(self):
return self.target.make_indexer()
@dataclasses.dataclass
class Buffer(IRNode):
name: str
layout: Layout
def __post_init__(self):
super().__post_init__()
self.origin_node = None
def make_indexer(self):
return self.layout.make_indexer()
def get_name(self):
assert self.name
return self.name
def get_device(self):
return self.layout.device
def get_origin_node(self):
return self.origin_node
def get_dtype(self):
return getattr(self.layout, "dtype", None)
def get_size(self):
return list(self.layout.size)
def get_stride(self):
return list(self.layout.stride)
def get_layout(self):
return self.layout
def get_storage_numel(self):
return self.get_numel()
def is_extern(self):
return False
def freeze_layout(self):
if not isinstance(self.layout, (MultiOutputLayout, AliasedLayout)):
self.layout = self.layout.as_fixed()
def freeze_layout_with_stride_order(self, order):
assert isinstance(self.layout, FlexibleLayout)
self.layout = self.layout.as_stride_order(order)
def freeze_layout_with_fill_order(self, order):
assert isinstance(self.layout, FlexibleLayout)
self.layout = self.layout.as_fill_order(order)
def freeze_layout_with_same_order(self, stride):
assert isinstance(self.layout, FlexibleLayout)
self.layout = self.layout.as_same_order(stride)
def make_loader(self):
def loader(index):
indexer = self.layout.make_indexer()
return ops.load(self.name, indexer(index))
return loader
def is_no_op(self):
return False
def codegen_reference(self):
return self.get_name()
def decide_layout(self):
pass
def get_alias_names(self):
if isinstance(self.layout, AliasedLayout):
return [self.layout.view.get_name()]
return ()
def get_mutation_names(self):
if isinstance(self.layout, MutationLayout):
return [self.layout.target.get_name()]
return ()
@cache_on_self
def get_read_writes(self):
with patch.object(FlexibleLayout, "allow_indexing", True):
return extract_read_writes(
self.make_loader(),
self.get_size(),
)
def get_reads(self):
return self.get_read_writes().reads
def realize(self):
pass
class InputBuffer(Buffer):
pass
class ConstantBuffer(InputBuffer):
override_device = None
def make_loader(self):
def loader(index):
indexer = self.layout.make_indexer()
return ops.load(
V.graph.constant_name(self.name, self.override_device), indexer(index)
)
return loader
def constant_to_device(self, device):
return ConstantBuffer(V.graph.constant_name(self.name, device), self.layout)
class RandSeedBuffer(ConstantBuffer):
def codegen_reference(self):
# Clone makes sure if we pass this from forwards to backwards
# the value does not get clobbered by the time backwards is run.
return self.get_name() + ".clone()"
class NoneAsConstantBuffer(IRNode):
def codegen_reference(self):
return V.graph.wrapper_code.none_str
class ShapeAsConstantBuffer(IRNode):
def __init__(self, shape):
super().__init__()
self.shape = shape
def codegen_reference(self):
from torch._inductor.codegen.wrapper import pexpr
expr = pexpr(V.graph.sizevars.simplify(self.shape))
if V.graph.cpp_wrapper:
# wrap scalar to 0-d tensor for cpp wrapper
return f"torch::tensor({expr})"
else:
return expr
@dataclasses.dataclass
class ComputedBuffer(Buffer):
data: Loops
@cache_on_self
def get_read_writes(self):
with patch.object(FlexibleLayout, "allow_indexing", True):
if self.data.get_reduction_type():
return extract_read_writes(
self.get_store_function(),
self.data.get_size(),
self.data.get_reduction_size(),
)
else:
return extract_read_writes(
self.get_store_function(),
self.data.get_size(),
)
def get_store_function(self):
indexer = self.layout.as_fixed().make_indexer()
if self.data.get_reduction_type():
return partial(self.data.store_reduction, self.name, indexer)
else:
return partial(self.data.store_output, self.name, indexer)
def get_fill_order(self):
"""
If our layout is still flexible, try to determine the stride order based on stride orders of reads.
TODO(jansel): A better algorithm here would look at downstream consumers of this
value and try to do global graph-level layout optimization.
This is also something just begging to be autotuned.
"""
if isinstance(self.layout, FlexibleLayout):
(index_vars, reduction_vars), _ = dependencies.index_vars_squeeze(
self.data.get_size(), self.data.get_reduction_size()
)
reads = self.get_read_writes().reads
reads_bufs = [
V.graph.name_to_buffer[r.name]
if r.name in V.graph.name_to_buffer.keys()
else None
for r in reads
]
# only consider reads to buffer of same size
reads = [
sympy_subs(
r.index, {v: sympy.Integer(0) for v in reduction_vars if v != 0}
)
for r in reads
]
if reads:
stride_lengths = [
V.graph.sizevars.stride_hints(expr, index_vars) for expr in reads
]
from .scheduler import pick_loop_order
return pick_loop_order(stride_lengths, self.get_size())
return None
def decide_layout(self):
if isinstance(self.layout, FlexibleLayout):
order = self.get_fill_order()
if order:
self.freeze_layout_with_fill_order(order)
else:
self.freeze_layout()
def simplify_and_reorder(self):
"""
This is a main place where we do loop transformations in a
backend-agnostic way.
Here we:
1) Remove any 1 dimensions
2) Fuse contiguous dimensions together
3) Reorder dimensions based on stride orders
"""
args, var_ranges = dependencies.index_vars_squeeze(
self.data.get_size(), self.data.get_reduction_size(), prefix="q"
)
with patch.object(ConstantBuffer, "override_device", self.get_device()):
body = LoopBody(
self.get_store_function(),
(args if self.get_reduction_type() else args[:1]),
var_ranges,
)
index_formulas = [*body.indexing_exprs.values()]
reads_bufs = [
V.graph.name_to_buffer[reads_name]
if reads_name in V.graph.name_to_buffer.keys()
else None
for reads_name in body.reads_name2expr.keys()
]
memory_addrs = [
*body.reads_name2expr.values(),
*body.writes_name2expr.values(),
]
index_vars = []
reduce_vars = []
index_size = []
reduce_size = []
for v, s in var_ranges.items():
if v in args[0]:
assert not reduce_vars
index_vars.append(v)
index_size.append(s)
else:
assert v in args[1]
reduce_vars.append(v)
reduce_size.append(s)
# the reordering_reindex in reads' simplify_reorder_and_tile
reordering_reindex = [same_reorder(range(len(index_vars)))] * len(memory_addrs)
for i, reads_buf in enumerate(reads_bufs):
if isinstance(reads_buf, ComputedBuffer) and hasattr(
reads_buf, "iter_reordering_reindex"
):
reordering_reindex[i] = reads_buf.iter_reordering_reindex
def simplify_and_reorder(x_vars, support_vars, sizes, reordering_reindex=None):
sizes, reindex0, reindex1 = self._apply_loop_reordering(
x_vars, support_vars, sizes, memory_addrs, reordering_reindex
)
# for NHWC: reindex0([0,1,2,3]) = [0,2,3,1], reindex1([0,1,2,3]) = [0,3,2,1]
x_vars = reindex0(x_vars)
sizes, reindex2, prune = V.graph.sizevars._simplify_loops(
x_vars,
sizes,
index_prevent_reordering(index_formulas, x_vars, sizes),
)
x_vars = prune(x_vars)
# sizes, reindex1, prune = _simplify_loops(x_vars, sizes, index_formulas)
# x_vars = prune(x_vars)
# sizes, reindex2 = self._apply_loop_reordering(x_vars, sizes, memory_addrs)
reindex = fuse_reindexing(reindex1, reindex2)
return sizes, reindex, reindex1
support_vars = index_vars + reduce_vars
iter_ranges, iter_reindex, iter_reordering_reindex = simplify_and_reorder(
index_vars, support_vars, index_size, reordering_reindex
)
reduce_ranges, reduce_reindex, _ = simplify_and_reorder(
reduce_vars, support_vars, reduce_size
)
# remember the reordering if not have loop collapse.
if len(iter_ranges) == len(index_vars):
self.iter_reordering_reindex = iter_reordering_reindex
# retrace the loop body with simplification and reordering applied
(iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze(
iter_ranges, reduce_ranges, prefix="z"
)
body = LoopBody(
body, [iter_reindex(iter_vars), reduce_reindex(reduce_vars)], var_ranges
)
return (iter_ranges, reduce_ranges), body
@staticmethod
def _apply_loop_reordering(
index_vars,
support_vars,
sizes,
memory_addrs,
reordering_reindex=None,
priority_idx=None,
):
"""
Shuffle the order of loops around to hopefully improve performance.
"""
from .scheduler import pick_loop_order
if priority_idx is None:
priority_idx = []
try:
strides = [
V.graph.sizevars.stride_hints(expr, index_vars, support_vars)
for expr in memory_addrs
]
assert len(strides) == len(memory_addrs) and len(strides[0]) == len(
index_vars
)
# consider both layout(strides) and reordering(reordering_reindex)
if reordering_reindex is not None:
for i in range(len(memory_addrs)):
try:
strides[i] = reordering_reindex[i](strides[i])
# if len(order) != len(strides), do not reorder
except AssertionError:
pass
order = list(reversed(pick_loop_order(strides, sizes, priority_idx)))
except Exception:
if config.debug:
log.warning(
"Did not simplify complex index:\n%s\n%s",
dict(zip(index_vars, sizes)),
memory_addrs,
)
order = list(range(len(sizes)))
sizes = [sizes[i] for i in order]
return sizes, same_reorder(order), inverse_reorder(order)
def get_reduction_size(self):
return self.data.get_reduction_size()
def get_reduction_type(self):
return self.data.get_reduction_type()
def is_no_op(self):
return self.data.is_zero_elements()
def should_allocate(self):
return True
def constant_to_device(self, device):
"""Move this to a given device. Requires that all reads are to constants."""
return self.data.constant_to_device(device)
class TemplateBuffer(Buffer):
"""
Represents a Triton (in the future other type) of template operator
that we can fuse an epilogue onto.
"""
def __init__(self, layout, inputs, make_kernel_render):
super().__init__(name=None, layout=layout)
self.inputs = InputsKernel.unwrap_storage(inputs)
self.make_kernel_render = make_kernel_render
self.name = V.graph.register_buffer(self)
def get_read_writes(self):
return self.normalized_read_writes()
@cache_on_self
def normalized_read_writes(self):
name = self.get_name()
indexer = self.layout.make_indexer()
def dummy(index, rindex):
assert len(rindex) == 0
return ops.store(name, indexer(index), "fake")
deps = dependencies.extract_read_writes(
dummy, self.get_size(), (), normalize=True
)
deps.reads = {dependencies.StarDep(x.get_name()) for x in self.inputs}
return deps
def get_reduction_size(self):
return 1
def get_reduction_type(self):
return None
def is_no_op(self):
return False
def should_allocate(self):
return True
def simplify_and_reorder(self):
return (
(
self.get_size(),
(),
),
None,
)
@dataclasses.dataclass
class InputsKernel(Buffer):
inputs: List[Buffer]
def get_read_writes(self):
return dependencies.ReadWrites(
{dependencies.StarDep(x.get_name()) for x in self.inputs},
{dependencies.StarDep(self.get_name())},
set(),
[],
None,
op_counts=collections.Counter(),
)
@staticmethod
def unwrap_storage(inputs):
inputs_new = []
for x in inputs:
if isinstance(x, TensorBox):
x = x.data
if isinstance(x, StorageBox):
x = x.data
if isinstance(x, BaseView) and not isinstance(x, ReinterpretView):
x = ExternKernel.realize_input(x)
assert isinstance(x, (Buffer, ReinterpretView)), x
inputs_new.append(x)
return inputs_new
def is_extern(self):
return True
class NopKernel(InputsKernel):
def is_no_op(self):
return True
class ConcatKernel(NopKernel):
"""
There isn't actually a real kernel for concat, we just change the
storage for the upstream data.
"""
@classmethod
def create(cls, inputs, dim):
device = inputs[0].get_device()
dtype = inputs[0].get_dtype()
new_size = list(inputs[0].get_size())
offsets_start = [0]
offsets_end = [new_size[dim]]
assert 0 <= dim < len(new_size)
for i in range(1, len(inputs)):
input_size = inputs[i].get_size()
offsets_start.append(new_size[dim])
assert len(input_size) == len(new_size)
assert inputs[i].get_dtype() == dtype
assert inputs[i].get_device() == device
for j in range(len(new_size)):
if j == dim:
new_size[j] = new_size[j] + input_size[j]
else:
new_size[j] = V.graph.sizevars.guard_equals(
new_size[j], input_size[j]
)
offsets_end.append(new_size[dim])
output_stride = FlexibleLayout.contiguous_strides(new_size)
# If any of the inputs is in CL format, use CL format for the output
for i in range(len(inputs)):
x = inputs[i]
if is_storage_and_layout(x):
layout = x.get_layout()
if (
isinstance(layout, FixedLayout)
and layout.is_channels_last_contiguous()
):
# use CL stride for the output
output_stride = make_channels_last_strides_for(new_size)
break
kernel = ConcatKernel(
name=None,
layout=FixedLayout(
device=device,
dtype=dtype,
size=new_size,
stride=output_stride,
),
inputs=[],
)
kernel = StorageBox(kernel)
for i in range(len(inputs)):
kernel.data.inputs.append(
cls.realize_into(
inputs[i],
SliceView.create(kernel, dim, offsets_start[i], offsets_end[i]),
)
)
kernel.data.name = V.graph.register_buffer(kernel.data)
kernel.data.inputs = cls.unwrap_storage(kernel.data.inputs)
return kernel
@classmethod
def realize_into(cls, src, dst):
# Attempt to turn this into a ReinterpretView rather than assert.
# This has concessions around layout, as as_storage_and_layout
# can cause us to go from flexible to fixed layout.
if not isinstance(dst, ReinterpretView):
if is_storage_and_layout(dst):
storage, layout = as_storage_and_layout(dst)
dst = ReinterpretView(storage, layout)
assert isinstance(dst, ReinterpretView), dst
if isinstance(src, TensorBox):
# unwrap a TensorBox
return cls.realize_into(src.data, dst)
if isinstance(src, StorageBox):
src.realize()
# ExternKernelAlloc has specific requirements for output layout, should create a copy
if isinstance(src.data.layout, FlexibleLayout) and not isinstance(
src.data, ExternKernelAlloc
):
src.data.layout = AliasedLayout(dst)
return src.data
# introduce a copy
pw = Pointwise.create(
device=src.get_device(),
dtype=src.get_dtype(),
inner_fn=src.make_loader(),
ranges=[
V.graph.sizevars.guard_equals(a, b)
for a, b in zip(src.get_size(), dst.get_size())
],
)
return cls.realize_into(pw, dst)
def should_allocate(self):
return True
@dataclasses.dataclass
class ExternKernel(InputsKernel):
constant_args: Tuple[Any, ...] = ()
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
output_view: Optional[ReinterpretView] = None
def decide_layout(self):
if isinstance(self.layout, FlexibleLayout):
self.apply_constraint()
self.freeze_layout()
def codegen(self, wrapper):
raise NotImplementedError()
@staticmethod
def copy_input(x):
pw = Pointwise.create(
device=x.get_device(),
dtype=x.get_dtype(),
inner_fn=x.make_loader(),
ranges=x.get_size(),
origin_node=x.get_origin_node(),
traceback=x.get_traceback(),
)
pw.realize()
return pw
@classmethod
def process_kernel(cls, kernel, *args, **kwargs):
binded_args = signature(kernel).bind(*args, **kwargs).arguments
args_flat, args_spec = pytree.tree_flatten(binded_args)
is_arg_tensor = []
tensor_args = []
non_tensor_args = []
for arg in args_flat:
is_arg_tensor.append(isinstance(arg, IRNode))
if is_arg_tensor[-1]:
tensor_args.append(arg)
else:
if isinstance(arg, sympy.Expr):
arg = V.graph.sizevars.shape_env.create_symintnode(arg, hint=None)
non_tensor_args.append(arg)
def unflatten_args(new_tensor_args, new_non_tensor_args):
result = []
it_tensors = iter(new_tensor_args)
it_non_tensors = iter(new_non_tensor_args)
for is_tensor in is_arg_tensor:
if is_tensor:
result.append(next(it_tensors))
else:
result.append(next(it_non_tensors))
result = pytree.tree_unflatten(result, args_spec)
return result.get("args", []), result.get("kwargs", {})
tensor_args = [cls.realize_input(x) for x in tensor_args]
# freeze layout otherwise our output stride calculation might
# become incorrect
for x in tensor_args:
if is_storage_and_layout(x):
as_storage_and_layout(x, freeze=True)
# We don't have generic shape formulas, so just burn in the
# shapes and run an example input.
# TODO(jansel): replace this with dynamic shape formulas
example_args = []
# We need to retain the constant values of fake tensors that we originally
# propagated the graph with, because for some operators running without a
# constant would trigger an error / DataDependentException
for x in tensor_args:
if x.get_name() in V.graph.constants:
example_args.append(V.graph.constants[x.get_name()])
else:
example_args.append(ir_node_to_tensor(x, guard_shape=True))
new_args, new_kwargs = unflatten_args(example_args, non_tensor_args)
example_output = kernel(*new_args, **new_kwargs)
return example_output, tensor_args, non_tensor_args, unflatten_args
@classmethod
def convert_to_reinterpret_view(cls, x):
"""
In order to pass this to an extern kernel we need a
ReinterpretView not a View. This allows us to avoid some
unneeded copies.
"""
assert isinstance(x, BaseView)
if isinstance(x, ReinterpretView):
return x
x.unwrap_view().freeze_layout()
rw = extract_read_writes(x.make_loader(), x.get_size(), normalize=False)
assert len(rw.reads) == 1
index = V.graph.sizevars.simplify_with_ranges(
list(rw.reads)[0].index, rw.var_ranges
)
strides = V.graph.sizevars.stride_vars(index, rw.range_vars)
offset = V.graph.sizevars.offset_var(index, rw.range_vars)
expected = sympy_dot(rw.range_vars, strides) + offset
if index != expected:
log.debug(
"convert_to_reinterpret_view failed: stride=%s offset=%s index=%s",
strides,
offset,
index,
)
raise NotImplementedError()
return ReinterpretView(
data=x.data,
layout=FixedLayout(
device=x.get_device(),
dtype=x.get_dtype(),
size=x.get_size(),
stride=strides,
offset=offset,
),
)
@classmethod
def realize_input(cls, x):
if x is None:
return NoneAsConstantBuffer()
if isinstance(x, (sympy.Expr, sympy.Rel, int)):
return ShapeAsConstantBuffer(x)
if isinstance(x, Constant):
return V.graph.add_tensor_constant(
torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device())
)
if isinstance(x, ConstantBuffer):
return x
if isinstance(x, TensorBox):
return cls.realize_input(x.data)
if isinstance(x, ReinterpretView):
return x
if isinstance(x, BaseView):
x.realize()
if is_storage_and_layout(x.unwrap_view()) and not isinstance(
x.unwrap_view().data, ExternKernelAlloc
):
try:
return cls.convert_to_reinterpret_view(x)
except NotImplementedError:
pass
if isinstance(x, StorageBox):
# TODO(jansel): impose layout preference on realized buffer
x.realize()
return x
return cls.copy_input(x)
@classmethod
def require_stride1(cls, x):
if is_storage_and_layout(x):
if len(x.get_stride()) == 0:
return x
for stride in x.get_stride():
if stride == 1:
return x
return cls.copy_input(x)
@classmethod
def require_stride_order(cls, x, order):
if x.get_numel() == 0: # Layout doesn't matter
return x
# require x to have the layout as strided_ordered as order
if is_storage_and_layout(x):
if isinstance(x.get_layout(), FlexibleLayout):
# fix flexiblelayout to be FixedLayout with stride_order
as_storage_and_layout(
x, freeze=True, want_contiguous=False, stride_order=order
)
return x
elif isinstance(
x.get_layout(), FixedLayout
) and x.get_layout().is_stride_ordered(order):
return x
elif isinstance(x.get_layout(), MutationLayout):
if isinstance(x.get_layout().real_layout(), FlexibleLayout):
raise AssertionError(
"the MutationLayout's real layout shouldn't be FlexibleLayout"
)
elif isinstance(
x.get_layout().real_layout(), FixedLayout
) and x.get_layout().real_layout().is_stride_ordered(order):
return x
# TODO - Storage to InputBuffer
if isinstance(x, InputBuffer) and x.get_layout().is_stride_ordered(order):
return x
x = cls.copy_input(x)
as_storage_and_layout(x, freeze=True, want_contiguous=False, stride_order=order)
assert is_stride_order_storage_and_layout(x, order)
return x
@classmethod
def require_contiguous(cls, x):
return cls.require_stride_order(x, list(reversed(range(len(x.get_size())))))
def apply_constraint(self):
pass
def codegen_const_args(self):
return map(V.graph.wrapper_code.val_to_str, self.constant_args)
def codegen_args(self):
args = [x.codegen_reference() for x in self.inputs]
args.extend(self.codegen_const_args())
return args
def codegen_kwargs(self):
kwargs = []
if self.kwargs:
if V.graph.cpp_wrapper:
# TODO: use native_functions.yaml as the ground truth
assert (
self.ordered_kwargs_for_cpp_kernel
), "ordered_kwargs_for_cpp_kernel has to be provided"
for arg_name in self.ordered_kwargs_for_cpp_kernel:
assert arg_name in self.kwargs, (
"arg %s not found in self.kwargs" % arg_name
)
v = self.kwargs.get(arg_name)
kwargs.append(V.graph.wrapper_code.val_to_str(v))
else:
kwargs = [
f"{k}={V.graph.wrapper_code.val_to_str(v)}"
for k, v in self.kwargs.items()
]
return kwargs
def codegen_size_asserts(self, wrapper):
if config.size_asserts and not V.graph.cpp_wrapper:
size = V.graph.wrapper_code.codegen_shape_tuple(self.get_size())
stride = V.graph.wrapper_code.codegen_shape_tuple(self.get_stride())
wrapper.writeline(
f"assert_size_stride({self.get_name()}, {size}, {stride})"
)
def get_group_stride(self):
"""
get output sizes and strides, for template_codegen
"""
_size = self.get_size()
_stride = self.get_stride()
# iter_ranges = _size of output tensor, reduce_range = [] because no reduction
return [_size, []], _stride
def canonicalize(self):
"""
Manually get canonicalization of the output index
"""
# manually generate index formula for conv
sizevars = V.graph.sizevars
sizes = self.get_size()
strides = self.get_stride()
strides = [sizevars.size_hint(x) for x in strides]
index_vars = [sympy_symbol(f"d{i}") for i in range(len(sizes))]
# reorder index vars according to stride
index_order = sorted(range(len(strides)), key=strides.__getitem__, reverse=True)
lookup = {pos: idx for idx, pos in enumerate(index_order)}
order = [lookup[i] for i in range(len(lookup))]
index_vars = [index_vars[i] for i in order]
indexer = self.make_indexer()
index = indexer(index_vars)
new_sizes, reindex, prune = V.graph.sizevars._simplify_loops(
index_vars, sizes, [index]
)
# assign new variables each dimension to deal with numbering mismatches
# d0, d1, d2 could become d0, d2 -- which won't match d0, d1
_, add_var = var_builder("c")
replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes])))
index = sympy_subs(sympy.expand(index), replacement)
return index, tuple(new_sizes)
def __str__(self):
lines = [
f"{field.name}={getattr(self, field.name)}"
for field in dataclasses.fields(self)
]
lines.append(f"origin_node={self.origin_node!r}")
return self.str_helper(lines)
__repr__ = __str__
@dataclasses.dataclass
class ExternKernelOut(ExternKernel):
output_view: Optional[ReinterpretView] = None
def codegen(self, wrapper):
args = [*self.codegen_args(), *self.codegen_kwargs()]
wrapper.generate_extern_kernel_out(
self.output_view,
self.codegen_reference(),
args,
self.kernel,
)
def __init__(
self,
layout,
inputs,
constant_args=(),
kwargs=None,
output_view=None,
kernel=None,
cpp_kernel=None,
ordered_kwargs_for_cpp_kernel=(),
):
super().__init__(
None, layout, self.unwrap_storage(inputs), constant_args, kwargs or {}
)
self.output_view = output_view
self.name = V.graph.register_buffer(self)
self.kernel = cpp_kernel if V.graph.cpp_wrapper else kernel
self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel
def should_allocate(self):
return True
class ExternKernelAlloc(ExternKernel):
def codegen(self, wrapper):
args = [*self.codegen_args(), *self.codegen_kwargs()]
V.graph.wrapper_code.generate_extern_kernel_alloc(
self.get_name(), self.kernel, args, self.get_origin_node()
)
if isinstance(self.layout, Layout):
self.codegen_size_asserts(wrapper)
def __init__(
self,
layout,
inputs,
constant_args=(),
kwargs=None,
kernel=None,
cpp_kernel=None,
ordered_kwargs_for_cpp_kernel=(),
):
super().__init__(
None, layout, self.unwrap_storage(inputs), constant_args, kwargs or {}
)
self.name = V.graph.register_buffer(self)
self.kernel = cpp_kernel if V.graph.cpp_wrapper else kernel
self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel
def should_allocate(self):
return False
def apply_constraint(self):
raise NotImplementedError
class InplaceBernoulliFallback(ExternKernel):
"""
This needs to be a custom class to handle mutation properly
"""
kernel = "aten.bernoulli_"
def codegen(self, wrapper):
(x,) = [t.codegen_reference() for t in self.inputs]
wrapper.writeline(
f"{self.kernel}({x}, {', '.join(map(repr, self.constant_args))})"
)
def should_allocate(self):
return False
def get_mutation_names(self):
assert isinstance(self.layout, MutationLayout)
return (self.layout.target.get_name(),)
def __init__(self, x, *constant_args):
super().__init__(
None,
MutationLayout(x),
self.unwrap_storage([x]),
constant_args,
)
self.name = V.graph.register_buffer(self)
class ScatterFallback(ExternKernel):
"""
This needs to be a custom class to handle mutation properly.
This class handles both aten.scatter_ and aten.scatter_reduce_.
It also handle the case `src` being a scalar properly.
"""
def codegen(self, wrapper):
if self.src_is_tensor:
(x, index, src) = [t.codegen_reference() for t in self.inputs]
else:
(x, index) = [t.codegen_reference() for t in self.inputs]
src = self.constant_args[1]
line = f"{self.kernel}({x}, {self.constant_args[0]}, {index}, {src}"
if self.kernel == "aten.scatter_":
if self.kwargs["reduce"]:
line += f", reduce={repr(self.kwargs['reduce'])}"
else:
line += ", ".join([""] + self.codegen_kwargs())
line += ")"
wrapper.writeline(line)
def should_allocate(self):
return False
def __init__(
self,
fn,
x,
dim: int,
index,
src,
*,
reduce: str = None,
include_self: bool = True,
):
assert fn in {"aten.scatter_", "aten.scatter_reduce_"}
self.kernel = fn
self.src_is_tensor = isinstance(src, TensorBox)
if self.src_is_tensor:
tensors = [self.realize_input(t) for t in [x, index, src]]
constant_args = [dim]
else:
tensors = [self.realize_input(t) for t in [x, index]]
constant_args = [dim, src]
super().__init__(
None,
MutationLayout(x),
self.unwrap_storage(tensors),
constant_args,
{"reduce": reduce, "include_self": include_self},
)
self.name = V.graph.register_buffer(self)
class IndexPutFallback(ExternKernel):
"""
This needs to be a custom class to handle mutation and indices properly
"""
def codegen(self, wrapper):
(x, values, *valid_indices) = [t.codegen_reference() for t in self.inputs]
indices = []
iter_valid_indices = iter(valid_indices)
for i, _ in enumerate(self.indices):
if self.indices[i] is not None:
indices.append(next(iter_valid_indices))
else:
indices.append(V.graph.wrapper_code.none_str)
indices = f"{V.graph.wrapper_code.open_bracket}{', '.join(indices)}{V.graph.wrapper_code.closed_bracket}"
args = [x, indices, values, *self.codegen_const_args()]
wrapper.writeline(wrapper.wrap_kernel_call(self.kernel, args))
def should_allocate(self):
return False
def __init__(self, x, indices, values, accumulate):
self.indices = indices
valid_indices = [i for i in indices if i is not None]
tensors = [self.realize_input(x) for x in [x, values, *valid_indices]]
super().__init__(
None,
MutationLayout(x),
self.unwrap_storage(tensors),
[accumulate],
)
self.name = V.graph.register_buffer(self)
self.kernel = "at::index_put_" if V.graph.cpp_wrapper else "aten.index_put_"
class DeviceCopy(ExternKernelOut):
@classmethod
def create(cls, x, device):
if not x.is_extern() and all(
(r.name in V.graph.constants and hasattr(r, "index")) for r in x.get_reads()
):
return x.constant_to_device(device)
V.graph.device_types.add(device.type)
V.graph.add_device_idx(device.index)
V.graph.device_types.add(x.get_device().type)
V.graph.add_device_idx(x.get_device().index)
developer_warning("DeviceCopy in input program")
return DeviceCopy(
FlexibleLayout(
device=device,
dtype=x.get_dtype(),
size=x.get_size(),
),
[cls.realize_input(x)],
)
def codegen(self, wrapper):
args = self.codegen_args()
assert len(args) == 1
if self.output_view:
wrapper.writeline(
f"{self.output_view.codegen_reference()}.copy_({args[0]}){V.graph.wrapper_code.ending}"
)
else:
wrapper.writeline(
f"{self.codegen_reference()}.copy_({args[0]}){V.graph.wrapper_code.ending}"
)
class DynamicScalar(IRNode):
"""
The result of a call to aten._local_scalar_dense.
This is not yet implemented. The one model (so far) that calls this
(fastNLP_Bert) does not actually use the result. So we expect this
node to get dead code eliminated.
"""
def get_reads(self):
return ()
@dataclasses.dataclass
class FallbackKernel(ExternKernelAlloc):
def __init__(
self,
layout,
kernel,
tensor_args,
nontensor_args,
unflatten_args,
kwargs=None,
):
super().__init__(
layout,
tuple(tensor_args),
tuple(nontensor_args),
)
if getattr(torch.ops.aten, kernel.__name__, None) is kernel:
self.kernel = (
f"at::{kernel.__name__}"
if V.graph.cpp_wrapper
else f"aten.{kernel.__name__}"
)
else:
assert (
not V.graph.cpp_wrapper
), f"{kernel.__name__} is not supported with cpp wrapper"
self.kernel = (
f"{kernel.__module__.replace('._ops.', '.ops.')}.{kernel.__name__}"
)
self.unflatten_args = unflatten_args
self.kwargs = {} if kwargs is None else kwargs
V.graph.warn_fallback(self.kernel)
def codegen_args(self):
@dataclasses.dataclass
class Shim:
ref: Any
def __repr__(self):
return self.ref
tensor_args = [Shim(x.codegen_reference()) for x in self.inputs]
args, kwargs = self.unflatten_args(tensor_args, self.constant_args)
args = [V.graph.wrapper_code.val_to_str(x) for x in args]
# let self.codegen_kwargs handle kwargs
self.kwargs.update(kwargs)
return args
@classmethod
def create(cls, kernel, *args, **kwargs):
fake_incorrect_kernels = (
aten._fft_r2c.default,
aten._fft_r2c.out,
aten._fft_c2r.default,
aten._fft_c2c.default,
aten._fft_c2c.out,
aten._linalg_svd.default,
aten._linalg_svd.U,
aten._fused_moving_avg_obs_fq_helper_functional,
)
context = (
V.graph.fake_mode if kernel not in fake_incorrect_kernels else nullcontext()
)
with context:
(
example_output,
tensor_args,
non_tensor_args,
unflatten_args,
) = cls.process_kernel(kernel, *args, **kwargs)
assert tensor_args or isinstance(
example_output, torch.Tensor
), "Not sure where to find device info"
packed = FallbackKernel(
MultiOutputLayout(
tensor_args[0].get_device() if tensor_args else example_output.device
),
kernel,
tensor_args,
non_tensor_args,
unflatten_args,
)
def generate_output(output, indices):
if isinstance(output, (list, tuple)):
return type(output)(
generate_output(output[i], indices + [(type(output), i)])
for i in range(len(output))
)
elif isinstance(output, torch.Tensor):
return MultiOutput(
FixedLayout(
output.device,
output.dtype,
convert_shape_to_inductor(output.size()),
convert_shape_to_inductor(output.stride()),
),
packed,
indices,
)
elif isinstance(output, int):
return output
else:
assert output is None, "FallbackKernel output type is not supported"
return None
return generate_output(example_output, [])
def apply_constraint(self):
return super().apply_constraint()
@dataclasses.dataclass
class MultiOutputLayout(IRNode):
device: torch.device
class MultiOutput(ExternKernel):
def codegen_list_tuple_access(self, basename, indices):
if len(indices) > 0:
itype, i = indices[0]
if itype == list:
return self.codegen_list_tuple_access(f"{basename}[{i}]", indices[1:])
elif itype == tuple:
# cpp wrapper code needs to use std::get<> to access a tuple
tuple_access = V.graph.wrapper_code.codegen_tuple_access(
basename, str(i)
)
return self.codegen_list_tuple_access(tuple_access, indices[1:])
else:
raise AssertionError("non supported index type")
else:
return basename
def codegen(self, wrapper):
line = V.graph.wrapper_code.declare
line += f"{self.get_name()} = {self.codegen_list_tuple_access(self.inputs[0].get_name(), self.indices)}"
line += V.graph.wrapper_code.ending
V.graph.wrapper_code.writeline(line)
self.codegen_size_asserts(V.graph.wrapper_code)
def __init__(self, layout, input, indices: List[Tuple]):
super().__init__(None, layout, [input], ())
self.name = V.graph.register_buffer(self)
self.indices = indices
def should_allocate(self):
return False
def _prepare_convolution_fusion_create(
cls,
x: "TensorBox",
weight: "TensorBox",
bias: "TensorBox",
padding: List[int],
stride: List[int],
dilation: List[int],
groups: int,
transposed: bool = False,
output_padding: List[int] = None,
):
"""
This function is a helper function to prepare inputs, layout and constant args
for convolution post-op fusion's create function, including deciding the output
layout (channels first or channels last), realizing inputs and make them etc. The
function only supports the CPU device since conv post-op fusion kernel is only
supported on CPU right now.
"""
# Port from aten/src/ATen/native/ConvUtils.h: _conv_input_size
def _conv_input_size(
output_size, weight_size, padding, output_padding, stride, dilation, groups
):
assert len(output_size) == len(weight_size), "Expect input dim == weight dim"
dim = len(output_size)
assert dim > 2, "Expect input dim > 2"
BATCH_DIM = 0
WEIGHT_INPUT_CHANNELS_DIM = 1
input_size = []
input_size.append(output_size[BATCH_DIM])
input_size.append(weight_size[WEIGHT_INPUT_CHANNELS_DIM] * groups)
for d in range(2, dim):
kernel = (weight_size[d] - 1) * dilation[d - 2] + 1
input_size_d = (
(output_size[d] - 1) * stride[d - 2]
- (padding[d - 2] * 2)
+ kernel
+ output_padding[d - 2]
)
input_size.append(input_size_d)
return list(map(int, input_size))
# The size of prepacked_weight is the prepacked weight size of deconv:
# Groups > 1: [g*o, i/g, ...]
# Groups == 1: [o, i, ...]
# Returns original weight size in [i, o, ...]
def _original_deconv_weight_size(
prepacked_weight,
groups,
):
prepacked_weight_size = prepacked_weight.size()
dim = len(prepacked_weight_size)
assert dim > 2, "Expect weight dim > 2"
if groups > 1:
weight_size = []
weight_size.append(prepacked_weight_size[1] * groups)
weight_size.append(prepacked_weight_size[0] / groups)
for d in range(2, dim):
weight_size.append(prepacked_weight_size[d])
else:
weight_size = prepacked_weight.transpose(0, 1).size()
return weight_size
x.realize()
weight.realize()
with V.graph.fake_mode:
x_fake = ir_node_to_tensor(x, guard_shape=True)
weight_fake = ir_node_to_tensor(weight, guard_shape=True)
dims = len(x_fake.size()) - 2
assert 0 < len(padding) <= dims
assert 0 < len(dilation) <= dims
assert 0 < len(stride) <= dims
padding = pad_listlike(padding, dims)
dilation = pad_listlike(dilation, dims)
stride = pad_listlike(stride, dims)
if output_padding is None:
output_padding = pad_listlike([0], dims)
else:
assert 0 < len(output_padding) <= dims
output_padding = pad_listlike(output_padding, dims)
assert isinstance(groups, int)
if transposed:
# When transposed, the size of the prepacked oneDNN weight is different
# from the PyTorch weight. We're not able to run aten conv with such
# size. We infer the output size from the input params here:
weight_size = _original_deconv_weight_size(weight_fake, groups)
input_size = x_fake.size()
output_size = _conv_input_size(
input_size,
weight_size,
padding,
output_padding,
stride,
dilation,
groups,
)
else:
bias_fake = (
ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias
)
output = torch.ops.aten.convolution(
x_fake,
weight_fake,
bias_fake,
stride,
padding,
dilation,
transposed,
output_padding,
groups,
)
output_size = output.size()
req_stride_order = [0] + list(reversed(range(1, len(stride) + 1)))
req_stride_order = [len(req_stride_order)] + req_stride_order
output_stride = make_channels_last_strides_for(output_size)
x = cls.require_stride_order(x, req_stride_order)
assert x.get_device().type == "cpu" and weight.get_device().type == "cpu"
inputs = [x, weight]
kernel_layout = FixedLayout(
x.get_device(),
x.get_dtype(),
convert_shape_to_inductor(output_size),
convert_shape_to_inductor(output_stride),
)
constant_args = [padding, stride, dilation, groups]
if transposed:
constant_args.insert(1, output_padding)
if bias is not None:
inputs.append(bias)
else:
constant_args.insert(0, bias)
return inputs, constant_args, kernel_layout, req_stride_order
class ConvolutionUnary(ExternKernelAlloc):
kernel = "torch.ops.mkldnn._convolution_pointwise"
def __init__(
self,
layout,
inputs,
constant_args=(),
kernel="torch.ops.mkldnn._convolution_pointwise",
):
super().__init__(layout, inputs, constant_args)
self.kernel = kernel
def codegen(self, wrapper):
wrapper.writeline(
f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
)
if isinstance(self.layout, Layout):
self.codegen_size_asserts(wrapper)
@classmethod
def create(
cls,
x: "TensorBox",
weight: "TensorBox",
bias: "TensorBox",
padding_: List[int],
stride_: List[int],
dilation_: List[int],
groups: int,
attr,
scalars,
algorithm,
):
kernel = "torch.ops.mkldnn._convolution_pointwise"
(inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create(
cls, x, weight, bias, padding_, stride_, dilation_, groups
)
constant_args = constant_args + [attr, scalars, algorithm]
return ConvolutionUnary(
layout=kernel_layout,
inputs=inputs,
constant_args=constant_args,
kernel=kernel,
)
class ConvolutionBinary(ExternKernelAlloc):
kernel = "torch.ops.mkldnn._convolution_pointwise.binary"
def __init__(
self,
layout,
inputs,
constant_args=(),
kernel="torch.ops.mkldnn._convolution_pointwise.binary",
):
super().__init__(layout, inputs, constant_args)
self.kernel = kernel
def codegen(self, wrapper):
wrapper.writeline(
f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
)
if isinstance(self.layout, Layout):
self.codegen_size_asserts(wrapper)
@classmethod
def create(
cls,
x: "TensorBox",
other: "TensorBox",
weight: "TensorBox",
bias: "TensorBox",
padding_: List[int],
stride_: List[int],
dilation_: List[int],
groups: int,
binary_attr: str,
binary_alpha: Optional[float],
unary_attr: Optional[str],
unary_scalars: Optional[List],
unary_algorithm: Optional[str],
):
kernel = "torch.ops.mkldnn._convolution_pointwise.binary"
(
inputs,
constant_args,
kernel_layout,
req_stride_order,
) = _prepare_convolution_fusion_create(
cls, x, weight, bias, padding_, stride_, dilation_, groups
)
other = cls.require_stride_order(other, req_stride_order)
inputs.insert(1, other)
constant_args = constant_args + [
binary_attr,
binary_alpha,
unary_attr,
unary_scalars,
unary_algorithm,
]
return ConvolutionBinary(
layout=kernel_layout,
inputs=inputs,
constant_args=constant_args,
kernel=kernel,
)
class ConvolutionBinaryInplace(ExternKernelAlloc):
kernel = "torch.ops.mkldnn._convolution_pointwise_.binary"
def __init__(
self,
kernel_layout,
inputs,
constant_args=(),
kernel="torch.ops.mkldnn._convolution_pointwise_.binary",
):
super().__init__(kernel_layout, inputs, constant_args)
self.kernel = kernel
def codegen(self, wrapper):
wrapper.writeline(
f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
)
def get_mutation_names(self):
assert isinstance(self.layout, MutationLayout)
return (self.layout.target.get_name(),)
@classmethod
def create(
cls,
x: "TensorBox",
other: "TensorBox",
weight: "TensorBox",
bias: "TensorBox",
padding_: List[int],
stride_: List[int],
dilation_: List[int],
groups: int,
binary_attr: str,
binary_alpha: Optional[float],
unary_attr: Optional[str],
unary_scalars: Optional[List],
unary_algorithm: Optional[str],
):
kernel = "torch.ops.mkldnn._convolution_pointwise_.binary"
(
inputs,
constant_args,
_,
req_stride_order,
) = _prepare_convolution_fusion_create(
cls, x, weight, bias, padding_, stride_, dilation_, groups
)
other = cls.require_stride_order(other, req_stride_order)
V.graph.realize_users_of(other.get_name())
inputs.insert(1, other)
constant_args = constant_args + [
binary_attr,
binary_alpha,
unary_attr,
unary_scalars,
unary_algorithm,
]
return ConvolutionBinaryInplace(
kernel_layout=MutationLayout(inputs[1]),
inputs=inputs,
constant_args=constant_args,
kernel=kernel,
)
class MKLPackedLinear(ExternKernelAlloc):
def __init__(
self,
layout,
inputs,
constant_args=(),
):
super().__init__(
layout,
inputs,
constant_args,
None,
kernel="torch.ops.mkl._mkl_linear",
cpp_kernel="mkl::_mkl_linear",
)
self.cpp_kernel_key = "mkl_linear"
self.cpp_op_schema = """
at::Tensor(
const at::Tensor& self,
const at::Tensor& mkl_weight_t,
const at::Tensor& origin_weight_t,
const c10::optional<at::Tensor>& bias_opt,
const int64_t prepack_batch_size)"""
def codegen(self, wrapper):
wrapper.generate_fusion_ops_code(
self.get_name(),
self.kernel,
self.codegen_args(),
self.cpp_op_schema,
self.cpp_kernel_key,
)
@classmethod
def create(cls, x, packed_w, orig_w, batch_size):
x = cls.require_stride1(cls.realize_input(x))
orig_w = cls.require_stride1(cls.realize_input(orig_w))
*m, _ = x.get_size()
oc, _ = orig_w.get_size()
output_size = list(m) + [oc]
output_stride = make_contiguous_strides_for(output_size)
inputs = [x, packed_w, orig_w]
constant_args = [None, batch_size]
return MKLPackedLinear(
layout=FixedLayout(
x.get_device(), x.get_dtype(), output_size, output_stride
),
inputs=inputs,
constant_args=constant_args,
)
class LinearUnary(ExternKernelAlloc):
def __init__(
self,
layout,
inputs,
constant_args=(),
):
super().__init__(
layout,
inputs,
constant_args,
None,
kernel="torch.ops.mkldnn._linear_pointwise",
cpp_kernel="mkldnn::_linear_pointwise",
)
self.cpp_kernel_key = "linear_pointwise"
self.cpp_op_schema = """
at::Tensor(
const at::Tensor& input_t,
const at::Tensor& weight_t,
const c10::optional<at::Tensor>& bias_opt,
c10::string_view attr,
torch::List<c10::optional<at::Scalar>> scalars,
c10::optional<c10::string_view> algorithm)"""
def codegen(self, wrapper):
wrapper.generate_fusion_ops_code(
self.get_name(),
self.kernel,
self.codegen_args(),
self.cpp_op_schema,
self.cpp_kernel_key,
)
@classmethod
def create(cls, x, w, b, attr, scalars, algorithm):
x = cls.require_stride1(cls.realize_input(x))
w = cls.require_stride1(cls.realize_input(w))
*m, ic = x.get_size()
oc, ic = w.get_size()
inputs = [x, w]
constant_args = [attr, scalars if scalars else [-1], algorithm]
if b is not None:
b = cls.require_stride1(cls.realize_input(b))
inputs.append(b)
else:
constant_args.insert(0, None)
return LinearUnary(
layout=FlexibleLayout(
device=x.get_device(),
dtype=x.get_dtype(),
size=list(m) + [oc],
),
inputs=inputs,
constant_args=constant_args,
)
def apply_constraint(self):
pass
class LinearBinary(ExternKernelAlloc):
kernel = "torch.ops.mkldnn._linear_pointwise.binary"
def __init__(
self,
layout,
inputs,
constant_args=(),
):
super().__init__(
layout,
inputs,
constant_args,
None,
kernel="torch.ops.mkldnn._linear_pointwise.binary",
cpp_kernel="mkldnn::_linear_pointwise",
)
self.cpp_kernel_overlad_name = "binary"
self.cpp_kernel_key = "linear_pointwise_binary"
self.cpp_op_schema = """
at::Tensor(
const at::Tensor& input_t,
const at::Tensor& other_t,
const at::Tensor& weight_t,
const c10::optional<at::Tensor>& bias_opt,
c10::string_view attr)
"""
def codegen(self, wrapper):
wrapper.generate_fusion_ops_code(
self.get_name(),
self.kernel,
self.codegen_args(),
self.cpp_op_schema,
self.cpp_kernel_key,
self.cpp_kernel_overlad_name,
)
@classmethod
def create(cls, x, y, w, b, attr):
x = cls.require_stride1(cls.realize_input(x))
y = cls.require_stride1(cls.realize_input(y))
w = cls.require_stride1(cls.realize_input(w))
*m, ic = x.get_size()
oc, ic = w.get_size()
inputs = [x, y, w]
constant_args = [attr]
if b is not None:
b = cls.require_stride1(cls.realize_input(b))
inputs.append(b)
else:
constant_args.insert(0, b)
return LinearBinary(
layout=FlexibleLayout(
device=x.get_device(),
dtype=x.get_dtype(),
size=list(m) + [oc],
),
inputs=inputs,
constant_args=constant_args,
)
def apply_constraint(self):
pass
class ConvolutionTransposeUnary(ExternKernelAlloc):
kernel = "torch.ops.mkldnn._convolution_transpose_pointwise"
def __init__(
self,
layout,
inputs,
constant_args=(),
kernel="torch.ops.mkldnn._convolution_transpose_pointwise",
):
super().__init__(layout, inputs, constant_args)
self.kernel = kernel
def codegen(self, wrapper):
wrapper.writeline(
f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})"
)
@classmethod
def create(
cls,
x: "TensorBox",
weight: "TensorBox",
bias: "TensorBox",
padding_: List[int],
output_padding_: List[int],
stride_: List[int],
dilation_: List[int],
groups_: int,
attr,
scalars,
algorithm,
):
kernel = "torch.ops.mkldnn._convolution_transpose_pointwise"
transposed = True
(
inputs,
constant_args,
kernel_layout,
_,
) = _prepare_convolution_fusion_create(
cls,
x,
weight,
bias,
padding_,
stride_,
dilation_,
groups_,
transposed,
output_padding_,
)
constant_args = constant_args + [attr, scalars, algorithm]
return ConvolutionTransposeUnary(
layout=kernel_layout,
inputs=inputs,
constant_args=constant_args,
kernel=kernel,
)
@dataclasses.dataclass
class MutableBox(IRNode):
"""
TensorBox / StorageBox allow in-place mutation of Tensors
"""
data: IRNode
def __getattr__(self, name):
fn = getattr(self.data, name)
if callable(fn):
return fn
raise AttributeError(f"{type(self.data).__name__}.{name} not callable")
@property
def layout(self):
return self.data.layout
def __str__(self):
if isinstance(self.data, MutableBox):
line0 = f"{type(self).__name__}({type(self.data).__name__}("
endl = "))"
inner = self.data.data
else:
line0 = f"{type(self).__name__}("
inner = self.data
endl = ")"
lines = [
line0,
indent(str(inner)),
endl,
]
return "\n".join(lines)
__repr__ = __str__
class TensorBox(MutableBox):
@staticmethod
def create(data):
return TensorBox(StorageBox(data))
class StorageBox(MutableBox):
def is_input_buffer(self):
if isinstance(self.data, (InputBuffer, ReinterpretView)):
return self.data.get_name() in V.graph.graph_inputs
return False
def realize(self):
if isinstance(
self.data,
(
ComputedBuffer,
InputsKernel,
InputBuffer,
ReinterpretView,
TemplateBuffer,
),
):
return self.data.get_name()
assert isinstance(self.data, (Pointwise, Reduction)), type(self.data)
origin_node = self.data.get_origin_node()
traceback = self.data.get_traceback()
self.data = ComputedBuffer(
name=None,
layout=FlexibleLayout(
device=self.data.get_device(),
dtype=self.data.get_dtype(),
size=self.data.get_size(),
),
data=self.data,
)
self.data.name = V.graph.register_buffer(self.data)
self.data.origins = self.origins
self.data.origin_node = origin_node
self.data.traceback = traceback
return self.data.name
def realize_hint(self):
"""
Called on buffers we expect to be forced to realize later.
"""
if (
isinstance(self.data, (Pointwise, Reduction))
and self.num_reads() > 1
and self.is_pointwise_non_scalar_tensor_num_reads_larger_than_one()
):
self.realize()
def has_exceeded_max_reads(self):
return isinstance(self.data, Pointwise) and (
self.num_reads() > config.realize_acc_reads_threshold
or len(self.inner_fn_str()) > config.realize_bytes_threshold
)
def mark_reuse(self, users):
"""
A heuristic to decide if we should realize a tensor
that is used multiple times.
"""
def should_realize_on_cpu(loops: Union[Pointwise, Reduction]):
"""
The heuristic for realizing reused result of heavy ops on cpu
"""
heavy_ops = ["exp"] # a list of heavy ops
fn_str = loops.inner_fn_str()
return any((op + "(") in fn_str for op in heavy_ops)
if (
users > 1
and isinstance(self.data, (Pointwise, Reduction))
and (
self.num_reads() > config.realize_reads_threshold
or len(self.inner_fn_str()) > config.realize_bytes_threshold
or (is_cpu(self.data) and should_realize_on_cpu(self.data))
)
):
self.realize()
@cache_on_self
def num_reads(self):
data = self.data
if isinstance(data, (InputsKernel, InputBuffer, ReinterpretView)):
return 1
if isinstance(data, ComputedBuffer):
read_writes = data.get_read_writes()
else:
assert isinstance(data, (Pointwise, Reduction)), type(data)
read_writes = ComputedBuffer(
name=None,
layout=FlexibleLayout(
device=data.get_device(),
dtype=data.get_dtype(),
size=data.get_size(),
),
data=data,
).get_read_writes()
return len(read_writes.reads)
@cache_on_self
def is_pointwise_non_scalar_tensor_num_reads_larger_than_one(self):
# Skip the check for non Pointwise instances
return (
(sum(read.index != 0 for read in self.data.get_reads()) > 1)
if isinstance(self.data, Pointwise)
else True
)
class InterpreterShim(torch.fx.Interpreter):
@staticmethod
@functools.lru_cache(None)
def _dummy_gm():
return torch.fx.symbolic_trace(identity)
def __init__(self, graph, submodules):
# call super() with a placeholder to avoid constructing a
# GraphModule which is very expensive (it does codegen).
super().__init__(self._dummy_gm(), garbage_collect_values=False)
self.module = self
self.graph = graph
self.submodules = submodules
self.extra_traceback = False
self.fetch_attr = submodules.__getitem__
self.current_node = None
def run_node(self, n: torch.fx.Node) -> Any:
self.current_node = n
return super().run_node(n)
def run(self, *args, **kwargs):
with V.set_interpreter_handler(self):
return super().run(*args, **kwargs)
class LoopBody:
"""
Captures the body of a Loops subclass into an FX graph. Persists any
indexing simplifications and makes it easier to analyze loop bodies.
"""
def __init__(self, fn, args, var_ranges):
super().__init__()
self.var_ranges = var_ranges
self.indexing_exprs = {}
self.indexing_exprs_name = {}
self.reads = []
self.writes = []
self.reads_name2expr = {}
self.writes_name2expr = {}
self.other = []
self.submodules = {"get_index": self.get_index}
self.subblocks = {}
self.indirect_vars = []
self.indirect_max_sizes = []
self.indirect_new = {}
self.root_block = LoopBodyBlock(self, fn, args)
self.indexing = None
def debug_str(self):
lines = [f"var_ranges = {dict(self.var_ranges)}"]
lines.extend([f"{name} = {val}" for name, val in self.indexing_exprs.items()])
lines.extend(
[
block.debug_str(name)
for name, block in itertools.chain(
[("body", self.root_block)], self.subblocks.items()
)
]
)
return "\n".join(lines)
def add_index_expr(self, expr: sympy.Expr, category, buf_name):
getattr(self, category).append(expr)
if buf_name is not None:
getattr(self, f"{category}_name2expr")[buf_name] = expr
if expr not in self.indexing_exprs_name:
name = f"index{len(self.indexing_exprs)}"
self.indexing_exprs_name[expr] = name
self.indexing_exprs[name] = expr
return self.indexing_exprs_name[expr]
def add_submodule(self, block, prefix):
"""Not actually for nn.Modules, but subblocks in generated code are mapped to FX call_module opcodes"""
if prefix[-1].isnumeric() and prefix not in self.submodules:
name = prefix
else:
name = f"{prefix}{len(self.submodules)}"
self.submodules[name] = block
return name
def add_indirect(self, size):
name = f"indirect{len(self.indirect_vars)}"
var = sympy_symbol(name)
self.indirect_vars.append(var)
self.indirect_max_sizes.append(size)
return var
def replace_indirect(self, old, new):
"""Swap in a variable used in indirect indexing"""
if str(old) == str(new):
return
self.indirect_new[old] = new
self.indexing = {k: sympy_subs(v, {old: new}) for k, v in self.indexing.items()}
def get_index(self, name):
return self.indexing[name]
def __call__(self, *indices):
index = list(itertools.chain(*indices))
assert len(index) == len(self.var_ranges), (index, self.var_ranges)
assert all(v not in self.var_ranges for v in index)
replacements = dict(zip(self.var_ranges.keys(), index))
self.indexing = {
name: sympy_subs(expr, replacements)
for name, expr in self.indexing_exprs.items()
}
result = self.root_block()
self.indexing = None
return result
class LoopBodyBlock:
"""
Captures the body of a Loops subclass into an FX graph.
In normal cases there will be a 1:1 mapping between LoopBody and
LoopBodyBlock, hower in the case of ops.masked() the masked out
operations will manifest as an extra LoopBodyBlock.
"""
def __init__(self, body: LoopBody, fn: Callable, args: List[Any]):
self.body = body
def add_index(expr, category, buf_name=None):
return tracer.create_proxy(
"call_module",
"get_index",
(self.body.add_index_expr(expr, category, buf_name),),
{},
)
class CaptureIndexing(V.WrapperHandler):
self.name = "CaptureIndexing"
def load(self, name: str, index: sympy.Expr):
index = add_index(index, "reads", name)
return self._inner.load(name, index)
def store(self, name, index, value, mode=None):
index = add_index(index, "writes", name)
return self._inner.store(name, index, value, mode)
def reduction(self, name, dtype, src_dtype, reduction_type, index, value):
index = add_index(index, "writes", name)
return self._inner.reduction(
name, dtype, src_dtype, reduction_type, index, value
)
def index_expr(self, index, dtype):
if isinstance(index, (int, sympy.Integer)):
return ops.constant(int(index), dtype)
index = add_index(index, "other")
return self._inner.index_expr(index, dtype)
@staticmethod
def masked(mask_proxy, masked_body: Callable, other_proxy):
"""
Recursively capture the masked out body in another LoopBodyBlock
"""
def shim(mask, other):
return V.ops.masked(mask, subblock, other)
name = self.body.add_submodule(shim, "masked_subblock")
subblock = LoopBodyBlock(self.body, masked_body, [])
self.body.subblocks[name] = subblock
return tracer.create_proxy(
"call_module", name, (mask_proxy, other_proxy), {}
)
@staticmethod
def indirect_indexing(index_proxy, size):
"""
Flow data from tensors into indexing formulas.
Introduce a call_module to update the indexing.
"""
def set_indirect(new_var):
self.body.replace_indirect(
var, V.ops.indirect_indexing(new_var, size)
)
var = self.body.add_indirect(size)
tracer.create_proxy(
"call_module",
self.body.add_submodule(set_indirect, f"set_{var}"),
(index_proxy,),
{},
)
return var
tracer = torch.fx.Tracer()
tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__)
proxy_ops = tracer.create_proxy("placeholder", "ops", (), {})
from .sizevars import SimplifyIndexing
with V.set_ops_handler(
SimplifyIndexing(CaptureIndexing(proxy_ops), self.body.var_ranges)
):
tracer.create_proxy("output", "output", (fn(*args),), {})
self.graph = tracer.graph
def __call__(self):
graph = self.graph
submodules = self.body.submodules
return InterpreterShim(graph, submodules).run(V.get_ops_handler())
def debug_str(self, name="block"):
code = torch.fx.GraphModule(self.body.submodules, self.graph).code
return re.sub(
# strip `; del var0` suffixes to make output prettier
r";[^\n]*",
"",
code.strip().replace("def forward(", f"def {name}("),
)
class Wait(ExternKernelAlloc):
"""
Wait should not be used by itself. It should always be constructed in tandem
with a collective op that produces a work to wait on.
"""
def __init__(
self,
layout,
inputs,
constant_args=(),
):
super().__init__(layout, inputs, constant_args)
def should_allocate(self):
return False
def codegen(self, wrapper):
wrapper.add_import_once(
"from torch.distributed._functional_collectives import _wait_tensor"
)
(input_collective,) = [t.codegen_reference() for t in self.inputs]
wrapper.writeline(f"{input_collective} = _wait_tensor({input_collective})")
# wait op still needs to produce a 'buffer' that represents the tensor output.
# this is a symbolic gesture, and it gets handled by WrapperCodegen.
# codegen outputs a '# reuse' line that assigns the input buffer here ('input_collective')
# to a new name (`self.get_name()`) and `del`s the old name.
wrapper.writeline(f"{self.get_name()} = {input_collective}")
@classmethod
def create(cls, collective_op: "TensorBox"):
# TODO(whc) i'm not sure what's going on here, this probably means I missed something upstream
collective_op.decide_layout()
return Wait(
layout=collective_op.get_layout(),
inputs=[collective_op],
)
def get_alias_names(self):
# Signal to codegen that our output buffer isn't safe to reuse
return [self.inputs[0].codegen_reference()]
class CollectiveKernel(ExternKernel):
"""
Each CollectiveKernel should follow the patterns
- it writes into a given output buffer
- the kernel delegates into c10d processgroup, which returns a 'work' obj
- the work obj is registered via _register_tensor_work so it can be waited on later
"""
def __init__(self, layout, inputs, constant_args):
super().__init__(None, layout, inputs, constant_args)
self.name = V.graph.register_buffer(self)
def should_allocate(self):
return True
def codegen_collective(self, wrapper, output_name, input_names):
# factor so the boilerplate can be handled in CollectiveKernel.codegen
raise NotImplementedError("Must implement")
def codegen(self, wrapper):
wrapper.add_import_once("import torch.distributed as dist")
wrapper.add_import_once(
"from torch.distributed._functional_collectives import _str_to_reduce_op, _register_tensor_work"
)
wrapper.add_import_once(
"from torch.distributed.distributed_c10d import _find_or_create_pg_by_ranks_and_tag"
)
# extract references to our args in string form for codegen output
input_names = [t.codegen_reference() for t in self.inputs]
output_name = self.get_name()
tag, ranks, group_size = self.constant_args
# TODO: avoid more than one ref of the same pg (even though they are cached inside the api)
wrapper.writeline(
f"{output_name}_pg = _find_or_create_pg_by_ranks_and_tag('{tag}', {ranks}, {group_size})"
)
self.codegen_collective(wrapper, output_name, input_names)
wrapper.writeline(f"_register_tensor_work({output_name}, {output_name}_work)")
class MultiOutputNoSizeAssert(MultiOutput):
"""
Extract partial output from a multi-output OP.
Works like MultiOutput but doesn't assert size. This must be a property guaranteed by the op emiting this.
"""
def codegen(self, wrapper):
wrapper.writeline(
f"{self.get_name()} = {self.inputs[0].get_name()}{self.index}"
)
class InPlaceHint(ExternKernel):
"""
Helper OP to encode an in/out argument that tries to make it inplace whenever possible.
Wrap the input of your inplace op to enable this behavior.
The design is based on two key decisions:
- this node is resposible for allocating the in/out buffer used by the collective.
This is controlled by the ``should_allocate`` method that returns True here and
False for the collective node
- The scheduler special-case this node and enable it to reuse its input.
"""
def codegen(self, wrapper):
input_name = self.inputs[0].codegen_reference()
output_name = self.get_name()
if not wrapper.did_reuse(self, self.inputs[0]):
wrapper.writeline(f"{output_name}.copy_({input_name}) #no reuse")
def __init__(self, layout, input):
input = self.realize_input(input)
super().__init__(None, layout, self.unwrap_storage([input]), ())
self.name = V.graph.register_buffer(self)
def should_allocate(self):
return True
class AllReduceCoalesced(ExternKernel):
def __init__(self, layout, inputs, constant_args, reduce_op):
super().__init__(None, layout, inputs, constant_args)
self.reduce_op = reduce_op
self.name = V.graph.register_buffer(self)
def should_allocate(self):
return False
@classmethod
def create(
cls,
inputs: List["TensorBox"],
reduce_op: str,
tag: str,
ranks: List[int],
group_size: int,
):
res = []
def wrap_input(var):
nonlocal res
op = InPlaceHint(
FlexibleLayout(var.get_device(), var.get_dtype(), var.get_size()), var
)
res.append(op)
return TensorBox.create(op)
inputs = list(map(wrap_input, inputs))
layout = MultiOutputLayout(inputs[0].get_device())
packed = AllReduceCoalesced(
layout=layout,
inputs=inputs,
constant_args=[tag, ranks, group_size],
reduce_op=reduce_op,
)
for i, in_t in enumerate(inputs):
res.append(
MultiOutputNoSizeAssert(
FlexibleLayout(
in_t.get_device(), in_t.get_dtype(), in_t.get_size()
),
packed,
f"[{i}]",
)
)
return res
def codegen(self, wrapper):
wrapper.add_import_once("import torch.distributed as dist")
wrapper.add_import_once(
"from torch.distributed._functional_collectives import _str_to_reduce_op, _register_tensor_work"
)
wrapper.add_import_once(
"from torch.distributed.distributed_c10d import _find_or_create_pg_by_ranks_and_tag"
)
output_name = self.get_name()
tag, ranks, group_size = self.constant_args
wrapper.writeline(
f"{output_name}_pg = _find_or_create_pg_by_ranks_and_tag('{tag}', {ranks}, {group_size})"
)
inputs = []
for inp in self.inputs:
inputs.append(inp.codegen_reference())
wrapper.writeline(f"{output_name} = [{','.join(inputs)}] ")
wrapper.writeline(
f"{output_name}_work = dist.all_reduce_coalesced("
f"{output_name}, "
f"op=_str_to_reduce_op('{str(self.reduce_op)}'), "
f"group={output_name}_pg, "
"async_op=True)"
)
wrapper.writeline(f"_register_tensor_work({output_name}, {output_name}_work)")
class AllReduce(CollectiveKernel):
def __init__(self, layout, inputs, constant_args, reduce_op):
super().__init__(layout, inputs, constant_args)
self.reduce_op = reduce_op
@classmethod
def create(
cls, x: "TensorBox", reduce_op: str, tag: str, ranks: List[int], group_size: int
):
x = cls.realize_input(x)
# is there a difference between literally using x.data.layout below, vs
# creating a new one that has the same properties?
new_layout = FlexibleLayout(x.get_device(), x.get_dtype(), x.get_size())
return AllReduce(
layout=new_layout,
inputs=[x],
constant_args=[tag, ranks, group_size],
reduce_op=reduce_op,
)
def codegen_collective(self, wrapper, output_name, input_names):
# We must copy our input buffer sometimes, but the scheduler will help us find opportunities
# to reuse the input buffer. (This requires no other users of the input buffer.)
if not wrapper.did_reuse(self, self.inputs[0]):
wrapper.writeline(f"{output_name}.copy_({input_names[0]})")
# At this point, output_name points to a buffer that is either
# (1) the input buffer, which we're allowed to inplace modify
# (2) a freshly allocated buffer, which we've copied the input into above
wrapper.writeline(
f"{output_name}_work = dist.all_reduce("
f"{output_name}, async_op=True, group={output_name}_pg, op=_str_to_reduce_op('{str(self.reduce_op)}'))"
)
class AllGatherIntoTensor(CollectiveKernel):
def __init__(self, layout, inputs, constant_args):
super().__init__(layout, inputs, constant_args)
@classmethod
def create(cls, x: "TensorBox", tag: str, ranks: List[int], group_size: int):
x = cls.realize_input(x)
# is there a difference between literally using x.data.layout below, vs
# creating a new one that has the same properties?
new_size = x.get_size()
new_size[0] *= group_size
new_layout = FlexibleLayout(x.get_device(), x.get_dtype(), new_size)
# AllReduce returns a 'work' object. But Inductor's scheduler doesn't need to know
# about that, and we just pretend for scheduling purposes that the work obj is a 1-elem tensor.
# Nobody should consume the output of AllReduce except 'Wait', which we control here.
return AllGatherIntoTensor(
layout=new_layout,
inputs=[x],
constant_args=[tag, ranks, group_size],
)
def codegen_collective(self, wrapper, output_name, input_names):
wrapper.writeline(
f"{output_name}_work = dist.all_gather_into_tensor("
f"{output_name}, {input_names[0]}, async_op=True, group={output_name}_pg)"
)
# At this point, output_name points to a fresh buffer
wrapper.writeline(
f"{output_name}_work = dist.all_gather_into_tensor({output_name}, {input_names[0]}, async_op=True,"
f" group={output_name}_pg)"
)
wrapper.writeline(f"_register_tensor_work({output_name}, {output_name}_work)")
class ReduceScatterTensor(CollectiveKernel):
def __init__(self, layout, inputs, constant_args, reduce_op):
super().__init__(layout, inputs, constant_args)
self.reduce_op = reduce_op
@classmethod
def create(
cls,
x: "TensorBox",
reduce_op: str,
tag: str,
ranks: List[int],
group_size: int,
):
x = cls.realize_input(x)
# is there a difference between literally using x.data.layout below, vs
# creating a new one that has the same properties?
new_size = x.get_size()
new_size[0] /= group_size
new_layout = FlexibleLayout(x.get_device(), x.get_dtype(), new_size)
return ReduceScatterTensor(
layout=new_layout,
inputs=[x],
constant_args=[tag, ranks, group_size],
reduce_op=reduce_op,
)
def codegen_collective(self, wrapper, output_name, input_names):
wrapper.writeline(
f"{output_name}_work = dist.reduce_scatter_tensor("
f"{output_name}, {input_names[0]}, "
f"async_op=True, group={output_name}_pg, op=_str_to_reduce_op('{str(self.reduce_op)}'))"
)