| import collections |
| import contextlib |
| import dataclasses |
| import functools |
| import itertools |
| import logging |
| import re |
| import textwrap |
| import traceback |
| from contextlib import nullcontext |
| from enum import Enum |
| from functools import partial |
| from inspect import signature |
| from typing import Any, Callable, ClassVar, Dict, List, Optional, Set, Tuple, Union |
| from unittest.mock import patch |
| |
| import sympy |
| from sympy import Expr, Integer, simplify |
| |
| import torch._logging |
| |
| import torch.fx |
| import torch.utils._pytree as pytree |
| from torch._dynamo.utils import identity |
| from torch._prims_common import ( |
| compute_required_storage_length, |
| is_boolean_dtype, |
| is_float_dtype, |
| make_channels_last_strides_for, |
| make_contiguous_strides_for, |
| ) |
| from torch.fx.experimental.symbolic_shapes import FloorDiv |
| |
| from . import config, dependencies |
| from .codegen.common import index_prevent_reordering |
| from .cuda_properties import get_device_properties |
| from .dependencies import extract_read_writes, var_builder |
| from .utils import ( |
| argsort, |
| cache_on_self, |
| convert_shape_to_inductor, |
| convert_shape_to_symint, |
| developer_warning, |
| pad_listlike, |
| sympy_dot, |
| sympy_product, |
| sympy_subs, |
| sympy_symbol, |
| ) |
| from .virtualized import ops, V |
| |
| log = logging.getLogger(__name__) |
| indent = functools.partial(textwrap.indent, prefix=" ") |
| aten = torch.ops.aten |
| |
| """ [Note: Inductor IR] |
| |
| Inductor's IR is produced by executing 'lowering' code (see lowering.py). Each |
| lowering is registered to a particular aten operator, and expects inputs that |
| correspond to the aten schema. However, in place of torch Tensor inputs, lowerings |
| expect Inductor TensorBox inputs. |
| |
| TensorBox IR represents torch tensors. Tensors are sometimes single objects owning |
| storage, and sometimes views of another Tensor's storage. Mutating tensor operations |
| (such as add_()) affect the underlying storage and any associated views. Other operations |
| (such as .t_()) update metadata about the current view but don't modify the underlying storage. |
| |
| To model this in Inductor, the IR distinguishes between TensorBox, View, StorageBox and Buffer. |
| |
| TensorBox is the top level IR construct that any lowering should produce and maps to a torch.Tensor |
| output from an operation. But just as torch.Tensors take different forms, TensorBox IR can |
| reference View IR or directly reference StorageBox IRs. |
| |
| Some Inductor lowerings produce new sets of 'Box'es, while others (such as .t() or other view ops) |
| may take an existing TensorBox and point it to a new underlying View IR. |
| |
| Tensors that directly own storage are represented as a chain of: |
| TensorBox -> StorageBox -> Buffer |
| where Buffer is a simple (1D) allocation, and StorageBox introduces the concept of a Layout. |
| |
| If you mutate the data of such a tensor, we swing the StorageBox pointer to point to a new buffer |
| (leaving the old buffer unmodified and functionalizing the operation). |
| |
| Tensors backed by views add one more indirection to the IR. |
| TensorBox -> View -> StorageBox -> Buffer |
| In these cases, the underlying StorageBox/Buffer will be shared with the pre-view TensorBox. |
| """ |
| |
| |
| def validate_ir(node_or_nodes): |
| def _check_tensorbox(nodes): |
| # Could expand this to check deeper properties |
| # (e.g. TensorBox points to View or StorageBox) |
| if isinstance(nodes, (List, Tuple)): |
| for node in nodes: |
| _check_tensorbox(node) |
| else: |
| assert isinstance( |
| nodes, |
| ( |
| DynamicScalar, |
| TensorBox, |
| RandSeedBuffer, |
| sympy.Symbol, |
| sympy.core.relational.Relational, |
| Expr, |
| ), |
| ), f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]" |
| |
| # Be picky about the accepted data structure (don't use pytree here) |
| _check_tensorbox(node_or_nodes) |
| |
| |
| def inverse_reorder(order): |
| inv_order = dict(zip(order, range(len(order)))) |
| |
| def reindex(index): |
| assert len(index) == len(inv_order) |
| return [index[inv_order[i]] for i in range(len(index))] |
| |
| return reindex |
| |
| |
| def same_reorder(order): |
| def reindex(index): |
| assert len(index) == len(order) |
| return [index[order[i]] for i in range(len(index))] |
| |
| return reindex |
| |
| |
| def fuse_reindexing(reindex1, reindex2): |
| def reindex(index): |
| return reindex1(reindex2(index)) |
| |
| return reindex |
| |
| |
| def stride_order2fill_order(order): |
| """ |
| Convert stride order to fill order |
| For channel last format, |
| stride order = [3, 0, 2, 1] and fill order = [1, 3, 2, 0] |
| """ |
| lookup = {pos: idx for idx, pos in enumerate(order)} |
| fill_order = [lookup[i] for i in range(len(order))] |
| return fill_order |
| |
| |
| def get_stride_order(seq): |
| """ |
| Convert strides to stride order |
| """ |
| sorted_idx = argsort(seq) |
| out = [None for _ in range(len(seq))] |
| for i, elem in enumerate(sorted_idx): |
| out[elem] = i |
| return out |
| |
| |
| def ir_node_to_tensor(x, guard_shape=True): |
| if x is None: |
| return None |
| if not guard_shape: |
| shape_fn = V.graph.sizevars.size_hint |
| else: |
| shape_fn = identity |
| size = [shape_fn(s) for s in x.get_size()] |
| if is_storage_and_layout(x): |
| stride = [shape_fn(s) for s in x.get_layout().stride] |
| else: |
| stride = make_contiguous_strides_for(size) |
| dtype = x.get_dtype() |
| device = x.get_device() |
| size = convert_shape_to_symint(size) |
| stride = convert_shape_to_symint(stride) |
| t = torch.empty_strided( |
| size=size, stride=stride, dtype=dtype, device=device |
| ).zero_() |
| return t |
| |
| |
| class ModularIndexing(sympy.Function): |
| """ |
| ModularIndexing(a, b, c) => (a // b) % c |
| """ |
| |
| nargs = (3,) |
| is_integer = True |
| |
| @classmethod |
| def eval(cls, base, divisor, modulus): |
| if base == 0 or modulus == 1: |
| return sympy.Integer(0) |
| |
| if ( |
| isinstance(base, sympy.Integer) |
| and isinstance(divisor, sympy.Integer) |
| and isinstance(modulus, sympy.Integer) |
| ): |
| return (base // divisor) % modulus |
| |
| if divisor != 1: |
| gcd = sympy.gcd(base, divisor) |
| if gcd != 1: |
| return ModularIndexing( |
| simplify(base / gcd), simplify(divisor / gcd), modulus |
| ) |
| |
| if isinstance(base, sympy.Add): |
| new_terms = [] |
| all_positive = True |
| for term in base.args: |
| if sympy.gcd(term, modulus * divisor) != modulus * divisor: |
| if (isinstance(term, sympy.Integer) and term < 0) or ( |
| isinstance(term, sympy.Mul) |
| and isinstance(term.args[0], sympy.Integer) |
| and term.args[0] < 0 |
| ): |
| # workaround for https://github.com/openai/triton/issues/619, |
| # if there are negative terms, // produces wrong result |
| # TODO if https://github.com/openai/triton/issues/619 is fixed |
| # this optimization would become valid |
| all_positive = False |
| break |
| else: |
| new_terms.append(term) |
| |
| if len(new_terms) != len(base.args) and all_positive: |
| return ModularIndexing(sum(new_terms), divisor, modulus) |
| |
| if isinstance(base, FloorDiv): |
| return ModularIndexing(base.args[0], base.args[1] * divisor, modulus) |
| |
| |
| class CleanDiv(FloorDiv): |
| """ |
| Div where we can assume no rounding. |
| This is to enable future optimizations. |
| """ |
| |
| pass |
| |
| |
| class CeilDiv(sympy.Function): |
| """ |
| Div used in indexing that rounds up. |
| """ |
| |
| is_integer = True |
| |
| def __new__(cls, base, divisor): |
| if sympy.gcd(base, divisor) == divisor: |
| return CleanDiv(base, divisor) |
| else: |
| return FloorDiv(base + (divisor - 1), divisor) |
| |
| |
| def get_device_type(x): |
| if getattr(x, "get_device", None): |
| return get_device_type(x.get_device()) |
| if isinstance(x, torch.device): |
| return x.type |
| return None |
| |
| |
| def is_triton(x): |
| return get_device_type(x) == "cuda" |
| |
| |
| def is_cpu(x): |
| return get_device_type(x) == "cpu" |
| |
| |
| @dataclasses.dataclass |
| class IRNode: |
| _current_origins: ClassVar[Set[Any]] = set() |
| |
| @staticmethod |
| @contextlib.contextmanager |
| def current_origins(origins: Set[torch.fx.Node]): |
| old = IRNode._current_origins |
| IRNode._current_origins = old | origins |
| try: |
| yield |
| finally: |
| IRNode._current_origins = old |
| |
| def __post_init__(self): |
| self.origins = set(self._current_origins) |
| self.traceback = traceback.format_stack() if config.debug_ir_traceback else None |
| |
| def get_traceback(self): |
| return self.traceback |
| |
| def common_repr(self): |
| origins = f"origins={getattr(self, 'origins', '')}" |
| if len(origins) > 64: |
| # this can get *very* long |
| origins = f"{origins[:61]}..." |
| return [origins] |
| |
| def str_helper(self, lines): |
| lines = lines + self.common_repr() |
| lines = indent(",\n".join(map(str, lines))) |
| return f"{type(self).__name__}(\n{lines}\n)" |
| |
| def is_user_of(self, name): |
| return any(name == dep.name for dep in self.get_reads()) |
| |
| def get_numel(self): |
| return sympy_product(self.get_size()) |
| |
| |
| @dataclasses.dataclass |
| class Loops(IRNode): |
| device: torch.device |
| dtype: torch.dtype |
| inner_fn: Callable |
| ranges: List[Expr] |
| |
| def __str__(self, names=("ranges",)): |
| return self.str_helper( |
| [ |
| f"'{self.device.type}'", |
| str(self.dtype), |
| self.inner_fn_str(), |
| ] |
| + [f"{name}={getattr(self, name)}" for name in names] |
| + [f"origin_node={self.origin_node!r}"] |
| ) |
| |
| def __post_init__(self): |
| super().__post_init__() |
| self.origin_node = None |
| |
| __repr__ = __str__ |
| |
| def get_dtype(self): |
| return self.dtype |
| |
| def get_device(self): |
| return self.device |
| |
| def get_origin_node(self): |
| return self.origin_node |
| |
| def get_size(self): |
| return self.ranges |
| |
| def is_extern(self): |
| return False |
| |
| @classmethod |
| def create(cls, *args, **kwargs): |
| origin_node = kwargs.pop("origin_node", None) |
| tb = kwargs.pop("traceback", None) |
| r = cls(*args, **kwargs) |
| r.origin_node = origin_node |
| r.traceback = ( |
| tb or traceback.format_stack() if config.debug_ir_traceback else None |
| ) |
| return TensorBox.create(r) |
| |
| @staticmethod |
| def _index(ranges, prefix="i"): |
| return [ |
| sympy.Integer(0) if s == 1 else sympy_symbol(f"{prefix}{n}") |
| for n, s in enumerate(ranges) |
| ] |
| |
| @cache_on_self |
| def inner_fn_str(self): |
| index = self._index(self.ranges) |
| return V.KernelFormatterHandler.ir_to_string(self.inner_fn, index) |
| |
| def is_zero_elements(self): |
| return any(r == 0 for r in self.ranges) |
| |
| @cache_on_self |
| def get_reads(self): |
| with patch.object(FlexibleLayout, "allow_indexing", True): |
| if self.get_reduction_type(): |
| return extract_read_writes( |
| self.make_loader(), |
| self.get_size(), |
| self.get_reduction_size(), |
| ).reads |
| else: |
| return extract_read_writes( |
| self.make_loader(), |
| self.get_size(), |
| ).reads |
| |
| |
| class Pointwise(Loops): |
| def make_loader(self): |
| return self.inner_fn |
| |
| def get_reduction_size(self): |
| return [] |
| |
| def get_reduction_type(self): |
| return None |
| |
| def store_output(self, output_name, indexer, vars): |
| return ops.store(output_name, indexer(vars), self.inner_fn(vars)) |
| |
| def constant_to_device(self, device): |
| """Move this to a given device. Requires that all reads are to constants.""" |
| loader = self.make_loader() |
| loader = patch.object(ConstantBuffer, "override_device", device)(loader) |
| return Pointwise(device, self.dtype, loader, self.ranges) |
| |
| |
| @dataclasses.dataclass |
| class Scatter(Pointwise): |
| output_indexer: Callable[[List[Expr]], Expr] |
| scatter_mode: Optional[str] = None |
| |
| def constant_to_device(self, device): |
| """Move this to a given device. Requires that all reads are to constants.""" |
| loader = self.make_loader() |
| loader = patch.object(ConstantBuffer, "override_device", device)(loader) |
| return Scatter( |
| device, |
| self.dtype, |
| loader, |
| self.ranges, |
| self.output_indexer, |
| self.scatter_mode, |
| ) |
| |
| def store_output(self, output_name, indexer, vars): |
| return ops.store( |
| output_name, |
| indexer(self.output_indexer(vars)), |
| self.inner_fn(vars), |
| mode=self.scatter_mode, |
| ) |
| |
| |
| class ReductionHint(Enum): |
| INNER = 0 |
| OUTER = 1 |
| OUTER_TINY = 2 |
| DEFAULT = 3 |
| |
| |
| class TileHint(Enum): |
| SQUARE = 0 |
| DEFAULT = 1 |
| |
| |
| @dataclasses.dataclass |
| class Reduction(Loops): |
| reduction_ranges: List[Expr] |
| reduction_type: str |
| # self.dtype represents the dst dtype |
| src_dtype: torch.dtype |
| reduction_hint: ReductionHint |
| |
| def __str__(self): |
| return Loops.__str__( |
| self, names=("ranges", "reduction_ranges", "reduction_type") |
| ) |
| |
| __repr__ = __str__ |
| |
| def get_reduction_size(self): |
| return self.reduction_ranges |
| |
| def get_reduction_type(self): |
| return self.reduction_type |
| |
| def store_reduction(self, output_name, indexer, vars, reduction_vars): |
| return ops.reduction( |
| output_name, |
| self.dtype, |
| self.src_dtype, |
| self.reduction_type, |
| indexer(vars), |
| self.inner_fn(vars, reduction_vars), |
| ) |
| |
| def index_length(self): |
| return len(self.ranges) + len(self.reduction_ranges) |
| |
| @cache_on_self |
| def inner_fn_str(self): |
| index = self._index(self.ranges) |
| rindex = self._index(self.reduction_ranges, "r") |
| return V.KernelFormatterHandler.ir_to_string( |
| self.inner_fn, |
| index, |
| rindex, |
| ) |
| |
| def constant_to_device(self, device): |
| """Move this to a given device. Requires that all reads are to constants.""" |
| loader = self.make_loader() |
| loader = patch.object(ConstantBuffer, "override_device", device)(loader) |
| return Reduction( |
| device, |
| self.dtype, |
| loader, |
| self.ranges, |
| self.reduction_ranges, |
| self.reduction_type, |
| self.src_dtype, |
| ReductionHint.DEFAULT, |
| ) |
| |
| @staticmethod |
| def num_splits( |
| device, |
| dst_dtype, |
| src_dtype, |
| inner_fn, |
| ranges, |
| reduction_ranges, |
| reduction_type, |
| reduction_numel, |
| ): |
| num_sm = get_device_properties(device).multi_processor_count |
| min_elements_per_thread = 32 |
| max_elements_per_thread = 512 |
| threads_per_sm = 2048 |
| min_elements_per_device = min_elements_per_thread * num_sm * threads_per_sm |
| max_elements_per_device = max_elements_per_thread * num_sm * threads_per_sm |
| |
| def inner_reduction_splits(reduction_numel_hint, numel_hint): |
| # do heuristics that's close to eager mode for split inner reduction |
| # we leak reduction autotune configs here, and will need to refactor to avoid this later |
| num_warps = 8 |
| num_threads = 32 * num_warps |
| if numel_hint >= 2 * num_sm: # don't split if there are enough outputs |
| return 1 |
| if reduction_numel_hint <= 8192: |
| return 1 |
| if reduction_numel_hint * numel_hint <= min_elements_per_device: |
| split_size = min_elements_per_thread |
| elif reduction_numel_hint * numel_hint < max_elements_per_device: |
| target_blocks = num_sm * threads_per_sm // (2 * num_threads) |
| blocks_per_output = (target_blocks + numel_hint - 1) // numel_hint |
| tmp_split_size = ( |
| reduction_numel_hint + num_threads * blocks_per_output - 1 |
| ) // (num_threads * blocks_per_output) |
| divisors = sympy.divisors(reduction_numel_hint) |
| closest = min(divisors, key=lambda x: abs(x - tmp_split_size)) |
| if abs(closest - tmp_split_size) < 30: |
| # prefer even splits, but never smalle than min_elements_per_thread |
| split_size = max(closest, min_elements_per_thread) |
| else: |
| split_size = tmp_split_size |
| else: |
| divisors = sympy.divisors(reduction_numel_hint) |
| closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread)) |
| if abs(closest - max_elements_per_thread) < 50: |
| # prefer even splits |
| split_size = closest |
| else: |
| split_size = max_elements_per_thread |
| return (reduction_numel_hint + split_size * num_threads - 1) // ( |
| split_size * num_threads |
| ) |
| |
| def outer_reduction_splits(reduction_numel_hint, numel_hint): |
| # TODO the best heuristic currently has XBLOCK (corresponding to numel_hint) 128 |
| # extend to even smaller number of outputs |
| num_warps = 8 |
| num_threads = num_warps * 32 |
| rvals_per_thread = 4 # comes from heuristics, refactor to not leak here |
| xvals_per_block = 128 |
| xblocks = (numel_hint + xvals_per_block - 1) // xvals_per_block |
| if reduction_numel_hint * numel_hint < min_elements_per_device: |
| split_size = min_elements_per_thread |
| elif reduction_numel_hint * numel_hint < max_elements_per_device: |
| target_blocks = num_sm * threads_per_sm // (num_threads) |
| target_blocks = (target_blocks + xblocks - 1) // xblocks |
| tmp_split_size = ( |
| reduction_numel_hint + rvals_per_thread * target_blocks - 1 |
| ) // (rvals_per_thread * target_blocks) |
| divisors = sympy.divisors(reduction_numel_hint) |
| closest = min(divisors, key=lambda x: abs(x - tmp_split_size)) |
| if abs(tmp_split_size - closest) < 20: |
| split_size = max(closest, min_elements_per_thread) |
| else: |
| split_size = tmp_split_size |
| else: |
| divisors = sympy.divisors(reduction_numel_hint) |
| closest = min(divisors, key=lambda x: abs(x - max_elements_per_thread)) |
| if abs(closest - max_elements_per_thread) < 50: |
| # prefer even splits |
| split_size = closest |
| else: |
| split_size = max_elements_per_thread |
| |
| return (reduction_numel_hint + rvals_per_thread * split_size - 1) // ( |
| rvals_per_thread * split_size |
| ) |
| |
| reduction_numel_hint = V.graph.sizevars.size_hint(reduction_numel) |
| numel_hint = V.graph.sizevars.size_hint(sympy_product(ranges)) |
| # easy cases |
| if numel_hint == 1: |
| return ReductionHint.INNER, inner_reduction_splits( |
| reduction_numel_hint, numel_hint |
| ) |
| if ( |
| reduction_numel_hint <= min_elements_per_thread |
| or numel_hint >= num_sm * 2 * 32 |
| ): |
| return ReductionHint.DEFAULT, 1 |
| |
| r = Reduction( |
| device, |
| dst_dtype, |
| inner_fn, |
| ranges, |
| reduction_ranges, |
| reduction_type, |
| src_dtype, |
| ReductionHint.DEFAULT, |
| ) |
| |
| def get_read_indices(r): |
| cb = ComputedBuffer( |
| name=None, |
| layout=FlexibleLayout( |
| device=r.get_device(), |
| dtype=r.get_dtype(), |
| size=r.get_size(), |
| ), |
| data=r, |
| ) |
| read_writes = cb.get_read_writes() |
| # try finding the full size producer |
| # TODO this will fail for something like ((1, N) * (N, 1)).sum() |
| # this would also possibly be wrong for producers with the different contiguity but we hope those cases are rare |
| range_vars = [ |
| r |
| for r in read_writes.range_vars |
| if isinstance(r, sympy.Expr) and not isinstance(r, sympy.Number) |
| ] |
| indices = [] |
| changed = False |
| for md in sorted(read_writes.reads, key=lambda x: x.name): |
| if all(r in md.index.free_symbols for r in range_vars): |
| indices.append(md.index) |
| if md.name in V.graph.name_to_buffer: |
| buf = V.graph.name_to_buffer[md.name] |
| original_stride = buf.layout.stride |
| buf.decide_layout() |
| if buf.layout.stride != original_stride: |
| changed = True |
| return indices, changed |
| |
| indices, changed = get_read_indices(r) |
| if changed: |
| indices, _ = get_read_indices(r) |
| |
| if len(indices) == 0: |
| # TODO determine splits when all inputs are broadcast |
| return ReductionHint.DEFAULT, 1 |
| |
| (_, reduction_vars), ranges = dependencies.index_vars_squeeze( |
| r.get_size(), r.get_reduction_size() |
| ) |
| num_outer = 0 |
| num_inner = 0 |
| for i in indices: |
| i = V.graph.sizevars.simplify_with_ranges(i, ranges) |
| strides = V.graph.sizevars.stride_hints(i, reduction_vars, ranges.keys()) |
| outer = all(s > 1 for s in strides) |
| if outer: |
| num_outer += 1 |
| else: |
| num_inner += 1 |
| if num_inner > num_outer: |
| return ReductionHint.INNER, inner_reduction_splits( |
| reduction_numel_hint, numel_hint |
| ) |
| else: |
| return ReductionHint.OUTER, outer_reduction_splits( |
| reduction_numel_hint, numel_hint |
| ) |
| |
| @staticmethod |
| def _unroll_reduction_fn(inner_fn, reduction_ranges, reduction_type): |
| """Convert inner_fn from a reduction to an pointwise""" |
| reduction_ranges = [ |
| V.graph.sizevars.guard_static_shape(x) for x in reduction_ranges |
| ] |
| |
| if reduction_type == "sum": |
| |
| def combine_fn(a, b): |
| return ops.add(a, b) |
| |
| elif reduction_type == "prod": |
| |
| def combine_fn(a, b): |
| return ops.mul(a, b) |
| |
| elif reduction_type == "xor_sum": |
| |
| def combine_fn(a, b): |
| return ops.bitwise_xor(a, b) |
| |
| elif reduction_type == "min": |
| |
| def combine_fn(a, b): |
| return ops.minimum(a, b) |
| |
| elif reduction_type == "max": |
| |
| def combine_fn(a, b): |
| return ops.maximum(a, b) |
| |
| elif reduction_type == "any": |
| |
| def combine_fn(a, b): |
| return ops.logical_or(a, b) |
| |
| elif reduction_type == "argmin": |
| |
| def combine_fn(a, b): |
| a_value, a_index = a |
| b_value, b_index = b |
| mask = ops.lt(b_value, a_value) |
| a_isnan = ops.ne(a_value, a_value) |
| b_isnan = ops.ne(b_value, b_value) |
| mask = ops.logical_or(mask, ops.gt(b_isnan, a_isnan)) |
| |
| return ( |
| ops.where(mask, b_value, a_value), |
| ops.where(mask, b_index, a_index), |
| ) |
| |
| elif reduction_type == "argmax": |
| |
| def combine_fn(a, b): |
| a_value, a_index = a |
| b_value, b_index = b |
| mask = ops.gt(b_value, a_value) |
| a_isnan = ops.ne(a_value, a_value) |
| b_isnan = ops.ne(b_value, b_value) |
| mask = ops.logical_or(mask, ops.gt(b_isnan, a_isnan)) |
| |
| return ( |
| ops.where(mask, b_value, a_value), |
| ops.where(mask, b_index, a_index), |
| ) |
| |
| else: |
| raise NotImplementedError(f"unknown reduction_type={reduction_type}") |
| |
| def fn(index): |
| return functools.reduce( |
| combine_fn, |
| ( |
| value_fn(index, rindex) |
| for rindex in itertools.product( |
| *[range(x) for x in reduction_ranges] |
| ) |
| ), |
| ) |
| |
| if reduction_type in ("argmin", "argmax"): |
| flatten_index = FixedLayout( |
| None, |
| None, |
| reduction_ranges, |
| FlexibleLayout.contiguous_strides(reduction_ranges), |
| ).make_indexer() |
| |
| def value_fn(index, rindex): |
| rindex = [sympy.expand(i) for i in rindex] |
| return ( |
| inner_fn(index, rindex), |
| ops.index_expr(flatten_index(rindex), torch.int64), |
| ) |
| |
| return lambda index: fn(index)[1] |
| else: |
| value_fn = inner_fn |
| return fn |
| |
| @classmethod |
| def create( |
| cls, |
| device: torch.device, |
| dst_dtype: torch.dtype, |
| src_dtype: torch.dtype, |
| inner_fn: Callable, |
| ranges: List[Expr], |
| reduction_ranges: List[Expr], |
| reduction_type: str, |
| reduction_hint: ReductionHint = ReductionHint.DEFAULT, |
| ): |
| reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges)) |
| |
| if reduction_numel == 0: |
| # N.B. This is a hack to generate the literal of the given type |
| # Ideally, we should be fixing `def constant` in triton.py |
| # but it breaks due to hardcoded dtypes in other places |
| def py_cnst(val): |
| return ( |
| bool(val) |
| if dst_dtype == torch.bool |
| else float(val) |
| if dst_dtype.is_floating_point |
| else int(val) |
| ) |
| |
| rtypes_to_inits = { |
| "sum": py_cnst(0), |
| "xor_sum": py_cnst(0), |
| "prod": py_cnst(1), |
| "any": py_cnst(0), |
| # "all" is desugared to `!any(!val)` |
| } |
| |
| assert ( |
| reduction_type in rtypes_to_inits.keys() |
| ), f"{reduction_type} not supported for zero-dimension tensors!" |
| |
| def const_fn(index): |
| return ops.constant(rtypes_to_inits[reduction_type], dst_dtype) |
| |
| return Pointwise.create( |
| device=device, |
| dtype=src_dtype, |
| inner_fn=const_fn, |
| ranges=list(ranges), |
| ) |
| |
| if reduction_numel == 1: |
| # this reduction is actually a pointwise op |
| if reduction_type in ("argmin", "argmax"): |
| |
| def fn(index): |
| return ops.constant(0, dst_dtype) |
| |
| else: |
| |
| def fn(index): |
| reduction_index = [sympy.Integer(0) for _ in reduction_ranges] |
| return inner_fn(index, reduction_index) |
| |
| return Pointwise.create(device, dst_dtype, fn, ranges) |
| |
| if ( |
| isinstance(reduction_numel, sympy.Integer) |
| and V.graph.sizevars.size_hint(reduction_numel) |
| < config.unroll_reductions_threshold |
| and sympy_product(ranges) != 1 |
| ): |
| return Pointwise.create( |
| device, |
| dst_dtype, |
| cls._unroll_reduction_fn(inner_fn, reduction_ranges, reduction_type), |
| ranges, |
| ) |
| |
| split_reduction = ( |
| is_triton(device) |
| and reduction_type |
| not in { |
| "argmax", |
| "argmin", |
| } |
| and config.split_reductions |
| ) |
| if split_reduction: |
| # TODO(voz): dedup with sizevar util introduced in other PR |
| def _is_static(x): |
| return isinstance(x, (int, sympy.Integer)) |
| |
| split_reduction = ( |
| all(_is_static(r) for r in ranges) |
| and all(_is_static(r) for r in reduction_ranges) |
| and _is_static(reduction_numel) |
| ) |
| |
| if split_reduction: |
| # triton doesn't support reduce to single element well, so break it up |
| hint, split = cls.num_splits( |
| device, |
| dst_dtype, |
| src_dtype, |
| inner_fn, |
| ranges, |
| reduction_ranges, |
| reduction_type, |
| reduction_numel, |
| ) |
| # intermediate reduction in split can contain complex indexing, |
| # and num_splits will fail to correctly set the hint |
| # reuse the passed hint if available |
| if reduction_hint == ReductionHint.DEFAULT: |
| reduction_hint = hint |
| if split > 1: |
| # triton doesn't support reduce to single element well, so break it up |
| return cls.create_multilayer( |
| device, |
| dst_dtype, |
| src_dtype, |
| inner_fn, |
| ranges, |
| reduction_ranges, |
| reduction_type, |
| split, |
| reduction_hint, |
| ) |
| |
| return TensorBox.create( |
| Reduction( |
| device, |
| dst_dtype, |
| inner_fn, |
| ranges, |
| reduction_ranges, |
| reduction_type, |
| src_dtype, |
| reduction_hint, |
| ) |
| ) |
| |
| @staticmethod |
| def default_value(reduction_type, dtype): |
| if reduction_type in {"max", "argmax"}: |
| if is_float_dtype(dtype): |
| return float("-inf") |
| elif is_boolean_dtype(dtype): |
| return 0 |
| else: |
| return torch.iinfo(dtype).min |
| if reduction_type in {"min", "argmin"}: |
| if is_float_dtype(dtype): |
| return float("inf") |
| elif is_boolean_dtype(dtype): |
| return 1 |
| else: |
| return torch.iinfo(dtype).max |
| |
| return { |
| "sum": 0, |
| "prod": 1, |
| "xor_sum": 0, |
| "any": 0, |
| }[reduction_type] |
| |
| @classmethod |
| def create_multilayer( |
| cls, |
| device: torch.device, |
| dst_dtype: torch.dtype, |
| src_dtype: torch.dtype, |
| inner_fn: Callable, |
| ranges: List[Expr], |
| reduction_ranges: List[Expr], |
| reduction_type: str, |
| split: int, |
| reduction_hint: ReductionHint, |
| ): |
| """ |
| Break a large reduction up into multiple smaller reductions |
| recursively |
| """ |
| reduction_numel = sympy_product(reduction_ranges) |
| |
| # TODO(jansel): convert this to dynamic shapes |
| # TODO(jansel): realize the reduction so we can do dynamic indexing |
| reduction_ranges = [ |
| sympy.Integer(V.graph.sizevars.guard_static_shape(s)) |
| for s in reduction_ranges |
| ] |
| reduction_numel = sympy.Integer( |
| V.graph.sizevars.guard_static_shape(reduction_numel) |
| ) |
| |
| if V.graph.sizevars.size_hint(reduction_numel) % split == 0: |
| need_mask = False |
| else: |
| need_mask = True |
| |
| split = sympy.Integer(split) |
| block_size = FloorDiv(reduction_numel + (split - 1), split) |
| |
| reindex = View.dynamic_reshape_indexer(reduction_ranges, [reduction_numel]) |
| |
| def wrapper_fn(index, reduction_index): |
| (reduction_index,) = reduction_index |
| *new_index, reduction_block = index |
| indices = block_size * reduction_block + reduction_index |
| |
| def body(): |
| return inner_fn(new_index, reindex([indices])) |
| |
| if need_mask: |
| mask = ops.lt( |
| ops.index_expr(indices, torch.int32), |
| ops.index_expr(reduction_numel, torch.int32), |
| ) |
| return ops.masked( |
| mask, body, cls.default_value(reduction_type, dst_dtype) |
| ) |
| else: |
| return body() |
| |
| # triton will automatically compute reductions in fp32 if reducing over fp16/bf16 |
| # within the kernel. keep the intermediate in fp32 so as to keep the whole reduction |
| # in fp32 and not reduce precision by breaking up the kernel into multiple layers |
| intermediate_dtype = ( |
| dst_dtype |
| if dst_dtype not in (torch.float16, torch.bfloat16) |
| else torch.float |
| ) |
| intermediate = Reduction.create( |
| device, |
| intermediate_dtype, |
| src_dtype, |
| wrapper_fn, |
| [*ranges, split], |
| [block_size], |
| reduction_type, |
| reduction_hint, |
| ) |
| intermediate.realize() |
| intermediate_loader = intermediate.make_loader() |
| |
| def intermediate_fn(index, reduction_index): |
| return intermediate_loader([*index, *reduction_index]) |
| |
| numel_hint = V.graph.sizevars.size_hint(sympy_product(ranges)) |
| if split <= 512 and numel_hint <= 512 and reduction_hint == ReductionHint.OUTER: |
| reduction_hint = ReductionHint.OUTER_TINY |
| if ( |
| split <= 1024 |
| and numel_hint <= 256 |
| and reduction_hint == ReductionHint.OUTER |
| ): |
| reduction_hint = ReductionHint.OUTER_TINY |
| return TensorBox.create( |
| Reduction( |
| device, |
| dst_dtype, |
| intermediate_fn, |
| ranges, |
| [split], |
| reduction_type, |
| src_dtype, |
| reduction_hint, |
| ) |
| ) |
| |
| |
| def is_storage_and_layout(x): |
| try: |
| as_storage_and_layout(x, freeze=False) |
| return True |
| except NotImplementedError: |
| return False |
| |
| |
| def is_contiguous_storage_and_layout(x): |
| try: |
| buffer, layout = as_storage_and_layout(x, freeze=False) |
| return layout.is_contiguous() |
| except NotImplementedError: |
| return False |
| |
| |
| def as_storage_and_layout(x, freeze=True, want_contiguous=False, stride_order=None): |
| """Try to simplify x into a StorageBox and a Layout""" |
| if isinstance(x, TensorBox): |
| return as_storage_and_layout( |
| x.data, |
| freeze=freeze, |
| want_contiguous=want_contiguous, |
| stride_order=stride_order, |
| ) |
| if isinstance(x, StorageBox) and isinstance(x.data, Buffer): |
| if freeze: |
| if want_contiguous: |
| x.data.freeze_layout() |
| assert x.data.layout.is_contiguous() |
| elif stride_order is not None: |
| x.data.freeze_layout_with_stride_order(stride_order) |
| else: |
| x.data.decide_layout() |
| return x, x.data.layout |
| if isinstance(x, ReinterpretView): |
| # making the base of x contiguous or stride_ordered will not necessarily make |
| # the ReinterpretedView either, so dont pass along those arguments |
| buffer, _ = as_storage_and_layout( |
| x.data, |
| freeze=freeze, |
| ) |
| return buffer, x.layout |
| raise NotImplementedError |
| |
| |
| as_contiguous_storage_and_layout = functools.partial( |
| as_storage_and_layout, want_contiguous=True |
| ) |
| |
| |
| def is_stride_order_storage_and_layout(x, stride_order): |
| try: |
| buffer, layout = as_storage_and_layout(x, freeze=False) |
| return layout.is_stride_ordered(stride_order) |
| except NotImplementedError: |
| return False |
| |
| |
| @dataclasses.dataclass |
| class BaseView(IRNode): |
| data: IRNode |
| |
| def get_dtype(self): |
| return self.data.get_dtype() |
| |
| def get_device(self): |
| return self.data.get_device() |
| |
| def get_origin_node(self): |
| return None |
| |
| def get_name(self): |
| return self.data.get_name() |
| |
| def mark_reuse(self, users): |
| return self.data.mark_reuse(users) |
| |
| def has_exceeded_max_reads(self): |
| return self.data.has_exceeded_max_reads() |
| |
| def realize(self): |
| return self.data.realize() |
| |
| def realize_hint(self): |
| return self.data.realize_hint() |
| |
| def get_storage_numel(self): |
| return self.data.get_storage_numel() |
| |
| def is_extern(self): |
| return self.data.is_extern() |
| |
| @cache_on_self |
| def get_reads(self): |
| with patch.object(FlexibleLayout, "allow_indexing", True): |
| return extract_read_writes( |
| self.make_loader(), |
| self.get_size(), |
| ).reads |
| |
| def unwrap_view(self): |
| x = self |
| while isinstance(x, BaseView): |
| x = x.data |
| return x |
| |
| def constant_to_device(self, device): |
| """Move this to a given device. Requires that all reads are to constants.""" |
| loader = self.make_loader() |
| loader = patch.object(ConstantBuffer, "override_device", device)(loader) |
| return Pointwise(device, self.get_dtype(), loader, self.get_size()) |
| |
| |
| @dataclasses.dataclass |
| class ExpandView(BaseView): |
| size: List[Expr] |
| |
| @staticmethod |
| def _normalize_size(x, new_size): |
| """Replace `-1` with correct sizes""" |
| new_size = list(map(sympy.expand, new_size)) |
| old_size = x.get_size() |
| old_size = [None] * (len(new_size) - len(old_size)) + list(old_size) |
| assert len(new_size) == len(old_size) |
| for i in range(len(new_size)): |
| if new_size[i] == -1: |
| assert old_size[i] is not None |
| new_size[i] = old_size[i] |
| return new_size |
| |
| @classmethod |
| def create(cls, x, new_size): |
| new_size = cls._normalize_size(x, new_size) |
| |
| if is_storage_and_layout(x): |
| storage, old_layout = as_storage_and_layout(x) |
| skip = len(new_size) - len(old_layout.size) |
| assert skip >= 0 |
| new_stride = [sympy.Integer(0)] * skip |
| for stride, size in zip(old_layout.stride, old_layout.size): |
| new_stride.append(stride if size != 1 else sympy.Integer(0)) |
| new_layout = FixedLayout( |
| old_layout.device, |
| old_layout.dtype, |
| list(new_size), |
| new_stride, |
| old_layout.offset, |
| ) |
| return ReinterpretView(storage, new_layout) |
| |
| return ExpandView(x, new_size) |
| |
| def get_size(self): |
| return self.size |
| |
| def make_loader(self): |
| target = self.get_size() |
| actual = self.data.get_size() |
| skip = len(target) - len(actual) |
| inner = self.data.make_loader() |
| |
| def load(index): |
| index = list(index[skip:]) |
| assert len(index) == len(actual) |
| for i in range(len(actual)): |
| if actual[i] == 1: |
| # zero out broadcast dimension |
| index[i] = sympy.Integer(0) |
| return inner(index) |
| |
| return load |
| |
| |
| @dataclasses.dataclass |
| class PermuteView(BaseView): |
| dims: List[Expr] |
| |
| @classmethod |
| def create(cls, x, dims): |
| dims = cls._map_neg_dims(dims) |
| assert set(dims) == set(range(len(dims))) |
| |
| if is_storage_and_layout(x): |
| storage, old_layout = as_storage_and_layout(x) |
| new_layout = FixedLayout( |
| old_layout.device, |
| old_layout.dtype, |
| [old_layout.size[i] for i in dims], |
| [old_layout.stride[i] for i in dims], |
| old_layout.offset, |
| ) |
| return ReinterpretView(storage, new_layout) |
| |
| return PermuteView(x, dims) |
| |
| @classmethod |
| def _map_neg_dims(cls, dims): |
| return [dim if dim >= 0 else len(dims) + dim for dim in dims] |
| |
| def get_size(self): |
| assert set(self._map_neg_dims(self.dims)) == set(range(len(self.dims))) |
| size = self.data.get_size() |
| return [size[i] for i in self.dims] |
| |
| def make_loader(self): |
| inner = self.data.make_loader() |
| inv = {j: i for i, j in enumerate(self.dims)} |
| inv = [inv[i] for i in range(len(self.dims))] |
| assert set(inv) == set(range(len(self.dims))) |
| |
| def load(index): |
| index = [index[i] for i in inv] |
| return inner(index) |
| |
| return load |
| |
| |
| class SqueezeView(BaseView): |
| @classmethod |
| def create(cls, x, *, dim=None): |
| if is_storage_and_layout(x): |
| storage, old_layout = as_storage_and_layout(x) |
| new_size = [] |
| new_stride = [] |
| if dim is not None: |
| assert isinstance(dim, int), "expected integer dim argument" |
| assert 0 <= dim and dim < len(old_layout.size) |
| |
| for i, (size, stride) in enumerate(zip(old_layout.size, old_layout.stride)): |
| if dim is None: |
| if size != 1: |
| new_size.append(size) |
| new_stride.append(stride) |
| else: |
| if i != dim: |
| new_size.append(size) |
| new_stride.append(stride) |
| else: |
| assert size == 1, "expected squeezed size to be 1" |
| |
| new_layout = FixedLayout( |
| old_layout.device, |
| old_layout.dtype, |
| new_size, |
| new_stride, |
| old_layout.offset, |
| ) |
| return ReinterpretView(storage, new_layout) |
| |
| if dim is None: |
| # redirect to a generic view |
| return View.create(x, [s for s in x.get_size() if s != 1]) |
| else: |
| assert x.get_size()[dim] == 1 |
| return View.create(x, [s for i, s in enumerate(x.get_size()) if i != dim]) |
| |
| @staticmethod |
| def squeezer(size: Tuple[sympy.Expr, ...]): |
| new_size = [s for s in size if s != 1] |
| not_one = [i for i, s in enumerate(size) if s != 1] |
| length = len(size) |
| |
| def reindex(index: List[sympy.Expr]) -> List[sympy.Expr]: |
| assert len(index) == len(not_one), f"{index} {not_one}" |
| new_index = [sympy.Integer(0)] * length |
| for idx, s in zip(not_one, index): |
| new_index[idx] = s |
| return tuple(new_index) |
| |
| return new_size, reindex |
| |
| def __init__(self, data): |
| raise AssertionError("use SqueezeView.create()") |
| |
| |
| @dataclasses.dataclass |
| class View(BaseView): |
| size: List[Expr] |
| reindex: Callable |
| |
| def make_indexer(self): |
| base_indexer = self.data.make_indexer() |
| |
| def indexer(idx): |
| return base_indexer(self.reindex(idx)) |
| |
| return indexer |
| |
| @staticmethod |
| def handle_negative_index(idx, size): |
| idx = sympy.expand(idx) |
| size = sympy.expand(size) |
| sizevars = V.graph.sizevars |
| if sizevars.size_hint(idx) < 0: |
| sizevars.guard_lt(idx, 0) |
| idx = idx + size |
| return idx |
| |
| def reindex_str(self): |
| index_old = [sympy_symbol(f"i{n}") for n in range(len(self.size))] |
| index_new = list(self.reindex(index_old)) |
| return f"lambda {', '.join(map(str, index_old))}: {index_new}" |
| |
| def __str__(self): |
| return self.str_helper( |
| [self.data, f"size={self.size}", f"reindex={self.reindex_str()}"] |
| ) |
| |
| __repr__ = __str__ |
| |
| @classmethod |
| def create(cls, x, new_size): |
| assert isinstance(new_size, (tuple, list)) |
| old_size, new_size = cls.resolve_negative_size(x.get_size(), new_size) |
| |
| # Skip pointless views |
| if V.graph.sizevars.statically_known_list_equals(old_size, new_size): |
| return x |
| |
| if 0 in new_size and is_storage_and_layout(x): |
| storage, old_layout = as_storage_and_layout(x, freeze=False) |
| new_layout = FixedLayout( |
| old_layout.device, |
| old_layout.dtype, |
| new_size, |
| FlexibleLayout.contiguous_strides(new_size), |
| old_layout.offset, |
| ) |
| return ReinterpretView(storage, new_layout) |
| # TODO: a new class for FixedTransferLayout that output layout is constrained by input layout |
| elif is_contiguous_storage_and_layout(x) and not isinstance( |
| x.data, ExternKernelAlloc |
| ): |
| storage, old_layout = as_contiguous_storage_and_layout(x) |
| new_layout = FixedLayout( |
| old_layout.device, |
| old_layout.dtype, |
| new_size, |
| FlexibleLayout.contiguous_strides(new_size), |
| old_layout.offset, |
| ) |
| return ReinterpretView(storage, new_layout) |
| |
| reindex = cls.dynamic_reshape_indexer(old_size, new_size) |
| return cls(x, tuple(new_size), reindex) |
| |
| @staticmethod |
| def resolve_negative_size(old_size, new_size): |
| new_size = [V.graph.sizevars.simplify(x) for x in new_size] |
| old_size = [V.graph.sizevars.simplify(x) for x in old_size] |
| |
| new_size = list(new_size) |
| for i in range(len(new_size)): |
| if new_size[i] == -1: |
| new_size[i] = sympy.Integer(1) |
| new_size[i] = CleanDiv(sympy_product(old_size), sympy_product(new_size)) |
| break |
| |
| V.graph.sizevars.guard_equals(sympy_product(old_size), sympy_product(new_size)) |
| return old_size, new_size |
| |
| @classmethod |
| def dynamic_reshape_indexer(cls, old_size, new_size): |
| try: |
| reindex = cls._dynamic_reshape_indexer(old_size, new_size) |
| except (AssertionError, IndexError): |
| # optimistic algorithm failed, lets do a fallback |
| flat = [sympy_product(old_size)] |
| reindex1 = cls._dynamic_reshape_indexer(old_size, flat) |
| reindex2 = cls._dynamic_reshape_indexer(flat, new_size) |
| reindex = fuse_reindexing(reindex1, reindex2) |
| return reindex |
| |
| @staticmethod |
| def _dynamic_reshape_indexer(old_size, new_size): |
| """ |
| Perform a reshape entirely by modifying indexing math |
| """ |
| size_hint = V.graph.sizevars.size_hint |
| vars = [sympy_symbol(f"view{i}") for i in range(len(new_size))] |
| |
| stack_new = list(zip(vars, new_size)) |
| stack_old = list(old_size) |
| |
| view_expr = [] |
| while stack_new and stack_old: |
| size_old = stack_old.pop() |
| var, size_new = stack_new.pop() |
| if size_old == 1: |
| view_expr.append(sympy.Integer(0)) |
| stack_new.append((var, size_new)) # re-add |
| elif size_new == 1: |
| stack_old.append(size_old) # re-add |
| elif size_hint(size_new) == size_hint(size_old): |
| view_expr.append(var) |
| V.graph.sizevars.guard_equals(size_new, size_old) |
| elif size_hint(size_new) < size_hint(size_old): |
| while size_hint(size_new) < size_hint(size_old): |
| var2, size_new2 = stack_new.pop() |
| var = var2 * size_new + var |
| size_new = size_new * size_new2 |
| view_expr.append(var) |
| V.graph.sizevars.guard_equals(size_new, size_old) |
| elif size_hint(size_new) > size_hint(size_old): |
| divisor = sympy.Integer(1) |
| modulus = size_old |
| view_expr.append(ModularIndexing(var, divisor, modulus)) |
| divisor = divisor * modulus |
| while size_hint(size_new) > size_hint(size_old): |
| modulus = stack_old.pop() |
| view_expr.append(ModularIndexing(var, divisor, modulus)) |
| divisor = divisor * modulus |
| size_old = size_old * modulus |
| V.graph.sizevars.guard_equals(size_new, size_old) |
| else: |
| raise AssertionError() |
| |
| while stack_old: |
| size_old = stack_old.pop() |
| V.graph.sizevars.guard_equals(size_old, 1) |
| view_expr.append(sympy.Integer(0)) |
| |
| while stack_new: |
| var, size_new = stack_new.pop() |
| V.graph.sizevars.guard_equals(size_new, 1) |
| |
| view_expr = list(reversed(view_expr)) |
| assert len(view_expr) == len(old_size) |
| |
| def reindex(index): |
| assert len(index) == len(vars), (len(index), len(vars)) |
| replacements = dict(zip(vars, index)) |
| return tuple(sympy_subs(x, replacements) for x in view_expr) |
| |
| return reindex |
| |
| def get_size(self): |
| return self.size |
| |
| def make_loader(self): |
| def load(index): |
| return inner(self.reindex(index)) |
| |
| inner = self.data.make_loader() |
| return load |
| |
| |
| @dataclasses.dataclass |
| class ReinterpretView(BaseView): |
| """Pretend our storage has a different layout""" |
| |
| layout: "Layout" |
| |
| def __post_init__(self): |
| super().__post_init__() |
| if isinstance(self.data, BaseView): |
| self.data = self.data.unwrap_view() |
| |
| def __str__(self): |
| return self.str_helper( |
| [ |
| self.data, |
| self.layout, |
| ] |
| ) |
| |
| __repr__ = __str__ |
| |
| def get_name(self): |
| return self.data.get_name() |
| |
| def get_device(self): |
| return self.layout.device |
| |
| def get_origin_node(self): |
| return None |
| |
| def get_dtype(self): |
| return self.layout.dtype |
| |
| def get_size(self): |
| return list(self.layout.size) |
| |
| def get_stride(self): |
| return list(self.layout.stride) |
| |
| def make_loader(self): |
| def loader(index): |
| indexer = self.layout.make_indexer() |
| return ops.load(self.get_name(), indexer(index)) |
| |
| return loader |
| |
| def make_indexer(self): |
| return self.layout.make_indexer() |
| |
| def get_layout(self): |
| return self.layout |
| |
| def freeze_layout(self): |
| pass |
| |
| def codegen_reference(self): |
| size = V.graph.wrapper_code.codegen_shape_tuple(self.layout.size) |
| stride = V.graph.wrapper_code.codegen_shape_tuple(self.layout.stride) |
| offset = V.graph.wrapper_code.codegen_sizevar(self.layout.offset) |
| namespace = V.graph.wrapper_code.namespace |
| if offset != "0": |
| return ( |
| f"{namespace}as_strided({self.get_name()}, {size}, {stride}, {offset})" |
| ) |
| return f"{namespace}as_strided({self.get_name()}, {size}, {stride})" |
| |
| |
| class SliceView(View): |
| @classmethod |
| def create(cls, x, dim, start, end, step=1): |
| step = sympy.expand(step) |
| assert step > 0 |
| try: |
| if start == 0 and end >= 2**63 - 1 and step == 1: |
| return x |
| except TypeError: |
| pass |
| |
| sizevars = V.graph.sizevars |
| new_size = list(x.get_size()) |
| |
| start = cls.handle_negative_index(start, new_size[dim]) |
| end = cls.handle_negative_index(end, new_size[dim]) |
| |
| end = sizevars.guard_min(end, new_size[dim]) |
| start = sizevars.guard_min(sizevars.guard_min(start, new_size[dim]), end) |
| if start == 0 and sizevars.size_hint(end - new_size[dim]) == 0 and step == 1: |
| sizevars.guard_equals(end, new_size[dim]) |
| return x |
| |
| new_size[dim] = FloorDiv(end - start + (step - 1), step) |
| |
| if is_storage_and_layout(x): |
| # Fast path |
| storage, old_layout = as_storage_and_layout(x) |
| new_stride = list(old_layout.stride) |
| new_stride[dim] = new_stride[dim] * step |
| new_layout = FixedLayout( |
| old_layout.device, |
| old_layout.dtype, |
| new_size, |
| new_stride, |
| old_layout.offset + old_layout.stride[dim] * start, |
| ) |
| return ReinterpretView(storage, new_layout) |
| |
| def reindex(index): |
| assert len(index) == len(new_size), f"wrong ndim {index} {new_size}" |
| index = list(index) |
| index[dim] = index[dim] * step + start |
| return index |
| |
| # redirect to a generic view |
| return SliceView(x, size=new_size, reindex=reindex) |
| |
| |
| class BaseConstant(IRNode): |
| def get_size(self): |
| return () |
| |
| def get_dtype(self): |
| return self.dtype |
| |
| def get_device(self): |
| return self.device |
| |
| def get_origin_node(self): |
| return None |
| |
| def mark_reuse(self, users): |
| pass |
| |
| def has_exceeded_max_reads(self): |
| return False |
| |
| def get_reads(self): |
| return () |
| |
| def is_extern(self): |
| return False |
| |
| |
| @dataclasses.dataclass |
| class Constant(BaseConstant): |
| value: Any |
| dtype: torch.dtype |
| device: torch.device |
| |
| def make_loader(self): |
| def loader(index): |
| return ops.constant(self.value, self.dtype) |
| |
| return loader |
| |
| def realize(self): |
| pass |
| |
| |
| @dataclasses.dataclass |
| class IndexingConstant(BaseConstant): |
| index: Any |
| dtype: torch.dtype |
| device: torch.device |
| |
| def make_loader(self): |
| def loader(index): |
| return ops.index_expr(self.index, self.dtype) |
| |
| return loader |
| |
| |
| @dataclasses.dataclass |
| class Layout(IRNode): |
| def __init__( |
| self, |
| device: torch.device, |
| dtype: torch.dtype, |
| size: List[Expr], |
| stride: List[Expr], |
| offset: Expr = Integer(0), |
| ): |
| assert stride is None or len(size) == len( |
| stride |
| ), f"size={size}, stride={stride}" |
| self.device = device |
| self.dtype = dtype |
| assert all(isinstance(s, (Expr, int)) for s in size) |
| self.size = size |
| self._stride = stride |
| self.offset = offset |
| |
| @property |
| def stride(self): |
| return self._stride |
| |
| def __str__(self): |
| offset = "" |
| if self.offset != 0: |
| offset = f", offset={self.offset}" |
| return ( |
| f"{type(self).__name__}('{self.device.type}', {self.dtype}, " |
| f"size={self.size}, stride={self.stride}{offset})" |
| ) |
| |
| __repr__ = __str__ |
| |
| def is_contiguous(self): |
| for left, right, size in zip( |
| self.stride, FlexibleLayout.contiguous_strides(self.size), self.size |
| ): |
| if size != 1 and left != right: |
| return False |
| return True |
| |
| def is_channels_last_contiguous(self): |
| ndim = len(self.size) |
| if ndim not in [4, 5]: |
| return False |
| for left, right, size in zip( |
| self.stride, make_channels_last_strides_for(self.size), self.size |
| ): |
| if size != 1 and left != right: |
| return False |
| return True |
| |
| def is_transposed(self): |
| for left, right, size in zip( |
| self.stride, |
| reversed(FlexibleLayout.contiguous_strides(self.size)), |
| self.size, |
| ): |
| if size != 1 and left != right: |
| return False |
| return True |
| |
| def is_stride_ordered(self, order): |
| assert len(self.stride) == len(order) |
| # reorder the stride given order |
| stride_ordered = [None] * len(order) |
| for i in range(len(order)): |
| stride_ordered[order[i]] = V.graph.sizevars.size_hint(self.stride[i]) |
| # check if it is in ascending order |
| for i in range(len(order) - 1): |
| if stride_ordered[i] > stride_ordered[i + 1]: |
| return False |
| return True |
| |
| def is_channels_last_stride_ordered(self): |
| # create channels_last order(NCHW, NCDHW, the C is the first order). |
| order = [0] + list(reversed(range(1, len(self.stride) - 1))) |
| order = [len(order)] + order |
| return self.is_stride_ordered(order) |
| |
| def as_fixed(self): |
| return FixedLayout( |
| self.device, |
| self.dtype, |
| self.size, |
| self.stride, |
| self.offset, |
| ) |
| |
| def make_indexer(self): |
| assert ( |
| FlexibleLayout.allow_indexing |
| ), f"convert {type(self).__name__} to FixedLayout first" |
| return self.as_fixed().make_indexer() |
| |
| def __eq__(self, other) -> bool: |
| return ( |
| self.device == other.device |
| and self.dtype == other.dtype |
| and self.size == other.size |
| and self.stride == other.stride |
| and self.offset == other.offset |
| ) |
| |
| def storage_size(self) -> sympy.Expr: |
| return compute_required_storage_length(self.size, self.stride, self.offset) |
| |
| |
| class FixedLayout(Layout): |
| """A Tensor layout we cannot change""" |
| |
| def __init__( |
| self, |
| device: torch.device, |
| dtype: torch.dtype, |
| size: List[Expr], |
| stride: List[Expr] = None, |
| offset: Expr = Integer(0), |
| ): |
| if stride is None: |
| stride = FlexibleLayout.contiguous_strides(size) |
| super().__init__( |
| device, |
| dtype, |
| size, |
| stride, |
| offset, |
| ) |
| |
| def make_indexer(self): |
| """A closure containing math to read a given element""" |
| |
| def indexer(index): |
| assert len(index) == len(self.stride) == len(self.size) |
| result = self.offset |
| for idx, stride, sz in zip(index, self.stride, self.size): |
| if sz != 1: |
| result = result + idx * stride |
| return result |
| |
| return indexer |
| |
| |
| class FlexibleLayout(Layout): |
| """A Tensor layout we are allowed to change""" |
| |
| allow_indexing = False |
| |
| @staticmethod |
| def contiguous_strides(sizes): |
| if len(sizes) == 0: |
| return [] |
| reversed_strides = [sympy.Integer(1)] |
| for size in reversed(sizes[1:]): |
| reversed_strides.append(size * reversed_strides[-1]) |
| return list(reversed(reversed_strides)) |
| |
| @staticmethod |
| def fill_ordered(sizes, order): |
| """ |
| Create a stride based on the order the dimensions should be filled in. |
| |
| In this format, channels last would be: |
| [1, 3, 2, 0] |
| """ |
| assert set(range(len(sizes))) == set(order) |
| next_stride = sympy.Integer(1) |
| strides = [None] * len(order) |
| |
| for i in order: |
| strides[i] = next_stride |
| next_stride = next_stride * sizes[i] |
| return strides |
| |
| @staticmethod |
| def stride_ordered(sizes, order): |
| """ |
| Create a stride based on the sorted order of a permuted range. |
| |
| In this format, channels last would be: |
| [3, 0, 2, 1] |
| """ |
| assert set(range(len(sizes))) == set(order) |
| fill_order = stride_order2fill_order(order) |
| return FlexibleLayout.fill_ordered(sizes, fill_order) |
| |
| @staticmethod |
| def same_ordered(sizes, stride): |
| """ |
| Create a stride that has the same stride order as given stride |
| |
| For example, if given stride is [1000, 1, 100, 10], |
| the fill order should be [1, 3, 2, 0] |
| """ |
| assert len(sizes) == len(stride) |
| stride = [V.graph.sizevars.size_hint(x) for x in stride] |
| fill_order = sorted(range(len(stride)), key=stride.__getitem__) |
| return FlexibleLayout.fill_ordered(sizes, fill_order) |
| |
| def as_stride_order(self, order): |
| return FixedLayout( |
| self.device, |
| self.dtype, |
| self.size, |
| self.stride_ordered(self.size, order), |
| self.offset, |
| ) |
| |
| def as_fill_order(self, order): |
| return FixedLayout( |
| self.device, |
| self.dtype, |
| self.size, |
| self.fill_ordered(self.size, order), |
| self.offset, |
| ) |
| |
| def as_same_order(self, stride): |
| return FixedLayout( |
| self.device, |
| self.dtype, |
| self.size, |
| self.same_ordered(self.size, stride), |
| self.offset, |
| ) |
| |
| def __init__(self, device, dtype, size, stride_order=None): |
| if stride_order: |
| strides = FlexibleLayout.fill_ordered(size, stride_order) |
| else: |
| strides = FlexibleLayout.contiguous_strides(size) |
| super().__init__(device, dtype, size, strides) |
| |
| |
| class AliasedLayout(Layout): |
| """Shares the same storage as another tensor""" |
| |
| def __init__(self, view: "ReinterpretView"): |
| layout = view.get_layout() |
| super().__init__( |
| layout.device, |
| layout.dtype, |
| layout.size, |
| layout.stride, |
| ) |
| self.view = view |
| |
| def make_indexer(self): |
| return self.as_fixed().make_indexer() |
| |
| def maybe_guard_aligned(self): |
| offset = self.view.get_layout().offset |
| if offset == 0: |
| return True |
| from .compile_fx import ALIGNMENT |
| |
| return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT) |
| |
| |
| class MutationLayout(Layout): |
| def __init__(self, target: IRNode): |
| super().__init__( |
| target.get_device(), |
| target.get_dtype(), |
| target.get_size(), |
| None, # type: ignore[arg-type] |
| ) |
| self.target = target |
| |
| @Layout.stride.getter |
| def stride(self): |
| return self.real_layout().stride |
| |
| def storage_size(self) -> sympy.Expr: |
| return self.real_layout().storage_size() |
| |
| def real_layout(self): |
| def unwrap_views(target): |
| if isinstance(target, MutationLayout): |
| return unwrap_views(target.target) |
| if isinstance(target, BaseView): |
| return unwrap_views(target.unwrap_view()) |
| if isinstance(target, MutableBox): |
| return unwrap_views(target.data) |
| return target |
| |
| return unwrap_views(self.target).layout |
| |
| @classmethod |
| def realize_into(cls, src, dst): |
| dst.realize() |
| V.graph.realize_users_of(dst.get_name()) |
| |
| if isinstance(src, TensorBox): |
| src = src.data |
| |
| if not isinstance(src, StorageBox) or src.is_user_of(dst.get_name()): |
| need_copy = True |
| else: |
| src.realize() |
| need_copy = not isinstance(src.data.layout, FlexibleLayout) |
| |
| if need_copy: |
| src = Pointwise.create( |
| device=src.get_device(), |
| dtype=src.get_dtype(), |
| inner_fn=src.make_loader(), |
| ranges=[ |
| V.graph.sizevars.guard_equals(a, b) |
| for a, b in zip(src.get_size(), dst.get_size()) |
| ], |
| ).data |
| src.realize() |
| |
| assert isinstance(src.data.layout, FlexibleLayout) |
| src.data.layout = MutationLayout(dst) |
| return src.data |
| |
| def as_fixed(self): |
| return self |
| |
| def make_indexer(self): |
| return self.target.make_indexer() |
| |
| |
| @dataclasses.dataclass |
| class Buffer(IRNode): |
| name: str |
| layout: Layout |
| |
| def __post_init__(self): |
| super().__post_init__() |
| self.origin_node = None |
| |
| def make_indexer(self): |
| return self.layout.make_indexer() |
| |
| def get_name(self): |
| assert self.name |
| return self.name |
| |
| def get_device(self): |
| return self.layout.device |
| |
| def get_origin_node(self): |
| return self.origin_node |
| |
| def get_dtype(self): |
| return getattr(self.layout, "dtype", None) |
| |
| def get_size(self): |
| return list(self.layout.size) |
| |
| def get_stride(self): |
| return list(self.layout.stride) |
| |
| def get_layout(self): |
| return self.layout |
| |
| def get_storage_numel(self): |
| return self.get_numel() |
| |
| def is_extern(self): |
| return False |
| |
| def freeze_layout(self): |
| if not isinstance(self.layout, (MultiOutputLayout, AliasedLayout)): |
| self.layout = self.layout.as_fixed() |
| |
| def freeze_layout_with_stride_order(self, order): |
| assert isinstance(self.layout, FlexibleLayout) |
| self.layout = self.layout.as_stride_order(order) |
| |
| def freeze_layout_with_fill_order(self, order): |
| assert isinstance(self.layout, FlexibleLayout) |
| self.layout = self.layout.as_fill_order(order) |
| |
| def freeze_layout_with_same_order(self, stride): |
| assert isinstance(self.layout, FlexibleLayout) |
| self.layout = self.layout.as_same_order(stride) |
| |
| def make_loader(self): |
| def loader(index): |
| indexer = self.layout.make_indexer() |
| return ops.load(self.name, indexer(index)) |
| |
| return loader |
| |
| def is_no_op(self): |
| return False |
| |
| def codegen_reference(self): |
| return self.get_name() |
| |
| def decide_layout(self): |
| pass |
| |
| def get_alias_names(self): |
| if isinstance(self.layout, AliasedLayout): |
| return [self.layout.view.get_name()] |
| return () |
| |
| def get_mutation_names(self): |
| if isinstance(self.layout, MutationLayout): |
| return [self.layout.target.get_name()] |
| return () |
| |
| @cache_on_self |
| def get_read_writes(self): |
| with patch.object(FlexibleLayout, "allow_indexing", True): |
| return extract_read_writes( |
| self.make_loader(), |
| self.get_size(), |
| ) |
| |
| def get_reads(self): |
| return self.get_read_writes().reads |
| |
| def realize(self): |
| pass |
| |
| |
| class InputBuffer(Buffer): |
| pass |
| |
| |
| class ConstantBuffer(InputBuffer): |
| override_device = None |
| |
| def make_loader(self): |
| def loader(index): |
| indexer = self.layout.make_indexer() |
| return ops.load( |
| V.graph.constant_name(self.name, self.override_device), indexer(index) |
| ) |
| |
| return loader |
| |
| def constant_to_device(self, device): |
| return ConstantBuffer(V.graph.constant_name(self.name, device), self.layout) |
| |
| |
| class RandSeedBuffer(ConstantBuffer): |
| def codegen_reference(self): |
| # Clone makes sure if we pass this from forwards to backwards |
| # the value does not get clobbered by the time backwards is run. |
| return self.get_name() + ".clone()" |
| |
| |
| class NoneAsConstantBuffer(IRNode): |
| def codegen_reference(self): |
| return V.graph.wrapper_code.none_str |
| |
| |
| class ShapeAsConstantBuffer(IRNode): |
| def __init__(self, shape): |
| super().__init__() |
| self.shape = shape |
| |
| def codegen_reference(self): |
| from torch._inductor.codegen.wrapper import pexpr |
| |
| expr = pexpr(V.graph.sizevars.simplify(self.shape)) |
| if V.graph.cpp_wrapper: |
| # wrap scalar to 0-d tensor for cpp wrapper |
| return f"torch::tensor({expr})" |
| else: |
| return expr |
| |
| |
| @dataclasses.dataclass |
| class ComputedBuffer(Buffer): |
| data: Loops |
| |
| @cache_on_self |
| def get_read_writes(self): |
| with patch.object(FlexibleLayout, "allow_indexing", True): |
| if self.data.get_reduction_type(): |
| return extract_read_writes( |
| self.get_store_function(), |
| self.data.get_size(), |
| self.data.get_reduction_size(), |
| ) |
| else: |
| return extract_read_writes( |
| self.get_store_function(), |
| self.data.get_size(), |
| ) |
| |
| def get_store_function(self): |
| indexer = self.layout.as_fixed().make_indexer() |
| if self.data.get_reduction_type(): |
| return partial(self.data.store_reduction, self.name, indexer) |
| else: |
| return partial(self.data.store_output, self.name, indexer) |
| |
| def get_fill_order(self): |
| """ |
| If our layout is still flexible, try to determine the stride order based on stride orders of reads. |
| |
| TODO(jansel): A better algorithm here would look at downstream consumers of this |
| value and try to do global graph-level layout optimization. |
| This is also something just begging to be autotuned. |
| """ |
| if isinstance(self.layout, FlexibleLayout): |
| (index_vars, reduction_vars), _ = dependencies.index_vars_squeeze( |
| self.data.get_size(), self.data.get_reduction_size() |
| ) |
| reads = self.get_read_writes().reads |
| reads_bufs = [ |
| V.graph.name_to_buffer[r.name] |
| if r.name in V.graph.name_to_buffer.keys() |
| else None |
| for r in reads |
| ] |
| # only consider reads to buffer of same size |
| reads = [ |
| sympy_subs( |
| r.index, {v: sympy.Integer(0) for v in reduction_vars if v != 0} |
| ) |
| for r in reads |
| ] |
| |
| if reads: |
| stride_lengths = [ |
| V.graph.sizevars.stride_hints(expr, index_vars) for expr in reads |
| ] |
| from .scheduler import pick_loop_order |
| |
| return pick_loop_order(stride_lengths, self.get_size()) |
| |
| return None |
| |
| def decide_layout(self): |
| if isinstance(self.layout, FlexibleLayout): |
| order = self.get_fill_order() |
| if order: |
| self.freeze_layout_with_fill_order(order) |
| else: |
| self.freeze_layout() |
| |
| def simplify_and_reorder(self): |
| """ |
| This is a main place where we do loop transformations in a |
| backend-agnostic way. |
| |
| Here we: |
| 1) Remove any 1 dimensions |
| 2) Fuse contiguous dimensions together |
| 3) Reorder dimensions based on stride orders |
| """ |
| args, var_ranges = dependencies.index_vars_squeeze( |
| self.data.get_size(), self.data.get_reduction_size(), prefix="q" |
| ) |
| with patch.object(ConstantBuffer, "override_device", self.get_device()): |
| body = LoopBody( |
| self.get_store_function(), |
| (args if self.get_reduction_type() else args[:1]), |
| var_ranges, |
| ) |
| index_formulas = [*body.indexing_exprs.values()] |
| reads_bufs = [ |
| V.graph.name_to_buffer[reads_name] |
| if reads_name in V.graph.name_to_buffer.keys() |
| else None |
| for reads_name in body.reads_name2expr.keys() |
| ] |
| memory_addrs = [ |
| *body.reads_name2expr.values(), |
| *body.writes_name2expr.values(), |
| ] |
| index_vars = [] |
| reduce_vars = [] |
| index_size = [] |
| reduce_size = [] |
| for v, s in var_ranges.items(): |
| if v in args[0]: |
| assert not reduce_vars |
| index_vars.append(v) |
| index_size.append(s) |
| else: |
| assert v in args[1] |
| reduce_vars.append(v) |
| reduce_size.append(s) |
| |
| # the reordering_reindex in reads' simplify_reorder_and_tile |
| reordering_reindex = [same_reorder(range(len(index_vars)))] * len(memory_addrs) |
| for i, reads_buf in enumerate(reads_bufs): |
| if isinstance(reads_buf, ComputedBuffer) and hasattr( |
| reads_buf, "iter_reordering_reindex" |
| ): |
| reordering_reindex[i] = reads_buf.iter_reordering_reindex |
| |
| def simplify_and_reorder(x_vars, support_vars, sizes, reordering_reindex=None): |
| sizes, reindex0, reindex1 = self._apply_loop_reordering( |
| x_vars, support_vars, sizes, memory_addrs, reordering_reindex |
| ) |
| # for NHWC: reindex0([0,1,2,3]) = [0,2,3,1], reindex1([0,1,2,3]) = [0,3,2,1] |
| x_vars = reindex0(x_vars) |
| sizes, reindex2, prune = V.graph.sizevars._simplify_loops( |
| x_vars, |
| sizes, |
| index_prevent_reordering(index_formulas, x_vars, sizes), |
| ) |
| x_vars = prune(x_vars) |
| # sizes, reindex1, prune = _simplify_loops(x_vars, sizes, index_formulas) |
| # x_vars = prune(x_vars) |
| # sizes, reindex2 = self._apply_loop_reordering(x_vars, sizes, memory_addrs) |
| reindex = fuse_reindexing(reindex1, reindex2) |
| return sizes, reindex, reindex1 |
| |
| support_vars = index_vars + reduce_vars |
| iter_ranges, iter_reindex, iter_reordering_reindex = simplify_and_reorder( |
| index_vars, support_vars, index_size, reordering_reindex |
| ) |
| reduce_ranges, reduce_reindex, _ = simplify_and_reorder( |
| reduce_vars, support_vars, reduce_size |
| ) |
| |
| # remember the reordering if not have loop collapse. |
| if len(iter_ranges) == len(index_vars): |
| self.iter_reordering_reindex = iter_reordering_reindex |
| # retrace the loop body with simplification and reordering applied |
| (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze( |
| iter_ranges, reduce_ranges, prefix="z" |
| ) |
| body = LoopBody( |
| body, [iter_reindex(iter_vars), reduce_reindex(reduce_vars)], var_ranges |
| ) |
| return (iter_ranges, reduce_ranges), body |
| |
| @staticmethod |
| def _apply_loop_reordering( |
| index_vars, |
| support_vars, |
| sizes, |
| memory_addrs, |
| reordering_reindex=None, |
| priority_idx=None, |
| ): |
| """ |
| Shuffle the order of loops around to hopefully improve performance. |
| """ |
| from .scheduler import pick_loop_order |
| |
| if priority_idx is None: |
| priority_idx = [] |
| |
| try: |
| strides = [ |
| V.graph.sizevars.stride_hints(expr, index_vars, support_vars) |
| for expr in memory_addrs |
| ] |
| assert len(strides) == len(memory_addrs) and len(strides[0]) == len( |
| index_vars |
| ) |
| # consider both layout(strides) and reordering(reordering_reindex) |
| if reordering_reindex is not None: |
| for i in range(len(memory_addrs)): |
| try: |
| strides[i] = reordering_reindex[i](strides[i]) |
| # if len(order) != len(strides), do not reorder |
| except AssertionError: |
| pass |
| order = list(reversed(pick_loop_order(strides, sizes, priority_idx))) |
| except Exception: |
| if config.debug: |
| log.warning( |
| "Did not simplify complex index:\n%s\n%s", |
| dict(zip(index_vars, sizes)), |
| memory_addrs, |
| ) |
| order = list(range(len(sizes))) |
| sizes = [sizes[i] for i in order] |
| return sizes, same_reorder(order), inverse_reorder(order) |
| |
| def get_reduction_size(self): |
| return self.data.get_reduction_size() |
| |
| def get_reduction_type(self): |
| return self.data.get_reduction_type() |
| |
| def is_no_op(self): |
| return self.data.is_zero_elements() |
| |
| def should_allocate(self): |
| return True |
| |
| def constant_to_device(self, device): |
| """Move this to a given device. Requires that all reads are to constants.""" |
| return self.data.constant_to_device(device) |
| |
| |
| class TemplateBuffer(Buffer): |
| """ |
| Represents a Triton (in the future other type) of template operator |
| that we can fuse an epilogue onto. |
| """ |
| |
| def __init__(self, layout, inputs, make_kernel_render): |
| super().__init__(name=None, layout=layout) |
| self.inputs = InputsKernel.unwrap_storage(inputs) |
| self.make_kernel_render = make_kernel_render |
| self.name = V.graph.register_buffer(self) |
| |
| def get_read_writes(self): |
| return self.normalized_read_writes() |
| |
| @cache_on_self |
| def normalized_read_writes(self): |
| name = self.get_name() |
| indexer = self.layout.make_indexer() |
| |
| def dummy(index, rindex): |
| assert len(rindex) == 0 |
| return ops.store(name, indexer(index), "fake") |
| |
| deps = dependencies.extract_read_writes( |
| dummy, self.get_size(), (), normalize=True |
| ) |
| deps.reads = {dependencies.StarDep(x.get_name()) for x in self.inputs} |
| return deps |
| |
| def get_reduction_size(self): |
| return 1 |
| |
| def get_reduction_type(self): |
| return None |
| |
| def is_no_op(self): |
| return False |
| |
| def should_allocate(self): |
| return True |
| |
| def simplify_and_reorder(self): |
| return ( |
| ( |
| self.get_size(), |
| (), |
| ), |
| None, |
| ) |
| |
| |
| @dataclasses.dataclass |
| class InputsKernel(Buffer): |
| inputs: List[Buffer] |
| |
| def get_read_writes(self): |
| return dependencies.ReadWrites( |
| {dependencies.StarDep(x.get_name()) for x in self.inputs}, |
| {dependencies.StarDep(self.get_name())}, |
| set(), |
| [], |
| None, |
| op_counts=collections.Counter(), |
| ) |
| |
| @staticmethod |
| def unwrap_storage(inputs): |
| inputs_new = [] |
| for x in inputs: |
| if isinstance(x, TensorBox): |
| x = x.data |
| if isinstance(x, StorageBox): |
| x = x.data |
| if isinstance(x, BaseView) and not isinstance(x, ReinterpretView): |
| x = ExternKernel.realize_input(x) |
| assert isinstance(x, (Buffer, ReinterpretView)), x |
| inputs_new.append(x) |
| return inputs_new |
| |
| def is_extern(self): |
| return True |
| |
| |
| class NopKernel(InputsKernel): |
| def is_no_op(self): |
| return True |
| |
| |
| class ConcatKernel(NopKernel): |
| """ |
| There isn't actually a real kernel for concat, we just change the |
| storage for the upstream data. |
| """ |
| |
| @classmethod |
| def create(cls, inputs, dim): |
| device = inputs[0].get_device() |
| dtype = inputs[0].get_dtype() |
| new_size = list(inputs[0].get_size()) |
| offsets_start = [0] |
| offsets_end = [new_size[dim]] |
| assert 0 <= dim < len(new_size) |
| for i in range(1, len(inputs)): |
| input_size = inputs[i].get_size() |
| offsets_start.append(new_size[dim]) |
| assert len(input_size) == len(new_size) |
| assert inputs[i].get_dtype() == dtype |
| assert inputs[i].get_device() == device |
| for j in range(len(new_size)): |
| if j == dim: |
| new_size[j] = new_size[j] + input_size[j] |
| else: |
| new_size[j] = V.graph.sizevars.guard_equals( |
| new_size[j], input_size[j] |
| ) |
| offsets_end.append(new_size[dim]) |
| |
| output_stride = FlexibleLayout.contiguous_strides(new_size) |
| # If any of the inputs is in CL format, use CL format for the output |
| for i in range(len(inputs)): |
| x = inputs[i] |
| if is_storage_and_layout(x): |
| layout = x.get_layout() |
| if ( |
| isinstance(layout, FixedLayout) |
| and layout.is_channels_last_contiguous() |
| ): |
| # use CL stride for the output |
| output_stride = make_channels_last_strides_for(new_size) |
| break |
| |
| kernel = ConcatKernel( |
| name=None, |
| layout=FixedLayout( |
| device=device, |
| dtype=dtype, |
| size=new_size, |
| stride=output_stride, |
| ), |
| inputs=[], |
| ) |
| kernel = StorageBox(kernel) |
| for i in range(len(inputs)): |
| kernel.data.inputs.append( |
| cls.realize_into( |
| inputs[i], |
| SliceView.create(kernel, dim, offsets_start[i], offsets_end[i]), |
| ) |
| ) |
| kernel.data.name = V.graph.register_buffer(kernel.data) |
| kernel.data.inputs = cls.unwrap_storage(kernel.data.inputs) |
| |
| return kernel |
| |
| @classmethod |
| def realize_into(cls, src, dst): |
| # Attempt to turn this into a ReinterpretView rather than assert. |
| # This has concessions around layout, as as_storage_and_layout |
| # can cause us to go from flexible to fixed layout. |
| if not isinstance(dst, ReinterpretView): |
| if is_storage_and_layout(dst): |
| storage, layout = as_storage_and_layout(dst) |
| dst = ReinterpretView(storage, layout) |
| assert isinstance(dst, ReinterpretView), dst |
| if isinstance(src, TensorBox): |
| # unwrap a TensorBox |
| return cls.realize_into(src.data, dst) |
| if isinstance(src, StorageBox): |
| src.realize() |
| # ExternKernelAlloc has specific requirements for output layout, should create a copy |
| if isinstance(src.data.layout, FlexibleLayout) and not isinstance( |
| src.data, ExternKernelAlloc |
| ): |
| src.data.layout = AliasedLayout(dst) |
| return src.data |
| # introduce a copy |
| pw = Pointwise.create( |
| device=src.get_device(), |
| dtype=src.get_dtype(), |
| inner_fn=src.make_loader(), |
| ranges=[ |
| V.graph.sizevars.guard_equals(a, b) |
| for a, b in zip(src.get_size(), dst.get_size()) |
| ], |
| ) |
| return cls.realize_into(pw, dst) |
| |
| def should_allocate(self): |
| return True |
| |
| |
| @dataclasses.dataclass |
| class ExternKernel(InputsKernel): |
| constant_args: Tuple[Any, ...] = () |
| kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) |
| output_view: Optional[ReinterpretView] = None |
| |
| def decide_layout(self): |
| if isinstance(self.layout, FlexibleLayout): |
| self.apply_constraint() |
| self.freeze_layout() |
| |
| def codegen(self, wrapper): |
| raise NotImplementedError() |
| |
| @staticmethod |
| def copy_input(x): |
| pw = Pointwise.create( |
| device=x.get_device(), |
| dtype=x.get_dtype(), |
| inner_fn=x.make_loader(), |
| ranges=x.get_size(), |
| origin_node=x.get_origin_node(), |
| traceback=x.get_traceback(), |
| ) |
| pw.realize() |
| return pw |
| |
| @classmethod |
| def process_kernel(cls, kernel, *args, **kwargs): |
| binded_args = signature(kernel).bind(*args, **kwargs).arguments |
| args_flat, args_spec = pytree.tree_flatten(binded_args) |
| |
| is_arg_tensor = [] |
| tensor_args = [] |
| non_tensor_args = [] |
| for arg in args_flat: |
| is_arg_tensor.append(isinstance(arg, IRNode)) |
| if is_arg_tensor[-1]: |
| tensor_args.append(arg) |
| else: |
| if isinstance(arg, sympy.Expr): |
| arg = V.graph.sizevars.shape_env.create_symintnode(arg, hint=None) |
| non_tensor_args.append(arg) |
| |
| def unflatten_args(new_tensor_args, new_non_tensor_args): |
| result = [] |
| it_tensors = iter(new_tensor_args) |
| it_non_tensors = iter(new_non_tensor_args) |
| for is_tensor in is_arg_tensor: |
| if is_tensor: |
| result.append(next(it_tensors)) |
| else: |
| result.append(next(it_non_tensors)) |
| result = pytree.tree_unflatten(result, args_spec) |
| return result.get("args", []), result.get("kwargs", {}) |
| |
| tensor_args = [cls.realize_input(x) for x in tensor_args] |
| |
| # freeze layout otherwise our output stride calculation might |
| # become incorrect |
| for x in tensor_args: |
| if is_storage_and_layout(x): |
| as_storage_and_layout(x, freeze=True) |
| |
| # We don't have generic shape formulas, so just burn in the |
| # shapes and run an example input. |
| # TODO(jansel): replace this with dynamic shape formulas |
| example_args = [] |
| |
| # We need to retain the constant values of fake tensors that we originally |
| # propagated the graph with, because for some operators running without a |
| # constant would trigger an error / DataDependentException |
| for x in tensor_args: |
| if x.get_name() in V.graph.constants: |
| example_args.append(V.graph.constants[x.get_name()]) |
| else: |
| example_args.append(ir_node_to_tensor(x, guard_shape=True)) |
| |
| new_args, new_kwargs = unflatten_args(example_args, non_tensor_args) |
| example_output = kernel(*new_args, **new_kwargs) |
| |
| return example_output, tensor_args, non_tensor_args, unflatten_args |
| |
| @classmethod |
| def convert_to_reinterpret_view(cls, x): |
| """ |
| In order to pass this to an extern kernel we need a |
| ReinterpretView not a View. This allows us to avoid some |
| unneeded copies. |
| """ |
| assert isinstance(x, BaseView) |
| if isinstance(x, ReinterpretView): |
| return x |
| |
| x.unwrap_view().freeze_layout() |
| rw = extract_read_writes(x.make_loader(), x.get_size(), normalize=False) |
| assert len(rw.reads) == 1 |
| |
| index = V.graph.sizevars.simplify_with_ranges( |
| list(rw.reads)[0].index, rw.var_ranges |
| ) |
| strides = V.graph.sizevars.stride_vars(index, rw.range_vars) |
| offset = V.graph.sizevars.offset_var(index, rw.range_vars) |
| expected = sympy_dot(rw.range_vars, strides) + offset |
| |
| if index != expected: |
| log.debug( |
| "convert_to_reinterpret_view failed: stride=%s offset=%s index=%s", |
| strides, |
| offset, |
| index, |
| ) |
| raise NotImplementedError() |
| |
| return ReinterpretView( |
| data=x.data, |
| layout=FixedLayout( |
| device=x.get_device(), |
| dtype=x.get_dtype(), |
| size=x.get_size(), |
| stride=strides, |
| offset=offset, |
| ), |
| ) |
| |
| @classmethod |
| def realize_input(cls, x): |
| if x is None: |
| return NoneAsConstantBuffer() |
| if isinstance(x, (sympy.Expr, sympy.Rel, int)): |
| return ShapeAsConstantBuffer(x) |
| if isinstance(x, Constant): |
| return V.graph.add_tensor_constant( |
| torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device()) |
| ) |
| if isinstance(x, ConstantBuffer): |
| return x |
| if isinstance(x, TensorBox): |
| return cls.realize_input(x.data) |
| if isinstance(x, ReinterpretView): |
| return x |
| if isinstance(x, BaseView): |
| x.realize() |
| if is_storage_and_layout(x.unwrap_view()) and not isinstance( |
| x.unwrap_view().data, ExternKernelAlloc |
| ): |
| try: |
| return cls.convert_to_reinterpret_view(x) |
| except NotImplementedError: |
| pass |
| if isinstance(x, StorageBox): |
| # TODO(jansel): impose layout preference on realized buffer |
| x.realize() |
| return x |
| return cls.copy_input(x) |
| |
| @classmethod |
| def require_stride1(cls, x): |
| if is_storage_and_layout(x): |
| if len(x.get_stride()) == 0: |
| return x |
| for stride in x.get_stride(): |
| if stride == 1: |
| return x |
| return cls.copy_input(x) |
| |
| @classmethod |
| def require_stride_order(cls, x, order): |
| if x.get_numel() == 0: # Layout doesn't matter |
| return x |
| |
| # require x to have the layout as strided_ordered as order |
| if is_storage_and_layout(x): |
| if isinstance(x.get_layout(), FlexibleLayout): |
| # fix flexiblelayout to be FixedLayout with stride_order |
| as_storage_and_layout( |
| x, freeze=True, want_contiguous=False, stride_order=order |
| ) |
| return x |
| elif isinstance( |
| x.get_layout(), FixedLayout |
| ) and x.get_layout().is_stride_ordered(order): |
| return x |
| elif isinstance(x.get_layout(), MutationLayout): |
| if isinstance(x.get_layout().real_layout(), FlexibleLayout): |
| raise AssertionError( |
| "the MutationLayout's real layout shouldn't be FlexibleLayout" |
| ) |
| elif isinstance( |
| x.get_layout().real_layout(), FixedLayout |
| ) and x.get_layout().real_layout().is_stride_ordered(order): |
| return x |
| |
| # TODO - Storage to InputBuffer |
| if isinstance(x, InputBuffer) and x.get_layout().is_stride_ordered(order): |
| return x |
| x = cls.copy_input(x) |
| as_storage_and_layout(x, freeze=True, want_contiguous=False, stride_order=order) |
| assert is_stride_order_storage_and_layout(x, order) |
| return x |
| |
| @classmethod |
| def require_contiguous(cls, x): |
| return cls.require_stride_order(x, list(reversed(range(len(x.get_size()))))) |
| |
| def apply_constraint(self): |
| pass |
| |
| def codegen_const_args(self): |
| return map(V.graph.wrapper_code.val_to_str, self.constant_args) |
| |
| def codegen_args(self): |
| args = [x.codegen_reference() for x in self.inputs] |
| args.extend(self.codegen_const_args()) |
| return args |
| |
| def codegen_kwargs(self): |
| kwargs = [] |
| if self.kwargs: |
| if V.graph.cpp_wrapper: |
| # TODO: use native_functions.yaml as the ground truth |
| assert ( |
| self.ordered_kwargs_for_cpp_kernel |
| ), "ordered_kwargs_for_cpp_kernel has to be provided" |
| for arg_name in self.ordered_kwargs_for_cpp_kernel: |
| assert arg_name in self.kwargs, ( |
| "arg %s not found in self.kwargs" % arg_name |
| ) |
| v = self.kwargs.get(arg_name) |
| kwargs.append(V.graph.wrapper_code.val_to_str(v)) |
| else: |
| kwargs = [ |
| f"{k}={V.graph.wrapper_code.val_to_str(v)}" |
| for k, v in self.kwargs.items() |
| ] |
| return kwargs |
| |
| def codegen_size_asserts(self, wrapper): |
| if config.size_asserts and not V.graph.cpp_wrapper: |
| size = V.graph.wrapper_code.codegen_shape_tuple(self.get_size()) |
| stride = V.graph.wrapper_code.codegen_shape_tuple(self.get_stride()) |
| wrapper.writeline( |
| f"assert_size_stride({self.get_name()}, {size}, {stride})" |
| ) |
| |
| def get_group_stride(self): |
| """ |
| get output sizes and strides, for template_codegen |
| """ |
| _size = self.get_size() |
| _stride = self.get_stride() |
| # iter_ranges = _size of output tensor, reduce_range = [] because no reduction |
| return [_size, []], _stride |
| |
| def canonicalize(self): |
| """ |
| Manually get canonicalization of the output index |
| """ |
| # manually generate index formula for conv |
| sizevars = V.graph.sizevars |
| sizes = self.get_size() |
| strides = self.get_stride() |
| strides = [sizevars.size_hint(x) for x in strides] |
| index_vars = [sympy_symbol(f"d{i}") for i in range(len(sizes))] |
| # reorder index vars according to stride |
| index_order = sorted(range(len(strides)), key=strides.__getitem__, reverse=True) |
| lookup = {pos: idx for idx, pos in enumerate(index_order)} |
| order = [lookup[i] for i in range(len(lookup))] |
| index_vars = [index_vars[i] for i in order] |
| indexer = self.make_indexer() |
| index = indexer(index_vars) |
| |
| new_sizes, reindex, prune = V.graph.sizevars._simplify_loops( |
| index_vars, sizes, [index] |
| ) |
| |
| # assign new variables each dimension to deal with numbering mismatches |
| # d0, d1, d2 could become d0, d2 -- which won't match d0, d1 |
| _, add_var = var_builder("c") |
| replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes]))) |
| |
| index = sympy_subs(sympy.expand(index), replacement) |
| return index, tuple(new_sizes) |
| |
| def __str__(self): |
| lines = [ |
| f"{field.name}={getattr(self, field.name)}" |
| for field in dataclasses.fields(self) |
| ] |
| lines.append(f"origin_node={self.origin_node!r}") |
| return self.str_helper(lines) |
| |
| __repr__ = __str__ |
| |
| |
| @dataclasses.dataclass |
| class ExternKernelOut(ExternKernel): |
| output_view: Optional[ReinterpretView] = None |
| |
| def codegen(self, wrapper): |
| args = [*self.codegen_args(), *self.codegen_kwargs()] |
| wrapper.generate_extern_kernel_out( |
| self.output_view, |
| self.codegen_reference(), |
| args, |
| self.kernel, |
| ) |
| |
| def __init__( |
| self, |
| layout, |
| inputs, |
| constant_args=(), |
| kwargs=None, |
| output_view=None, |
| kernel=None, |
| cpp_kernel=None, |
| ordered_kwargs_for_cpp_kernel=(), |
| ): |
| super().__init__( |
| None, layout, self.unwrap_storage(inputs), constant_args, kwargs or {} |
| ) |
| self.output_view = output_view |
| self.name = V.graph.register_buffer(self) |
| self.kernel = cpp_kernel if V.graph.cpp_wrapper else kernel |
| self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel |
| |
| def should_allocate(self): |
| return True |
| |
| |
| class ExternKernelAlloc(ExternKernel): |
| def codegen(self, wrapper): |
| args = [*self.codegen_args(), *self.codegen_kwargs()] |
| V.graph.wrapper_code.generate_extern_kernel_alloc( |
| self.get_name(), self.kernel, args, self.get_origin_node() |
| ) |
| if isinstance(self.layout, Layout): |
| self.codegen_size_asserts(wrapper) |
| |
| def __init__( |
| self, |
| layout, |
| inputs, |
| constant_args=(), |
| kwargs=None, |
| kernel=None, |
| cpp_kernel=None, |
| ordered_kwargs_for_cpp_kernel=(), |
| ): |
| super().__init__( |
| None, layout, self.unwrap_storage(inputs), constant_args, kwargs or {} |
| ) |
| self.name = V.graph.register_buffer(self) |
| self.kernel = cpp_kernel if V.graph.cpp_wrapper else kernel |
| self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel |
| |
| def should_allocate(self): |
| return False |
| |
| def apply_constraint(self): |
| raise NotImplementedError |
| |
| |
| class InplaceBernoulliFallback(ExternKernel): |
| """ |
| This needs to be a custom class to handle mutation properly |
| """ |
| |
| kernel = "aten.bernoulli_" |
| |
| def codegen(self, wrapper): |
| (x,) = [t.codegen_reference() for t in self.inputs] |
| wrapper.writeline( |
| f"{self.kernel}({x}, {', '.join(map(repr, self.constant_args))})" |
| ) |
| |
| def should_allocate(self): |
| return False |
| |
| def get_mutation_names(self): |
| assert isinstance(self.layout, MutationLayout) |
| return (self.layout.target.get_name(),) |
| |
| def __init__(self, x, *constant_args): |
| super().__init__( |
| None, |
| MutationLayout(x), |
| self.unwrap_storage([x]), |
| constant_args, |
| ) |
| self.name = V.graph.register_buffer(self) |
| |
| |
| class ScatterFallback(ExternKernel): |
| """ |
| This needs to be a custom class to handle mutation properly. |
| This class handles both aten.scatter_ and aten.scatter_reduce_. |
| It also handle the case `src` being a scalar properly. |
| """ |
| |
| def codegen(self, wrapper): |
| if self.src_is_tensor: |
| (x, index, src) = [t.codegen_reference() for t in self.inputs] |
| else: |
| (x, index) = [t.codegen_reference() for t in self.inputs] |
| src = self.constant_args[1] |
| line = f"{self.kernel}({x}, {self.constant_args[0]}, {index}, {src}" |
| if self.kernel == "aten.scatter_": |
| if self.kwargs["reduce"]: |
| line += f", reduce={repr(self.kwargs['reduce'])}" |
| else: |
| line += ", ".join([""] + self.codegen_kwargs()) |
| line += ")" |
| wrapper.writeline(line) |
| |
| def should_allocate(self): |
| return False |
| |
| def __init__( |
| self, |
| fn, |
| x, |
| dim: int, |
| index, |
| src, |
| *, |
| reduce: str = None, |
| include_self: bool = True, |
| ): |
| assert fn in {"aten.scatter_", "aten.scatter_reduce_"} |
| self.kernel = fn |
| self.src_is_tensor = isinstance(src, TensorBox) |
| if self.src_is_tensor: |
| tensors = [self.realize_input(t) for t in [x, index, src]] |
| constant_args = [dim] |
| else: |
| tensors = [self.realize_input(t) for t in [x, index]] |
| constant_args = [dim, src] |
| super().__init__( |
| None, |
| MutationLayout(x), |
| self.unwrap_storage(tensors), |
| constant_args, |
| {"reduce": reduce, "include_self": include_self}, |
| ) |
| self.name = V.graph.register_buffer(self) |
| |
| |
| class IndexPutFallback(ExternKernel): |
| """ |
| This needs to be a custom class to handle mutation and indices properly |
| """ |
| |
| def codegen(self, wrapper): |
| (x, values, *valid_indices) = [t.codegen_reference() for t in self.inputs] |
| indices = [] |
| iter_valid_indices = iter(valid_indices) |
| for i, _ in enumerate(self.indices): |
| if self.indices[i] is not None: |
| indices.append(next(iter_valid_indices)) |
| else: |
| indices.append(V.graph.wrapper_code.none_str) |
| |
| indices = f"{V.graph.wrapper_code.open_bracket}{', '.join(indices)}{V.graph.wrapper_code.closed_bracket}" |
| args = [x, indices, values, *self.codegen_const_args()] |
| wrapper.writeline(wrapper.wrap_kernel_call(self.kernel, args)) |
| |
| def should_allocate(self): |
| return False |
| |
| def __init__(self, x, indices, values, accumulate): |
| self.indices = indices |
| valid_indices = [i for i in indices if i is not None] |
| tensors = [self.realize_input(x) for x in [x, values, *valid_indices]] |
| super().__init__( |
| None, |
| MutationLayout(x), |
| self.unwrap_storage(tensors), |
| [accumulate], |
| ) |
| self.name = V.graph.register_buffer(self) |
| self.kernel = "at::index_put_" if V.graph.cpp_wrapper else "aten.index_put_" |
| |
| |
| class DeviceCopy(ExternKernelOut): |
| @classmethod |
| def create(cls, x, device): |
| if not x.is_extern() and all( |
| (r.name in V.graph.constants and hasattr(r, "index")) for r in x.get_reads() |
| ): |
| return x.constant_to_device(device) |
| |
| V.graph.device_types.add(device.type) |
| V.graph.add_device_idx(device.index) |
| V.graph.device_types.add(x.get_device().type) |
| V.graph.add_device_idx(x.get_device().index) |
| |
| developer_warning("DeviceCopy in input program") |
| return DeviceCopy( |
| FlexibleLayout( |
| device=device, |
| dtype=x.get_dtype(), |
| size=x.get_size(), |
| ), |
| [cls.realize_input(x)], |
| ) |
| |
| def codegen(self, wrapper): |
| args = self.codegen_args() |
| assert len(args) == 1 |
| if self.output_view: |
| wrapper.writeline( |
| f"{self.output_view.codegen_reference()}.copy_({args[0]}){V.graph.wrapper_code.ending}" |
| ) |
| else: |
| wrapper.writeline( |
| f"{self.codegen_reference()}.copy_({args[0]}){V.graph.wrapper_code.ending}" |
| ) |
| |
| |
| class DynamicScalar(IRNode): |
| """ |
| The result of a call to aten._local_scalar_dense. |
| |
| This is not yet implemented. The one model (so far) that calls this |
| (fastNLP_Bert) does not actually use the result. So we expect this |
| node to get dead code eliminated. |
| """ |
| |
| def get_reads(self): |
| return () |
| |
| |
| @dataclasses.dataclass |
| class FallbackKernel(ExternKernelAlloc): |
| def __init__( |
| self, |
| layout, |
| kernel, |
| tensor_args, |
| nontensor_args, |
| unflatten_args, |
| kwargs=None, |
| ): |
| super().__init__( |
| layout, |
| tuple(tensor_args), |
| tuple(nontensor_args), |
| ) |
| if getattr(torch.ops.aten, kernel.__name__, None) is kernel: |
| self.kernel = ( |
| f"at::{kernel.__name__}" |
| if V.graph.cpp_wrapper |
| else f"aten.{kernel.__name__}" |
| ) |
| else: |
| assert ( |
| not V.graph.cpp_wrapper |
| ), f"{kernel.__name__} is not supported with cpp wrapper" |
| self.kernel = ( |
| f"{kernel.__module__.replace('._ops.', '.ops.')}.{kernel.__name__}" |
| ) |
| self.unflatten_args = unflatten_args |
| self.kwargs = {} if kwargs is None else kwargs |
| V.graph.warn_fallback(self.kernel) |
| |
| def codegen_args(self): |
| @dataclasses.dataclass |
| class Shim: |
| ref: Any |
| |
| def __repr__(self): |
| return self.ref |
| |
| tensor_args = [Shim(x.codegen_reference()) for x in self.inputs] |
| args, kwargs = self.unflatten_args(tensor_args, self.constant_args) |
| args = [V.graph.wrapper_code.val_to_str(x) for x in args] |
| # let self.codegen_kwargs handle kwargs |
| self.kwargs.update(kwargs) |
| return args |
| |
| @classmethod |
| def create(cls, kernel, *args, **kwargs): |
| fake_incorrect_kernels = ( |
| aten._fft_r2c.default, |
| aten._fft_r2c.out, |
| aten._fft_c2r.default, |
| aten._fft_c2c.default, |
| aten._fft_c2c.out, |
| aten._linalg_svd.default, |
| aten._linalg_svd.U, |
| aten._fused_moving_avg_obs_fq_helper_functional, |
| ) |
| context = ( |
| V.graph.fake_mode if kernel not in fake_incorrect_kernels else nullcontext() |
| ) |
| with context: |
| ( |
| example_output, |
| tensor_args, |
| non_tensor_args, |
| unflatten_args, |
| ) = cls.process_kernel(kernel, *args, **kwargs) |
| |
| assert tensor_args or isinstance( |
| example_output, torch.Tensor |
| ), "Not sure where to find device info" |
| packed = FallbackKernel( |
| MultiOutputLayout( |
| tensor_args[0].get_device() if tensor_args else example_output.device |
| ), |
| kernel, |
| tensor_args, |
| non_tensor_args, |
| unflatten_args, |
| ) |
| |
| def generate_output(output, indices): |
| if isinstance(output, (list, tuple)): |
| return type(output)( |
| generate_output(output[i], indices + [(type(output), i)]) |
| for i in range(len(output)) |
| ) |
| elif isinstance(output, torch.Tensor): |
| return MultiOutput( |
| FixedLayout( |
| output.device, |
| output.dtype, |
| convert_shape_to_inductor(output.size()), |
| convert_shape_to_inductor(output.stride()), |
| ), |
| packed, |
| indices, |
| ) |
| elif isinstance(output, int): |
| return output |
| else: |
| assert output is None, "FallbackKernel output type is not supported" |
| return None |
| |
| return generate_output(example_output, []) |
| |
| def apply_constraint(self): |
| return super().apply_constraint() |
| |
| |
| @dataclasses.dataclass |
| class MultiOutputLayout(IRNode): |
| device: torch.device |
| |
| |
| class MultiOutput(ExternKernel): |
| def codegen_list_tuple_access(self, basename, indices): |
| if len(indices) > 0: |
| itype, i = indices[0] |
| if itype == list: |
| return self.codegen_list_tuple_access(f"{basename}[{i}]", indices[1:]) |
| elif itype == tuple: |
| # cpp wrapper code needs to use std::get<> to access a tuple |
| tuple_access = V.graph.wrapper_code.codegen_tuple_access( |
| basename, str(i) |
| ) |
| return self.codegen_list_tuple_access(tuple_access, indices[1:]) |
| else: |
| raise AssertionError("non supported index type") |
| else: |
| return basename |
| |
| def codegen(self, wrapper): |
| line = V.graph.wrapper_code.declare |
| line += f"{self.get_name()} = {self.codegen_list_tuple_access(self.inputs[0].get_name(), self.indices)}" |
| line += V.graph.wrapper_code.ending |
| V.graph.wrapper_code.writeline(line) |
| self.codegen_size_asserts(V.graph.wrapper_code) |
| |
| def __init__(self, layout, input, indices: List[Tuple]): |
| super().__init__(None, layout, [input], ()) |
| self.name = V.graph.register_buffer(self) |
| self.indices = indices |
| |
| def should_allocate(self): |
| return False |
| |
| |
| def _prepare_convolution_fusion_create( |
| cls, |
| x: "TensorBox", |
| weight: "TensorBox", |
| bias: "TensorBox", |
| padding: List[int], |
| stride: List[int], |
| dilation: List[int], |
| groups: int, |
| transposed: bool = False, |
| output_padding: List[int] = None, |
| ): |
| """ |
| This function is a helper function to prepare inputs, layout and constant args |
| for convolution post-op fusion's create function, including deciding the output |
| layout (channels first or channels last), realizing inputs and make them etc. The |
| function only supports the CPU device since conv post-op fusion kernel is only |
| supported on CPU right now. |
| """ |
| |
| # Port from aten/src/ATen/native/ConvUtils.h: _conv_input_size |
| def _conv_input_size( |
| output_size, weight_size, padding, output_padding, stride, dilation, groups |
| ): |
| assert len(output_size) == len(weight_size), "Expect input dim == weight dim" |
| dim = len(output_size) |
| assert dim > 2, "Expect input dim > 2" |
| |
| BATCH_DIM = 0 |
| WEIGHT_INPUT_CHANNELS_DIM = 1 |
| input_size = [] |
| input_size.append(output_size[BATCH_DIM]) |
| input_size.append(weight_size[WEIGHT_INPUT_CHANNELS_DIM] * groups) |
| for d in range(2, dim): |
| kernel = (weight_size[d] - 1) * dilation[d - 2] + 1 |
| input_size_d = ( |
| (output_size[d] - 1) * stride[d - 2] |
| - (padding[d - 2] * 2) |
| + kernel |
| + output_padding[d - 2] |
| ) |
| input_size.append(input_size_d) |
| return list(map(int, input_size)) |
| |
| # The size of prepacked_weight is the prepacked weight size of deconv: |
| # Groups > 1: [g*o, i/g, ...] |
| # Groups == 1: [o, i, ...] |
| # Returns original weight size in [i, o, ...] |
| def _original_deconv_weight_size( |
| prepacked_weight, |
| groups, |
| ): |
| prepacked_weight_size = prepacked_weight.size() |
| dim = len(prepacked_weight_size) |
| assert dim > 2, "Expect weight dim > 2" |
| if groups > 1: |
| weight_size = [] |
| weight_size.append(prepacked_weight_size[1] * groups) |
| weight_size.append(prepacked_weight_size[0] / groups) |
| for d in range(2, dim): |
| weight_size.append(prepacked_weight_size[d]) |
| else: |
| weight_size = prepacked_weight.transpose(0, 1).size() |
| return weight_size |
| |
| x.realize() |
| weight.realize() |
| with V.graph.fake_mode: |
| x_fake = ir_node_to_tensor(x, guard_shape=True) |
| weight_fake = ir_node_to_tensor(weight, guard_shape=True) |
| dims = len(x_fake.size()) - 2 |
| assert 0 < len(padding) <= dims |
| assert 0 < len(dilation) <= dims |
| assert 0 < len(stride) <= dims |
| padding = pad_listlike(padding, dims) |
| dilation = pad_listlike(dilation, dims) |
| stride = pad_listlike(stride, dims) |
| if output_padding is None: |
| output_padding = pad_listlike([0], dims) |
| else: |
| assert 0 < len(output_padding) <= dims |
| output_padding = pad_listlike(output_padding, dims) |
| assert isinstance(groups, int) |
| if transposed: |
| # When transposed, the size of the prepacked oneDNN weight is different |
| # from the PyTorch weight. We're not able to run aten conv with such |
| # size. We infer the output size from the input params here: |
| weight_size = _original_deconv_weight_size(weight_fake, groups) |
| input_size = x_fake.size() |
| output_size = _conv_input_size( |
| input_size, |
| weight_size, |
| padding, |
| output_padding, |
| stride, |
| dilation, |
| groups, |
| ) |
| else: |
| bias_fake = ( |
| ir_node_to_tensor(bias, guard_shape=True) if bias is not None else bias |
| ) |
| output = torch.ops.aten.convolution( |
| x_fake, |
| weight_fake, |
| bias_fake, |
| stride, |
| padding, |
| dilation, |
| transposed, |
| output_padding, |
| groups, |
| ) |
| output_size = output.size() |
| |
| req_stride_order = [0] + list(reversed(range(1, len(stride) + 1))) |
| req_stride_order = [len(req_stride_order)] + req_stride_order |
| output_stride = make_channels_last_strides_for(output_size) |
| |
| x = cls.require_stride_order(x, req_stride_order) |
| assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" |
| inputs = [x, weight] |
| |
| kernel_layout = FixedLayout( |
| x.get_device(), |
| x.get_dtype(), |
| convert_shape_to_inductor(output_size), |
| convert_shape_to_inductor(output_stride), |
| ) |
| constant_args = [padding, stride, dilation, groups] |
| if transposed: |
| constant_args.insert(1, output_padding) |
| |
| if bias is not None: |
| inputs.append(bias) |
| else: |
| constant_args.insert(0, bias) |
| return inputs, constant_args, kernel_layout, req_stride_order |
| |
| |
| class ConvolutionUnary(ExternKernelAlloc): |
| kernel = "torch.ops.mkldnn._convolution_pointwise" |
| |
| def __init__( |
| self, |
| layout, |
| inputs, |
| constant_args=(), |
| kernel="torch.ops.mkldnn._convolution_pointwise", |
| ): |
| super().__init__(layout, inputs, constant_args) |
| self.kernel = kernel |
| |
| def codegen(self, wrapper): |
| wrapper.writeline( |
| f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})" |
| ) |
| if isinstance(self.layout, Layout): |
| self.codegen_size_asserts(wrapper) |
| |
| @classmethod |
| def create( |
| cls, |
| x: "TensorBox", |
| weight: "TensorBox", |
| bias: "TensorBox", |
| padding_: List[int], |
| stride_: List[int], |
| dilation_: List[int], |
| groups: int, |
| attr, |
| scalars, |
| algorithm, |
| ): |
| kernel = "torch.ops.mkldnn._convolution_pointwise" |
| (inputs, constant_args, kernel_layout, _) = _prepare_convolution_fusion_create( |
| cls, x, weight, bias, padding_, stride_, dilation_, groups |
| ) |
| constant_args = constant_args + [attr, scalars, algorithm] |
| return ConvolutionUnary( |
| layout=kernel_layout, |
| inputs=inputs, |
| constant_args=constant_args, |
| kernel=kernel, |
| ) |
| |
| |
| class ConvolutionBinary(ExternKernelAlloc): |
| kernel = "torch.ops.mkldnn._convolution_pointwise.binary" |
| |
| def __init__( |
| self, |
| layout, |
| inputs, |
| constant_args=(), |
| kernel="torch.ops.mkldnn._convolution_pointwise.binary", |
| ): |
| super().__init__(layout, inputs, constant_args) |
| self.kernel = kernel |
| |
| def codegen(self, wrapper): |
| wrapper.writeline( |
| f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})" |
| ) |
| if isinstance(self.layout, Layout): |
| self.codegen_size_asserts(wrapper) |
| |
| @classmethod |
| def create( |
| cls, |
| x: "TensorBox", |
| other: "TensorBox", |
| weight: "TensorBox", |
| bias: "TensorBox", |
| padding_: List[int], |
| stride_: List[int], |
| dilation_: List[int], |
| groups: int, |
| binary_attr: str, |
| binary_alpha: Optional[float], |
| unary_attr: Optional[str], |
| unary_scalars: Optional[List], |
| unary_algorithm: Optional[str], |
| ): |
| kernel = "torch.ops.mkldnn._convolution_pointwise.binary" |
| ( |
| inputs, |
| constant_args, |
| kernel_layout, |
| req_stride_order, |
| ) = _prepare_convolution_fusion_create( |
| cls, x, weight, bias, padding_, stride_, dilation_, groups |
| ) |
| other = cls.require_stride_order(other, req_stride_order) |
| inputs.insert(1, other) |
| constant_args = constant_args + [ |
| binary_attr, |
| binary_alpha, |
| unary_attr, |
| unary_scalars, |
| unary_algorithm, |
| ] |
| return ConvolutionBinary( |
| layout=kernel_layout, |
| inputs=inputs, |
| constant_args=constant_args, |
| kernel=kernel, |
| ) |
| |
| |
| class ConvolutionBinaryInplace(ExternKernelAlloc): |
| kernel = "torch.ops.mkldnn._convolution_pointwise_.binary" |
| |
| def __init__( |
| self, |
| kernel_layout, |
| inputs, |
| constant_args=(), |
| kernel="torch.ops.mkldnn._convolution_pointwise_.binary", |
| ): |
| super().__init__(kernel_layout, inputs, constant_args) |
| self.kernel = kernel |
| |
| def codegen(self, wrapper): |
| wrapper.writeline( |
| f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})" |
| ) |
| |
| def get_mutation_names(self): |
| assert isinstance(self.layout, MutationLayout) |
| return (self.layout.target.get_name(),) |
| |
| @classmethod |
| def create( |
| cls, |
| x: "TensorBox", |
| other: "TensorBox", |
| weight: "TensorBox", |
| bias: "TensorBox", |
| padding_: List[int], |
| stride_: List[int], |
| dilation_: List[int], |
| groups: int, |
| binary_attr: str, |
| binary_alpha: Optional[float], |
| unary_attr: Optional[str], |
| unary_scalars: Optional[List], |
| unary_algorithm: Optional[str], |
| ): |
| kernel = "torch.ops.mkldnn._convolution_pointwise_.binary" |
| ( |
| inputs, |
| constant_args, |
| _, |
| req_stride_order, |
| ) = _prepare_convolution_fusion_create( |
| cls, x, weight, bias, padding_, stride_, dilation_, groups |
| ) |
| other = cls.require_stride_order(other, req_stride_order) |
| V.graph.realize_users_of(other.get_name()) |
| inputs.insert(1, other) |
| constant_args = constant_args + [ |
| binary_attr, |
| binary_alpha, |
| unary_attr, |
| unary_scalars, |
| unary_algorithm, |
| ] |
| return ConvolutionBinaryInplace( |
| kernel_layout=MutationLayout(inputs[1]), |
| inputs=inputs, |
| constant_args=constant_args, |
| kernel=kernel, |
| ) |
| |
| |
| class MKLPackedLinear(ExternKernelAlloc): |
| def __init__( |
| self, |
| layout, |
| inputs, |
| constant_args=(), |
| ): |
| super().__init__( |
| layout, |
| inputs, |
| constant_args, |
| None, |
| kernel="torch.ops.mkl._mkl_linear", |
| cpp_kernel="mkl::_mkl_linear", |
| ) |
| self.cpp_kernel_key = "mkl_linear" |
| self.cpp_op_schema = """ |
| at::Tensor( |
| const at::Tensor& self, |
| const at::Tensor& mkl_weight_t, |
| const at::Tensor& origin_weight_t, |
| const c10::optional<at::Tensor>& bias_opt, |
| const int64_t prepack_batch_size)""" |
| |
| def codegen(self, wrapper): |
| wrapper.generate_fusion_ops_code( |
| self.get_name(), |
| self.kernel, |
| self.codegen_args(), |
| self.cpp_op_schema, |
| self.cpp_kernel_key, |
| ) |
| |
| @classmethod |
| def create(cls, x, packed_w, orig_w, batch_size): |
| x = cls.require_stride1(cls.realize_input(x)) |
| orig_w = cls.require_stride1(cls.realize_input(orig_w)) |
| *m, _ = x.get_size() |
| oc, _ = orig_w.get_size() |
| output_size = list(m) + [oc] |
| output_stride = make_contiguous_strides_for(output_size) |
| inputs = [x, packed_w, orig_w] |
| constant_args = [None, batch_size] |
| |
| return MKLPackedLinear( |
| layout=FixedLayout( |
| x.get_device(), x.get_dtype(), output_size, output_stride |
| ), |
| inputs=inputs, |
| constant_args=constant_args, |
| ) |
| |
| |
| class LinearUnary(ExternKernelAlloc): |
| def __init__( |
| self, |
| layout, |
| inputs, |
| constant_args=(), |
| ): |
| super().__init__( |
| layout, |
| inputs, |
| constant_args, |
| None, |
| kernel="torch.ops.mkldnn._linear_pointwise", |
| cpp_kernel="mkldnn::_linear_pointwise", |
| ) |
| self.cpp_kernel_key = "linear_pointwise" |
| self.cpp_op_schema = """ |
| at::Tensor( |
| const at::Tensor& input_t, |
| const at::Tensor& weight_t, |
| const c10::optional<at::Tensor>& bias_opt, |
| c10::string_view attr, |
| torch::List<c10::optional<at::Scalar>> scalars, |
| c10::optional<c10::string_view> algorithm)""" |
| |
| def codegen(self, wrapper): |
| wrapper.generate_fusion_ops_code( |
| self.get_name(), |
| self.kernel, |
| self.codegen_args(), |
| self.cpp_op_schema, |
| self.cpp_kernel_key, |
| ) |
| |
| @classmethod |
| def create(cls, x, w, b, attr, scalars, algorithm): |
| x = cls.require_stride1(cls.realize_input(x)) |
| w = cls.require_stride1(cls.realize_input(w)) |
| |
| *m, ic = x.get_size() |
| oc, ic = w.get_size() |
| |
| inputs = [x, w] |
| constant_args = [attr, scalars if scalars else [-1], algorithm] |
| if b is not None: |
| b = cls.require_stride1(cls.realize_input(b)) |
| inputs.append(b) |
| else: |
| constant_args.insert(0, None) |
| |
| return LinearUnary( |
| layout=FlexibleLayout( |
| device=x.get_device(), |
| dtype=x.get_dtype(), |
| size=list(m) + [oc], |
| ), |
| inputs=inputs, |
| constant_args=constant_args, |
| ) |
| |
| def apply_constraint(self): |
| pass |
| |
| |
| class LinearBinary(ExternKernelAlloc): |
| kernel = "torch.ops.mkldnn._linear_pointwise.binary" |
| |
| def __init__( |
| self, |
| layout, |
| inputs, |
| constant_args=(), |
| ): |
| super().__init__( |
| layout, |
| inputs, |
| constant_args, |
| None, |
| kernel="torch.ops.mkldnn._linear_pointwise.binary", |
| cpp_kernel="mkldnn::_linear_pointwise", |
| ) |
| self.cpp_kernel_overlad_name = "binary" |
| self.cpp_kernel_key = "linear_pointwise_binary" |
| self.cpp_op_schema = """ |
| at::Tensor( |
| const at::Tensor& input_t, |
| const at::Tensor& other_t, |
| const at::Tensor& weight_t, |
| const c10::optional<at::Tensor>& bias_opt, |
| c10::string_view attr) |
| """ |
| |
| def codegen(self, wrapper): |
| wrapper.generate_fusion_ops_code( |
| self.get_name(), |
| self.kernel, |
| self.codegen_args(), |
| self.cpp_op_schema, |
| self.cpp_kernel_key, |
| self.cpp_kernel_overlad_name, |
| ) |
| |
| @classmethod |
| def create(cls, x, y, w, b, attr): |
| x = cls.require_stride1(cls.realize_input(x)) |
| y = cls.require_stride1(cls.realize_input(y)) |
| w = cls.require_stride1(cls.realize_input(w)) |
| |
| *m, ic = x.get_size() |
| oc, ic = w.get_size() |
| |
| inputs = [x, y, w] |
| constant_args = [attr] |
| if b is not None: |
| b = cls.require_stride1(cls.realize_input(b)) |
| inputs.append(b) |
| else: |
| constant_args.insert(0, b) |
| |
| return LinearBinary( |
| layout=FlexibleLayout( |
| device=x.get_device(), |
| dtype=x.get_dtype(), |
| size=list(m) + [oc], |
| ), |
| inputs=inputs, |
| constant_args=constant_args, |
| ) |
| |
| def apply_constraint(self): |
| pass |
| |
| |
| class ConvolutionTransposeUnary(ExternKernelAlloc): |
| kernel = "torch.ops.mkldnn._convolution_transpose_pointwise" |
| |
| def __init__( |
| self, |
| layout, |
| inputs, |
| constant_args=(), |
| kernel="torch.ops.mkldnn._convolution_transpose_pointwise", |
| ): |
| super().__init__(layout, inputs, constant_args) |
| self.kernel = kernel |
| |
| def codegen(self, wrapper): |
| wrapper.writeline( |
| f"{self.get_name()} = {self.kernel}({', '.join(self.codegen_args())})" |
| ) |
| |
| @classmethod |
| def create( |
| cls, |
| x: "TensorBox", |
| weight: "TensorBox", |
| bias: "TensorBox", |
| padding_: List[int], |
| output_padding_: List[int], |
| stride_: List[int], |
| dilation_: List[int], |
| groups_: int, |
| attr, |
| scalars, |
| algorithm, |
| ): |
| kernel = "torch.ops.mkldnn._convolution_transpose_pointwise" |
| transposed = True |
| ( |
| inputs, |
| constant_args, |
| kernel_layout, |
| _, |
| ) = _prepare_convolution_fusion_create( |
| cls, |
| x, |
| weight, |
| bias, |
| padding_, |
| stride_, |
| dilation_, |
| groups_, |
| transposed, |
| output_padding_, |
| ) |
| constant_args = constant_args + [attr, scalars, algorithm] |
| return ConvolutionTransposeUnary( |
| layout=kernel_layout, |
| inputs=inputs, |
| constant_args=constant_args, |
| kernel=kernel, |
| ) |
| |
| |
| @dataclasses.dataclass |
| class MutableBox(IRNode): |
| """ |
| TensorBox / StorageBox allow in-place mutation of Tensors |
| """ |
| |
| data: IRNode |
| |
| def __getattr__(self, name): |
| fn = getattr(self.data, name) |
| if callable(fn): |
| return fn |
| raise AttributeError(f"{type(self.data).__name__}.{name} not callable") |
| |
| @property |
| def layout(self): |
| return self.data.layout |
| |
| def __str__(self): |
| if isinstance(self.data, MutableBox): |
| line0 = f"{type(self).__name__}({type(self.data).__name__}(" |
| endl = "))" |
| inner = self.data.data |
| else: |
| line0 = f"{type(self).__name__}(" |
| inner = self.data |
| endl = ")" |
| |
| lines = [ |
| line0, |
| indent(str(inner)), |
| endl, |
| ] |
| return "\n".join(lines) |
| |
| __repr__ = __str__ |
| |
| |
| class TensorBox(MutableBox): |
| @staticmethod |
| def create(data): |
| return TensorBox(StorageBox(data)) |
| |
| |
| class StorageBox(MutableBox): |
| def is_input_buffer(self): |
| if isinstance(self.data, (InputBuffer, ReinterpretView)): |
| return self.data.get_name() in V.graph.graph_inputs |
| return False |
| |
| def realize(self): |
| if isinstance( |
| self.data, |
| ( |
| ComputedBuffer, |
| InputsKernel, |
| InputBuffer, |
| ReinterpretView, |
| TemplateBuffer, |
| ), |
| ): |
| return self.data.get_name() |
| assert isinstance(self.data, (Pointwise, Reduction)), type(self.data) |
| origin_node = self.data.get_origin_node() |
| traceback = self.data.get_traceback() |
| self.data = ComputedBuffer( |
| name=None, |
| layout=FlexibleLayout( |
| device=self.data.get_device(), |
| dtype=self.data.get_dtype(), |
| size=self.data.get_size(), |
| ), |
| data=self.data, |
| ) |
| self.data.name = V.graph.register_buffer(self.data) |
| self.data.origins = self.origins |
| self.data.origin_node = origin_node |
| self.data.traceback = traceback |
| return self.data.name |
| |
| def realize_hint(self): |
| """ |
| Called on buffers we expect to be forced to realize later. |
| """ |
| if ( |
| isinstance(self.data, (Pointwise, Reduction)) |
| and self.num_reads() > 1 |
| and self.is_pointwise_non_scalar_tensor_num_reads_larger_than_one() |
| ): |
| self.realize() |
| |
| def has_exceeded_max_reads(self): |
| return isinstance(self.data, Pointwise) and ( |
| self.num_reads() > config.realize_acc_reads_threshold |
| or len(self.inner_fn_str()) > config.realize_bytes_threshold |
| ) |
| |
| def mark_reuse(self, users): |
| """ |
| A heuristic to decide if we should realize a tensor |
| that is used multiple times. |
| """ |
| |
| def should_realize_on_cpu(loops: Union[Pointwise, Reduction]): |
| """ |
| The heuristic for realizing reused result of heavy ops on cpu |
| """ |
| heavy_ops = ["exp"] # a list of heavy ops |
| fn_str = loops.inner_fn_str() |
| return any((op + "(") in fn_str for op in heavy_ops) |
| |
| if ( |
| users > 1 |
| and isinstance(self.data, (Pointwise, Reduction)) |
| and ( |
| self.num_reads() > config.realize_reads_threshold |
| or len(self.inner_fn_str()) > config.realize_bytes_threshold |
| or (is_cpu(self.data) and should_realize_on_cpu(self.data)) |
| ) |
| ): |
| self.realize() |
| |
| @cache_on_self |
| def num_reads(self): |
| data = self.data |
| if isinstance(data, (InputsKernel, InputBuffer, ReinterpretView)): |
| return 1 |
| if isinstance(data, ComputedBuffer): |
| read_writes = data.get_read_writes() |
| else: |
| assert isinstance(data, (Pointwise, Reduction)), type(data) |
| read_writes = ComputedBuffer( |
| name=None, |
| layout=FlexibleLayout( |
| device=data.get_device(), |
| dtype=data.get_dtype(), |
| size=data.get_size(), |
| ), |
| data=data, |
| ).get_read_writes() |
| return len(read_writes.reads) |
| |
| @cache_on_self |
| def is_pointwise_non_scalar_tensor_num_reads_larger_than_one(self): |
| # Skip the check for non Pointwise instances |
| return ( |
| (sum(read.index != 0 for read in self.data.get_reads()) > 1) |
| if isinstance(self.data, Pointwise) |
| else True |
| ) |
| |
| |
| class InterpreterShim(torch.fx.Interpreter): |
| @staticmethod |
| @functools.lru_cache(None) |
| def _dummy_gm(): |
| return torch.fx.symbolic_trace(identity) |
| |
| def __init__(self, graph, submodules): |
| # call super() with a placeholder to avoid constructing a |
| # GraphModule which is very expensive (it does codegen). |
| super().__init__(self._dummy_gm(), garbage_collect_values=False) |
| self.module = self |
| self.graph = graph |
| self.submodules = submodules |
| self.extra_traceback = False |
| self.fetch_attr = submodules.__getitem__ |
| self.current_node = None |
| |
| def run_node(self, n: torch.fx.Node) -> Any: |
| self.current_node = n |
| return super().run_node(n) |
| |
| def run(self, *args, **kwargs): |
| with V.set_interpreter_handler(self): |
| return super().run(*args, **kwargs) |
| |
| |
| class LoopBody: |
| """ |
| Captures the body of a Loops subclass into an FX graph. Persists any |
| indexing simplifications and makes it easier to analyze loop bodies. |
| """ |
| |
| def __init__(self, fn, args, var_ranges): |
| super().__init__() |
| self.var_ranges = var_ranges |
| self.indexing_exprs = {} |
| self.indexing_exprs_name = {} |
| self.reads = [] |
| self.writes = [] |
| self.reads_name2expr = {} |
| self.writes_name2expr = {} |
| self.other = [] |
| self.submodules = {"get_index": self.get_index} |
| self.subblocks = {} |
| self.indirect_vars = [] |
| self.indirect_max_sizes = [] |
| self.indirect_new = {} |
| self.root_block = LoopBodyBlock(self, fn, args) |
| self.indexing = None |
| |
| def debug_str(self): |
| lines = [f"var_ranges = {dict(self.var_ranges)}"] |
| lines.extend([f"{name} = {val}" for name, val in self.indexing_exprs.items()]) |
| lines.extend( |
| [ |
| block.debug_str(name) |
| for name, block in itertools.chain( |
| [("body", self.root_block)], self.subblocks.items() |
| ) |
| ] |
| ) |
| return "\n".join(lines) |
| |
| def add_index_expr(self, expr: sympy.Expr, category, buf_name): |
| getattr(self, category).append(expr) |
| if buf_name is not None: |
| getattr(self, f"{category}_name2expr")[buf_name] = expr |
| if expr not in self.indexing_exprs_name: |
| name = f"index{len(self.indexing_exprs)}" |
| self.indexing_exprs_name[expr] = name |
| self.indexing_exprs[name] = expr |
| return self.indexing_exprs_name[expr] |
| |
| def add_submodule(self, block, prefix): |
| """Not actually for nn.Modules, but subblocks in generated code are mapped to FX call_module opcodes""" |
| if prefix[-1].isnumeric() and prefix not in self.submodules: |
| name = prefix |
| else: |
| name = f"{prefix}{len(self.submodules)}" |
| self.submodules[name] = block |
| return name |
| |
| def add_indirect(self, size): |
| name = f"indirect{len(self.indirect_vars)}" |
| var = sympy_symbol(name) |
| self.indirect_vars.append(var) |
| self.indirect_max_sizes.append(size) |
| return var |
| |
| def replace_indirect(self, old, new): |
| """Swap in a variable used in indirect indexing""" |
| if str(old) == str(new): |
| return |
| self.indirect_new[old] = new |
| self.indexing = {k: sympy_subs(v, {old: new}) for k, v in self.indexing.items()} |
| |
| def get_index(self, name): |
| return self.indexing[name] |
| |
| def __call__(self, *indices): |
| index = list(itertools.chain(*indices)) |
| assert len(index) == len(self.var_ranges), (index, self.var_ranges) |
| assert all(v not in self.var_ranges for v in index) |
| replacements = dict(zip(self.var_ranges.keys(), index)) |
| self.indexing = { |
| name: sympy_subs(expr, replacements) |
| for name, expr in self.indexing_exprs.items() |
| } |
| result = self.root_block() |
| self.indexing = None |
| return result |
| |
| |
| class LoopBodyBlock: |
| """ |
| Captures the body of a Loops subclass into an FX graph. |
| In normal cases there will be a 1:1 mapping between LoopBody and |
| LoopBodyBlock, hower in the case of ops.masked() the masked out |
| operations will manifest as an extra LoopBodyBlock. |
| """ |
| |
| def __init__(self, body: LoopBody, fn: Callable, args: List[Any]): |
| self.body = body |
| |
| def add_index(expr, category, buf_name=None): |
| return tracer.create_proxy( |
| "call_module", |
| "get_index", |
| (self.body.add_index_expr(expr, category, buf_name),), |
| {}, |
| ) |
| |
| class CaptureIndexing(V.WrapperHandler): |
| self.name = "CaptureIndexing" |
| |
| def load(self, name: str, index: sympy.Expr): |
| index = add_index(index, "reads", name) |
| return self._inner.load(name, index) |
| |
| def store(self, name, index, value, mode=None): |
| index = add_index(index, "writes", name) |
| return self._inner.store(name, index, value, mode) |
| |
| def reduction(self, name, dtype, src_dtype, reduction_type, index, value): |
| index = add_index(index, "writes", name) |
| return self._inner.reduction( |
| name, dtype, src_dtype, reduction_type, index, value |
| ) |
| |
| def index_expr(self, index, dtype): |
| if isinstance(index, (int, sympy.Integer)): |
| return ops.constant(int(index), dtype) |
| index = add_index(index, "other") |
| return self._inner.index_expr(index, dtype) |
| |
| @staticmethod |
| def masked(mask_proxy, masked_body: Callable, other_proxy): |
| """ |
| Recursively capture the masked out body in another LoopBodyBlock |
| """ |
| |
| def shim(mask, other): |
| return V.ops.masked(mask, subblock, other) |
| |
| name = self.body.add_submodule(shim, "masked_subblock") |
| subblock = LoopBodyBlock(self.body, masked_body, []) |
| self.body.subblocks[name] = subblock |
| return tracer.create_proxy( |
| "call_module", name, (mask_proxy, other_proxy), {} |
| ) |
| |
| @staticmethod |
| def indirect_indexing(index_proxy, size): |
| """ |
| Flow data from tensors into indexing formulas. |
| Introduce a call_module to update the indexing. |
| """ |
| |
| def set_indirect(new_var): |
| self.body.replace_indirect( |
| var, V.ops.indirect_indexing(new_var, size) |
| ) |
| |
| var = self.body.add_indirect(size) |
| tracer.create_proxy( |
| "call_module", |
| self.body.add_submodule(set_indirect, f"set_{var}"), |
| (index_proxy,), |
| {}, |
| ) |
| return var |
| |
| tracer = torch.fx.Tracer() |
| tracer.graph = torch.fx.Graph(tracer_cls=tracer.__class__) |
| proxy_ops = tracer.create_proxy("placeholder", "ops", (), {}) |
| from .sizevars import SimplifyIndexing |
| |
| with V.set_ops_handler( |
| SimplifyIndexing(CaptureIndexing(proxy_ops), self.body.var_ranges) |
| ): |
| tracer.create_proxy("output", "output", (fn(*args),), {}) |
| self.graph = tracer.graph |
| |
| def __call__(self): |
| graph = self.graph |
| submodules = self.body.submodules |
| |
| return InterpreterShim(graph, submodules).run(V.get_ops_handler()) |
| |
| def debug_str(self, name="block"): |
| code = torch.fx.GraphModule(self.body.submodules, self.graph).code |
| return re.sub( |
| # strip `; del var0` suffixes to make output prettier |
| r";[^\n]*", |
| "", |
| code.strip().replace("def forward(", f"def {name}("), |
| ) |
| |
| |
| class Wait(ExternKernelAlloc): |
| """ |
| Wait should not be used by itself. It should always be constructed in tandem |
| with a collective op that produces a work to wait on. |
| """ |
| |
| def __init__( |
| self, |
| layout, |
| inputs, |
| constant_args=(), |
| ): |
| super().__init__(layout, inputs, constant_args) |
| |
| def should_allocate(self): |
| return False |
| |
| def codegen(self, wrapper): |
| wrapper.add_import_once( |
| "from torch.distributed._functional_collectives import _wait_tensor" |
| ) |
| (input_collective,) = [t.codegen_reference() for t in self.inputs] |
| wrapper.writeline(f"{input_collective} = _wait_tensor({input_collective})") |
| |
| # wait op still needs to produce a 'buffer' that represents the tensor output. |
| # this is a symbolic gesture, and it gets handled by WrapperCodegen. |
| # codegen outputs a '# reuse' line that assigns the input buffer here ('input_collective') |
| # to a new name (`self.get_name()`) and `del`s the old name. |
| wrapper.writeline(f"{self.get_name()} = {input_collective}") |
| |
| @classmethod |
| def create(cls, collective_op: "TensorBox"): |
| # TODO(whc) i'm not sure what's going on here, this probably means I missed something upstream |
| collective_op.decide_layout() |
| return Wait( |
| layout=collective_op.get_layout(), |
| inputs=[collective_op], |
| ) |
| |
| def get_alias_names(self): |
| # Signal to codegen that our output buffer isn't safe to reuse |
| return [self.inputs[0].codegen_reference()] |
| |
| |
| class CollectiveKernel(ExternKernel): |
| """ |
| Each CollectiveKernel should follow the patterns |
| - it writes into a given output buffer |
| - the kernel delegates into c10d processgroup, which returns a 'work' obj |
| - the work obj is registered via _register_tensor_work so it can be waited on later |
| """ |
| |
| def __init__(self, layout, inputs, constant_args): |
| super().__init__(None, layout, inputs, constant_args) |
| self.name = V.graph.register_buffer(self) |
| |
| def should_allocate(self): |
| return True |
| |
| def codegen_collective(self, wrapper, output_name, input_names): |
| # factor so the boilerplate can be handled in CollectiveKernel.codegen |
| raise NotImplementedError("Must implement") |
| |
| def codegen(self, wrapper): |
| wrapper.add_import_once("import torch.distributed as dist") |
| wrapper.add_import_once( |
| "from torch.distributed._functional_collectives import _str_to_reduce_op, _register_tensor_work" |
| ) |
| wrapper.add_import_once( |
| "from torch.distributed.distributed_c10d import _find_or_create_pg_by_ranks_and_tag" |
| ) |
| |
| # extract references to our args in string form for codegen output |
| input_names = [t.codegen_reference() for t in self.inputs] |
| output_name = self.get_name() |
| tag, ranks, group_size = self.constant_args |
| |
| # TODO: avoid more than one ref of the same pg (even though they are cached inside the api) |
| wrapper.writeline( |
| f"{output_name}_pg = _find_or_create_pg_by_ranks_and_tag('{tag}', {ranks}, {group_size})" |
| ) |
| |
| self.codegen_collective(wrapper, output_name, input_names) |
| |
| wrapper.writeline(f"_register_tensor_work({output_name}, {output_name}_work)") |
| |
| |
| class MultiOutputNoSizeAssert(MultiOutput): |
| """ |
| Extract partial output from a multi-output OP. |
| Works like MultiOutput but doesn't assert size. This must be a property guaranteed by the op emiting this. |
| """ |
| |
| def codegen(self, wrapper): |
| wrapper.writeline( |
| f"{self.get_name()} = {self.inputs[0].get_name()}{self.index}" |
| ) |
| |
| |
| class InPlaceHint(ExternKernel): |
| """ |
| Helper OP to encode an in/out argument that tries to make it inplace whenever possible. |
| Wrap the input of your inplace op to enable this behavior. |
| |
| The design is based on two key decisions: |
| - this node is resposible for allocating the in/out buffer used by the collective. |
| This is controlled by the ``should_allocate`` method that returns True here and |
| False for the collective node |
| - The scheduler special-case this node and enable it to reuse its input. |
| """ |
| |
| def codegen(self, wrapper): |
| input_name = self.inputs[0].codegen_reference() |
| output_name = self.get_name() |
| if not wrapper.did_reuse(self, self.inputs[0]): |
| wrapper.writeline(f"{output_name}.copy_({input_name}) #no reuse") |
| |
| def __init__(self, layout, input): |
| input = self.realize_input(input) |
| super().__init__(None, layout, self.unwrap_storage([input]), ()) |
| self.name = V.graph.register_buffer(self) |
| |
| def should_allocate(self): |
| return True |
| |
| |
| class AllReduceCoalesced(ExternKernel): |
| def __init__(self, layout, inputs, constant_args, reduce_op): |
| super().__init__(None, layout, inputs, constant_args) |
| self.reduce_op = reduce_op |
| self.name = V.graph.register_buffer(self) |
| |
| def should_allocate(self): |
| return False |
| |
| @classmethod |
| def create( |
| cls, |
| inputs: List["TensorBox"], |
| reduce_op: str, |
| tag: str, |
| ranks: List[int], |
| group_size: int, |
| ): |
| res = [] |
| |
| def wrap_input(var): |
| nonlocal res |
| op = InPlaceHint( |
| FlexibleLayout(var.get_device(), var.get_dtype(), var.get_size()), var |
| ) |
| res.append(op) |
| return TensorBox.create(op) |
| |
| inputs = list(map(wrap_input, inputs)) |
| |
| layout = MultiOutputLayout(inputs[0].get_device()) |
| |
| packed = AllReduceCoalesced( |
| layout=layout, |
| inputs=inputs, |
| constant_args=[tag, ranks, group_size], |
| reduce_op=reduce_op, |
| ) |
| for i, in_t in enumerate(inputs): |
| res.append( |
| MultiOutputNoSizeAssert( |
| FlexibleLayout( |
| in_t.get_device(), in_t.get_dtype(), in_t.get_size() |
| ), |
| packed, |
| f"[{i}]", |
| ) |
| ) |
| return res |
| |
| def codegen(self, wrapper): |
| wrapper.add_import_once("import torch.distributed as dist") |
| wrapper.add_import_once( |
| "from torch.distributed._functional_collectives import _str_to_reduce_op, _register_tensor_work" |
| ) |
| wrapper.add_import_once( |
| "from torch.distributed.distributed_c10d import _find_or_create_pg_by_ranks_and_tag" |
| ) |
| |
| output_name = self.get_name() |
| tag, ranks, group_size = self.constant_args |
| |
| wrapper.writeline( |
| f"{output_name}_pg = _find_or_create_pg_by_ranks_and_tag('{tag}', {ranks}, {group_size})" |
| ) |
| |
| inputs = [] |
| for inp in self.inputs: |
| inputs.append(inp.codegen_reference()) |
| |
| wrapper.writeline(f"{output_name} = [{','.join(inputs)}] ") |
| |
| wrapper.writeline( |
| f"{output_name}_work = dist.all_reduce_coalesced(" |
| f"{output_name}, " |
| f"op=_str_to_reduce_op('{str(self.reduce_op)}'), " |
| f"group={output_name}_pg, " |
| "async_op=True)" |
| ) |
| wrapper.writeline(f"_register_tensor_work({output_name}, {output_name}_work)") |
| |
| |
| class AllReduce(CollectiveKernel): |
| def __init__(self, layout, inputs, constant_args, reduce_op): |
| super().__init__(layout, inputs, constant_args) |
| self.reduce_op = reduce_op |
| |
| @classmethod |
| def create( |
| cls, x: "TensorBox", reduce_op: str, tag: str, ranks: List[int], group_size: int |
| ): |
| x = cls.realize_input(x) |
| |
| # is there a difference between literally using x.data.layout below, vs |
| # creating a new one that has the same properties? |
| new_layout = FlexibleLayout(x.get_device(), x.get_dtype(), x.get_size()) |
| |
| return AllReduce( |
| layout=new_layout, |
| inputs=[x], |
| constant_args=[tag, ranks, group_size], |
| reduce_op=reduce_op, |
| ) |
| |
| def codegen_collective(self, wrapper, output_name, input_names): |
| # We must copy our input buffer sometimes, but the scheduler will help us find opportunities |
| # to reuse the input buffer. (This requires no other users of the input buffer.) |
| if not wrapper.did_reuse(self, self.inputs[0]): |
| wrapper.writeline(f"{output_name}.copy_({input_names[0]})") |
| |
| # At this point, output_name points to a buffer that is either |
| # (1) the input buffer, which we're allowed to inplace modify |
| # (2) a freshly allocated buffer, which we've copied the input into above |
| wrapper.writeline( |
| f"{output_name}_work = dist.all_reduce(" |
| f"{output_name}, async_op=True, group={output_name}_pg, op=_str_to_reduce_op('{str(self.reduce_op)}'))" |
| ) |
| |
| |
| class AllGatherIntoTensor(CollectiveKernel): |
| def __init__(self, layout, inputs, constant_args): |
| super().__init__(layout, inputs, constant_args) |
| |
| @classmethod |
| def create(cls, x: "TensorBox", tag: str, ranks: List[int], group_size: int): |
| x = cls.realize_input(x) |
| |
| # is there a difference between literally using x.data.layout below, vs |
| # creating a new one that has the same properties? |
| new_size = x.get_size() |
| new_size[0] *= group_size |
| new_layout = FlexibleLayout(x.get_device(), x.get_dtype(), new_size) |
| |
| # AllReduce returns a 'work' object. But Inductor's scheduler doesn't need to know |
| # about that, and we just pretend for scheduling purposes that the work obj is a 1-elem tensor. |
| # Nobody should consume the output of AllReduce except 'Wait', which we control here. |
| return AllGatherIntoTensor( |
| layout=new_layout, |
| inputs=[x], |
| constant_args=[tag, ranks, group_size], |
| ) |
| |
| def codegen_collective(self, wrapper, output_name, input_names): |
| wrapper.writeline( |
| f"{output_name}_work = dist.all_gather_into_tensor(" |
| f"{output_name}, {input_names[0]}, async_op=True, group={output_name}_pg)" |
| ) |
| |
| # At this point, output_name points to a fresh buffer |
| wrapper.writeline( |
| f"{output_name}_work = dist.all_gather_into_tensor({output_name}, {input_names[0]}, async_op=True," |
| f" group={output_name}_pg)" |
| ) |
| wrapper.writeline(f"_register_tensor_work({output_name}, {output_name}_work)") |
| |
| |
| class ReduceScatterTensor(CollectiveKernel): |
| def __init__(self, layout, inputs, constant_args, reduce_op): |
| super().__init__(layout, inputs, constant_args) |
| self.reduce_op = reduce_op |
| |
| @classmethod |
| def create( |
| cls, |
| x: "TensorBox", |
| reduce_op: str, |
| tag: str, |
| ranks: List[int], |
| group_size: int, |
| ): |
| x = cls.realize_input(x) |
| |
| # is there a difference between literally using x.data.layout below, vs |
| # creating a new one that has the same properties? |
| new_size = x.get_size() |
| new_size[0] /= group_size |
| new_layout = FlexibleLayout(x.get_device(), x.get_dtype(), new_size) |
| |
| return ReduceScatterTensor( |
| layout=new_layout, |
| inputs=[x], |
| constant_args=[tag, ranks, group_size], |
| reduce_op=reduce_op, |
| ) |
| |
| def codegen_collective(self, wrapper, output_name, input_names): |
| wrapper.writeline( |
| f"{output_name}_work = dist.reduce_scatter_tensor(" |
| f"{output_name}, {input_names[0]}, " |
| f"async_op=True, group={output_name}_pg, op=_str_to_reduce_op('{str(self.reduce_op)}'))" |
| ) |