blob: 6bbf3f02337afb0cd5dea9dce209bcb442e74ff0 [file] [log] [blame]
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import time
import torch
import torch.nn as nn
from functorch import make_functional
from functorch.compile import nnc_jit
torch._C._jit_override_can_fuse_on_cpu(True)
def bench(f, iters=100, warmup=10):
for _ in range(warmup):
f()
begin = time.time()
for _ in range(iters):
f()
print(time.time() - begin)
class Foo(nn.Module):
def __init__(self, num_layers=3, features=100):
super().__init__()
mods = []
for _ in range(num_layers):
mods.append(nn.Linear(features, features, bias=False))
self.mod = nn.Sequential(*mods)
def forward(self, x):
return (self.mod(x) ** 2).sum()
batch_size = 16
features = 64
num_layers = 8
inp = torch.randn((batch_size, features))
mod = Foo(num_layers, features)
jit_mod = torch.jit.script(mod)
func_model, weights = make_functional(mod)
lr = 1.0
def functional_step(x, weights):
weights = [weight.detach().requires_grad_() for weight in weights]
out = func_model(weights, x)
out.backward()
new_weights = [weight - lr * weight.grad for weight in weights]
return out, new_weights
optim = torch.optim.SGD(
jit_mod.parameters(), lr=lr, momentum=0, dampening=0, weight_decay=0
)
def jit_step(x, weights):
optim.zero_grad()
loss = jit_mod(x)
loss.backward()
optim.step()
return loss, None
def train(train_step, weights):
torch.manual_seed(16)
train_step(inp, weights)
begin = time.time()
for itr in range(1000):
loss, weights = train_step(torch.randn(batch_size, features), weights)
if itr % 200 == 0:
print(f"Loss at {itr}: {loss}")
print("Time taken: ", time.time() - begin)
print()
grad_pt = functional_step
grad_nnc = nnc_jit(functional_step)
print("Starting PT training")
train(grad_pt, weights)
print("Starting NNC training")
train(grad_nnc, weights)
print("Starting JIT training")
train(jit_step, None)