| .. testsetup:: |
| |
| # These are hidden from the docs, but these are necessary for `doctest` |
| # since the `inspect` module doesn't play nicely with the execution |
| # environment for `doctest` |
| import torch |
| |
| original_script = torch.jit.script |
| def script_wrapper(obj, *args, **kwargs): |
| obj.__module__ = 'FakeMod' |
| return original_script(obj, *args, **kwargs) |
| |
| torch.jit.script = script_wrapper |
| |
| original_trace = torch.jit.trace |
| def trace_wrapper(obj, *args, **kwargs): |
| obj.__module__ = 'FakeMod' |
| return original_trace(obj, *args, **kwargs) |
| |
| torch.jit.trace = trace_wrapper |
| |
| .. _language-reference-v2: |
| |
| TorchScript Language Reference |
| ============================== |
| |
| This reference manual describes the syntax and core semantics of the TorchScript language. |
| TorchScript is a statically typed subset of the Python language. This document explains the supported features of |
| Python in TorchScript and also how the language diverges from regular Python. Any features of Python that are not mentioned in |
| this reference manual are not part of TorchScript. TorchScript focuses specifically on the features of Python that are needed to |
| represent neural network models in PyTorch. |
| |
| .. contents:: |
| :local: |
| :depth: 1 |
| |
| .. _type_system: |
| |
| Terminology |
| ~~~~~~~~~~~ |
| |
| This document uses the following terminologies: |
| |
| .. list-table:: |
| :widths: 25 25 |
| :header-rows: 1 |
| |
| * - Pattern |
| - Notes |
| * - ``::=`` |
| - Indicates that the given symbol is defined as. |
| * - ``" "`` |
| - Represents real keywords and delimiters that are part of the syntax. |
| * - ``A | B`` |
| - Indicates either A or B. |
| * - ``( )`` |
| - Indicates grouping. |
| * - ``[]`` |
| - Indicates optional. |
| * - ``A+`` |
| - Indicates a regular expression where term A is repeated at least once. |
| * - ``A*`` |
| - Indicates a regular expression where term A is repeated zero or more times. |
| |
| Type System |
| ~~~~~~~~~~~ |
| TorchScript is a statically typed subset of Python. The largest difference between TorchScript and the full Python language is that TorchScript only supports a small set of types that are needed to express |
| neural net models. |
| |
| TorchScript Types |
| ^^^^^^^^^^^^^^^^^ |
| |
| The TorchScript type system consists of ``TSType`` and ``TSModuleType`` as defined below. |
| |
| :: |
| |
| TSAllType ::= TSType | TSModuleType |
| TSType ::= TSMetaType | TSPrimitiveType | TSStructuralType | TSNominalType |
| |
| ``TSType`` represents the majority of TorchScript types that are composable and that can be used in TorchScript type annotations. |
| ``TSType`` refers to any of the following: |
| |
| * Meta Types, e.g., ``Any`` |
| * Primitive Types, e.g., ``int``, ``float``, and ``str`` |
| * Structural Types, e.g., ``Optional[int]`` or ``List[MyClass]`` |
| * Nominal Types (Python classes), e.g., ``MyClass`` (user-defined), ``torch.tensor`` (built-in) |
| |
| ``TSModuleType`` represents ``torch.nn.Module`` and its subclasses. It is treated differently from ``TSType`` because its type schema is inferred partly from the object instance and partly from the class definition. |
| As such, instances of a ``TSModuleType`` may not follow the same static type schema. ``TSModuleType`` cannot be used as a TorchScript type annotation or be composed with ``TSType`` for type safety considerations. |
| |
| Meta Types |
| ^^^^^^^^^^ |
| |
| Meta types are so abstract that they are more like type constraints than concrete types. |
| Currently TorchScript defines one meta-type, ``Any``, that represents any TorchScript type. |
| |
| ``Any`` Type |
| """""""""""" |
| |
| The ``Any`` type represents any TorchScript type. ``Any`` specifies no type constraints, thus there is no type-checking on ``Any``. |
| As such it can be bound to any Python or TorchScript data types (e.g., ``int``, TorchScript ``tuple``, or an arbitrary Python class that is not scripted). |
| |
| :: |
| |
| TSMetaType ::= "Any" |
| |
| Where: |
| |
| * ``Any`` is the Python class name from the typing module. Therefore, to use the ``Any`` type, you must import it from ``typing`` (e.g., ``from typing import Any``). |
| * Since ``Any`` can represent any TorchScript type, the set of operators that are allowed to operate on values of this type on ``Any`` is limited. |
| |
| Operators Supported for ``Any`` Type |
| """""""""""""""""""""""""""""""""""" |
| |
| * Assignment to data of ``Any`` type. |
| * Binding to parameter or return of ``Any`` type. |
| * ``x is``, ``x is not`` where ``x`` is of ``Any`` type. |
| * ``isinstance(x, Type)`` where ``x`` is of ``Any`` type. |
| * Data of ``Any`` type is printable. |
| * Data of ``List[Any]`` type may be sortable if the data is a list of values of the same type ``T`` and that ``T`` supports comparison operators. |
| |
| **Compared to Python** |
| |
| |
| ``Any`` is the least constrained type in the TorchScript type system. In that sense, it is quite similar to the |
| ``Object`` class in Python. However, ``Any`` only supports a subset of the operators and methods that are supported by ``Object``. |
| |
| Design Notes |
| """""""""""" |
| |
| When we script a PyTorch module, we may encounter data that is not involved in the execution of the script. Nevertheless, it has to be described |
| by a type schema. It is not only cumbersome to describe static types for unused data (in the context of the script), but also may lead to unnecessary |
| scripting failures. ``Any`` is introduced to describe the type of the data where precise static types are not necessary for compilation. |
| |
| **Example 1** |
| |
| This example illustrates how ``Any`` can be used to allow the second element of the tuple parameter to be of any type. This is possible |
| because ``x[1]`` is not involved in any computation that requires knowing its precise type. |
| |
| .. testcode:: |
| |
| import torch |
| |
| from typing import Tuple |
| from typing import Any |
| |
| @torch.jit.export |
| def inc_first_element(x: Tuple[int, Any]): |
| return (x[0]+1, x[1]) |
| |
| m = torch.jit.script(inc_first_element) |
| print(m((1,2.0))) |
| print(m((1,(100,200)))) |
| |
| The example above produces the following output: |
| |
| .. testoutput:: |
| |
| (2, 2.0) |
| (2, (100, 200)) |
| |
| The second element of the tuple is of ``Any`` type, thus can bind to multiple types. |
| For example, ``(1, 2.0)`` binds a float type to ``Any`` as in ``Tuple[int, Any]``, |
| whereas ``(1, (100, 200))`` binds a tuple to ``Any`` in the second invocation. |
| |
| |
| **Example 2** |
| |
| This example illustrates how we can use ``isinstance`` to dynamically check the type of the data that is annotated as ``Any`` type: |
| |
| .. testcode:: |
| |
| import torch |
| from typing import Any |
| |
| def f(a:Any): |
| print(a) |
| return (isinstance(a, torch.Tensor)) |
| |
| ones = torch.ones([2]) |
| m = torch.jit.script(f) |
| print(m(ones)) |
| |
| The example above produces the following output: |
| |
| .. testoutput:: |
| |
| 1 |
| 1 |
| [ CPUFloatType{2} ] |
| True |
| |
| Primitive Types |
| ^^^^^^^^^^^^^^^ |
| |
| Primitive TorchScript types are types that represent a single type of value and go with a single pre-defined |
| type name. |
| |
| :: |
| |
| TSPrimitiveType ::= "int" | "float" | "double" | "complex" | "bool" | "str" | "None" |
| |
| Structural Types |
| ^^^^^^^^^^^^^^^^ |
| |
| Structural types are types that are structurally defined without a user-defined name (unlike nominal types), |
| such as ``Future[int]``. Structural types are composable with any ``TSType``. |
| |
| :: |
| |
| TSStructuralType ::= TSTuple | TSNamedTuple | TSList | TSDict | |
| TSOptional | TSUnion | TSFuture | TSRRef | TSAwait |
| |
| TSTuple ::= "Tuple" "[" (TSType ",")* TSType "]" |
| TSNamedTuple ::= "namedtuple" "(" (TSType ",")* TSType ")" |
| TSList ::= "List" "[" TSType "]" |
| TSOptional ::= "Optional" "[" TSType "]" |
| TSUnion ::= "Union" "[" (TSType ",")* TSType "]" |
| TSFuture ::= "Future" "[" TSType "]" |
| TSRRef ::= "RRef" "[" TSType "]" |
| TSAwait ::= "Await" "[" TSType "]" |
| TSDict ::= "Dict" "[" KeyType "," TSType "]" |
| KeyType ::= "str" | "int" | "float" | "bool" | TensorType | "Any" |
| |
| Where: |
| |
| * ``Tuple``, ``List``, ``Optional``, ``Union``, ``Future``, ``Dict`` represent Python type class names that are defined in the module ``typing``. To use these type names, you must import them from ``typing`` (e.g., ``from typing import Tuple``). |
| * ``namedtuple`` represents the Python class ``collections.namedtuple`` or ``typing.NamedTuple``. |
| * ``Future`` and ``RRef`` represent the Python classes ``torch.futures`` and ``torch.distributed.rpc``. |
| * ``Await`` represent the Python class ``torch._awaits._Await`` |
| |
| **Compared to Python** |
| |
| Apart from being composable with TorchScript types, these TorchScript structural types often support a common subset of the operators and methods of their Python counterparts. |
| |
| **Example 1** |
| |
| This example uses ``typing.NamedTuple`` syntax to define a tuple: |
| |
| .. testcode:: |
| |
| import torch |
| from typing import NamedTuple |
| from typing import Tuple |
| |
| class MyTuple(NamedTuple): |
| first: int |
| second: int |
| |
| def inc(x: MyTuple) -> Tuple[int, int]: |
| return (x.first+1, x.second+1) |
| |
| t = MyTuple(first=1, second=2) |
| scripted_inc = torch.jit.script(inc) |
| print("TorchScript:", scripted_inc(t)) |
| |
| The example above produces the following output: |
| |
| .. testoutput:: |
| |
| TorchScript: (2, 3) |
| |
| **Example 2** |
| |
| This example uses ``collections.namedtuple`` syntax to define a tuple: |
| |
| .. testcode:: |
| |
| import torch |
| from typing import NamedTuple |
| from typing import Tuple |
| from collections import namedtuple |
| |
| _AnnotatedNamedTuple = NamedTuple('_NamedTupleAnnotated', [('first', int), ('second', int)]) |
| _UnannotatedNamedTuple = namedtuple('_NamedTupleAnnotated', ['first', 'second']) |
| |
| def inc(x: _AnnotatedNamedTuple) -> Tuple[int, int]: |
| return (x.first+1, x.second+1) |
| |
| m = torch.jit.script(inc) |
| print(inc(_UnannotatedNamedTuple(1,2))) |
| |
| The example above produces the following output: |
| |
| .. testoutput:: |
| |
| (2, 3) |
| |
| **Example 3** |
| |
| This example illustrates a common mistake of annotating structural types, i.e., not importing the composite type |
| classes from the ``typing`` module: |
| |
| :: |
| |
| import torch |
| |
| # ERROR: Tuple not recognized because not imported from typing |
| @torch.jit.export |
| def inc(x: Tuple[int, int]): |
| return (x[0]+1, x[1]+1) |
| |
| m = torch.jit.script(inc) |
| print(m((1,2))) |
| |
| Running the above code yields the following scripting error: |
| |
| :: |
| |
| File "test-tuple.py", line 5, in <module> |
| def inc(x: Tuple[int, int]): |
| NameError: name 'Tuple' is not defined |
| |
| The remedy is to add the line ``from typing import Tuple`` to the beginning of the code. |
| |
| Nominal Types |
| ^^^^^^^^^^^^^ |
| |
| Nominal TorchScript types are Python classes. These types are called nominal because they are declared with a custom |
| name and are compared using class names. Nominal classes are further classified into the following categories: |
| |
| :: |
| |
| TSNominalType ::= TSBuiltinClasses | TSCustomClass | TSEnum |
| |
| Among them, ``TSCustomClass`` and ``TSEnum`` must be compilable to TorchScript Intermediate Representation (IR). This is enforced by the type-checker. |
| |
| Built-in Class |
| ^^^^^^^^^^^^^^ |
| |
| Built-in nominal types are Python classes whose semantics are built into the TorchScript system (e.g., tensor types). |
| TorchScript defines the semantics of these built-in nominal types, and often supports only a subset of the methods or |
| attributes of its Python class definition. |
| |
| :: |
| |
| TSBuiltinClass ::= TSTensor | "torch.device" | "torch.Stream" | "torch.dtype" | |
| "torch.nn.ModuleList" | "torch.nn.ModuleDict" | ... |
| TSTensor ::= "torch.Tensor" | "common.SubTensor" | "common.SubWithTorchFunction" | |
| "torch.nn.parameter.Parameter" | and subclasses of torch.Tensor |
| |
| |
| Special Note on torch.nn.ModuleList and torch.nn.ModuleDict |
| """"""""""""""""""""""""""""""""""""""""""""""""""""""""""" |
| |
| Although ``torch.nn.ModuleList`` and ``torch.nn.ModuleDict`` are defined as a list and dictionary in Python, |
| they behave more like tuples in TorchScript: |
| |
| * In TorchScript, instances of ``torch.nn.ModuleList`` or ``torch.nn.ModuleDict`` are immutable. |
| * Code that iterates over ``torch.nn.ModuleList`` or ``torch.nn.ModuleDict`` is completely unrolled so that elements of ``torch.nn.ModuleList`` or keys of ``torch.nn.ModuleDict`` can be of different subclasses of ``torch.nn.Module``. |
| |
| **Example** |
| |
| The following example highlights the use of a few built-in Torchscript classes (``torch.*``): |
| |
| :: |
| |
| import torch |
| |
| @torch.jit.script |
| class A: |
| def __init__(self): |
| self.x = torch.rand(3) |
| |
| def f(self, y: torch.device): |
| return self.x.to(device=y) |
| |
| def g(): |
| a = A() |
| return a.f(torch.device("cpu")) |
| |
| script_g = torch.jit.script(g) |
| print(script_g.graph) |
| |
| Custom Class |
| ^^^^^^^^^^^^ |
| |
| Unlike built-in classes, semantics of custom classes are user-defined and the entire class definition must be compilable to TorchScript IR and subject to TorchScript type-checking rules. |
| |
| :: |
| |
| TSClassDef ::= [ "@torch.jit.script" ] |
| "class" ClassName [ "(object)" ] ":" |
| MethodDefinition | |
| [ "@torch.jit.ignore" ] | [ "@torch.jit.unused" ] |
| MethodDefinition |
| |
| Where: |
| |
| * Classes must be new-style classes. Python 3 supports only new-style classes. In Python 2.x, a new-style class is specified by subclassing from the object. |
| * Instance data attributes are statically typed, and instance attributes must be declared by assignments inside the ``__init__()`` method. |
| * Method overloading is not supported (i.e., you cannot have multiple methods with the same method name). |
| * ``MethodDefinition`` must be compilable to TorchScript IR and adhere to TorchScript’s type-checking rules, (i.e., all methods must be valid TorchScript functions and class attribute definitions must be valid TorchScript statements). |
| * ``torch.jit.ignore`` and ``torch.jit.unused`` can be used to ignore the method or function that is not fully torchscriptable or should be ignored by the compiler. |
| |
| **Compared to Python** |
| |
| |
| TorchScript custom classes are quite limited compared to their Python counterpart. Torchscript custom classes: |
| |
| * Do not support class attributes. |
| * Do not support subclassing except for subclassing an interface type or object. |
| * Do not support method overloading. |
| * Must initialize all its instance attributes in ``__init__()``; this is because TorchScript constructs a static schema of the class by inferring attribute types in ``__init__()``. |
| * Must contain only methods that satisfy TorchScript type-checking rules and are compilable to TorchScript IRs. |
| |
| **Example 1** |
| |
| Python classes can be used in TorchScript if they are annotated with ``@torch.jit.script``, similar to how a TorchScript function would be declared: |
| |
| :: |
| |
| @torch.jit.script |
| class MyClass: |
| def __init__(self, x: int): |
| self.x = x |
| |
| def inc(self, val: int): |
| self.x += val |
| |
| |
| **Example 2** |
| |
| A TorchScript custom class type must "declare" all its instance attributes by assignments in ``__init__()``. If an instance attribute is not defined in ``__init__()`` but accessed in other methods of the class, the class cannot be compiled as a TorchScript class, as shown in the following example: |
| |
| :: |
| |
| import torch |
| |
| @torch.jit.script |
| class foo: |
| def __init__(self): |
| self.y = 1 |
| |
| # ERROR: self.x is not defined in __init__ |
| def assign_x(self): |
| self.x = torch.rand(2, 3) |
| |
| The class will fail to compile and issue the following error: |
| |
| :: |
| |
| RuntimeError: |
| Tried to set nonexistent attribute: x. Did you forget to initialize it in __init__()?: |
| def assign_x(self): |
| self.x = torch.rand(2, 3) |
| ~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE |
| |
| **Example 3** |
| |
| In this example, a TorchScript custom class defines a class variable name, which is not allowed: |
| |
| :: |
| |
| import torch |
| |
| @torch.jit.script |
| class MyClass(object): |
| name = "MyClass" |
| def __init__(self, x: int): |
| self.x = x |
| |
| def fn(a: MyClass): |
| return a.name |
| |
| It leads to the following compile-time error: |
| |
| :: |
| |
| RuntimeError: |
| '__torch__.MyClass' object has no attribute or method 'name'. Did you forget to initialize an attribute in __init__()?: |
| File "test-class2.py", line 10 |
| def fn(a: MyClass): |
| return a.name |
| ~~~~~~ <--- HERE |
| |
| Enum Type |
| ^^^^^^^^^ |
| |
| Like custom classes, semantics of the enum type are user-defined and the entire class definition must be compilable to TorchScript IR and adhere to TorchScript type-checking rules. |
| |
| :: |
| |
| TSEnumDef ::= "class" Identifier "(enum.Enum | TSEnumType)" ":" |
| ( MemberIdentifier "=" Value )+ |
| ( MethodDefinition )* |
| |
| Where: |
| |
| * Value must be a TorchScript literal of type ``int``, ``float``, or ``str``, and must be of the same TorchScript type. |
| * ``TSEnumType`` is the name of a TorchScript enumerated type. Similar to Python enum, TorchScript allows restricted ``Enum`` subclassing, that is, subclassing an enumerated is allowed only if it does not define any members. |
| |
| **Compared to Python** |
| |
| |
| * TorchScript supports only ``enum.Enum``. It does not support other variations such as ``enum.IntEnum``, ``enum.Flag``, ``enum.IntFlag``, and ``enum.auto``. |
| * Values of TorchScript enum members must be of the same type and can only be ``int``, ``float``, or ``str`` types, whereas Python enum members can be of any type. |
| * Enums containing methods are ignored in TorchScript. |
| |
| **Example 1** |
| |
| The following example defines the class ``Color`` as an ``Enum`` type: |
| |
| :: |
| |
| import torch |
| from enum import Enum |
| |
| class Color(Enum): |
| RED = 1 |
| GREEN = 2 |
| |
| def enum_fn(x: Color, y: Color) -> bool: |
| if x == Color.RED: |
| return True |
| return x == y |
| |
| m = torch.jit.script(enum_fn) |
| |
| print("Eager: ", enum_fn(Color.RED, Color.GREEN)) |
| print("TorchScript: ", m(Color.RED, Color.GREEN)) |
| |
| **Example 2** |
| |
| The following example shows the case of restricted enum subclassing, where ``BaseColor`` does not define any member, thus can be subclassed by ``Color``: |
| |
| :: |
| |
| import torch |
| from enum import Enum |
| |
| class BaseColor(Enum): |
| def foo(self): |
| pass |
| |
| class Color(BaseColor): |
| RED = 1 |
| GREEN = 2 |
| |
| def enum_fn(x: Color, y: Color) -> bool: |
| if x == Color.RED: |
| return True |
| return x == y |
| |
| m = torch.jit.script(enum_fn) |
| |
| print("TorchScript: ", m(Color.RED, Color.GREEN)) |
| print("Eager: ", enum_fn(Color.RED, Color.GREEN)) |
| |
| TorchScript Module Class |
| ^^^^^^^^^^^^^^^^^^^^^^^^ |
| |
| ``TSModuleType`` is a special class type that is inferred from object instances that are created outside TorchScript. ``TSModuleType`` is named by the Python class of the object instance. The ``__init__()`` method of the Python class is not considered a TorchScript method, so it does not have to comply with TorchScript’s type-checking rules. |
| |
| The type schema of a module instance class is constructed directly from an instance object (created outside the scope of TorchScript) rather than inferred from ``__init__()`` like custom classes. It is possible that two objects of the same instance class type follow two different type schemas. |
| |
| In this sense, ``TSModuleType`` is not really a static type. Therefore, for type safety considerations, ``TSModuleType`` cannot be used in a TorchScript type annotation or be composed with ``TSType``. |
| |
| Module Instance Class |
| ^^^^^^^^^^^^^^^^^^^^^ |
| |
| TorchScript module type represents the type schema of a user-defined PyTorch module instance. When scripting a PyTorch module, the module object is always created outside TorchScript (i.e., passed in as parameter to ``forward``). The Python module class is treated as a module instance class, so the ``__init__()`` method of the Python module class is not subject to the type-checking rules of TorchScript. |
| |
| :: |
| |
| TSModuleType ::= "class" Identifier "(torch.nn.Module)" ":" |
| ClassBodyDefinition |
| |
| Where: |
| |
| * ``forward()`` and other methods decorated with ``@torch.jit.export`` must be compilable to TorchScript IR and subject to TorchScript’s type-checking rules. |
| |
| Unlike custom classes, only the forward method and other methods decorated with ``@torch.jit.export`` of the module type need to be compilable. Most notably, ``__init__()`` is not considered a TorchScript method. Consequently, module type constructors cannot be invoked within the scope of TorchScript. Instead, TorchScript module objects are always constructed outside and passed into ``torch.jit.script(ModuleObj)``. |
| |
| **Example 1** |
| |
| This example illustrates a few features of module types: |
| |
| * The ``TestModule`` instance is created outside the scope of TorchScript (i.e., before invoking ``torch.jit.script``). |
| * ``__init__()`` is not considered a TorchScript method, therefore, it does not have to be annotated and can contain arbitrary Python code. In addition, the ``__init__()`` method of an instance class cannot be invoked in TorchScript code. Because ``TestModule`` instances are instantiated in Python, in this example, ``TestModule(2.0)`` and ``TestModule(2)`` create two instances with different types for its data attributes. ``self.x`` is of type ``float`` for ``TestModule(2.0)``, whereas ``self.y`` is of type ``int`` for ``TestModule(2.0)``. |
| * TorchScript automatically compiles other methods (e.g., ``mul()``) invoked by methods annotated via ``@torch.jit.export`` or ``forward()`` methods. |
| * Entry-points to a TorchScript program are either ``forward()`` of a module type, functions annotated as ``torch.jit.script``, or methods annotated as ``torch.jit.export``. |
| |
| .. testcode:: |
| |
| import torch |
| |
| class TestModule(torch.nn.Module): |
| def __init__(self, v): |
| super().__init__() |
| self.x = v |
| |
| def forward(self, inc: int): |
| return self.x + inc |
| |
| m = torch.jit.script(TestModule(1)) |
| print(f"First instance: {m(3)}") |
| |
| m = torch.jit.script(TestModule(torch.ones([5]))) |
| print(f"Second instance: {m(3)}") |
| |
| The example above produces the following output: |
| |
| .. testoutput:: |
| |
| First instance: 4 |
| Second instance: tensor([4., 4., 4., 4., 4.]) |
| |
| **Example 2** |
| |
| The following example shows an incorrect usage of module type. Specifically, this example invokes the constructor of ``TestModule`` inside the scope of TorchScript: |
| |
| .. testcode:: |
| |
| import torch |
| |
| class TestModule(torch.nn.Module): |
| def __init__(self, v): |
| super().__init__() |
| self.x = v |
| |
| def forward(self, x: int): |
| return self.x + x |
| |
| class MyModel: |
| def __init__(self, v: int): |
| self.val = v |
| |
| @torch.jit.export |
| def doSomething(self, val: int) -> int: |
| # error: should not invoke the constructor of module type |
| myModel = TestModule(self.val) |
| return myModel(val) |
| |
| # m = torch.jit.script(MyModel(2)) # Results in below RuntimeError |
| # RuntimeError: Could not get name of python class object |
| |
| .. _type_annotation: |
| |
| |
| Type Annotation |
| ~~~~~~~~~~~~~~~ |
| Since TorchScript is statically typed, programmers need to annotate types at *strategic points* of TorchScript code so that every local variable or |
| instance data attribute has a static type, and every function and method has a statically typed signature. |
| |
| When to Annotate Types |
| ^^^^^^^^^^^^^^^^^^^^^^ |
| In general, type annotations are only needed in places where static types cannot be automatically inferred (e.g., parameters or sometimes return types to |
| methods or functions). Types of local variables and data attributes are often automatically inferred from their assignment statements. Sometimes an inferred type |
| may be too restrictive, e.g., ``x`` being inferred as ``NoneType`` through assignment ``x = None``, whereas ``x`` is actually used as an ``Optional``. In such |
| cases, type annotations may be needed to overwrite auto inference, e.g., ``x: Optional[int] = None``. Note that it is always safe to type annotate a local variable |
| or data attribute even if its type can be automatically inferred. The annotated type must be congruent with TorchScript’s type-checking. |
| |
| When a parameter, local variable, or data attribute is not type annotated and its type cannot be automatically inferred, TorchScript assumes it to be a |
| default type of ``TensorType``, ``List[TensorType]``, or ``Dict[str, TensorType]``. |
| |
| Annotate Function Signature |
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| Since a parameter may not be automatically inferred from the body of the function (including both functions and methods), they need to be type annotated. Otherwise, they assume the default type ``TensorType``. |
| |
| TorchScript supports two styles for method and function signature type annotation: |
| |
| * **Python3-style** annotates types directly on the signature. As such, it allows individual parameters to be left unannotated (whose type will be the default type of ``TensorType``), or allows the return type to be left unannotated (whose type will be automatically inferred). |
| |
| |
| :: |
| |
| Python3Annotation ::= "def" Identifier [ "(" ParamAnnot* ")" ] [ReturnAnnot] ":" |
| FuncOrMethodBody |
| ParamAnnot ::= Identifier [ ":" TSType ] "," |
| ReturnAnnot ::= "->" TSType |
| |
| Note that when using Python3 style, the type ``self`` is automatically inferred and should not be annotated. |
| |
| * **Mypy style** annotates types as a comment right below the function/method declaration. In the Mypy style, since parameter names do not appear in the annotation, all parameters have to be annotated. |
| |
| |
| :: |
| |
| MyPyAnnotation ::= "# type:" "(" ParamAnnot* ")" [ ReturnAnnot ] |
| ParamAnnot ::= TSType "," |
| ReturnAnnot ::= "->" TSType |
| |
| **Example 1** |
| |
| In this example: |
| |
| * ``a`` is not annotated and assumes the default type of ``TensorType``. |
| * ``b`` is annotated as type ``int``. |
| * The return type is not annotated and is automatically inferred as type ``TensorType`` (based on the type of the value being returned). |
| |
| :: |
| |
| import torch |
| |
| def f(a, b: int): |
| return a+b |
| |
| m = torch.jit.script(f) |
| print("TorchScript:", m(torch.ones([6]), 100)) |
| |
| **Example 2** |
| |
| The following example uses Mypy style annotation. Note that parameters or return values must be annotated even if some of |
| them assume the default type. |
| |
| :: |
| |
| import torch |
| |
| def f(a, b): |
| # type: (torch.Tensor, int) → torch.Tensor |
| return a+b |
| |
| m = torch.jit.script(f) |
| print("TorchScript:", m(torch.ones([6]), 100)) |
| |
| |
| Annotate Variables and Data Attributes |
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| In general, types of data attributes (including class and instance data attributes) and local variables can be automatically inferred from assignment statements. |
| Sometimes, however, if a variable or attribute is associated with values of different types (e.g., as ``None`` or ``TensorType``), then they may need to be explicitly |
| type annotated as a *wider* type such as ``Optional[int]`` or ``Any``. |
| |
| Local Variables |
| """"""""""""""" |
| Local variables can be annotated according to Python3 typing module annotation rules, i.e., |
| |
| :: |
| |
| LocalVarAnnotation ::= Identifier [":" TSType] "=" Expr |
| |
| In general, types of local variables can be automatically inferred. In some cases, however, you may need to annotate a multi-type for local variables |
| that may be associated with different concrete types. Typical multi-types include ``Optional[T]`` and ``Any``. |
| |
| **Example** |
| |
| :: |
| |
| import torch |
| |
| def f(a, setVal: bool): |
| value: Optional[torch.Tensor] = None |
| if setVal: |
| value = a |
| return value |
| |
| ones = torch.ones([6]) |
| m = torch.jit.script(f) |
| print("TorchScript:", m(ones, True), m(ones, False)) |
| |
| Instance Data Attributes |
| """""""""""""""""""""""" |
| For ``ModuleType`` classes, instance data attributes can be annotated according to Python3 typing module annotation rules. Instance data attributes can be annotated (optionally) as final |
| via ``Final``. |
| |
| :: |
| |
| "class" ClassIdentifier "(torch.nn.Module):" |
| InstanceAttrIdentifier ":" ["Final("] TSType [")"] |
| ... |
| |
| Where: |
| |
| * ``InstanceAttrIdentifier`` is the name of an instance attribute. |
| * ``Final`` indicates that the attribute cannot be re-assigned outside of ``__init__`` or overridden in subclasses. |
| |
| **Example** |
| |
| :: |
| |
| import torch |
| |
| class MyModule(torch.nn.Module): |
| offset_: int |
| |
| def __init__(self, offset): |
| self.offset_ = offset |
| |
| ... |
| |
| |
| |
| Type Annotation APIs |
| ^^^^^^^^^^^^^^^^^^^^ |
| |
| ``torch.jit.annotate(T, expr)`` |
| """"""""""""""""""""""""""""""" |
| This API annotates type ``T`` to an expression ``expr``. This is often used when the default type of an expression is not the type intended by the programmer. |
| For instance, an empty list (dictionary) has the default type of ``List[TensorType]`` (``Dict[TensorType, TensorType]``), but sometimes it may be used to initialize |
| a list of some other types. Another common use case is for annotating the return type of ``tensor.tolist()``. Note, however, that it cannot be used to annotate |
| the type of a module attribute in `__init__`; ``torch.jit.Attribute`` should be used for this instead. |
| |
| **Example** |
| |
| In this example, ``[]`` is declared as a list of integers via ``torch.jit.annotate`` (instead of assuming ``[]`` to be the default type of ``List[TensorType]``): |
| |
| :: |
| |
| import torch |
| from typing import List |
| |
| def g(l: List[int], val: int): |
| l.append(val) |
| return l |
| |
| def f(val: int): |
| l = g(torch.jit.annotate(List[int], []), val) |
| return l |
| |
| m = torch.jit.script(f) |
| print("Eager:", f(3)) |
| print("TorchScript:", m(3)) |
| |
| |
| See :meth:`torch.jit.annotate` for more information. |
| |
| |
| Type Annotation Appendix |
| ^^^^^^^^^^^^^^^^^^^^^^^^ |
| |
| TorchScript Type System Definition |
| """""""""""""""""""""""""""""""""" |
| |
| :: |
| |
| TSAllType ::= TSType | TSModuleType |
| TSType ::= TSMetaType | TSPrimitiveType | TSStructuralType | TSNominalType |
| |
| TSMetaType ::= "Any" |
| TSPrimitiveType ::= "int" | "float" | "double" | "complex" | "bool" | "str" | "None" |
| |
| TSStructuralType ::= TSTuple | TSNamedTuple | TSList | TSDict | TSOptional | |
| TSUnion | TSFuture | TSRRef | TSAwait |
| TSTuple ::= "Tuple" "[" (TSType ",")* TSType "]" |
| TSNamedTuple ::= "namedtuple" "(" (TSType ",")* TSType ")" |
| TSList ::= "List" "[" TSType "]" |
| TSOptional ::= "Optional" "[" TSType "]" |
| TSUnion ::= "Union" "[" (TSType ",")* TSType "]" |
| TSFuture ::= "Future" "[" TSType "]" |
| TSRRef ::= "RRef" "[" TSType "]" |
| TSAwait ::= "Await" "[" TSType "]" |
| TSDict ::= "Dict" "[" KeyType "," TSType "]" |
| KeyType ::= "str" | "int" | "float" | "bool" | TensorType | "Any" |
| |
| TSNominalType ::= TSBuiltinClasses | TSCustomClass | TSEnum |
| TSBuiltinClass ::= TSTensor | "torch.device" | "torch.stream"| |
| "torch.dtype" | "torch.nn.ModuleList" | |
| "torch.nn.ModuleDict" | ... |
| TSTensor ::= "torch.tensor" and subclasses |
| |
| Unsupported Typing Constructs |
| """"""""""""""""""""""""""""" |
| TorchScript does not support all features and types of the Python3 `typing <https://docs.python.org/3/library/typing.html#module-typing>`_ module. |
| Any functionality from the `typing <https://docs.python.org/3/library/typing.html#module-typing>`_ module that is not explicitly specified in this |
| documentation is unsupported. The following table summarizes ``typing`` constructs that are either unsupported or supported with restrictions in TorchScript. |
| |
| ============================= ================ |
| Item Description |
| ----------------------------- ---------------- |
| ``typing.Any`` In development |
| ``typing.NoReturn`` Not supported |
| ``typing.Callable`` Not supported |
| ``typing.Literal`` Not supported |
| ``typing.ClassVar`` Not supported |
| ``typing.Final`` Supported for module attributes, class attribute, and annotations, but not for functions. |
| ``typing.AnyStr`` Not supported |
| ``typing.overload`` In development |
| Type aliases Not supported |
| Nominal typing In development |
| Structural typing Not supported |
| NewType Not supported |
| Generics Not supported |
| ============================= ================ |
| |
| |
| .. _expressions: |
| |
| |
| Expressions |
| ~~~~~~~~~~~ |
| |
| The following section describes the grammar of expressions that are supported in TorchScript. |
| It is modeled after `the expressions chapter of the Python language reference <https://docs.python.org/3/reference/expressions.html>`_. |
| |
| Arithmetic Conversions |
| ^^^^^^^^^^^^^^^^^^^^^^ |
| There are a number of implicit type conversions that are performed in TorchScript: |
| |
| |
| * A ``Tensor`` with a ``float`` or ``int`` data type can be implicitly converted to an instance of ``FloatType`` or ``IntType`` provided that it has a size of 0, does not have ``require_grad`` set to ``True``, and will not require narrowing. |
| * Instances of ``StringType`` can be implicitly converted to ``DeviceType``. |
| * The implicit conversion rules from the two bullet points above can be applied to instances of ``TupleType`` to produce instances of ``ListType`` with the appropriate contained type. |
| |
| |
| Explicit conversions can be invoked using the ``float``, ``int``, ``bool``, and ``str`` built-in functions |
| that accept primitive data types as arguments and can accept user-defined types if they implement |
| ``__bool__``, ``__str__``, etc. |
| |
| |
| Atoms |
| ^^^^^ |
| Atoms are the most basic elements of expressions. |
| |
| :: |
| |
| atom ::= identifier | literal | enclosure |
| enclosure ::= parenth_form | list_display | dict_display |
| |
| Identifiers |
| """"""""""" |
| The rules that dictate what is a legal identifier in TorchScript are the same as |
| their `Python counterparts <https://docs.python.org/3/reference/lexical_analysis.html#identifiers>`_. |
| |
| Literals |
| """""""" |
| |
| :: |
| |
| literal ::= stringliteral | integer | floatnumber |
| |
| Evaluation of a literal yields an object of the appropriate type with the specific value |
| (with approximations applied as necessary for floats). Literals are immutable, and multiple evaluations |
| of identical literals may obtain the same object or distinct objects with the same value. |
| `stringliteral <https://docs.python.org/3/reference/lexical_analysis.html#string-and-bytes-literals>`_, |
| `integer <https://docs.python.org/3/reference/lexical_analysis.html#integer-literals>`_, and |
| `floatnumber <https://docs.python.org/3/reference/lexical_analysis.html#floating-point-literals>`_ |
| are defined in the same way as their Python counterparts. |
| |
| Parenthesized Forms |
| """"""""""""""""""" |
| |
| :: |
| |
| parenth_form ::= '(' [expression_list] ')' |
| |
| A parenthesized expression list yields whatever the expression list yields. If the list contains at least one |
| comma, it yields a ``Tuple``; otherwise, it yields the single expression inside the expression list. An empty |
| pair of parentheses yields an empty ``Tuple`` object (``Tuple[]``). |
| |
| List and Dictionary Displays |
| """""""""""""""""""""""""""" |
| |
| :: |
| |
| list_comprehension ::= expression comp_for |
| comp_for ::= 'for' target_list 'in' or_expr |
| list_display ::= '[' [expression_list | list_comprehension] ']' |
| dict_display ::= '{' [key_datum_list | dict_comprehension] '}' |
| key_datum_list ::= key_datum (',' key_datum)* |
| key_datum ::= expression ':' expression |
| dict_comprehension ::= key_datum comp_for |
| |
| Lists and dicts can be constructed by either listing the container contents explicitly or by providing |
| instructions on how to compute them via a set of looping instructions (i.e. a *comprehension*). A comprehension |
| is semantically equivalent to using a for loop and appending to an ongoing list. |
| Comprehensions implicitly create their own scope to make sure that the items of the target list do not leak into the |
| enclosing scope. In the case that container items are explicitly listed, the expressions in the expression list |
| are evaluated left-to-right. If a key is repeated in a ``dict_display`` that has a ``key_datum_list``, the |
| resultant dictionary uses the value from the rightmost datum in the list that uses the repeated key. |
| |
| Primaries |
| ^^^^^^^^^ |
| |
| :: |
| |
| primary ::= atom | attributeref | subscription | slicing | call |
| |
| |
| Attribute References |
| """""""""""""""""""" |
| |
| :: |
| |
| attributeref ::= primary '.' identifier |
| |
| |
| The ``primary`` must evaluate to an object of a type that supports attribute references that have an attribute named |
| ``identifier``. |
| |
| Subscriptions |
| """"""""""""" |
| |
| :: |
| |
| subscription ::= primary '[' expression_list ']' |
| |
| |
| The ``primary`` must evaluate to an object that supports subscription. |
| |
| * If the primary is a ``List``, ``Tuple``, or ``str``, the expression list must evaluate to an integer or slice. |
| * If the primary is a ``Dict``, the expression list must evaluate to an object of the same type as the key type of the ``Dict``. |
| * If the primary is a ``ModuleList``, the expression list must be an ``integer`` literal. |
| * If the primary is a ``ModuleDict``, the expression must be a ``stringliteral``. |
| |
| |
| Slicings |
| """""""" |
| A slicing selects a range of items in a ``str``, ``Tuple``, ``List``, or ``Tensor``. Slicings may be used as |
| expressions or targets in assignment or ``del`` statements. |
| |
| :: |
| |
| slicing ::= primary '[' slice_list ']' |
| slice_list ::= slice_item (',' slice_item)* [','] |
| slice_item ::= expression | proper_slice |
| proper_slice ::= [expression] ':' [expression] [':' [expression] ] |
| |
| Slicings with more than one slice item in their slice lists can only be used with primaries that evaluate to an |
| object of type ``Tensor``. |
| |
| |
| Calls |
| """"" |
| |
| :: |
| |
| call ::= primary '(' argument_list ')' |
| argument_list ::= args [',' kwargs] | kwargs |
| args ::= [arg (',' arg)*] |
| kwargs ::= [kwarg (',' kwarg)*] |
| kwarg ::= arg '=' expression |
| arg ::= identifier |
| |
| |
| The ``primary`` must desugar or evaluate to a callable object. All argument expressions are evaluated |
| before the call is attempted. |
| |
| Power Operator |
| ^^^^^^^^^^^^^^ |
| |
| :: |
| |
| power ::= primary ['**' u_expr] |
| |
| |
| The power operator has the same semantics as the built-in pow function (not supported); it computes its |
| left argument raised to the power of its right argument. It binds more tightly than unary operators on the |
| left, but less tightly than unary operators on the right; i.e. ``-2 ** -3 == -(2 ** (-3))``. The left and right |
| operands can be ``int``, ``float`` or ``Tensor``. Scalars are broadcast in the case of scalar-tensor/tensor-scalar |
| exponentiation operations, and tensor-tensor exponentiation is done elementwise without any broadcasting. |
| |
| Unary and Arithmetic Bitwise Operations |
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| |
| :: |
| |
| u_expr ::= power | '-' power | '~' power |
| |
| The unary ``-`` operator yields the negation of its argument. The unary ``~`` operator yields the bitwise inversion |
| of its argument. ``-`` can be used with ``int``, ``float``, and ``Tensor`` of ``int`` and ``float``. |
| ``~`` can only be used with ``int`` and ``Tensor`` of ``int``. |
| |
| Binary Arithmetic Operations |
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| |
| :: |
| |
| m_expr ::= u_expr | m_expr '*' u_expr | m_expr '@' m_expr | m_expr '//' u_expr | m_expr '/' u_expr | m_expr '%' u_expr |
| a_expr ::= m_expr | a_expr '+' m_expr | a_expr '-' m_expr |
| |
| The binary arithmetic operators can operate on ``Tensor``, ``int``, and ``float``. For tensor-tensor ops, both arguments must |
| have the same shape. For scalar-tensor or tensor-scalar ops, the scalar is usually broadcast to the size of the |
| tensor. Division ops can only accept scalars as their right-hand side argument, and do not support broadcasting. |
| The ``@`` operator is for matrix multiplication and only operates on ``Tensor`` arguments. The multiplication operator |
| (``*``) can be used with a list and integer in order to get a result that is the original list repeated a certain |
| number of times. |
| |
| Shifting Operations |
| ^^^^^^^^^^^^^^^^^^^ |
| |
| :: |
| |
| shift_expr ::= a_expr | shift_expr ( '<<' | '>>' ) a_expr |
| |
| |
| These operators accept two ``int`` arguments, two ``Tensor`` arguments, or a ``Tensor`` argument and an ``int`` or |
| ``float`` argument. In all cases, a right shift by ``n`` is defined as floor division by ``pow(2, n)``, and a left shift |
| by ``n`` is defined as multiplication by ``pow(2, n)``. When both arguments are ``Tensors``, they must have the same |
| shape. When one is a scalar and the other is a ``Tensor``, the scalar is logically broadcast to match the size of |
| the ``Tensor``. |
| |
| Binary Bitwise Operations |
| ^^^^^^^^^^^^^^^^^^^^^^^^^ |
| |
| :: |
| |
| and_expr ::= shift_expr | and_expr '&' shift_expr |
| xor_expr ::= and_expr | xor_expr '^' and_expr |
| or_expr ::= xor_expr | or_expr '|' xor_expr |
| |
| |
| The ``&`` operator computes the bitwise AND of its arguments, the ``^`` the bitwise XOR, and the ``|`` the bitwise OR. |
| Both operands must be ``int`` or ``Tensor``, or the left operand must be ``Tensor`` and the right operand must be |
| ``int``. When both operands are ``Tensor``, they must have the same shape. When the right operand is ``int``, and |
| the left operand is ``Tensor``, the right operand is logically broadcast to match the shape of the ``Tensor``. |
| |
| Comparisons |
| ^^^^^^^^^^^ |
| |
| :: |
| |
| comparison ::= or_expr (comp_operator or_expr)* |
| comp_operator ::= '<' | '>' | '==' | '>=' | '<=' | '!=' | 'is' ['not'] | ['not'] 'in' |
| |
| A comparison yields a boolean value (``True`` or ``False``), or if one of the operands is a ``Tensor``, a boolean |
| ``Tensor``. Comparisons can be chained arbitrarily as long as they do not yield boolean ``Tensors`` that have more |
| than one element. ``a op1 b op2 c ...`` is equivalent to ``a op1 b and b op2 c and ...``. |
| |
| Value Comparisons |
| """"""""""""""""" |
| The operators ``<``, ``>``, ``==``, ``>=``, ``<=``, and ``!=`` compare the values of two objects. The two objects generally need to be of |
| the same type, unless there is an implicit type conversion available between the objects. User-defined types can |
| be compared if rich comparison methods (e.g., ``__lt__``) are defined on them. Built-in type comparison works like |
| Python: |
| |
| * Numbers are compared mathematically. |
| * Strings are compared lexicographically. |
| * ``lists``, ``tuples``, and ``dicts`` can be compared only to other ``lists``, ``tuples``, and ``dicts`` of the same type and are compared using the comparison operator of corresponding elements. |
| |
| Membership Test Operations |
| """""""""""""""""""""""""" |
| The operators ``in`` and ``not in`` test for membership. ``x in s`` evaluates to ``True`` if ``x`` is a member of ``s`` and ``False`` otherwise. |
| ``x not in s`` is equivalent to ``not x in s``. This operator is supported for ``lists``, ``dicts``, and ``tuples``, and can be used with |
| user-defined types if they implement the ``__contains__`` method. |
| |
| Identity Comparisons |
| """""""""""""""""""" |
| For all types except ``int``, ``double``, ``bool``, and ``torch.device``, operators ``is`` and ``is not`` test for the object’s identity; |
| ``x is y`` is ``True`` if and only if ``x`` and ``y`` are the same object. For all other types, ``is`` is equivalent to |
| comparing them using ``==``. ``x is not y`` yields the inverse of ``x is y``. |
| |
| Boolean Operations |
| ^^^^^^^^^^^^^^^^^^ |
| |
| :: |
| |
| or_test ::= and_test | or_test 'or' and_test |
| and_test ::= not_test | and_test 'and' not_test |
| not_test ::= 'bool' '(' or_expr ')' | comparison | 'not' not_test |
| |
| User-defined objects can customize their conversion to ``bool`` by implementing a ``__bool__`` method. The operator ``not`` |
| yields ``True`` if its operand is false, ``False`` otherwise. The expression ``x`` and ``y`` first evaluates ``x``; if it is ``False``, its |
| value (``False``) is returned; otherwise, ``y`` is evaluated and its value is returned (``False`` or ``True``). The expression ``x`` or ``y`` |
| first evaluates ``x``; if it is ``True``, its value (``True``) is returned; otherwise, ``y`` is evaluated and its value is returned |
| (``False`` or ``True``). |
| |
| Conditional Expressions |
| ^^^^^^^^^^^^^^^^^^^^^^^ |
| |
| :: |
| |
| conditional_expression ::= or_expr ['if' or_test 'else' conditional_expression] |
| expression ::= conditional_expression |
| |
| The expression ``x if c else y`` first evaluates the condition ``c`` rather than x. If ``c`` is ``True``, ``x`` is |
| evaluated and its value is returned; otherwise, ``y`` is evaluated and its value is returned. As with if-statements, |
| ``x`` and ``y`` must evaluate to a value of the same type. |
| |
| Expression Lists |
| ^^^^^^^^^^^^^^^^ |
| |
| :: |
| |
| expression_list ::= expression (',' expression)* [','] |
| starred_item ::= '*' primary |
| |
| A starred item can only appear on the left-hand side of an assignment statement, e.g., ``a, *b, c = ...``. |
| |
| .. statements: |
| |
| Simple Statements |
| ~~~~~~~~~~~~~~~~~ |
| |
| The following section describes the syntax of simple statements that are supported in TorchScript. |
| It is modeled after `the simple statements chapter of the Python language reference <https://docs.python.org/3/reference/simple_stmts.html>`_. |
| |
| Expression Statements |
| ^^^^^^^^^^^^^^^^^^^^^^ |
| |
| :: |
| |
| expression_stmt ::= starred_expression |
| starred_expression ::= expression | (starred_item ",")* [starred_item] |
| starred_item ::= assignment_expression | "*" or_expr |
| |
| Assignment Statements |
| ^^^^^^^^^^^^^^^^^^^^^^ |
| |
| :: |
| |
| assignment_stmt ::= (target_list "=")+ (starred_expression) |
| target_list ::= target ("," target)* [","] |
| target ::= identifier |
| | "(" [target_list] ")" |
| | "[" [target_list] "]" |
| | attributeref |
| | subscription |
| | slicing |
| | "*" target |
| |
| Augmented Assignment Statements |
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| |
| :: |
| |
| augmented_assignment_stmt ::= augtarget augop (expression_list) |
| augtarget ::= identifier | attributeref | subscription |
| augop ::= "+=" | "-=" | "*=" | "/=" | "//=" | "%=" | |
| "**="| ">>=" | "<<=" | "&=" | "^=" | "|=" |
| |
| |
| Annotated Assignment Statements |
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| :: |
| |
| annotated_assignment_stmt ::= augtarget ":" expression |
| ["=" (starred_expression)] |
| |
| The ``raise`` Statement |
| ^^^^^^^^^^^^^^^^^^^^^^^^ |
| |
| :: |
| |
| raise_stmt ::= "raise" [expression ["from" expression]] |
| |
| Raise statements in TorchScript do not support ``try\except\finally``. |
| |
| The ``assert`` Statement |
| ^^^^^^^^^^^^^^^^^^^^^^^^^ |
| |
| :: |
| |
| assert_stmt ::= "assert" expression ["," expression] |
| |
| Assert statements in TorchScript do not support ``try\except\finally``. |
| |
| The ``return`` Statement |
| ^^^^^^^^^^^^^^^^^^^^^^^^^ |
| |
| :: |
| |
| return_stmt ::= "return" [expression_list] |
| |
| Return statements in TorchScript do not support ``try\except\finally``. |
| |
| The ``del`` Statement |
| ^^^^^^^^^^^^^^^^^^^^^^ |
| |
| :: |
| |
| del_stmt ::= "del" target_list |
| |
| The ``pass`` Statement |
| ^^^^^^^^^^^^^^^^^^^^^^^ |
| |
| :: |
| |
| pass_stmt ::= "pass" |
| |
| The ``print`` Statement |
| ^^^^^^^^^^^^^^^^^^^^^^^^ |
| |
| :: |
| |
| print_stmt ::= "print" "(" expression [, expression] [.format{expression_list}] ")" |
| |
| The ``break`` Statement |
| ^^^^^^^^^^^^^^^^^^^^^^^^ |
| |
| :: |
| |
| break_stmt ::= "break" |
| |
| The ``continue`` Statement: |
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| |
| :: |
| |
| continue_stmt ::= "continue" |
| |
| Compound Statements |
| ~~~~~~~~~~~~~~~~~~~ |
| |
| The following section describes the syntax of compound statements that are supported in TorchScript. |
| The section also highlights how Torchscript differs from regular Python statements. |
| It is modeled after `the compound statements chapter of the Python language reference <https://docs.python.org/3/reference/compound_stmts.html>`_. |
| |
| The ``if`` Statement |
| ^^^^^^^^^^^^^^^^^^^^^ |
| |
| Torchscript supports both basic ``if/else`` and ternary ``if/else``. |
| |
| Basic ``if/else`` Statement |
| """""""""""""""""""""""""""" |
| |
| :: |
| |
| if_stmt ::= "if" assignment_expression ":" suite |
| ("elif" assignment_expression ":" suite) |
| ["else" ":" suite] |
| |
| ``elif`` statements can repeat for an arbitrary number of times, but it needs to be before ``else`` statement. |
| |
| Ternary ``if/else`` Statement |
| """""""""""""""""""""""""""""" |
| |
| :: |
| |
| if_stmt ::= return [expression_list] "if" assignment_expression "else" [expression_list] |
| |
| **Example 1** |
| |
| A ``tensor`` with 1 dimension is promoted to ``bool``: |
| |
| .. testcode:: |
| |
| import torch |
| |
| @torch.jit.script |
| def fn(x: torch.Tensor): |
| if x: # The tensor gets promoted to bool |
| return True |
| return False |
| print(fn(torch.rand(1))) |
| |
| The example above produces the following output: |
| |
| .. testoutput:: |
| |
| True |
| |
| **Example 2** |
| |
| A ``tensor`` with multi dimensions are not promoted to ``bool``: |
| |
| :: |
| |
| import torch |
| |
| # Multi dimensional Tensors error out. |
| |
| @torch.jit.script |
| def fn(): |
| if torch.rand(2): |
| print("Tensor is available") |
| |
| if torch.rand(4,5,6): |
| print("Tensor is available") |
| |
| print(fn()) |
| |
| Running the above code yields the following ``RuntimeError``. |
| |
| :: |
| |
| RuntimeError: The following operation failed in the TorchScript interpreter. |
| Traceback of TorchScript (most recent call last): |
| @torch.jit.script |
| def fn(): |
| if torch.rand(2): |
| ~~~~~~~~~~~~ <--- HERE |
| print("Tensor is available") |
| RuntimeError: Boolean value of Tensor with more than one value is ambiguous |
| |
| If a conditional variable is annotated as ``final``, either the true or false branch is evaluated depending on the evaluation of the conditional variable. |
| |
| **Example 3** |
| |
| In this example, only the True branch is evaluated, since ``a`` is annotated as ``final`` and set to ``True``: |
| |
| :: |
| |
| import torch |
| |
| a : torch.jit.final[Bool] = True |
| |
| if a: |
| return torch.empty(2,3) |
| else: |
| return [] |
| |
| |
| The ``while`` Statement |
| ^^^^^^^^^^^^^^^^^^^^^^^^ |
| |
| :: |
| |
| while_stmt ::= "while" assignment_expression ":" suite |
| |
| `while...else` statements are not supported in Torchscript. It results in a ``RuntimeError``. |
| |
| The ``for-in`` Statement |
| ^^^^^^^^^^^^^^^^^^^^^^^^^ |
| |
| :: |
| |
| for_stmt ::= "for" target_list "in" expression_list ":" suite |
| ["else" ":" suite] |
| |
| ``for...else`` statements are not supported in Torchscript. It results in a ``RuntimeError``. |
| |
| **Example 1** |
| |
| For loops on tuples: these unroll the loop, generating a body for each member of the tuple. The body must type-check correctly for each member. |
| |
| .. testcode:: |
| |
| import torch |
| from typing import Tuple |
| |
| @torch.jit.script |
| def fn(): |
| tup = (3, torch.ones(4)) |
| for x in tup: |
| print(x) |
| |
| fn() |
| |
| The example above produces the following output: |
| |
| .. testoutput:: |
| |
| 3 |
| 1 |
| 1 |
| 1 |
| 1 |
| [ CPUFloatType{4} ] |
| |
| |
| **Example 2** |
| |
| For loops on lists: for loops over a ``nn.ModuleList`` will unroll the body of the loop at compile time, with each member of the module list. |
| |
| :: |
| |
| class SubModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.weight = nn.Parameter(torch.randn(2)) |
| |
| def forward(self, input): |
| return self.weight + input |
| |
| class MyModule(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.mods = torch.nn.ModuleList([SubModule() for i in range(10)]) |
| |
| def forward(self, v): |
| for module in self.mods: |
| v = module(v) |
| return v |
| |
| model = torch.jit.script(MyModule()) |
| |
| The ``with`` Statement |
| ^^^^^^^^^^^^^^^^^^^^^^^ |
| The ``with`` statement is used to wrap the execution of a block with methods defined by a context manager. |
| |
| :: |
| |
| with_stmt ::= "with" with_item ("," with_item) ":" suite |
| with_item ::= expression ["as" target] |
| |
| * If a target was included in the ``with`` statement, the return value from the context manager’s ``__enter__()`` is assigned to it. Unlike python, if an exception caused the suite to be exited, its type, value, and traceback are not passed as arguments to ``__exit__()``. Three ``None`` arguments are supplied. |
| * ``try``, ``except``, and ``finally`` statements are not supported inside ``with`` blocks. |
| * Exceptions raised within ``with`` block cannot be suppressed. |
| |
| The ``tuple`` Statement |
| ^^^^^^^^^^^^^^^^^^^^^^^^ |
| |
| :: |
| |
| tuple_stmt ::= tuple([iterables]) |
| |
| * Iterable types in TorchScript include ``Tensors``, ``lists``, ``tuples``, ``dictionaries``, ``strings``, ``torch.nn.ModuleList``, and ``torch.nn.ModuleDict``. |
| * You cannot convert a List to Tuple by using this built-in function. |
| |
| Unpacking all outputs into a tuple is covered by: |
| |
| :: |
| |
| abc = func() # Function that returns a tuple |
| a,b = func() |
| |
| The ``getattr`` Statement |
| ^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| |
| :: |
| |
| getattr_stmt ::= getattr(object, name[, default]) |
| |
| * Attribute name must be a literal string. |
| * Module type object is not supported (e.g., torch._C). |
| * Custom class object is not supported (e.g., torch.classes.*). |
| |
| The ``hasattr`` Statement |
| ^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| |
| :: |
| |
| hasattr_stmt ::= hasattr(object, name) |
| |
| * Attribute name must be a literal string. |
| * Module type object is not supported (e.g., torch._C). |
| * Custom class object is not supported (e.g., torch.classes.*). |
| |
| The ``zip`` Statement |
| ^^^^^^^^^^^^^^^^^^^^^^ |
| |
| :: |
| |
| zip_stmt ::= zip(iterable1, iterable2) |
| |
| * Arguments must be iterables. |
| * Two iterables of same outer container type but different length are supported. |
| |
| **Example 1** |
| |
| Both the iterables must be of the same container type: |
| |
| .. testcode:: |
| |
| a = [1, 2] # List |
| b = [2, 3, 4] # List |
| zip(a, b) # works |
| |
| **Example 2** |
| |
| This example fails because the iterables are of different container types: |
| |
| :: |
| |
| a = (1, 2) # Tuple |
| b = [2, 3, 4] # List |
| zip(a, b) # Runtime error |
| |
| Running the above code yields the following ``RuntimeError``. |
| |
| :: |
| |
| RuntimeError: Can not iterate over a module list or |
| tuple with a value that does not have a statically determinable length. |
| |
| **Example 3** |
| |
| Two iterables of the same container Type but different data type is supported: |
| |
| .. testcode:: |
| |
| a = [1.3, 2.4] |
| b = [2, 3, 4] |
| zip(a, b) # Works |
| |
| Iterable types in TorchScript include ``Tensors``, ``lists``, ``tuples``, ``dictionaries``, ``strings``, ``torch.nn.ModuleList``, and ``torch.nn.ModuleDict``. |
| |
| The ``enumerate`` Statement |
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| |
| :: |
| |
| enumerate_stmt ::= enumerate([iterable]) |
| |
| * Arguments must be iterables. |
| * Iterable types in TorchScript include ``Tensors``, ``lists``, ``tuples``, ``dictionaries``, ``strings``, ``torch.nn.ModuleList`` and ``torch.nn.ModuleDict``. |
| |
| |
| .. _python-values-torch-script: |
| |
| Python Values |
| ~~~~~~~~~~~~~ |
| |
| .. _python-builtin-functions-values-resolution: |
| |
| Resolution Rules |
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| When given a Python value, TorchScript attempts to resolve it in the following five different ways: |
| |
| * Compilable Python Implementation: |
| * When a Python value is backed by a Python implementation that can be compiled by TorchScript, TorchScript compiles and uses the underlying Python implementation. |
| * Example: ``torch.jit.Attribute`` |
| * Op Python Wrapper: |
| * When a Python value is a wrapper of a native PyTorch op, TorchScript emits the corresponding operator. |
| * Example: ``torch.jit._logging.add_stat_value`` |
| * Python Object Identity Match: |
| * For a limited set of ``torch.*`` API calls (in the form of Python values) that TorchScript supports, TorchScript attempts to match a Python value against each item in the set. |
| * When matched, TorchScript generates a corresponding ``SugaredValue`` instance that contains lowering logic for these values. |
| * Example: ``torch.jit.isinstance()`` |
| * Name Match: |
| * For Python built-in functions and constants, TorchScript identifies them by name, and creates a corresponding ``SugaredValue`` instance that implements their functionality. |
| * Example: ``all()`` |
| * Value Snapshot: |
| * For Python values from unrecognized modules, TorchScript attempts to take a snapshot of the value and converts it to a constant in the graph of the function(s) or method(s) that are being compiled. |
| * Example: ``math.pi`` |
| |
| |
| |
| .. _python-builtin-functions-support: |
| |
| Python Built-in Functions Support |
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| .. list-table:: TorchScript Support for Python Built-in Functions |
| :widths: 25 25 50 |
| :header-rows: 1 |
| |
| * - Built-in Function |
| - Support Level |
| - Notes |
| * - ``abs()`` |
| - Partial |
| - Only supports ``Tensor``/``Int``/``Float`` type inputs. | Doesn't honor ``__abs__`` override. |
| * - ``all()`` |
| - Full |
| - |
| * - ``any()`` |
| - Full |
| - |
| * - ``ascii()`` |
| - None |
| - |
| * - ``bin()`` |
| - Partial |
| - Only supports ``Int`` type input. |
| * - ``bool()`` |
| - Partial |
| - Only supports ``Tensor``/``Int``/``Float`` type inputs. |
| * - ``breakpoint()`` |
| - None |
| - |
| * - ``bytearray()`` |
| - None |
| - |
| * - ``bytes()`` |
| - None |
| - |
| * - ``callable()`` |
| - None |
| - |
| * - ``chr()`` |
| - Partial |
| - Only ASCII character set is supported. |
| * - ``classmethod()`` |
| - Full |
| - |
| * - ``compile()`` |
| - None |
| - |
| * - ``complex()`` |
| - None |
| - |
| * - ``delattr()`` |
| - None |
| - |
| * - ``dict()`` |
| - Full |
| - |
| * - ``dir()`` |
| - None |
| - |
| * - ``divmod()`` |
| - Full |
| - |
| * - ``enumerate()`` |
| - Full |
| - |
| * - ``eval()`` |
| - None |
| - |
| * - ``exec()`` |
| - None |
| - |
| * - ``filter()`` |
| - None |
| - |
| * - ``float()`` |
| - Partial |
| - Doesn't honor ``__index__`` override. |
| * - ``format()`` |
| - Partial |
| - Manual index specification not supported. | Format type modifier not supported. |
| * - ``frozenset()`` |
| - None |
| - |
| * - ``getattr()`` |
| - Partial |
| - Attribute name must be string literal. |
| * - ``globals()`` |
| - None |
| - |
| * - ``hasattr()`` |
| - Partial |
| - Attribute name must be string literal. |
| * - ``hash()`` |
| - Full |
| - ``Tensor``'s hash is based on identity not numeric value. |
| * - ``hex()`` |
| - Partial |
| - Only supports ``Int`` type input. |
| * - ``id()`` |
| - Full |
| - Only supports ``Int`` type input. |
| * - ``input()`` |
| - None |
| - |
| * - ``int()`` |
| - Partial |
| - ``base`` argument not supported. | Doesn't honor ``__index__`` override. |
| * - ``isinstance()`` |
| - Full |
| - ``torch.jit.isintance`` provides better support when checking against container types like ``Dict[str, int]``. |
| * - ``issubclass()`` |
| - None |
| - |
| * - ``iter()`` |
| - None |
| - |
| * - ``len()`` |
| - Full |
| - |
| * - ``list()`` |
| - Full |
| - |
| * - ``ord()`` |
| - Partial |
| - Only ASCII character set is supported. |
| * - ``pow()`` |
| - Full |
| - |
| * - ``print()`` |
| - Partial |
| - ``separate``, ``end`` and ``file`` arguments are not supported. |
| * - ``property()`` |
| - None |
| - |
| * - ``range()`` |
| - Full |
| - |
| * - ``repr()`` |
| - None |
| - |
| * - ``reversed()`` |
| - None |
| - |
| * - ``round()`` |
| - Partial |
| - ``ndigits`` argument is not supported. |
| * - ``set()`` |
| - None |
| - |
| * - ``setattr()`` |
| - None |
| - |
| * - ``slice()`` |
| - Full |
| - |
| * - ``sorted()`` |
| - Partial |
| - ``key`` argument is not supported. |
| * - ``staticmethod()`` |
| - Full |
| - |
| * - ``str()`` |
| - Partial |
| - ``encoding`` and ``errors`` arguments are not supported. |
| * - ``sum()`` |
| - Full |
| - |
| * - ``super()`` |
| - Partial |
| - It can only be used in ``nn.Module``'s ``__init__`` method. |
| * - ``type()`` |
| - None |
| - |
| * - ``vars()`` |
| - None |
| - |
| * - ``zip()`` |
| - Full |
| - |
| * - ``__import__()`` |
| - None |
| - |
| |
| .. _python-builtin-values-support: |
| |
| Python Built-in Values Support |
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| .. list-table:: TorchScript Support for Python Built-in Values |
| :widths: 25 25 50 |
| :header-rows: 1 |
| |
| * - Built-in Value |
| - Support Level |
| - Notes |
| * - ``False`` |
| - Full |
| - |
| * - ``True`` |
| - Full |
| - |
| * - ``None`` |
| - Full |
| - |
| * - ``NotImplemented`` |
| - None |
| - |
| * - ``Ellipsis`` |
| - Full |
| - |
| |
| |
| .. _torch_apis_in_torchscript: |
| |
| torch.* APIs |
| ~~~~~~~~~~~~ |
| |
| .. _torch_apis_in_torchscript_rpc: |
| |
| Remote Procedure Calls |
| ^^^^^^^^^^^^^^^^^^^^^^ |
| |
| TorchScript supports a subset of RPC APIs that supports running a function on |
| a specified remote worker instead of locally. |
| |
| Specifically, following APIs are fully supported: |
| |
| - ``torch.distributed.rpc.rpc_sync()`` |
| - ``rpc_sync()`` makes a blocking RPC call to run a function on a remote worker. RPC messages are sent and received in parallel to execution of Python code. |
| - More details about its usage and examples can be found in :meth:`~torch.distributed.rpc.rpc_sync`. |
| |
| - ``torch.distributed.rpc.rpc_async()`` |
| - ``rpc_async()`` makes a non-blocking RPC call to run a function on a remote worker. RPC messages are sent and received in parallel to execution of Python code. |
| - More details about its usage and examples can be found in :meth:`~torch.distributed.rpc.rpc_async`. |
| - ``torch.distributed.rpc.remote()`` |
| - ``remote.()`` executes a remote call on a worker and gets a Remote Reference ``RRef`` as the return value. |
| - More details about its usage and examples can be found in :meth:`~torch.distributed.rpc.remote`. |
| |
| .. _torch_apis_in_torchscript_async: |
| |
| Asynchronous Execution |
| ^^^^^^^^^^^^^^^^^^^^^^ |
| |
| TorchScript enables you to create asynchronous computation tasks to make better use |
| of computation resources. This is done via supporting a list of APIs that are |
| only usable within TorchScript: |
| |
| - ``torch.jit.fork()`` |
| - Creates an asynchronous task executing func and a reference to the value of the result of this execution. Fork will return immediately. |
| - Synonymous to ``torch.jit._fork()``, which is only kept for backward compatibility reasons. |
| - More details about its usage and examples can be found in :meth:`~torch.jit.fork`. |
| - ``torch.jit.wait()`` |
| - Forces completion of a ``torch.jit.Future[T]`` asynchronous task, returning the result of the task. |
| - Synonymous to ``torch.jit._wait()``, which is only kept for backward compatibility reasons. |
| - More details about its usage and examples can be found in :meth:`~torch.jit.wait`. |
| |
| |
| .. _torch_apis_in_torchscript_annotation: |
| |
| Type Annotations |
| ^^^^^^^^^^^^^^^^ |
| |
| TorchScript is statically-typed. It provides and supports a set of utilities to help annotate variables and attributes: |
| |
| - ``torch.jit.annotate()`` |
| - Provides a type hint to TorchScript where Python 3 style type hints do not work well. |
| - One common example is to annotate type for expressions like ``[]``. ``[]`` is treated as ``List[torch.Tensor]`` by default. When a different type is needed, you can use this code to hint TorchScript: ``torch.jit.annotate(List[int], [])``. |
| - More details can be found in :meth:`~torch.jit.annotate` |
| - ``torch.jit.Attribute`` |
| - Common use cases include providing type hint for ``torch.nn.Module`` attributes. Because their ``__init__`` methods are not parsed by TorchScript, ``torch.jit.Attribute`` should be used instead of ``torch.jit.annotate`` in the module's ``__init__`` methods. |
| - More details can be found in :meth:`~torch.jit.Attribute` |
| - ``torch.jit.Final`` |
| - An alias for Python's ``typing.Final``. ``torch.jit.Final`` is kept only for backward compatibility reasons. |
| |
| |
| .. _torch_apis_in_torchscript_meta_programming: |
| |
| Meta Programming |
| ^^^^^^^^^^^^^^^^ |
| |
| TorchScript provides a set of utilities to facilitate meta programming: |
| |
| - ``torch.jit.is_scripting()`` |
| - Returns a boolean value indicating whether the current program is compiled by ``torch.jit.script`` or not. |
| - When used in an ``assert`` or an ``if`` statement, the scope or branch where ``torch.jit.is_scripting()`` evaluates to ``False`` is not compiled. |
| - Its value can be evaluated statically at compile time, thus commonly used in ``if`` statements to stop TorchScript from compiling one of the branches. |
| - More details and examples can be found in :meth:`~torch.jit.is_scripting` |
| - ``torch.jit.is_tracing()`` |
| - Returns a boolean value indicating whether the current program is traced by ``torch.jit.trace`` / ``torch.jit.trace_module`` or not. |
| - More details can be found in :meth:`~torch.jit.is_tracing` |
| - ``@torch.jit.ignore`` |
| - This decorator indicates to the compiler that a function or method should be ignored and left as a Python function. |
| - This allows you to leave code in your model that is not yet TorchScript compatible. |
| - If a function decorated by ``@torch.jit.ignore`` is called from TorchScript, ignored functions will dispatch the call to the Python interpreter. |
| - Models with ignored functions cannot be exported. |
| - More details and examples can be found in :meth:`~torch.jit.ignore` |
| - ``@torch.jit.unused`` |
| - This decorator indicates to the compiler that a function or method should be ignored and replaced with the raising of an exception. |
| - This allows you to leave code in your model that is not yet TorchScript compatible and still export your model. |
| - If a function decorated by ``@torch.jit.unused`` is called from TorchScript, a runtime error will be raised. |
| - More details and examples can be found in :meth:`~torch.jit.unused` |
| |
| .. _torch_apis_in_torchscript_type_refinement: |
| |
| Type Refinement |
| ^^^^^^^^^^^^^^^ |
| |
| - ``torch.jit.isinstance()`` |
| - Returns a boolean indicating whether a variable is of the specified type. |
| - More details about its usage and examples can be found in :meth:`~torch.jit.isinstance`. |