| { |
| "cells": [ |
| { |
| "cell_type": "markdown", |
| "id": "b687b169-ec83-493d-a7c5-f8c6cd402ea3", |
| "metadata": {}, |
| "source": [ |
| "# Neural Tangent Kernels\n", |
| "\n", |
| "<a href=\"https://colab.research.google.com/github/pytorch/pytorch/blob/master/functorch/notebooks/neural_tangent_kernels.ipynb\">\n", |
| " <img style=\"width: auto\" src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n", |
| "</a>\n", |
| "\n", |
| "The neural tangent kernel (NTK) is a kernel that describes [how a neural network evolves during training](https://en.wikipedia.org/wiki/Neural_tangent_kernel). There has been a lot of research around it [in recent years](https://arxiv.org/abs/1806.07572). This tutorial, inspired by the implementation of [NTKs in JAX](https://github.com/google/neural-tangents) (see [Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720) for details), demonstrates how to easily compute this quantity using functorch." |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "77f41c65-f070-4b60-b3d0-1c8f56ed4f64", |
| "metadata": {}, |
| "source": [ |
| "## Setup\n", |
| "\n", |
| "First, some setup. Let's define a simple CNN that we wish to compute the NTK of." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 1, |
| "id": "855fa70b-5b63-4973-94df-41be57ab6ecf", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "import torch\n", |
| "import torch.nn as nn\n", |
| "from functorch import make_functional, vmap, vjp, jvp, jacrev\n", |
| "device = 'cuda'\n", |
| "\n", |
| "class CNN(nn.Module):\n", |
| " def __init__(self):\n", |
| " super().__init__()\n", |
| " self.conv1 = nn.Conv2d(3, 32, (3, 3))\n", |
| " self.conv2 = nn.Conv2d(32, 32, (3, 3))\n", |
| " self.conv3 = nn.Conv2d(32, 32, (3, 3))\n", |
| " self.fc = nn.Linear(21632, 10)\n", |
| " \n", |
| " def forward(self, x):\n", |
| " x = self.conv1(x)\n", |
| " x = x.relu()\n", |
| " x = self.conv2(x)\n", |
| " x = x.relu()\n", |
| " x = self.conv3(x)\n", |
| " x = x.flatten(1)\n", |
| " x = self.fc(x)\n", |
| " return x" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "52c600e9-207a-41ec-93b4-5d940827bda0", |
| "metadata": {}, |
| "source": [ |
| "And let's generate some random data" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 2, |
| "id": "0001a907-f5c9-4532-9ee9-2e94b8487d08", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "x_train = torch.randn(20, 3, 32, 32, device=device)\n", |
| "x_test = torch.randn(5, 3, 32, 32, device=device)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "8af210fe-9613-48ee-a96c-d0836458b0f1", |
| "metadata": {}, |
| "source": [ |
| "## Create a function version of the model\n", |
| "\n", |
| "functorch transforms operate on functions. In particular, to compute the NTK, we will need a function that accepts the parameters of the model and a single input (as opposed to a batch of inputs!) and returns a single output.\n", |
| "\n", |
| "We'll use functorch's `make_functional` to accomplish the first step. If your module has buffers, you'll want to use `make_functional_with_buffers` instead." |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 3, |
| "id": "e6b4bb59-bdde-46cd-8a28-7fd00a37a387", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "net = CNN().to(device)\n", |
| "fnet, params = make_functional(net)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "319276a4-da45-499a-af47-0677107559b6", |
| "metadata": {}, |
| "source": [ |
| "Keep in mind that the model was originally written to accept a batch of input data points. In our CNN example, there are no inter-batch operations. That is, each data point in the batch is independent of other data points. With this assumption in mind, we can easily generate a function that evaluates the model on a single data point:" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 4, |
| "id": "0b8b4021-eb10-4a50-9d99-3817cb0ce4cc", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "def fnet_single(params, x):\n", |
| " return fnet(params, x.unsqueeze(0)).squeeze(0)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "62bc6b5a-31fa-411e-8069-e6c1f6d05248", |
| "metadata": {}, |
| "source": [ |
| "## Compute the NTK: method 1 (Jacobian contraction)\n", |
| "\n", |
| "We're ready to compute the empirical NTK. The empirical NTK for two data points $x_1$ and $x_2$ is defined as the matrix product between the Jacobian of the model evaluated at $x_1$ and the Jacobian of the model evaluated at $x_2$:\n", |
| "\n", |
| "$$J_{net}(x_1) J_{net}^T(x_2)$$\n", |
| "\n", |
| "In the batched case where $x_1$ is a batch of data points and $x_2$ is a batch of data points, then we want the matrix product between the Jacobians of all combinations of data points from $x_1$ and $x_2$.\n", |
| "\n", |
| "The first method consists of doing just that - computing the two Jacobians, and contracting them. Here's how to compute the NTK in the batched case:" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 5, |
| "id": "99a38a4b-64d3-4e13-bd63-2d71e8dd6840", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2):\n", |
| " # Compute J(x1)\n", |
| " jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)\n", |
| " jac1 = [j.flatten(2) for j in jac1]\n", |
| " \n", |
| " # Compute J(x2)\n", |
| " jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)\n", |
| " jac2 = [j.flatten(2) for j in jac2]\n", |
| " \n", |
| " # Compute J(x1) @ J(x2).T\n", |
| " result = torch.stack([torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac1, jac2)])\n", |
| " result = result.sum(0)\n", |
| " return result" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 6, |
| "id": "cbf54d2b-c4bc-46bd-9e55-e1471d639a4e", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "torch.Size([20, 5, 10, 10])\n" |
| ] |
| } |
| ], |
| "source": [ |
| "result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test)\n", |
| "print(result.shape)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "ea844f45-98fb-4cba-8056-644292b968ab", |
| "metadata": {}, |
| "source": [ |
| "In some cases, you may only want the diagonal or the trace of this quantity, especially if you know beforehand that the network architecture results in an NTK where the non-diagonal elements can be approximated by zero. It's easy to adjust the above function to do that:" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 7, |
| "id": "aae760c9-e906-4fda-b490-1126a86b7e96", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "def empirical_ntk_jacobian_contraction(fnet_single, params, x1, x2, compute='full'):\n", |
| " # Compute J(x1)\n", |
| " jac1 = vmap(jacrev(fnet_single), (None, 0))(params, x1)\n", |
| " jac1 = [j.flatten(2) for j in jac1]\n", |
| " \n", |
| " # Compute J(x2)\n", |
| " jac2 = vmap(jacrev(fnet_single), (None, 0))(params, x2)\n", |
| " jac2 = [j.flatten(2) for j in jac2]\n", |
| " \n", |
| " # Compute J(x1) @ J(x2).T\n", |
| " einsum_expr = None\n", |
| " if compute == 'full':\n", |
| " einsum_expr = 'Naf,Mbf->NMab'\n", |
| " elif compute == 'trace':\n", |
| " einsum_expr = 'Naf,Maf->NM'\n", |
| " elif compute == 'diagonal':\n", |
| " einsum_expr = 'Naf,Maf->NMa'\n", |
| " else:\n", |
| " assert False\n", |
| " \n", |
| " result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)])\n", |
| " result = result.sum(0)\n", |
| " return result" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 8, |
| "id": "42d974f3-1f9d-4953-8677-5ee22cfc67eb", |
| "metadata": {}, |
| "outputs": [ |
| { |
| "name": "stdout", |
| "output_type": "stream", |
| "text": [ |
| "torch.Size([20, 5])\n" |
| ] |
| } |
| ], |
| "source": [ |
| "result = empirical_ntk_jacobian_contraction(fnet_single, params, x_train, x_test, 'trace')\n", |
| "print(result.shape)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "6c941e5d-51d7-47b2-80ee-edcd4aee6aaa", |
| "metadata": {}, |
| "source": [ |
| "The asymptotic time complexity of this method is $N O [FP]$ (time to compute the Jacobians) $ + N^2 O^2 P$ (time to contract the Jacobians), where $N$ is the batch size of $x_1$ and $x_2$, $O$ is the model's output size, $P$ is the total number of parameters, and $[FP]$ is the cost of a single forward pass through the model. See section section 3.2 in [Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720) for details." |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "6c931e5d-51d7-47b2-80ee-ddcd4aee6aaa", |
| "metadata": {}, |
| "source": [ |
| "## Compute the NTK: method 2 (NTK-vector products)\n", |
| "\n", |
| "The next method we will discuss is a way to compute the NTK using NTK-vector products.\n", |
| "\n", |
| "This method reformulates NTK as a stack of NTK-vector products applied to columns of an identity matrix $I_O$ of size $O\\times O$ (where $O$ is the output size of the model):\n", |
| "\n", |
| "$$J_{net}(x_1) J_{net}^T(x_2) = J_{net}(x_1) J_{net}^T(x_2) I_{O} = \\left[J_{net}(x_1) \\left[J_{net}^T(x_2) e_o\\right]\\right]_{o=1}^{O},$$\n", |
| "where $e_o\\in \\mathbb{R}^O$ are column vectors of the identity matrix $I_O$.\n", |
| "\n", |
| "- Let $\\textrm{vjp}_o = J_{net}^T(x_2) e_o$. We can use a vector-Jacobian product to compute this.\n", |
| "- Now, consider $J_{net}(x_1) \\textrm{vjp}_o$. This is a Jacobian-vector product!\n", |
| "- Finally, we can run the above computation in parallel over all columns $e_o$ of $I_O$ using `vmap`.\n", |
| "\n", |
| "This suggests that we can use a combination of reverse-mode AD (to compute the vector-Jacobian product) and forward-mode AD (to compute the Jacobian-vector product) to compute the NTK.\n", |
| "\n", |
| "Let's code that up:" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 9, |
| "id": "dc4b49d7-3096-45d5-a7a1-7032309a2613", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "def empirical_ntk_ntk_vps(func, params, x1, x2, compute='full'):\n", |
| " def get_ntk(x1, x2):\n", |
| " def func_x1(params):\n", |
| " return func(params, x1)\n", |
| "\n", |
| " def func_x2(params):\n", |
| " return func(params, x2)\n", |
| "\n", |
| " output, vjp_fn = vjp(func_x1, params)\n", |
| "\n", |
| " def get_ntk_slice(vec):\n", |
| " # This computes vec @ J(x2).T\n", |
| " # `vec` is some unit vector (a single slice of the Identity matrix)\n", |
| " vjps = vjp_fn(vec)\n", |
| " # This computes J(X1) @ vjps\n", |
| " _, jvps = jvp(func_x2, (params,), vjps)\n", |
| " return jvps\n", |
| "\n", |
| " # Here's our identity matrix\n", |
| " basis = torch.eye(output.numel(), dtype=output.dtype, device=output.device).view(output.numel(), -1)\n", |
| " return vmap(get_ntk_slice)(basis)\n", |
| " \n", |
| " # get_ntk(x1, x2) computes the NTK for a single data point x1, x2\n", |
| " # Since the x1, x2 inputs to empirical_ntk_ntk_vps are batched,\n", |
| " # we actually wish to compute the NTK between every pair of data points\n", |
| " # between {x1} and {x2}. That's what the vmaps here do.\n", |
| " result = vmap(vmap(get_ntk, (None, 0)), (0, None))(x1, x2)\n", |
| " \n", |
| " if compute == 'full':\n", |
| " return result\n", |
| " if compute == 'trace':\n", |
| " return torch.einsum('NMKK->NM', result)\n", |
| " if compute == 'diagonal':\n", |
| " return torch.einsum('NMKK->NMK', result)" |
| ] |
| }, |
| { |
| "cell_type": "code", |
| "execution_count": 10, |
| "id": "f750544f-9e48-47fe-9f9b-e1b8ae49b245", |
| "metadata": {}, |
| "outputs": [], |
| "source": [ |
| "result_from_jacobian_contraction = empirical_ntk_jacobian_contraction(fnet_single, params, x_test, x_train)\n", |
| "result_from_ntk_vps = empirical_ntk_ntk_vps(fnet_single, params, x_test, x_train)\n", |
| "assert torch.allclose(result_from_jacobian_contraction, result_from_ntk_vps, atol=1e-5)" |
| ] |
| }, |
| { |
| "cell_type": "markdown", |
| "id": "84253466-971d-4475-999c-fe3de6bd25b5", |
| "metadata": {}, |
| "source": [ |
| "Our code for `empirical_ntk_ntk_vps` looks like a direct translation from the math above! This showcases the power of function transforms: good luck trying to write an efficient version of the above using stock PyTorch.\n", |
| "\n", |
| "The asymptotic time complexity of this method is $N^2 O [FP]$, where $N$ is the batch size of $x_1$ and $x_2$, $O$ is the model's output size, and $[FP]$ is the cost of a single forward pass through the model. Hence this method performs more forward passes through the network than method 1, Jacobian contraction ($N^2 O$ instead of $N O$), but avoids the contraction cost altogether (no $N^2 O^2 P$ term, where $P$ is the total number of model's parameters). Therefore, this method is preferable when $O P$ is large relative to $[FP]$, such as fully-connected (not convolutional) models with many outputs $O$. Memory-wise, both methods should be comparable. See section 3.3 in [Fast Finite Width Neural Tangent Kernel](https://arxiv.org/abs/2206.08720) for details." |
| ] |
| } |
| ], |
| "metadata": { |
| "kernelspec": { |
| "display_name": "Python 3 (ipykernel)", |
| "language": "python", |
| "name": "python3" |
| }, |
| "language_info": { |
| "codemirror_mode": { |
| "name": "ipython", |
| "version": 3 |
| }, |
| "file_extension": ".py", |
| "mimetype": "text/x-python", |
| "name": "python", |
| "nbconvert_exporter": "python", |
| "pygments_lexer": "ipython3", |
| "version": "3.9.7" |
| } |
| }, |
| "nbformat": 4, |
| "nbformat_minor": 5 |
| } |