blob: 9d041be9092684295dc9781b8352ac74454b3303 [file] [log] [blame]
{
"cells": [
{
"cell_type": "markdown",
"id": "b687b169-ec83-493d-a7c5-f8c6cd402ea3",
"metadata": {},
"source": [
"# Neural Tangent Kernels\n",
"\n",
"<a href=\"https://colab.research.google.com/github/pytorch/pytorch/blob/master/functorch/notebooks/neural_tangent_kernels.ipynb\">\n",
" <img style=\"width: auto\" src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
"</a>\n",
"\n",
"The neural tangent kernel (NTK) is a kernel that describes [how a neural network evolves during training](https://en.wikipedia.org/wiki/Neural_tangent_kernel). There has been a lot of research around it [in recent years](https://arxiv.org/abs/1806.07572). This tutorial, inspired by the implementation of [NTKs in JAX](https://github.com/google/neural-tangents) (see [Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720) for details), demonstrates how to easily compute this quantity using functorch."
]
},
{
"cell_type": "markdown",
"id": "77f41c65-f070-4b60-b3d0-1c8f56ed4f64",
"metadata": {},
"source": [
"## Setup\n",
"\n",
"First, some setup. Let's define a simple CNN that we wish to compute the NTK of."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "855fa70b-5b63-4973-94df-41be57ab6ecf",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from functorch import make_functional, vmap, vjp, jvp, jacrev\n",
"device = 'cuda'\n",
"\n",
"class CNN(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.conv1 = nn.Conv2d(3, 32, (3, 3))\n",
" self.conv2 = nn.Conv2d(32, 32, (3, 3))\n",
" self.conv3 = nn.Conv2d(32, 32, (3, 3))\n",
" self.fc = nn.Linear(21632, 10)\n",
" \n",
" def forward(self, x):\n",
" x = self.conv1(x)\n",
" x = x.relu()\n",
" x = self.conv2(x)\n",
" x = x.relu()\n",
" x = self.conv3(x)\n",
" x = x.flatten(1)\n",
" x = self.fc(x)\n",
" return x"
]
},
{
"cell_type": "markdown",
"id": "52c600e9-207a-41ec-93b4-5d940827bda0",
"metadata": {},
"source": [
"And let's generate some random data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "0001a907-f5c9-4532-9ee9-2e94b8487d08",
"metadata": {},
"outputs": [],
"source": [
"x_train = torch.randn(20, 3, 32, 32, device=device)\n",
"x_test = torch.randn(5, 3, 32, 32, device=device)"
]
},
{
"cell_type": "markdown",
"id": "8af210fe-9613-48ee-a96c-d0836458b0f1",
"metadata": {},
"source": [
"## Create a function version of the model\n",
"\n",
"functorch transforms operate on functions. In particular, to compute the NTK, we will need a function that accepts the parameters of the model and a single input (as opposed to a batch of inputs!) and returns a single output.\n",
"\n",
"We'll use functorch's `make_functional` to accomplish the first step. If your module has buffers, you'll want to use `make_functional_with_buffers` instead."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "e6b4bb59-bdde-46cd-8a28-7fd00a37a387",
"metadata": {},
"outputs": [],
"source": [
"net = CNN().to(device)\n",
"fnet, params = make_functional(net)"
]
},
{
"cell_type": "markdown",
"id": "319276a4-da45-499a-af47-0677107559b6",
"metadata": {},
"source": [
"Keep in mind that the model was originally written to accept a batch of input data points. In our CNN example, there are no inter-batch operations. That is, each data point in the batch is independent of other data points. With this assumption in mind, we can easily generate a function that evaluates the model on a single data point:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "0b8b4021-eb10-4a50-9d99-3817cb0ce4cc",
"metadata": {},
"outputs": [],
"source": [
"def fnet_single(params, x):\n",
" return fnet(params, x.unsqueeze(0)).squeeze(0)"
]
},
{
"cell_type": "markdown",
"id": "62bc6b5a-31fa-411e-8069-e6c1f6d05248",
"metadata": {},
"source": [
"## Compute the NTK: method 1 (Jacobian contraction)\n",
"\n",
"We're ready to compute the empirical NTK. The empirical NTK for two data points $x_1$ and $x_2$ is defined as the matrix product between the Jacobian of the model evaluated at $x_1$ and the Jacobian of the model evaluated at $x_2$:\n",
"\n",
"$$J_{net}(x_1) J_{net}^T(x_2)$$\n",
"\n",
"In the batched case where $x_1$ is a batch of data points and $x_2$ is a batch of data points, then we want the matrix product between the Jacobians of all combinations of data points from $x_1$ and $x_2$.\n",
"\n",
"The first method consists of doing just that - computing the two Jacobians, and contracting them. Here's how to compute the NTK in the batched case:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "99a38a4b-64d3-4e13-bd63-2d71e8dd6840",
"metadata": {},
"outputs": [],
"source": [
"def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2):\n",
" # Compute J(x1)\n",
" jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)\n",
" jac1 = [j.flatten(2) for j in jac1]\n",
" \n",
" # Compute J(x2)\n",
" jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)\n",
" jac2 = [j.flatten(2) for j in jac2]\n",
" \n",
" # Compute J(x1) @ J(x2).T\n",
" result = torch.stack([torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac1, jac2)])\n",
" result = result.sum(0)\n",
" return result"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "cbf54d2b-c4bc-46bd-9e55-e1471d639a4e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([20, 5, 10, 10])\n"
]
}
],
"source": [
"result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test)\n",
"print(result.shape)"
]
},
{
"cell_type": "markdown",
"id": "ea844f45-98fb-4cba-8056-644292b968ab",
"metadata": {},
"source": [
"In some cases, you may only want the diagonal or the trace of this quantity, especially if you know beforehand that the network architecture results in an NTK where the non-diagonal elements can be approximated by zero. It's easy to adjust the above function to do that:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "aae760c9-e906-4fda-b490-1126a86b7e96",
"metadata": {},
"outputs": [],
"source": [
"def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute='full'):\n",
" # Compute J(x1)\n",
" jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)\n",
" jac1 = [j.flatten(2) for j in jac1]\n",
" \n",
" # Compute J(x2)\n",
" jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)\n",
" jac2 = [j.flatten(2) for j in jac2]\n",
" \n",
" # Compute J(x1) @ J(x2).T\n",
" einsum_expr = None\n",
" if compute == 'full':\n",
" einsum_expr = 'Naf,Mbf->NMab'\n",
" elif compute == 'trace':\n",
" einsum_expr = 'Naf,Maf->NM'\n",
" elif compute == 'diagonal':\n",
" einsum_expr = 'Naf,Maf->NMa'\n",
" else:\n",
" assert False\n",
" \n",
" result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)])\n",
" result = result.sum(0)\n",
" return result"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "42d974f3-1f9d-4953-8677-5ee22cfc67eb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([20, 5])\n"
]
}
],
"source": [
"result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test, 'trace')\n",
"print(result.shape)"
]
},
{
"cell_type": "markdown",
"id": "6c941e5d-51d7-47b2-80ee-edcd4aee6aaa",
"metadata": {},
"source": [
"The asymptotic time complexity of this method is $N O [FP]$ (time to compute the Jacobians) $ + N^2 O^2 P$ (time to contract the Jacobians), where $N$ is the batch size of $x_1$ and $x_2$, $O$ is the model's output size, $P$ is the total number of parameters, and $[FP]$ is the cost of a single forward pass through the model. See section section 3.2 in [Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720) for details."
]
},
{
"cell_type": "markdown",
"id": "6c931e5d-51d7-47b2-80ee-ddcd4aee6aaa",
"metadata": {},
"source": [
"## Compute the NTK: method 2 (NTK-vector products)\n",
"\n",
"The next method we will discuss is a way to compute the NTK using NTK-vector products.\n",
"\n",
"This method reformulates NTK as a stack of NTK-vector products applied to columns of an identity matrix $I_O$ of size $O\\times O$ (where $O$ is the output size of the model):\n",
"\n",
"$$J_{net}(x_1) J_{net}^T(x_2) = J_{net}(x_1) J_{net}^T(x_2) I_{O} = \\left[J_{net}(x_1) \\left[J_{net}^T(x_2) e_o\\right]\\right]_{o=1}^{O},$$\n",
"where $e_o\\in \\mathbb{R}^O$ are column vectors of the identity matrix $I_O$.\n",
"\n",
"- Let $\\textrm{vjp}_o = J_{net}^T(x_2) e_o$. We can use a vector-Jacobian product to compute this.\n",
"- Now, consider $J_{net}(x_1) \\textrm{vjp}_o$. This is a Jacobian-vector product!\n",
"- Finally, we can run the above computation in parallel over all columns $e_o$ of $I_O$ using `vmap`.\n",
"\n",
"This suggests that we can use a combination of reverse-mode AD (to compute the vector-Jacobian product) and forward-mode AD (to compute the Jacobian-vector product) to compute the NTK.\n",
"\n",
"Let's code that up:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "dc4b49d7-3096-45d5-a7a1-7032309a2613",
"metadata": {},
"outputs": [],
"source": [
"def empirical_ntk_ntk_vps(func, params, x1, x2, compute='full'):\n",
" def get_ntk(x1, x2):\n",
" def func_x1(params):\n",
" return func(params, x1)\n",
"\n",
" def func_x2(params):\n",
" return func(params, x2)\n",
"\n",
" output, vjp_fn = vjp(func_x1, params)\n",
"\n",
" def get_ntk_slice(vec):\n",
" # This computes vec @ J(x2).T\n",
" # `vec` is some unit vector (a single slice of the Identity matrix)\n",
" vjps = vjp_fn(vec)\n",
" # This computes J(X1) @ vjps\n",
" _, jvps = jvp(func_x2, (params,), vjps)\n",
" return jvps\n",
"\n",
" # Here's our identity matrix\n",
" basis = torch.eye(output.numel(), dtype=output.dtype, device=output.device).view(output.numel(), -1)\n",
" return vmap(get_ntk_slice)(basis)\n",
" \n",
" # get_ntk(x1, x2) computes the NTK for a single data point x1, x2\n",
" # Since the x1, x2 inputs to empirical_ntk_ntk_vps are batched,\n",
" # we actually wish to compute the NTK between every pair of data points\n",
" # between {x1} and {x2}. That's what the vmaps here do.\n",
" result = vmap(vmap(get_ntk, (None, 0)), (0, None))(x1, x2)\n",
" \n",
" if compute == 'full':\n",
" return result\n",
" if compute == 'trace':\n",
" return torch.einsum('NMKK->NM', result)\n",
" if compute == 'diagonal':\n",
" return torch.einsum('NMKK->NMK', result)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "f750544f-9e48-47fe-9f9b-e1b8ae49b245",
"metadata": {},
"outputs": [],
"source": [
"result_from_jacobian_contraction = empirical_ntk_jacobian_contraction(fnet_single, params, x_test, x_train)\n",
"result_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train)\n",
"assert torch.allclose(result_from_jacobian_contraction, result_from_ntk_vps, atol=1e-5)"
]
},
{
"cell_type": "markdown",
"id": "84253466-971d-4475-999c-fe3de6bd25b5",
"metadata": {},
"source": [
"Our code for `empirical_ntk_ntk_vps` looks like a direct translation from the math above! This showcases the power of function transforms: good luck trying to write an efficient version of the above using stock PyTorch.\n",
"\n",
"The asymptotic time complexity of this method is $N^2 O [FP]$, where $N$ is the batch size of $x_1$ and $x_2$, $O$ is the model's output size, and $[FP]$ is the cost of a single forward pass through the model. Hence this method performs more forward passes through the network than method 1, Jacobian contraction ($N^2 O$ instead of $N O$), but avoids the contraction cost altogether (no $N^2 O^2 P$ term, where $P$ is the total number of model's parameters). Therefore, this method is preferable when $O P$ is large relative to $[FP]$, such as fully-connected (not convolutional) models with many outputs $O$. Memory-wise, both methods should be comparable. See section 3.3 in [Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720) for details."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}