| # mypy: allow-untyped-defs |
| import abc |
| import contextlib |
| import ctypes |
| import importlib |
| import inspect |
| import sys |
| import types |
| from typing import Any, Callable, Dict, List, Set, Type, Union |
| |
| import torch |
| import torch.utils._pytree as pytree |
| from torch import _utils_internal |
| from torch._C import _dispatch_is_included_in_alias as is_included_in_alias, DispatchKey |
| from torch._functorch.pyfunctorch import dispatch_functorch |
| from torch.utils._python_dispatch import TorchDispatchMode |
| |
| |
| # Query `hasattr` only once. |
| _SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags") |
| |
| |
| @contextlib.contextmanager |
| def dl_open_guard(): |
| """ |
| Context manager to set the RTLD_GLOBAL dynamic linker flag while we open a |
| shared library to load custom operators. |
| """ |
| if not _SET_GLOBAL_FLAGS: |
| yield |
| return |
| old_flags = sys.getdlopenflags() |
| sys.setdlopenflags(old_flags | ctypes.RTLD_GLOBAL) |
| try: |
| yield |
| finally: |
| sys.setdlopenflags(old_flags) |
| |
| |
| class OperatorBase: |
| """ |
| Base class for OpOverload (which represents C++ ATen operators) and HigherOrderOperator |
| (which represents Python-only operators that are unrepresentable in TorchScript). |
| """ |
| |
| def __init__(self): |
| # The dispatch cache precomputes a mapping of dispatch key that the |
| # dispatcher wants to dispatch to, to an actual implementation of the |
| # dispatch key. Confusingly, the actual implementation could *also* be a |
| # dispatch key, but in this case, this refers to the C++ kernel that |
| # was registered to some dispatch key. Aliases are permitted in the |
| # latter but not the former; for example, you might lookup the |
| # entry for AutogradCPU, and this maps you to the Autograd key for |
| # the generic autograd kernel that works for all devices. Since this |
| # is the Python dispatcher, you can also put an arbitrary Python |
| # callable to call instead. This handler gets precisely the |
| # args/kwargs that the operator was __call__'ed with. |
| # NB: This name is hard-coded in torch/csrc/autograd/python_variable.cpp |
| # for use with OpOverload; cache lookup is done entirely from C++ |
| # for speed. |
| # TODO: The cache is NOT currently used by HigherOrderOperator, but it should! |
| self._dispatch_cache: Dict[ |
| DispatchKey, Union[DispatchKey, Callable[..., Any]] |
| ] = {} |
| |
| # This table allows you to override the behavior of a particular |
| # dispatch key to call a custom Python function, rather than the |
| # ordinary C++ configured behavior. This is the raison d'etre of |
| # Python dispatcher: to let you program the dispatcher from Python |
| # in case you need something unusual, and don't want to clobber |
| # the existing registrations using the Python operator registration |
| # API. |
| self.py_kernels: Dict[DispatchKey, Callable[..., Any]] = {} |
| |
| # This table allows you to override the behavior of a particular |
| # operator for a particular TorchDispatchMode. In practice, |
| # we are using this mostly for ProxyTensorMode. Modes can be |
| # thought of as an open world extension of dispatch keys, so it |
| # makes sense that you should be able to register them, the same |
| # way you can register dispatch keys. |
| self.python_key_table: Dict[ |
| Union[Type[TorchDispatchMode], Type[torch.Tensor]], Callable[..., Any] |
| ] = {} |
| |
| # This table allows you to override the behavior of functorch |
| # transformations. NB: this currently only does something for |
| # HigherOrderOperator |
| self.functorch_table = {} |
| |
| def __call__(self, *args, **kwargs): |
| raise NotImplementedError |
| |
| def has_kernel_for_dispatch_key(self, k): |
| return k in self.py_kernels |
| |
| def has_kernel_for_any_dispatch_key(self, ks): |
| for k in self.py_kernels: |
| if not torch._C._dispatch_is_alias_key(k) and ks.has(k): |
| return True |
| return False |
| |
| def py_impl(self, k): |
| def inner(fn): |
| if inspect.isclass(k) and ( |
| issubclass(k, TorchDispatchMode) or issubclass(k, torch.Tensor) |
| ): |
| assert k not in self.python_key_table |
| # TODO(voz): Should we replace setting DispatchKey.Python entirely with setting mode keys? |
| self.python_key_table[k] = fn |
| self._dispatch_cache.clear() |
| return fn |
| |
| if isinstance(k, torch._C._functorch.TransformType): |
| assert k not in self.functorch_table |
| self.functorch_table[k] = fn |
| return fn |
| |
| assert isinstance(k, DispatchKey) |
| assert ( |
| k != DispatchKey.Python |
| ), "Please register a mode for the torch._C.DispatchKey.Python key instead." |
| |
| if k in self.py_kernels: |
| raise RuntimeError( |
| f"Trying to override a python impl for {k} on operator {self.name()}" |
| ) |
| self.py_kernels[k] = fn |
| self._dispatch_cache.clear() |
| return fn |
| |
| return inner |
| |
| # Registers an implementation to all **3** variants of functionalization that we have: |
| # - DispatchKey.Functionalize |
| # - functorch.TransformType.Functionalize |
| # - FunctionalTensorMode |
| # Example: |
| # @py_functionalize_impl |
| # def functionalize_rule(ctx, inner_f, *args): |
| # args_unwrapped = ctx.unwrap_tensors(args) |
| # with ctx.redispatch_to_next(): |
| # out = ctx.functionalize(inner_f)(*args_unwrapped) |
| # return ctx.wrap_tensors(out) |
| def py_functionalize_impl(self, fn): |
| from torch._subclasses.functional_tensor import ( |
| CppFunctionalizeAPI as _CppFunctionalizeAPI, |
| FunctorchFunctionalizeAPI as _FunctorchFunctionalizeAPI, |
| PythonFunctionalizeAPI as _PythonFunctionalizeAPI, |
| ) |
| |
| # Construct our three flavors of functionalization, |
| # each of which have slightly different wrap/unwrap/redispatch policies |
| def functionalize_dk_fn(*args, **kwargs): |
| return fn(_CppFunctionalizeAPI(), *args, **kwargs) |
| |
| def functionalize_dispatch_mode_fn(mode, *args, **kwargs): |
| return fn(_PythonFunctionalizeAPI(mode), *args, **kwargs) |
| |
| def functionalize_functorch_fn(interpreter, *args, **kwargs): |
| return fn(_FunctorchFunctionalizeAPI(interpreter), *args, **kwargs) |
| |
| self.py_impl(DispatchKey.Functionalize)(functionalize_dk_fn) |
| self.py_impl(torch._subclasses.functional_tensor.FunctionalTensorMode)( |
| functionalize_dispatch_mode_fn |
| ) |
| self.py_impl(torch._C._functorch.TransformType.Functionalize)( |
| functionalize_functorch_fn |
| ) |
| |
| return fn |
| |
| def name(self): |
| raise NotImplementedError |
| |
| |
| # Equivalent to computeDispatchTableEntryWithDebug |
| def resolve_key(op: OperatorBase, k: DispatchKey): # type: ignore[valid-type] |
| # 1. (Direct) operator registration |
| if op.has_kernel_for_dispatch_key(k): |
| return k |
| # 2.1 Use CompositeExplicitAutogradNonFunctional kernel if available |
| cand = DispatchKey.CompositeExplicitAutogradNonFunctional |
| if ( |
| k == DispatchKey.Undefined or is_included_in_alias(k, cand) |
| ) and op.has_kernel_for_dispatch_key(cand): |
| return cand |
| # 2.2 Use CompositeExplicitAutograd kernel if available |
| cand = DispatchKey.CompositeExplicitAutograd |
| if ( |
| k == DispatchKey.Undefined or is_included_in_alias(k, cand) |
| ) and op.has_kernel_for_dispatch_key(cand): |
| return cand |
| has_backend_kernel = op.has_kernel_for_any_dispatch_key( |
| torch._C._dispatch_get_backend_keyset_from_autograd(k) |
| ) or op.has_kernel_for_dispatch_key(DispatchKey.CompositeExplicitAutograd) |
| # 2.3. Use CompositeImplicitAutograd kernel if available |
| cand = DispatchKey.CompositeImplicitAutogradNestedTensor |
| if ( |
| (k != DispatchKey.Undefined and is_included_in_alias(k, cand)) |
| and op.has_kernel_for_dispatch_key(cand) |
| and not has_backend_kernel |
| ): |
| return cand |
| cand = DispatchKey.CompositeImplicitAutograd |
| if ( |
| k == DispatchKey.Undefined or is_included_in_alias(k, cand) |
| ) and op.has_kernel_for_dispatch_key(cand): |
| if k == DispatchKey.AutogradOther and op.has_kernel_for_any_dispatch_key( |
| torch._C._dispatch_autogradother_backends |
| ): |
| raise RuntimeError("ambiguous autogradother kernel") |
| elif not has_backend_kernel: |
| return cand |
| # 2.4. For autograd backend keys, use kernel from DispatchKey::Autograd if available |
| cand = DispatchKey.Autograd |
| if is_included_in_alias(k, cand) and op.has_kernel_for_dispatch_key(cand): |
| return cand |
| # 2.5 Use kernel from DispatchKey::FuncTorchBatchedDecomposition if available |
| cand = DispatchKey.FuncTorchBatchedDecomposition |
| if is_included_in_alias(k, cand) and op.has_kernel_for_dispatch_key(cand): |
| return cand |
| # Backend fallback |
| if torch._C._dispatch_has_backend_fallback(k): |
| # The dispatch key itself will implicitly route to backend fallback. |
| # This is probably not great for the pure Python implementation. |
| return k |
| raise NotImplementedError(f"could not find kernel for {op} at dispatch key {k}") |
| |
| |
| _higher_order_ops: Dict[str, "HigherOrderOperator"] = {} |
| |
| _HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS = [ |
| DispatchKey.PythonDispatcher, # type: ignore[attr-defined] |
| DispatchKey.PythonTLSSnapshot, # type: ignore[attr-defined] |
| DispatchKey.ADInplaceOrView, |
| DispatchKey.BackendSelect, |
| DispatchKey.AutocastCPU, # type: ignore[attr-defined] |
| DispatchKey.AutocastCUDA, # type: ignore[attr-defined] |
| ] |
| |
| |
| class HigherOrderOperator(OperatorBase, abc.ABC): |
| # The HigherOrderOperator will appear as torch.ops.higher_order.{name} |
| # |
| # If you're creating a new HigherOrderOperator, please do not change the |
| # default. Adding operators to the global torch.ops namespace is a bad |
| # practice due to name collisions. |
| def __init__(self, name): |
| super().__init__() |
| if type(self) is HigherOrderOperator: |
| raise RuntimeError( |
| "Direct instantiation of HigherOrderOperator is not allowed. Please subclass it." |
| ) |
| self._name = name |
| |
| # Make _OPNamespace not scream, this whole name based association needs a good hard look |
| self.__name__ = name |
| _higher_order_ops[name] = self |
| self._ns = "higher_order" |
| self.__module__ = "torch.ops.higher_order" |
| |
| self.non_fallthrough_keys = torch._C._dispatch_keyset_full() |
| |
| for dispatch_key in _HIGHER_ORDER_OP_DEFAULT_FALLTHROUGH_DISPATCH_KEYS: |
| self.fallthrough(dispatch_key) |
| |
| # [NOTE] We have to register pre-dispatch key implementation |
| # because sometimes HOP use aot-dispatch tracing to detect certaion |
| # mutations. This is problematic when we are functionalizing HOP |
| # during pre-dispatch because when the inner tracer starts, it will see |
| # that PreDispatch key is still active. In that case, we just redispatch |
| # it to next key. This is only safe to do when PreDispatch key stack has no |
| # active modes. |
| |
| def py_impl(self, k): |
| if isinstance(k, DispatchKey) and not self.non_fallthrough_keys.has(k): |
| self.non_fallthrough_keys = self.non_fallthrough_keys.add(k) |
| return super().py_impl(k) |
| |
| @property |
| def namespace(self): |
| return self._ns |
| |
| def fallthrough(self, dispatch_key): |
| self.non_fallthrough_keys = self.non_fallthrough_keys.remove(dispatch_key) |
| |
| # Use positional-only argument to avoid naming collide with custom ops arguments |
| # that are named "self". |
| def dispatch(self, /, dispatch_key, *args, **kwargs): |
| from torch.utils._python_dispatch import _get_current_dispatch_mode |
| |
| if dispatch_key in self._dispatch_cache: |
| kernel = self._dispatch_cache[dispatch_key] |
| assert not isinstance(kernel, DispatchKey) |
| return kernel(*args, **kwargs) |
| |
| if dispatch_key == DispatchKey.FuncTorchDynamicLayerFrontMode: |
| return dispatch_functorch(self, args, kwargs) |
| |
| if dispatch_key == DispatchKey.Python: |
| # Keep the following 1:1 with handle_torch_function_no_python_arg_parser |
| # in torch/csrc/utils/python_arg_parser.cpp |
| |
| overloaded_args_list = [] |
| |
| def has_python_key(tensor): |
| return torch._C._dispatch_keys(tensor).has("Python") |
| |
| def check_overloaded(arg): |
| if isinstance(arg, torch.Tensor) and has_python_key(arg): |
| overloaded_args_list.append(arg) |
| |
| for arg in (*args, *kwargs.values()): |
| check_overloaded(arg) |
| if isinstance(arg, (list, tuple)): |
| for a in arg: |
| check_overloaded(a) |
| |
| overloaded_args = tuple(overloaded_args_list) |
| overloaded_types = tuple(type(arg) for arg in overloaded_args) |
| |
| # Step 1: dispatch on any user TorchDispatchModes |
| from torch.utils._python_dispatch import _pop_mode_temporarily |
| |
| curr_mode = _get_current_dispatch_mode() |
| if curr_mode is not None: |
| if type(curr_mode) in self.python_key_table: |
| handler = self.python_key_table[type(curr_mode)] |
| with _pop_mode_temporarily() as mode: |
| # "natural" calling convention: (mode, *args, **kwargs) |
| # TODO(rzou): we should support torch_dispatch calling convention too. |
| result = handler(mode, *args, **kwargs) |
| else: |
| raise NotImplementedError( |
| f"There was no rule registered for HOP {self._name} and mode {curr_mode}. " |
| f"We recommend filing an issue." |
| ) |
| if result is not NotImplemented: |
| return result |
| |
| # Step 2: dispatch on any subclasses |
| for arg in overloaded_args: |
| subclass_type = type(arg) |
| if ( |
| subclass_type.__torch_dispatch__ |
| == torch._C._disabled_torch_dispatch_impl |
| ): |
| continue |
| if subclass_type in self.python_key_table: |
| handler = self.python_key_table[subclass_type] |
| # "natural" calling convention: (*args, **kwargs) |
| # TODO(rzou): we should support torch_dispatch calling convention too. |
| result = handler(*args, **kwargs) |
| else: |
| raise NotImplementedError( |
| f"There was no rule registered for HOP {self._name} and subclass {subclass_type}. " |
| f"We recommend filing an issue." |
| ) |
| if result is not NotImplemented: |
| return result |
| |
| # All handlers returned NotImplemented |
| raise TypeError( |
| f"Multiple dispatch failed for {self._name}. There was no registered that " |
| f"did not return NotImplemented. Use HOP.py_impl to register some. " |
| f"Tried mode: {curr_mode}) and subclasses: " |
| f"{[type(a) for a in overloaded_args]}" |
| ) |
| |
| functionality_key = torch._C._to_functionality_key(dispatch_key) # type: ignore[attr-defined] |
| if functionality_key == DispatchKey.PreDispatch: |
| from torch.utils._python_dispatch import _pop_mode_temporarily |
| |
| # The check for Python in the exclude set is so we properly respect `with no_dispatch()` |
| # calls inside of a mode. |
| if ( |
| _len_torch_dispatch_stack_pre_dispatch() > 0 |
| ) and not torch._C._dispatch_tls_is_dispatch_key_excluded( |
| DispatchKey.Python |
| ): |
| curr_mode = _get_current_dispatch_mode_pre_dispatch() |
| assert ( |
| curr_mode is not None |
| ), "Illegal invocation of dispatch on torch._C.DispatchKey.PreDispatch without a mode." |
| assert ( |
| type(curr_mode) in self.python_key_table |
| ), f"Current active mode {curr_mode} not registered" |
| handler = self.python_key_table[type(curr_mode)] |
| with _pop_mode_temporarily(functionality_key) as mode: |
| return handler(mode, *args, **kwargs) |
| |
| final_key = resolve_key(self, dispatch_key) |
| |
| # This can current fail due to backend fallbacks. You just have to |
| # register them by hand for HigherOrderOperator. |
| if final_key not in self.py_kernels: |
| raise NotImplementedError( |
| f"could not find kernel for HigherOrderOperator {self._name} " |
| f"at dispatch key {final_key} (resolved from {dispatch_key})" |
| ) |
| |
| # [NOTE] We shouldn't cache PreDispatch kernel here because depending |
| # on what modes are active, predispatch behaviour is different. |
| # Also we do same thing for normal ops: |
| # See Note [Not Caching Per-Dispatch-Key Mode Handlers] |
| if dispatch_key != DispatchKey.PreDispatch: |
| self._dispatch_cache[dispatch_key] = self.py_kernels[final_key] |
| kernel = self.py_kernels[final_key] |
| # It's illegal to register DispatchKey to py_kernels, since there's no |
| # C++ kernel to call into |
| assert not isinstance(kernel, DispatchKey) |
| return kernel(*args, **kwargs) |
| |
| @abc.abstractmethod |
| def __call__(self, /, *args, **kwargs): |
| # Dynamo already traces the body of HigherOrderOp beforehand when it |
| # so no need to trace into it. |
| from torch._dynamo import disable |
| |
| @disable |
| def wrapper(): |
| flat_args = _to_flat_tuple(args, kwargs) |
| if torch.overrides.has_torch_function(flat_args): |
| return torch.overrides.handle_torch_function( |
| self, flat_args, *args, **kwargs |
| ) |
| |
| dispatch_key_set = _compute_keyset(args, kwargs, self.non_fallthrough_keys) |
| return self.dispatch( |
| dispatch_key_set.highestPriorityTypeId(), *args, **kwargs |
| ) |
| |
| return wrapper() |
| |
| def __str__(self): |
| return f"{self.name()}" |
| |
| def name(self): |
| return self._name |
| |
| |
| def _to_flat_tuple(args, kwargs): |
| return pytree.arg_tree_leaves(*args, **kwargs) |
| |
| |
| def _compute_keyset(args, kwargs, non_fallthrough_keys): |
| tensors = _get_tensors(args, kwargs) |
| return key_extractor(tensors, non_fallthrough_keys) |
| |
| |
| def _get_tensors(args, kwargs): |
| flat_all = _to_flat_tuple(args, kwargs) |
| tensor_args = [t for t in flat_all if isinstance(t, torch.Tensor)] |
| return tuple(tensor_args) |
| |
| |
| # Note - this should maintain identical impl to the C++ dispatcher key extraction logic |
| # at ATen/core/dispatch/DispatchKeyExtractor.h |
| def key_extractor(tensors, key_mask): |
| key_set = torch._C._dispatch_tls_local_include_set() |
| for tensor in tensors: |
| key_set = key_set | torch._C._dispatch_keys(tensor) |
| key_set = key_set - torch._C._dispatch_tls_local_exclude_set() |
| key_set = key_set & key_mask |
| return key_set |
| |
| |
| # Mode stack for PreDispatchKey |
| # it should always have three keys with |
| # priority given to FunctionalTensorMode and |
| # then ProxyTorchDispatchMode. It means that |
| # slot 0 belongs to ProxyTorchDispatchMode and |
| # slot 1 belongs to FunctionalTensorMode. |
| # |
| # SchemaCheckMode is separate from the other 2, |
| # and is only valid when the stack is empty. |
| # SchemaCheckMode is for testing purposes, and |
| # is meant to run in eager mode on concrete inputs, |
| # checking for incorrect schemas in regards to |
| # aliasing or mutating ops. |
| class _ModeStackStateForPreDispatch: |
| def __init__(self): |
| self.__infra_modes = [None, None] |
| self._schema_check_mode = None |
| |
| def set(self, index, mode): |
| assert index < len(self.__infra_modes) |
| self.__infra_modes[index] = mode |
| |
| def get(self, index): |
| assert index < len(self.__infra_modes) |
| return self.__infra_modes[index] |
| |
| def count(self): |
| return len([i for i in self.__infra_modes if i is not None]) + int( |
| self._schema_check_mode is not None |
| ) |
| |
| |
| _mode_stack_state_for_pre_dispatch = _ModeStackStateForPreDispatch() |
| |
| |
| def unset_mode_pre_dispatch(mode_key, schema_check=False): |
| current_mode_stack_pre_dispatch = mode_stack_state_for_pre_dispatch() |
| assert mode_key is None or mode_key in ( |
| torch._C._TorchDispatchModeKey.PROXY, |
| torch._C._TorchDispatchModeKey.FUNCTIONAL, |
| ) |
| if schema_check: |
| assert mode_key is None |
| |
| def _unset_mode(): |
| if mode_key == torch._C._TorchDispatchModeKey.PROXY: |
| current_mode = current_mode_stack_pre_dispatch.get(0) |
| mode_stack_state_for_pre_dispatch().set(0, None) |
| return current_mode |
| elif mode_key == torch._C._TorchDispatchModeKey.FUNCTIONAL: |
| current_mode = current_mode_stack_pre_dispatch.get(1) |
| mode_stack_state_for_pre_dispatch().set(1, None) |
| return current_mode |
| else: |
| current_mode = mode_stack_state_for_pre_dispatch()._schema_check_mode |
| mode_stack_state_for_pre_dispatch()._schema_check_mode = None |
| return current_mode |
| |
| current_mode = _unset_mode() |
| |
| new_pre_dispatch_len = _len_torch_dispatch_stack_pre_dispatch() |
| # When we are unsetting a mode, we need to check if there is |
| # active mode left on the PreDispatch key. If there is nothing |
| # active, we need to remove PreDispatch key from local dispatch include |
| # set. |
| if new_pre_dispatch_len == 0: |
| torch._C._dispatch_tls_set_dispatch_key_included(DispatchKey.PreDispatch, False) |
| |
| return current_mode |
| |
| |
| def _set_mode_pre_dispatch(mode): |
| from torch._subclasses.functional_tensor import FunctionalTensorMode |
| from torch._subclasses.schema_check_mode import SchemaCheckMode |
| from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode |
| |
| assert isinstance( |
| mode, |
| ( |
| FunctionalTensorMode, |
| ProxyTorchDispatchMode, |
| SchemaCheckMode, |
| ), |
| ) |
| |
| previous_mode_stack_len = _len_torch_dispatch_stack_pre_dispatch() |
| if isinstance(mode, SchemaCheckMode): |
| current_mode = mode_stack_state_for_pre_dispatch()._schema_check_mode |
| if previous_mode_stack_len > 0: |
| raise AssertionError( |
| "SchemaCheckMode for pre-dispatch must be used exclusively, found other modes on the stack" |
| ) |
| mode_stack_state_for_pre_dispatch()._schema_check_mode = mode |
| elif isinstance(mode, FunctionalTensorMode): |
| current_mode = mode_stack_state_for_pre_dispatch().get(1) |
| assert current_mode is None |
| mode_stack_state_for_pre_dispatch().set(1, mode) |
| else: |
| current_mode = mode_stack_state_for_pre_dispatch().get(0) |
| assert current_mode is None |
| mode_stack_state_for_pre_dispatch().set(0, mode) |
| |
| # When we are setting a mode, we need to check if there is |
| # active mode left on the PreDispatch key. If there was nothing |
| # active before setting this mode, it means that PreDispatch key |
| # was turned off. So we need to turn it on again. |
| if previous_mode_stack_len == 0: |
| torch._C._dispatch_tls_set_dispatch_key_included(DispatchKey.PreDispatch, True) |
| |
| |
| def _pop_mode_from_pre_dispatch(): |
| mode_stack = mode_stack_state_for_pre_dispatch() |
| pre_dispatch_len = _len_torch_dispatch_stack_pre_dispatch() |
| |
| if pre_dispatch_len == 0: |
| raise AssertionError("Trying to pop empty mode stack") |
| |
| if mode_stack._schema_check_mode is not None: |
| return unset_mode_pre_dispatch(None, schema_check=True) |
| if mode_stack.get(1) is not None: |
| return unset_mode_pre_dispatch(torch._C._TorchDispatchModeKey.FUNCTIONAL) |
| if mode_stack.get(0) is not None: |
| return unset_mode_pre_dispatch(torch._C._TorchDispatchModeKey.PROXY) |
| |
| |
| def _len_torch_dispatch_stack_pre_dispatch(): |
| return mode_stack_state_for_pre_dispatch().count() |
| |
| |
| def _get_dispatch_mode_pre_dispatch(mode_key): |
| assert mode_key in ( |
| torch._C._TorchDispatchModeKey.PROXY, |
| torch._C._TorchDispatchModeKey.FUNCTIONAL, |
| ) |
| if mode_key == torch._C._TorchDispatchModeKey.PROXY: |
| return mode_stack_state_for_pre_dispatch().get(0) |
| else: |
| return mode_stack_state_for_pre_dispatch().get(1) |
| |
| |
| def _get_current_dispatch_mode_pre_dispatch(): |
| if mode_stack_state_for_pre_dispatch()._schema_check_mode is not None: |
| return mode_stack_state_for_pre_dispatch()._schema_check_mode |
| else: |
| stack_len = mode_stack_state_for_pre_dispatch().count() |
| if stack_len == 2: |
| return mode_stack_state_for_pre_dispatch().get(1) |
| if stack_len == 1: |
| return ( |
| mode_stack_state_for_pre_dispatch().get(1) |
| if mode_stack_state_for_pre_dispatch().get(1) is not None |
| else mode_stack_state_for_pre_dispatch().get(0) |
| ) |
| return None |
| |
| |
| def mode_stack_state_for_pre_dispatch(): |
| global _mode_stack_state_for_pre_dispatch |
| return _mode_stack_state_for_pre_dispatch |
| |
| |
| cached_ops: Set["OpOverload"] = set() |
| |
| |
| def add_cached_op(op_overload): |
| global cached_ops |
| cached_ops.add(op_overload) |
| |
| |
| def reset_cached_ops(): |
| global cached_ops |
| cached_ops.clear() |
| |
| |
| def get_cached_ops(): |
| global cached_ops |
| return cached_ops |
| |
| |
| # Each OpOverload object contains pointer to a a specific operator overload, a pointer to the parent `OpOverloadPacket` object. |
| # You can obtain an OpOverload object through attribute query on OpOverloadPacket. |
| class OpOverload(OperatorBase): |
| def __init__(self, overloadpacket, op, op_dk, schema, tags): |
| super().__init__() |
| self._op = op |
| self._op_dk = op_dk |
| self._schema = schema |
| self._overloadpacket = overloadpacket |
| self._tags = tags |
| self._overloadname = ( |
| "default" if schema.overload_name == "" else schema.overload_name |
| ) |
| self._name = self._schema.name |
| if schema.overload_name: |
| self._name += "." + schema.overload_name |
| self.__name__ = f"{self._schema.name.split('::')[1]}.{self._overloadname}" |
| self.__module__ = overloadpacket.__module__ |
| op.__module__ = overloadpacket.__module__ |
| self.__qualname__ = self._name |
| self.__annotations__ = {} |
| # Only compute the OperatorHandle when we need it. Not all OpOverloads have |
| # OperatorHandles (the TorchScript ones don't...) |
| self._lazy_handle = None |
| |
| # If the OpOverload was constructed from a Library.def in Python. |
| self._defined_in_python = self.__qualname__ in torch.library._defs |
| |
| # Logic replicated from aten/src/ATen/native/MathBitsFallback.h |
| is_write = None |
| for a in self._schema.arguments: |
| if a.alias_info is None: |
| continue |
| if is_write is None: |
| is_write = a.alias_info.is_write |
| else: |
| # We will conservatively call mixed mutable/non-mutable |
| # aliased inputs as NOT a view |
| is_write = a.alias_info.is_write or is_write |
| self.is_view = is_write is not None and not is_write |
| |
| @property |
| def _namespace(self): |
| return self._schema.name.split("::")[0] |
| |
| @property |
| def _opname(self): |
| return self._schema.name.split("::")[1] |
| |
| @property |
| def _handle(self): |
| if self._lazy_handle is None: |
| self._lazy_handle = torch._C._dispatch_find_schema_or_throw( |
| self._schema.name, self._schema.overload_name |
| ) |
| return self._lazy_handle |
| |
| # it's a no-op since OpOverload object is immutable and must be unique for a given op overload. |
| def __deepcopy__(self, memo=None): |
| return self |
| |
| def __repr__(self): |
| return "<OpOverload(op='{}.{}', overload='{}')>".format( |
| *self._schema.name.split("::"), self._overloadname |
| ) |
| |
| # Use positional-only argument to avoid naming collision with aten ops arguments |
| # that are named "self". This way, all the aten ops can be called by kwargs. |
| def __call__(self, /, *args, **kwargs): |
| return self._op(*args, **kwargs) |
| |
| # Use positional-only argument to avoid naming collision with aten ops arguments |
| # that are named "self". This way, all the aten ops can be called by kwargs. |
| def redispatch(self, /, keyset, *args, **kwargs): |
| return self._handle.redispatch_boxed(keyset, *args, **kwargs) |
| |
| def __hash__(self): |
| return hash(self._op) |
| |
| # `my_namespace.my_op_name.overload_name` |
| def __str__(self): |
| return "{}.{}.{}".format(*self._schema.name.split("::"), self._overloadname) |
| |
| def has_kernel_for_dispatch_key(self, k): |
| return super().has_kernel_for_dispatch_key( |
| k |
| ) or torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), k) |
| |
| def has_kernel_for_any_dispatch_key(self, ks): |
| return torch._C._dispatch_has_kernel_for_any_dispatch_key( |
| self.name(), ks |
| ) or super().has_kernel_for_any_dispatch_key(ks) |
| |
| @property |
| def namespace(self): |
| return self._schema.name.split("::")[0] |
| |
| def _can_decompose(self): |
| dk = DispatchKey.CompositeImplicitAutograd |
| return dk in self.py_kernels or torch._C._dispatch_has_kernel_for_dispatch_key( |
| self.name(), dk |
| ) |
| |
| def decompose(self, *args, **kwargs): |
| dk = DispatchKey.CompositeImplicitAutograd |
| if dk in self.py_kernels: |
| # NB: This branch is not too necessary anymore, because we can |
| # apply Python CompositeImplicitAutograd *before* tracing |
| # using Python dispatcher (also taking advantage of the autograd |
| # formula). But it's included for completeness |
| return self.py_kernels[dk](*args, **kwargs) |
| elif torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk): |
| return self._op_dk(dk, *args, **kwargs) |
| else: |
| return NotImplemented |
| |
| # Remove a dispatch key from the dispatch cache. This will force it to get |
| # recomputed the next time. Does nothing |
| # WARNING: if you register a dispatch key to py_kernels of an OpOverload, |
| # calling _del_dispatch on that key is NOT sufficient to apply your change, |
| # because a single registration may affect MULTIPLE dispatch keys (e.g., |
| # registering Autograd affects AutogradCPU). del_dispatch is to be used |
| # only if you are specifically modifying how get_dispatch handles a |
| # particular input 'key'. |
| def _uncache_dispatch(self, key): |
| self._dispatch_cache.pop(key, None) |
| |
| # This implements the pre-computation logic for the Python dispatcher. |
| def _get_dispatch(self, key): |
| # This is only called upon a cache miss |
| assert key not in self._dispatch_cache, f"{self} {key}" |
| |
| if key == DispatchKey.Python: |
| if not isinstance(self, TorchBindOpOverload) and not self.python_key_table: |
| self._dispatch_cache[key] = key |
| add_cached_op(self) |
| return key |
| |
| def handler(*args, **kwargs): |
| from torch.utils._python_dispatch import _get_current_dispatch_mode |
| |
| # TODO: We also need to handle tensor subclasses here |
| # TODO(voz): We should walk all the nodes here / turn it into a list, topmode is ok for now. |
| curr_mode = type(_get_current_dispatch_mode()) |
| assert ( |
| curr_mode is not None |
| ), "Illegal invocation of dispatch on torch._C.DispatchKey.Python without a mode." |
| |
| if curr_mode not in self.python_key_table: |
| if isinstance(self, TorchBindOpOverload): |
| with torch.utils._python_dispatch._pop_mode_temporarily() as mode: |
| return torch._library.utils.handle_dispatch_mode( |
| mode, self, *args, **kwargs |
| ) |
| else: |
| return self._op_dk(key, *args, **kwargs) |
| |
| with torch.utils._python_dispatch._pop_mode_temporarily() as mode: |
| return self.python_key_table[curr_mode](mode, *args, **kwargs) |
| |
| self._dispatch_cache[key] = handler |
| add_cached_op(self) |
| return handler |
| |
| functionality_key = torch._C._to_functionality_key(key) # type: ignore[attr-defined] |
| if functionality_key == DispatchKey.PreDispatch: |
| curr_stack_len = _len_torch_dispatch_stack_pre_dispatch() |
| # The check for Python in the exclude set is so we properly respect `with no_dispatch()` |
| # calls inside of a mode. |
| if ( |
| curr_stack_len > 0 |
| and not torch._C._dispatch_tls_is_dispatch_key_excluded( |
| DispatchKey.Python |
| ) |
| ): |
| |
| def handler(*args, **kwargs): |
| @contextlib.contextmanager |
| def _temporarily_pop_modes_from_pre_dispatch(): |
| top_mode = _pop_mode_from_pre_dispatch() |
| try: |
| yield top_mode |
| finally: |
| _set_mode_pre_dispatch(top_mode) |
| |
| with _temporarily_pop_modes_from_pre_dispatch() as curr_mode: |
| return torch._library.utils.handle_dispatch_mode( |
| curr_mode, self, *args, **kwargs |
| ) |
| |
| # Note [Not Caching Per-Dispatch-Key Mode Handlers] |
| # Note that we're not caching this handler. There isn't really a point, since the slow bit |
| # is the handler itself (in python). |
| # Also, not caching means that we don't have to reset the cache when any existing |
| # modes go out of scope (which in of itself takes time to loop through all operators). |
| return handler |
| |
| final_key = resolve_key(self, key) |
| |
| # See Note [Not Caching Per-Dispatch-Key Mode Handlers] |
| cache_result = key != DispatchKey.PreDispatch |
| |
| # TODO: We could potentially have lots of debugging wrappers against |
| # dispatch keys; design some general registration mechanism instead of |
| # having if statement for each of them |
| if key == DispatchKey.Functionalize: |
| import torch._dispatch.python as pydispatch |
| |
| if pydispatch.CROSSREF_FUNCTIONALIZE: |
| handler = pydispatch.make_crossref_functionalize(self, final_key) |
| if cache_result: |
| self._dispatch_cache[key] = handler |
| add_cached_op(self) |
| return handler |
| |
| r = self.py_kernels.get(final_key, final_key) |
| if cache_result: |
| self._dispatch_cache[key] = r |
| add_cached_op(self) |
| return r |
| |
| def name(self): |
| return self._name |
| |
| @property |
| def overloadpacket(self): |
| return self._overloadpacket |
| |
| @property |
| def op(self): |
| return self._op |
| |
| @property |
| def tags(self): |
| return self._tags |
| |
| # TODO: add more methods to expose information about input and output arguments |
| |
| |
| # TorchBindOpOverload are those custom ops which have at least one overload's |
| # schema consists of torch.ScriptObject (i.e. custom class) input. |
| # TorchBindOpOverload will skip C++ dispatcher and purely dispatched in python |
| # when its inputs contain FakeScriptObject in a similar way as higher order ops. |
| class TorchBindOpOverload(OpOverload): |
| def _fallthrough_keys(self) -> List[DispatchKey]: |
| # TODO: we should be calling the fallback for these, but a fallthrough is almost close |
| # enough to the fallback in most cases that we care about. |
| _DEFAULT_FALLTHROUGH_KEYS = [ |
| DispatchKey.Autograd, |
| DispatchKey.AutogradCPU, |
| DispatchKey.AutogradCUDA, |
| DispatchKey.ADInplaceOrView, |
| DispatchKey.BackendSelect, |
| DispatchKey.PythonTLSSnapshot, |
| DispatchKey.PythonDispatcher, |
| ] |
| |
| def _may_use_fallthrough_instead_of_fallback(key: DispatchKey): |
| if torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), key): |
| return torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough( |
| self.name(), key |
| ) |
| |
| return ( |
| key not in self.py_kernels |
| or self.py_kernels[key] is torch.library.fallthrough_kernel |
| ) |
| |
| return [ |
| key |
| for key in _DEFAULT_FALLTHROUGH_KEYS |
| if _may_use_fallthrough_instead_of_fallback(key) |
| ] |
| |
| @contextlib.contextmanager |
| def _register_as_effectful_op_temporarily(self): |
| from torch._higher_order_ops.effects import ( |
| _EffectType, |
| _register_effectful_op, |
| SIDE_EFFECTS, |
| ) |
| |
| try: |
| if self not in SIDE_EFFECTS: |
| _register_effectful_op(self, _EffectType.ORDERED) |
| yield |
| finally: |
| if self in SIDE_EFFECTS: |
| del SIDE_EFFECTS[self] |
| |
| # Use positional-only argument to avoid naming collision with aten ops arguments |
| # that are named "self". This way, all the aten ops can be called by kwargs. |
| def __call__(self, /, *args, **kwargs): |
| if _must_dispatch_in_python(args, kwargs): |
| # When any inputs are FakeScriptObject, we need to |
| # skip c++ dispatcher and dispatch in python through _get_dispatch of python_dispatcher |
| # because C++ dispatcher will check the schema and cannot recognize FakeScriptObject. |
| # |
| # Note: |
| # 1. We only register the torchbind op temporarily as effectful op because we only want |
| # the effect token functionalization logic to be applied during tracing. Otherwise, the behavior |
| # of the eagerly executing the op might change after tracing. |
| # 2. We don't want to register the op as effectful for all torchbind ops in ctor because this might |
| # cause unexpected behavior for some autograd.profiler ops e.g. profiler._record_function_exit._RecordFunction. |
| with self._register_as_effectful_op_temporarily(): |
| return self._dispatch_in_python(args, kwargs, self._fallthrough_keys()) |
| return self._op(*args, **kwargs) |
| |
| def _dispatch_in_python(self, args, kwargs, fallthrough_keys): |
| non_fallthrough_keys = torch._C._dispatch_keyset_full() |
| for key in fallthrough_keys: |
| non_fallthrough_keys = non_fallthrough_keys.remove(key) |
| |
| dispatch_key_set = _compute_keyset(args, kwargs, non_fallthrough_keys) |
| dispatch_key = dispatch_key_set.highestPriorityTypeId() |
| |
| handler = ( |
| self._get_dispatch(dispatch_key) |
| if dispatch_key not in self._dispatch_cache |
| else self._dispatch_cache[dispatch_key] |
| ) |
| |
| if isinstance(handler, DispatchKey): |
| # fallthrough keys can be registered at runtime via torch.library.impl |
| # so need to add it to fallthrough_keys and re-dispatch. |
| if torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough( |
| self.name(), dispatch_key |
| ): |
| return self._dispatch_in_python( |
| args, kwargs, fallthrough_keys + [dispatch_key] |
| ) |
| |
| raise RuntimeError( |
| f"Torchbind op {self} received a FakeScriptObject input when dispatching {handler}." |
| f" but no python implementation is found." |
| f" Please file an issue on this when you encounter this error." |
| f" This error can happen when you export or compile the model." |
| f" It can still happpen even if a C++ implementation for {dispatch_key}. " |
| f" has been registered. That's because FakeScriptObject purely lives in python and cannot work " |
| f" with a C++ implementation." |
| ) |
| |
| assert isinstance(handler, Callable) # type: ignore[arg-type] |
| return handler(*args, **kwargs) |
| |
| |
| def _must_dispatch_in_python(args, kwargs): |
| return pytree.tree_any( |
| lambda obj: isinstance( |
| obj, torch._library.fake_class_registry.FakeScriptObject |
| ), |
| (args, kwargs), |
| ) |
| |
| |
| def _has_script_object_arg(schema: torch.FunctionSchema) -> bool: |
| return any(isinstance(arg.type, torch.ClassType) for arg in schema.arguments) |
| |
| |
| # OpOverloadPacket class contains pointer to a base unresolved operator that doesn't correspond to a specific operator |
| # You can obtain an OpOverload object through attribute query. |
| class OpOverloadPacket: |
| def __init__(self, qualified_op_name, op_name, op, overload_names): |
| # These attributes are accessible on the object through the properties |
| # defined below but are immutable |
| self._qualified_op_name = qualified_op_name |
| self.__name__ = op_name |
| self._op = op |
| self._overload_names = overload_names |
| self._dir = [] |
| self._has_torchbind_op_overload = any( |
| _has_script_object_arg(schema) for schema in self._schemas.values() |
| ) |
| |
| # it's a no-op since OpOverloadPacket object is immutable and must be unique for a given op. |
| def __deepcopy__(self, memo=None): |
| return self |
| |
| def __repr__(self): |
| return "<OpOverloadPacket(op='{}.{}')>".format( |
| *self._qualified_op_name.split("::") |
| ) |
| |
| def __hash__(self): |
| return hash(self._op) |
| |
| def __str__(self): |
| return "{}.{}".format(*self._qualified_op_name.split("::")) |
| |
| @property |
| def op(self): |
| return self._op |
| |
| @property |
| def _schemas(self): |
| return { |
| overload_name: torch._C._get_schema(self._qualified_op_name, overload_name) |
| for overload_name in self._overload_names |
| } |
| |
| def __getattr__(self, key): |
| # It is not a valid op_name when __file__ is passed in |
| if key == "__file__": |
| return "torch.ops" |
| |
| # ensure that query for dunder attributes that does not exist on |
| # opoverloadpacket but instead exists on the self._op object does not unnecessarily call |
| # `_get_operation_overload` (which is an expensive operation). |
| # This is done to prevent any potential slowdown. This list can be extended |
| # if there exists other attributes like `__name__` that only exist on self._op and not on the |
| # opoverloadpacket. |
| # This is ok since we are guaranteed that an overload name for an aten op can't start with '__' |
| try: |
| if key.startswith("__"): |
| return getattr(self._op, key) |
| except AttributeError: |
| # for consistency because it seems weird to |
| # throw an attribute error with a message containing |
| # an object name different from the one the attribute |
| # query was performed on. |
| raise AttributeError( |
| f"'{str(self)}' can't have an overload name beginning with '__' and the " |
| f"underlying op {str(self._op)} has no attribute {key} either." |
| ) from None |
| |
| try: |
| # This is ok since we are guaranteed that an overload name for an aten op can't be 'default' |
| use_key = "" if key == "default" else key |
| # TODO: disallow access to overloads registered by JIT |
| op_dk_tags = torch._C._get_operation_overload( |
| self._qualified_op_name, use_key |
| ) |
| if op_dk_tags is None: |
| raise AttributeError( |
| f"The underlying op of '{str(self)}' has no overload name '{key}'" |
| ) |
| |
| op_, op_dk_, tags = op_dk_tags |
| schema = torch._C._get_schema(self._qualified_op_name, use_key) |
| overload = ( |
| OpOverload(self, op_, op_dk_, schema, tags) |
| if not _has_script_object_arg(schema) |
| else TorchBindOpOverload(self, op_, op_dk_, schema, tags) |
| ) |
| # cache the overload object |
| setattr(self, key, overload) |
| self._dir.append(key) |
| return overload |
| except RuntimeError: |
| raise AttributeError( |
| f"The underlying op of '{str(self)}' has no overload name '{key}'" |
| ) from None |
| |
| def __iter__(self): |
| return iter(self._dir) |
| |
| # Use positional-only argument to avoid naming collision with aten ops arguments |
| # that are named "self". This way, all the aten ops can be called by kwargs. |
| def __call__(self, /, *args, **kwargs): |
| # overloading __call__ to ensure torch.ops.foo.bar() |
| # is still callable from JIT |
| # We save the function ptr as the `op` attribute on |
| # OpOverloadPacket to access it here. |
| |
| # Directly calling OverloadPacket goes into C++, which will check |
| # the schema and cause an error for torchbind op when inputs consist of FakeScriptObject so we |
| # intercept it here and call TorchBindOpverload instead. |
| if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs): |
| return _call_overload_packet_from_python(self, args, kwargs) |
| return self._op(*args, **(kwargs or {})) |
| |
| # TODO: use this to make a __dir__ |
| def overloads(self): |
| return [n if n else "default" for n in self._overload_names] |
| |
| |
| # Note - this mirrors the logic of the cpp_function defined in jit/python/init.cpp |
| # _jit_get_operations, which calls _get_operation_for_overload_or_packet. |
| def _call_overload_packet_from_python(op: OpOverloadPacket, args, kwargs): |
| # Re-use the torch function handling logic in cpp |
| torch_function_called, ret = torch._C._maybe_call_torch_function_for_op_packet( |
| op, *args, **kwargs |
| ) |
| |
| if torch_function_called: |
| return ret |
| |
| # The following mirrors getOpWithStack. |
| # In cpp, we do a schema matching for the arguments, and call ToIValue to |
| # to check whether the arguments are valid. But need to do similar things here |
| # and check the schema whether the FakeScriptObject is the corresponding fake class |
| # of the actual class used in schema. |
| exceptions = {} |
| found_op = None |
| for overload_name in op.overloads(): |
| op_overload = getattr(op, overload_name) |
| try: |
| _ = torch._C._check_schema_allow_fake_script_object( |
| op_overload._schema, *args, **kwargs |
| ) |
| found_op = op_overload |
| break |
| except RuntimeError as e: |
| exceptions[overload_name] = e |
| |
| if found_op: |
| return found_op(*args, **kwargs) |
| |
| err_msg = ( |
| f"Fail to match any TorchBindOverload of {op} with following exceptions:\n" |
| ) |
| for i, (key, msg) in enumerate(exceptions.items()): |
| err_msg += f"Overload name {key}:\n {msg}\n" |
| raise RuntimeError(err_msg) |
| |
| |
| # Resolution of torch.fn is different from torch.ops.aten.fn |
| # torch.fn uses the Python argparser, matches with the |
| # appropriate schema, and calls into the unboxed version of the method |
| # torch.ops.aten.fn resolution is done via the mechanism defined in JIT. |
| # JIT creates a stack of all the overloads and then tries to match the |
| # correct one at runtime and always calls into the boxed version of the method |
| # Autograd codegen creates VariableType, TracerType, |
| # inplace or view type and python bindings. |
| # Aten codegen generates tensor methods for the tensor class. |
| |
| # _OpNamespace is a subclass of ModuleType because the torch script |
| # allows attribute lookups on modules only. Since we want torch.ops.foo.bar() |
| # to work from script, we need to ensure ops and foo are modules |
| |
| |
| class _OpNamespace(types.ModuleType): |
| """ |
| An op namespace to dynamically bind Operators into Python. |
| |
| Say a user has created a custom Operator called "my_namespace::my_op". To |
| call this op, the user will write torch.ops.my_namespace.my_op(...). |
| At startup, this operation will not yet be bound into Python. Instead, the |
| following sequence of magic tricks will occur: |
| 1. `torch.ops.my_namespace` will invoke the `__getattr__` magic method |
| on the `torch.ops` object, which will create a new `_OpNamespace` |
| object called `my_namespace` and set it as an attribute on the `ops` |
| object. |
| 2. `torch.ops.my_namespace.my_op` will then invoke `__getattr__` on |
| the `my_namespace` object, which will retrieve the operation via |
| `torch.get_operation`, a function bound from C++, and then in a similar |
| fashion bind this new object onto the `my_namespace` object. |
| 3. `torch.ops.my_namespace.my_op(...)` then calls this new operation |
| and subsequent accesses will incur no further lookup (the namespace and |
| operation will already exist). |
| """ |
| |
| def __init__(self, name): |
| super().__init__("torch.ops." + name) |
| self.name = name |
| self._dir = [] |
| |
| def __iter__(self): |
| return iter(self._dir) |
| |
| def __getattr__(self, op_name): |
| # It is not a valid op_name when __file__ is passed in |
| if op_name == "__file__": |
| return "torch.ops" |
| elif op_name in ["__origin__", "__self__"]: |
| raise AttributeError( |
| f"Invalid attribute '{op_name}' for '_OpNamespace' '{self.name}'" |
| ) |
| |
| # Get the op `my_namespace::my_op` if available. This will also check |
| # for overloads and raise an exception if there are more than one. |
| namespace_name = self.name |
| qualified_op_name = f"{namespace_name}::{op_name}" |
| module_name = self.__module__ + "." + namespace_name |
| |
| try: |
| op, overload_names = _get_packet(qualified_op_name, module_name) |
| if op is None: |
| raise AttributeError( |
| f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'" |
| ) |
| except RuntimeError as e: |
| # Turn this into AttributeError so getattr(obj, key, default) |
| # works (this is called by TorchScript with __origin__) |
| raise AttributeError( |
| f"'_OpNamespace' '{self.name}' object has no attribute '{op_name}'" |
| ) from e |
| |
| op.__module__ = module_name |
| opoverloadpacket = OpOverloadPacket( |
| qualified_op_name, op_name, op, overload_names |
| ) |
| opoverloadpacket.__module__ = self.__module__ + "." + namespace_name |
| # cache the opoverloadpacket to ensure that each op corresponds to |
| # a unique OpOverloadPacket object |
| setattr(self, op_name, opoverloadpacket) |
| self._dir.append(op_name) |
| return opoverloadpacket |
| |
| |
| def _get_packet(qualname, op_module): |
| op, overload_names = torch._C._jit_get_operation(qualname) |
| if op is not None: |
| # let the script frontend know that op is identical to the builtin op |
| # with qualified_op_name |
| torch.jit._builtins._register_builtin(op, qualname) |
| op.__module__ = op_module |
| return op, overload_names |
| |
| |
| def _refresh_packet(packet): |
| op, overload_names = _get_packet(packet._qualified_op_name, packet._op.__module__) |
| assert op is not None |
| packet._op = op |
| packet._overload_names = overload_names |
| |
| |
| class _PyOpNamespace(_OpNamespace): |
| def __init__(self, name, ops): |
| super().__init__(name) |
| self._ops = ops |
| |
| def __getattr__(self, name): |
| # Following _OpNamespace.__getattr__, we cache the op on the _PyOpNamespace object. |
| op = self._ops.get(name, None) |
| if op is None: |
| raise AttributeError( |
| f"'_PyOpNamespace' '{self.name}' object has no attribute '{name}'" |
| ) |
| setattr(self, name, op) |
| return op |
| |
| |
| class _Ops(types.ModuleType): |
| __file__ = "_ops.py" |
| |
| def __init__(self): |
| super().__init__("torch.ops") |
| self.loaded_libraries = set() |
| self._higher_order_op_namespace = _PyOpNamespace( |
| "torch.ops.higher_order", _higher_order_ops |
| ) |
| self._dir = [] |
| |
| def __getattr__(self, name): |
| # Check if the name is a HigherOrderOperator |
| if name == "higher_order": |
| return self._higher_order_op_namespace |
| |
| # Here we are creating `torch.ops.my_namespace` |
| namespace = _OpNamespace(name) |
| setattr(self, name, namespace) |
| self._dir.append(name) |
| return namespace |
| |
| def __iter__(self): |
| return iter(self._dir) |
| |
| def import_module(self, module): |
| """ |
| Imports a Python module that has torch.library registrations. |
| |
| Generally, to extend PyTorch with custom operators, a user will |
| create a Python module whose import triggers registration of |
| the custom operators via a torch.ops.load_library call or a call |
| to one or more torch.library.* APIs. |
| |
| It is unexpected for Python modules to have side effects, so some |
| linters and formatters will complain. Use this API to import Python |
| modules that contain these torch.library side effects. |
| |
| Args: |
| module (str): The name of the Python module to import |
| |
| """ |
| importlib.import_module(module) |
| |
| def load_library(self, path): |
| """ |
| Loads a shared library from the given path into the current process. |
| |
| The library being loaded may run global initialization code to register |
| custom operators with the PyTorch JIT runtime. This allows dynamically |
| loading custom operators. For this, you should compile your operator |
| and the static registration code into a shared library object, and then |
| call ``torch.ops.load_library('path/to/libcustom.so')`` to load the |
| shared object. |
| |
| After the library is loaded, it is added to the |
| ``torch.ops.loaded_libraries`` attribute, a set that may be inspected |
| for the paths of all libraries loaded using this function. |
| |
| Args: |
| path (str): A path to a shared library to load. |
| """ |
| if torch._running_with_deploy(): |
| return |
| |
| path = _utils_internal.resolve_library_path(path) |
| with dl_open_guard(): |
| # Import the shared library into the process, thus running its |
| # static (global) initialization code in order to register custom |
| # operators with the JIT. |
| ctypes.CDLL(path) |
| self.loaded_libraries.add(path) |
| |
| |
| # The ops "namespace" |
| ops: _Ops = _Ops() |