| .. _export.ir_spec: |
| |
| torch.export IR Specification |
| ============================= |
| |
| Export IR is an intermediate representation (IR) for compilers, which bears |
| similarities to MLIR and TorchScript. It is specifically designed to express the |
| semantics of PyTorch programs. Export IR primarily represents computation in a |
| streamlined list of operations, with limited support for dynamism such as |
| control flows. |
| |
| To create an Export IR graph, a frontend can be used that soundly captures a |
| PyTorch program via a trace-specializing mechanism. The resulting Export IR can |
| then be optimized and executed by a backend. This can be done today through |
| :func:`torch.export.export`. |
| |
| The key concepts that will be covered in this document include: |
| |
| - ExportedProgram: the data structure containing the Export IR program |
| - Graph: which consists of a list of nodes. |
| - Nodes: which represents operations, control flow, and metadata stored on this node. |
| - Values are produced and consumed by nodes. |
| - Types are associated with values and nodes. |
| - The size and memory layout of values are also defined. |
| |
| Assumptions |
| ------------ |
| |
| This doc assumes that the audience is sufficiently familiar with PyTorch, |
| specifically with :class:`torch.fx` and its related toolings. Thus it will stop |
| describing contents present in :class:`torch.fx` documentation and paper. |
| |
| What is Export IR |
| ----------------- |
| |
| Export IR is a graph-based intermediate representation IR of PyTorch programs. |
| Export IR is realized on top of :class:`torch.fx.Graph`. In other words, **all |
| Export IR graphs are also valid FX graphs**, and if interpreted using standard |
| FX semantics, Export IR can be interpreted soundly. One implication is that an |
| exported graph can be converted to a valid Python program via standard FX |
| codegen. |
| |
| This documentation will primarily focus on highlighting areas where Export IR |
| differs from FX in terms of its strictness, while skipping parts where it shares |
| similarities with FX. |
| |
| ExportedProgram |
| --------------- |
| |
| The top-level Export IR construct is an :class:`torch.export.ExportedProgram` |
| class. It bundles the computational graph of a PyTorch model (which is usually a |
| :class:`torch.nn.Module`) with the parameters or weights that this model |
| consumes. |
| |
| Some notable attributes of the :class:`torch.export.ExportedProgram` class are: |
| |
| - ``graph_module`` (:class:`torch.fx.GraphModule`): Data structure containing |
| the flattened computational graph of the PyTorch model. The graph can be |
| directly accessed through `ExportedProgram.graph`. |
| - ``graph_signature`` (:class:`torch.export.ExportGraphSignature`): The graph |
| signature, which specifies the parameters and buffer names used and mutated |
| within the graph. Instead of storing parameters and buffers as attributes of |
| the graph, they are lifted as inputs to the graph. The graph_signature is |
| utilized to keep track of additional information on these parameters and |
| buffers. |
| - ``state_dict`` (``Dict[str, Union[torch.Tensor, torch.nn.Parameter]]``): Data |
| structure containing the parameters and buffers. |
| - ``range_constraints`` (``Dict[sympy.Symbol, RangeConstraint]``): For programs |
| that are exported with data dependent behavior, the metadata on each node will |
| contain symbolic shapes (which look like ``s0``, ``i0``). This attribute maps |
| the symbolic shapes to their lower/upper ranges. |
| |
| Graph |
| ----- |
| |
| An Export IR Graph is a PyTorch program represented in the form of a DAG |
| (directed acyclic graph). Each node in this graph represents a particular |
| computation or operation, and edges of this graph consist of references between |
| nodes. |
| |
| We can view Graph having this schema: |
| |
| .. code-block:: python |
| |
| class Graph: |
| nodes: List[Node] |
| |
| In practice, Export IR's graph is realized as :class:`torch.fx.Graph` Python class. |
| |
| An Export IR graph contains the following nodes (Nodes will be described in more |
| details in the next section): |
| |
| - 0 or more nodes of op type ``placeholder`` |
| - 0 or more nodes of op type ``call_function`` |
| - exactly 1 node of op type ``output`` |
| |
| **Collorary:** The smallest valid Graph will be of one node. i.e. nodes is never empty. |
| |
| **Definition:** |
| The set of ``placeholder`` nodes of a Graph represents the **inputs** of the |
| Graph of GraphModule. The `output` node of a Graph represents the **outputs** |
| of the Graph of GraphModule. |
| |
| Example:: |
| |
| from torch import nn |
| |
| class MyModule(nn.Module): |
| |
| def forward(self, x, y): |
| return x + y |
| |
| mod = torch.export.export(MyModule()) |
| print(mod.graph) |
| |
| The above is the textual representation of a Graph, with each line being a node. |
| |
| Node |
| ---- |
| |
| A Node represents a particular computation or operation and is represented in |
| Python using the :class:`torch.fx.Node` class. Edges between nodes are |
| represented as direct references to other nodes via the ``args`` property of the |
| Node class. Using the same FX machinery, we can represent the following |
| operations that a computational graph typically needs, such as operator calls, |
| placeholders (aka inputs), conditionals, and loops. |
| |
| The Node has the following schema: |
| |
| .. code-block:: python |
| |
| class Node: |
| name: str # name of node |
| op_name: str # type of operation |
| |
| # interpretation of the fields below depends on op_name |
| target: [str|Callable] |
| args: List[object] |
| kwargs: Dict[str, object] |
| meta: Dict[str, object] |
| |
| **FX Text Format** |
| |
| As in the example above, notice that each line has this format:: |
| |
| %<name>:[...] = <op_name>[target=<target>](args = (%arg1, %arg2, arg3, arg4, …)), kwargs = {"keyword": arg5}) |
| |
| This format captures everything present in the Node class, with the exception of |
| ``meta``, in a compact format. |
| |
| Concretely: |
| |
| - **<name>** is the name of the node as it would appear in ``node.name``. |
| |
| - **<op_name>** is the ``node.op`` field, which must be one of these: |
| `<call_function>`, `<placeholder>`, |
| `<get_attr>`, or `<output>`. |
| |
| - **<target>** is the target of the node as ``node.target``. The meaning of this |
| field depends on ``op_name``. |
| |
| - **args1, … args 4…** are what is listed in the ``node.args`` tuple. If a |
| value in the list is an :class:`torch.fx.Node`, then it will be especially |
| indicated with a leading **%.** |
| |
| For example, a call to the add operator would appear as:: |
| |
| %add1 = call_function[target = torch.op.aten.add.Tensor](args = (%x, %y), kwargs = {}) |
| |
| Where ``%x``, ``%y`` are two other Nodes that have names x and y. Worth noting |
| that the string ``torch.op.aten.add.Tensor`` represents the callable object that |
| is actually stored in the target field, not merely its string name. |
| |
| The final line of this text format is:: |
| |
| return [add] |
| |
| which is a Node with ``op_name = output``, indicating that we are returning this |
| one element. |
| |
| call_function |
| ^^^^^^^^^^^^^ |
| |
| A ``call_function`` node represents a call to an operator. |
| |
| **Definitions** |
| |
| - **Functional:** We say a callable is “functional” if it satisfies all the |
| following requirements: |
| |
| - Non-mutating: The operator does not mutate the value of its input (for |
| tensors, this includes both metadata and data). |
| - No side effects: The operator does not mutate states that are visible |
| from outside, like changing values of module parameters. |
| |
| - **Operator:** is a functional callable with a predefined schema. Examples of |
| such operators include functional ATen operators. |
| |
| **Representation in FX** |
| |
| .. code-block:: |
| |
| %name = call_function[target = operator](args = (%x, %y, …), kwargs = {}) |
| |
| |
| **Differences from vanilla FX call_function** |
| |
| 1. In FX graph, a call_function can refer to any callable, in Export IR, we |
| restrict it to only a select subset of ATen operators, custom operators, and |
| control flow operators. |
| |
| 2. In Export IR, constant arguments will be embedded within the graph. |
| |
| 3. In FX graph, a get_attr node can represent reading any attribute stored in |
| the graph module. However, in Export IR this is restricted to readign only |
| submodules as all parameters/buffers will be passed in as inputs to the graph |
| module. |
| |
| Metadata |
| ~~~~~~~~ |
| |
| ``Node.meta`` is a dict attached to every FX node. However, the FX spec does not |
| specify what metadata can or will be there. Export IR provides a stronger |
| contract, specifically all ``call_function`` nodes will guarantee having and |
| only having the following metadata fields: |
| |
| - ``node.meta["stack_trace"]`` is a string containing the Python stack trace |
| referencing the original Python source code. An example stack trace looks |
| like:: |
| |
| File "my_module.py", line 19, in forward |
| return x + dummy_helper(y) |
| File "helper_utility.py", line 89, in dummy_helper |
| return y + 1 |
| |
| - ``node.meta["val"]`` describes the output of running the operation. It can be |
| of type `<symint>`, `<FakeTensor>`, a |
| ``List[Union[FakeTensor, SymInt]]``, or ``None``. |
| |
| - ``node.meta["nn_module_stack"]`` describes the "stacktrace" of the |
| :class:`torch.nn.Module` from which the node came, if it was from a |
| :class:`torch.nn.Module` call. For example, if a node containing the ``addmm`` |
| op called from a :class:`torch.nn.Linear` module inside of a |
| :class:`torch.nn.Sequential` module, the ``nn_module_stack`` would look |
| something like:: |
| |
| {'self_linear': ('self.linear', <class 'torch.nn.Linear'>), 'self_sequential': ('self.sequential', <class 'torch.nn.Sequential'>)} |
| |
| - ``node.meta["source_fn_stack"]`` contains the torch function or the leaf |
| :class:`torch.nn.Module` class this node was called from before decomposition. |
| For example, a node containing the ``addmm`` op from a |
| :class:`torch.nn.Linear` module call would contain :class:`torch.nn.Linear` in |
| their ``source_fn``, and a node containing the ``addmm`` op from a |
| :class:`torch.nn.functional.Linear` module call would contain |
| :class:`torch.nn.functional.Linear` in their ``source_fn``. |
| |
| placeholder |
| ^^^^^^^^^^^ |
| |
| Placeholder represents an input to a graph. Its semantics are exactly the same as in FX. |
| Placeholder nodes must be the first N nodes in the nodes list of a graph. N can be zero. |
| |
| **Representation in FX** |
| |
| .. code-block:: python |
| |
| %name = placeholder[target = name](args = ()) |
| |
| The target field is a string which is the name of input. |
| |
| ``args``, if non-empty, should be of size 1 representing the default value of this input. |
| |
| **Metadata** |
| |
| Placeholder nodes also have ``meta[‘val’]``, like ``call_function`` nodes. The |
| ``val`` field in this case represents the input shape/dtype that the graph is |
| expected to receive for this input parameter. |
| |
| output |
| ^^^^^^ |
| |
| An output call represents a return statement in a function; it thus terminates the |
| current graph. There is one and only one output node, and it will always be the |
| last node of the graph. |
| |
| **Representation in FX** |
| |
| .. code-block:: |
| |
| output[](args = (%something, …)) |
| |
| This has the exact semantics as in :class:`torch.fx`. ``args`` represents the node |
| to be returned. |
| |
| **Metadata** |
| |
| Output node has the same metadata as ``call_function`` nodes. |
| |
| get_attr |
| ^^^^^^^^ |
| |
| ``get_attr`` nodes represent reading a submodule from the encapsulating |
| :class:`torch.fx.GraphModule`. Unlike a vanilla FX graph from |
| :func:`torch.fx.symbolic_trace` in which ``get_attr`` nodes are used to read |
| attributes such as parameters and buffers from the top-level |
| :class:`torch.fx.GraphModule`, parameters and buffers are passed in as |
| inputs to the graph module, and stored in the top-level |
| :class:`torch.export.ExportedProgram`. |
| |
| **Representation in FX** |
| |
| .. code-block:: python |
| |
| %name = get_attr[target = name](args = ()) |
| |
| **Example** |
| |
| Consider the following model:: |
| |
| from functorch.experimental.control_flow import cond |
| |
| def true_fn(x): |
| return x.sin() |
| |
| def false_fn(x): |
| return x.cos() |
| |
| def f(x, y): |
| return cond(y, true_fn, false_fn, [x]) |
| |
| Graph:: |
| |
| graph(): |
| %x_1 : [num_users=1] = placeholder[target=x_1] |
| %y_1 : [num_users=1] = placeholder[target=y_1] |
| %true_graph_0 : [num_users=1] = get_attr[target=true_graph_0] |
| %false_graph_0 : [num_users=1] = get_attr[target=false_graph_0] |
| %conditional : [num_users=1] = call_function[target=torch.ops.higher_order.cond](args = (%y_1, %true_graph_0, %false_graph_0, [%x_1]), kwargs = {}) |
| return conditional |
| |
| The line, ``%true_graph_0 : [num_users=1] = get_attr[target=true_graph_0]``, |
| reads the submodule ``true_graph_0`` which contains the ``sin`` operator. |
| |
| References |
| ---------- |
| |
| SymInt |
| ^^^^^^ |
| |
| A SymInt is an object that can either be a literal integer or a symbol that represents |
| an Integer (represented in Python by ``sympy.Symbol`` class). When SymInt is a |
| symbol, it describes a variable of type integer that is unknown to the graph at |
| compile time, that is, its value is only known at runtime. |
| |
| FakeTensor |
| ^^^^^^^^^^ |
| |
| A FakeTensor is an object that contains the metadata of a tensor. It can be |
| viewed as having the following metadata. |
| |
| .. code-block:: python |
| |
| class FakeTensor: |
| size: List[SymInt] |
| dtype: torch.dtype |
| device: torch.device |
| dim_order: List[int] # This doesn't exist yet |
| |
| The size field of FakeTensor is a list of integers or SymInts. If SymInts are |
| present, this means this tensor has a dynamic shape. If integers are present, it |
| is assumed that the tensor will have that exact static shape. The rank of the |
| TensorMeta is never dynamic. The dtype field represents the dtype of the |
| output of that node. There are no implicit type promotions in Edge IR. There |
| are no strides in FakeTensor. |
| |
| In other words: |
| |
| - If the operator in node.target returns a Tensor, then ``node.meta['val']`` is a |
| FakeTensor describing that tensor. |
| - If the operator in node.target returns an n-tuple of Tensors, then |
| ``node.meta['val']`` is an n-tuple of FakeTensors describing each tensor. |
| - If the operator in node.target returns an int/float/scalar that is known at |
| compile time, then ``node.meta['val']`` is None. |
| - If the operator in node.target returns an int/float/scalar that is not known |
| at compile time, then ``node.meta['val']`` is of type SymInt. |
| |
| For example: |
| |
| - ``aten::add`` returns a Tensor; so its spec will be a FakeTensor with dtype |
| and size of the tensor returned by this operator. |
| - ``aten::sym_size`` returns an integer; so its val will be a SymInt because its |
| value is only available at runtime. |
| - ``max_pool2d_with_indexes`` returns a tuple of (Tensor, Tensor); so the spec |
| will also be a 2-tuple of FakeTensor objects, the first TensorMeta describes |
| the first element of the return value etc. |
| |
| Python code:: |
| |
| def add_one(x): |
| return torch.ops.aten(x, 1) |
| |
| Graph:: |
| |
| graph(): |
| %ph_0 : [#users=1] = placeholder[target=ph_0] |
| %add_tensor : [#users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%ph_0, 1), kwargs = {}) |
| return [add_tensor] |
| |
| FakeTensor:: |
| |
| FakeTensor(dtype=torch.int, size=[2,], device=CPU) |
| |
| Pytree-able Types |
| ^^^^^^^^^^^^^^^^^ |
| |
| We define a type “Pytree-able”, if it is either a leaf type or a container type |
| that contains other Pytree-able types. |
| |
| Note: |
| |
| The concept of pytree is the same as the one documented |
| `here <https://jax.readthedocs.io/en/latest/pytrees.html>`__ for JAX: |
| |
| |
| The following types are defined as **leaf type**: |
| |
| .. list-table:: |
| :widths: 50 50 |
| :header-rows: 1 |
| |
| * - Type |
| - Definition |
| * - Tensor |
| - :class:`torch.Tensor` |
| * - Scalar |
| - Any numerical types from Python, including integral types, floating point types, and zero dimensional tensors. |
| * - int |
| - Python int (binded as int64_t in C++) |
| * - float |
| - Python float (binded as double in C++) |
| * - bool |
| - Python bool |
| * - str |
| - Python string |
| * - ScalarType |
| - :class:`torch.dtype` |
| * - Layout |
| - :class:`torch.layout` |
| * - MemoryFormat |
| - :class:`torch.memory_format` |
| * - Device |
| - :class:`torch.device` |
| |
| The following types are defined as **container type**: |
| |
| .. list-table:: |
| :widths: 50 50 |
| :header-rows: 1 |
| |
| * - Type |
| - Definition |
| * - Tuple |
| - Python tuple |
| * - List |
| - Python list |
| * - Dict |
| - Python dict with Scalar keys |
| * - NamedTuple |
| - Python namedtuple |
| * - Dataclass |
| - Must be registered through `register_dataclass <https://github.com/pytorch/pytorch/blob/901aa85b58e8f490631ce1db44e6555869a31893/torch/export/__init__.py#L693>`__ |
| * - Custom class |
| - Any custom class defined with `_register_pytree_node <https://github.com/pytorch/pytorch/blob/901aa85b58e8f490631ce1db44e6555869a31893/torch/utils/_pytree.py#L72>`__ |