| """ |
| The torch package contains data structures for multi-dimensional |
| tensors and defines mathematical operations over these tensors. |
| Additionally, it provides many utilities for efficient serialization of |
| Tensors and arbitrary types, and other useful utilities. |
| |
| It has a CUDA counterpart, that enables you to run your tensor computations |
| on an NVIDIA GPU with compute capability >= 3.0. |
| """ |
| |
| # mypy: allow-untyped-defs |
| |
| import builtins |
| import ctypes |
| import glob |
| import importlib |
| import inspect |
| import math |
| import os |
| import platform |
| import sys |
| import textwrap |
| import threading |
| from typing import ( |
| Any as _Any, |
| Callable as _Callable, |
| Dict as _Dict, |
| Optional as _Optional, |
| overload as _overload, |
| Set as _Set, |
| Tuple as _Tuple, |
| Type as _Type, |
| TYPE_CHECKING, |
| TypeVar as _TypeVar, |
| Union as _Union, |
| ) |
| from typing_extensions import ParamSpec as _ParamSpec, TypeGuard as _TypeGuard |
| |
| |
| if TYPE_CHECKING: |
| from .types import IntLikeType |
| |
| |
| # multipy/deploy is setting this import before importing torch, this is the most |
| # reliable way we have to detect if we're running within deploy. |
| # https://github.com/pytorch/multipy/blob/d60f34ad38c371e441fe7ffdb77a3c3dda5a5d19/multipy/runtime/interpreter/interpreter_impl.cpp#L134-L137 |
| def _running_with_deploy() -> builtins.bool: |
| return sys.modules.get("torch._meta_registrations", None) is object |
| |
| |
| from torch._utils import ( |
| _functionalize_sync as _sync, |
| _import_dotted_name, |
| classproperty, |
| ) |
| from torch._utils_internal import ( |
| get_file_path, |
| prepare_multiprocessing_environment, |
| USE_GLOBAL_DEPS, |
| USE_RTLD_GLOBAL_WITH_LIBTORCH, |
| ) |
| |
| |
| # TODO(torch_deploy) figure out how to freeze version.py in fbcode build |
| if _running_with_deploy(): |
| __version__ = "torch-deploy-1.8" |
| else: |
| from torch.torch_version import __version__ as __version__ |
| |
| __all__ = [ |
| "BoolStorage", |
| "BoolTensor", |
| "ByteStorage", |
| "ByteTensor", |
| "CharStorage", |
| "CharTensor", |
| "DoubleStorage", |
| "DoubleTensor", |
| "FloatStorage", |
| "FloatTensor", |
| "GradScaler", |
| "IntStorage", |
| "IntTensor", |
| "LongStorage", |
| "LongTensor", |
| "ShortStorage", |
| "ShortTensor", |
| "SymBool", |
| "SymFloat", |
| "SymInt", |
| "Tensor", |
| "TypedStorage", |
| "UntypedStorage", |
| "are_deterministic_algorithms_enabled", |
| "autocast", |
| "chunk", |
| "compile", |
| "cond", |
| "enable_grad", |
| "export", |
| "get_default_device", |
| "get_deterministic_debug_mode", |
| "get_device_module", |
| "get_float32_matmul_precision", |
| "get_rng_state", |
| "inference_mode", |
| "initial_seed", |
| "is_deterministic_algorithms_warn_only_enabled", |
| "is_storage", |
| "is_tensor", |
| "is_warn_always_enabled", |
| "load", |
| "lobpcg", |
| "manual_seed", |
| "matmul", |
| "no_grad", |
| "rand", |
| "randn", |
| "save", |
| "seed", |
| "set_default_device", |
| "set_default_tensor_type", |
| "set_deterministic_debug_mode", |
| "set_float32_matmul_precision", |
| "set_printoptions", |
| "set_rng_state", |
| "set_warn_always", |
| "split", |
| "stack", |
| "sym_float", |
| "sym_int", |
| "sym_ite", |
| "sym_max", |
| "sym_min", |
| "sym_not", |
| "typename", |
| "unravel_index", |
| "use_deterministic_algorithms", |
| "vmap", |
| ] |
| |
| # Please keep this list sorted |
| assert __all__ == sorted(__all__) |
| |
| ################################################################################ |
| # Load the extension module |
| ################################################################################ |
| |
| if sys.platform == "win32": |
| |
| def _load_dll_libraries() -> None: |
| import sysconfig |
| |
| from torch.version import cuda as cuda_version |
| |
| pfiles_path = os.getenv("ProgramFiles", r"C:\Program Files") |
| py_dll_path = os.path.join(sys.exec_prefix, "Library", "bin") |
| th_dll_path = os.path.join(os.path.dirname(__file__), "lib") |
| usebase_path = os.path.join( |
| sysconfig.get_config_var("userbase"), "Library", "bin" |
| ) |
| |
| # When users create a virtualenv that inherits the base environment, |
| # we will need to add the corresponding library directory into |
| # DLL search directories. Otherwise, it will rely on `PATH` which |
| # is dependent on user settings. |
| if sys.exec_prefix != sys.base_exec_prefix: |
| base_py_dll_path = os.path.join(sys.base_exec_prefix, "Library", "bin") |
| else: |
| base_py_dll_path = "" |
| |
| dll_paths = [ |
| p |
| for p in (th_dll_path, py_dll_path, base_py_dll_path, usebase_path) |
| if os.path.exists(p) |
| ] |
| |
| if not builtins.any( |
| os.path.exists(os.path.join(p, "nvToolsExt64_1.dll")) for p in dll_paths |
| ): |
| nvtoolsext_dll_path = os.path.join( |
| os.getenv( |
| "NVTOOLSEXT_PATH", |
| os.path.join(pfiles_path, "NVIDIA Corporation", "NvToolsExt"), |
| ), |
| "bin", |
| "x64", |
| ) |
| else: |
| nvtoolsext_dll_path = "" |
| |
| if cuda_version and builtins.all( |
| not glob.glob(os.path.join(p, "cudart64*.dll")) for p in dll_paths |
| ): |
| cuda_version_1 = cuda_version.replace(".", "_") |
| cuda_path_var = "CUDA_PATH_V" + cuda_version_1 |
| default_path = os.path.join( |
| pfiles_path, "NVIDIA GPU Computing Toolkit", "CUDA", f"v{cuda_version}" |
| ) |
| cuda_path = os.path.join(os.getenv(cuda_path_var, default_path), "bin") |
| else: |
| cuda_path = "" |
| |
| dll_paths.extend( |
| p for p in (nvtoolsext_dll_path, cuda_path) if os.path.exists(p) |
| ) |
| |
| kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) |
| with_load_library_flags = hasattr(kernel32, "AddDllDirectory") |
| prev_error_mode = kernel32.SetErrorMode(0x0001) |
| |
| kernel32.LoadLibraryW.restype = ctypes.c_void_p |
| if with_load_library_flags: |
| kernel32.LoadLibraryExW.restype = ctypes.c_void_p |
| |
| for dll_path in dll_paths: |
| os.add_dll_directory(dll_path) |
| |
| try: |
| ctypes.CDLL("vcruntime140.dll") |
| ctypes.CDLL("msvcp140.dll") |
| ctypes.CDLL("vcruntime140_1.dll") |
| except OSError: |
| print( |
| textwrap.dedent( |
| """ |
| Microsoft Visual C++ Redistributable is not installed, this may lead to the DLL load failure. |
| It can be downloaded at https://aka.ms/vs/16/release/vc_redist.x64.exe |
| """ |
| ).strip() |
| ) |
| |
| dlls = glob.glob(os.path.join(th_dll_path, "*.dll")) |
| path_patched = False |
| for dll in dlls: |
| is_loaded = False |
| if with_load_library_flags: |
| res = kernel32.LoadLibraryExW(dll, None, 0x00001100) |
| last_error = ctypes.get_last_error() |
| if res is None and last_error != 126: |
| err = ctypes.WinError(last_error) |
| err.strerror += ( |
| f' Error loading "{dll}" or one of its dependencies.' |
| ) |
| raise err |
| elif res is not None: |
| is_loaded = True |
| if not is_loaded: |
| if not path_patched: |
| os.environ["PATH"] = ";".join(dll_paths + [os.environ["PATH"]]) |
| path_patched = True |
| res = kernel32.LoadLibraryW(dll) |
| if res is None: |
| err = ctypes.WinError(ctypes.get_last_error()) |
| err.strerror += ( |
| f' Error loading "{dll}" or one of its dependencies.' |
| ) |
| raise err |
| |
| kernel32.SetErrorMode(prev_error_mode) |
| |
| _load_dll_libraries() |
| del _load_dll_libraries |
| |
| |
| def _preload_cuda_deps(lib_folder: str, lib_name: str) -> None: |
| """Preloads cuda deps if they could not be found otherwise.""" |
| # Should only be called on Linux if default path resolution have failed |
| assert platform.system() == "Linux", "Should only be called on Linux" |
| |
| lib_path = None |
| for path in sys.path: |
| nvidia_path = os.path.join(path, "nvidia") |
| if not os.path.exists(nvidia_path): |
| continue |
| candidate_lib_paths = glob.glob( |
| os.path.join(nvidia_path, lib_folder, "lib", lib_name) |
| ) |
| if candidate_lib_paths and not lib_path: |
| lib_path = candidate_lib_paths[0] |
| if lib_path: |
| break |
| if not lib_path: |
| raise ValueError(f"{lib_name} not found in the system path {sys.path}") |
| ctypes.CDLL(lib_path) |
| |
| |
| # See Note [Global dependencies] |
| def _load_global_deps() -> None: |
| if _running_with_deploy() or platform.system() == "Windows": |
| return |
| |
| # Determine the file extension based on the platform |
| lib_ext = ".dylib" if platform.system() == "Darwin" else ".so" |
| lib_name = f"libtorch_global_deps{lib_ext}" |
| here = os.path.abspath(__file__) |
| global_deps_lib_path = os.path.join(os.path.dirname(here), "lib", lib_name) |
| |
| try: |
| ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL) |
| except OSError as err: |
| # Can only happen for wheel with cuda libs as PYPI deps |
| # As PyTorch is not purelib, but nvidia-*-cu12 is |
| cuda_libs: _Dict[str, str] = { |
| "cublas": "libcublas.so.*[0-9]", |
| "cudnn": "libcudnn.so.*[0-9]", |
| "cuda_nvrtc": "libnvrtc.so.*[0-9]", |
| "cuda_runtime": "libcudart.so.*[0-9]", |
| "cuda_cupti": "libcupti.so.*[0-9]", |
| "cufft": "libcufft.so.*[0-9]", |
| "curand": "libcurand.so.*[0-9]", |
| "nvjitlink": "libnvJitLink.so.*[0-9]", |
| "cusparse": "libcusparse.so.*[0-9]", |
| "cusolver": "libcusolver.so.*[0-9]", |
| "nccl": "libnccl.so.*[0-9]", |
| "nvtx": "libnvToolsExt.so.*[0-9]", |
| } |
| is_cuda_lib_err = [ |
| lib for lib in cuda_libs.values() if lib.split(".")[0] in err.args[0] |
| ] |
| if not is_cuda_lib_err: |
| raise err |
| for lib_folder, lib_name in cuda_libs.items(): |
| _preload_cuda_deps(lib_folder, lib_name) |
| ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL) |
| |
| |
| if (USE_RTLD_GLOBAL_WITH_LIBTORCH or os.getenv("TORCH_USE_RTLD_GLOBAL")) and ( |
| _running_with_deploy() or platform.system() != "Windows" |
| ): |
| # Do it the hard way. You might want to load libtorch with RTLD_GLOBAL in a |
| # few circumstances: |
| # |
| # 1. You're in a build environment (e.g., fbcode) where |
| # libtorch_global_deps is not available, but you still need |
| # to get mkl to link in with RTLD_GLOBAL or it will just |
| # not work. |
| # |
| # 2. You're trying to run PyTorch under UBSAN and you need |
| # to ensure that only one copy of libtorch is loaded, so |
| # vptr checks work properly |
| # |
| # If you're using this setting, you must verify that all the libraries |
| # you load consistently use the same libstdc++, or you may have |
| # mysterious segfaults. |
| # |
| old_flags = sys.getdlopenflags() |
| sys.setdlopenflags(os.RTLD_GLOBAL | os.RTLD_LAZY) |
| |
| from torch._C import * # noqa: F403 |
| |
| sys.setdlopenflags(old_flags) |
| del old_flags |
| |
| else: |
| # Easy way. You want this most of the time, because it will prevent |
| # C++ symbols from libtorch clobbering C++ symbols from other |
| # libraries, leading to mysterious segfaults. |
| # |
| # If building in an environment where libtorch_global_deps isn't available |
| # like parts of fbsource, but where RTLD_GLOBAL causes segfaults, you will |
| # want USE_RTLD_GLOBAL_WITH_LIBTORCH = False and USE_GLOBAL_DEPS = False |
| # |
| # See Note [Global dependencies] |
| if USE_GLOBAL_DEPS: |
| _load_global_deps() |
| from torch._C import * # noqa: F403 |
| |
| |
| class SymInt: |
| """ |
| Like an int (including magic methods), but redirects all operations on the |
| wrapped node. This is used in particular to symbolically record operations |
| in the symbolic shape workflow. |
| """ |
| |
| def __init__(self, node): |
| # This field MUST be named node; C++ binding code assumes that this |
| # class has a field named node that stores SymNode |
| self.node = node |
| |
| def __bool__(self): |
| return builtins.bool(self != 0) |
| |
| def __int__(self): |
| return self.node.int_() |
| |
| def __index__(self): |
| return self.node.int_() |
| |
| # Magic methods installed by torch.fx.experimental.sym_node |
| |
| def __round__(self, ndigits=None): |
| return self |
| |
| def __truediv__(self, other): |
| if isinstance(other, (builtins.float, SymFloat)): |
| return sym_float(self).__float_truediv__(other) |
| if not isinstance(other, (builtins.int, SymInt)): |
| return NotImplemented |
| return self.__int_truediv__(other) |
| |
| def __rtruediv__(self, other): |
| if isinstance(other, (builtins.float, SymFloat)): |
| return sym_float(self).__rfloat_truediv__(other) |
| if not isinstance(other, (builtins.int, SymInt)): |
| return NotImplemented |
| return self.__rint_truediv__(other) |
| |
| def __floordiv__(self, other): |
| if isinstance(other, (builtins.float, SymFloat)): |
| return sym_float(math.floor(sym_float(self) / other)) |
| if not isinstance(other, (builtins.int, SymInt)): |
| return NotImplemented |
| return self.__int_floordiv__(other) |
| |
| def __rfloordiv__(self, other): |
| if isinstance(other, (builtins.float, SymFloat)): |
| return sym_float(math.floor(other / sym_float(self))) |
| if not isinstance(other, (builtins.int, SymInt)): |
| return NotImplemented |
| return self.__rint_floordiv__(other) |
| |
| # nb: complex is impossible to handle correctly lol, with |
| # negative base and integral float need to diverge semantics and |
| # just always return complex. Neener neener pretend this problem |
| # doesn't exist |
| def __pow__(self, other): |
| if isinstance(other, (builtins.float, SymFloat)): |
| return sym_float(self).__pow__(other) |
| if not isinstance(other, (builtins.int, SymInt)): |
| return NotImplemented |
| # Guards! This guard is necessary because we need to know it to |
| # determine the output type of this operation |
| if other >= 0: |
| return self.__pow_by_natural__(other) |
| else: |
| # Mercifully, when the exponent is negative, Python just promotes |
| # to doubles and does a float pow: |
| # |
| # if (Py_SIZE(b) < 0 && c == NULL) { |
| # /* if exponent is negative and there's no modulus: |
| # return a float. This works because we know |
| # that this calls float_pow() which converts its |
| # arguments to double. */ |
| # Py_DECREF(a); |
| # Py_DECREF(b); |
| # return PyFloat_Type.tp_as_number->nb_power(v, w, x); |
| # } |
| return sym_float(self).__pow__(sym_float(other)) |
| |
| def __rpow__(self, other): |
| if isinstance(other, (builtins.float, SymFloat)): |
| return sym_float(self).__rpow__(other) |
| if not isinstance(other, (builtins.int, SymInt)): |
| return NotImplemented |
| if self >= 0: # self is exponent |
| return self.__rpow_by_natural__(other) |
| else: |
| return sym_float(self).__rpow__(sym_float(other)) |
| |
| def __eq__(self, other: object) -> builtins.bool: |
| raise TypeError("type stub not overridden") |
| |
| def __lt__(self, other) -> builtins.bool: |
| raise TypeError("type stub not overridden") |
| |
| def __gt__(self, other) -> builtins.bool: |
| raise TypeError("type stub not overridden") |
| |
| def __le__(self, other) -> builtins.bool: |
| raise TypeError("type stub not overridden") |
| |
| def __ge__(self, other) -> builtins.bool: |
| raise TypeError("type stub not overridden") |
| |
| def __add__(self, other) -> "SymInt": |
| raise TypeError("type stub not overridden") |
| |
| def __mod__(self, other: "IntLikeType") -> "SymInt": |
| raise TypeError("type stub not overridden") |
| |
| def __mul__(self, other) -> "SymInt": |
| raise TypeError("type stub not overridden") |
| |
| def __pow_by_natural__(self, other) -> "SymInt": |
| raise TypeError("type stub not overridden") |
| |
| def __rpow_by_natural__(self, other) -> "SymInt": |
| raise TypeError("type stub not overridden") |
| |
| def __int_truediv__(self, other) -> "SymFloat": |
| raise TypeError("type stub not overridden") |
| |
| def __rint_truediv__(self, other) -> "SymFloat": |
| raise TypeError("type stub not overridden") |
| |
| def __int_floordiv__(self, other) -> "SymFloat": |
| raise TypeError("type stub not overridden") |
| |
| def __rint_floordiv__(self, other) -> "SymFloat": |
| raise TypeError("type stub not overridden") |
| |
| def __sym_max__(self, other): |
| raise TypeError("type stub not overridden") |
| |
| def __sym_min__(self, other): |
| raise TypeError("type stub not overridden") |
| |
| def __sym_float__(self): |
| raise TypeError("type stub not overridden") |
| |
| def __neg__(self): |
| raise TypeError("type stub not overridden") |
| |
| def __sub__(self, other: "IntLikeType") -> "SymInt": |
| raise TypeError("type stub not overridden") |
| |
| def __repr__(self): |
| return self.node._graph_repr() |
| |
| def _sympy_(self): |
| return self.node.expr |
| |
| def __hash__(self) -> builtins.int: |
| if self.node.is_nested_int(): |
| return hash(self.node.nested_int()) |
| else: |
| # We could support constant SymInts as well, but not doing it for now |
| raise TypeError("unhashable type: non-nested SymInt") |
| # TODO: Force specialization |
| # This can't be done because the TypeError here is load bearing |
| # for einops |
| # https://github.com/arogozhnikov/einops/blob/6181e1e95dc58c00a3143c1726da1c6ee0463164/einops/einops.py#L237 |
| # return hash(builtins.int(self)) |
| |
| def as_integer_ratio(self) -> _Tuple["SymInt", builtins.int]: |
| """Represent this int as an exact integer ratio""" |
| return self, 1 |
| |
| def bit_length(self) -> builtins.int: |
| # TODO: A more relaxed guard is possible here, where you guard to |
| # allow all integer quantities which would result in the same bit |
| # length. We can also just make a dedicated Sympy function for |
| # computing this quantity and represent it symbolically. |
| return builtins.int(self).bit_length() |
| |
| def conjugate(self) -> "SymInt": |
| return self |
| |
| |
| class SymFloat: |
| """ |
| Like an float (including magic methods), but redirects all operations on the |
| wrapped node. This is used in particular to symbolically record operations |
| in the symbolic shape workflow. |
| """ |
| |
| def __init__(self, node): |
| # This field MUST be named node; C++ binding code assumes that this |
| # class has a field named node that stores SymNode |
| self.node = node |
| |
| def __truediv__(self, other): |
| if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): |
| return NotImplemented |
| return self.__float_truediv__(sym_float(other)) |
| |
| def __rtruediv__(self, other): |
| if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): |
| return NotImplemented |
| return self.__rfloat_truediv__(sym_float(other)) |
| |
| def __floordiv__(self, other): |
| if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): |
| return NotImplemented |
| return sym_float(math.floor(self / sym_float(other))) |
| |
| def __rfloordiv__(self, other): |
| if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): |
| return NotImplemented |
| return sym_float(math.floor(sym_float(other) / self)) |
| |
| def __bool__(self): |
| return self.node.bool_() |
| |
| def __float__(self): |
| return self.node.guard_float("", 0) |
| |
| # Symbolic power does NOT work with negative base, this is to avoid |
| # potential complex outputs |
| def __pow__(self, other): |
| if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): |
| return NotImplemented |
| torch._check(self >= 0) |
| return self.__float_pow__(other) |
| |
| def __rpow__(self, other): |
| if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): |
| return NotImplemented |
| torch._check(other >= 0) |
| return self.__rfloat_pow__(other) |
| |
| # Magic methods installed by torch.fx.experimental.sym_node |
| |
| def __eq__(self, other: object) -> builtins.bool: |
| raise TypeError("type stub not overridden") |
| |
| def __lt__(self, other) -> builtins.bool: |
| raise TypeError("type stub not overridden") |
| |
| def __gt__(self, other) -> builtins.bool: |
| raise TypeError("type stub not overridden") |
| |
| def __le__(self, other) -> builtins.bool: |
| raise TypeError("type stub not overridden") |
| |
| def __ge__(self, other) -> builtins.bool: |
| raise TypeError("type stub not overridden") |
| |
| def __float_pow__(self, other) -> "SymFloat": |
| raise TypeError("type stub not overridden") |
| |
| def __rfloat_pow__(self, other) -> "SymFloat": |
| raise TypeError("type stub not overridden") |
| |
| def __float_truediv__(self, other) -> "SymFloat": |
| raise TypeError("type stub not overridden") |
| |
| def __rfloat_truediv__(self, other) -> "SymFloat": |
| raise TypeError("type stub not overridden") |
| |
| def __trunc__(self): |
| raise TypeError("type stub not overridden") |
| |
| def __sym_max__(self, other): |
| raise TypeError("type stub not overridden") |
| |
| def __sym_min__(self, other): |
| raise TypeError("type stub not overridden") |
| |
| def __sym_int__(self): |
| raise TypeError("type stub not overridden") |
| |
| def is_integer(self): |
| """Return True if the float is an integer.""" |
| raise TypeError("type stub not overridden") |
| |
| def as_integer_ratio(self) -> _Tuple[builtins.int, builtins.int]: |
| """Represent this float as an exact integer ratio""" |
| return builtins.float(self).as_integer_ratio() |
| |
| def __repr__(self): |
| return self.node._graph_repr() |
| |
| def _sympy_(self): |
| return self.node.expr |
| |
| def __hash__(self): |
| return hash(builtins.float(self)) |
| |
| |
| class SymBool: |
| """ |
| Like an bool (including magic methods), but redirects all operations on the |
| wrapped node. This is used in particular to symbolically record operations |
| in the symbolic shape workflow. |
| |
| Unlike regular bools, regular boolean operators will force extra guards instead |
| of symbolically evaluate. Use the bitwise operators instead to handle this. |
| """ |
| |
| def __init__(self, node): |
| # This field MUST be named node; C++ binding code assumes that this |
| # class has a field named node that stores SymNode |
| self.node = node |
| |
| def __bool__(self): |
| return self.node.bool_() |
| |
| def __int__(self): |
| return builtins.int(self.node.bool_()) |
| |
| # Magic methods installed by torch.fx.experimental.sym_node |
| def __and__(self, other) -> "SymBool": |
| raise TypeError("type stub not overridden") |
| |
| def __or__(self, other) -> "SymBool": |
| raise TypeError("type stub not overridden") |
| |
| # We very carefully define __sym_not__, and not a number of other |
| # plausible alternatives: |
| # |
| # - We do not override __not__ because this is not a real magic |
| # method; you cannot override the meaning of the not builtin in |
| # Python. We use the name 'sym_not' to clarify that in user code you |
| # cannot use the builtin not or operator.not_ or operator.__not__ and |
| # hit this magic method; you must use our custom sym_not operator. |
| # |
| # - We do not override the __invert__ method because SymBool is |
| # meant to be usable in situations where bool is expected. However, |
| # bitwise negation ~a does the wrong thing with booleans (because |
| # bool is a subclass of int, so ~1 = -2 which is not falseish.) |
| # This would be a giant footgun, so we get around it by defining |
| # our own operator. Note that bitwise and/or do the right thing, |
| # so we reuse the conventional operators there for readability. |
| # |
| def __sym_not__(self) -> "SymBool": |
| raise TypeError("type stub not overridden") |
| |
| def __sym_ite__(self, then_val, else_val): |
| raise TypeError("type stub not overridden") |
| |
| def __eq__(self, other) -> builtins.bool: |
| raise TypeError("type stub not overridden") |
| |
| def __repr__(self): |
| return self.node._graph_repr() |
| |
| def _sympy_(self): |
| return self.node.expr |
| |
| def __hash__(self): |
| if self.node.is_constant(): |
| return hash(self.node.bool_()) |
| else: |
| # Force specialization |
| return hash(builtins.bool(self)) |
| |
| |
| def sym_not(a): |
| r"""SymInt-aware utility for logical negation. |
| |
| Args: |
| a (SymBool or bool): Object to negate |
| """ |
| import sympy |
| |
| if overrides.has_torch_function_unary(a): |
| return overrides.handle_torch_function(sym_not, (a,), a) |
| if hasattr(a, "__sym_not__"): |
| return a.__sym_not__() |
| if isinstance(a, sympy.Basic): |
| return ~a # type: ignore[operator] |
| return not a |
| |
| |
| def sym_float(a): |
| r"""SymInt-aware utility for float casting. |
| |
| Args: |
| a (SymInt, SymFloat, or object): Object to cast |
| """ |
| if overrides.has_torch_function_unary(a): |
| return overrides.handle_torch_function(sym_float, (a,), a) |
| if isinstance(a, SymFloat): |
| return a |
| elif hasattr(a, "__sym_float__"): |
| return a.__sym_float__() |
| return builtins.float(a) # type: ignore[operator] |
| |
| |
| def sym_int(a): |
| r"""SymInt-aware utility for int casting. |
| |
| Args: |
| a (SymInt, SymFloat, or object): Object to cast |
| """ |
| if overrides.has_torch_function_unary(a): |
| return overrides.handle_torch_function(sym_int, (a,), a) |
| if isinstance(a, SymInt): |
| return a |
| elif isinstance(a, SymFloat): |
| return math.trunc(a) |
| return builtins.int(a) # type: ignore[operator] |
| |
| |
| def sym_max(a, b): |
| """ |
| SymInt-aware utility for max which avoids branching on a < b. |
| Unlike builtins.max(), this only works for int/float, and it always |
| promotes to float if any argument is float (unlike builtins.max, which |
| will faithfully preserve the type of the input argument). |
| """ |
| if overrides.has_torch_function((a, b)): |
| return overrides.handle_torch_function(sym_max, (a, b), a, b) |
| if isinstance(a, (SymInt, SymFloat)): |
| return a.__sym_max__(b) |
| elif isinstance(b, (SymInt, SymFloat)): |
| # Due to promotion semantics, this is operator is commutative: |
| # max(1, 1.0) === max(1.0, 1) === 1.0 |
| return b.__sym_max__(a) |
| # TODO: Probably can make bool work too, just lazy |
| |
| all_types, float_types = __all_and_float_types() |
| |
| assert isinstance(a, all_types), type(a) |
| assert isinstance(b, all_types), type(b) |
| if isinstance(a, float_types) or isinstance(b, float_types): |
| return builtins.float(builtins.max(a, b)) |
| else: |
| return builtins.max(a, b) |
| |
| |
| def __all_and_float_types() -> _Tuple[_Tuple[_Type, ...], _Tuple[_Type, ...]]: |
| try: |
| import numpy as np |
| |
| all_types: _Tuple[_Type, ...] = ( |
| np.integer, |
| np.floating, |
| builtins.int, |
| builtins.float, |
| ) |
| float_types: _Tuple[_Type, ...] = (np.floating, builtins.float) |
| except ModuleNotFoundError: |
| all_types = (builtins.int, builtins.float) |
| float_types = (builtins.float,) |
| |
| return all_types, float_types |
| |
| |
| def sym_min(a, b): |
| """SymInt-aware utility for min().""" |
| if overrides.has_torch_function((a, b)): |
| return overrides.handle_torch_function(sym_min, (a, b), a, b) |
| if isinstance(a, (SymInt, SymFloat)): |
| return a.__sym_min__(b) |
| elif isinstance(b, (SymInt, SymFloat)): |
| return b.__sym_min__(a) |
| |
| all_types, float_types = __all_and_float_types() |
| |
| assert isinstance(a, all_types), type(a) |
| assert isinstance(b, all_types), type(b) |
| if isinstance(a, float_types) or isinstance(b, float_types): |
| return builtins.float(builtins.min(a, b)) |
| else: |
| return builtins.min(a, b) |
| |
| |
| # Drop in replacement for math.sqrt, math.sin, math.cos etc |
| def _get_sym_math_fn(name): |
| def fn(a): |
| if overrides.has_torch_function_unary(a): |
| return overrides.handle_torch_function(fn, (a,), a) |
| if hasattr(a, f"__sym_{name}__"): |
| return getattr(a, f"__sym_{name}__")() |
| return getattr(math, name)(a) |
| |
| return fn |
| |
| |
| __fn, __name, __sym_name = None, "", "" |
| for __name in ( |
| "sqrt", |
| "cos", |
| "cosh", |
| "sin", |
| "sinh", |
| "tan", |
| "tanh", |
| "asin", |
| "acos", |
| "atan", |
| ): |
| __sym_name = f"_sym_{__name}" |
| __fn = _get_sym_math_fn(__name) |
| __fn.__qualname__ = __fn.__name__ = __sym_name |
| globals()[__sym_name] = __fn |
| |
| del __fn, __name, __sym_name, _get_sym_math_fn |
| |
| # Adding temporary shortcut |
| sym_sqrt = globals()["_sym_sqrt"] |
| __all__.append("sym_sqrt") |
| |
| |
| def sym_ite(b, t, f): |
| if overrides.has_torch_function((b, t, f)): |
| return overrides.handle_torch_function(sym_ite, (b, t, f), b, t, f) |
| assert isinstance(b, (SymBool, builtins.bool)) and type(t) == type(f) |
| if isinstance(b, SymBool): |
| return b.__sym_ite__(t, f) |
| return t if b else f |
| |
| |
| # Check to see if we can load C extensions, and if not provide some guidance |
| # on what the problem might be. |
| try: |
| # _initExtension is chosen (arbitrarily) as a sentinel. |
| from torch._C import _initExtension |
| except ImportError: |
| import torch._C as _C_for_compiled_check |
| |
| # The __file__ check only works for Python 3.7 and above. |
| if _C_for_compiled_check.__file__ is None: |
| raise ImportError( |
| textwrap.dedent( |
| """ |
| Failed to load PyTorch C extensions: |
| It appears that PyTorch has loaded the `torch/_C` folder |
| of the PyTorch repository rather than the C extensions which |
| are expected in the `torch._C` namespace. This can occur when |
| using the `install` workflow. e.g. |
| $ python setup.py install && python -c "import torch" |
| |
| This error can generally be solved using the `develop` workflow |
| $ python setup.py develop && python -c "import torch" # This should succeed |
| or by running Python from a different directory. |
| """ |
| ).strip() |
| ) from None |
| raise # If __file__ is not None the cause is unknown, so just re-raise. |
| |
| # The torch._C submodule is already loaded via `from torch._C import *` above |
| # Make an explicit reference to the _C submodule to appease linters |
| from torch import _C as _C |
| |
| |
| __name, __obj = "", None |
| for __name in dir(_C): |
| if __name[0] != "_" and not __name.endswith("Base"): |
| __all__.append(__name) |
| __obj = getattr(_C, __name) |
| if callable(__obj) or inspect.isclass(__obj): |
| if __obj.__module__ != __name__: # "torch" |
| # TODO: fix their module from C++ side |
| if __name not in { |
| "DisableTorchFunctionSubclass", |
| "DisableTorchFunction", |
| "Generator", |
| }: |
| __obj.__module__ = __name__ # "torch" |
| elif __name == "TensorBase": |
| # issue 109438 / pr 109940. Prevent TensorBase from being copied into torch. |
| delattr(sys.modules[__name__], __name) |
| |
| del __name, __obj |
| |
| if not TYPE_CHECKING: |
| # issue 38137 and python issue 43367. Submodules of a C extension are |
| # non-standard, and attributes of those submodules cannot be pickled since |
| # pickle expect to be able to import them as "from _C.sub import attr" |
| # which fails with "_C is not a package |
| def _import_extension_to_sys_modules(module, memo=None): |
| if memo is None: |
| memo = set() |
| if module in memo: |
| return |
| memo.add(module) |
| module_name = module.__name__ |
| for name in dir(module): |
| member = getattr(module, name) |
| member_name = getattr(member, "__name__", "") |
| if inspect.ismodule(member) and member_name.startswith(module_name): |
| sys.modules.setdefault(member_name, member) |
| # Recurse for submodules (e.g., `_C._dynamo.eval_frame`) |
| _import_extension_to_sys_modules(member, memo) |
| |
| _import_extension_to_sys_modules(_C) |
| del _import_extension_to_sys_modules |
| |
| ################################################################################ |
| # Define basic utilities |
| ################################################################################ |
| |
| |
| def typename(obj: _Any, /) -> str: |
| """ |
| String representation of the type of an object. |
| |
| This function returns a fully qualified string representation of an object's type. |
| Args: |
| obj (object): The object whose type to represent |
| Returns: |
| str: the type of the object `o` |
| Example: |
| >>> x = torch.tensor([1, 2, 3]) |
| >>> torch.typename(x) |
| 'torch.LongTensor' |
| >>> torch.typename(torch.nn.Parameter) |
| 'torch.nn.parameter.Parameter' |
| """ |
| if isinstance(obj, torch.Tensor): |
| return obj.type() |
| |
| module = getattr(obj, "__module__", "") or "" |
| qualname = "" |
| |
| if hasattr(obj, "__qualname__"): |
| qualname = obj.__qualname__ |
| elif hasattr(obj, "__name__"): |
| qualname = obj.__name__ |
| else: |
| module = obj.__class__.__module__ or "" |
| qualname = obj.__class__.__qualname__ |
| |
| if module in {"", "builtins"}: |
| return qualname |
| return f"{module}.{qualname}" |
| |
| |
| def is_tensor(obj: _Any, /) -> _TypeGuard["torch.Tensor"]: |
| r"""Returns True if `obj` is a PyTorch tensor. |
| |
| Note that this function is simply doing ``isinstance(obj, Tensor)``. |
| Using that ``isinstance`` check is better for typechecking with mypy, |
| and more explicit - so it's recommended to use that instead of |
| ``is_tensor``. |
| |
| Args: |
| obj (object): Object to test |
| Example:: |
| |
| >>> x = torch.tensor([1, 2, 3]) |
| >>> torch.is_tensor(x) |
| True |
| |
| """ |
| return isinstance(obj, torch.Tensor) |
| |
| |
| def is_storage(obj: _Any, /) -> _TypeGuard[_Union["TypedStorage", "UntypedStorage"]]: |
| r"""Returns True if `obj` is a PyTorch storage object. |
| |
| Args: |
| obj (Object): Object to test |
| """ |
| return type(obj) in _storage_classes |
| |
| |
| _GLOBAL_DEVICE_CONTEXT = threading.local() |
| |
| |
| def get_default_device() -> "torch.device": |
| r"""Gets the default ``torch.Tensor`` to be allocated on ``device``""" |
| global _GLOBAL_DEVICE_CONTEXT |
| |
| if hasattr(_GLOBAL_DEVICE_CONTEXT, "device_context"): |
| device = _GLOBAL_DEVICE_CONTEXT.device_context.device |
| if device.index is not None: |
| return device |
| else: |
| # TODO: Call like get_device_index() method corresponding to |
| # each device type |
| return torch.tensor([]).device |
| else: |
| return torch.device("cpu") |
| |
| |
| def set_default_device( |
| device: _Optional[_Union["torch.device", str, builtins.int]], |
| ) -> None: |
| """Sets the default ``torch.Tensor`` to be allocated on ``device``. This |
| does not affect factory function calls which are called with an explicit |
| ``device`` argument. Factory calls will be performed as if they |
| were passed ``device`` as an argument. |
| |
| To only temporarily change the default device instead of setting it |
| globally, use ``with torch.device(device):`` instead. |
| |
| The default device is initially ``cpu``. If you set the default tensor |
| device to another device (e.g., ``cuda``) without a device index, tensors |
| will be allocated on whatever the current device for the device type, |
| even after :func:`torch.cuda.set_device` is called. |
| |
| .. warning:: |
| |
| This function imposes a slight performance cost on every Python |
| call to the torch API (not just factory functions). If this |
| is causing problems for you, please comment on |
| https://github.com/pytorch/pytorch/issues/92701 |
| |
| .. note:: |
| |
| This doesn't affect functions that create tensors that share the same memory as the input, like: |
| :func:`torch.from_numpy` and :func:`torch.frombuffer` |
| |
| Args: |
| device (device or string): the device to set as default |
| |
| Example:: |
| |
| >>> # xdoctest: +SKIP("requires cuda, changes global state") |
| >>> torch.get_default_device() |
| device(type='cpu') |
| >>> torch.set_default_device('cuda') # current device is 0 |
| >>> torch.get_default_device() |
| device(type='cuda', index=0) |
| >>> torch.set_default_device('cuda') |
| >>> torch.cuda.set_device('cuda:1') # current device is 1 |
| >>> torch.get_default_device() |
| device(type='cuda', index=1) |
| >>> torch.set_default_device('cuda:1') |
| >>> torch.get_default_device() |
| device(type='cuda', index=1) |
| |
| """ |
| global _GLOBAL_DEVICE_CONTEXT |
| if hasattr(_GLOBAL_DEVICE_CONTEXT, "device_context"): |
| device_context = _GLOBAL_DEVICE_CONTEXT.device_context |
| if device_context is not None: |
| device_context.__exit__(None, None, None) |
| |
| if device is None: |
| device_context = None |
| else: |
| from torch.utils._device import DeviceContext |
| |
| device_context = DeviceContext(device) |
| device_context.__enter__() |
| _GLOBAL_DEVICE_CONTEXT.device_context = device_context |
| |
| |
| def set_default_tensor_type(t: _Union[_Type["torch.Tensor"], str], /) -> None: |
| r""" |
| .. warning:: |
| |
| This function is deprecated as of PyTorch 2.1, please use :func:`torch.set_default_dtype()` and |
| :func:`torch.set_default_device()` as alternatives. |
| |
| Sets the default ``torch.Tensor`` type to floating point tensor type |
| ``t``. This type will also be used as default floating point type for |
| type inference in :func:`torch.tensor`. |
| |
| The default floating point tensor type is initially ``torch.FloatTensor``. |
| |
| Args: |
| t (type or string): the floating point tensor type or its name |
| |
| Example:: |
| |
| >>> # xdoctest: +SKIP("Other tests may have changed the default type. Can we reset it?") |
| >>> torch.tensor([1.2, 3]).dtype # initial default for floating point is torch.float32 |
| torch.float32 |
| >>> torch.set_default_tensor_type(torch.DoubleTensor) |
| >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor |
| torch.float64 |
| |
| """ |
| if isinstance(t, str): |
| t = _import_dotted_name(t) |
| _C._set_default_tensor_type(t) |
| |
| |
| def set_default_dtype(d: "torch.dtype", /) -> None: |
| r""" |
| |
| Sets the default floating point dtype to :attr:`d`. Supports floating point dtype |
| as inputs. Other dtypes will cause torch to raise an exception. |
| |
| When PyTorch is initialized its default floating point dtype is torch.float32, |
| and the intent of set_default_dtype(torch.float64) is to facilitate NumPy-like |
| type inference. The default floating point dtype is used to: |
| |
| 1. Implicitly determine the default complex dtype. When the default floating type is float16, |
| the default complex dtype is complex32. For float32, the default complex dtype is complex64. |
| For float64, it is complex128. For bfloat16, an exception will be raised because |
| there is no corresponding complex type for bfloat16. |
| 2. Infer the dtype for tensors constructed using Python floats or complex Python |
| numbers. See examples below. |
| 3. Determine the result of type promotion between bool and integer tensors and |
| Python floats and complex Python numbers. |
| |
| Args: |
| d (:class:`torch.dtype`): the floating point dtype to make the default. |
| |
| Example: |
| >>> # xdoctest: +SKIP("Other tests may have changed the default type. Can we reset it?") |
| >>> # initial default for floating point is torch.float32 |
| >>> # Python floats are interpreted as float32 |
| >>> torch.tensor([1.2, 3]).dtype |
| torch.float32 |
| >>> # initial default for floating point is torch.complex64 |
| >>> # Complex Python numbers are interpreted as complex64 |
| >>> torch.tensor([1.2, 3j]).dtype |
| torch.complex64 |
| |
| >>> torch.set_default_dtype(torch.float64) |
| >>> # Python floats are now interpreted as float64 |
| >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor |
| torch.float64 |
| >>> # Complex Python numbers are now interpreted as complex128 |
| >>> torch.tensor([1.2, 3j]).dtype # a new complex tensor |
| torch.complex128 |
| |
| >>> torch.set_default_dtype(torch.float16) |
| >>> # Python floats are now interpreted as float16 |
| >>> torch.tensor([1.2, 3]).dtype # a new floating point tensor |
| torch.float16 |
| >>> # Complex Python numbers are now interpreted as complex128 |
| >>> torch.tensor([1.2, 3j]).dtype # a new complex tensor |
| torch.complex32 |
| |
| """ |
| _C._set_default_dtype(d) |
| |
| |
| def use_deterministic_algorithms( |
| mode: builtins.bool, |
| *, |
| warn_only: builtins.bool = False, |
| ) -> None: |
| r"""Sets whether PyTorch operations must use "deterministic" |
| algorithms. That is, algorithms which, given the same input, and when |
| run on the same software and hardware, always produce the same output. |
| When enabled, operations will use deterministic algorithms when available, |
| and if only nondeterministic algorithms are available they will throw a |
| :class:`RuntimeError` when called. |
| |
| .. note:: This setting alone is not always enough to make an application |
| reproducible. Refer to :ref:`reproducibility` for more information. |
| |
| .. note:: :func:`torch.set_deterministic_debug_mode` offers an alternative |
| interface for this feature. |
| |
| The following normally-nondeterministic operations will act |
| deterministically when ``mode=True``: |
| |
| * :class:`torch.nn.Conv1d` when called on CUDA tensor |
| * :class:`torch.nn.Conv2d` when called on CUDA tensor |
| * :class:`torch.nn.Conv3d` when called on CUDA tensor |
| * :class:`torch.nn.ConvTranspose1d` when called on CUDA tensor |
| * :class:`torch.nn.ConvTranspose2d` when called on CUDA tensor |
| * :class:`torch.nn.ConvTranspose3d` when called on CUDA tensor |
| * :class:`torch.nn.ReplicationPad2d` when attempting to differentiate a CUDA tensor |
| * :func:`torch.bmm` when called on sparse-dense CUDA tensors |
| * :func:`torch.Tensor.__getitem__` when attempting to differentiate a CPU tensor |
| and the index is a list of tensors |
| * :func:`torch.Tensor.index_put` with ``accumulate=False`` |
| * :func:`torch.Tensor.index_put` with ``accumulate=True`` when called on a CPU |
| tensor |
| * :func:`torch.Tensor.put_` with ``accumulate=True`` when called on a CPU |
| tensor |
| * :func:`torch.Tensor.scatter_add_` when called on a CUDA tensor |
| * :func:`torch.gather` when called on a CUDA tensor that requires grad |
| * :func:`torch.index_add` when called on CUDA tensor |
| * :func:`torch.index_select` when attempting to differentiate a CUDA tensor |
| * :func:`torch.repeat_interleave` when attempting to differentiate a CUDA tensor |
| * :func:`torch.Tensor.index_copy` when called on a CPU or CUDA tensor |
| * :func:`torch.Tensor.scatter` when `src` type is Tensor and called on CUDA tensor |
| * :func:`torch.Tensor.scatter_reduce` when ``reduce='sum'`` or ``reduce='mean'`` and called on CUDA tensor |
| |
| The following normally-nondeterministic operations will throw a |
| :class:`RuntimeError` when ``mode=True``: |
| |
| * :class:`torch.nn.AvgPool3d` when attempting to differentiate a CUDA tensor |
| * :class:`torch.nn.AdaptiveAvgPool2d` when attempting to differentiate a CUDA tensor |
| * :class:`torch.nn.AdaptiveAvgPool3d` when attempting to differentiate a CUDA tensor |
| * :class:`torch.nn.MaxPool3d` when attempting to differentiate a CUDA tensor |
| * :class:`torch.nn.AdaptiveMaxPool2d` when attempting to differentiate a CUDA tensor |
| * :class:`torch.nn.FractionalMaxPool2d` when attempting to differentiate a CUDA tensor |
| * :class:`torch.nn.FractionalMaxPool3d` when attempting to differentiate a CUDA tensor |
| * :class:`torch.nn.MaxUnpool1d` |
| * :class:`torch.nn.MaxUnpool2d` |
| * :class:`torch.nn.MaxUnpool3d` |
| * :func:`torch.nn.functional.interpolate` when attempting to differentiate a CUDA tensor |
| and one of the following modes is used: |
| |
| - ``linear`` |
| - ``bilinear`` |
| - ``bicubic`` |
| - ``trilinear`` |
| |
| * :class:`torch.nn.ReflectionPad1d` when attempting to differentiate a CUDA tensor |
| * :class:`torch.nn.ReflectionPad2d` when attempting to differentiate a CUDA tensor |
| * :class:`torch.nn.ReflectionPad3d` when attempting to differentiate a CUDA tensor |
| * :class:`torch.nn.ReplicationPad1d` when attempting to differentiate a CUDA tensor |
| * :class:`torch.nn.ReplicationPad3d` when attempting to differentiate a CUDA tensor |
| * :class:`torch.nn.NLLLoss` when called on a CUDA tensor |
| * :class:`torch.nn.CTCLoss` when attempting to differentiate a CUDA tensor |
| * :class:`torch.nn.EmbeddingBag` when attempting to differentiate a CUDA tensor when |
| ``mode='max'`` |
| * :func:`torch.Tensor.put_` when ``accumulate=False`` |
| * :func:`torch.Tensor.put_` when ``accumulate=True`` and called on a CUDA tensor |
| * :func:`torch.histc` when called on a CUDA tensor |
| * :func:`torch.bincount` when called on a CUDA tensor and ``weights`` |
| tensor is given |
| * :func:`torch.kthvalue` with called on a CUDA tensor |
| * :func:`torch.median` with indices output when called on a CUDA tensor |
| * :func:`torch.nn.functional.grid_sample` when attempting to differentiate a CUDA tensor |
| * :func:`torch.cumsum` when called on a CUDA tensor when dtype is floating point or complex |
| * :func:`torch.Tensor.scatter_reduce` when ``reduce='prod'`` and called on CUDA tensor |
| * :func:`torch.Tensor.resize_` when called with a quantized tensor |
| |
| In addition, several operations fill uninitialized memory when this setting |
| is turned on and when |
| :attr:`torch.utils.deterministic.fill_uninitialized_memory` is turned on. |
| See the documentation for that attribute for more information. |
| |
| A handful of CUDA operations are nondeterministic if the CUDA version is |
| 10.2 or greater, unless the environment variable ``CUBLAS_WORKSPACE_CONFIG=:4096:8`` |
| or ``CUBLAS_WORKSPACE_CONFIG=:16:8`` is set. See the CUDA documentation for more |
| details: `<https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility>`_ |
| If one of these environment variable configurations is not set, a :class:`RuntimeError` |
| will be raised from these operations when called with CUDA tensors: |
| |
| * :func:`torch.mm` |
| * :func:`torch.mv` |
| * :func:`torch.bmm` |
| |
| Note that deterministic operations tend to have worse performance than |
| nondeterministic operations. |
| |
| .. note:: |
| |
| This flag does not detect or prevent nondeterministic behavior caused |
| by calling an inplace operation on a tensor with an internal memory |
| overlap or by giving such a tensor as the :attr:`out` argument for an |
| operation. In these cases, multiple writes of different data may target |
| a single memory location, and the order of writes is not guaranteed. |
| |
| Args: |
| mode (:class:`bool`): If True, makes potentially nondeterministic |
| operations switch to a deterministic algorithm or throw a runtime |
| error. If False, allows nondeterministic operations. |
| |
| Keyword args: |
| warn_only (:class:`bool`, optional): If True, operations that do not |
| have a deterministic implementation will throw a warning instead of |
| an error. Default: ``False`` |
| |
| Example:: |
| |
| >>> # xdoctest: +SKIP |
| >>> torch.use_deterministic_algorithms(True) |
| |
| # Forward mode nondeterministic error |
| >>> torch.randn(10, device='cuda').kthvalue(1) |
| ... |
| RuntimeError: kthvalue CUDA does not have a deterministic implementation... |
| |
| # Backward mode nondeterministic error |
| >>> torch.nn.AvgPool3d(1)(torch.randn(3, 4, 5, 6, requires_grad=True).cuda()).sum().backward() |
| ... |
| RuntimeError: avg_pool3d_backward_cuda does not have a deterministic implementation... |
| """ |
| _C._set_deterministic_algorithms(mode, warn_only=warn_only) |
| |
| |
| def are_deterministic_algorithms_enabled() -> builtins.bool: |
| r"""Returns True if the global deterministic flag is turned on. Refer to |
| :func:`torch.use_deterministic_algorithms` documentation for more details. |
| """ |
| return _C._get_deterministic_algorithms() |
| |
| |
| def is_deterministic_algorithms_warn_only_enabled() -> builtins.bool: |
| r"""Returns True if the global deterministic flag is set to warn only. |
| Refer to :func:`torch.use_deterministic_algorithms` documentation for more |
| details. |
| """ |
| return _C._get_deterministic_algorithms_warn_only() |
| |
| |
| def set_deterministic_debug_mode(debug_mode: _Union[builtins.int, str]) -> None: |
| r"""Sets the debug mode for deterministic operations. |
| |
| .. note:: This is an alternative interface for |
| :func:`torch.use_deterministic_algorithms`. Refer to that function's |
| documentation for details about affected operations. |
| |
| Args: |
| debug_mode(str or int): If "default" or 0, don't error or warn on |
| nondeterministic operations. If "warn" or 1, warn on |
| nondeterministic operations. If "error" or 2, error on |
| nondeterministic operations. |
| """ |
| |
| # NOTE: builtins.int is used here because int in this scope resolves |
| # to torch.int |
| if not isinstance(debug_mode, (builtins.int, str)): |
| raise TypeError(f"debug_mode must be str or int, but got {type(debug_mode)}") |
| |
| if isinstance(debug_mode, str): |
| if debug_mode == "default": |
| debug_mode = 0 |
| elif debug_mode == "warn": |
| debug_mode = 1 |
| elif debug_mode == "error": |
| debug_mode = 2 |
| else: |
| raise RuntimeError( |
| "invalid value of debug_mode, expected one of `default`, " |
| f"`warn`, `error`, but got {debug_mode}" |
| ) |
| |
| if debug_mode == 0: |
| _C._set_deterministic_algorithms(False) |
| elif debug_mode == 1: |
| _C._set_deterministic_algorithms(True, warn_only=True) |
| elif debug_mode == 2: |
| _C._set_deterministic_algorithms(True) |
| else: |
| raise RuntimeError( |
| "invalid value of debug_mode, expected 0, 1, or 2, " f"but got {debug_mode}" |
| ) |
| |
| |
| def get_deterministic_debug_mode() -> builtins.int: |
| r"""Returns the current value of the debug mode for deterministic |
| operations. Refer to :func:`torch.set_deterministic_debug_mode` |
| documentation for more details. |
| """ |
| |
| if _C._get_deterministic_algorithms(): |
| if _C._get_deterministic_algorithms_warn_only(): |
| return 1 |
| else: |
| return 2 |
| else: |
| return 0 |
| |
| |
| def get_float32_matmul_precision() -> str: |
| r"""Returns the current value of float32 matrix multiplication precision. Refer to |
| :func:`torch.set_float32_matmul_precision` documentation for more details. |
| """ |
| return _C._get_float32_matmul_precision() |
| |
| |
| def set_float32_matmul_precision(precision: str) -> None: |
| r"""Sets the internal precision of float32 matrix multiplications. |
| |
| Running float32 matrix multiplications in lower precision may significantly increase |
| performance, and in some programs the loss of precision has a negligible impact. |
| |
| Supports three settings: |
| |
| * "highest", float32 matrix multiplications use the float32 datatype (24 mantissa |
| bits with 23 bits explicitly stored) for internal computations. |
| * "high", float32 matrix multiplications either use the TensorFloat32 datatype (10 |
| mantissa bits explicitly stored) or treat each float32 number as the sum of two bfloat16 numbers |
| (approximately 16 mantissa bits with 14 bits explicitly stored), if the appropriate fast matrix multiplication |
| algorithms are available. Otherwise float32 matrix multiplications are computed |
| as if the precision is "highest". See below for more information on the bfloat16 |
| approach. |
| * "medium", float32 matrix multiplications use the bfloat16 datatype (8 mantissa |
| bits with 7 bits explicitly stored) for internal computations, if a fast matrix multiplication algorithm |
| using that datatype internally is available. Otherwise float32 |
| matrix multiplications are computed as if the precision is "high". |
| |
| When using "high" precision, float32 multiplications may use a bfloat16-based algorithm |
| that is more complicated than simply truncating to some smaller number mantissa bits |
| (e.g. 10 for TensorFloat32, 7 for bfloat16 explicitly stored). Refer to [Henry2019]_ for a complete |
| description of this algorithm. To briefly explain here, the first step is to realize |
| that we can perfectly encode a single float32 number as the sum of three bfloat16 |
| numbers (because float32 has 23 mantissa bits while bfloat16 has 7 explicitly stored, and both have the |
| same number of exponent bits). This means that the product of two float32 numbers can |
| be exactly given by the sum of nine products of bfloat16 numbers. We can then trade |
| accuracy for speed by dropping some of these products. The "high" precision algorithm |
| specifically keeps only the three most significant products, which conveniently excludes |
| all of the products involving the last 8 mantissa bits of either input. This means that |
| we can represent our inputs as the sum of two bfloat16 numbers rather than three. |
| Because bfloat16 fused-multiply-add (FMA) instructions are typically >10x faster than |
| float32 ones, it's faster to do three multiplications and 2 additions with bfloat16 |
| precision than it is to do a single multiplication with float32 precision. |
| |
| .. [Henry2019] http://arxiv.org/abs/1904.06376 |
| |
| .. note:: |
| |
| This does not change the output dtype of float32 matrix multiplications, |
| it controls how the internal computation of the matrix multiplication is performed. |
| |
| .. note:: |
| |
| This does not change the precision of convolution operations. Other flags, |
| like `torch.backends.cudnn.allow_tf32`, may control the precision of convolution |
| operations. |
| |
| .. note:: |
| |
| This flag currently only affects one native device type: CUDA. |
| If "high" or "medium" are set then the TensorFloat32 datatype will be used |
| when computing float32 matrix multiplications, equivalent to setting |
| `torch.backends.cuda.matmul.allow_tf32 = True`. When "highest" (the default) |
| is set then the float32 datatype is used for internal computations, equivalent |
| to setting `torch.backends.cuda.matmul.allow_tf32 = False`. |
| |
| Args: |
| precision(str): can be set to "highest" (default), "high", or "medium" (see above). |
| |
| """ |
| _C._set_float32_matmul_precision(precision) |
| |
| |
| def set_warn_always(b: builtins.bool, /) -> None: |
| r"""When this flag is False (default) then some PyTorch warnings may only |
| appear once per process. This helps avoid excessive warning information. |
| Setting it to True causes these warnings to always appear, which may be |
| helpful when debugging. |
| |
| Args: |
| b (:class:`bool`): If True, force warnings to always be emitted |
| If False, set to the default behaviour |
| """ |
| _C._set_warnAlways(b) |
| |
| |
| def is_warn_always_enabled() -> builtins.bool: |
| r"""Returns True if the global warn_always flag is turned on. Refer to |
| :func:`torch.set_warn_always` documentation for more details. |
| """ |
| return _C._get_warnAlways() |
| |
| |
| ################################################################################ |
| # Define error checking functions |
| ################################################################################ |
| |
| # These error checking functions must be kept consistent with their C++ |
| # equivalents. Their C++ equivalents are mentioned where applicable. |
| |
| |
| def _check_with( |
| error_type, |
| cond: _Union[builtins.bool, SymBool], |
| message: _Callable[[], str], |
| ): # noqa: F811 |
| if not isinstance(cond, (builtins.bool, SymBool)): |
| raise TypeError(f"cond must be a bool, but got {type(cond)}") |
| |
| from torch.fx.experimental.symbolic_shapes import expect_true |
| |
| if expect_true(cond): |
| return |
| |
| # error_type must be a subclass of Exception and not subclass of Warning |
| assert issubclass(error_type, Exception) and not issubclass(error_type, Warning) |
| |
| if message is None: |
| message_evaluated = ( |
| "Expected cond to be True, but got False. (Could this error " |
| "message be improved? If so, please report an enhancement request " |
| "to PyTorch.)" |
| ) |
| |
| else: |
| if not callable(message): |
| raise TypeError("message must be a callable") |
| |
| message_evaluated = str(message()) |
| |
| raise error_type(message_evaluated) |
| |
| |
| def _check(cond, message=None): # noqa: F811 |
| r"""Throws error containing an optional message if the specified condition |
| is False. |
| |
| Error type: ``RuntimeError`` |
| |
| C++ equivalent: ``TORCH_CHECK`` |
| |
| Args: |
| cond (:class:`bool`): If False, throw error |
| |
| message (Callable, optional): Callable that returns either a string or |
| an object that has a ``__str__()`` method to be used as the error |
| message. Default: ``None`` |
| """ |
| _check_with(RuntimeError, cond, message) |
| |
| |
| def _check_is_size(i, message=None): |
| """Checks that a given integer is a valid size (i.e., is non-negative). |
| You should use this over _check(i >= 0) because we can use the semantic |
| information (that i is a size) to make some further inferences in case |
| i is an unbacked SymInt. |
| |
| NB: Do NOT use this in contexts where a -1 size would be valid (indicating |
| to infer the size from context, or if you should wrap-around or truncate). |
| Only use this if the only valid value is an honest to goodness size. |
| """ |
| # This is responsible for the expect_true |
| _check(i >= 0, message) |
| from torch.fx.experimental.symbolic_shapes import _advise_is_size |
| |
| _advise_is_size(i) |
| |
| |
| def _check_index(cond, message=None): # noqa: F811 |
| r"""Throws error containing an optional message if the specified condition |
| is False. |
| |
| Error type: ``IndexError`` |
| |
| C++ equivalent: ``TORCH_CHECK_INDEX`` |
| |
| Args: |
| cond (:class:`bool`): If False, throw error |
| |
| message (Callable, optional): Callable that returns either a string or |
| an object that has a ``__str__()`` method to be used as the error |
| message. Default: ``None`` |
| """ |
| _check_with(IndexError, cond, message) |
| |
| |
| def _check_value(cond, message=None): # noqa: F811 |
| r"""Throws error containing an optional message if the specified condition |
| is False. |
| |
| Error type: ``ValueError`` |
| |
| C++ equivalent: ``TORCH_CHECK_VALUE`` |
| |
| Args: |
| cond (:class:`bool`): If False, throw error |
| |
| message (Callable, optional): Callable that returns either a string or |
| an object that has a ``__str__()`` method to be used as the error |
| message. Default: ``None`` |
| """ |
| _check_with(ValueError, cond, message) |
| |
| |
| def _check_type(cond, message=None): # noqa: F811 |
| r"""Throws error containing an optional message if the specified condition |
| is False. |
| |
| Error type: ``TypeError`` |
| |
| C++ equivalent: ``TORCH_CHECK_TYPE`` |
| |
| Args: |
| cond (:class:`bool`): If False, throw error |
| |
| message (Callable, optional): Callable that returns either a string or |
| an object that has a ``__str__()`` method to be used as the error |
| message. Default: ``None`` |
| """ |
| _check_with(TypeError, cond, message) |
| |
| |
| def _check_not_implemented(cond, message=None): # noqa: F811 |
| r"""Throws error containing an optional message if the specified condition |
| is False. |
| |
| Error type: ``NotImplementedError`` |
| |
| C++ equivalent: ``TORCH_CHECK_NOT_IMPLEMENTED`` |
| |
| Args: |
| cond (:class:`bool`): If False, throw error |
| |
| message (Callable, optional): Callable that returns either a string or |
| an object that has a ``__str__()`` method to be used as the error |
| message. Default: ``None`` |
| """ |
| _check_with(NotImplementedError, cond, message) |
| |
| |
| def _check_tensor_all_with(error_type, cond, message=None): # noqa: F811 |
| if not is_tensor(cond): |
| raise TypeError(f"cond must be a tensor, but got {type(cond)}") |
| |
| if not cond.dtype == torch.bool: |
| raise TypeError(f"cond tensor must have dtype torch.bool, but got {cond.dtype}") |
| |
| _check_with(error_type, cond._is_all_true().item(), message) # type: ignore[arg-type] |
| |
| |
| # C++ equivalent: `TORCH_CHECK_TENSOR_ALL` |
| def _check_tensor_all(cond, message=None): # noqa: F811 |
| r"""Throws error containing an optional message if the specified condition |
| is False. |
| |
| Error type: ``RuntimeError`` |
| |
| C++ equivalent: ``TORCH_CHECK_TENSOR_ALL`` |
| |
| Args: |
| cond (:class:`torch.Tensor`): Tensor of dtype ``torch.bool``. If any |
| element is ``False``, throw error |
| |
| message (Callable, optional): Callable that returns either a string or |
| an object that has a ``__str__()`` method to be used as the error |
| message. Default: ``None`` |
| """ |
| _check_tensor_all_with(RuntimeError, cond, message) |
| |
| |
| ################################################################################ |
| # Define numeric constants |
| ################################################################################ |
| |
| # For Python Array API (https://data-apis.org/array-api/latest/API_specification/constants.html) and |
| # NumPy consistency (https://numpy.org/devdocs/reference/constants.html) |
| from math import e, inf, nan, pi |
| |
| |
| newaxis: None = None |
| |
| __all__.extend(["e", "pi", "nan", "inf", "newaxis"]) |
| |
| ################################################################################ |
| # Define Storage and Tensor classes |
| ################################################################################ |
| |
| from torch._tensor import Tensor # usort: skip |
| |
| # needs to be after torch.Tensor is defined to avoid circular dependencies |
| from torch import storage as storage # usort: skip |
| from torch.storage import ( |
| _LegacyStorage, |
| _StorageBase, |
| _warn_typed_storage_removal, |
| TypedStorage, |
| UntypedStorage, |
| ) |
| |
| |
| # NOTE: New <type>Storage classes should never be added. When adding a new |
| # dtype, use torch.storage.TypedStorage directly. |
| class ByteStorage(_LegacyStorage): |
| @classproperty |
| def dtype(self): |
| _warn_typed_storage_removal(stacklevel=3) |
| return self._dtype |
| |
| @classproperty |
| def _dtype(self): |
| return torch.uint8 |
| |
| |
| class DoubleStorage(_LegacyStorage): |
| @classproperty |
| def dtype(self): |
| _warn_typed_storage_removal(stacklevel=3) |
| return self._dtype |
| |
| @classproperty |
| def _dtype(self): |
| return torch.double |
| |
| |
| class FloatStorage(_LegacyStorage): |
| @classproperty |
| def dtype(self): |
| _warn_typed_storage_removal(stacklevel=3) |
| return self._dtype |
| |
| @classproperty |
| def _dtype(self): |
| return torch.float |
| |
| |
| class HalfStorage(_LegacyStorage): |
| @classproperty |
| def dtype(self): |
| _warn_typed_storage_removal(stacklevel=3) |
| return self._dtype |
| |
| @classproperty |
| def _dtype(self): |
| return torch.half |
| |
| |
| class LongStorage(_LegacyStorage): |
| @classproperty |
| def dtype(self): |
| _warn_typed_storage_removal(stacklevel=3) |
| return self._dtype |
| |
| @classproperty |
| def _dtype(self): |
| return torch.long |
| |
| |
| class IntStorage(_LegacyStorage): |
| @classproperty |
| def dtype(self): |
| _warn_typed_storage_removal(stacklevel=3) |
| return self._dtype |
| |
| @classproperty |
| def _dtype(self): |
| return torch.int |
| |
| |
| class ShortStorage(_LegacyStorage): |
| @classproperty |
| def dtype(self): |
| _warn_typed_storage_removal(stacklevel=3) |
| return self._dtype |
| |
| @classproperty |
| def _dtype(self): |
| return torch.short |
| |
| |
| class CharStorage(_LegacyStorage): |
| @classproperty |
| def dtype(self): |
| _warn_typed_storage_removal(stacklevel=3) |
| return self._dtype |
| |
| @classproperty |
| def _dtype(self): |
| return torch.int8 |
| |
| |
| class BoolStorage(_LegacyStorage): |
| @classproperty |
| def dtype(self): |
| _warn_typed_storage_removal(stacklevel=3) |
| return self._dtype |
| |
| @classproperty |
| def _dtype(self): |
| return torch.bool |
| |
| |
| class BFloat16Storage(_LegacyStorage): |
| @classproperty |
| def dtype(self): |
| _warn_typed_storage_removal(stacklevel=3) |
| return self._dtype |
| |
| @classproperty |
| def _dtype(self): |
| return torch.bfloat16 |
| |
| |
| class ComplexDoubleStorage(_LegacyStorage): |
| @classproperty |
| def dtype(self): |
| _warn_typed_storage_removal(stacklevel=3) |
| return self._dtype |
| |
| @classproperty |
| def _dtype(self): |
| return torch.cdouble |
| |
| |
| class ComplexFloatStorage(_LegacyStorage): |
| @classproperty |
| def dtype(self): |
| _warn_typed_storage_removal(stacklevel=3) |
| return self._dtype |
| |
| @classproperty |
| def _dtype(self): |
| return torch.cfloat |
| |
| |
| class QUInt8Storage(_LegacyStorage): |
| @classproperty |
| def dtype(self): |
| _warn_typed_storage_removal(stacklevel=3) |
| return self._dtype |
| |
| @classproperty |
| def _dtype(self): |
| return torch.quint8 |
| |
| |
| class QInt8Storage(_LegacyStorage): |
| @classproperty |
| def dtype(self): |
| _warn_typed_storage_removal(stacklevel=3) |
| return self._dtype |
| |
| @classproperty |
| def _dtype(self): |
| return torch.qint8 |
| |
| |
| class QInt32Storage(_LegacyStorage): |
| @classproperty |
| def dtype(self): |
| _warn_typed_storage_removal(stacklevel=3) |
| return self._dtype |
| |
| @classproperty |
| def _dtype(self): |
| return torch.qint32 |
| |
| |
| class QUInt4x2Storage(_LegacyStorage): |
| @classproperty |
| def dtype(self): |
| _warn_typed_storage_removal(stacklevel=3) |
| return self._dtype |
| |
| @classproperty |
| def _dtype(self): |
| return torch.quint4x2 |
| |
| |
| class QUInt2x4Storage(_LegacyStorage): |
| @classproperty |
| def dtype(self): |
| _warn_typed_storage_removal(stacklevel=3) |
| return self._dtype |
| |
| @classproperty |
| def _dtype(self): |
| return torch.quint2x4 |
| |
| |
| _storage_classes: _Set[_Type[_Union[TypedStorage, UntypedStorage]]] = { |
| UntypedStorage, |
| DoubleStorage, |
| FloatStorage, |
| LongStorage, |
| IntStorage, |
| ShortStorage, |
| CharStorage, |
| ByteStorage, |
| HalfStorage, |
| BoolStorage, |
| QUInt8Storage, |
| QInt8Storage, |
| QInt32Storage, |
| BFloat16Storage, |
| ComplexFloatStorage, |
| ComplexDoubleStorage, |
| QUInt4x2Storage, |
| QUInt2x4Storage, |
| TypedStorage, |
| } |
| |
| # The _tensor_classes set is initialized by the call to initialize_python_bindings. |
| _tensor_classes: _Set[_Type["torch.Tensor"]] = set() |
| |
| # If you edit these imports, please update torch/__init__.py.in as well |
| from torch import amp as amp, random as random, serialization as serialization |
| from torch._tensor_str import set_printoptions |
| from torch.amp import autocast, GradScaler |
| from torch.random import get_rng_state, initial_seed, manual_seed, seed, set_rng_state |
| from torch.serialization import load, save |
| |
| |
| ################################################################################ |
| # Initialize extension |
| ################################################################################ |
| |
| |
| # Shared memory manager needs to know the exact location of manager executable |
| def _manager_path(): |
| if _running_with_deploy() or platform.system() == "Windows": |
| return b"" |
| path = get_file_path("torch", "bin", "torch_shm_manager") |
| prepare_multiprocessing_environment(get_file_path("torch")) |
| if not os.path.exists(path): |
| raise RuntimeError("Unable to find torch_shm_manager at " + path) |
| return path.encode("utf-8") |
| |
| |
| _C._initExtension(_manager_path()) |
| |
| del _manager_path |
| |
| # Appease the type checker: it can't deal with direct setting of globals(). |
| # Note that we will see "too many" functions when reexporting this way; there |
| # is not a good way to fix this problem. Perhaps, try to redesign VariableFunctions |
| # so that this import is good enough |
| if TYPE_CHECKING: |
| # Some type signatures pulled in from _VariableFunctions here clash with |
| # signatures already imported. For now these clashes are ignored; see |
| # PR #43339 for details. |
| from torch._C._VariableFunctions import * # type: ignore[assignment, misc] # noqa: F403 |
| |
| # Fixup segment_reduce visibility |
| _segment_reduce = segment_reduce |
| del segment_reduce # noqa: F821 |
| |
| # Ops not to be exposed in `torch` namespace, |
| # mostly helper ops. |
| PRIVATE_OPS = ("unique_dim",) |
| |
| __name, __obj = "", None |
| for __name in dir(_C._VariableFunctions): |
| if __name.startswith("__") or __name in PRIVATE_OPS: |
| continue |
| __obj = getattr(_C._VariableFunctions, __name) |
| __obj.__module__ = __name__ # "torch" |
| # Hide some APIs that should not be public |
| if __name == "segment_reduce": |
| # TODO: Once the undocumented FC window is passed, remove the line bellow |
| globals()[__name] = __obj |
| __name = "_" + __name |
| globals()[__name] = __obj |
| if not __name.startswith("_"): |
| __all__.append(__name) |
| |
| del __name, __obj |
| |
| ################################################################################ |
| # Add torch.dtype instances to the public API |
| ################################################################################ |
| |
| import torch |
| |
| |
| __all__.extend( |
| name for name in dir(torch) if isinstance(getattr(torch, name), torch.dtype) |
| ) |
| |
| ################################################################################ |
| # Import TorchDynamo's lazy APIs to avoid circular dependenices |
| ################################################################################ |
| |
| # needs to be before from torch.functional import * to avoid circular dependencies |
| from torch._compile import _disable_dynamo # usort: skip |
| |
| ################################################################################ |
| # Import interface functions defined in Python |
| ################################################################################ |
| |
| # needs to be after the above ATen bindings so we can overwrite from Python side |
| from torch import _VF as _VF, functional as functional # usort: skip |
| from torch.functional import * # usort: skip # noqa: F403 |
| |
| ################################################################################ |
| # Remove unnecessary members |
| ################################################################################ |
| |
| del _StorageBase |
| del _LegacyStorage |
| |
| ################################################################################ |
| # Define _assert |
| ################################################################################ |
| |
| |
| # needs to be before the submodule imports to avoid circular dependencies |
| def _assert(condition, message): |
| r"""A wrapper around Python's assert which is symbolically traceable.""" |
| if type(condition) is not torch.Tensor and overrides.has_torch_function( |
| (condition,) |
| ): |
| return overrides.handle_torch_function( |
| _assert, (condition,), condition, message |
| ) |
| assert condition, message |
| |
| |
| ################################################################################ |
| # Import most common subpackages |
| ################################################################################ |
| |
| # Use the redundant form so that type checkers know that these are a part of |
| # the public API. The "regular" import lines are there solely for the runtime |
| # side effect of adding to the imported module's members for other users. |
| |
| # needs to be before import torch.nn as nn to avoid circular dependencies |
| from torch.autograd import ( # usort: skip |
| enable_grad as enable_grad, |
| inference_mode as inference_mode, |
| no_grad as no_grad, |
| set_grad_enabled as set_grad_enabled, |
| ) |
| |
| from torch import ( |
| __config__ as __config__, |
| __future__ as __future__, |
| _awaits as _awaits, |
| autograd as autograd, |
| backends as backends, |
| cpu as cpu, |
| cuda as cuda, |
| distributed as distributed, |
| distributions as distributions, |
| fft as fft, |
| futures as futures, |
| hub as hub, |
| jit as jit, |
| linalg as linalg, |
| mps as mps, |
| mtia as mtia, |
| multiprocessing as multiprocessing, |
| nested as nested, |
| nn as nn, |
| optim as optim, |
| overrides as overrides, |
| profiler as profiler, |
| sparse as sparse, |
| special as special, |
| testing as testing, |
| types as types, |
| utils as utils, |
| xpu as xpu, |
| ) |
| from torch.signal import windows as windows |
| |
| |
| # Quantized, sparse, AO, etc. should be last to get imported, as nothing |
| # is expected to depend on them. |
| from torch import ao as ao # usort: skip |
| |
| # nn.quant* depends on ao -- so should be after those. |
| import torch.nn.intrinsic |
| import torch.nn.qat |
| import torch.nn.quantizable |
| import torch.nn.quantized |
| |
| |
| _C._init_names(list(_storage_classes)) |
| |
| # attach docstrings to torch and tensor functions |
| from torch import _size_docs, _storage_docs, _tensor_docs, _torch_docs |
| |
| |
| del _torch_docs, _tensor_docs, _storage_docs, _size_docs |
| |
| |
| def compiled_with_cxx11_abi() -> builtins.bool: |
| r"""Returns whether PyTorch was built with _GLIBCXX_USE_CXX11_ABI=1""" |
| return _C._GLIBCXX_USE_CXX11_ABI |
| |
| |
| from torch import _library as _library, _ops as _ops |
| |
| |
| # Import the ops and classes "namespace" |
| from torch._ops import ops as ops # usort: skip |
| from torch._classes import classes as classes # usort: skip |
| |
| sys.modules.setdefault(f"{__name__}.ops", ops) |
| sys.modules.setdefault(f"{__name__}.classes", classes) |
| |
| # quantization depends on torch.fx and torch.ops |
| # Import quantization |
| from torch import quantization as quantization # usort: skip |
| |
| # Import the quasi random sampler |
| from torch import quasirandom as quasirandom # usort: skip |
| |
| # If you are seeing this, it means that this call site was not checked if |
| # the memory format could be preserved, and it was switched to old default |
| # behaviour of contiguous |
| legacy_contiguous_format = contiguous_format # defined by _C._initExtension() |
| |
| # Register fork handler to initialize OpenMP in child processes (see gh-28389) |
| from torch.multiprocessing._atfork import register_after_fork |
| |
| |
| register_after_fork(torch.get_num_threads) |
| del register_after_fork |
| |
| # Import tools that require fully imported torch (for applying |
| # torch.jit.script as a decorator, for instance): |
| from torch._lobpcg import lobpcg as lobpcg |
| |
| |
| # These were previously defined in native_functions.yaml and appeared on the |
| # `torch` namespace, but we moved them to c10 dispatch to facilitate custom |
| # class usage. We add these lines here to preserve backward compatibility. |
| quantized_lstm = ops.aten.quantized_lstm |
| quantized_gru = ops.aten.quantized_gru |
| |
| # Import experimental masked operations support. See |
| # [RFC-0016](https://github.com/pytorch/rfcs/pull/27) for more |
| # information. |
| from torch import masked as masked |
| |
| # Import removed ops with error message about removal |
| from torch._linalg_utils import ( # type: ignore[misc] |
| _symeig as symeig, |
| eig, |
| lstsq, |
| matrix_rank, |
| solve, |
| ) |
| from torch.utils.dlpack import from_dlpack, to_dlpack |
| |
| |
| class _TorchCompileInductorWrapper: |
| compiler_name = "inductor" |
| |
| def __init__(self, mode, options, dynamic): |
| self.config: _Dict[str, _Any] = {} |
| self.dynamic = dynamic |
| self.apply_mode(mode) |
| self.apply_options(options) |
| |
| if self.config.get("triton.cudagraphs", False): |
| os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1" |
| # FIXME: CUDA Graph does not work well with CUPTI teardown. |
| # 1) crashes on 1st lazy CUPTI re-init after teardown (CUDA 11) |
| # 2) crashes on 2nd non-lazy CUPTI re-init after teardown (CUDA 12) |
| # Workaround: turn off CUPTI teardown when using CUDA Graphs. |
| os.environ["TEARDOWN_CUPTI"] = "0" |
| |
| def __eq__(self, other): |
| return ( |
| isinstance(other, _TorchCompileInductorWrapper) |
| and self.config == other.config |
| and self.dynamic == other.dynamic |
| ) |
| |
| def apply_mode(self, mode: _Optional[str]): |
| if mode is None or mode == "default": |
| pass |
| elif mode in {"reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"}: |
| from torch._inductor import list_mode_options |
| |
| self.apply_options(list_mode_options(mode, self.dynamic)) |
| else: |
| raise RuntimeError( |
| f"Unrecognized mode={mode}, should be one of: default, reduce-overhead, max-autotune, max-autotune-no-cudagraphs" |
| ) |
| |
| def apply_options(self, options: _Optional[_Dict[str, _Any]]): |
| if not options: |
| return |
| |
| from torch._inductor import config |
| |
| current_config: _Dict[str, _Any] = config.shallow_copy_dict() |
| |
| for key, val in options.items(): |
| attr_name = key.replace("-", "_") |
| if attr_name not in current_config: |
| raise RuntimeError( |
| f"Unexpected optimization option {key}, known options are {list(current_config.keys())}" |
| ) |
| if type(val) is not type(current_config[attr_name]): |
| val_type_str = type(val).__name__ |
| expected_type_str = type(current_config[attr_name]).__name__ |
| raise RuntimeError( |
| f"Unexpected type of attr {key}, got {val_type_str} should be {expected_type_str}" |
| ) |
| self.config[attr_name] = val |
| |
| def __call__(self, model_, inputs_): |
| from torch._inductor.compile_fx import compile_fx |
| |
| return compile_fx(model_, inputs_, config_patches=self.config) |
| |
| def get_compiler_config(self): |
| from torch._inductor.compile_fx import get_patched_config_dict |
| |
| return get_patched_config_dict(config_patches=self.config) |
| |
| def reset(self): |
| from torch._inductor import config |
| |
| if "triton.cudagraphs" in self.config or config.triton.cudagraphs: |
| if self.config.get("triton.cudagraphs", True): |
| from torch._inductor.cudagraph_trees import reset_cudagraph_trees |
| |
| reset_cudagraph_trees() |
| |
| |
| class _TorchCompileWrapper: |
| def __init__(self, backend, mode, options, dynamic): |
| from torch._dynamo.backends.registry import lookup_backend |
| |
| if isinstance(backend, str): |
| self.compiler_name = backend |
| elif hasattr(backend, "__name__"): |
| self.compiler_name = backend.__name__ |
| else: |
| self.compiler_name = str(backend) |
| self.dynamic = dynamic |
| self.compiler_fn = lookup_backend(backend) |
| self.kwargs = {} |
| # only pass the args if they non-empty |
| if mode and mode != "default": |
| self.kwargs["mode"] = mode |
| if options: |
| self.kwargs["options"] = options |
| |
| def __eq__(self, other): |
| return ( |
| isinstance(other, _TorchCompileWrapper) |
| and self.compiler_fn == other.compiler_fn |
| and self.kwargs == other.kwargs |
| and self.dynamic == other.dynamic |
| ) |
| |
| def __call__(self, model_, inputs_): |
| return self.compiler_fn(model_, inputs_, **self.kwargs) |
| |
| def reset(self): |
| if hasattr(self.compiler_fn, "reset"): |
| self.compiler_fn.reset() |
| |
| |
| _InputT = _ParamSpec("_InputT") |
| _RetT = _TypeVar("_RetT") |
| |
| |
| @_overload |
| def compile( |
| model: _Callable[_InputT, _RetT], |
| *, |
| fullgraph: builtins.bool = False, |
| dynamic: _Optional[builtins.bool] = None, |
| backend: _Union[str, _Callable] = "inductor", |
| mode: _Union[str, None] = None, |
| options: _Optional[_Dict[str, _Union[str, builtins.int, builtins.bool]]] = None, |
| disable: builtins.bool = False, |
| ) -> _Callable[_InputT, _RetT]: ... |
| |
| |
| @_overload |
| def compile( |
| model: None = None, |
| *, |
| fullgraph: builtins.bool = False, |
| dynamic: _Optional[builtins.bool] = None, |
| backend: _Union[str, _Callable] = "inductor", |
| mode: _Union[str, None] = None, |
| options: _Optional[_Dict[str, _Union[str, builtins.int, builtins.bool]]] = None, |
| disable: builtins.bool = False, |
| ) -> _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]]: ... |
| |
| |
| def compile( |
| model: _Optional[_Callable] = None, |
| *, |
| fullgraph: builtins.bool = False, |
| dynamic: _Optional[builtins.bool] = None, |
| backend: _Union[str, _Callable] = "inductor", |
| mode: _Union[str, None] = None, |
| options: _Optional[_Dict[str, _Union[str, builtins.int, builtins.bool]]] = None, |
| disable: builtins.bool = False, |
| ) -> _Union[ |
| _Callable[[_Callable[_InputT, _RetT]], _Callable[_InputT, _RetT]], |
| _Callable[_InputT, _RetT], |
| ]: |
| """ |
| Optimizes given model/function using TorchDynamo and specified backend. |
| If you are compiling an :class:`torch.nn.Module`, you can also use :meth:`torch.nn.Module.compile` |
| to compile the module inplace without changing its structure. |
| |
| Concretely, for every frame executed within the compiled region, we will attempt |
| to compile it and cache the compiled result on the code object for future |
| use. A single frame may be compiled multiple times if previous compiled |
| results are not applicable for subsequent calls (this is called a "guard |
| failure), you can use TORCH_LOGS=guards to debug these situations. |
| Multiple compiled results can be associated with a frame up to |
| ``torch._dynamo.config.cache_size_limit``, which defaults to 8; at which |
| point we will fall back to eager. Note that compile caches are per |
| *code object*, not frame; if you dynamically create multiple copies of a |
| function, they will all share the same code cache. |
| |
| Args: |
| model (Callable): Module/function to optimize |
| fullgraph (bool): If False (default), torch.compile attempts to discover compileable regions |
| in the function that it will optimize. If True, then we require that the entire function be |
| capturable into a single graph. If this is not possible (that is, if there are graph breaks), |
| then this will raise an error. |
| dynamic (bool or None): Use dynamic shape tracing. When this is True, we will up-front attempt |
| to generate a kernel that is as dynamic as possible to avoid recompilations when |
| sizes change. This may not always work as some operations/optimizations will |
| force specialization; use TORCH_LOGS=dynamic to debug overspecialization. |
| When this is False, we will NEVER generate dynamic kernels, we will always specialize. |
| By default (None), we automatically detect if dynamism has occurred and compile a more |
| dynamic kernel upon recompile. |
| backend (str or Callable): backend to be used |
| |
| - "inductor" is the default backend, which is a good balance between performance and overhead |
| |
| - Non experimental in-tree backends can be seen with `torch._dynamo.list_backends()` |
| |
| - Experimental or debug in-tree backends can be seen with `torch._dynamo.list_backends(None)` |
| |
| - To register an out-of-tree custom backend: |
| https://pytorch.org/docs/main/torch.compiler_custom_backends.html#registering-custom-backends |
| mode (str): Can be either "default", "reduce-overhead", "max-autotune" or "max-autotune-no-cudagraphs" |
| |
| - "default" is the default mode, which is a good balance between performance and overhead |
| |
| - "reduce-overhead" is a mode that reduces the overhead of python with CUDA graphs, |
| useful for small batches. Reduction of overhead can come at the cost of more memory |
| usage, as we will cache the workspace memory required for the invocation so that we |
| do not have to reallocate it on subsequent runs. Reduction of overhead is not guaranteed |
| to work; today, we only reduce overhead for CUDA only graphs which do not mutate inputs. |
| There are other circumstances where CUDA graphs are not applicable; use TORCH_LOG=perf_hints |
| to debug. |
| |
| - "max-autotune" is a mode that leverages Triton or template based matrix multiplications |
| on supported devices and Triton based convolutions on GPU. |
| It enables CUDA graphs by default on GPU. |
| |
| - "max-autotune-no-cudagraphs" is a mode similar to "max-autotune" but without CUDA graphs |
| |
| - To see the exact configs that each mode sets you can call `torch._inductor.list_mode_options()` |
| |
| options (dict): A dictionary of options to pass to the backend. Some notable ones to try out are |
| |
| - `epilogue_fusion` which fuses pointwise ops into templates. Requires `max_autotune` to also be set |
| |
| - `max_autotune` which will profile to pick the best matmul configuration |
| |
| - `fallback_random` which is useful when debugging accuracy issues |
| |
| - `shape_padding` which pads matrix shapes to better align loads on GPUs especially for tensor cores |
| |
| - `triton.cudagraphs` which will reduce the overhead of python with CUDA graphs |
| |
| - `trace.enabled` which is the most useful debugging flag to turn on |
| |
| - `trace.graph_diagram` which will show you a picture of your graph after fusion |
| |
| - For inductor you can see the full list of configs that it supports by calling `torch._inductor.list_options()` |
| disable (bool): Turn torch.compile() into a no-op for testing |
| |
| Example:: |
| |
| @torch.compile(options={"triton.cudagraphs": True}, fullgraph=True) |
| def foo(x): |
| return torch.sin(x) + torch.cos(x) |
| |
| """ |
| _C._log_api_usage_once("torch.compile") |
| if sys.version_info >= (3, 13): |
| raise RuntimeError("Dynamo is not supported on Python 3.13+") |
| |
| # Decorator mode |
| if model is None: |
| |
| def fn(model: _Callable[_InputT, _RetT]) -> _Callable[_InputT, _RetT]: |
| if model is None: |
| raise RuntimeError("Model can't be None") |
| return compile( |
| model, |
| fullgraph=fullgraph, |
| dynamic=dynamic, |
| backend=backend, |
| mode=mode, |
| options=options, |
| disable=disable, |
| ) |
| |
| return fn |
| |
| if mode is not None and options is not None: |
| raise RuntimeError( |
| "Either mode or options can be specified, but both can't be specified at the same time." |
| ) |
| if mode is None and options is None: |
| mode = "default" |
| if backend == "inductor": |
| backend = _TorchCompileInductorWrapper(mode, options, dynamic) |
| else: |
| backend = _TorchCompileWrapper(backend, mode, options, dynamic) |
| |
| return torch._dynamo.optimize( |
| backend=backend, |
| nopython=fullgraph, |
| dynamic=dynamic, |
| disable=disable, |
| )(model) # type: ignore[return-value] |
| |
| |
| def _register_device_module(device_type, module): |
| r"""Register an external runtime module of the specific :attr:`device_type` |
| supported by torch. |
| |
| After the :attr:`module` is registered correctly, the user can refer |
| the external runtime module as part of torch with attribute torch.xxx. |
| """ |
| # Make sure the device_type represent a supported device type for torch. |
| device_type = torch.device(device_type).type |
| m = sys.modules[__name__] |
| if hasattr(m, device_type): |
| raise RuntimeError( |
| f"The runtime module of '{device_type}' has already " |
| f"been registered with '{getattr(m, device_type)}'" |
| ) |
| setattr(m, device_type, module) |
| torch_module_name = ".".join([__name__, device_type]) |
| sys.modules[torch_module_name] = module |
| |
| |
| from torch import ( |
| export as export, |
| func as func, |
| library as library, |
| return_types as return_types, |
| ) |
| from torch._higher_order_ops import cond as cond, while_loop as while_loop |
| from torch.func import vmap as vmap |
| |
| |
| if not TYPE_CHECKING: |
| from torch import _meta_registrations |
| |
| # Enable CUDA Sanitizer |
| if "TORCH_CUDA_SANITIZER" in os.environ: |
| import torch.cuda._sanitizer as csan |
| |
| csan.enable_cuda_sanitizer() |
| |
| # Populate magic methods on SymInt and SymFloat |
| import torch.fx.experimental.sym_node |
| |
| |
| # Register MPS specific decomps |
| torch.backends.mps._init() |
| |
| if not _running_with_deploy(): |
| from torch import compiler as compiler |
| |
| class _TritonLibrary: |
| lib = torch.library.Library("triton", "DEF") |
| ops_table: _Dict[_Tuple[str, str], _Callable] = {} |
| |
| @classmethod |
| def registerOp(cls, op_key, full_schema, op_impl, dispatch_key): |
| if (op_key, dispatch_key) not in cls.ops_table: |
| cls.lib.define(full_schema) |
| cls.lib.impl("triton::" + op_key, op_impl, dispatch_key) |
| cls.ops_table[(op_key, dispatch_key)] = op_impl |
| |
| return cls.ops_table[(op_key, dispatch_key)] |
| |
| |
| # Deprecated attributes |
| _deprecated_attrs = { |
| "has_mps": torch.backends.mps.is_built, |
| "has_cuda": torch.backends.cuda.is_built, |
| "has_cudnn": torch.backends.cudnn.is_available, |
| "has_mkldnn": torch.backends.mkldnn.is_available, |
| } |
| |
| if TYPE_CHECKING: |
| # Import the following modules during type checking to enable code intelligence features, |
| # such as auto-completion in tools like pylance, even when these modules are not explicitly |
| # imported in user code. |
| from torch import ( |
| _dynamo as _dynamo, |
| _inductor as _inductor, |
| _subclasses as _subclasses, |
| onnx as onnx, |
| ) |
| |
| else: |
| _lazy_modules = { |
| "_dynamo", |
| "_inductor", |
| "_export", |
| # ONNX must be imported after _dynamo, _ops, _subclasses, fx, func and jit |
| "onnx", |
| } |
| |
| def __getattr__(name): |
| # Deprecated attrs |
| replacement = _deprecated_attrs.get(name) |
| if replacement is not None: |
| import warnings |
| |
| warnings.warn( |
| f"'{name}' is deprecated, please use '{replacement.__module__}.{replacement.__name__}()'", |
| stacklevel=2, |
| ) |
| return replacement() |
| |
| # Lazy modules |
| if name in _lazy_modules: |
| return importlib.import_module(f".{name}", __name__) |
| |
| raise AttributeError(f"module '{__name__}' has no attribute '{name}'") |
| |
| |
| def get_device_module(device: _Optional[_Union[torch.device, str]] = None): |
| """ |
| Returns the module associated with a given device(e.g., torch.device('cuda'), "mtia:0", "xpu", ...). |
| If no device is given, return the module for the current accelerator or CPU if none is present. |
| """ |
| if isinstance(device, torch.device): |
| device_module_name = device.type |
| elif isinstance(device, str): |
| device_module_name = torch.device(device).type |
| elif device is None: |
| # Using default accelerator type. If no accelerator is available, it automatically returns CPU device. |
| device_module_name = torch._C._get_accelerator().type |
| else: |
| raise RuntimeError( |
| f"Invalid value of device '{device}', expect torch.device, str, or None" |
| ) |
| device_module = getattr(torch, device_module_name, None) |
| if device_module is None: |
| raise RuntimeError( |
| f"Device '{device_module_name}' does not have a corresponding module registered as 'torch.{device_module_name}'." |
| ) |
| return device_module |
| |
| |
| def _constrain_as_size( |
| symbol, |
| min: _Optional[builtins.int] = None, |
| max: _Optional[builtins.int] = None, |
| ): |
| """ |
| This indicates that a given int is size-like, and can be used in any context where a size is expected. |
| You will typically use this when reading out integers from Tensors, e.g., max.item() or lengths.tolist() |
| which then need to be used as tensor constructors. Providing these assertions to PyTorch can help resolve |
| GuardOnDataDependentSymNode errors upon export, since we cannot guard on unbacked SymInts. |
| |
| This function has unusual semantics in some circumstances in framework |
| code, we will treat this int as >= 2 (when we do a size-oblivious guard). |
| This makes it easier to use the unbacked int in size contexts, |
| as we will often attempt to guard on a size being zero/one |
| (e.g., when computing the contiguity of a tensor, or testing if |
| broadcasting can occur), which will not work on unbacked SymInts. |
| However, if we conservatively assume that the size is not zero/one, we will |
| end up with a graph that will still work even if the size is zero/one. |
| |
| For more details, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit |
| ``` |
| """ |
| torch.sym_constrain_range_for_size(symbol, min=min, max=max) |
| |
| |
| from torch import _logging |
| |
| |
| _logging._init_logs() |
| |
| |
| def _import_device_backends(): |
| """ |
| Leverage the Python plugin mechanism to load out-of-the-tree device extensions. |
| See this RFC: https://github.com/pytorch/pytorch/issues/122468 |
| """ |
| from importlib.metadata import entry_points |
| |
| group_name = "torch.backends" |
| if sys.version_info < (3, 10): |
| backend_extensions = entry_points().get(group_name, ()) |
| else: |
| backend_extensions = entry_points(group=group_name) |
| |
| for backend_extension in backend_extensions: |
| try: |
| # Load the extension |
| entrypoint = backend_extension.load() |
| # Call the entrypoint |
| entrypoint() |
| except Exception as err: |
| raise RuntimeError( |
| f"Failed to load the backend extension: {backend_extension.name}. " |
| f"You can disable extension auto-loading with TORCH_DEVICE_BACKEND_AUTOLOAD=0." |
| ) from err |
| |
| |
| def _is_device_backend_autoload_enabled() -> builtins.bool: |
| """ |
| Whether autoloading out-of-the-tree device extensions is enabled. |
| The switch depends on the value of the environment variable |
| `TORCH_DEVICE_BACKEND_AUTOLOAD`. |
| |
| Returns: |
| bool: Whether to enable autoloading the extensions. Enabled by default. |
| |
| Examples: |
| >>> torch._is_device_backend_autoload_enabled() |
| True |
| """ |
| # enabled by default |
| return os.getenv("TORCH_DEVICE_BACKEND_AUTOLOAD", "1") == "1" |
| |
| |
| if _is_device_backend_autoload_enabled(): |
| _import_device_backends() |