| # -*- coding: utf-8 -*- |
| # Owner(s): ["module: unknown"] |
| |
| import logging |
| import warnings |
| from torch.testing._internal.common_utils import TestCase |
| from torch import nn |
| import torch |
| from typing import Tuple |
| import copy |
| |
| from torch.ao.sparsity._experimental.data_sparsifier import DataNormSparsifier |
| from torch.ao.sparsity._experimental.data_scheduler import BaseDataScheduler |
| |
| logging.basicConfig(format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO) |
| |
| |
| class ImplementedDataScheduler(BaseDataScheduler): |
| def __init__(self, sparsifier, sparsifier_hyperparam, last_epoch=-1, verbose=False): |
| super().__init__(sparsifier, sparsifier_hyperparam, last_epoch, verbose) |
| |
| def get_schedule_param(self): |
| if self.last_epoch > 0: |
| return {name: config['sparsity_level'] * 0.5 |
| for name, config in self.data_sparsifier.data_groups.items()} |
| else: |
| return self.base_param |
| |
| |
| class TestBaseDataScheduler(TestCase): |
| def _get_data(self): |
| tensor1, param1, emb1 = torch.randn(5, 5), nn.Parameter(torch.randn(10, 10)), nn.Embedding(50, 5) |
| data_list = [ |
| ('tensor1', tensor1), ('param1', param1), ('emb1', emb1) |
| ] |
| defaults = { |
| 'sparsity_level': 0.7, |
| 'sparse_block_shape': (1, 4), |
| 'zeros_per_block': 2 |
| } |
| data_with_config = [ |
| { |
| 'name': 'tensor2', 'data': torch.randn(4, 4), |
| 'config': {'sparsity_level': 0.3} |
| } |
| ] |
| return data_list, data_with_config, defaults |
| |
| def _get_sparsifier(self, data_list, data_with_config, defaults): |
| sparsifier = DataNormSparsifier(data_list, **defaults) |
| for data_config_dict in data_with_config: |
| name, data, config = data_config_dict['name'], data_config_dict['data'], data_config_dict['config'] |
| sparsifier.add_data(name=name, data=data, **config) |
| return sparsifier |
| |
| def _get_scheduler(self, sparsifier, schedule_param): |
| scheduler = ImplementedDataScheduler(sparsifier, schedule_param) |
| return scheduler |
| |
| def _get_schedule_param(self): |
| return 'sparsity_level' |
| |
| def _get_name_data_config(self, some_data, defaults): |
| config = copy.deepcopy(defaults) |
| if isinstance(some_data, Tuple): |
| # dealing with data_list |
| name, data = some_data |
| else: |
| # dealing with data_with_config |
| name, data, new_config = some_data['name'], some_data['data'], some_data['config'] |
| config.update(new_config) |
| return name, data, config |
| |
| def test_constructor(self): |
| """Checks if the warning is thrown if the scheduler step is called |
| before the sparsifier step""" |
| data_list, data_with_config, defaults = self._get_data() |
| sparsifier = self._get_sparsifier(data_list, data_with_config, defaults) |
| schedule_param = self._get_schedule_param() |
| scheduler = self._get_scheduler(sparsifier, schedule_param) |
| |
| assert scheduler.data_sparsifier == sparsifier |
| assert scheduler._step_count == 1 |
| |
| for name, config in sparsifier.data_groups.items(): |
| assert scheduler.base_param[name] == config.get(schedule_param, None) |
| |
| def test_order_of_steps(self): |
| data_list, data_with_config, defaults = self._get_data() |
| sparsifier = self._get_sparsifier(data_list, data_with_config, defaults) |
| schedule_param = self._get_schedule_param() |
| scheduler = self._get_scheduler(sparsifier, schedule_param) |
| |
| # Sparsifier step is not called |
| with self.assertWarns(UserWarning): |
| scheduler.step() |
| |
| # Correct order has no warnings |
| # Note: This will trigger if other warnings are present. |
| with warnings.catch_warnings(record=True) as w: |
| sparsifier.step() |
| scheduler.step() |
| # Make sure there is no warning related to the base_data_scheduler |
| for warning in w: |
| fname = warning.filename |
| fname = '/'.join(fname.split('/')[-5:]) |
| assert fname != 'torch/ao/sparsity/experimental/scheduler/data_scheduler/base_data_scheduler.py' |
| |
| def test_step(self): |
| data_list, data_with_config, defaults = self._get_data() |
| sparsifier = self._get_sparsifier(data_list, data_with_config, defaults) |
| schedule_param = self._get_schedule_param() |
| scheduler = self._get_scheduler(sparsifier, schedule_param) |
| |
| all_data = data_list + data_with_config |
| |
| for some_data in all_data: |
| name, _, config = self._get_name_data_config(some_data, defaults) |
| assert sparsifier.data_groups[name][schedule_param] == config[schedule_param] |
| |
| sparsifier.step() |
| scheduler.step() |
| |
| for some_data in all_data: |
| name, _, config = self._get_name_data_config(some_data, defaults) |
| assert sparsifier.data_groups[name][schedule_param] == config[schedule_param] * 0.5 |
| |
| # checking step count |
| step_cnt = 5 |
| for _ in range(0, step_cnt): |
| sparsifier.step() |
| scheduler.step() |
| |
| assert scheduler._step_count == step_cnt + 2 # step_cnt + step above + 1 step in constructor |
| |
| def test_state_dict(self): |
| data_list, data_with_config, defaults = self._get_data() |
| sparsifier = self._get_sparsifier(data_list, data_with_config, defaults) |
| schedule_param = self._get_schedule_param() |
| scheduler1 = self._get_scheduler(sparsifier, schedule_param) |
| |
| sparsifier.step() |
| scheduler1.step() |
| |
| scheduler2 = self._get_scheduler(sparsifier, schedule_param) |
| all_data = data_list + data_with_config |
| for some_data in all_data: |
| name, _, _ = self._get_name_data_config(some_data, defaults) |
| assert scheduler1.base_param[name] != scheduler2.base_param[name] |
| assert scheduler1._last_param[name] == scheduler2.base_param[name] |
| |
| scheduler1_state = scheduler1.state_dict() |
| scheduler2.load_state_dict(scheduler1_state) |
| |
| for some_data in all_data: |
| name, _, _ = self._get_name_data_config(some_data, defaults) |
| assert scheduler1.base_param[name] == scheduler2.base_param[name] |
| assert scheduler1._last_param[name] == scheduler2._last_param[name] |