| """ |
| This file contains canonical definitions for our symbol naming conventions, |
| across torch.fx.experimental.symbolic_shapes and torch._inductor. The |
| intention is: |
| |
| 1. To make it easily greppable where all the sites we use a prefix are |
| 2. Make it possible to easily tell if we can introduce a new prefix without |
| introducing a conflict |
| |
| You can occasionally test if prefixes have been hardcoded by renaming prefixes |
| in this file and seeing what breaks. |
| """ |
| |
| from enum import auto, Enum |
| from typing import Sequence, Union |
| |
| import sympy |
| |
| |
| class SymT(Enum): |
| SIZE = auto() |
| FLOAT = auto() |
| UNBACKED_INT = auto() |
| UNBACKED_FLOAT = auto() |
| # Inductor: The intermediates in inner_fn tmp0, one generated per ops call. |
| # If one of these shows up in an indexing expression, that means an |
| # indirect load is happening. |
| TMP = auto() |
| # Inductor: Placeholder variable that is later replaced with TMP |
| INDIRECT = auto() |
| # Inductor: Some size expressions are replaced with a precomputed size ps0 |
| # which is computed host side, and then directly reused in the kernel, so |
| # we don't repeatedly recompute it on device. |
| PRECOMPUTED_SIZE = auto() |
| # Inductor: An indexing variable i0 in loops IR which ranges over non-reduced |
| # dim in the loop |
| INDEX = auto() |
| # Inductor: A reduction indexing r0 variable in loops IR which ranges over |
| # reduced dim in the loop |
| RINDEX = auto() |
| # Inductor: In templated kernels torch._inductor.kernel, we have a hook to |
| # store the final output and append epilogue fusions. To do this, we must |
| # know what the indexes the outputs range over. NB: These will also |
| # advertise as INDEX, this is... probably OK? |
| TEMPLATE_INDEX = auto() |
| # Inductor: iteration domain for blockIdx.x/blockIdx.y |
| XBLOCK = auto() |
| YBLOCK = auto() |
| # Inductor: this is used solely for dynamic_reshape_indexer |
| VIEW = auto() |
| |
| |
| # Invariant: there must not be a prefix which is a prefix of another string, |
| # as this introduces ambiguity |
| prefix_str = { |
| SymT.SIZE: "s", # integer |
| SymT.UNBACKED_INT: "u", # integer |
| # Prefix z here is chosen to avoid false aliasing in symbol_is_type test |
| # DO NOT add a "z" type. You also need to avoid conflicts on these |
| # prefixes but this is somewhat easier to manage |
| SymT.FLOAT: "zf", |
| SymT.UNBACKED_FLOAT: "zuf", |
| SymT.TMP: "tmp", |
| SymT.PRECOMPUTED_SIZE: "ps", |
| SymT.INDEX: "i", |
| SymT.RINDEX: "r", |
| SymT.TEMPLATE_INDEX: "idx", |
| SymT.XBLOCK: "x", |
| SymT.YBLOCK: "y", |
| SymT.INDIRECT: "indirect", # false aliasing? |
| SymT.VIEW: "view", |
| } |
| |
| |
| def make_symbol(prefix: SymT, idx: int, **kwargs) -> sympy.Symbol: |
| # TODO: maybe put the assumptions here directly |
| return sympy.Symbol(f"{prefix_str[prefix]}{idx}", **kwargs) |
| |
| |
| # This type is a little wider than it should be, because free_symbols says |
| # that it contains Basic, rather than Symbol |
| def symbol_is_type(sym: sympy.Basic, prefix: Union[SymT, Sequence[SymT]]) -> bool: |
| assert isinstance(sym, sympy.Symbol) |
| if isinstance(prefix, SymT): |
| return sym.name.startswith(prefix_str[prefix]) |
| else: |
| return sym.name.startswith(tuple(prefix_str[p] for p in prefix)) |
| |
| |
| def free_symbol_is_type(e: sympy.Expr, prefix: SymT) -> bool: |
| return any(symbol_is_type(v, prefix) for v in e.free_symbols) |