| Migrating from functorch to torch.func |
| ====================================== |
| |
| torch.func, previously known as "functorch", is |
| `JAX-like <https://github.com/google/jax>`_ composable function transforms for PyTorch. |
| |
| functorch started as an out-of-tree library over at |
| the `pytorch/functorch <https://github.com/pytorch/functorch>`_ repository. |
| Our goal has always been to upstream functorch directly into PyTorch and provide |
| it as a core PyTorch library. |
| |
| As the final step of the upstream, we've decided to migrate from being a top level package |
| (``functorch``) to being a part of PyTorch to reflect how the function transforms are |
| integrated directly into PyTorch core. As of PyTorch 2.0, we are deprecating |
| ``import functorch`` and ask that users migrate to the newest APIs, which we |
| will maintain going forward. ``import functorch`` will be kept around to maintain |
| backwards compatibility for a couple of releases. |
| |
| function transforms |
| ------------------- |
| |
| The following APIs are a drop-in replacement for the following |
| `functorch APIs <https://pytorch.org/functorch/1.13/functorch.html>`_. |
| They are fully backwards compatible. |
| |
| |
| ============================== ======================================= |
| functorch API PyTorch API (as of PyTorch 2.0) |
| ============================== ======================================= |
| functorch.vmap :func:`torch.vmap` or :func:`torch.func.vmap` |
| functorch.grad :func:`torch.func.grad` |
| functorch.vjp :func:`torch.func.vjp` |
| functorch.jvp :func:`torch.func.jvp` |
| functorch.jacrev :func:`torch.func.jacrev` |
| functorch.jacfwd :func:`torch.func.jacfwd` |
| functorch.hessian :func:`torch.func.hessian` |
| functorch.functionalize :func:`torch.func.functionalize` |
| ============================== ======================================= |
| |
| Furthermore, if you are using torch.autograd.functional APIs, please try out |
| the :mod:`torch.func` equivalents instead. :mod:`torch.func` function |
| transforms are more composable and more performant in many cases. |
| |
| =========================================== ======================================= |
| torch.autograd.functional API torch.func API (as of PyTorch 2.0) |
| =========================================== ======================================= |
| :func:`torch.autograd.functional.vjp` :func:`torch.func.grad` or :func:`torch.func.vjp` |
| :func:`torch.autograd.functional.jvp` :func:`torch.func.jvp` |
| :func:`torch.autograd.functional.jacobian` :func:`torch.func.jacrev` or :func:`torch.func.jacfwd` |
| :func:`torch.autograd.functional.hessian` :func:`torch.func.hessian` |
| =========================================== ======================================= |
| |
| NN module utilities |
| ------------------- |
| |
| We've changed the APIs to apply function transforms over NN modules to make them |
| fit better into the PyTorch design philosophy. The new API is different, so |
| please read this section carefully. |
| |
| functorch.make_functional |
| ^^^^^^^^^^^^^^^^^^^^^^^^^ |
| |
| :func:`torch.func.functional_call` is the replacement for |
| `functorch.make_functional <https://pytorch.org/functorch/1.13/generated/functorch.make_functional.html#functorch.make_functional>`_ |
| and |
| `functorch.make_functional_with_buffers <https://pytorch.org/functorch/1.13/generated/functorch.make_functional_with_buffers.html#functorch.make_functional_with_buffers>`_. |
| However, it is not a drop-in replacement. |
| |
| If you're in a hurry, you can use |
| `helper functions in this gist <https://gist.github.com/zou3519/7769506acc899d83ef1464e28f22e6cf>`_ |
| that emulate the behavior of functorch.make_functional and functorch.make_functional_with_buffers. |
| We recommend using :func:`torch.func.functional_call` directly because it is a more explicit |
| and flexible API. |
| |
| Concretely, functorch.make_functional returns a functional module and parameters. |
| The functional module accepts parameters and inputs to the model as arguments. |
| :func:`torch.func.functional_call` allows one to call the forward pass of an existing |
| module using new parameters and buffers and inputs. |
| |
| Here's an example of how to compute gradients of parameters of a model using functorch |
| vs :mod:`torch.func`:: |
| |
| # --------------- |
| # using functorch |
| # --------------- |
| import torch |
| import functorch |
| inputs = torch.randn(64, 3) |
| targets = torch.randn(64, 3) |
| model = torch.nn.Linear(3, 3) |
| |
| fmodel, params = functorch.make_functional(model) |
| |
| def compute_loss(params, inputs, targets): |
| prediction = fmodel(params, inputs) |
| return torch.nn.functional.mse_loss(prediction, targets) |
| |
| grads = functorch.grad(compute_loss)(params, inputs, targets) |
| |
| # ------------------------------------ |
| # using torch.func (as of PyTorch 2.0) |
| # ------------------------------------ |
| import torch |
| inputs = torch.randn(64, 3) |
| targets = torch.randn(64, 3) |
| model = torch.nn.Linear(3, 3) |
| |
| params = dict(model.named_parameters()) |
| |
| def compute_loss(params, inputs, targets): |
| prediction = torch.func.functional_call(model, params, (inputs,)) |
| return torch.nn.functional.mse_loss(prediction, targets) |
| |
| grads = torch.func.grad(compute_loss)(params, inputs, targets) |
| |
| And here's an example of how to compute jacobians of model parameters:: |
| |
| # --------------- |
| # using functorch |
| # --------------- |
| import torch |
| import functorch |
| inputs = torch.randn(64, 3) |
| model = torch.nn.Linear(3, 3) |
| |
| fmodel, params = functorch.make_functional(model) |
| jacobians = functorch.jacrev(fmodel)(params, inputs) |
| |
| # ------------------------------------ |
| # using torch.func (as of PyTorch 2.0) |
| # ------------------------------------ |
| import torch |
| from torch.func import jacrev, functional_call |
| inputs = torch.randn(64, 3) |
| model = torch.nn.Linear(3, 3) |
| |
| params = dict(model.named_parameters()) |
| # jacrev computes jacobians of argnums=0 by default. |
| # We set it to 1 to compute jacobians of params |
| jacobians = jacrev(functional_call, argnums=1)(model, params, (inputs,)) |
| |
| Note that it is important for memory consumption that you should only carry |
| around a single copy of your parameters. ``model.named_parameters()`` does not copy |
| the parameters. If in your model training you update the parameters of the model |
| in-place, then the ``nn.Module`` that is your model has the single copy of the |
| parameters and everything is OK. |
| |
| However, if you want to carry your parameters around in a dictionary and update |
| them out-of-place, then there are two copies of parameters: the one in the |
| dictionary and the one in the ``model``. In this case, you should change |
| ``model`` to not hold memory by converting it to the meta device via |
| ``model.to('meta')``. |
| |
| functorch.combine_state_for_ensemble |
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| |
| Please use :func:`torch.func.stack_module_state` instead of |
| `functorch.combine_state_for_ensemble <https://pytorch.org/functorch/1.13/generated/functorch.combine_state_for_ensemble.html>`_ |
| :func:`torch.func.stack_module_state` returns two dictionaries, one of stacked parameters, and |
| one of stacked buffers, that can then be used with :func:`torch.vmap` and :func:`torch.func.functional_call` |
| for ensembling. |
| |
| For example, here is an example of how to ensemble over a very simple model:: |
| |
| import torch |
| num_models = 5 |
| batch_size = 64 |
| in_features, out_features = 3, 3 |
| models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] |
| data = torch.randn(batch_size, 3) |
| |
| # --------------- |
| # using functorch |
| # --------------- |
| import functorch |
| fmodel, params, buffers = functorch.combine_state_for_ensemble(models) |
| output = functorch.vmap(fmodel, (0, 0, None))(params, buffers, data) |
| assert output.shape == (num_models, batch_size, out_features) |
| |
| # ------------------------------------ |
| # using torch.func (as of PyTorch 2.0) |
| # ------------------------------------ |
| import copy |
| |
| # Construct a version of the model with no memory by putting the Tensors on |
| # the meta device. |
| base_model = copy.deepcopy(models[0]) |
| base_model.to('meta') |
| |
| params, buffers = torch.func.stack_module_state(models) |
| |
| # It is possible to vmap directly over torch.func.functional_call, |
| # but wrapping it in a function makes it clearer what is going on. |
| def call_single_model(params, buffers, data): |
| return torch.func.functional_call(base_model, (params, buffers), (data,)) |
| |
| output = torch.vmap(call_single_model, (0, 0, None))(params, buffers, data) |
| assert output.shape == (num_models, batch_size, out_features) |
| |
| |
| functorch.compile |
| ----------------- |
| |
| We are no longer supporting functorch.compile (also known as AOTAutograd) |
| as a frontend for compilation in PyTorch; we have integrated AOTAutograd |
| into PyTorch's compilation story. If you are a user, please use |
| :func:`torch.compile` instead. |