blob: b44c885507740e8f8da12609128a76b051e453f5 [file] [log] [blame]
# -*- coding: utf-8 -*-
# Owner(s): ["module: unknown"]
import logging
import torch
import torch.ao.quantization as tq
from torch import nn
from torch.ao import sparsity
from torch.testing._internal.common_utils import TestCase
logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)
sparse_defaults = {
"sparsity_level": 0.8,
"sparse_block_shape": (1, 4),
"zeros_per_block": 4,
}
# This series of tests are to check the composability goals for sparsity and quantization. Namely
# that performing quantization and sparsity model manipulations in various orderings
# does not cause problems
class TestComposability(TestCase):
def _get_model_and_sparsifier_and_sparse_config(self, qconfig=None):
model = nn.Sequential(
nn.Linear(4, 4), # 0
nn.ReLU(),
nn.Linear(4, 4), # 2
nn.ReLU(),
tq.QuantStub(),
nn.Linear(4, 4), # 5
nn.ReLU(),
tq.DeQuantStub(),
)
if qconfig is None:
model[4].qconfig = tq.get_default_qconfig("fbgemm")
model[5].qconfig = tq.get_default_qconfig("fbgemm")
else:
model[4].qconfig = qconfig
model[5].qconfig = qconfig
sparsifier = sparsity.WeightNormSparsifier(**sparse_defaults)
sparse_config = [
{
"module": model[5],
"sparsity_level": 0.7,
"sparse_block_shape": (1, 4),
"zeros_per_block": 4,
},
model[0],
]
return model, sparsifier, sparse_config
def _squash_mask_calibrate_and_convert(self, model, sparsifier, input):
sparsifier.step()
sparsifier.squash_mask()
model(input)
tq.convert(model, inplace=True)
def _calculate_sparsity(self, tensor):
return ((tensor == 0).sum() / tensor.numel()).item()
# This test checks whether performing quantization prepare before sparse prepare
# causes any issues and verifies that the correct observers are inserted and that
# the quantized model works as expected
def test_q_prep_before_s_prep(self):
(
mod,
sparsifier,
sparse_config,
) = self._get_model_and_sparsifier_and_sparse_config()
tq.prepare(mod, inplace=True)
sparsifier.prepare(mod, config=sparse_config)
# check that correct modules had parametrizations added
self.assertTrue(hasattr(mod[0], "parametrizations"))
self.assertTrue(hasattr(mod[5], "parametrizations"))
# check that correct observers were inserted
self.assertTrue(hasattr(mod[5], "activation_post_process"))
self._squash_mask_calibrate_and_convert(
mod, sparsifier, torch.randn(1, 4, 4, 4)
)
# check that final module is the expected quantized module and that the model runs
self.assertTrue(isinstance(mod[5], torch.nn.quantized.Linear))
self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
# This test checks whether performing sparsity prepare before quantization prepare
# causes any issues. In particular, previous quantization flow was unable to match
# the post sparse prepare module names (adding parametrizations changes the module class names)
# which would result in those parametrized modules not being quantized. This test verifies that
# the fix for this was successful.
def test_s_prep_before_q_prep(self):
(
mod,
sparsifier,
sparse_config,
) = self._get_model_and_sparsifier_and_sparse_config()
sparsifier.prepare(mod, config=sparse_config)
tq.prepare(mod, inplace=True)
# check that correct modules had parametrizations added and
# that none were lost during prepare
self.assertTrue(hasattr(mod[0], "parametrizations"))
self.assertTrue(hasattr(mod[5], "parametrizations"))
# check that correct observers were inserted and that matching
# occured successfully
self.assertTrue(hasattr(mod[5], "activation_post_process"))
self._squash_mask_calibrate_and_convert(
mod, sparsifier, torch.randn(1, 4, 4, 4)
)
# check that final module is the expected quantized module and that the model runs
self.assertTrue(isinstance(mod[5], torch.nn.quantized.Linear))
self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
# if the sparsified modules have not undergone the final squash mask operation, its possible
# that the problem outlined in test_s_prep_before_q_prep would occur. This test verifies
# both that the fix to the convert flow avoids this issue and that the resulting quantized
# module uses the sparse version of the weight value.
def test_convert_without_squash_mask(self):
(
mod,
sparsifier,
sparse_config,
) = self._get_model_and_sparsifier_and_sparse_config()
sparsifier.prepare(mod, config=sparse_config)
tq.prepare(mod, inplace=True)
# check that correct modules had parametrizations added and
# that none were lost during prepare
self.assertTrue(hasattr(mod[0], "parametrizations"))
self.assertTrue(hasattr(mod[5], "parametrizations"))
# check that correct observers were inserted and that matching
# occured successfully
self.assertTrue(hasattr(mod[5], "activation_post_process"))
sparsifier.step()
sparsity_level = self._calculate_sparsity(mod[5].weight)
mod(torch.randn(1, 4, 4, 4))
tq.convert(mod, inplace=True)
# check that final module is the expected quantized module and that the model runs
self.assertTrue(isinstance(mod[5], torch.nn.quantized.Linear))
self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
# check that module was actually sparsified
cur_sparsity = self._calculate_sparsity(mod[5]._weight_bias()[0])
self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
self.assertGreaterAlmostEqual(
sparsity_level, sparse_config[0]["sparsity_level"]
)
self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
# This tests whether performing sparse prepare before fusion causes any issues. The
# worry was that the link created between the sparsifier and the modules that need to
# be sparsified would be broken.
def test_s_prep_before_fusion(self):
(
mod,
sparsifier,
sparse_config,
) = self._get_model_and_sparsifier_and_sparse_config()
sparsifier.prepare(mod, config=sparse_config)
tq.fuse_modules(mod, [["5", "6"]], inplace=True)
mod[5].qconfig = tq.get_default_qconfig("fbgemm")
tq.prepare(mod, inplace=True)
# check that correct modules had parametrizations added and
# that none were lost during prepare or fusion
self.assertTrue(hasattr(mod[0], "parametrizations"))
self.assertTrue(hasattr(mod[5][0], "parametrizations"))
# check that correct observers were inserted and that matching
# occured successfully
self.assertTrue(hasattr(mod[5], "activation_post_process"))
self._squash_mask_calibrate_and_convert(
mod, sparsifier, torch.randn(1, 4, 4, 4)
)
# check that final module is the expected quantized module and that the model runs
self.assertTrue(isinstance(mod[5], torch.nn.intrinsic.quantized.LinearReLU))
self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
# This tests whether performing fusion before sparse prepare causes and issues. The
# main worry was that the links to the modules in the sparse config would be broken by fusion.
def test_fusion_before_s_prep(self):
(
mod,
sparsifier,
sparse_config,
) = self._get_model_and_sparsifier_and_sparse_config()
tq.fuse_modules(mod, [["5", "6"]], inplace=True)
sparsifier.prepare(mod, config=sparse_config)
mod[5].qconfig = tq.get_default_qconfig("fbgemm")
tq.prepare(mod, inplace=True)
# check that correct modules had parametrizations added and
# that none were lost during prepare
self.assertTrue(hasattr(mod[0], "parametrizations"))
self.assertTrue(hasattr(mod[5][0], "parametrizations"))
# check that correct observers were inserted and that matching
# occured successfully
self.assertTrue(hasattr(mod[5], "activation_post_process"))
sparsifier.step()
sparsity_level = self._calculate_sparsity(mod[5][0].weight)
mod(torch.randn(1, 4, 4, 4))
tq.convert(mod, inplace=True)
# check that final module is the expected quantized module and that the model runs
self.assertTrue(isinstance(mod[5], torch.nn.intrinsic.quantized.LinearReLU))
self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
# check that module was actually sparsified
cur_sparsity = self._calculate_sparsity(mod[5]._weight_bias()[0])
self.assertGreaterAlmostEqual(cur_sparsity, sparsity_level)
self.assertGreaterAlmostEqual(
sparsity_level, sparse_config[0]["sparsity_level"]
)
self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
# This tests whether performing sparse prepare before qat prepare causes issues.
# The primary worries were that qat_prep wouldn't recognize the parametrized
# modules and that the convert step for qat would remove the paramerizations
# from the modules.
def test_s_prep_before_qat_prep(self):
(
mod,
sparsifier,
sparse_config,
) = self._get_model_and_sparsifier_and_sparse_config(
tq.get_default_qat_qconfig("fbgemm")
)
sparsifier.prepare(mod, config=sparse_config)
tq.prepare_qat(mod, inplace=True)
self.assertTrue(hasattr(mod[0], "parametrizations"))
self.assertTrue(hasattr(mod[5], "parametrizations"))
# check that correct observers were inserted and that matching
# occured successfully
self.assertTrue(hasattr(mod[5], "activation_post_process"))
self.assertTrue(isinstance(mod[5], torch.nn.qat.Linear))
self._squash_mask_calibrate_and_convert(
mod, sparsifier, torch.randn(1, 4, 4, 4)
)
# check that final module is the expected quantized module and that the model runs
self.assertTrue(isinstance(mod[5], torch.nn.quantized.Linear))
self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
# check that module was actually sparsified
cur_sparsity = self._calculate_sparsity(mod[5]._weight_bias()[0])
self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])
# This tests whether performing qat prepare before sparse prepare causes issues.
def test_qat_prep_before_s_prep(self):
mod, sparsifier, _ = self._get_model_and_sparsifier_and_sparse_config(
tq.get_default_qat_qconfig("fbgemm")
)
tq.prepare_qat(mod, inplace=True)
# need to setup sparse_config on new modules
sparse_config = [
{
"module": mod[5],
"sparsity_level": 0.7,
"sparse_block_shape": (1, 4),
"zeros_per_block": 4,
},
mod[0],
]
sparsifier.prepare(mod, config=sparse_config)
# check that correct modules had parametrizations added and
# that none were lost during qat prepare
self.assertTrue(hasattr(mod[0], "parametrizations"))
self.assertTrue(hasattr(mod[5], "parametrizations"))
# check that correct observers were inserted and that matching
# occured successfully
self.assertTrue(hasattr(mod[5], "activation_post_process"))
self.assertTrue(isinstance(mod[5], torch.nn.qat.Linear))
self._squash_mask_calibrate_and_convert(
mod, sparsifier, torch.randn(1, 4, 4, 4)
)
# check that final module is the expected quantized module and that the model runs
self.assertTrue(isinstance(mod[5], torch.nn.quantized.Linear))
self.assertEqual(mod(torch.randn(1, 4, 4, 4)).shape, torch.Size([1, 4, 4, 4]))
# check that module was actually sparsified
cur_sparsity = self._calculate_sparsity(mod[5]._weight_bias()[0])
self.assertGreaterAlmostEqual(cur_sparsity, sparse_config[0]["sparsity_level"])