blob: 543c9afd019fa5c77954415217e75240d999a8a4 [file] [log] [blame]
# -*- coding: utf-8 -*-
# Owner(s): ["module: unknown"]
import logging
import warnings
from torch.testing._internal.common_utils import TestCase
from torch import nn
import torch
from typing import Tuple
import copy
from torch.ao.sparsity._experimental.data_sparsifier import DataNormSparsifier
from torch.ao.sparsity._experimental.data_scheduler import BaseDataScheduler
logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
class ImplementedDataScheduler(BaseDataScheduler):
def __init__(self, sparsifier, sparsifier_hyperparam, last_epoch=-1, verbose=False):
super().__init__(sparsifier, sparsifier_hyperparam, last_epoch, verbose)
def get_schedule_param(self):
if self.last_epoch > 0:
return {name: config['sparsity_level'] * 0.5
for name, config in self.data_sparsifier.data_groups.items()}
else:
return self.base_param
class TestBaseDataScheduler(TestCase):
def _get_data(self):
tensor1, param1, emb1 = torch.randn(5, 5), nn.Parameter(torch.randn(10, 10)), nn.Embedding(50, 5)
data_list = [
('tensor1', tensor1), ('param1', param1), ('emb1', emb1)
]
defaults = {
'sparsity_level': 0.7,
'sparse_block_shape': (1, 4),
'zeros_per_block': 2
}
data_with_config = [
{
'name': 'tensor2', 'data': torch.randn(4, 4),
'config': {'sparsity_level': 0.3}
}
]
return data_list, data_with_config, defaults
def _get_sparsifier(self, data_list, data_with_config, defaults):
sparsifier = DataNormSparsifier(data_list, **defaults)
for data_config_dict in data_with_config:
name, data, config = data_config_dict['name'], data_config_dict['data'], data_config_dict['config']
sparsifier.add_data(name=name, data=data, **config)
return sparsifier
def _get_scheduler(self, sparsifier, schedule_param):
scheduler = ImplementedDataScheduler(sparsifier, schedule_param)
return scheduler
def _get_schedule_param(self):
return 'sparsity_level'
def _get_name_data_config(self, some_data, defaults):
config = copy.deepcopy(defaults)
if isinstance(some_data, Tuple):
# dealing with data_list
name, data = some_data
else:
# dealing with data_with_config
name, data, new_config = some_data['name'], some_data['data'], some_data['config']
config.update(new_config)
return name, data, config
def test_constructor(self):
"""Checks if the warning is thrown if the scheduler step is called
before the sparsifier step"""
data_list, data_with_config, defaults = self._get_data()
sparsifier = self._get_sparsifier(data_list, data_with_config, defaults)
schedule_param = self._get_schedule_param()
scheduler = self._get_scheduler(sparsifier, schedule_param)
assert scheduler.data_sparsifier == sparsifier
assert scheduler._step_count == 1
for name, config in sparsifier.data_groups.items():
assert scheduler.base_param[name] == config.get(schedule_param, None)
def test_order_of_steps(self):
data_list, data_with_config, defaults = self._get_data()
sparsifier = self._get_sparsifier(data_list, data_with_config, defaults)
schedule_param = self._get_schedule_param()
scheduler = self._get_scheduler(sparsifier, schedule_param)
# Sparsifier step is not called
with self.assertWarns(UserWarning):
scheduler.step()
# Correct order has no warnings
# Note: This will trigger if other warnings are present.
with warnings.catch_warnings(record=True) as w:
sparsifier.step()
scheduler.step()
# Make sure there is no warning related to the base_data_scheduler
for warning in w:
fname = warning.filename
fname = '/'.join(fname.split('/')[-5:])
assert fname != 'torch/ao/sparsity/experimental/scheduler/data_scheduler/base_data_scheduler.py'
def test_step(self):
data_list, data_with_config, defaults = self._get_data()
sparsifier = self._get_sparsifier(data_list, data_with_config, defaults)
schedule_param = self._get_schedule_param()
scheduler = self._get_scheduler(sparsifier, schedule_param)
all_data = data_list + data_with_config
for some_data in all_data:
name, _, config = self._get_name_data_config(some_data, defaults)
assert sparsifier.data_groups[name][schedule_param] == config[schedule_param]
sparsifier.step()
scheduler.step()
for some_data in all_data:
name, _, config = self._get_name_data_config(some_data, defaults)
assert sparsifier.data_groups[name][schedule_param] == config[schedule_param] * 0.5
# checking step count
step_cnt = 5
for _ in range(0, step_cnt):
sparsifier.step()
scheduler.step()
assert scheduler._step_count == step_cnt + 2 # step_cnt + step above + 1 step in constructor
def test_state_dict(self):
data_list, data_with_config, defaults = self._get_data()
sparsifier = self._get_sparsifier(data_list, data_with_config, defaults)
schedule_param = self._get_schedule_param()
scheduler1 = self._get_scheduler(sparsifier, schedule_param)
sparsifier.step()
scheduler1.step()
scheduler2 = self._get_scheduler(sparsifier, schedule_param)
all_data = data_list + data_with_config
for some_data in all_data:
name, _, _ = self._get_name_data_config(some_data, defaults)
assert scheduler1.base_param[name] != scheduler2.base_param[name]
assert scheduler1._last_param[name] == scheduler2.base_param[name]
scheduler1_state = scheduler1.state_dict()
scheduler2.load_state_dict(scheduler1_state)
for some_data in all_data:
name, _, _ = self._get_name_data_config(some_data, defaults)
assert scheduler1.base_param[name] == scheduler2.base_param[name]
assert scheduler1._last_param[name] == scheduler2._last_param[name]