blob: 252fb3dcb8d2f3982fe9600562476e422cc42cfc [file] [log] [blame]
# Owner(s): ["oncall: jit"]
import os
import sys
import unittest
from torch.testing._internal.common_utils import (
enable_profiling_mode_for_profiling_tests, GRAPH_EXECUTOR, ProfilingMode,
set_default_dtype,
)
import torch
import torch.nn as nn
import torch.nn.functional as F
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA
from torch.testing._internal.common_utils import slowTest, suppress_warnings
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
try:
import torchvision
HAS_TORCHVISION = True
except ImportError:
HAS_TORCHVISION = False
except RuntimeError:
HAS_TORCHVISION = False
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
class MnistNet(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.reshape(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
class TestModels(JitTestCase):
@staticmethod
def _test_dcgan_models(self, device, check_export_import=True):
class DCGANGenerator(nn.Module):
def __init__(self, nz, ngf, nc):
super().__init__()
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
def forward(self, input):
return self.main(input)
class DCGANDiscriminator(nn.Module):
def __init__(self, nc, ndf):
super().__init__()
self.main = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
return self.main(input).view(-1, 1).squeeze(1)
bs, nz, ngf, nc, ndf = 5, 6, 9, 3, 10
self.checkTrace(DCGANGenerator(nz, ngf, nc).to(device),
(torch.rand(bs, nz, 1, 1, device=device),),
export_import=check_export_import)
example_input = DCGANGenerator(nz, ngf, nc).to(device)(torch.rand(bs, nz, 1, 1, device=device))
self.checkTrace(DCGANDiscriminator(nc, ndf).to(device), (example_input,),
export_import=check_export_import)
def test_dcgan_models(self):
# Note: Can sometimes fail with low precision if run with float dtype
with set_default_dtype(torch.double):
self._test_dcgan_models(self, device='cpu')
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_dcgan_models_cuda(self):
# Note: Can sometimes fail with low precision if run with float dtype
with set_default_dtype(torch.double):
# XXX: export_import on CUDA modules doesn't work (#11480)
self._test_dcgan_models(self, device='cuda', check_export_import=False)
@staticmethod
def _test_neural_style(self, device, check_export_import=True):
class TransformerNet(torch.nn.Module):
def __init__(self):
super().__init__()
# Initial convolution layers
self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
self.in2 = torch.nn.InstanceNorm2d(64, affine=True)
self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
self.in3 = torch.nn.InstanceNorm2d(128, affine=True)
# Residual layers
self.res1 = ResidualBlock(128)
self.res2 = ResidualBlock(128)
self.res3 = ResidualBlock(128)
self.res4 = ResidualBlock(128)
self.res5 = ResidualBlock(128)
# Upsampling Layers
self.deconv1 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2)
self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2)
self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1)
# Non-linearities
self.relu = torch.nn.ReLU()
def forward(self, X):
y = self.relu(self.in1(self.conv1(X)))
y = self.relu(self.in2(self.conv2(y)))
y = self.relu(self.in3(self.conv3(y)))
y = self.res1(y)
y = self.res2(y)
y = self.res3(y)
y = self.res4(y)
y = self.res5(y)
y = self.relu(self.in4(self.deconv1(y)))
y = self.relu(self.in5(self.deconv2(y)))
y = self.deconv3(y)
return y
class ConvLayer(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride):
super().__init__()
reflection_padding = kernel_size // 2
self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
def forward(self, x):
out = self.reflection_pad(x)
out = self.conv2d(out)
return out
class ResidualBlock(torch.nn.Module):
"""ResidualBlock
introduced in: https://arxiv.org/abs/1512.03385
recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
"""
def __init__(self, channels):
super().__init__()
self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
self.relu = torch.nn.ReLU()
def forward(self, x):
residual = x
out = self.relu(self.in1(self.conv1(x)))
out = self.in2(self.conv2(out))
out = out + residual
return out
class UpsampleConvLayer(torch.nn.Module):
"""UpsampleConvLayer
Upsamples the input and then does a convolution. This method gives better results
compared to ConvTranspose2d.
ref: http://distill.pub/2016/deconv-checkerboard/
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
super().__init__()
self.upsample = upsample
if upsample:
self.upsample_layer = torch.nn.Upsample(mode='nearest', scale_factor=upsample)
reflection_padding = kernel_size // 2
self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
def forward(self, x):
x_in = x
if self.upsample:
x_in = self.upsample_layer(x_in)
out = self.reflection_pad(x_in)
out = self.conv2d(out)
return out
self.checkTrace(TransformerNet(), (torch.rand(5, 3, 16, 16),), export_import=check_export_import)
@slowTest
def test_neural_style(self):
self._test_neural_style(self, device='cpu')
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_neural_style_cuda(self):
# XXX: export_import on CUDA modules doesn't work (#11480)
self._test_neural_style(self, device='cuda', check_export_import=False)
@unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.LEGACY, "Bug found in deprecated executor")
@staticmethod
def _test_mnist(self, device, check_export_import=True):
# eval() is present because dropout makes this nondeterministic
with enable_profiling_mode_for_profiling_tests():
self.checkTrace(MnistNet().to(device).eval(), (torch.rand(5, 1, 28, 28, device=device),),
export_import=check_export_import)
def test_mnist(self):
self._test_mnist(self, device='cpu')
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_mnist_cuda(self):
# XXX: export_import on CUDA modules doesn't work (#11480)
self._test_mnist(self, device='cuda', check_export_import=False)
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_mnist_training_leaks_no_memory_cuda(self):
net = MnistNet().cuda()
# MnistNet uses dropout, don't check its trace
traced_net = torch.jit.trace(net, [torch.randn(5, 1, 28, 28, device='cuda')],
check_trace=False)
def train(iters):
for _ in range(iters):
# Get some fake data
inp = torch.randn(5, 1, 28, 28, device='cuda')
out = traced_net(inp)
# Here's some fake loss
out.sum().backward()
# Zero out grads
traced_net.zero_grad()
# Set it up so the params have .grad fields so they are not reported as leaks
train(1)
with self.assertLeaksNoCudaTensors():
train(5)
@staticmethod
def _test_reinforcement_learning(self, device, test_export_import=True):
class Policy(nn.Module):
def __init__(self):
super().__init__()
self.affine1 = nn.Linear(4, 128)
self.affine2 = nn.Linear(128, 2)
def forward(self, x):
x = F.relu(self.affine1(x))
action_scores = self.affine2(x)
return F.softmax(action_scores, dim=1)
with enable_profiling_mode_for_profiling_tests():
self.checkTrace(Policy().to(device), (torch.rand(1, 4, device=device),),
export_import=test_export_import)
def test_reinforcement_learning(self):
self._test_reinforcement_learning(self, device='cpu')
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_reinforcement_learning_cuda(self):
# XXX: export_import on CUDA modules doesn't work (#11480)
self._test_reinforcement_learning(self, device='cuda', test_export_import=False)
@staticmethod
def _test_snli(self, device, check_export_import=True, quantized=False):
class Bottle(nn.Module):
def forward(self, input):
if len(input.size()) <= 2:
return super().forward(input)
size = input.size()[:2]
out = super().forward(input.view(size[0] * size[1], -1))
return out.view(size[0], size[1], -1)
class Linear(Bottle, nn.Linear):
pass
class Encoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
input_size = config.d_proj if config.projection else config.d_embed
dropout = 0 if config.n_layers == 1 else config.dp_ratio
self.rnn = nn.LSTM(input_size=input_size, hidden_size=config.d_hidden,
num_layers=config.n_layers, dropout=dropout,
bidirectional=config.birnn)
def forward(self, inputs):
batch_size = inputs.size()[1]
state_shape = self.config.n_cells, batch_size, self.config.d_hidden
h0 = c0 = inputs.new_zeros(state_shape)
outputs, (ht, ct) = self.rnn(inputs, (h0, c0))
return ht[-1] if not self.config.birnn else ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1)
class SNLIClassifier(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.embed = nn.Embedding(config.n_embed, config.d_embed)
self.projection = Linear(config.d_embed, config.d_proj)
self.encoder = Encoder(config)
self.dropout = nn.Dropout(p=config.dp_ratio)
self.relu = nn.ReLU()
seq_in_size = 2 * config.d_hidden
if self.config.birnn:
seq_in_size *= 2
lin_config = [seq_in_size] * 2
self.out = nn.Sequential(
Linear(*lin_config),
self.relu,
self.dropout,
Linear(*lin_config),
self.relu,
self.dropout,
Linear(*lin_config),
self.relu,
self.dropout,
Linear(seq_in_size, config.d_out))
def forward(self, premise, hypothesis):
prem_embed = self.embed(premise)
hypo_embed = self.embed(hypothesis)
if self.config.fix_emb:
prem_embed = prem_embed.detach()
hypo_embed = hypo_embed.detach()
if self.config.projection:
prem_embed = self.relu(self.projection(prem_embed))
hypo_embed = self.relu(self.projection(hypo_embed))
premise = self.encoder(prem_embed)
hypothesis = self.encoder(hypo_embed)
scores = self.out(torch.cat([premise, hypothesis], 1))
return scores
class Config:
n_embed = 100
d_embed = 100
d_proj = 300
dp_ratio = 0.0 # For deterministic testing TODO: change by fixing seed in checkTrace?
d_hidden = 30
birnn = True
d_out = 300
fix_emb = True
projection = True
n_layers = 2
n_cells = 4 # 2 * n_layers because birnn = True
premise = torch.LongTensor(48, 64).random_(0, 100).to(device)
hypothesis = torch.LongTensor(24, 64).random_(0, 100).to(device)
if quantized:
snli = SNLIClassifier(Config()).cpu()
torch.jit.quantized.quantize_linear_modules(snli)
# we don't do export/import checks because we would need to call
# _pack/_unpack
self.checkTrace(snli, (premise, hypothesis), inputs_require_grads=False,
export_import=False)
else:
self.checkTrace(SNLIClassifier(Config()).to(device), (premise, hypothesis),
inputs_require_grads=False, export_import=check_export_import)
@slowTest
def test_snli(self):
self._test_snli(self, device='cpu')
@skipIfNoFBGEMM
# Suppression: this exercises a deprecated API
@suppress_warnings
def test_snli_quantized(self):
self._test_snli(self, device='cpu', quantized=True)
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_snli_cuda(self):
# XXX: export_import on CUDA modules doesn't work (#11480)
self._test_snli(self, device='cuda', check_export_import=False)
@staticmethod
def _test_super_resolution(self, device, check_export_import=True):
class Net(nn.Module):
def __init__(self, upscale_factor):
super().__init__()
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
def forward(self, x):
x = self.relu(self.conv1(x))
x = self.relu(self.conv2(x))
x = self.relu(self.conv3(x))
x = self.pixel_shuffle(self.conv4(x))
return x
net = Net(upscale_factor=4).to(device)
self.checkTrace(net, (torch.rand(5, 1, 32, 32, device=device),),
export_import=check_export_import)
@slowTest
def test_super_resolution(self):
self._test_super_resolution(self, device='cpu')
@unittest.skipIf(not RUN_CUDA, 'no CUDA')
def test_super_resolution_cuda(self):
# XXX: export_import on CUDA modules doesn't work (#11480)
self._test_super_resolution(self, device='cuda', check_export_import=False)
@suppress_warnings
def test_time_sequence_prediction(self):
class Sequence(torch.jit.ScriptModule):
def __init__(self):
super().__init__()
self.lstm1 = nn.LSTMCell(1, 51)
self.lstm2 = nn.LSTMCell(51, 51)
self.linear = nn.Linear(51, 1)
@torch.jit.script_method
def forward(self, input):
# TODO: add future as input with default val
# see https://github.com/pytorch/pytorch/issues/8724
outputs = torch.empty((3, 0))
h_t = torch.zeros((3, 51))
c_t = torch.zeros((3, 51))
h_t2 = torch.zeros((3, 51))
c_t2 = torch.zeros((3, 51))
output = torch.zeros([3, 51])
future = 2
# TODO: chunk call should appear as the for loop iterable
# We hard-code it to 4 for now.
a, b, c, d = input.chunk(input.size(1), dim=1)
for input_t in (a, b, c, d):
h_t, c_t = self.lstm1(input_t, (h_t, c_t))
h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
output = self.linear(h_t2)
outputs = torch.cat((outputs, output), 1)
for _ in range(future): # if we should predict the future
h_t, c_t = self.lstm1(output, (h_t, c_t))
h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
output = self.linear(h_t2)
outputs = torch.cat((outputs, output), 1)
return outputs
class Traced(nn.Module):
def __init__(self):
super().__init__()
self.seq = Sequence()
def forward(self, input):
return self.seq.forward(input)
# disabled due to a jitter issues that will be fixed by using load/store in the compiler
with torch._jit_internal._disable_emit_hooks():
# TODO: toggle export_import once above issues are fixed
self.checkTrace(Traced(), (torch.rand(3, 4),),
export_import=False)
@staticmethod
def _test_vae(self, device, check_export_import=True, quantized=False):
class VAE(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(784, 400)
self.fc21 = nn.Linear(400, 20)
self.fc22 = nn.Linear(400, 20)
self.fc3 = nn.Linear(20, 400)
self.fc4 = nn.Linear(400, 784)
def encode(self, x):
h1 = F.relu(self.fc1(x))
return self.fc21(h1), self.fc22(h1)
def reparameterize(self, mu, logvar):
if self.training:
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return eps.mul(std).add_(mu)
else:
return mu
def decode(self, z):
h3 = F.relu(self.fc3(z))
return torch.sigmoid(self.fc4(h3))
def forward(self, x):
mu, logvar = self.encode(x.view(-1, 784))
z = self.reparameterize(mu, logvar)
return self.decode(z), mu, logvar
if quantized:
vae = VAE().to(device).eval()
torch.jit.quantized.quantize_linear_modules(vae)
# We don't do export/import checks because we would need to call
# _unpack and _pack
self.checkTrace(vae, (torch.rand(128, 1, 28, 28, device=device),),
export_import=False, allow_unused=True,
inputs_require_grads=False)
else:
with enable_profiling_mode_for_profiling_tests():
# eval() is present because randn_like makes this nondeterministic
self.checkTrace(VAE().to(device).eval(), (torch.rand(128, 1, 28, 28, device=device),),
export_import=check_export_import)
def test_vae(self):
self._test_vae(self, device='cpu')
@skipIfNoFBGEMM
# Suppression: this exercises a deprecated API
@suppress_warnings
def test_vae_quantized(self):
self._test_vae(self, device='cpu', quantized=True)
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_vae_cuda(self):
# XXX: export_import on CUDA modules doesn't work (#11480)
self._test_vae(self, device='cuda', check_export_import=False)
@slowTest
@skipIfNoTorchVision
def test_script_module_trace_resnet18(self):
x = torch.ones(1, 3, 224, 224)
m_orig = torch.jit.trace(torchvision.models.resnet18(), torch.ones(1, 3, 224, 224))
m_import = self.getExportImportCopy(m_orig)
input = torch.randn(1, 3, 224, 224, requires_grad=True)
output_orig = m_orig(input)
output_orig.sum().backward()
grad_orig = input.grad.clone()
input.grad.zero_()
output_import = m_import(input)
output_import.sum().backward()
grad_import = input.grad.clone()
self.assertEqual(output_orig, output_import)
self.assertEqual(grad_orig, grad_import)
@slowTest
@skipIfNoTorchVision
def test_script_module_script_resnet(self):
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class BasicBlock(torch.jit.ScriptModule):
expansion = 1
__constants__ = ['downsample']
def __init__(self, inplanes, planes, stride=1, downsample=None):
super().__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
self.stride = stride
@torch.jit.script_method
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class ResNet(torch.jit.ScriptModule):
__constants__ = ['layer1', 'layer2', 'layer3', 'layer4']
def __init__(self, block, layers, num_classes=1000):
super().__init__()
self.inplanes = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
nn.BatchNorm2d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
@torch.jit.script_method
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
resnet18 = ResNet(BasicBlock, [2, 2, 2, 2])
resnet18_imported = self.getExportImportCopy(resnet18)
input = torch.randn(1, 3, 224, 224, requires_grad=True)
output_orig = resnet18(input)
output_orig.sum().backward()
grad_orig = input.grad.clone()
input.grad.zero_()
output_import = resnet18_imported(input)
output_import.sum().backward()
grad_import = input.grad.clone()
self.assertEqual(output_orig, output_import)
self.assertEqual(grad_orig, grad_import)
@skipIfNoTorchVision
def test_alexnet(self):
x = torch.ones(1, 3, 224, 224)
model = torchvision.models.AlexNet()
with torch.random.fork_rng(devices=[]):
g, outputs, inputs = torch.jit._get_trace_graph(model, x, return_inputs=True)
self.run_pass('cse', g)
m = self.createFunctionFromGraph(g)
with torch.random.fork_rng(devices=[]):
self.assertEqual(outputs, m(*inputs))