| # mypy: allow-untyped-defs |
| from __future__ import annotations |
| |
| import collections |
| import contextlib |
| import dataclasses |
| import enum |
| import functools |
| import inspect |
| import io |
| import itertools |
| import json |
| import logging |
| import math |
| import operator |
| import os |
| import platform |
| import shutil |
| import sys |
| import tempfile |
| import textwrap |
| import time |
| import unittest |
| from datetime import datetime |
| from io import StringIO |
| from pathlib import Path |
| from typing import ( |
| Any, |
| Callable, |
| Dict, |
| Generic, |
| Iterable, |
| List, |
| NamedTuple, |
| Optional, |
| Protocol, |
| Set, |
| Tuple, |
| TypeVar, |
| Union, |
| ValuesView, |
| ) |
| from typing_extensions import Concatenate, ParamSpec |
| from unittest import mock |
| |
| import sympy |
| |
| import torch |
| import torch._export |
| import torch.utils._pytree as pytree |
| from torch._dynamo.device_interface import get_interface_for_device |
| from torch._dynamo.utils import detect_fake_mode |
| from torch.autograd import DeviceType |
| from torch.autograd.profiler_util import EventList |
| from torch.fx.passes.shape_prop import ShapeProp |
| from torch.utils._sympy.functions import CeilDiv, CleanDiv, FloorDiv, ModularIndexing |
| from torch.utils._sympy.symbol import make_symbol, SymT |
| from torch.utils._sympy.value_ranges import bound_sympy, ValueRanges |
| from . import config |
| from .runtime.runtime_utils import cache_dir, ceildiv as runtime_ceildiv |
| |
| log = logging.getLogger(__name__) |
| |
| _T = TypeVar("_T") |
| VarRanges = Dict[sympy.Expr, sympy.Expr] |
| |
| GPU_ALIGN_BYTES = 16 |
| |
| ALIGN_BYTES = 64 |
| assert (ALIGN_BYTES & (ALIGN_BYTES - 1)) == 0 and ALIGN_BYTES >= 8, "must be power of 2" |
| |
| |
| def _align(nbytes): |
| """Round up to the nearest multiple of ALIGN_BYTES""" |
| return (nbytes + ALIGN_BYTES - 1) & -ALIGN_BYTES |
| |
| |
| def _is_aligned(v: sympy.Expr): |
| """v can be statically proven to be a multiple of ALIGN_BYTES""" |
| if isinstance(v, (sympy.Add, sympy.Max)): |
| return all(map(_is_aligned, v.args)) |
| return isinstance(v, align) or sympy.gcd(v, ALIGN_BYTES) == ALIGN_BYTES |
| |
| |
| class align(sympy.Function): |
| """Symbolically round up to the nearest multiple of ALIGN_BYTES""" |
| |
| nargs = (1,) |
| is_integer = True |
| |
| @classmethod |
| def eval(cls, value): |
| if isinstance(value, (int, sympy.Integer)): |
| return _align(int(value)) |
| if _is_aligned(value): |
| return value |
| |
| |
| def do_bench_using_profiling(fn: Callable[[], Any], warmup=25, rep=100) -> float: |
| """ |
| Returns benchmark results by examining torch profiler events. |
| This could be more accurate as it doesn't count CPU side overhead. |
| However, this also requires manually excluding irrelevant event, e.g. |
| vectorized_elementwise_kernel which is used to fill L2 cache, |
| various CUDA events, etc, so could also be fragile. |
| """ |
| |
| fn() |
| torch.cuda.synchronize() |
| cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda") |
| |
| # Estimate the runtime of the function |
| start_event = torch.cuda.Event(enable_timing=True) |
| end_event = torch.cuda.Event(enable_timing=True) |
| start_event.record() |
| for _ in range(5): |
| cache.zero_() |
| fn() |
| end_event.record() |
| torch.cuda.synchronize() |
| estimate_ms = start_event.elapsed_time(end_event) / 5 |
| |
| # compute number of warmup and repeat |
| n_warmup = max(1, int(warmup / estimate_ms)) |
| n_repeat = max(1, int(rep / estimate_ms)) |
| |
| # Warm-up |
| for _ in range(n_warmup): |
| fn() |
| |
| with torch.profiler.profile( |
| activities=[ |
| torch.profiler.ProfilerActivity.CUDA, |
| ] |
| ) as p: |
| # Benchmark |
| for i in range(n_repeat): |
| # we clear the L2 cache before each run |
| cache.zero_() |
| # record time of `fn` |
| fn() |
| # Record clocks |
| torch.cuda.synchronize() |
| |
| log.debug("raw events") |
| log.debug(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) |
| |
| filtered_events = EventList( |
| [ |
| event |
| for event in p.events() |
| if event.device_type == DeviceType.CUDA and event.name != "Context Sync" |
| ] |
| ) |
| if len(filtered_events) % n_repeat != 0: |
| raise RuntimeError( |
| "Failed to divide all profiling events into #repeat groups. " |
| "#CUDA events: %d, #repeats: %s", |
| len(filtered_events), |
| n_repeat, |
| ) |
| num_event_per_group = len(filtered_events) / n_repeat |
| actual_events = EventList( |
| [ |
| event |
| for i, event in enumerate(filtered_events) |
| if i % num_event_per_group != 0 |
| ] |
| ) |
| actual_events._build_tree() |
| actual_events = actual_events.key_averages() |
| |
| log.debug("profiling time breakdown") |
| log.debug(actual_events.table(row_limit=-1)) |
| |
| res = sum(event.device_time_total for event in actual_events) / 1000.0 / n_repeat |
| log.debug("profiling results: %s ms", res) |
| return res |
| |
| |
| @functools.lru_cache(None) |
| def has_torchvision_roi_align() -> bool: |
| try: |
| from torchvision.ops import roi_align # noqa: F401 |
| |
| torch._C._dispatch_has_kernel_for_dispatch_key("torchvision::nms", "Meta") |
| return roi_align is not None and hasattr( |
| getattr(torch.ops, "torchvision", None), "roi_align" |
| ) |
| except ImportError: |
| return False |
| except RuntimeError as e: |
| assert "torchvision::nms does not exist" in str(e) |
| return False |
| |
| |
| def decode_device(device: Union[Optional[torch.device], str]) -> torch.device: |
| if device is None: |
| return torch.tensor(0.0).device # default device |
| if isinstance(device, str): |
| device = torch.device(device) |
| if device.type not in ("cpu", "meta") and device.index is None: |
| device_interface = get_interface_for_device(device.type) |
| return torch.device(device.type, index=device_interface.Worker.current_device()) |
| return device |
| |
| |
| def sympy_product(it): |
| return functools.reduce(operator.mul, it, sympy.Integer(1)) |
| |
| |
| def sympy_dot(seq1, seq2): |
| assert len(seq1) == len(seq2) |
| return sympy.expand(sum(a * b for a, b in zip(seq1, seq2))) |
| |
| |
| def unique(it: Iterable[_T]) -> ValuesView[_T]: |
| return {id(x): x for x in it}.values() |
| |
| |
| def ceildiv( |
| numer: Union[int, sympy.Expr], denom: Union[int, sympy.Expr] |
| ) -> Union[int, sympy.Expr]: |
| if isinstance(numer, sympy.Expr) or isinstance(denom, sympy.Expr): |
| return CeilDiv(numer, denom) |
| # TODO: There is a bug in a call to this function, to repro: |
| # python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy |
| # --amp --only YituTechConvBert --dynamic-shapes |
| assert isinstance(numer, int) and isinstance( |
| denom, int |
| ), f"{numer}: {type(numer)}, {denom}: {type(denom)}" |
| return runtime_ceildiv(numer, denom) |
| |
| |
| def _type_of(key): |
| # Use the function here to get rid of dependencies on the Triton during the codegen. |
| # Refer to Triton implementation here: |
| # https://github.com/openai/triton/blob/98b5945d2aef679e00ebca8e07c35c3658ec76de/python/triton/runtime/jit.py#L238 |
| # `None` is nullptr. Implicitly convert to *i8. |
| if key is None: |
| return "*i8" |
| dtype_str = str(key).split(".")[-1] |
| tys = { |
| "bool": "i1", |
| "float8e4nv": "fp8e4nv", |
| "float8e5": "fp8e5", |
| "float8e4b15": "fp8e4b15", |
| "float8e4b15x4": "fp8e4b15x4", |
| "float8_e4m3fn": "fp8e4nv", |
| "float8_e5m2": "fp8e5", |
| "float16": "fp16", |
| "bfloat16": "bf16", |
| "float32": "fp32", |
| "float64": "fp64", |
| "int8": "i8", |
| "int16": "i16", |
| "int32": "i32", |
| "int64": "i64", |
| "uint8": "u8", |
| "uint16": "u16", |
| "uint32": "u32", |
| "uint64": "u64", |
| } |
| # reinterpret can create triton type |
| for v in list(tys.values()): |
| tys[v] = v |
| return key if isinstance(key, str) else f"*{tys[dtype_str]}" |
| |
| |
| def convert_shape_to_inductor( |
| lst: Iterable[Union[int, torch.SymInt]] |
| ) -> List[sympy.Expr]: |
| """ |
| Gets the shape and stride of a tensor. For non-symbolic tensors, this is |
| trivial. But for symbolic tensors, we need to map from SymIntNode into |
| sympy.Expr. |
| """ |
| return [ |
| i.node.expr if isinstance(i, torch.SymInt) else sympy.Integer(i) for i in lst |
| ] |
| |
| |
| def convert_shape_to_symint( |
| lst: Iterable[Union[int, sympy.Expr]] |
| ) -> List[Union[int, torch.SymInt]]: |
| """ |
| Takes a list of shapes from Inductor and converts them into symints (or just |
| ints if all shapes are static). |
| """ |
| from .virtualized import V |
| |
| return [ |
| i |
| if isinstance(i, int) |
| else int(i) |
| if isinstance(i, sympy.Integer) |
| else V.graph.sizevars.shape_env.create_symintnode(i, hint=None) |
| for i in lst |
| ] |
| |
| |
| def is_view(op: torch._ops.OpOverload): |
| """ |
| Does this op overload have aliasing |
| """ |
| assert isinstance(op, torch._ops.OpOverload) |
| return any(a.alias_info is not None for a in op._schema.arguments) |
| |
| |
| def is_pointwise_use(use): |
| if not use.op == "call_function": |
| return False |
| |
| if not ( |
| isinstance(use.target, torch._ops.OpOverload) or use.target is operator.getitem |
| ): |
| return False |
| |
| if use.target is operator.getitem or is_view(use.target): |
| return all(is_pointwise_use(u) for u in use.users) |
| |
| return torch.Tag.pointwise in use.target.tags |
| |
| |
| def gen_gm_and_inputs(target, args, kwargs): |
| g = torch.fx.Graph() |
| g_args = [] |
| a_args = [] |
| for n, arg in enumerate(args): |
| if isinstance(arg, torch.Tensor): |
| g_args.append(g.placeholder(f"arg{n}")) |
| a_args.append(arg) |
| else: |
| g_args.append(arg) |
| assert all(not isinstance(x, torch.Tensor) for x in kwargs.values()) |
| node = g.call_function(target, tuple(g_args), kwargs) |
| if ( |
| len(target._schema.returns) == 1 |
| and str(target._schema.returns[0].type) == "Tensor" |
| ): |
| node = (node,) |
| g.output(node) |
| |
| gm = torch.fx.GraphModule({}, g) |
| return gm, a_args |
| |
| |
| def synchronize(device: str = "cuda"): |
| if device == "cpu": |
| return |
| device_interface = get_interface_for_device(device) |
| if device_interface.is_available(): |
| device_interface.synchronize() |
| |
| |
| def timed( |
| model: Callable[..., Any], example_inputs, times: int = 1, device: str = "cuda" |
| ) -> float: |
| synchronize(device) |
| torch.manual_seed(1337) |
| t0 = time.perf_counter() |
| for _ in range(times): |
| result = model(*example_inputs) |
| synchronize(device) |
| t1 = time.perf_counter() |
| # GC the result after timing |
| assert result is not None # type: ignore[possibly-undefined] |
| return t1 - t0 |
| |
| |
| def print_performance( |
| fn, args=(), times=10, repeat=10, baseline=1.0, device: str = "cuda" |
| ): |
| timings = torch.tensor([timed(fn, args, times, device) for _ in range(repeat)]) |
| took = torch.median(timings) / times |
| print(f"{took / baseline:.6f}") |
| return took |
| |
| |
| def precompute_method(obj: Any, method: str): |
| """Replace obj.method() with a new method that returns a precomputed constant.""" |
| result = getattr(obj, method)() |
| setattr(obj, method, lambda: result) |
| |
| |
| def precompute_methods(obj: Any, methods: List[str]): |
| """Replace methods with new methods that returns a precomputed constants.""" |
| for method in methods: |
| precompute_method(obj, method) |
| |
| |
| def cmp(a, b) -> int: |
| return int(a > b) - int(a < b) |
| |
| |
| def pad_listlike(x, size): |
| if len(x) == 1: |
| return type(x)([x[0]]) * size |
| else: |
| return x |
| |
| |
| # Used to ensure that iterating over a set is deterministic |
| def tuple_sorted(x): |
| if len(x) == 0: |
| return [] |
| |
| def sort_func(elem): |
| if isinstance(elem, str): |
| return elem |
| else: |
| # We expect `elem` to be `scheduler.BaseSchedulerNode` type here, |
| # but we are not able to do isinstance assert because of circular dependency |
| return elem.get_name() |
| |
| return sorted(x, key=sort_func) |
| |
| |
| P = ParamSpec("P") |
| RV = TypeVar("RV", covariant=True) |
| |
| |
| class CachedMethod(Protocol, Generic[P, RV]): |
| @staticmethod |
| def clear_cache(self) -> None: |
| ... |
| |
| def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV: |
| ... |
| |
| |
| # See https://github.com/python/mypy/issues/13222#issuecomment-1193073470 to understand the type signature |
| def cache_on_self(fn: Callable[Concatenate[Any, P], RV]) -> CachedMethod[P, RV]: |
| key = f"__{fn.__name__}_cache" |
| |
| @functools.wraps(fn) |
| def wrapper(self): |
| if not hasattr(self, key): |
| setattr(self, key, fn(self)) |
| return getattr(self, key) |
| |
| def clear_cache(self): |
| if hasattr(self, key): |
| delattr(self, key) |
| |
| wrapper.clear_cache = clear_cache # type: ignore[attr-defined] |
| return wrapper # type: ignore[return-value] |
| |
| |
| def aggregate_origins(node_schedule): |
| from . import ir |
| |
| if isinstance(node_schedule, list): |
| return functools.reduce( |
| operator.or_, |
| [ |
| node.node.origins |
| for node in node_schedule |
| if hasattr(node, "node") and node.node |
| ], |
| set(), |
| ) |
| elif isinstance(node_schedule, ir.ExternKernel): |
| return node_schedule.origins |
| else: |
| return set() |
| |
| |
| def get_fused_kernel_name(node_schedule, descriptive_names): |
| all_origins = aggregate_origins(node_schedule) |
| if descriptive_names == "original_aten": |
| # Bases the kernel name off of the top-level aten operator (i.e. pre-decompositions) |
| sources = [ |
| origin.meta["original_aten"]._overloadpacket.__name__ |
| for origin in all_origins |
| if origin.op == "call_function" |
| and "original_aten" in origin.meta |
| and origin.meta["original_aten"] is not None |
| ] |
| sources = sorted(set(sources)) |
| elif descriptive_names == "torch": |
| # Bases the kernel name off of the top-level "torch" operator (i.e. post-dynamo graph) |
| sources = [] |
| for origin in all_origins: |
| if origin.op == "call_function" and "source_fn_stack" in origin.meta: |
| source_fn = origin.meta["source_fn_stack"][-1] |
| if isinstance(source_fn[1], str): |
| sources.append(source_fn[1]) |
| else: |
| sources.append(source_fn[1].__name__) |
| sources = sorted(set(sources)) |
| elif descriptive_names == "inductor_node": |
| sources = [ |
| origin.name for origin in all_origins if origin.op == "call_function" |
| ] |
| else: |
| raise NotImplementedError |
| sources = sources |
| return "_".join(["fused"] + sources) |
| |
| |
| def get_kernel_metadata(node_schedule, wrapper): |
| all_origins = aggregate_origins(node_schedule) |
| inductor_nodes = [origin for origin in all_origins if origin.op == "call_function"] |
| |
| from_node_dict = collections.defaultdict(list) |
| original_aten_dict = collections.defaultdict(list) |
| for node in inductor_nodes: |
| if "original_aten" in node.meta and node.meta["original_aten"] is not None: |
| key = str(node.meta["original_aten"]._overloadpacket) |
| original_aten_dict[key].append(node.name) |
| if "from_node" in node.meta: |
| key = node.meta["from_node"][0][0] |
| from_node_dict[key].append(node.name) |
| metadata = ( |
| f"{wrapper.comment} Source Nodes: [{', '.join(sorted(from_node_dict.keys()))}], " |
| f"Original ATen: [{', '.join(sorted(original_aten_dict.keys()))}]" |
| ) |
| # trace back to original node here |
| detailed_metadata = [] |
| for original_node, nodes in sorted(from_node_dict.items()): |
| detailed_metadata.append( |
| f"{wrapper.comment} {original_node} => {', '.join(sorted(nodes))}" |
| ) |
| return metadata, "\n".join(detailed_metadata) |
| |
| |
| def dominated_nodes( |
| initial_queue: Iterable[torch.fx.Node], skip_filter=None |
| ) -> Set[torch.fx.Node]: |
| """Returns the set of nodes whose values depend on those within initial_queue""" |
| initial_queue = list(initial_queue) |
| dominated_set = set(initial_queue) |
| |
| while initial_queue: |
| node = initial_queue.pop() |
| for user in node.users: |
| if skip_filter and skip_filter(user): |
| continue |
| if user not in dominated_set: |
| dominated_set.add(user) |
| initial_queue.append(user) |
| |
| return dominated_set |
| |
| |
| def gather_origins(args, kwargs): |
| import itertools |
| |
| from . import ir |
| |
| def is_unrealized_node(n): |
| if isinstance(n, ir.TensorBox): |
| return is_unrealized_node(n.data) |
| if isinstance(n, ir.StorageBox): |
| return is_unrealized_node(n.data) |
| return isinstance(n, ir.IRNode) and isinstance(n, ir.Pointwise) |
| |
| kwarg_origins = [val.origins for val in kwargs.values() if is_unrealized_node(val)] |
| arg_origins = [arg.origins for arg in args if is_unrealized_node(arg)] |
| return set(itertools.chain(*arg_origins, *kwarg_origins)) |
| |
| |
| def sympy_str(expr: sympy.Expr) -> str: |
| """ |
| Normal sympy str is very slow, this is a lot faster. The result are |
| somewhat worse, as it doesn't do as much simplification. So don't |
| use this for final codegen. |
| """ |
| if isinstance(expr, sympy.Symbol): |
| return expr.name |
| if isinstance(expr, sympy.Add): |
| return " + ".join(map(sympy_str, expr.args)) |
| if isinstance(expr, sympy.Mul): |
| return " * ".join(map(sympy_str, expr.args)) |
| |
| if isinstance(expr, (ModularIndexing, CleanDiv, FloorDiv)): |
| return f"{expr.func.__name__}({', '.join(map(sympy_str, expr.args))})" |
| return str(expr) |
| |
| |
| def get_bounds_index_expr(index): |
| from .virtualized import V |
| |
| # If this expression does not come from an FX node, we compute its bounds |
| if ( |
| config.compute_all_bounds |
| and (fx_node := getattr(V.interpreter, "current_node", None)) |
| and fx_node.target != "index_expr" |
| ): |
| return bound_sympy(index) |
| else: |
| return ValueRanges.unknown() |
| |
| |
| def sympy_index_symbol_with_prefix(prefix: SymT, idx: int) -> sympy.Symbol: |
| """ |
| Used to generate an integer-nonnegative symbol. |
| """ |
| # This should never be used for creating shape/stride symbols, as those |
| # should all be allocated before Inductor. |
| assert prefix != SymT.SIZE |
| # NOTE: shape symbols are positive (> 0), but index variables are only |
| # non-negative (>= 0). |
| return make_symbol(prefix, idx, integer=True, nonnegative=True) |
| |
| |
| def generate_assert(check): |
| return (check or config.debug_index_asserts) and config.assert_indirect_indexing |
| |
| |
| def sympy_index_symbol(name: str) -> sympy.Symbol: |
| """ |
| Used to generate an integer-nonnegative symbol. |
| """ |
| # This should never be used for creating shape/stride symbols, as those |
| # should all be allocated before Inductor. |
| assert name[0] != "s" |
| # NOTE: shape symbols are positive (> 0), but index variables are only |
| # non-negative (>= 0). |
| return sympy.Symbol(name, integer=True, nonnegative=True) |
| |
| |
| def sympy_subs(expr: sympy.Expr, replacements: Dict[sympy.Expr, Any]) -> sympy.Expr: |
| """ |
| When the passed replacement symbol v is a string, it is converted to a symbol with name v that |
| have the same replaced expression integer and nonnegative properties. |
| """ |
| |
| def to_symbol(replaced, replacement): |
| assert isinstance(replaced, sympy.Expr) |
| if isinstance(replacement, str): |
| return sympy.Symbol( |
| replacement, |
| integer=replaced.is_integer, # type: ignore[attr-defined] |
| nonnegative=replaced.is_nonnegative, # type: ignore[attr-defined] |
| ) |
| else: |
| return replacement |
| |
| # xreplace is faster than subs, but is way more picky |
| return sympy.sympify(expr).xreplace( |
| {k: to_symbol(k, v) for k, v in replacements.items()} |
| ) |
| |
| |
| def is_symbolic(a: Any) -> bool: |
| return isinstance(a, torch.SymInt) or ( |
| isinstance(a, torch.Tensor) |
| and any(is_symbolic(x) for x in itertools.chain(a.size(), a.stride())) |
| ) |
| |
| |
| def any_is_symbolic(*args: Any) -> bool: |
| return any(is_symbolic(a) for a in args) |
| |
| |
| def get_first_incompatible_cudagraph_node(gm): |
| from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols |
| |
| forbidden_set = { |
| "aten._fused_moving_avg_obs_fq_helper.default", |
| "aten._fused_moving_avg_obs_fq_helper_functional.default", |
| "aten.multinomial.default", |
| "fbgemm.dense_to_jagged.default", |
| "fbgemm.jagged_to_padded_dense.default", |
| "run_and_save_rng_state", |
| "run_with_rng_state", |
| "aten._local_scalar_dense", |
| # Technically, it's not necessary to ban this, because an |
| # assert_scalar with constant arguments can be validly run |
| # with CUDA graphs, but the operator is also pointless with |
| # constant arguments, so might as well ban |
| "aten._assert_scalar", |
| } |
| if torch.are_deterministic_algorithms_enabled(): |
| forbidden_set.update( |
| { |
| "aten._unsafe_index_put.default", |
| "aten.index_put.default", |
| "aten.index_put_.default", |
| "aten.scatter.src", |
| "aten.scatter.reduce", |
| "aten.scatter.value_reduce", |
| "aten.scatter_add_", |
| "aten.scatter_add.default", |
| "aten.scatter_reduce.two", |
| "aten.scatter_reduce_.two", |
| "aten.scatter_reduce.two_out", |
| } |
| ) |
| for node in gm.graph.nodes: |
| if str(node.target) in forbidden_set: |
| return node |
| if (val := node.meta.get("val")) is not None and free_unbacked_symbols(val): |
| return node |
| return None |
| |
| |
| def has_incompatible_cudagraph_ops(gm): |
| return get_first_incompatible_cudagraph_node(gm) is not None |
| |
| |
| def output_node(gm: torch.fx.GraphModule): |
| """Get the output node from an FX graph""" |
| last_node = next(iter(reversed(gm.graph.nodes))) |
| assert last_node.op == "output" |
| return last_node |
| |
| |
| _registered_caches: List[Any] = [] |
| |
| |
| def clear_on_fresh_inductor_cache(obj: Any): |
| """ |
| Use this decorator to register any caches that should be cache_clear'd |
| with fresh_inductor_cache(). |
| """ |
| if not hasattr(obj, "cache_clear") or not callable(obj.cache_clear): |
| raise AttributeError(f"{obj} does not have a cache_clear method") |
| |
| _registered_caches.append(obj) |
| return obj |
| |
| |
| def clear_inductor_caches(): |
| """ |
| Clear all registered caches. |
| """ |
| for obj in _registered_caches: |
| obj.cache_clear() |
| |
| |
| @contextlib.contextmanager |
| def fresh_inductor_cache(cache_entries=None): |
| """ |
| Contextmanager that provides a clean tmp cachedir for inductor. |
| |
| Optionally, pass a dict as 'cache_entries' to get a list of filenames and sizes |
| generated with this cache instance. |
| """ |
| clear_inductor_caches() |
| |
| inductor_cache_dir = tempfile.mkdtemp() |
| try: |
| with mock.patch.dict( |
| os.environ, {"TORCHINDUCTOR_CACHE_DIR": inductor_cache_dir} |
| ): |
| triton_cache_dir = os.path.join(inductor_cache_dir, "triton") |
| with mock.patch.dict(os.environ, {"TRITON_CACHE_DIR": triton_cache_dir}): |
| yield |
| if isinstance(cache_entries, dict): |
| assert len(cache_entries) == 0, "expected empty cache_entries dict" |
| if os.path.exists(triton_cache_dir): |
| files = os.listdir(triton_cache_dir) |
| cache_entries.update( |
| { |
| f: os.path.getsize(os.path.join(triton_cache_dir, f)) |
| for f in files |
| if ".lock" not in f |
| } |
| ) |
| shutil.rmtree(inductor_cache_dir) |
| except Exception: |
| log.warning("on error, temporary cache dir kept at %s", inductor_cache_dir) |
| raise |
| finally: |
| clear_inductor_caches() |
| |
| |
| def argsort(seq) -> List[int]: |
| # preserve original order for equal strides |
| getter = seq.__getitem__ |
| a_r = range(len(seq)) |
| return list(reversed(sorted(a_r, key=getter, reverse=True))) # noqa: C413 |
| |
| |
| @functools.lru_cache(8) |
| def get_dtype_size(dtype): |
| return torch.empty((), dtype=dtype).element_size() |
| |
| |
| class LineContext(NamedTuple): |
| context: Any |
| |
| |
| class IndentedBuffer: |
| tabwidth = 4 |
| |
| def __init__(self, initial_indent=0): |
| self._lines = [] |
| self._indent = initial_indent |
| |
| def getvaluewithlinemap(self) -> tuple[str, list[tuple[int, LineContext]]]: |
| buf = StringIO() |
| p = 1 |
| linemap = [] |
| for line in self._lines: |
| if isinstance(line, DeferredLineBase): |
| line = line() |
| if line is None: |
| continue |
| elif isinstance(line, LineContext): |
| linemap.append((p, line.context)) |
| continue |
| assert isinstance(line, str) |
| buf.write(line) |
| buf.write("\n") |
| p += 1 + line.count("\n") |
| return buf.getvalue(), linemap |
| |
| def getvalue(self) -> str: |
| v, _ = self.getvaluewithlinemap() |
| return v |
| |
| def getrawvalue(self) -> str: |
| buf = StringIO() |
| for line in self._lines: |
| if isinstance(line, DeferredLineBase): |
| line = line() |
| if line is None: |
| continue |
| elif isinstance(line, LineContext): |
| continue |
| assert isinstance(line, str) |
| # backslash implies line continuation |
| if line.endswith("\\"): |
| buf.write(line[:-1]) |
| else: |
| buf.write(line) |
| buf.write("\n") |
| return buf.getvalue() |
| |
| def clear(self): |
| self._lines.clear() |
| |
| def __bool__(self): |
| return bool(self._lines) |
| |
| def prefix(self): |
| return " " * (self._indent * self.tabwidth) |
| |
| def newline(self): |
| self.writeline("\n") |
| |
| def writeline(self, line): |
| if isinstance(line, LineContext): |
| self._lines.append(line) |
| elif isinstance(line, DeferredLineBase): |
| self._lines.append(line.with_prefix(self.prefix())) |
| elif line.strip(): |
| self._lines.append(f"{self.prefix()}{line}") |
| else: |
| self._lines.append("") |
| |
| def writelines(self, lines): |
| for line in lines: |
| self.writeline(line) |
| |
| def indent(self, offset=1): |
| @contextlib.contextmanager |
| def ctx(): |
| self._indent += offset |
| try: |
| yield |
| finally: |
| self._indent -= offset |
| |
| return ctx() |
| |
| def do_indent(self, offset=1): |
| self._indent += offset |
| |
| def do_unindent(self, offset=1): |
| self._indent -= offset |
| |
| def splice(self, other_code, strip=False): |
| if isinstance(other_code, IndentedBuffer): |
| dedent = float("inf") |
| for line in other_code._lines: |
| if not isinstance(line, LineContext) and line: |
| dedent = min(dedent, len(line) - len(line.lstrip())) |
| if math.isinf(dedent): |
| dedent = 0 |
| for line in other_code._lines: |
| if isinstance(line, LineContext): |
| self._lines.append(line) |
| else: |
| IndentedBuffer.writeline(self, line[int(dedent) :]) |
| else: |
| other_code = textwrap.dedent(other_code) |
| if strip: |
| other_code = other_code.lstrip() |
| if not other_code: |
| return |
| other_code = other_code.rstrip() |
| for line in other_code.split("\n"): |
| self.writeline(line) |
| |
| def map(self, func: Callable[[Any], Any]) -> IndentedBuffer: |
| res = IndentedBuffer(initial_indent=self._indent) |
| res._lines = [func(line) for line in self._lines] |
| return res |
| |
| def __repr__(self): |
| return f"{type(self)}({self.getvalue()})" |
| |
| def __add__(self, other): |
| assert self._indent == other._indent |
| res = IndentedBuffer(initial_indent=self._indent) |
| res.writelines(self._lines) |
| res.writelines(other._lines) |
| return res |
| |
| |
| class FakeIndentedBuffer(IndentedBuffer): |
| def __init__(self): |
| super().__init__() |
| |
| def __getattribute__(self, name): |
| if name == "__class__": # Allow access to the class attribute |
| return object.__getattribute__(self, name) |
| raise RuntimeError( |
| f"Tried to call self.{name} on FakeIndentedBuffer. This buffer" |
| "is currently used on TritonTemplateKernel to prevent actual" |
| "writes to the body without explicitly specifying the body with" |
| "`TritonTemplateKernel.set_subgraph_body(name)`" |
| ) |
| |
| |
| @contextlib.contextmanager |
| def restore_stdout_stderr(initial_stdout, initial_stderr): |
| try: |
| yield |
| finally: |
| sys.stdout = initial_stdout |
| sys.stderr = initial_stderr |
| |
| |
| class DeferredLineBase: |
| """A line that can be 'unwritten' at a later time""" |
| |
| def __init__(self, line): |
| if not line.strip(): |
| line = "" |
| self.line = line |
| |
| def __call__(self) -> Optional[str]: |
| """Returns either self.line or None to indicate the line has been 'unwritten'""" |
| raise NotImplementedError |
| |
| def _new_line(self, line: str) -> DeferredLineBase: |
| """Returns a new deferred line with the same condition""" |
| raise NotImplementedError |
| |
| def with_prefix(self, prefix): |
| return self._new_line(f"{prefix}{self.line}") |
| |
| def lstrip(self): |
| return self._new_line(self.line.lstrip()) |
| |
| def __getitem__(self, index): |
| return self._new_line(self.line[index]) |
| |
| def __bool__(self): |
| return bool(self.line) |
| |
| def __len__(self): |
| return len(self.line) |
| |
| |
| @functools.lru_cache(None) |
| def is_big_gpu(index) -> bool: |
| min_sms = 68 # 3080 |
| avail_sms = torch.cuda.get_device_properties(index).multi_processor_count |
| if avail_sms < min_sms: |
| log.warning( |
| "Not enough SMs to use max_autotune_gemm mode", |
| extra={"min_sms": min_sms, "avail_sms": avail_sms}, |
| ) |
| return False |
| return True |
| |
| |
| def use_max_autotune() -> bool: |
| return ( |
| config.max_autotune or config.max_autotune_gemm or config.search_autotune_cache |
| ) |
| |
| |
| def _use_template_for_cuda(layout, allowed_layout_dtypes: List[torch.dtype]) -> bool: |
| return ( |
| use_max_autotune() |
| and layout.device.type == "cuda" |
| and layout.dtype in allowed_layout_dtypes |
| and is_big_gpu(layout.device.index or 0) |
| ) |
| |
| |
| def _use_autotune_backend(backend: str) -> bool: |
| return backend.upper() in [ |
| x.strip() for x in config.max_autotune_gemm_backends.upper().split(",") |
| ] |
| |
| |
| def use_triton_template(layout, *, enable_int32=False): |
| layout_dtypes = [torch.float16, torch.bfloat16, torch.float32] |
| if enable_int32: |
| layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32] |
| return _use_template_for_cuda(layout, layout_dtypes) and _use_autotune_backend( |
| "TRITON" |
| ) |
| |
| |
| def use_cutlass_template(layout, m, n, k): |
| from .virtualized import V |
| |
| gemm_size = V.graph.sizevars.size_hint(m * n * k, fallback=-1) |
| if gemm_size <= 0 or gemm_size < config.cuda.cutlass_backend_min_gemm_size: |
| return False |
| from .codegen.cuda.cutlass_utils import try_import_cutlass |
| |
| # Do not use cutlass template on ROCm |
| if torch.version.hip: |
| return False |
| |
| layout_dtypes = [torch.float16, torch.bfloat16, torch.float32, torch.int32] |
| res = _use_template_for_cuda(layout, layout_dtypes) and _use_autotune_backend( |
| "CUTLASS" |
| ) |
| |
| if res: |
| if not try_import_cutlass(): |
| log.warning( |
| "Failed to import CUTLASS lib. Please check whether " |
| "_inductor.config.cuda.cutlass_dir is set correctly. " |
| "Skipping CUTLASS backend for now." |
| ) |
| return False |
| return res |
| |
| |
| def _use_template_for_cpu(layout): |
| return use_max_autotune() and layout.device.type == "cpu" |
| |
| |
| def use_cpp_packed_gemm_template(layout, mat1, mat2): |
| from . import ir |
| from .codegen.cpp_micro_gemm import create_micro_gemm |
| from .kernel.mm_common import mm_args |
| |
| if not _use_template_for_cpu(layout) or not _use_autotune_backend("CPP"): |
| return False |
| |
| if not config.cpp.weight_prepack: |
| return False |
| |
| layout_dtypes = [torch.float32] |
| m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2) |
| # TODO(jgong5): support dynamic shapes for n or k |
| if has_free_symbols((n, k)): |
| return False |
| if isinstance(mat2, ir.BaseView): |
| mat2 = mat2.unwrap_view() |
| micro_gemm = create_micro_gemm( |
| "micro_gemm", m, n, k, layout.dtype, num_threads=parallel_num_threads() |
| ) |
| # TODO(jgong5): support n % n_block_size != 0 |
| return ( |
| layout.dtype in layout_dtypes |
| and micro_gemm is not None |
| and n % micro_gemm.register_blocking[1] == 0 |
| and mat1.get_stride()[-1] == 1 # TODO(jgong5): support transposed input |
| and isinstance(mat2, ir.StorageBox) |
| and mat2.is_module_buffer() |
| ) |
| |
| |
| def use_aten_gemm_kernels(): |
| return not use_max_autotune() or _use_autotune_backend("ATEN") |
| |
| |
| class DebugDirManager: |
| counter = itertools.count(0) |
| prev_debug_name: str |
| |
| def __init__(self): |
| self.id = next(DebugDirManager.counter) |
| |
| def __enter__(self): |
| self.prev_debug_name = torch._dynamo.config.debug_dir_root |
| self.new_name = f"{self.prev_debug_name}_tmp_{self.id}" |
| torch._dynamo.config.debug_dir_root = self.new_name |
| |
| def __exit__(self, *args): |
| shutil.rmtree(self.new_name) |
| torch._dynamo.config.debug_dir_root = self.prev_debug_name |
| |
| |
| def run_and_get_code(fn, *args, **kwargs): |
| from .graph import GraphLowering |
| |
| compile_to_module = GraphLowering.compile_to_module |
| source_codes: List[str] = [] |
| |
| def patched_compile_to_module(self): |
| mod = compile_to_module(self) |
| with open(mod.__file__) as f: |
| source_codes.append(f.read()) |
| return mod |
| |
| # If FX code caching is enabled, a hit prevents getting the code. |
| with config.patch({"fx_graph_cache": False}): |
| with mock.patch.object( |
| GraphLowering, "compile_to_module", patched_compile_to_module |
| ): |
| torch._dynamo.reset() |
| result = fn(*args, **kwargs) |
| return result, source_codes |
| |
| |
| def get_code(fn, *args, **kwargs): |
| """Get the inductor-generated code, but skip any actual compilation or running.""" |
| from .graph import GraphLowering |
| |
| source_codes: List[str] = [] |
| |
| def patched_compile_to_module(self: GraphLowering): |
| class DummyModule: |
| """This is empty to replace the generated triton module""" |
| |
| def __init__(self): |
| pass |
| |
| def call(self, *args, **kwargs): |
| # Don't do anything when called |
| pass |
| |
| code, _ = ( |
| self.codegen_with_cpp_wrapper() if self.cpp_wrapper else self.codegen() |
| ) |
| # Skip all the actual compiling. |
| |
| source_codes.append(code) |
| return DummyModule() |
| |
| # If FX code caching is enabled, a hit prevents getting the code. |
| with config.patch({"fx_graph_cache": False}): |
| with mock.patch.object( |
| GraphLowering, "compile_to_module", patched_compile_to_module |
| ): |
| torch._dynamo.reset() |
| # Note the return here is None |
| _ = fn(*args, **kwargs) |
| |
| return source_codes |
| |
| |
| def get_triton_code(fn, *args, **kwargs): |
| source_codes = get_code(fn, *args, **kwargs) |
| # Can have two outputs if backwards was eagerly compiled |
| assert ( |
| 1 <= len(source_codes) <= 2 |
| ), f"expected one or two code outputs got {len(source_codes)}" |
| return source_codes[0] |
| |
| |
| def run_and_get_triton_code(fn, *args, **kwargs): |
| _, source_codes = run_and_get_code(fn, *args, **kwargs) |
| # Can have two outputs if backwards was eagerly compiled |
| assert ( |
| 1 <= len(source_codes) <= 2 |
| ), f"expected one or two code outputs got {len(source_codes)}" |
| return source_codes[0] |
| |
| |
| @contextlib.contextmanager |
| def override_lowering(aten_op, override_fn): |
| """ |
| Override the lowering of aten_op with override_fn. |
| The first argument of override_fn is the original lowering fn. |
| """ |
| from torch._inductor import lowering |
| |
| orig_fn = lowering.lowerings[aten_op] |
| try: |
| lowering.lowerings[aten_op] = functools.partial(override_fn, orig_fn) |
| yield |
| finally: |
| lowering.lowerings[aten_op] = orig_fn |
| |
| |
| def add_scheduler_init_hook(pre_fn, post_fn=None): |
| """ |
| Add hook functions to be called at the beginning and end of Scheduler.__init__. |
| Used for unit tests. |
| """ |
| from torch._inductor.scheduler import Scheduler |
| |
| orig_fn = Scheduler.__init__ |
| |
| def wrapper(scheduler, nodes): |
| pre_fn(scheduler, nodes) |
| out = orig_fn(scheduler, nodes) |
| if post_fn: |
| post_fn(scheduler, nodes) |
| return out |
| |
| return unittest.mock.patch.object(Scheduler, "__init__", wrapper) |
| |
| |
| def developer_warning(msg): |
| """ |
| Warnings that will be actionable for PyTorch developers, but not |
| end users. Allows us to easily disable them in stable releases but |
| keep them on for nightly builds. |
| """ |
| if config.developer_warnings: |
| log.warning(msg) |
| else: |
| log.info(msg) |
| |
| |
| def get_benchmark_name(): |
| """ |
| An experimental API used only when config.benchmark_kernel is true. |
| |
| The benchmark name is only available at codegen time. So we can not |
| directly call it in benchmark_all_kernels which is run after codegen. |
| |
| The function assumes the argument after --only is the benchmark name. |
| It works for torchbench.py/hugginface.py/timm_models.py. But for ad-hoc |
| scripts, this function may return None. |
| |
| There are 2 flavors of --only argument we need handle: |
| 1. --only model_name |
| 2. --only=model_name |
| """ |
| try: |
| idx = sys.argv.index("--only") |
| if ( |
| idx + 1 < len(sys.argv) |
| and len(sys.argv[idx + 1]) > 0 |
| and sys.argv[idx + 1][0] != "-" |
| ): |
| return sys.argv[idx + 1] |
| except ValueError: |
| pass |
| |
| for arg in sys.argv: |
| if arg.startswith("--only="): |
| return arg[len("--only=") :] |
| |
| |
| def is_ones(items): |
| return all(x == 1 for x in items) |
| |
| |
| def is_zeros(items): |
| return all(x == 0 for x in items) |
| |
| |
| def is_cpu_device(inputs): |
| return all( |
| item.device == torch.device("cpu") |
| for item in inputs |
| if isinstance(item, torch.Tensor) |
| ) |
| |
| |
| def get_sympy_Expr_dtype(val: sympy.Expr) -> torch.dtype: |
| assert isinstance( |
| val, sympy.Expr |
| ), "only support sympy.Expr as input to get_sympy_Expr_dtype" |
| if val.is_integer: # type: ignore[attr-defined] |
| return torch.int64 |
| else: |
| return torch.float64 |
| |
| |
| @contextlib.contextmanager |
| def maybe_profile(should_profile, *args, **kwargs): |
| if should_profile: |
| with torch.profiler.profile(*args, **kwargs) as p: |
| yield p |
| else: |
| yield |
| |
| |
| def parallel_num_threads(): |
| threads = config.cpp.threads |
| if threads < 1: |
| threads = torch.get_num_threads() |
| return threads |
| |
| |
| @functools.lru_cache(None) |
| def get_device_tflops(dtype): |
| from triton.testing import get_max_simd_tflops, get_max_tensorcore_tflops |
| |
| assert dtype in (torch.float16, torch.bfloat16, torch.float32) |
| |
| if inspect.signature(get_max_simd_tflops).parameters.get("clock_rate"): |
| # Triton API change in https://github.com/openai/triton/pull/2293 |
| from torch._utils_internal import max_clock_rate |
| |
| sm_clock = max_clock_rate() |
| if dtype in (torch.float16, torch.bfloat16): |
| return get_max_tensorcore_tflops(dtype, sm_clock) |
| |
| if torch.backends.cuda.matmul.allow_tf32: |
| return get_max_tensorcore_tflops(torch.float32, sm_clock) |
| else: |
| return get_max_simd_tflops(torch.float32, sm_clock) |
| else: |
| if dtype in (torch.float16, torch.bfloat16): |
| return get_max_tensorcore_tflops(dtype) |
| |
| if torch.backends.cuda.matmul.allow_tf32: |
| return get_max_tensorcore_tflops(torch.float32) |
| else: |
| return get_max_simd_tflops(torch.float32) |
| |
| |
| @functools.lru_cache(None) |
| def get_gpu_dram_gbps(): |
| from triton.testing import get_dram_gbps |
| |
| return get_dram_gbps() |
| |
| |
| def get_gpu_shared_memory(): |
| from triton.runtime import driver |
| |
| return driver.active.utils.get_device_properties(0).get("max_shared_mem", 0) |
| |
| |
| def is_welford_reduction(reduction_type): |
| return reduction_type.startswith("welford") |
| |
| |
| def reduction_num_outputs(reduction_type): |
| return 3 if is_welford_reduction(reduction_type) else 1 |
| |
| |
| def is_linux() -> bool: |
| return platform.system() == "Linux" |
| |
| |
| def has_free_symbols(itr: Iterable[Any]): |
| return any(isinstance(x, sympy.Expr) and not x.is_number for x in itr) |
| |
| |
| def is_dynamic(*args): |
| from . import ir |
| |
| for t in args: |
| if isinstance(t, ir.TensorBox): |
| if has_free_symbols(t.data.get_size()) or ( |
| hasattr(t.data, "get_stride") and has_free_symbols(t.data.get_stride()) |
| ): |
| return True |
| elif isinstance(t, (ir.StorageBox, ir.BaseView, ir.ComputedBuffer)): |
| assert hasattr(t, "get_size") and hasattr(t, "get_stride") |
| if has_free_symbols(t.get_size()) or has_free_symbols(t.get_stride()): |
| return True |
| elif not isinstance(t, ir.IRNode): |
| continue |
| else: |
| raise TypeError(f"unexpected type for is_dynamic {type(t)}") |
| |
| return False |
| |
| |
| # Placeholder strings used in triton codegen. |
| class Placeholder(enum.Enum): |
| # The placeholder for the actual name of a triton kernel. |
| # e.g. for "def triton_" it would be "triton_" |
| KERNEL_NAME = "KERNEL_NAME" |
| |
| # The descriptive name of the triton kernel; when unique_kernel_names = False, this |
| # placeholder will be replaced with a string with more information. |
| DESCRIPTIVE_NAME = "DESCRIPTIVE_NAME" |
| |
| |
| def pass_execution_and_save(func, gm, inp, msg): |
| from .pattern_matcher import stable_topological_sort |
| |
| with tempfile.NamedTemporaryFile( |
| mode="w", |
| encoding="utf-8", |
| delete=False, |
| ) as f: |
| before_io = io.StringIO() |
| after_io = io.StringIO() |
| ShapeProp(gm=gm, fake_mode=detect_fake_mode(inp)).propagate(*inp) |
| print(f"Before:\n{gm.graph}", file=f) |
| print(gm.graph, file=before_io) |
| start_time = datetime.now() |
| func(gm.graph) |
| time_elapsed = datetime.now() - start_time |
| # recompile graph |
| stable_topological_sort(gm.graph) |
| gm.graph.lint() |
| gm.recompile() |
| |
| print(f"After:\n{gm.graph}", file=f) |
| print(gm.graph, file=after_io) |
| t = before_io.getvalue() == after_io.getvalue() |
| log.info( |
| "%s, save before/after graph to %s, graph before/after are the same = %s, time elapsed = %s", |
| msg, |
| f.name, |
| t, |
| time_elapsed, |
| ) |
| |
| |
| def is_collective(node): |
| from . import ir |
| |
| return type(node) == ir._CollectiveKernel |
| |
| |
| def is_wait(node): |
| from . import ir |
| |
| return type(node) == ir._WaitKernel |
| |
| |
| def num_fw_fixed_arguments(dynamo_gm_num_inputs: int, aot_fw_gm_num_inputs: int): |
| "Computes the number of inputs to the aot fw graph which have fixed addresses (params and buffers)" |
| num_rng_seed_offset_inputs = ( |
| 2 if torch._functorch.config.functionalize_rng_ops else 0 |
| ) |
| return aot_fw_gm_num_inputs - dynamo_gm_num_inputs - num_rng_seed_offset_inputs |
| |
| |
| def count_tangents(fx_g: torch.fx.GraphModule): |
| """ |
| Infers which inputs are static for a backwards graph |
| """ |
| |
| def is_saved_tensor(x): |
| return ( |
| "tangents" not in x.name |
| and "bwd_seed" not in x.name |
| and "bwd_base_offset" not in x.name |
| ) |
| |
| arg_count = 0 |
| static_arg_idxs = [] |
| for n in fx_g.graph.nodes: |
| if n.op == "placeholder": |
| if is_saved_tensor(n): |
| static_arg_idxs.append(arg_count) |
| arg_count += 1 |
| |
| assert static_arg_idxs == list(range(len(static_arg_idxs))) |
| return len(static_arg_idxs) |
| |
| |
| @dataclasses.dataclass |
| class BoxedBool: |
| value: bool |
| |
| def __bool__(self): |
| return self.value |
| |
| @staticmethod |
| def disable(obj): |
| if isinstance(obj, BoxedBool): |
| obj.value = False |
| return obj |
| return False |
| |
| |
| @contextlib.contextmanager |
| def collect_defined_kernels(kernel_list): |
| from .codegen.wrapper import WrapperCodeGen |
| |
| orig_define_kernel = WrapperCodeGen.define_kernel |
| |
| def new_define_kernel(wrapper, name, kernel_code, metadata, *args, **kwargs): |
| nonlocal kernel_list |
| kernel_list.append(kernel_code) |
| return orig_define_kernel(wrapper, name, kernel_code, metadata, *args, **kwargs) |
| |
| with unittest.mock.patch.object(WrapperCodeGen, "define_kernel", new_define_kernel): |
| yield |
| |
| |
| def get_cloned_parameter_buffer_name(name: str): |
| return name + "__original__" |
| |
| |
| def is_gpu(device: str): |
| return device in ["cuda", "xpu"] |
| |
| |
| def device_need_guard(device: str): |
| assert isinstance(device, str) |
| return is_gpu(device) |
| |
| |
| def needs_fallback_due_to_atomic_add_limitations(dtype): |
| # tl.atomic_add does NOT support the following types |
| return dtype in {torch.int64, torch.bool, torch.bfloat16} |
| |
| |
| def use_scatter_fallback( |
| op_overload: torch._ops.OpOverload, |
| reduction_type, |
| self_dtype, |
| src_dtype, |
| src_device_type, |
| src_is_tensor, |
| ): |
| reduce_ty = ( |
| "add" if op_overload.overloadpacket == torch.ops.aten.scatter_ else "sum" |
| ) |
| |
| return ( |
| reduction_type not in {None, reduce_ty} |
| or ( |
| src_is_tensor |
| and is_gpu(src_device_type) |
| and needs_fallback_due_to_atomic_add_limitations(src_dtype) |
| ) |
| or ( |
| op_overload.overloadpacket == torch.ops.aten.scatter_reduce_ |
| and reduction_type == "sum" |
| and src_is_tensor |
| and src_device_type == "cpu" |
| and config.cpp.fallback_scatter_reduce_sum |
| and (config.cpp.dynamic_threads or parallel_num_threads() != 1) |
| ) |
| or (reduction_type == reduce_ty and self_dtype in {torch.bool, torch.int64}) |
| or torch.are_deterministic_algorithms_enabled() |
| ) |
| |
| |
| def dump_node_schedule(node_schedule): |
| """ |
| An API that can be used in pdb to dump a node_schedule. |
| Right mainly dump the read/write dependencies but can add more as needed. |
| """ |
| from torch._inductor.codegen.simd import DisableReduction, EnableReduction |
| from torch._inductor.scheduler import SchedulerNode |
| |
| print(f"Node schedule with {len(node_schedule)} nodes") |
| for idx, node in enumerate(node_schedule): |
| print(f" {idx:3}:") |
| if node is EnableReduction: |
| print("enable reduction") |
| elif node is DisableReduction: |
| print("disable reduction") |
| elif isinstance(node, SchedulerNode): |
| is_red = node.is_reduction() |
| print(f"{'red' if is_red else 'pw'} scheduler node") |
| if is_red: |
| assert node.node is not None |
| print(f"original reduction hint {node.node.data.reduction_hint}") # type: ignore[attr-defined] |
| print("ReadDep:") |
| for dep in node.read_writes.reads: |
| print(dep) |
| print("WriteDep:") |
| for dep in node.read_writes.writes: |
| print(dep) |
| else: |
| raise RuntimeError(f"Unrecognized node type: {type(node)}") |
| |
| |
| def tensor_is_aligned(tensor: torch.Tensor): |
| # See Note: [Input Alignment handling in Inductor] |
| # Right now, we don't try to guard on the alignment of the storage offset. |
| # When this comment was written, non-symbolic storage_offsets are not guarded on |
| # but symbolic storage_offsets are. For consistency, we suppress guard creation |
| # upon performing this check: that ensures that we don't add recompiles when we |
| # add this logic. |
| return ( |
| tensor.storage_offset() * get_dtype_size(tensor.dtype) |
| ) % GPU_ALIGN_BYTES == 0 |
| |
| |
| def should_assume_input_aligned(example_input: torch.Tensor): |
| # See Note: [Input Alignment handling in Inductor] |
| |
| # right now, we only care about alignment for cuda tensors. |
| if not is_gpu(example_input.device.type): |
| return False |
| return config.assume_aligned_inputs or tensor_is_aligned(example_input) |
| |
| |
| def maybe_get_suppress_shape_guards_ctx(): |
| # Try to get TracingContext.try_get().fake_mode.shape_env.suppress_guards() |
| # If it's not available, return a nullcontext. |
| |
| # If we're dealing with cudagraphs, we might not have a tracing_context |
| tracing_context = torch._guards.TracingContext.try_get() |
| if not tracing_context: |
| return contextlib.nullcontext() |
| |
| # In standalone inductor compile mode, we might not have a shape_env attached to the fake mode |
| shape_env = tracing_context.fake_mode.shape_env |
| if not shape_env: |
| return contextlib.nullcontext() |
| |
| return shape_env.suppress_guards() |
| |
| |
| def aoti_eager_cache_dir(namespace: str, device: str): |
| return Path(cache_dir()) / "aoti_eager" / namespace / device |
| |
| |
| def aoti_eager_op_conf_lock(op_func_name_with_overload: str): |
| from filelock import FileLock |
| |
| # Avoid circular import |
| from torch._inductor.codecache import get_lock_dir, LOCK_TIMEOUT |
| |
| op_conf_lock_file = f"{op_func_name_with_overload}.lock" |
| lock_dir = get_lock_dir() |
| return FileLock(os.path.join(lock_dir, op_conf_lock_file), timeout=LOCK_TIMEOUT) |
| |
| |
| def load_aoti_eager_cache(ns: str, op_func_name_with_overload: str, device_type: str): |
| device_kernel_cache = aoti_eager_cache_dir(ns, device_type) |
| op_conf = device_kernel_cache / f"{op_func_name_with_overload}.json" |
| if not op_conf.exists(): |
| return [] |
| |
| with aoti_eager_op_conf_lock(op_func_name_with_overload): |
| with open(op_conf) as f: |
| json_data = json.load(f) |
| for item in json_data: |
| # Get absolution path for kernel library |
| kernel_lib_abs_path = device_kernel_cache / item["kernel_path"] |
| item["kernel_path"] = kernel_lib_abs_path.as_posix() |
| |
| # Check if the kernel library exists |
| if not kernel_lib_abs_path.exists(): |
| return [] |
| |
| for metadata in item["meta_info"]: |
| assert not metadata[ |
| "is_dynamic" |
| ], "Only support static shape for now" |
| if metadata["device_type"] == "cpu": |
| metadata["device_index"] = -1 |
| metadata["dtype"] = getattr(torch, metadata["dtype"].split(".")[-1]) |
| |
| return json_data |
| |
| |
| def aoti_compile_with_persistent_cache( |
| ns: str, |
| op_func_name_with_overload: str, |
| device_type: str, |
| dynamic: bool, |
| f: Callable[..., Any], |
| args: Tuple[Any], |
| kwargs: Dict[str, Any], |
| *, |
| dynamic_shapes: Optional[Dict[str, Any]] = None, |
| options: Optional[Dict[str, Any]] = None, |
| remove_runtime_assertions: bool = False, |
| disable_constraint_solver: bool = False, |
| ): |
| """ |
| Compile the given function with persistent cache for AOTI eager mode. |
| """ |
| assert not dynamic, "Only support static shape for now" |
| type_to_torch_dtype = {int: torch.int32, float: torch.float, bool: torch.bool} |
| supported_scalar_types = tuple(type_to_torch_dtype.keys()) |
| flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs) |
| if not all( |
| isinstance(input, (supported_scalar_types, torch.Tensor)) |
| for input in flattened_inputs |
| ): |
| raise NotImplementedError("Only support tensor, int, float, bool for now") |
| |
| persistent_cache = aoti_eager_cache_dir(ns, device_type) |
| if not persistent_cache.exists(): |
| persistent_cache.mkdir(parents=True) |
| |
| persistent_cache_lib = persistent_cache / "lib" |
| if not persistent_cache_lib.exists(): |
| persistent_cache_lib.mkdir() |
| |
| with mock.patch.dict( |
| os.environ, |
| {"TORCHINDUCTOR_CACHE_DIR": persistent_cache_lib.absolute().as_posix()}, |
| ): |
| try: |
| kernel_lib_path = torch._export.aot_compile( |
| f, |
| args, |
| kwargs, |
| dynamic_shapes=dynamic_shapes, |
| options=options, |
| remove_runtime_assertions=remove_runtime_assertions, |
| disable_constraint_solver=disable_constraint_solver, |
| # Some operations may have non-Tensor parameters like int, float, bool. These |
| # non-Tensor parameters will not be the input of the graph. Therefore, we do |
| # need to keep the same signature. |
| same_signature=False, |
| ) |
| |
| kernel_metadata_items = [] |
| for input in flattened_inputs: |
| # TODO(Eikan): To add dynamic support |
| metadata: Dict[str, Any] = {} |
| metadata["is_dynamic"] = dynamic |
| |
| if isinstance(input, torch.Tensor): |
| metadata["device_type"] = f"{input.device.type}" |
| if is_cpu_device([input]): |
| metadata["device_index"] = -1 |
| else: |
| metadata["device_index"] = input.device.index |
| metadata["dtype"] = f"{input.dtype}" |
| metadata["sizes"] = list(input.size()) |
| metadata["strides"] = list(input.stride()) |
| else: |
| assert isinstance(input, supported_scalar_types) |
| # Scalar tensor |
| metadata["device_type"] = device_type |
| metadata["device_index"] = -1 if device_type == "cpu" else 0 |
| metadata["dtype"] = f"{type_to_torch_dtype[type(input)]}" |
| metadata["sizes"] = [] |
| metadata["strides"] = [] |
| metadata["scalar_value"] = input |
| |
| kernel_metadata_items.append(metadata) |
| |
| kernel_meta_info: Dict[str, Any] = {} |
| kernel_meta_info["meta_info"] = kernel_metadata_items |
| kernel_meta_info["kernel_path"] = ( |
| Path(kernel_lib_path).relative_to(persistent_cache).as_posix() |
| ) |
| |
| json_data = [] |
| update_json = True |
| op_conf = persistent_cache / f"{op_func_name_with_overload}.json" |
| mode = "r" if op_conf.exists() else "w" |
| with aoti_eager_op_conf_lock(op_func_name_with_overload): |
| with open(op_conf, mode) as op_conf_file: |
| try: |
| json_data = json.load(op_conf_file) |
| except Exception as e: |
| json_data = [] |
| |
| assert isinstance(json_data, list) |
| for item in json_data: |
| assert isinstance(item, dict) |
| # Same kernel meta info already exists in the json file |
| if item["meta_info"] == kernel_metadata_items: |
| update_json = False |
| break |
| |
| if update_json: |
| json_data.append(kernel_meta_info) |
| with open(op_conf, "w") as op_conf_file: |
| json.dump(json_data, op_conf_file, indent=4) |
| |
| return kernel_lib_path |
| except Exception as e: |
| return "" |
| |
| |
| def run_and_get_cpp_code(fn, *args, **kwargs): |
| # We use the patch context manager instead of using it as a decorator. |
| # In this way, we can ensure that the attribute is patched and unpatched correctly |
| # even if this run_and_get_cpp_code function is called multiple times. |
| with unittest.mock.patch.object(config, "debug", True): |
| torch._dynamo.reset() |
| import io |
| import logging |
| |
| log_capture_string = io.StringIO() |
| ch = logging.StreamHandler(log_capture_string) |
| from torch._inductor.graph import output_code_log |
| |
| output_code_log.addHandler(ch) |
| prev_level = output_code_log.level |
| output_code_log.setLevel(logging.DEBUG) |
| result = fn(*args, **kwargs) |
| s = log_capture_string.getvalue() |
| output_code_log.setLevel(prev_level) |
| output_code_log.removeHandler(ch) |
| return result, s |