| # mypy: allow-untyped-defs |
| """ |
| The weak_script annotation needs to be here instead of inside torch/jit/ so it |
| can be used in other places in torch/ (namely torch.nn) without running into |
| circular dependency problems |
| """ |
| |
| import ast |
| import builtins |
| import collections |
| import contextlib |
| import enum |
| import inspect |
| import io |
| import pickle |
| import sys |
| import textwrap |
| import threading |
| import types |
| import typing |
| import warnings |
| import weakref |
| from typing import ( |
| Any, |
| Callable, |
| Dict, |
| Final, |
| ForwardRef, |
| get_args, |
| get_origin, |
| List, |
| Optional, |
| Tuple, |
| Type, |
| Union, |
| ) |
| |
| import torch |
| |
| # This is needed. `torch._jit_internal` is imported before `torch.distributed.__init__`. |
| # Explicitly ask to import `torch.distributed.__init__` first. |
| # Otherwise, "AttributeError: module 'torch' has no attribute 'distributed'" is raised. |
| import torch.distributed.rpc |
| import torch.package._mangling as package_mangling |
| from torch._awaits import _Await |
| from torch._C import _Await as CAwait, Future as CFuture |
| from torch._sources import fake_range, get_source_lines_and_file, parse_def |
| from torch.futures import Future |
| |
| |
| IS_PY39_PLUS: Final[bool] = sys.version_info >= (3, 9) |
| IS_PY310_PLUS: Final[bool] = sys.version_info >= (3, 10) |
| |
| BuiltinUnionType: Union[Type, Tuple[Type, ...]] |
| if sys.version_info >= (3, 10): |
| # NOTE: IS_PY310_PLUS doesn't work with mypy. |
| # cf. https://mypy.readthedocs.io/en/stable/common_issues.html#python-version-and-system-platform-checks |
| BuiltinUnionType = types.UnionType |
| else: |
| BuiltinUnionType = () # trick: this makes isinstance short circuit. |
| |
| LockType: Type |
| try: |
| import _thread |
| |
| LockType = _thread.LockType |
| except ImportError: |
| import _dummy_thread # type: ignore[import-not-found] |
| |
| LockType = _dummy_thread.LockType |
| |
| # Wrapper functions that can call either of 2 functions depending on a boolean |
| # argument |
| boolean_dispatched: "weakref.WeakKeyDictionary[Callable, Dict[str, Callable]]" = ( |
| weakref.WeakKeyDictionary() |
| ) # noqa: T484 |
| |
| |
| FAKE_FILENAME_PREFIX = "__torch_jit_dataclass" |
| |
| |
| def is_final(ann) -> bool: |
| return ( |
| hasattr(ann, "__module__") |
| and ann.__module__ in {"typing", "typing_extensions"} |
| and (get_origin(ann) is Final or isinstance(ann, type(Final))) |
| ) |
| |
| |
| # allows BroadcastingList instance to be subscriptable |
| class BroadcastingListCls: |
| def __getitem__(self, types): |
| return |
| |
| |
| # mypy doesn't support parameters on types, so we have to explicitly type each |
| # list size |
| BroadcastingList1 = BroadcastingListCls() |
| for i in range(2, 7): |
| globals()[f"BroadcastingList{i}"] = BroadcastingList1 |
| |
| |
| def is_scripting() -> bool: |
| r""" |
| Function that returns True when in compilation and False otherwise. This |
| is useful especially with the @unused decorator to leave code in your |
| model that is not yet TorchScript compatible. |
| .. testcode:: |
| |
| import torch |
| |
| @torch.jit.unused |
| def unsupported_linear_op(x): |
| return x |
| |
| def linear(x): |
| if torch.jit.is_scripting(): |
| return torch.linear(x) |
| else: |
| return unsupported_linear_op(x) |
| """ |
| return False |
| |
| |
| # Retrieves a fully-qualified name (module hierarchy + classname) for a given obj. |
| def _qualified_name(obj, mangle_name=True) -> str: |
| # This special case allows us to override the qualified name on a type. |
| # It's currently used in conjunction with tracing, where we create a |
| # fake module to filter only supported attributes. However, since this |
| # new type is defined as a local class, we need a mechanism to override |
| # its qualname so it appears correctly in the TorchScript system. This, |
| # we set '_jit_override_qualname' with the original traced module's |
| # qualified name, which is picked up here |
| if hasattr(obj, "_jit_override_qualname"): |
| return obj._jit_override_qualname |
| # short-circuit in cases where the object already has a known qualified name |
| if isinstance(obj, torch._C.ScriptFunction): |
| return obj.qualified_name |
| |
| if getattr(obj, "__name__", None): |
| name = obj.__name__ |
| # Enum classes do not have `__name__` attr, instead they have `name`. |
| elif isinstance(obj, enum.Enum): |
| name = obj.name |
| else: |
| raise RuntimeError("Could not get name of python class object") |
| |
| if name == "<lambda>": |
| name = "_lambda" # make name a valid identifier |
| |
| module_name = obj.__module__ |
| |
| # If the module is actually a torchbind module, then we should short circuit |
| if module_name == "torch._classes": |
| return obj.qualified_name |
| |
| # The Python docs are very clear that `__module__` can be None, but I can't |
| # figure out when it actually would be. |
| if module_name is None: |
| raise RuntimeError( |
| f"Could not get qualified name for class '{name}': " |
| "__module__ can't be None." |
| ) |
| |
| # if getattr(sys.modules[module_name], name) is not obj: |
| # raise RuntimeError(f"Could not get qualified name for class '{name}': " |
| # f"the attr {name} on module {module_name} is not the class") |
| |
| # torch.package and TorchScript have separate mangling schemes to avoid |
| # name collisions from multiple packages. To avoid them interfering with |
| # each other, normalize the package manging here. |
| if package_mangling.is_mangled(module_name): |
| module_name = module_name.replace("<", "_") |
| module_name = module_name.replace(">", "_") |
| |
| # The PythonExceptionValue C++ class in torch/csrc/jit/python/python_sugared_value.h |
| # does not need mangle the python class name. |
| if mangle_name: |
| # __main__ is a builtin module, so rewrite it to "__torch__". |
| if module_name == "__main__": |
| module_name = "__torch__" |
| else: |
| # Everything else gets a "__torch__" prefix to avoid name collisions |
| # with the names of user values. |
| module_name = "__torch__." + module_name |
| |
| if "." in name: |
| raise RuntimeError( |
| f"Could not get qualified name for class '{name}': " |
| f"'{name}' is not a valid identifier" |
| ) |
| |
| return module_name + "." + name |
| |
| |
| class SourceLoader: |
| def __init__(self): |
| self.content = {} |
| |
| def cache(self, fn, source): |
| self.content[fn] = source |
| |
| def get_source(self, fn): |
| return self.content.get(fn) |
| |
| |
| loader = SourceLoader() |
| |
| |
| def createResolutionCallbackFromEnv(lookup_base): |
| """ |
| Creates a resolution callback that will look up qualified names in an |
| environment, starting with `lookup_base` for the base of any qualified |
| names, then proceeding down the lookup chain with the resolved object. |
| |
| You should not use this directly, it should only be used from the other |
| createResolutionCallbackFrom* functions. |
| """ |
| |
| def lookupInModule(qualified_name, module): |
| if "." in qualified_name: |
| base, remaining_pieces = qualified_name.split(".", maxsplit=1) |
| module_value = getattr(module, base) |
| return lookupInModule(remaining_pieces, module_value) |
| else: |
| return getattr(module, qualified_name) |
| |
| def parseNestedExpr(expr, module) -> Tuple[Any, int]: |
| i = 0 |
| while i < len(expr) and expr[i] not in (",", "[", "]"): |
| i += 1 |
| |
| # Special case logic for the empty Tuple as a subscript (used |
| # in the type annotation `Tuple[()]`) |
| if expr[:i] == "()": |
| return (), i |
| |
| base = lookupInModule(expr[:i].strip(), module) |
| assert base is not None, f"Unresolvable type {expr[:i]}" |
| if i == len(expr) or expr[i] != "[": |
| return base, i |
| |
| assert expr[i] == "[" |
| parts = [] |
| while expr[i] != "]": |
| part_len = 0 |
| i += 1 |
| part, part_len = parseNestedExpr(expr[i:], module) |
| parts.append(part) |
| i += part_len |
| if len(parts) > 1: |
| return base[tuple(parts)], i + 1 |
| else: |
| return base[parts[0]], i + 1 |
| |
| def parseExpr(expr, module): |
| try: |
| value, len_parsed = parseNestedExpr(expr, module) |
| assert len_parsed == len( |
| expr |
| ), "whole expression was not parsed, falling back to c++ parser" |
| return value |
| except Exception: |
| """ |
| The python resolver fails in several cases in known unit tests, and is intended |
| to fall back gracefully to the c++ resolver in general. For example, python 2 style |
| annotations which are frequent in our unit tests often fail with types e.g. int not |
| resolvable from the calling frame. |
| """ |
| return None |
| |
| return lambda expr: parseExpr(expr, lookup_base) |
| |
| |
| def createResolutionCallbackFromFrame(frames_up: int = 0): |
| """ |
| Creates a function which, given a string variable name, |
| returns the value of the variable in the scope of the caller of |
| the function which called createResolutionCallbackFromFrame (by default). |
| |
| This is used to enable access in-scope Python variables inside |
| TorchScript fragments. |
| |
| frames_up is number of additional frames to go up on the stack. |
| The default value is 0, which correspond to the frame of the caller |
| of createResolutionCallbackFromFrame. Also for example, if frames_up is set |
| to 1, then the frame of the caller's caller of createResolutionCallbackFromFrame |
| will be taken. |
| |
| For example, the following program prints 2:: |
| |
| def bar(): |
| cb = createResolutionCallbackFromFrame(1) |
| print(cb("foo")) |
| |
| |
| def baz(): |
| foo = 2 |
| bar() |
| |
| |
| baz() |
| """ |
| frame = inspect.currentframe() |
| i = 0 |
| while i < frames_up + 1: |
| assert frame is not None |
| frame = frame.f_back |
| i += 1 |
| |
| assert frame is not None |
| f_locals = frame.f_locals |
| f_globals = frame.f_globals |
| |
| class env: |
| def __getattr__(self, key): |
| if key in f_locals: |
| return f_locals[key] |
| elif key in f_globals: |
| return f_globals[key] |
| elif key in dir(builtins): |
| return getattr(builtins, key) |
| |
| return createResolutionCallbackFromEnv(env()) |
| |
| |
| def get_closure(fn): |
| """ |
| Get a dictionary of closed over variables from a function |
| """ |
| captures = {} |
| captures.update(fn.__globals__) |
| |
| for index, captured_name in enumerate(fn.__code__.co_freevars): |
| captures[captured_name] = fn.__closure__[index].cell_contents |
| |
| return captures |
| |
| |
| # [local resolution in python] |
| # Depending on where a variable is defined, and where it is used, we may |
| # or may not be able to recover its value when recursively compiling a |
| # script function. Remember in the general case, a module or function is |
| # first defined and then later scripted. This means we do not have a |
| # chance to capture the active frames when the function is defined. Hence any |
| # name resolution has to happen later on the created closure. The way |
| # python captures type annotations restricts what we can recover. The |
| # follow example illustrates the different cases: |
| # |
| # class MyGlobalClass: |
| # ... |
| # def my_local_scope(): |
| # @torch.jit.script |
| # class MyClass: |
| # ... |
| # @torch.jit.script |
| # class MyClassUsedAsVar: |
| # ... |
| # def eg(x: MyClass, y: MyGlobalClass): |
| # a_local_capture : Foo |
| # return MyClassUsedAsVar(x) |
| # |
| # MyGlobalClass is defined in the __globals__ dictionary of function |
| # 'eg', so it is always recoverable. my_local_scope introduces a new local |
| # variable scope in the function. Classes defined here are only visible as |
| # local variables. For the case of MyClassUsedAsVar, it is captured |
| # because it is used as a variable inside the body of the function, and we |
| # can resolve it using the captures returned from `get_closure`. However, |
| # the type annotations are not captured by the closure. In Python |
| # 3.0--3.9, the _value_ of MyClass and MyGlobalClass will be available as |
| # annotations on `eg``, but starting in Python 4.0, they will represented as |
| # strings and no longer present. Furthermore, since the body of `eg` does |
| # not reference those names, they do not appear in the list of closed over |
| # variables. In Python 2.x, type annotations are in comments, leading to a |
| # similar situation where their definitions are not available. We anticipate |
| # that most users will not run into this issue because their modules and |
| # functions will be defined at a global scope like MyGlobalClass. In cases |
| # where they are not, it is possible to work around issues by declaring the |
| # values global in the function. |
| # In Python 3.9 declaring class as global will make it invisible to |
| # `inspect.getsource`, see https://bugs.python.org/issue42666 . |
| # This could be worked around by manualy adding it to `global()` dictionary. |
| |
| |
| def createResolutionCallbackFromClosure(fn): |
| """ |
| Create a resolutionCallback by introspecting the function instead of |
| looking up the stack for the enclosing scope |
| """ |
| closure = get_closure(fn) |
| |
| class closure_lookup: |
| # This is a class since `closure` is a dict and it's easier in |
| # `env_helper` if everything just works with `getattr` calls |
| def __getattr__(self, key): |
| if key in closure: |
| return closure[key] |
| elif hasattr(typing, key): |
| return getattr(typing, key) |
| elif hasattr(builtins, key): |
| return getattr(builtins, key) |
| return None |
| |
| return createResolutionCallbackFromEnv(closure_lookup()) |
| |
| |
| def can_compile_class(cls) -> bool: |
| # If any of the functions on a type don't have a code object, this type can't |
| # be compiled and is probably a builtin / bound from C |
| if is_ignored_fn(cls): |
| return False |
| |
| # Ignore the following list of built-in classes. |
| ignored_builtin_classes = (torch.nn.Module, tuple, list, Exception) |
| if issubclass(cls, ignored_builtin_classes): |
| return False |
| |
| names = cls.__dict__ |
| fns = [ |
| getattr(cls, name) |
| for name in names |
| if inspect.isroutine(getattr(cls, name, None)) |
| ] |
| has_code = [hasattr(fn, "__code__") for fn in fns] |
| return all(has_code) |
| |
| |
| def get_callable_argument_names(fn) -> List[str]: |
| """ |
| Gets names of all POSITIONAL_OR_KEYWORD arguments for callable `fn`. |
| Returns an empty list when other types of arguments are present. |
| |
| This is used by `torch.jit.trace` to assign meaningful argument names to |
| traced functions and modules. |
| |
| Args: |
| fn: A callable. |
| Returns: |
| Argument names: List[str] |
| """ |
| # inspect.signature may fail, give up in that case. |
| try: |
| callable_signature = inspect.signature(fn) |
| except Exception: |
| return [] |
| |
| argument_names = [] |
| for name, param in callable_signature.parameters.items(): |
| # All four other types of arguments do not map to individual values |
| # with a keyword as name. |
| if not param.kind == param.POSITIONAL_OR_KEYWORD: |
| continue |
| |
| argument_names.append(name) |
| |
| return argument_names |
| |
| |
| def get_annotation_str(annotation): |
| """ |
| Convert an AST node containing a type annotation to the string present in the source |
| that represents the same annotation. |
| """ |
| if isinstance(annotation, ast.Name): |
| return annotation.id |
| elif isinstance(annotation, ast.Attribute): |
| return ".".join([get_annotation_str(annotation.value), annotation.attr]) |
| elif isinstance(annotation, ast.Subscript): |
| # In Python3.9+ subscript indicies are not wrapped in ast.Index |
| subscript_slice = annotation.slice if IS_PY39_PLUS else annotation.slice.value # type: ignore[attr-defined] |
| return f"{get_annotation_str(annotation.value)}[{get_annotation_str(subscript_slice)}]" |
| elif isinstance(annotation, ast.Tuple): |
| return ",".join([get_annotation_str(elt) for elt in annotation.elts]) |
| elif isinstance(annotation, ast.Constant): |
| return f"{annotation.value}" |
| |
| # If an AST node is not handled here, it's probably handled in ScriptTypeParser. |
| return None |
| |
| |
| def get_type_hint_captures(fn): |
| """ |
| Get a dictionary containing type resolution mappings necessary to resolve types |
| for the literal annotations on 'fn'. These are not considered to be closed-over by fn |
| and must be obtained separately (e.g. using this function). |
| |
| Args: |
| fn: A callable. |
| Returns: |
| A Dict[str, Any] containing a mapping from the literal annotations used on |
| fn to the Python objects they refer to. |
| """ |
| # First, try to get the source of the function. We'll need to parse it to find the actual string names |
| # that were used to annotate the types, since inspect.signature() will only return the class object that |
| # the annotation refers to, not the string name. If we can't get the source, simply return an empty dict. |
| # This may happen in cases where the function is synthesized dynamically at runtime. |
| src = loader.get_source(fn) |
| if src is None: |
| try: |
| src = inspect.getsource(fn) |
| except OSError as e: |
| raise OSError( |
| f"Failed to get source for {fn} using inspect.getsource" |
| ) from e |
| |
| # Gather a dictionary of parameter name -> type, skipping any parameters whose annotated |
| # types are strings. These are only understood by TorchScript in the context of a type annotation |
| # that refers to a class in its own definition, but trying to include a mapping for this in the result |
| # function would cause infinite recursion because the class is currently being compiled. |
| # In addition, there is logic in ScriptTypeParser to handle this. |
| signature = inspect.signature(fn) |
| name_to_type = { |
| name: parameter.annotation |
| for name, parameter in signature.parameters.items() |
| if parameter.annotation is not inspect.Parameter.empty |
| and not isinstance(parameter.annotation, str) |
| } |
| |
| # Then, get the literal type annotations from the function declaration |
| # by source inspection. This accounts for the case in which aliases are used |
| # to annotate the arguments (e.g device_t = torch.device, and then d: device_t). |
| # frontend.py cannot be used here because it includes _jit_internal, so use ast instead. |
| a = ast.parse(textwrap.dedent(src)) |
| if len(a.body) != 1 or not isinstance(a.body[0], ast.FunctionDef): |
| raise RuntimeError(f"Expected {fn} to be a function") |
| f = a.body[0] |
| |
| # Prepare a dictionary of source annotation -> type, which will be the final result of this function, |
| # by using the parsed AST (f) to reconstruct source annotations as strings for each parameter and mapping |
| # them to the type object corresponding to the annotation via name_to_type using the parameter name. |
| annotation_to_type = {} |
| |
| for arg in f.args.args: |
| # Get the source type annotation string for this argument if possible. |
| arg_annotation_str = ( |
| get_annotation_str(arg.annotation) if arg.annotation else None |
| ) |
| |
| # If the argument has no annotation or get_annotation_str cannot convert it to a string, |
| # arg_annotation_str will be None. Skip this arg; ScriptTypeParser will probably handle |
| # this in the latter case. |
| if arg_annotation_str is None: |
| continue |
| |
| # Insert {arg_annotation_str: type} into annotation_to_type if possible. One reason arg_name may not |
| # be present in name_to_type is that the annotation itself is a string and not a type object |
| # (common for self-refential annotations in classes). Once again, let ScriptTypeParser handle this. |
| arg_name = arg.arg |
| if arg_name in name_to_type: |
| annotation_to_type[arg_annotation_str] = name_to_type[arg_name] |
| |
| # If there is a valid return annotation, include it in annotation_to_type. As with argument annotations, |
| # the literal annotation has to be convertible to a string by get_annotation_str, and the actual type |
| # of the annotation cannot be a string. |
| literal_return_annotation = get_annotation_str(f.returns) |
| valid_literal_annotation = literal_return_annotation is not None |
| return_annotation = signature.return_annotation |
| valid_return_annotation_type = ( |
| return_annotation is not inspect.Parameter.empty |
| and not isinstance(return_annotation, str) |
| ) |
| if valid_literal_annotation and valid_return_annotation_type: |
| annotation_to_type[literal_return_annotation] = return_annotation |
| |
| return annotation_to_type |
| |
| |
| def createResolutionCallbackForClassMethods(cls): |
| """ |
| This looks at all the methods defined in a class and pulls their closed-over |
| variables into a dictionary and uses that to resolve variables. |
| """ |
| # cls is a type here, so `ismethod` is false since the methods on the type |
| # aren't bound to anything, so Python treats them as regular functions |
| fns = [ |
| getattr(cls, name) |
| for name in cls.__dict__ |
| if inspect.isroutine(getattr(cls, name)) |
| ] |
| # Skip built-ins, as they do not have global scope nor type hints |
| # Needed to support `enum.Enum` derived classes in Python-3.11 |
| # That adds `_new_member_` property which is an alias to `__new__` |
| fns = [fn for fn in fns if not inspect.isbuiltin(fn) and hasattr(fn, "__globals__")] |
| captures = {} |
| |
| for fn in fns: |
| captures.update(get_closure(fn)) |
| captures.update(get_type_hint_captures(fn)) |
| |
| def lookup_in_class(key): |
| if key in captures: |
| return captures[key] |
| else: |
| return getattr(builtins, key, None) |
| |
| return lookup_in_class |
| |
| |
| def boolean_dispatch( |
| arg_name, |
| arg_index, |
| default, |
| if_true, |
| if_false, |
| module_name, |
| func_name, |
| ): |
| """ |
| Dispatches to either of 2 script functions based on a boolean argument. |
| In TorchScript, the boolean argument must be constant so that the correct |
| function to use can be determined at compile time. |
| """ |
| |
| def fn(*args, **kwargs): |
| dispatch_flag = default |
| if arg_name in kwargs: |
| dispatch_flag = kwargs[arg_name] |
| elif arg_index < len(args): |
| dispatch_flag = args[arg_index] |
| |
| if dispatch_flag: |
| return if_true(*args, **kwargs) |
| else: |
| return if_false(*args, **kwargs) |
| |
| if if_true.__doc__ is None and if_false.__doc__ is not None: |
| doc = if_false.__doc__ |
| if_true.__doc__ = doc |
| elif if_false.__doc__ is None and if_true.__doc__ is not None: |
| doc = if_true.__doc__ |
| if_false.__doc__ = doc |
| elif if_false.__doc__ is None and if_true.__doc__ is None: |
| # neither function has a docstring |
| doc = None |
| else: |
| raise RuntimeError("only one function can have a docstring") |
| fn.__doc__ = doc |
| |
| if module_name is not None: |
| fn.__module__ = module_name |
| if func_name is not None: |
| fn.__name__ = func_name |
| |
| boolean_dispatched[fn] = { |
| "if_true": if_true, |
| "if_false": if_false, |
| "index": arg_index, |
| "default": default, |
| "arg_name": arg_name, |
| } |
| return fn |
| |
| |
| class FunctionModifiers: |
| """ |
| Used to denote the behavior of a function in TorchScript. See export() and |
| ignore() for details. |
| """ |
| |
| UNUSED = "unused (ignored and replaced with raising of an exception)" |
| IGNORE = "ignore (leave as a call to Python, cannot be torch.jit.save'd)" |
| EXPORT = "export (compile this function even if nothing calls it)" |
| DEFAULT = "default (compile if called from a exported function / forward)" |
| COPY_TO_SCRIPT_WRAPPER = ( |
| "if this method is not scripted, copy the python method onto the scripted model" |
| ) |
| _DROP = "_drop (function is fully ignored, declaration can be unscriptable)" |
| |
| |
| def export(fn): |
| """ |
| This decorator indicates that a method on an ``nn.Module`` is used as an entry point into a |
| :class:`ScriptModule` and should be compiled. |
| |
| ``forward`` implicitly is assumed to be an entry point, so it does not need this decorator. |
| Functions and methods called from ``forward`` are compiled as they are seen |
| by the compiler, so they do not need this decorator either. |
| |
| Example (using ``@torch.jit.export`` on a method): |
| |
| .. testcode:: |
| |
| import torch |
| import torch.nn as nn |
| |
| class MyModule(nn.Module): |
| def implicitly_compiled_method(self, x): |
| return x + 99 |
| |
| # `forward` is implicitly decorated with `@torch.jit.export`, |
| # so adding it here would have no effect |
| def forward(self, x): |
| return x + 10 |
| |
| @torch.jit.export |
| def another_forward(self, x): |
| # When the compiler sees this call, it will compile |
| # `implicitly_compiled_method` |
| return self.implicitly_compiled_method(x) |
| |
| def unused_method(self, x): |
| return x - 20 |
| |
| # `m` will contain compiled methods: |
| # `forward` |
| # `another_forward` |
| # `implicitly_compiled_method` |
| # `unused_method` will not be compiled since it was not called from |
| # any compiled methods and wasn't decorated with `@torch.jit.export` |
| m = torch.jit.script(MyModule()) |
| """ |
| fn._torchscript_modifier = FunctionModifiers.EXPORT |
| return fn |
| |
| |
| def unused(fn): |
| """ |
| This decorator indicates to the compiler that a function or method should |
| be ignored and replaced with the raising of an exception. This allows you |
| to leave code in your model that is not yet TorchScript compatible and still |
| export your model. |
| |
| Example (using ``@torch.jit.unused`` on a method):: |
| |
| import torch |
| import torch.nn as nn |
| |
| |
| class MyModule(nn.Module): |
| def __init__(self, use_memory_efficient): |
| super().__init__() |
| self.use_memory_efficient = use_memory_efficient |
| |
| @torch.jit.unused |
| def memory_efficient(self, x): |
| import pdb |
| |
| pdb.set_trace() |
| return x + 10 |
| |
| def forward(self, x): |
| # Use not-yet-scriptable memory efficient mode |
| if self.use_memory_efficient: |
| return self.memory_efficient(x) |
| else: |
| return x + 10 |
| |
| |
| m = torch.jit.script(MyModule(use_memory_efficient=False)) |
| m.save("m.pt") |
| |
| m = torch.jit.script(MyModule(use_memory_efficient=True)) |
| # exception raised |
| m(torch.rand(100)) |
| """ |
| if isinstance(fn, property): |
| prop = fn |
| setattr( # noqa: B010 |
| prop.fget, "_torchscript_modifier", FunctionModifiers.UNUSED |
| ) |
| |
| if prop.fset: |
| setattr( # noqa: B010 |
| prop.fset, "_torchscript_modifier", FunctionModifiers.UNUSED |
| ) |
| |
| return prop |
| |
| fn._torchscript_modifier = FunctionModifiers.UNUSED |
| return fn |
| |
| |
| # No op context manager from python side |
| class _IgnoreContextManager(contextlib.AbstractContextManager): |
| def __init__(self, **kwargs): |
| pass |
| |
| def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: |
| pass |
| |
| |
| def ignore(drop=False, **kwargs): |
| """ |
| This decorator indicates to the compiler that a function or method should |
| be ignored and left as a Python function. This allows you to leave code in |
| your model that is not yet TorchScript compatible. If called from TorchScript, |
| ignored functions will dispatch the call to the Python interpreter. Models with ignored |
| functions cannot be exported; use :func:`@torch.jit.unused <torch.jit.unused>` instead. |
| |
| Example (using ``@torch.jit.ignore`` on a method):: |
| |
| import torch |
| import torch.nn as nn |
| |
| |
| class MyModule(nn.Module): |
| @torch.jit.ignore |
| def debugger(self, x): |
| import pdb |
| |
| pdb.set_trace() |
| |
| def forward(self, x): |
| x += 10 |
| # The compiler would normally try to compile `debugger`, |
| # but since it is `@ignore`d, it will be left as a call |
| # to Python |
| self.debugger(x) |
| return x |
| |
| |
| m = torch.jit.script(MyModule()) |
| |
| # Error! The call `debugger` cannot be saved since it calls into Python |
| m.save("m.pt") |
| |
| Example (using ``@torch.jit.ignore(drop=True)`` on a method): |
| |
| .. testcode:: |
| |
| import torch |
| import torch.nn as nn |
| |
| class MyModule(nn.Module): |
| @torch.jit.ignore(drop=True) |
| def training_method(self, x): |
| import pdb |
| pdb.set_trace() |
| |
| def forward(self, x): |
| if self.training: |
| self.training_method(x) |
| return x |
| |
| m = torch.jit.script(MyModule()) |
| |
| # This is OK since `training_method` is not saved, the call is replaced |
| # with a `raise`. |
| m.save("m.pt") |
| |
| .. testcleanup:: |
| |
| import os |
| os.remove('m.pt') |
| """ |
| |
| if callable(drop): |
| # used without any args, so drop is actually a function |
| # @torch.jit.ignore |
| # def fn(...): |
| fn = drop |
| fn._torchscript_modifier = FunctionModifiers.IGNORE |
| return fn |
| |
| if not isinstance(drop, bool): |
| raise RuntimeError( |
| "Argument to @torch.jit.ignore must be a bool or " |
| f"a function but got {drop}" |
| ) |
| |
| # for backwards compat |
| drop_on_export = kwargs.pop("drop_on_export", None) |
| if drop_on_export: |
| warnings.warn( |
| "ignore(drop_on_export=True) has been deprecated. TorchScript will now drop the function " |
| "call on compilation. Use torch.jit.unused now. {}", |
| category=FutureWarning, |
| ) |
| |
| drop = drop_on_export |
| elif drop: |
| warnings.warn( |
| "ignore(True) has been deprecated. TorchScript will now drop the function " |
| "call on compilation. Use torch.jit.unused now. {}", |
| category=FutureWarning, |
| ) |
| |
| def decorator(fn): |
| if drop: |
| fn._torchscript_modifier = FunctionModifiers.UNUSED |
| else: |
| fn._torchscript_modifier = FunctionModifiers.IGNORE |
| return fn |
| |
| return decorator |
| |
| |
| def _drop(fn): |
| fn._torchscript_modifier = FunctionModifiers._DROP |
| return fn |
| |
| |
| def _copy_to_script_wrapper(fn): |
| fn._torchscript_modifier = FunctionModifiers.COPY_TO_SCRIPT_WRAPPER |
| return fn |
| |
| |
| def module_has_exports(mod): |
| for name in dir(mod): |
| if hasattr(mod, name): |
| item = getattr(mod, name) |
| if callable(item): |
| if get_torchscript_modifier(item) is FunctionModifiers.EXPORT: |
| return True |
| return False |
| |
| |
| # WARNING: should_drop is currently being used by our JIT code coverage plug-in to mark JIT'd code as covered. If you |
| # rename this function, please update references in tools/coverage_plugins_package/src/coverage_plugins/jit_plugin.py to |
| # allow JIT'd code to still be covered. |
| def should_drop(fn) -> bool: |
| attr = get_torchscript_modifier(fn) |
| if attr is None: |
| return False |
| return attr is FunctionModifiers.UNUSED or attr is FunctionModifiers._DROP |
| |
| |
| def is_ignored_fn(fn) -> bool: |
| mod = get_torchscript_modifier(fn) |
| return ( |
| mod is FunctionModifiers.UNUSED |
| or mod is FunctionModifiers.IGNORE |
| or mod is FunctionModifiers._DROP |
| ) |
| |
| |
| def _is_drop_fn(fn) -> bool: |
| mod = get_torchscript_modifier(fn) |
| return mod is FunctionModifiers._DROP |
| |
| |
| def is_static_fn(cls, fn) -> bool: |
| return isinstance(inspect.getattr_static(cls, fn, default=None), staticmethod) |
| |
| |
| def get_static_fn(cls, fn): |
| return inspect.getattr_static(cls, fn).__func__ |
| |
| |
| def get_torchscript_modifier(fn): |
| if not callable(fn): |
| return None |
| if hasattr(fn, "__func__"): |
| fn = fn.__func__ |
| return getattr(fn, "_torchscript_modifier", FunctionModifiers.DEFAULT) |
| |
| |
| def copy_torchscript_modifier(orig, new) -> None: |
| attr = get_torchscript_modifier(orig) |
| if attr is None: |
| return |
| new._torchscript_modifier = attr |
| |
| |
| # overloading registration |
| # overloads get registered in this file, and compiled in torch/jit/__init__.py |
| # so that they can be imported in nn/functional.py without an import cycle |
| |
| # qualified_name => list[overload_functions] |
| _overloaded_fns: Dict[str, List[Callable]] = {} # noqa: T484 |
| |
| |
| _OVERLOAD_EXAMPLE = """ |
| Example usage of overload function: |
| @torch.jit._overload |
| def my_function(x: type0) -> type0: # decl 1 |
| pass |
| |
| @torch.jit._overload |
| def my_function(x: type1) -> type1: # decl 2 |
| pass |
| |
| def my_function(x): # implementation |
| if isinstance(x, type0): |
| return x |
| elif isinstance(x, type1): |
| return x |
| """ |
| |
| |
| def get_overload_no_implementation_error_message(kind, obj): |
| sourcelines, file_lineno, filename = get_source_lines_and_file(obj) |
| return ( |
| f'Implementation for the {kind} "{_qualified_name(obj)}" is missing. Please make ' |
| f"sure a definition is provided and defined after all overload declarations.\n" |
| f'File "{filename}", line {file_lineno}:\n' |
| + "".join(sourcelines) |
| + "\n" |
| + _OVERLOAD_EXAMPLE |
| ) |
| |
| |
| def _check_overload_body(func): |
| try: |
| parsed_def = parse_def(func) |
| except OSError as e: |
| # Parsing the function definition can raise an OSError if source is unavailable. |
| # Since this is just an initial check, just raise a warning if this is the case. |
| warnings.warn( |
| f"Unable to retrieve source for @torch.jit._overload function: {func}." |
| ) |
| return |
| |
| body = parsed_def.ast.body[0].body |
| |
| def is_pass(x): |
| return isinstance(x, ast.Pass) |
| |
| def is_ellipsis(x): |
| return ( |
| isinstance(x, ast.Expr) |
| and isinstance(x.value, ast.Constant) |
| and x.value.value is Ellipsis |
| ) |
| |
| if len(body) != 1 or not (is_pass(body[0]) or is_ellipsis(body[0])): |
| msg = ( |
| "Only `pass` statement or `...` can be the body of overload declaration:\n" |
| ) |
| msg += "\n".join(parsed_def.source.split("\n")[:3]) |
| msg += " <- Expecting `pass` or `...` here!\n" + _OVERLOAD_EXAMPLE |
| raise RuntimeError(msg) |
| |
| |
| def _overload(func): |
| _check_overload_body(func) |
| qual_name = _qualified_name(func) |
| global _overloaded_fns |
| fn_overload_list = _overloaded_fns.get(qual_name) |
| if fn_overload_list is None: |
| fn_overload_list = [] |
| _overloaded_fns[qual_name] = fn_overload_list |
| fn_overload_list.append(func) |
| return func |
| |
| |
| def _get_fn_overloads(qual_name): |
| return _overloaded_fns.get(qual_name) |
| |
| |
| def _clear_fn_overloads(qual_name) -> None: |
| del _overloaded_fns[qual_name] |
| |
| |
| def get_class_name_lineno(method) -> Tuple[str, int]: |
| current_frame = inspect.currentframe() |
| |
| # one for the get_class_name call, one for _overload_method call |
| for i in range(2): |
| assert ( |
| current_frame is not None |
| ) # assert current frame is not an Optional[FrameType] |
| current_frame = current_frame.f_back |
| |
| assert current_frame is not None # same here |
| class_name = current_frame.f_code.co_name |
| line_no = current_frame.f_code.co_firstlineno |
| return class_name, line_no |
| |
| |
| # At the point the decorator is applied to class methods the method |
| # has no reference to its owning class. _qualified_name would not include |
| # the class it is defined in, so any methods with the same name in the same file |
| # would have the same _qualified_name, even if they were defined in different |
| # classes. This problem only exists in python 2. |
| # We get around this problem by looking at the stack frame and identifying |
| # the class name, and throwing an error whenever overloads are used |
| # when modules of the same name are in the same file |
| |
| # qualified_name => class name => list[overload_functions] |
| _overloaded_methods: Dict[str, Dict[str, List[Callable]]] = {} # noqa: T484 |
| |
| |
| # (qualified_name, class name) => class_fileno |
| _overloaded_method_class_fileno: Dict[Tuple[str, str], int] = {} |
| |
| |
| def _overload_method(func): |
| _check_overload_body(func) |
| qual_name = _qualified_name(func) |
| global _overloaded_methods |
| class_name_map = _overloaded_methods.get(qual_name, None) |
| if class_name_map is None: |
| class_name_map = {} |
| _overloaded_methods[qual_name] = class_name_map |
| |
| class_name, line_no = get_class_name_lineno(func) |
| method_overloads = class_name_map.get(class_name, None) |
| if method_overloads is None: |
| method_overloads = [] |
| class_name_map[class_name] = method_overloads |
| _overloaded_method_class_fileno[(qual_name, class_name)] = line_no |
| else: |
| existing_lineno = _overloaded_method_class_fileno[(qual_name, class_name)] |
| if existing_lineno != line_no: |
| raise RuntimeError( |
| "Cannot currently overload the same method name in two different" |
| " classes with the same name in the same module" |
| ) |
| |
| method_overloads.append(func) |
| return func |
| |
| |
| def _get_overloaded_methods(method, mod_class): |
| # TODO: __name__ not set for submodules in recursive script |
| if not hasattr(method, "__name__"): |
| return None |
| qual_name = _qualified_name(method) |
| class_name_map = _overloaded_methods.get(qual_name, None) |
| if class_name_map is None: |
| return None |
| overloads = class_name_map.get(mod_class.__name__, None) |
| if overloads is None: |
| return None |
| |
| method_line_no = get_source_lines_and_file(method)[1] |
| mod_class_fileno = get_source_lines_and_file(mod_class)[1] |
| mod_end_fileno = mod_class_fileno + len(get_source_lines_and_file(mod_class)[0]) |
| if not (method_line_no >= mod_class_fileno and method_line_no <= mod_end_fileno): |
| raise AssertionError( |
| "Overloads are not useable when a module is redeclared within the same file: " |
| + str(method) |
| ) |
| return overloads |
| |
| |
| def is_tuple(ann) -> bool: |
| if ann is Tuple: |
| raise_error_container_parameter_missing("Tuple") |
| |
| # For some reason Python 3.7 violates the Type[A, B].__origin__ == Type rule |
| if not hasattr(ann, "__module__"): |
| return False |
| |
| ann_origin = get_origin(ann) |
| if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is tuple: |
| return True |
| return ann.__module__ == "typing" and (ann_origin is Tuple or ann_origin is tuple) |
| |
| |
| def is_list(ann) -> bool: |
| if ann is List: |
| raise_error_container_parameter_missing("List") |
| |
| if not hasattr(ann, "__module__"): |
| return False |
| |
| ann_origin = get_origin(ann) |
| if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is list: |
| return True |
| return ann.__module__ == "typing" and (ann_origin is List or ann_origin is list) |
| |
| |
| def is_dict(ann) -> bool: |
| if ann is Dict: |
| raise_error_container_parameter_missing("Dict") |
| |
| if not hasattr(ann, "__module__"): |
| return False |
| |
| ann_origin = get_origin(ann) |
| if IS_PY39_PLUS and ann.__module__ == "builtins" and ann_origin is dict: |
| return True |
| return ann.__module__ == "typing" and (ann_origin is Dict or ann_origin is dict) |
| |
| |
| def is_union(ann): |
| if ann is Union: |
| raise_error_container_parameter_missing("Union") |
| |
| return isinstance(ann, BuiltinUnionType) or ( |
| hasattr(ann, "__module__") |
| and ann.__module__ == "typing" |
| and (get_origin(ann) is Union) |
| ) |
| |
| |
| def is_optional(ann): |
| if ann is Optional: |
| raise_error_container_parameter_missing("Optional") |
| |
| def is_optional_as_optional(ann): |
| return ( |
| hasattr(ann, "__module__") |
| and ann.__module__ == "typing" |
| and (get_origin(ann) is Optional) |
| ) |
| |
| def is_union_as_optional(ann): |
| ann_args = get_args(ann) |
| return len(ann_args) == 2 and (None in ann_args or type(None) in ann_args) |
| |
| return is_optional_as_optional(ann) or (is_union(ann) and is_union_as_optional(ann)) |
| |
| |
| def is_future(ann) -> bool: |
| if ann is Future: |
| raise RuntimeError( |
| "Attempted to use Future without a " |
| "contained type. Please add a contained type, e.g. " |
| "Future[int]" |
| ) |
| return get_origin(ann) is Future |
| |
| |
| def is_await(ann) -> bool: |
| if ann is _Await: |
| return True |
| return get_origin(ann) is _Await |
| |
| |
| if torch.distributed.rpc.is_available(): |
| from torch._C._distributed_rpc import PyRRef |
| from torch.distributed.rpc import RRef |
| |
| def is_rref(ann) -> bool: |
| if ann is RRef: |
| raise RuntimeError( |
| "Attempted to use RRef without a " |
| "contained type. Please add a contained type, e.g. " |
| "RRef[int]" |
| ) |
| return get_origin(ann) is RRef |
| |
| def is_rref_instance(obj) -> bool: |
| return isinstance(obj, PyRRef) |
| |
| else: |
| |
| def is_rref_instance(obj) -> bool: |
| # If the RPC module doesn't exist then RRefs don't exist either. |
| return False |
| |
| |
| def _try_get_dispatched_fn(fn): |
| if not callable(fn): |
| return None |
| return boolean_dispatched.get(fn) |
| |
| |
| def _get_named_tuple_properties( |
| obj, |
| loc: Optional[torch._C._jit_tree_views.SourceRange] = None, |
| rcb=None, |
| ): |
| if loc is None: |
| loc = fake_range() |
| |
| assert issubclass(obj, tuple) and hasattr(obj, "_fields") |
| if hasattr(obj, "_field_defaults"): |
| defaults = [ |
| obj._field_defaults[field] |
| for field in obj._fields |
| if field in obj._field_defaults |
| ] |
| else: |
| defaults = [] |
| # In 3.10 recommended way to get annotations is to call `inspect.get_annotations` function |
| # Also, annotations from base class are not inherited so they need to be queried explicitly |
| if sys.version_info[:2] < (3, 10): |
| obj_annotations = getattr(obj, "__annotations__", {}) |
| else: |
| obj_annotations = inspect.get_annotations(obj) |
| if len(obj_annotations) == 0 and hasattr(obj, "__base__"): |
| obj_annotations = inspect.get_annotations(obj.__base__) |
| |
| annotations = [] |
| for field in obj._fields: |
| if field in obj_annotations: |
| field_type = obj_annotations[field] |
| # [Note: ForwardRef annotations in NamedTuple attributes] |
| # NamedTuple types are slightly different from normal types. |
| # |
| # Normally, annotations are evaluted like this (during jit.script): |
| # 1. Load strings of python code into c++ and parse. |
| # 2. Get annotations as strings |
| # 3. Use the PythonResolver's resolution callback (rcb) to convert |
| # the string into a python object |
| # 4. We call into annotations.py:ann_to_type to convert python obj |
| # from step 3 into a type that torchscript understands. |
| # |
| # NamedTuples are more complicated, because it has sub-types. |
| # Normally, once we have the NamedTuple type object from #3, |
| # we can just look at the annotation literal values and use |
| # ann_to_type directly on them. |
| # |
| # But sometimes, users will annotate with string literals, e.g. |
| # x: 'int' |
| # This also happens with PEP563 (from __forward__ import annotations) |
| # |
| # These annotations appear in the annotation dict as ForwardRef('int'). |
| # |
| # Then, we need to convert the string into a python object. This |
| # requires having local context for custom objects or imported types. |
| # rcb() is what gives us this. So, we plumb rcb through the stack so |
| # it can be used in this context for the if block below. |
| # |
| # FAQ: |
| # - Why do we need this special handling for NamedTuple but string |
| # annotations work fine for normal types? Normally, we parse the |
| # string directly and then call rcb() directly from C++. |
| # - Why not use ForwardRef._evaluate? For that, we need globals() |
| # and locals() for the local context where the NamedTuple was defined. |
| # rcb is what lets us look up into these. So, basically rcb does the |
| # hard work for us. |
| if isinstance(field_type, ForwardRef) and rcb is not None: |
| rcb_type = rcb(field_type.__forward_arg__) |
| # rcb returns None if it can't find anything. |
| if rcb_type is None: |
| raise ValueError( |
| f"Unknown type annotation: '{field_type}' in NamedTuple {obj.__name__}." |
| f" Likely due to partial support for ForwardRef parameters in NamedTuples, see #95858." |
| f" Issue occurred at {loc.highlight()}" |
| ) |
| field_type = rcb_type |
| the_type = torch.jit.annotations.ann_to_type(field_type, loc, rcb) |
| annotations.append(the_type) |
| else: |
| annotations.append(torch._C.TensorType.getInferred()) |
| return type(obj).__name__, obj._fields, annotations, defaults |
| |
| |
| def _create_named_tuple( |
| t, |
| unqual_name: str, |
| field_names: List[str], |
| defaults: Tuple[Any, ...], |
| ): |
| TupleType = collections.namedtuple(unqual_name, field_names, defaults=defaults) # type: ignore[call-arg, no-redef, misc] |
| return TupleType(*t) |
| |
| |
| @contextlib.contextmanager |
| def _disable_emit_hooks(): |
| hooks = torch._C._jit_get_emit_hooks() |
| torch._C._jit_set_emit_hooks(None, None) |
| try: |
| yield |
| finally: |
| torch._C._jit_set_emit_hooks(hooks[0], hooks[1]) |
| |
| |
| def _disable_emit_hooks_decorator(_DecoratorContextManager) -> None: # noqa: F811 |
| def __enter__(self) -> None: |
| self.hooks = torch._C._jit_get_emit_hooks() |
| torch._C._jit_set_emit_hooks(None, None) |
| |
| def __exit__(self, *args) -> None: |
| torch._C._jit_set_emit_hooks(self.hooks[0], self.hooks[1]) |
| |
| |
| def _is_exception(obj) -> bool: |
| if not inspect.isclass(obj): |
| return False |
| return issubclass(obj, Exception) |
| |
| |
| def raise_error_container_parameter_missing(target_type) -> None: |
| if target_type == "Dict": |
| raise RuntimeError( |
| "Attempted to use Dict without " |
| "contained types. Please add contained type, e.g. " |
| "Dict[int, int]" |
| ) |
| raise RuntimeError( |
| f"Attempted to use {target_type} without a " |
| "contained type. Please add a contained type, e.g. " |
| f"{target_type}[int]" |
| ) |
| |
| |
| def check_args_exist(target_type) -> None: |
| if target_type is List or target_type is list: |
| raise_error_container_parameter_missing("List") |
| elif target_type is Tuple or target_type is tuple: |
| raise_error_container_parameter_missing("Tuple") |
| elif target_type is Dict or target_type is dict: |
| raise_error_container_parameter_missing("Dict") |
| elif target_type is None or target_type is Optional: |
| raise_error_container_parameter_missing("Optional") |
| |
| |
| def check_empty_containers(obj) -> None: |
| if obj == [] or obj == {} or obj == (): |
| warnings.warn( |
| "The inner type of a container is lost when " |
| "calling torch.jit.isinstance in eager mode. For " |
| "example, List[int] would become list and " |
| "therefore falsely return True for List[float] or" |
| " List[str]." |
| ) |
| |
| |
| # supports List/Dict/Tuple and Optional types |
| # TODO support future |
| def container_checker(obj, target_type) -> bool: |
| origin_type = get_origin(target_type) |
| check_args_exist(target_type) |
| if origin_type is None: |
| return False |
| elif origin_type is list or origin_type is List: |
| check_empty_containers(obj) |
| if not isinstance(obj, list): |
| return False |
| arg_type = get_args(target_type)[0] |
| arg_origin = get_origin(arg_type) |
| for el in obj: |
| # check if nested container, ex: List[List[str]] |
| if arg_origin: # processes nested container, ex: List[List[str]] |
| if not container_checker(el, arg_type): |
| return False |
| elif not isinstance(el, arg_type): |
| return False |
| return True |
| elif origin_type is Dict or origin_type is dict: |
| check_empty_containers(obj) |
| if not isinstance(obj, dict): |
| return False |
| key_type = get_args(target_type)[0] |
| val_type = get_args(target_type)[1] |
| for key, val in obj.items(): |
| # check if keys are of right type |
| if not isinstance(key, key_type): |
| return False |
| val_origin = get_origin(val_type) |
| if val_origin: |
| if not container_checker(val, val_type): |
| return False |
| elif not isinstance(val, val_type): |
| return False |
| return True |
| elif origin_type is Tuple or origin_type is tuple: |
| check_empty_containers(obj) |
| if not isinstance(obj, tuple): |
| return False |
| arg_types = get_args(target_type) |
| if len(obj) != len(arg_types): |
| return False |
| for el, el_type in zip(obj, arg_types): |
| el_origin = get_origin(el_type) |
| if el_origin: |
| if not container_checker(el, el_type): |
| return False |
| elif not isinstance(el, el_type): |
| return False |
| return True |
| elif origin_type is Union or issubclass( |
| origin_type, BuiltinUnionType |
| ): # also handles Optional |
| if obj is None: # check before recursion because None is always fine |
| return True |
| inner_types = get_args(target_type) |
| for t in inner_types: |
| t_origin = get_origin(t) |
| if t_origin: |
| return container_checker(obj, t) |
| elif isinstance(obj, t): |
| return True |
| return False |
| |
| |
| def _isinstance(obj, target_type) -> bool: |
| if isinstance(target_type, collections.abc.Container): |
| if not isinstance(target_type, tuple): |
| raise RuntimeError( |
| "The second argument to " |
| "`torch.jit.isinstance` must be a type " |
| "or a tuple of types" |
| ) |
| for t_type in target_type: |
| if _isinstance(obj, t_type): |
| return True |
| return False |
| |
| origin_type = get_origin(target_type) |
| if origin_type: |
| return container_checker(obj, target_type) |
| |
| # Check to handle non-typed optional origin returns as none instead |
| # of as optional in 3.7-3.8 |
| check_args_exist(target_type) |
| |
| # handle non-containers |
| return isinstance(obj, target_type) |
| |
| |
| class _TensorExtractor(pickle.Pickler): |
| def __init__(self, *args, tensors: List[torch.Tensor], **kwargs): |
| super().__init__(*args, **kwargs) |
| self.tensors = tensors |
| |
| def persistent_id(self, obj): |
| if isinstance(obj, torch.Tensor): |
| self.tensors.append(obj) |
| return "" |
| # Since we just want to extract tensors, we don't mind if an object is |
| # unpicklable if it doesn't contain tensors, as we can just ignore/skip |
| # it. To play it safe, we only do so for common objects that we're sure |
| # don't contain tensors. Feel free to add new types here. Note also that |
| # even if a type isn't listed here this won't block users, since thet |
| # can just add a __getstate__ or __reduce__ method to their class. |
| if isinstance(obj, LockType): |
| return "" |
| # Futures and RRefs don't technically contain a value, they just offer |
| # the means to access a value. |
| if isinstance(obj, CFuture) or is_rref_instance(obj): |
| return "" |
| if isinstance(obj, CAwait): |
| return "" |
| if isinstance(obj, torch.cuda.Event): |
| return "" |
| if isinstance(obj, threading.Thread): |
| return "" |
| return None |
| |
| |
| def _extract_tensors(obj): |
| r""" |
| This function is exclusively called from C++. |
| See ``torch/csrc/jit/python/python_ivalue.h``. |
| |
| It extracts the tensors contained in the given object, through pickling. |
| """ |
| tensors: List[torch.Tensor] = [] |
| extractor = _TensorExtractor(io.BytesIO(), protocol=-1, tensors=tensors) |
| extractor.dump(obj) |
| return tensors |
| |
| |
| def _get_model_id(obj) -> Optional[str]: |
| if isinstance(obj, torch.jit.ScriptModule): |
| return str(obj._c._type()) |
| elif isinstance(obj, torch.jit.ScriptFunction): |
| return obj.qualified_name |
| else: |
| return None |
| |
| |
| # In Python-3.11+ typed enums (i.e. IntEnum for example) retain number of base class methods in subclass |
| # that were previously dropped. To preserve the behavior, explicitly drop them there |
| |
| if sys.version_info > (3, 10): |
| _drop(enum.Enum.__new__) |
| _drop(enum.Enum.__format__) |
| _drop(enum.Enum.__repr__) |
| _drop(enum.Enum.__str__) |