blob: 5ad062fc979b4a3af579ed4c85ae4c82a356edd7 [file] [log] [blame]
from __future__ import annotations
import itertools
from contextlib import contextmanager
from itertools import chain
from threading import local
from typing import Any, Callable, TYPE_CHECKING, Union
from unittest.mock import patch
import sympy
from torch._inductor.utils import IndentedBuffer
from torch.fx.graph import inplace_methods, magic_methods
from .utils import reduction_num_outputs, sympy_str, sympy_symbol
if TYPE_CHECKING:
from torch._inductor.graph import GraphLowering
threadlocal = local()
class Virtualized:
"""
A global variable that redirects via thread local variable
This allows us to swap in different op implementations in codegen.
"""
def __init__(self, vname: str, default):
self._key: str = f"__torchinductor_{vname}"
self._default = default
def _set_handler(self, value):
prior = self._get_handler()
setattr(threadlocal, self._key, value)
@contextmanager
def ctx():
try:
yield
finally:
self._set_handler(prior)
return ctx()
def _get_handler(self):
try:
return getattr(threadlocal, self._key)
except AttributeError:
return self._default()
def __getattr__(self, name):
return getattr(self._get_handler(), name)
class NullHandler:
pass
class NullKernelHandler(NullHandler):
"""
We need access `V.kernel.removed_buffers` in DeferredLine class when there
is no kernel in the context. This happens when codegening the wrapper.
Initialize `removed_buffers` and `inplaced_to_remove` explicitly so we don't
need call 'getattr' with default value which is error prone to typo in
attribute name.
"""
def __init__(self):
super().__init__()
self.removed_buffers = set()
self.inplaced_to_remove = set()
def _arg_str(a) -> str:
if isinstance(a, sympy.Expr):
return sympy_str(a)
return str(a)
class MockHandler:
def __getattr__(self, name):
if name == "name":
return "MockHandler"
def inner(*args, **kwargs):
fargs = [_arg_str(a) for a in args]
fargs.extend(f"{k}={v}" for k, v in kwargs.items())
return f"ops.{name}({', '.join(fargs)})"
return inner
@staticmethod
def masked(mask, body, other) -> str:
return f"ops.masked({mask}, {body()}, {other})"
@staticmethod
def indirect_indexing(index_var, size, check=True) -> sympy.Symbol:
return sympy_symbol(f"({str(index_var)})")
@classmethod
def _init_cls(cls):
def make_handler(format_string):
@staticmethod # type: ignore[misc]
def inner(*args):
return format_string.format(*args)
return inner
for name, format_string in chain(
magic_methods.items(), inplace_methods.items()
):
setattr(cls, name, make_handler(format_string))
class KernelFormatterHandler:
def __init__(self, parent_handler):
self.parent_handler = parent_handler
self.output = IndentedBuffer(1)
self.var_counter = itertools.count()
@staticmethod
def ir_to_string(ir_fn, index, rindex=None) -> str:
from .ir import FlexibleLayout
args = [index, rindex] if rindex is not None else [index]
names = ["index", "rindex"] if rindex is not None else ["index"]
formatter = KernelFormatterHandler(MockHandler())
with formatter.output.indent(-1):
formatter.output.writeline(f"def inner_fn({', '.join(names)}):")
for name, arg in zip(names, args):
if arg:
lhs = ", ".join(
[
str("_" if isinstance(v, (int, sympy.Integer)) else v)
for v in arg
]
)
formatter.output.writeline(f"{lhs} = {name}")
with V.set_ops_handler(formatter), patch.object(
FlexibleLayout, "allow_indexing", True
):
result = ir_fn(*args)
return formatter.getvalue(result)
def __getattr__(self, name) -> Callable[..., str]:
def inner(*args, **kwargs):
line = getattr(self.parent_handler, name)(*args, **kwargs)
if name == "indirect_indexing":
return line
# replace line with a new variable name
varname = f"tmp{next(self.var_counter)}"
self.output.writeline(f"{varname} = {line}")
return varname
return inner
def reduction(
self, dtype, src_dtype, reduction_type, value
) -> Union[tuple[str, ...], str]:
line = self.parent_handler.reduction(dtype, src_dtype, reduction_type, value)
num_values = reduction_num_outputs(reduction_type)
varnames = [f"tmp{next(self.var_counter)}" for _ in range(num_values)]
self.output.writeline(f"{','.join(varnames)} = {line}")
return tuple(varnames) if num_values > 1 else varnames[0]
def getvalue(self, result):
self.output.writeline(f"return {result}")
return self.output.getvalue()
class WrapperHandler:
def __init__(self, inner):
self._inner = inner
def __getattr__(self, item):
return getattr(self._inner, item)
MockHandler._init_cls()
_ops = Virtualized("ops", MockHandler)
_graph = Virtualized("graph", NullHandler)
_real_inputs = Virtualized("real_inputs", NullHandler)
_fake_mode = Virtualized("fake_mode", NullHandler)
_kernel = Virtualized("kernel", NullKernelHandler)
_debug = Virtualized("debug", NullHandler)
_interpreter = Virtualized("interpreter", NullHandler)
_aot_compilation = Virtualized("aot_compilation", NullHandler)
_current_node = Virtualized("current_node", NullHandler)
class OpsValue:
"""The return type of most ops calls.
This exists so we can overload magic methods, and write mathematical
expressions much more fluently. So instead of
ops.add(ops.mul(ops.mul(ops.sub(ops.mul(_Ap2, x), _Ap3), x), x), _1)
we can write
(_Ap2 * x - _Ap3) * x * x + _1
"""
value: Any
def __init__(self, value):
self.value = value
def __str__(self):
return str(self.value)
def __repr__(self):
return f"OpsValue({self.value!r})"
def __add__(self, other):
return ops.add(self, other)
def __mul__(self, other):
return ops.mul(self, other)
def __sub__(self, other):
return ops.sub(self, other)
def __neg__(self):
return ops.neg(self)
def __truediv__(self, other):
return ops.truediv(self, other)
def __floordiv__(self, other):
return ops.floordiv(self, other)
def __mod__(self, other):
return ops.mod(self, other)
def __pow__(self, other):
return ops.pow(self, other)
class OpsWrapper:
"""This wraps any returned IR values into an `OpsValue` instance, so that we
can overload the magic methods for writing mathematical expressions fluently.
"""
def __getattr__(self, name):
def inner(*args, **kwargs):
new_args = [OpsWrapper._unwrap(a) for a in args]
new_kwargs = {k: OpsWrapper._unwrap(v) for k, v in kwargs.items()}
return OpsWrapper._wrap(getattr(_ops, name)(*new_args, **new_kwargs))
return inner
@staticmethod
def _unwrap(x):
if isinstance(x, (list, tuple)):
return tuple(OpsWrapper._unwrap(v) for v in x)
if isinstance(x, OpsValue):
return x.value
return x
@staticmethod
def _wrap(x):
if isinstance(x, (list, tuple)):
return tuple(OpsValue(v) for v in x)
return OpsValue(x)
@staticmethod
def indirect_indexing(index, size, check=True):
# Returns a sympy value, not IR value
index = OpsWrapper._unwrap(index)
return _ops.indirect_indexing(index, size, check)
ops = OpsWrapper()
_MockHandler = MockHandler
class _V:
MockHandler = MockHandler
KernelFormatterHandler = KernelFormatterHandler
WrapperHandler = WrapperHandler
set_ops_handler: Callable[[Any], Any] = _ops._set_handler
get_ops_handler: Callable[[], Any] = _ops._get_handler
set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler
set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler
get_real_inputs: Callable[[], Any] = _real_inputs._get_handler
set_fake_mode: Callable[[Any], Any] = _fake_mode._set_handler
get_fake_mode: Callable[[], Any] = _fake_mode._get_handler
set_kernel_handler: Callable[[Any], Any] = _kernel._set_handler
set_debug_handler: Callable[[Any], Any] = _debug._set_handler
set_interpreter_handler: Callable[[Any], Any] = _interpreter._set_handler
set_aot_compilation: Callable[[Any], Any] = _aot_compilation._set_handler
get_aot_compilation: Callable[[], Any] = _aot_compilation._get_handler
set_current_node: Callable[[Any], Any] = _current_node._set_handler
get_current_node: Callable[[], Any] = _current_node._get_handler
@property
def ops(self) -> _MockHandler:
"""The operator handler specific to the current codegen task"""
return _ops._get_handler()
@property
def graph(self) -> GraphLowering:
"""The graph currently being generated"""
return _graph._get_handler()
@property
def real_inputs(self):
"""non-fake example inputs"""
return _real_inputs._get_handler()
@property
def fake_mode(self):
"""The graph currently being generated"""
return _fake_mode._get_handler()
@property
def kernel(self):
"""The kernel currently being generated"""
return _kernel._get_handler()
@property
def debug(self):
return _debug._get_handler()
@property
def interpreter(self):
return _interpreter._get_handler()
@property
def aot_compilation(self):
return _aot_compilation._get_handler()
@property
def current_node(self):
return _current_node._get_handler()
V = _V()