blob: f777461e42d8b7434726f8bf7db4b375247b21e0 [file] [log] [blame]
# -*- coding: utf-8 -*-
# Owner(s): ["module: unknown"]
import copy
import logging
import torch
from torch import nn
from torch.ao.sparsity import BasePruner, PruningParametrization, ZeroesParametrization
from torch.nn.utils import parametrize
from torch.testing._internal.common_utils import TestCase
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
DEVICES = {
torch.device("cpu"),
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
}
NEEDS_ZEROS = { # these layers should have pruned indices zero-ed, not removed
nn.BatchNorm2d
}
class Linear(nn.Module):
r"""Model with Linear layers, in Sequential and outside, without biases"""
def __init__(self):
super().__init__()
self.seq = nn.Sequential(
nn.Linear(16, 16, bias=False)
)
self.linear = nn.Linear(16, 16, bias=False)
def forward(self, x):
x = self.seq(x)
x = self.linear(x)
return x
class LinearB(nn.Module):
r"""Model with Linear layers, in Sequential and outside, with biases"""
def __init__(self):
super().__init__()
self.seq = nn.Sequential(
nn.Linear(16, 16, bias=True)
)
self.linear = nn.Linear(16, 16, bias=True)
def forward(self, x):
x = self.seq(x)
x = self.linear(x)
return x
class MultipleLinear(nn.Module):
r"""Model with multiple Linear layers, in Sequential and outside, without biases
and with activation functions"""
def __init__(self):
super().__init__()
self.seq = nn.Sequential(
nn.Linear(7, 5, bias=False),
nn.ReLU(),
nn.Linear(5, 8, bias=False),
nn.ReLU(),
nn.Linear(8, 6, bias=False)
)
self.linear = nn.Linear(6, 4, bias=False)
def forward(self, x):
x = self.seq(x)
x = self.linear(x)
return x
class MultipleLinearB(nn.Module):
r"""Model with multiple Linear layers, in Sequential and outside, with biases
and with activation functions"""
def __init__(self):
super().__init__()
self.seq = nn.Sequential(
nn.Linear(7, 5, bias=True),
nn.ReLU(),
nn.Linear(5, 8, bias=True),
nn.ReLU(),
nn.Linear(8, 6, bias=True)
)
self.linear = nn.Linear(6, 4, bias=True)
def forward(self, x):
x = self.seq(x)
x = self.linear(x)
return x
class MultipleLinearMixed(nn.Module):
r"""Model with multiple Linear layers, in Sequential and outside, some with biases
and with activation functions"""
def __init__(self):
super().__init__()
self.seq = nn.Sequential(
nn.Linear(7, 5, bias=True),
nn.ReLU(),
nn.Linear(5, 8, bias=False),
nn.ReLU(),
nn.Linear(8, 6, bias=True)
)
self.linear = nn.Linear(6, 4, bias=False)
def forward(self, x):
x = self.seq(x)
x = self.linear(x)
return x
class Conv2dA(nn.Module):
r"""Model with Conv2d layers, in Sequential and outside, without biases"""
def __init__(self):
super().__init__()
self.seq = nn.Sequential(
nn.Conv2d(1, 32, 3, 1, bias=False),
)
self.conv2d = nn.Conv2d(32, 64, 3, 1, bias=False)
def forward(self, x):
x = self.seq(x)
x = self.conv2d(x)
return x
class Conv2dB(nn.Module):
r"""Model with Conv2d layers, in Sequential and outside, with biases"""
def __init__(self):
super().__init__()
self.seq = nn.Sequential(
nn.Conv2d(1, 32, 3, 1, bias=True),
)
self.conv2d = nn.Conv2d(32, 64, 3, 1, bias=True)
def forward(self, x):
x = self.seq(x)
x = self.conv2d(x)
return x
class Conv2dC(nn.Module):
r"""Model with Conv2d layers, in Sequential and outside, with and without biases"""
def __init__(self):
super().__init__()
self.seq = nn.Sequential(
nn.Conv2d(1, 32, 3, 1, bias=True),
)
self.conv2d = nn.Conv2d(32, 64, 3, 1, bias=False)
def forward(self, x):
x = self.seq(x)
x = self.conv2d(x)
return x
class Conv2dBN(nn.Module):
r"""Model with Conv2d layers and BatchNorms"""
def __init__(self):
super().__init__()
self.seq = nn.Sequential(
nn.Conv2d(1, 32, 3, 1, bias=True),
nn.BatchNorm2d(32)
)
self.conv2d = nn.Conv2d(32, 64, 3, 1, bias=True)
self.bn = nn.BatchNorm2d(64)
def forward(self, x):
x = self.seq(x)
x = self.conv2d(x)
x = self.bn(x)
return x
class SimplePruner(BasePruner):
def update_mask(self, layer, **kwargs):
layer.parametrizations.weight[0].pruned_outputs.add(1)
class MultiplePruner(BasePruner):
def update_mask(self, layer, **kwargs):
layer.parametrizations.weight[0].pruned_outputs.update([1, 2])
class TestBasePruner(TestCase):
def _check_pruner_prepared(self, model, pruner, device):
for config in pruner.module_groups:
modules = []
if type(config['module']) is tuple:
for module in config['module']:
modules.append(module)
else:
module = config['module']
modules.append(module)
for module in modules:
assert module.weight.device.type == device.type
# Check mask exists
assert hasattr(module, 'mask')
# Check parametrization exists and is correct
assert parametrize.is_parametrized(module)
assert hasattr(module, "parametrizations")
# Assume that this is the 1st/only parametrization
if isinstance(module, tuple(NEEDS_ZEROS)):
assert type(module.parametrizations.weight[0]) == ZeroesParametrization
else:
assert type(module.parametrizations.weight[0]) == PruningParametrization
def _check_pruner_mask_squashed(self, model, pruner, device):
for config in pruner.module_groups:
modules = []
if type(config['module']) is tuple:
for module in config['module']:
modules.append(module)
else:
module = config['module']
modules.append(module)
for module in modules:
assert module.weight.device.type == device.type
assert not hasattr(module, "parametrizations")
assert not hasattr(module, 'mask')
def _check_pruner_valid_before_step(self, model, pruner, device):
for config in pruner.module_groups:
modules = []
if type(config['module']) is tuple:
for module in config['module']:
modules.append(module)
else:
module = config['module']
modules.append(module)
for module in modules:
assert module.weight.device.type == device.type
assert module.parametrizations.weight[0].pruned_outputs == set()
def _check_pruner_valid_after_step(self, model, pruner, pruned_set, device):
for config in pruner.module_groups:
modules = []
if type(config['module']) is tuple:
for module in config['module']:
modules.append(module)
else:
module = config['module']
modules.append(module)
for module in modules:
assert module.weight.device.type == device.type
assert module.parametrizations.weight[0].pruned_outputs == pruned_set
def _test_constructor_on_device(self, model, device):
self.assertRaisesRegex(TypeError, 'BasePruner .* update_mask',
BasePruner)
model1 = copy.deepcopy(model).to(device)
pruner = SimplePruner(None)
pruner.prepare(model1, None)
for g in pruner.module_groups:
module = g['module']
assert module.weight.device.type == device.type
assert len(pruner.module_groups) == 2
pruner.step()
# Can instantiate the model with configs
model2 = copy.deepcopy(model).to(device)
pruner = SimplePruner({'test': 3})
pruner.prepare(model2, [model2.linear])
assert len(pruner.module_groups) == 1
assert pruner.module_groups[0]['fqn'] == 'linear'
assert 'test' in pruner.module_groups[0]
assert pruner.module_groups[0]['test'] == 3
def test_constructor(self):
model = Linear()
for device in DEVICES:
self._test_constructor_on_device(model, torch.device(device))
def _test_prepare_linear_on_device(self, model, device):
model = copy.deepcopy(model).to(device)
x = torch.ones(128, 16, device=device)
pruner = SimplePruner(None)
pruner.prepare(model, None)
self._check_pruner_prepared(model, pruner, device)
assert model(x).shape == (128, 16)
def test_prepare_linear(self):
models = [Linear(), LinearB()] # without and with bias
for device in DEVICES:
for model in models:
self._test_prepare_linear_on_device(model, torch.device(device))
def _test_prepare_conv2d_on_device(self, model, config, device):
x = torch.ones((1, 1, 28, 28), device=device)
pruner = SimplePruner(None)
pruner.prepare(model, config)
self._check_pruner_prepared(model, pruner, device)
assert model(x).shape == (1, 64, 24, 24)
def test_prepare_conv2d(self):
bn_model = Conv2dBN()
bn_config = [(bn_model.seq[0], bn_model.seq[1]), (bn_model.conv2d, bn_model.bn)]
models = [Conv2dA(), Conv2dB(), Conv2dC(), bn_model]
configs = [None, None, None, bn_config]
for device in DEVICES:
for model, config in zip(models, configs):
model = model.to(device)
self._test_prepare_conv2d_on_device(model, config, torch.device(device))
def _test_squash_mask_linear_on_device(self, model, device):
model = copy.deepcopy(model).to(device)
x = torch.ones(128, 16, device=device)
pruner = SimplePruner(None)
pruner.prepare(model, None)
pruner.squash_mask()
self._check_pruner_mask_squashed(model, pruner, device)
assert model(x).shape == (128, 16)
def test_squash_mask_linear(self):
models = [Linear(), LinearB()] # without and with bias
for device in DEVICES:
for model in models:
self._test_squash_mask_linear_on_device(model, torch.device(device))
def _test_squash_mask_conv2d_on_device(self, model, config, device):
model = copy.deepcopy(model).to(device)
x = torch.ones((1, 1, 28, 28), device=device)
pruner = SimplePruner(None)
pruner.prepare(model, config)
pruner.squash_mask()
self._check_pruner_mask_squashed(model, pruner, device)
assert model(x).shape == (1, 64, 24, 24)
def test_squash_mask_conv2d(self):
bn_model = Conv2dBN()
bn_config = [(bn_model.seq[0], bn_model.seq[1]), (bn_model.conv2d, bn_model.bn)]
models = [Conv2dA(), Conv2dB(), Conv2dC(), bn_model]
configs = [None, None, None, bn_config]
for device in DEVICES:
for model, config in zip(models, configs):
model = model.to(device)
self._test_squash_mask_conv2d_on_device(model, config, torch.device(device))
def _test_step_linear_on_device(self, model, is_basic, device):
model = model.to(device)
if is_basic:
x = torch.ones(16, 16)
pruner = SimplePruner(None)
pruner.prepare(model, None)
self._check_pruner_valid_before_step(model, pruner, device)
pruner.step()
self._check_pruner_valid_after_step(model, pruner, {1}, device)
else:
x = torch.ones(7, 7)
pruner = MultiplePruner(None)
pruner.prepare(model, None)
self._check_pruner_valid_before_step(model, pruner, device)
pruner.step()
self._check_pruner_valid_after_step(model, pruner, {1, 2}, device)
def test_step_linear(self):
basic_models = [Linear(), LinearB()]
complex_models = [MultipleLinear(), MultipleLinearB(), MultipleLinearMixed()]
for device in DEVICES:
for model in basic_models:
self._test_step_linear_on_device(model, True, torch.device(device))
for model in complex_models:
self._test_step_linear_on_device(model, False, torch.device(device))
def _test_step_conv2d_on_device(self, model, config, device):
model = model.to(device)
x = torch.ones((1, 1, 28, 28)).to(device)
pruner = SimplePruner(None)
pruner.prepare(model, config)
self._check_pruner_valid_before_step(model, pruner, device)
pruner.step()
if type(model) is Conv2dBN:
assert pruner.get_module_pruned_outputs(model.seq[1]) == pruner.get_module_pruned_outputs(model.seq[0])
assert pruner.get_module_pruned_outputs(model.bn) == pruner.get_module_pruned_outputs(model.conv2d)
self._check_pruner_valid_after_step(model, pruner, {1}, device)
assert model(x).shape == (1, 64, 24, 24)
def test_step_conv2d(self):
bn_model = Conv2dBN()
bn_config = [(bn_model.seq[0], bn_model.seq[1]),
(bn_model.conv2d, bn_model.bn)]
models = [Conv2dA(), Conv2dB(), Conv2dC(), bn_model]
configs = [None, None, None, None, bn_config]
for device in DEVICES:
for model, config in zip(models, configs):
self._test_step_conv2d_on_device(model, config, torch.device(device))