| :orphan: |
| |
| .. _distributed-autograd-design: |
| |
| Distributed Autograd Design |
| =========================== |
| |
| This note will present the detailed design for distributed autograd and walk |
| through the internals of the same. Make sure you're familiar with |
| :ref:`autograd-mechanics` and the :ref:`distributed-rpc-framework` before |
| proceeding. |
| |
| Background |
| ^^^^^^^^^^ |
| |
| Let's say you have two nodes and a very simple model partitioned across two |
| nodes. This can be implemented using :mod:`torch.distributed.rpc` as follows: |
| |
| .. code:: |
| |
| import torch |
| import torch.distributed.rpc as rpc |
| |
| def my_add(t1, t2): |
| return torch.add(t1, t2) |
| |
| # On worker 0: |
| t1 = torch.rand((3, 3), requires_grad=True) |
| t2 = torch.rand((3, 3), requires_grad=True) |
| |
| # Perform some computation remotely. |
| t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2)) |
| |
| # Perform some computation locally based on remote result. |
| t4 = torch.rand((3, 3), requires_grad=True) |
| t5 = torch.mul(t3, t4) |
| |
| # Compute some loss. |
| loss = t5.sum() |
| |
| The main motivation behind distributed autograd is to enable running a backward |
| pass on such distributed models with the ``loss`` that we've computed and |
| record appropriate gradients for all tensors that require gradients. |
| |
| .. attaching_send_recv_functions: |
| |
| Autograd recording during the forward pass |
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| |
| PyTorch builds the autograd graph during the forward pass and this graph is |
| used to execute the backward pass. For more details see |
| :ref:`how-autograd-encodes-history`. |
| |
| For distributed autograd, we need to keep track of all RPCs during the forward |
| pass to ensure the backward pass is executed appropriately. For this purpose, |
| we attach ``send`` and ``recv`` functions to the autograd graph when we perform |
| an RPC. |
| |
| - The ``send`` function is attached to the source of the RPC and its output |
| edges point to the autograd function for the input tensors of the RPC. |
| The input for this function during the backward pass is received from the |
| destination as the output of the appropriate ``recv`` function. |
| - The ``recv`` function is attached to the destination of the RPC and its |
| inputs are retrieved from operators executed on the destination using the |
| input tensors. The output gradients of this function are sent to the source |
| node to the appropriate ``send`` function during the backward pass. |
| - Each ``send-recv`` pair is assigned a globally unique ``autograd_message_id`` |
| to uniquely identify the pair. This is useful to look up the corresponding |
| function on a remote node during the backward pass. |
| - For :ref:`rref`, whenever we call :meth:`torch.distributed.rpc.RRef.to_here` |
| we attach an appropriate ``send-recv`` pair for the tensors involved. |
| |
| As an example, this is what the autograd graph for our example above would look |
| like (t5.sum() excluded for simplicity): |
| |
| .. image:: ../_static/img/distributed_autograd/send_recv_functions.png |
| |
| .. autograd_context: |
| |
| Distributed Autograd Context |
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| |
| Each forward and backward pass that uses distributed autograd is assigned a |
| unique :class:`torch.distributed.autograd.context` and this context has a |
| globally unique ``autograd_context_id``. This context is created on each node |
| as needed. |
| |
| This context serves the following purpose: |
| |
| 1. Multiple nodes running distributed backward passes might accumulate |
| gradients on the same tensor and as a result the ``.grad`` field of the |
| tensor would have gradients from a variety of distributed backward passes |
| before we have the opportunity to run the optimizer. This is similar to |
| calling :meth:`torch.autograd.backward` multiple times locally. In order to |
| provide a way of separating out the gradients for each backward pass, the |
| gradients are accumulated in the :class:`torch.distributed.autograd.context` |
| for each backward pass. |
| 2. During the forward pass we store the ``send`` and ``recv`` functions for |
| each autograd pass in this context. This ensures we hold references to the |
| appropriate nodes in the autograd graph to keep it alive. In addition to |
| this, it is easy to look up the appropriate ``send`` and ``recv`` functions |
| during the backward pass. |
| 3. In general we also use this context to store some metadata for each |
| distributed autograd pass. |
| |
| | |
| |
| From the user's perspective the autograd context is setup as follows: |
| |
| .. code:: |
| |
| import torch.distributed.autograd as dist_autograd |
| with dist_autograd.context() as context_id: |
| loss = model.forward() |
| dist_autograd.backward(context_id, loss) |
| |
| It is important to note that your model's forward pass must be invoked within |
| the distributed autograd context manager, as a valid context is needed in |
| order to ensure that all ``send`` and ``recv`` functions are stored properly |
| to run the backward pass across all participating nodes. |
| |
| Distributed Backward Pass |
| ^^^^^^^^^^^^^^^^^^^^^^^^^ |
| |
| In this section we outline the challenge of computing dependencies accurately |
| during a distributed backward pass and describe a couple of algorithms (with |
| tradeoffs) on how we can execute a distributed backward pass. |
| |
| Computing dependencies |
| ---------------------- |
| |
| Consider the following piece of code being run on a single machine |
| |
| .. code:: |
| |
| import torch |
| a = torch.rand((3, 3), requires_grad=True) |
| b = torch.rand((3, 3), requires_grad=True) |
| c = torch.rand((3, 3), requires_grad=True) |
| d = a + b |
| e = b * c |
| d.sum.().backward() |
| |
| This is what the autograd graph for the code above would look like: |
| |
| .. image:: ../_static/img/distributed_autograd/local_dependencies.png |
| :scale: 80% |
| |
| The first step the autograd engine performs as part of the backward pass is |
| computing the number of dependencies for each node in the autograd graph. This |
| helps the autograd engine know when a node in the graph is ready for execution. |
| The numbers in brackets for ``add(1)`` and ``mul(0)`` denote the number of |
| dependencies. As you can see, this means during the backward pass the ``add`` |
| node needs 1 input and the ``mul`` node doesn't need any inputs (in other |
| words doesn't need to be executed). The local autograd engine computes these |
| dependencies by traversing the graph from the root nodes (``d`` in this case). |
| |
| The fact that certain nodes in the autograd graph might not be executed in the |
| backward pass poses a challenge for distributed autograd. Consider this piece |
| of code which uses RPC. |
| |
| .. code:: |
| |
| import torch |
| import torch.distributed.rpc as rpc |
| |
| a = torch.rand((3, 3), requires_grad=True) |
| b = torch.rand((3, 3), requires_grad=True) |
| c = torch.rand((3, 3), requires_grad=True) |
| |
| d = rpc.rpc_sync("worker1", torch.add, args=(a, b)) |
| e = rpc.rpc_sync("worker1", torch.mul, args=(b, c)) |
| loss = d.sum() |
| |
| The associated autograd graph for the code above would be: |
| |
| .. image:: ../_static/img/distributed_autograd/distributed_dependencies.png |
| |
| Computing dependencies of this distributed autograd graph is much more |
| challenging and requires some overhead (either in terms of computation or |
| network communication). |
| |
| For performance sensitive applications we can avoid a |
| lot of overhead by assuming every ``send`` and ``recv`` function are valid as |
| part of the backward pass (most applications don't perform RPCs that aren't |
| used). This simplifies the distributed autograd algorithm and is much more |
| efficient, but at the cost that the application needs to be aware of the |
| limitations. This algorithm is called the `FAST mode algorithm`_ and is |
| described in detail below. |
| |
| In the general case it might not be necessary that every ``send`` and ``recv`` |
| function is valid as part of the backward pass. To address this, we have |
| proposed a `SMART mode algorithm`_ which is described in a later section. |
| Please note that currently, only the `FAST` mode algorithm is implemented. |
| |
| .. _fast-mode-algorithm: |
| |
| FAST mode algorithm |
| ------------------- |
| |
| The key assumption of this algorithm is that each ``send`` function has a |
| dependency of 1 when we run a backward pass. In other words, we assume we'll |
| receive a gradient over RPC from another node. |
| |
| The algorithm is as follows: |
| |
| 1. We start from the worker which has the roots for the backward pass |
| (all roots must be local). |
| 2. Lookup all the ``send`` functions for the current |
| `Distributed Autograd Context`_. |
| 3. Compute dependencies locally starting from the provided roots and all the |
| ``send`` functions we retrieved. |
| 4. After computing dependencies, kick off the local autograd engine with the |
| provided roots. |
| 5. When the autograd engine executes the ``recv`` function, the ``recv`` |
| function sends the input gradients via RPC to the appropriate worker. |
| Each ``recv`` function knows the destination worker id since it is recorded |
| as part of the forward pass. The ``recv`` function also sends over the |
| ``autograd_context_id`` and ``autograd_message_id`` to the remote host. |
| 6. When this request is received on the remote host, we use the |
| ``autograd_context_id`` and ``autograd_message_id`` to look up the |
| appropriate ``send`` function. |
| 7. If this is the first time a worker has received a request for the given |
| ``autograd_context_id``, it will compute dependencies locally as described |
| in points 1-3 above. |
| 8. The ``send`` function retrieved in 6. is then enqueued for execution on the |
| local autograd engine for that worker. |
| 9. Finally, instead of accumulating the gradients on the ``.grad`` field of the |
| Tensor, we accumulate the gradients separately per |
| `Distributed Autograd Context`_. The gradients are stored in a |
| ``Dict[Tensor, Tensor]``, which is basically a map from Tensor to its |
| associated gradient and this map can be retrieved using the |
| :meth:`~torch.distributed.autograd.get_gradients` API. |
| |
| | |
| |
| As an example the complete code with distributed autograd would be as follows: |
| |
| .. code:: |
| |
| import torch |
| import torch.distributed.autograd as dist_autograd |
| import torch.distributed.rpc as rpc |
| |
| def my_add(t1, t2): |
| return torch.add(t1, t2) |
| |
| # On worker 0: |
| |
| # Setup the autograd context. Computations that take |
| # part in the distributed backward pass must be within |
| # the distributed autograd context manager. |
| with dist_autograd.context() as context_id: |
| t1 = torch.rand((3, 3), requires_grad=True) |
| t2 = torch.rand((3, 3), requires_grad=True) |
| |
| # Perform some computation remotely. |
| t3 = rpc.rpc_sync("worker1", my_add, args=(t1, t2)) |
| |
| # Perform some computation locally based on remote result. |
| t4 = torch.rand((3, 3), requires_grad=True) |
| t5 = torch.mul(t3, t4) |
| |
| # Compute some loss. |
| loss = t5.sum() |
| |
| # Run the backward pass. |
| dist_autograd.backward(context_id, [loss]) |
| |
| # Retrieve the gradients from the context. |
| dist_autograd.get_gradients(context_id) |
| |
| The distributed autograd graph with dependencies would be as follows (t5.sum() excluded for simplicity): |
| |
| .. image:: ../_static/img/distributed_autograd/distributed_dependencies_computed.png |
| |
| The `FAST mode algorithm`_ applied to the above example would be as follows: |
| |
| 1. On ``Worker 0`` we start from the roots ``loss`` and ``send1`` to compute |
| dependencies. As a result ``send1`` is marked with a dependency of 1 and ``mul`` |
| on ``Worker 0`` is marked with a dependency of 1. |
| 2. Now, we kickoff the local autograd engine on ``Worker 0``. We first execute |
| the ``mul`` function, accumulate its output in the autograd context as the |
| gradient for ``t4``. Then, we execute ``recv2`` which sends the gradients to |
| ``Worker 1``. |
| 3. Since this is the first time ``Worker 1`` has heard about this backward pass, |
| it starts dependency computation and marks the dependencies for ``send2``, |
| ``add`` and ``recv1`` appropriately. |
| 4. Next, we enqueue ``send2`` on the local autograd engine of ``Worker 1``, which |
| in turn executes ``add`` and ``recv1``. |
| 5. When ``recv1`` is executed it sends the gradients over to ``Worker 0``. |
| 6. Since ``Worker 0`` has already computed dependencies for this backward pass, |
| it just enqueues and executes ``send1`` locally. |
| 7. Finally, gradients for ``t1``, ``t2`` and ``t4`` are accumulated in the |
| `Distributed Autograd Context`_. |
| |
| SMART mode algorithm |
| -------------------- |
| Full details of this algorithm are still in the works, but for the general idea |
| you can refer to **Distributed Autograd Algorithm Smart mode** section in the |
| `RFC`_. |
| |
| Distributed Optimizer |
| ^^^^^^^^^^^^^^^^^^^^^ |
| |
| The :class:`~torch.distributed.optim.DistributedOptimizer` operates as follows: |
| |
| 1. Takes a list of remote parameters (:class:`~torch.distributed.rpc.RRef`) to |
| optimize. These could also be local parameters wrapped within a local |
| ``RRef``. |
| 2. Takes a :class:`~torch.optim.Optimizer` class as the local |
| optimizer to run on all distinct ``RRef`` owners. |
| 3. The distributed optimizer creates an instance of the local ``Optimizer`` on |
| each of the worker nodes and holds an ``RRef`` to them. |
| 4. When :meth:`torch.distributed.optim.DistributedOptimizer.step` is invoked, |
| the distributed optimizer uses RPC to remotely execute all the local |
| optimizers on the appropriate remote workers. A distributed autograd |
| ``context_id`` must be provided as input to |
| :meth:`torch.distributed.optim.DistributedOptimizer.step`. This is used |
| by local optimizers to apply gradients stored in the corresponding |
| context. |
| 5. If multiple concurrent distributed optimizers are updating the same |
| parameters on a worker, these updates are serialized via a lock. |
| |
| Simple end to end example |
| ^^^^^^^^^^^^^^^^^^^^^^^^^ |
| |
| Putting it all together, the following is a simple end to end example using |
| distributed autograd and the distributed optimizer. If the code is placed into a |
| file called "dist_autograd_simple.py", it can be run with the command |
| :code:`MASTER_ADDR="localhost" MASTER_PORT=29500 python dist_autograd_simple.py`: |
| |
| .. code:: |
| |
| import torch |
| import torch.multiprocessing as mp |
| import torch.distributed.autograd as dist_autograd |
| from torch.distributed import rpc |
| from torch import optim |
| from torch.distributed.optim import DistributedOptimizer |
| |
| def random_tensor(): |
| return torch.rand((3, 3), requires_grad=True) |
| |
| def _run_process(rank, dst_rank, world_size): |
| name = "worker{}".format(rank) |
| dst_name = "worker{}".format(dst_rank) |
| |
| # Initialize RPC. |
| rpc.init_rpc( |
| name=name, |
| rank=rank, |
| world_size=world_size |
| ) |
| |
| # Use a distributed autograd context. |
| with dist_autograd.context() as context_id: |
| # Forward pass (create references on remote nodes). |
| rref1 = rpc.remote(dst_name, random_tensor) |
| rref2 = rpc.remote(dst_name, random_tensor) |
| loss = rref1.to_here() + rref2.to_here() |
| |
| # Backward pass (run distributed autograd). |
| dist_autograd.backward(context_id, [loss.sum()]) |
| |
| # Build DistributedOptimizer. |
| dist_optim = DistributedOptimizer( |
| optim.SGD, |
| [rref1, rref2], |
| lr=0.05, |
| ) |
| |
| # Run the distributed optimizer step. |
| dist_optim.step(context_id) |
| |
| def run_process(rank, world_size): |
| dst_rank = (rank + 1) % world_size |
| _run_process(rank, dst_rank, world_size) |
| rpc.shutdown() |
| |
| if __name__ == '__main__': |
| # Run world_size workers |
| world_size = 2 |
| mp.spawn(run_process, args=(world_size,), nprocs=world_size) |
| |
| .. _RFC: https://github.com/pytorch/pytorch/issues/23110 |