Why functorch? | Install guide | Transformations | Documentation | Future Plans
This library is currently under heavy development - if you have suggestions on the API or use-cases you‘d like to be covered, please open an github issue or reach out. We’d love to hear about how you're using the library.
functorch
is JAX-like composable function transforms for PyTorch.
It aims to provide composable vmap
and grad
transforms that work with PyTorch modules and PyTorch autograd with good eager-mode performance.
In addition, there is experimental functionality to trace through these transformations using FX in order to capture the results of these transforms ahead of time. This would allow us to compile the results of vmap or grad to improve performance.
There are a number of use cases that are tricky to do in PyTorch today:
Composing vmap
, grad
, vjp
, and jvp
transforms allows us to express the above without designing a separate subsystem for each. This idea of composable function transforms comes from the JAX framework.
There are two ways to install functorch:
We recommend trying out the functorch beta first.
Follow the instructions in this Colab notebook
As of 9/21/2022, functorch
comes installed alongside a nightly PyTorch binary. Please install a Preview (nightly) PyTorch binary; see https://pytorch.org/ for instructions.
Once you've done that, run a quick sanity check in Python:
import torch from functorch import vmap x = torch.randn(3) y = vmap(torch.sin)(x) assert torch.allclose(y, x.sin())
As of 9/21/2022, functorch
comes installed alongside PyTorch and is in the PyTorch source tree. Please install PyTorch from source, then, you will be able to import functorch
.
Try to run some tests to make sure all is OK:
pytest test/test_vmap.py -v pytest test/test_eager_transforms.py -v
AOTAutograd has some additional optional requirements. You can install them via:
pip install networkx
To run functorch tests, please install our test dependencies (expecttest
, pyyaml
).
Follow the instructions here
Prerequisite: Install PyTorch
pip install functorch
Finally, run a quick sanity check in python:
import torch from functorch import vmap x = torch.randn(3) y = vmap(torch.sin)(x) assert torch.allclose(y, x.sin())
Right now, we support the following transforms:
grad
, vjp
, jvp
,jacrev
, jacfwd
, hessian
vmap
Furthermore, we have some utilities for working with PyTorch modules.
make_functional(model)
make_functional_with_buffers(model)
Note: vmap
imposes restrictions on the code that it can be used on. For more details, please read its docstring.
vmap(func)(*inputs)
is a transform that adds a dimension to all Tensor operations in func
. vmap(func)
returns a new function that maps func
over some dimension (default: 0) of each Tensor in inputs
.
vmap
is useful for hiding batch dimensions: one can write a function func
that runs on examples and then lift it to a function that can take batches of examples with vmap(func)
, leading to a simpler modeling experience:
from functorch import vmap batch_size, feature_size = 3, 5 weights = torch.randn(feature_size, requires_grad=True) def model(feature_vec): # Very simple linear model with activation assert feature_vec.dim() == 1 return feature_vec.dot(weights).relu() examples = torch.randn(batch_size, feature_size) result = vmap(model)(examples)
grad(func)(*inputs)
assumes func
returns a single-element Tensor. It compute the gradients of the output of func w.r.t. to inputs[0]
.
from functorch import grad x = torch.randn([]) cos_x = grad(lambda x: torch.sin(x))(x) assert torch.allclose(cos_x, x.cos()) # Second-order gradients neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x) assert torch.allclose(neg_sin_x, -x.sin())
When composed with vmap
, grad
can be used to compute per-sample-gradients:
from functorch import vmap batch_size, feature_size = 3, 5 def model(weights,feature_vec): # Very simple linear model with activation assert feature_vec.dim() == 1 return feature_vec.dot(weights).relu() def compute_loss(weights, example, target): y = model(weights, example) return ((y - target) ** 2).mean() # MSELoss weights = torch.randn(feature_size, requires_grad=True) examples = torch.randn(batch_size, feature_size) targets = torch.randn(batch_size) inputs = (weights,examples, targets) grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)
The vjp
transform applies func
to inputs
and returns a new function that computes vjps given some cotangents
Tensors.
from functorch import vjp outputs, vjp_fn = vjp(func, inputs); vjps = vjp_fn(*cotangents)
The jvp
transforms computes Jacobian-vector-products and is also known as “forward-mode AD”. It is not a higher-order function unlike most other transforms, but it returns the outputs of func(inputs)
as well as the jvp
s.
from functorch import jvp x = torch.randn(5) y = torch.randn(5) f = lambda x, y: (x * y) _, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5))) assert torch.allclose(output, x + y)
The jacrev
transform returns a new function that takes in x
and returns the Jacobian of torch.sin
with respect to x
using reverse-mode AD.
from functorch import jacrev x = torch.randn(5) jacobian = jacrev(torch.sin)(x) expected = torch.diag(torch.cos(x)) assert torch.allclose(jacobian, expected)
Use jacrev
to compute the jacobian. This can be composed with vmap to produce batched jacobians:
x = torch.randn(64, 5) jacobian = vmap(jacrev(torch.sin))(x) assert jacobian.shape == (64, 5, 5)
jacfwd
is a drop-in replacement for jacrev
that computes Jacobians using forward-mode AD:
from functorch import jacfwd x = torch.randn(5) jacobian = jacfwd(torch.sin)(x) expected = torch.diag(torch.cos(x)) assert torch.allclose(jacobian, expected)
Composing jacrev
with itself or jacfwd
can produce hessians:
def f(x): return x.sin().sum() x = torch.randn(5) hessian0 = jacrev(jacrev(f))(x) hessian1 = jacfwd(jacrev(f))(x)
The hessian
is a convenience function that combines jacfwd
and jacrev
:
from functorch import hessian def f(x): return x.sin().sum() x = torch.randn(5) hess = hessian(f)(x)
We can also trace through these transformations in order to capture the results as new code using make_fx
. There is also experimental integration with the NNC compiler (only works on CPU for now!).
from functorch import make_fx, grad def f(x): return torch.sin(x).sum() x = torch.randn(100) grad_f = make_fx(grad(f))(x) print(grad_f.code) def forward(self, x_1): sin = torch.ops.aten.sin(x_1) sum_1 = torch.ops.aten.sum(sin, None); sin = None cos = torch.ops.aten.cos(x_1); x_1 = None _tensor_constant0 = self._tensor_constant0 mul = torch.ops.aten.mul(_tensor_constant0, cos); _tensor_constant0 = cos = None return mul
Sometimes you may want to perform a transform with respect to the parameters and/or buffers of an nn.Module. This can happen for example in:
Our solution to this right now is an API that, given an nn.Module, creates a stateless version of it that can be called like a function.
make_functional(model)
returns a functional version of model
and the model.parameters()
make_functional_with_buffers(model)
returns a functional version of model
and the model.parameters()
and model.buffers()
.Here's an example where we compute per-sample-gradients using an nn.Linear layer:
import torch from functorch import make_functional, vmap, grad model = torch.nn.Linear(3, 3) data = torch.randn(64, 3) targets = torch.randn(64, 3) func_model, params = make_functional(model) def compute_loss(params, data, targets): preds = func_model(params, data) return torch.mean((preds - targets) ** 2) per_sample_grads = vmap(grad(compute_loss), (None, 0, 0))(params, data, targets)
If you're making an ensemble of models, you may find combine_state_for_ensemble
useful.
For more documentation, see our docs website.
torch._C._functorch.dump_tensor
: Dumps dispatch keys on stack torch._C._functorch._set_vmap_fallback_warning_enabled(False)
if the vmap warning spam bothers you.
In the end state, we'd like to upstream this into PyTorch once we iron out the design details. To figure out the details, we need your help -- please send us your use cases by starting a conversation in the issue tracker or trying our project out.
Functorch has a BSD-style license, as found in the LICENSE file.
If you use functorch in your publication, please cite it by using the following BibTeX entry.
@Misc{functorch2021, author = {Horace He, Richard Zou}, title = {functorch: JAX-like composable function transforms for PyTorch}, howpublished = {\url{https://github.com/pytorch/functorch}}, year = {2021} }