blob: d10f259ccef5e6abf4685c1d2e6bc57e0b8148d6 [file] [log] [blame]
# Owner(s): ["oncall: quantization"]
import copy
import torch
import torch.nn as nn
import torch.ao.nn.quantized as nnq
import torch.ao.nn.intrinsic as nni
import torch.ao.nn.intrinsic.quantized as nniq
import torch.ao.nn.intrinsic.qat as nniqat
from torch.ao.quantization import (
quantize,
prepare,
convert,
prepare_qat,
quantize_qat,
fuse_modules,
fuse_modules_qat,
QConfig,
default_qconfig,
default_qat_qconfig,
)
from torch.testing._internal.common_quantization import (
QuantizationTestCase,
ModelForFusion,
ModelWithSequentialFusion,
ModelForLinearBNFusion,
ModelForFusionWithBias,
ModelForConvTransposeBNFusion,
SingleLayerLinearModel,
test_only_eval_fn,
test_only_train_fn,
skipIfNoFBGEMM,
)
from torch.testing._internal.common_quantized import (
override_quantized_engine,
supported_qengines,
)
@skipIfNoFBGEMM
class TestFuseEager(QuantizationTestCase):
def test_fuse_module_train(self):
model = ModelForFusion(default_qat_qconfig).train()
# Test step by step fusion
model = fuse_modules_qat(model, ['conv1', 'bn1', 'relu1'])
model = fuse_modules_qat(model, ['sub1.conv', 'sub1.bn'])
self.assertEqual(type(model.conv1), nni.ConvBnReLU2d,
msg="Fused Conv + BN + Relu first layer")
self.assertEqual(type(model.bn1), torch.nn.Identity,
msg="Fused Conv + BN + Relu (skipped BN)")
self.assertEqual(type(model.relu1), torch.nn.Identity,
msg="Fused Conv + BN + Relu (skipped Relu)")
self.assertEqual(type(model.sub1.conv), nni.ConvBn2d,
msg="Fused submodule Conv + BN")
self.assertEqual(type(model.sub1.bn), torch.nn.Identity,
msg="Fused submodule Conv + BN (skipped BN)")
self.assertEqual(type(model.sub2.conv), torch.nn.Conv2d,
msg="Non-fused submodule Conv")
self.assertEqual(type(model.sub2.relu), torch.nn.ReLU,
msg="Non-fused submodule ReLU")
model = prepare_qat(model)
self.checkObservers(model)
def checkQAT(model):
self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d)
self.assertEqual(type(model.bn1), nn.Identity)
self.assertEqual(type(model.relu1), nn.Identity)
self.assertEqual(type(model.sub1.conv), nniqat.ConvBn2d)
self.assertEqual(type(model.sub1.bn), nn.Identity)
self.assertEqual(type(model.sub2.conv), nn.Conv2d)
self.assertEqual(type(model.sub2.relu), nn.ReLU)
checkQAT(model)
test_only_train_fn(model, self.img_data_1d_train)
model = convert(model)
def checkQuantized(model):
self.assertEqual(type(model.conv1), nniq.ConvReLU2d)
self.assertEqual(type(model.bn1), nn.Identity)
self.assertEqual(type(model.relu1), nn.Identity)
self.assertEqual(type(model.sub1.conv), nnq.Conv2d)
self.assertEqual(type(model.sub1.bn), nn.Identity)
self.assertEqual(type(model.sub2.conv), nn.Conv2d)
self.assertEqual(type(model.sub2.relu), nn.ReLU)
test_only_eval_fn(model, self.img_data_1d)
self.checkNoQconfig(model)
with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"):
checkQuantized(model)
model = ModelForFusion(default_qat_qconfig).train()
model = fuse_modules_qat(
model,
[['conv1', 'bn1', 'relu1'],
['sub1.conv', 'sub1.bn']])
model = quantize_qat(model, test_only_train_fn, [self.img_data_1d_train])
with self.assertRaisesRegex(RuntimeError, "Could not run 'aten::native_batch_norm' with arguments from the 'QuantizedCPU'"):
checkQuantized(model)
def test_fuse_module_eval(self):
model = ModelForFusion(default_qconfig)
model.eval()
model = fuse_modules(
model,
[['conv3', 'bn3', 'relu4'],
['conv1', 'bn1', 'relu1'],
['conv2', 'relu2'],
['bn2', 'relu3'],
['sub1.conv', 'sub1.bn']])
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
msg="Fused Conv + BN + Relu first layer (BN is folded)")
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
msg="Fused Conv + BN + Relu (Conv + folded BN only)")
self.assertEqual(type(model.conv1[1]), nn.ReLU,
msg="Fused Conv + BN + Relu second layer (Relu only)")
self.assertEqual(type(model.bn1), nn.Identity,
msg="Fused Conv + BN + Relu second layer (Skipped BN)")
self.assertEqual(type(model.relu1), nn.Identity,
msg="Fused Conv + BN + Relu second layer (Skipped Relu)")
self.assertEqual(type(model.conv2), nni.ConvReLU3d,
msg="Fused Conv + BN + Relu first layer (BN is folded)")
self.assertEqual(type(model.bn2), nni.BNReLU3d,
msg="Fused BN + Relu first layer (Relu is folded))")
self.assertEqual(type(model.relu3), nn.Identity,
msg="Fused BN + Relu second layer (Skipped Relu)")
self.assertEqual(type(model.conv2[0]), nn.Conv3d,
msg="Fused Conv + BN + Relu (Conv + folded BN only)")
self.assertEqual(type(model.conv2[1]), nn.ReLU,
msg="Fused Conv + BN + Relu second layer (Relu only)")
self.assertEqual(type(model.relu2), nn.Identity,
msg="Fused Conv + BN + Relu second layer (Skipped Relu)")
self.assertEqual(type(model.conv3), nni.ConvReLU1d,
msg="Fused Conv + Relu for Conv1d (folded BN)")
self.assertEqual(type(model.conv3[0]), nn.Conv1d,
msg="Fused Conv + Relu for Conv1d ")
self.assertEqual(type(model.conv3[1]), nn.ReLU,
msg="Fused Conv + Relu for Conv1d")
self.assertEqual(type(model.bn3), nn.Identity,
msg="Fused Conv + BN + Relu for Conv1d (Skipped BN)")
self.assertEqual(type(model.sub1.conv), nn.Conv2d,
msg="Fused submodule Conv + folded BN")
self.assertEqual(type(model.sub1.bn), nn.Identity,
msg="Fused submodule (skipped BN)")
self.assertEqual(type(model.sub2.conv), nn.Conv2d,
msg="Non-fused submodule Conv")
self.assertEqual(type(model.sub2.relu), torch.nn.ReLU,
msg="Non-fused submodule ReLU")
model = prepare(model)
self.checkObservers(model)
test_only_eval_fn(model, self.img_data_1d)
model = convert(model)
def checkQuantized(model):
self.assertEqual(type(model.conv3), nniq.ConvReLU1d)
self.assertEqual(type(model.conv1), nniq.ConvReLU2d)
self.assertEqual(type(model.bn1), nn.Identity)
self.assertEqual(type(model.relu1), nn.Identity)
self.assertEqual(type(model.sub1.conv), nnq.Conv2d)
self.assertEqual(type(model.sub1.bn), nn.Identity)
self.assertEqual(type(model.sub2.conv), nn.Conv2d)
self.assertEqual(type(model.sub2.relu), nn.ReLU)
self.assertEqual(type(model.bn2), nniq.BNReLU3d)
test_only_eval_fn(model, self.img_data_1d)
self.checkNoQconfig(model)
checkQuantized(model)
model = ModelForFusion(default_qconfig).eval()
model = fuse_modules(
model,
[['conv1', 'bn1', 'relu1'],
['conv2', 'relu2'],
['bn2', 'relu3'],
['sub1.conv', 'sub1.bn'],
['conv3', 'bn3', 'relu4']])
model = quantize(model, test_only_eval_fn, [self.img_data_1d])
checkQuantized(model)
def test_fusion_sequential_model_train(self):
for qengine in supported_qengines:
with override_quantized_engine(qengine):
model = ModelWithSequentialFusion().train()
model.to(torch.float)
fuse_modules_qat(
model, [['conv1', 'relu1'] ,
['features.0.0', 'features.0.1', 'features.0.2'],
['features.1.0', 'features.1.1', 'features.1.2'],
['features.2.0', 'features.2.1', 'features.2.2'],
['classifier.0', 'classifier.1']],
inplace=True)
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
msg="Fused Conv + Relu: nni.ConvReLU2d")
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
msg="Fused Conv + Relu: Conv2d")
self.assertEqual(type(model.conv1[1]), nn.ReLU,
msg="Fused Conv + Relu: Relu")
self.assertEqual(type(model.relu1), nn.Identity,
msg="Fused Conv + Relu: Identity")
for i in range(3):
self.assertEqual(type(model.features[i][0]), nni.ConvBnReLU2d,
msg="Fused submodule Conv + folded BN")
self.assertEqual(type(model.features[i][1]), nn.Identity,
msg="Fused submodule (skipped BN)")
self.assertEqual(type(model.features[i][2]), nn.Identity,
msg="Non-fused submodule Conv")
self.assertEqual(type(model.classifier[0]), nni.LinearReLU)
self.assertEqual(type(model.classifier[1]), nn.Identity)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig(qengine)
prepare_qat(model, inplace=True)
self.checkObservers(model)
model(self.img_data_2d[0][0])
def checkQAT(model):
self.assertEqual(type(model.conv1), nniqat.ConvReLU2d)
self.assertEqual(type(model.relu1), nn.Identity)
for i in range(3):
self.assertEqual(type(model.features[i][0]), nniqat.ConvBnReLU2d,
msg="Fused submodule Conv + folded BN")
self.assertEqual(type(model.features[i][1]), nn.Identity,
msg="Fused submodule (skipped BN)")
self.assertEqual(type(model.features[i][2]), nn.Identity,
msg="Non-fused submodule Conv")
self.assertEqual(type(model.classifier[0]), nniqat.LinearReLU)
self.assertEqual(type(model.classifier[1]), nn.Identity)
checkQAT(model)
model(self.img_data_2d[1][0])
convert(model, inplace=True)
model(self.img_data_2d[1][0])
self.checkModelWithSequentialQuantized(model)
def test_fusion_sequential_model_eval(self):
for qengine in supported_qengines:
with override_quantized_engine(qengine):
model = ModelWithSequentialFusion().eval()
model.to(torch.float)
fuse_modules(
model,
[['conv1', 'relu1'],
['features.0.0', 'features.0.1', 'features.0.2'],
['features.1.0', 'features.1.1', 'features.1.2'],
['features.2.0', 'features.2.1', 'features.2.2'],
['classifier.0', 'classifier.1']],
inplace=True)
self.assertEqual(type(model.conv1), nni.ConvReLU2d,
msg="Fused Conv + Relu: nni.ConvReLU2d")
self.assertEqual(type(model.conv1[0]), nn.Conv2d,
msg="Fused Conv + Relu: Conv2d")
self.assertEqual(type(model.conv1[1]), nn.ReLU,
msg="Fused Conv + Relu: Relu")
self.assertEqual(type(model.relu1), nn.Identity,
msg="Fused Conv + Relu: Identity")
for i in range(3):
self.assertEqual(type(model.features[i][0]), nni.ConvReLU2d,
msg="Fused submodule Conv + folded BN")
self.assertEqual(type(model.features[i][1]), nn.Identity,
msg="Fused submodule (skipped BN)")
self.assertEqual(type(model.features[i][2]), nn.Identity,
msg="Non-fused submodule Conv")
self.assertEqual(type(model.classifier[0]), nni.LinearReLU)
self.assertEqual(type(model.classifier[1]), nn.Identity)
model.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
prepare(model, inplace=True)
self.checkObservers(model)
model(self.img_data_2d[0][0])
convert(model, inplace=True)
model(self.img_data_2d[1][0])
self.checkModelWithSequentialQuantized(model)
def checkModelWithSequentialQuantized(self, model):
self.assertEqual(type(model.conv1), nniq.ConvReLU2d)
self.assertEqual(type(model.relu1), nn.Identity)
for i in range(3):
self.assertEqual(type(model.features[i][0]), nniq.ConvReLU2d)
self.assertEqual(type(model.features[i][1]), nn.Identity)
self.assertEqual(type(model.features[i][2]), nn.Identity)
self.assertEqual(type(model.classifier[0]), nniq.LinearReLU)
self.assertEqual(type(model.classifier[1]), nn.Identity)
def test_fusion_conv_with_bias(self):
for qengine in supported_qengines:
with override_quantized_engine(qengine):
model_orig = ModelForFusionWithBias().train()
# reference model
model_ref = copy.deepcopy(model_orig)
# output with no fusion.
out_ref = model_ref(self.img_data_2d[0][0])
# fused model
model_orig.qconfig = QConfig(activation=torch.nn.Identity,
weight=torch.nn.Identity)
model = fuse_modules_qat(
model_orig,
[["conv1", "bn1", "relu1"],
["conv2", "bn2"]])
prep_model = prepare_qat(model, inplace=False)
# output with fusion but no observers.
out_fused = prep_model(self.img_data_2d[0][0])
self.assertEqual(out_ref, out_fused)
def checkBN(bn_ref, bn):
self.assertEqual(bn_ref.weight, bn.weight)
self.assertEqual(bn_ref.bias, bn.bias)
self.assertEqual(bn_ref.running_mean, bn.running_mean)
self.assertEqual(bn_ref.running_var, bn.running_var)
checkBN(model_ref.bn1, prep_model.conv1.bn)
checkBN(model_ref.bn2, prep_model.conv2.bn)
model.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
prepare_qat(model, inplace=True)
model(self.img_data_2d[0][0])
def checkQAT(model):
self.assertEqual(type(model.conv1), nniqat.ConvBnReLU2d)
self.assertEqual(type(model.bn1), nn.Identity)
self.assertEqual(type(model.relu1), nn.Identity)
self.assertEqual(type(model.conv2), nniqat.ConvBn2d)
self.assertEqual(type(model.bn2), nn.Identity)
checkQAT(model)
def test_fusion_linear_bn_eval(self):
model = ModelForLinearBNFusion().train()
inp1 = torch.randn(8, 20)
inp2 = torch.randn(8, 20)
# Get some interesting values into the running mean and variance.
model(inp1)
model.eval()
golden = model(inp2)
model = fuse_modules(model, [["fc", "bn"]])
self.assertEqual(type(model.bn), nn.Identity)
self.assertEqual(golden, model(inp2))
def test_fusion_convtranspose_bn_eval(self):
model = ModelForConvTransposeBNFusion().train()
inp1 = torch.randn(8, 3, 16)
inp2 = torch.randn(8, 3, 16)
# Get some interesting values into the running mean and variance.
model(inp1)
model.eval()
golden = model(inp2)
model = fuse_modules(model, [["conv1", "bn1"], ["conv2", "bn2"], ["conv3", "bn3"]])
self.assertEqual(type(model.bn1), nn.Identity)
self.assertEqual(type(model.bn2), nn.Identity)
self.assertEqual(type(model.bn3), nn.Identity)
self.assertEqual(golden, model(inp2))
def test_fuse_function_customization(self):
dummy_model = SingleLayerLinearModel().train()
dummy_model.eval()
# A custom fuse funct
def custom_fuse_func(module, is_qat, add_fuser_mapping):
return [torch.nn.Identity()]
dummy_model = fuse_modules(dummy_model, [["fc1"]], fuser_func=custom_fuse_func)
self.assertEqual(type(dummy_model.fc1), nn.Identity)
def test_forward_hooks_preserved(self):
r"""Test case that checks whether forward pre hooks of the first module and
post forward hooks of the last module in modules list passed to fusion function preserved.
(e.g. before fusion: [nn.Conv2d (with pre forward hooks), nn.BatchNorm2d, nn.ReLU (with post forward hooks)]
after fusion: [nni.ConvBnReLU2d (with pre and post hooks), nn.Identity, nn.Identity])
"""
model = ModelForFusion(default_qat_qconfig).train()
counter = {
'pre_forwards': 0,
'forwards': 0,
}
fused = False
def fw_pre_hook(fused_module_class, h_module, input):
if fused:
self.assertEqual(type(h_module), fused_module_class,
"After fusion owner of the first module's forward pre hook is not a fused module")
counter['pre_forwards'] += 1
def fw_hook(fused_module_class, h_module, input, output):
if fused:
self.assertEqual(type(h_module), fused_module_class,
"After fusion owner of the last module's forward hook is not a fused module")
counter['forwards'] += 1
# Registering two pre and two post forward hooks, thus expecting counter increment by two each inference
model.conv1.register_forward_pre_hook(lambda *args: fw_pre_hook(nni.ConvBnReLU2d, *args))
model.sub1.conv.register_forward_pre_hook(lambda *args: fw_pre_hook(nni.ConvBn2d, *args))
model.relu1.register_forward_hook(lambda *args: fw_hook(nni.ConvBnReLU2d, *args))
model.sub1.bn.register_forward_hook(lambda *args: fw_hook(nni.ConvBn2d, *args))
test_only_eval_fn(model, self.img_data_1d)
self.assertEqual(counter['pre_forwards'], 2 * len(self.img_data_1d))
self.assertEqual(counter['forwards'], 2 * len(self.img_data_1d))
model = fuse_modules_qat(model, ['conv1', 'bn1', 'relu1'])
model = fuse_modules_qat(model, ['sub1.conv', 'sub1.bn'])
fused = True
before_fusion_pre_count = counter['pre_forwards']
before_fusion_post_count = counter['forwards']
test_only_eval_fn(model, self.img_data_1d)
self.assertEqual(counter['pre_forwards'] - before_fusion_pre_count, 2 * len(self.img_data_1d))
self.assertEqual(counter['forwards'] - before_fusion_post_count, 2 * len(self.img_data_1d))
def test_fuse_modules_with_nested_hooks(self):
r"""Test case that checks whether a nested module with sub-sub modules registered with hooks
can be safely fused. Safeguard for issues similar to https://github.com/pytorch/pytorch/issues/105063
in the future.
"""
def myhook(*x):
return ""
for qengine in supported_qengines:
with override_quantized_engine(qengine):
model = ModelWithSequentialFusion().eval()
for sub_model in model.modules():
if isinstance(sub_model, nn.Sequential):
for layer in sub_model:
if hasattr(layer, 'register_forward_hook'):
layer.register_forward_hook(myhook)
fuse_modules(model, [['features.0.0', 'features.0.1', 'features.0.2']], inplace=True)
self.assertEqual(
type(model.features[0][0]),
nni.ConvReLU2d,
msg="Fused submodule Conv + folded BN"
)
self.assertEqual(
type(model.features[0][1]),
nn.Identity,
msg="Fused submodule (skipped BN)"
)
self.assertEqual(
type(model.features[0][2]),
nn.Identity,
msg="Non-fused submodule Conv"
)
if __name__ == '__main__':
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_quantization.py TESTNAME\n\n"
"instead."
)