blob: deae3418966baec143895c9887c6aacdd7f63626 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"id": "903e2f76",
"metadata": {},
"source": [
"# Whirlwind Tour\n",
"\n",
"<a href=\"https://colab.research.google.com/github/pytorch/pytorch/blob/master/functorch/notebooks/whirlwind_tour.ipynb\">\n",
" <img style=\"width: auto\" src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
"</a>\n",
"\n",
"## What is functorch?\n",
"\n",
"functorch is a library for [JAX](https://github.com/google/jax)-like composable function transforms in PyTorch.\n",
"- A \"function transform\" is a higher-order function that accepts a numerical function and returns a new function that computes a different quantity.\n",
"- functorch has auto-differentiation transforms (`grad(f)` returns a function that computes the gradient of `f`), a vectorization/batching transform (`vmap(f)` returns a function that computes `f` over batches of inputs), and others.\n",
"- These function transforms can compose with each other arbitrarily. For example, composing `vmap(grad(f))` computes a quantity called per-sample-gradients that stock PyTorch cannot efficiently compute today.\n",
"\n",
"Furthermore, we also provide an experimental compilation transform in the `functorch.compile` namespace. Our compilation transform, named AOT (ahead-of-time) Autograd, returns to you an [FX graph](https://pytorch.org/docs/stable/fx.html) (that optionally contains a backward pass), of which compilation via various backends is one path you can take.\n",
"\n",
"\n",
"## Why composable function transforms?\n",
"There are a number of use cases that are tricky to do in PyTorch today:\n",
"- computing per-sample-gradients (or other per-sample quantities)\n",
"- running ensembles of models on a single machine\n",
"- efficiently batching together tasks in the inner-loop of MAML\n",
"- efficiently computing Jacobians and Hessians\n",
"- efficiently computing batched Jacobians and Hessians\n",
"\n",
"Composing `vmap`, `grad`, `vjp`, and `jvp` transforms allows us to express the above without designing a separate subsystem for each.\n",
"\n",
"## What are the transforms?\n",
"\n",
"### grad (gradient computation)\n",
"\n",
"`grad(func)` is our gradient computation transform. It returns a new function that computes the gradients of `func`. It assumes `func` returns a single-element Tensor and by default it computes the gradients of the output of `func` w.r.t. to the first input."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f920b923",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from functorch import grad\n",
"x = torch.randn([])\n",
"cos_x = grad(lambda x: torch.sin(x))(x)\n",
"assert torch.allclose(cos_x, x.cos())\n",
"\n",
"# Second-order gradients\n",
"neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)\n",
"assert torch.allclose(neg_sin_x, -x.sin())"
]
},
{
"cell_type": "markdown",
"id": "ef3b2d85",
"metadata": {},
"source": [
"### vmap (auto-vectorization)\n",
"\n",
"Note: vmap imposes restrictions on the code that it can be used on. For more details, please read its docstring.\n",
"\n",
"`vmap(func)(*inputs)` is a transform that adds a dimension to all Tensor operations in `func`. `vmap(func)` returns a new function that maps `func` over some dimension (default: 0) of each Tensor in inputs.\n",
"\n",
"vmap is useful for hiding batch dimensions: one can write a function func that runs on examples and then lift it to a function that can take batches of examples with `vmap(func)`, leading to a simpler modeling experience:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6ebac649",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from functorch import vmap\n",
"batch_size, feature_size = 3, 5\n",
"weights = torch.randn(feature_size, requires_grad=True)\n",
"\n",
"def model(feature_vec):\n",
" # Very simple linear model with activation\n",
" assert feature_vec.dim() == 1\n",
" return feature_vec.dot(weights).relu()\n",
"\n",
"examples = torch.randn(batch_size, feature_size)\n",
"result = vmap(model)(examples)"
]
},
{
"cell_type": "markdown",
"id": "5161e6d2",
"metadata": {},
"source": [
"When composed with `grad`, `vmap` can be used to compute per-sample-gradients:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ffb2fcb1",
"metadata": {},
"outputs": [],
"source": [
"from functorch import vmap\n",
"batch_size, feature_size = 3, 5\n",
"\n",
"def model(weights,feature_vec):\n",
" # Very simple linear model with activation\n",
" assert feature_vec.dim() == 1\n",
" return feature_vec.dot(weights).relu()\n",
"\n",
"def compute_loss(weights, example, target):\n",
" y = model(weights, example)\n",
" return ((y - target) ** 2).mean() # MSELoss\n",
"\n",
"weights = torch.randn(feature_size, requires_grad=True)\n",
"examples = torch.randn(batch_size, feature_size)\n",
"targets = torch.randn(batch_size)\n",
"inputs = (weights,examples, targets)\n",
"grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)"
]
},
{
"cell_type": "markdown",
"id": "11d711af",
"metadata": {},
"source": [
"### vjp (vector-Jacobian product)\n",
"\n",
"The `vjp` transform applies `func` to `inputs` and returns a new function that computes the vector-Jacobian product (vjp) given some `cotangents` Tensors."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ad48f9d4",
"metadata": {},
"outputs": [],
"source": [
"from functorch import vjp\n",
"\n",
"inputs = torch.randn(3)\n",
"func = torch.sin\n",
"cotangents = (torch.randn(3),)\n",
"\n",
"outputs, vjp_fn = vjp(func, inputs); vjps = vjp_fn(*cotangents)"
]
},
{
"cell_type": "markdown",
"id": "e0221270",
"metadata": {},
"source": [
"### jvp (Jacobian-vector product)\n",
"\n",
"The `jvp` transforms computes Jacobian-vector-products and is also known as \"forward-mode AD\". It is not a higher-order function unlike most other transforms, but it returns the outputs of `func(inputs)` as well as the jvps."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f3772f43",
"metadata": {},
"outputs": [],
"source": [
"from functorch import jvp\n",
"x = torch.randn(5)\n",
"y = torch.randn(5)\n",
"f = lambda x, y: (x * y)\n",
"_, output = jvp(f, (x, y), (torch.ones(5), torch.ones(5)))\n",
"assert torch.allclose(output, x + y)"
]
},
{
"cell_type": "markdown",
"id": "7b00953b",
"metadata": {},
"source": [
"### jacrev, jacfwd, and hessian\n",
"\n",
"The `jacrev` transform returns a new function that takes in `x` and returns the Jacobian of the function\n",
"with respect to `x` using reverse-mode AD."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "20f53be2",
"metadata": {},
"outputs": [],
"source": [
"from functorch import jacrev\n",
"x = torch.randn(5)\n",
"jacobian = jacrev(torch.sin)(x)\n",
"expected = torch.diag(torch.cos(x))\n",
"assert torch.allclose(jacobian, expected)"
]
},
{
"cell_type": "markdown",
"id": "b9007c88",
"metadata": {},
"source": [
"Use `jacrev` to compute the jacobian. This can be composed with `vmap` to produce batched jacobians:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "97d6c382",
"metadata": {},
"outputs": [],
"source": [
"x = torch.randn(64, 5)\n",
"jacobian = vmap(jacrev(torch.sin))(x)\n",
"assert jacobian.shape == (64, 5, 5)"
]
},
{
"cell_type": "markdown",
"id": "cda642ec",
"metadata": {},
"source": [
"`jacfwd` is a drop-in replacement for `jacrev` that computes Jacobians using forward-mode AD:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a8c1dedb",
"metadata": {},
"outputs": [],
"source": [
"from functorch import jacfwd\n",
"x = torch.randn(5)\n",
"jacobian = jacfwd(torch.sin)(x)\n",
"expected = torch.diag(torch.cos(x))\n",
"assert torch.allclose(jacobian, expected)"
]
},
{
"cell_type": "markdown",
"id": "39f85b50",
"metadata": {},
"source": [
"Composing `jacrev` with itself or `jacfwd` can produce hessians:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1e511139",
"metadata": {},
"outputs": [],
"source": [
"def f(x):\n",
" return x.sin().sum()\n",
"\n",
"x = torch.randn(5)\n",
"hessian0 = jacrev(jacrev(f))(x)\n",
"hessian1 = jacfwd(jacrev(f))(x)"
]
},
{
"cell_type": "markdown",
"id": "18efdc65",
"metadata": {},
"source": [
"The `hessian` is a convenience function that combines `jacfwd` and `jacrev`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fd1765df",
"metadata": {},
"outputs": [],
"source": [
"from functorch import hessian\n",
"\n",
"def f(x):\n",
" return x.sin().sum()\n",
"\n",
"x = torch.randn(5)\n",
"hess = hessian(f)(x)"
]
},
{
"cell_type": "markdown",
"id": "b597d7ad",
"metadata": {},
"source": [
"## Conclusion\n",
"\n",
"Check out our other tutorials (in the left bar) for more detailed explanations of how to apply functorch transforms for various use cases. `functorch` is very much a work in progress and we'd love to hear how you're using it -- we encourage you to start a conversation at our [issues tracker](https://github.com/pytorch/functorch) to discuss your use case."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}