| # -*- coding: utf-8 -*- |
| # Owner(s): ["module: unknown"] |
| |
| |
| import logging |
| |
| import torch |
| from torch.ao.sparsity.sparsifier.utils import ( |
| fqn_to_module, |
| get_arg_info_from_tensor_fqn, |
| module_to_fqn, |
| ) |
| |
| from torch.testing._internal.common_quantization import ( |
| ConvBnReLUModel, |
| ConvModel, |
| FunctionalLinear, |
| LinearAddModel, |
| ManualEmbeddingBagLinear, |
| SingleLayerLinearModel, |
| TwoLayerLinearModel, |
| ) |
| from torch.testing._internal.common_utils import TestCase |
| |
| logging.basicConfig( |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO |
| ) |
| |
| model_list = [ |
| ConvModel, |
| SingleLayerLinearModel, |
| TwoLayerLinearModel, |
| LinearAddModel, |
| ConvBnReLUModel, |
| ManualEmbeddingBagLinear, |
| FunctionalLinear, |
| ] |
| |
| |
| class TestSparsityUtilFunctions(TestCase): |
| def test_module_to_fqn(self): |
| """ |
| Tests that module_to_fqn works as expected when compared to known good |
| module.get_submodule(fqn) function |
| """ |
| for model_class in model_list: |
| model = model_class() |
| list_of_modules = [m for _, m in model.named_modules()] + [model] |
| for module in list_of_modules: |
| fqn = module_to_fqn(model, module) |
| check_module = model.get_submodule(fqn) |
| self.assertEqual(module, check_module) |
| |
| def test_module_to_fqn_fail(self): |
| """ |
| Tests that module_to_fqn returns None when an fqn that doesn't |
| correspond to a path to a node/tensor is given |
| """ |
| for model_class in model_list: |
| model = model_class() |
| fqn = module_to_fqn(model, torch.nn.Linear(3, 3)) |
| self.assertEqual(fqn, None) |
| |
| def test_module_to_fqn_root(self): |
| """ |
| Tests that module_to_fqn returns '' when model and target module are the same |
| """ |
| for model_class in model_list: |
| model = model_class() |
| fqn = module_to_fqn(model, model) |
| self.assertEqual(fqn, "") |
| |
| def test_fqn_to_module(self): |
| """ |
| Tests that fqn_to_module operates as inverse |
| of module_to_fqn |
| """ |
| for model_class in model_list: |
| model = model_class() |
| list_of_modules = [m for _, m in model.named_modules()] + [model] |
| for module in list_of_modules: |
| fqn = module_to_fqn(model, module) |
| check_module = fqn_to_module(model, fqn) |
| self.assertEqual(module, check_module) |
| |
| def test_fqn_to_module_fail(self): |
| """ |
| Tests that fqn_to_module returns None when it tries to |
| find an fqn of a module outside the model |
| """ |
| for model_class in model_list: |
| model = model_class() |
| fqn = "foo.bar.baz" |
| check_module = fqn_to_module(model, fqn) |
| self.assertEqual(check_module, None) |
| |
| def test_fqn_to_module_for_tensors(self): |
| """ |
| Tests that fqn_to_module works for tensors, actually all parameters |
| of the model. This is tested by identifying a module with a tensor, |
| and generating the tensor_fqn using module_to_fqn on the module + |
| the name of the tensor. |
| """ |
| for model_class in model_list: |
| model = model_class() |
| list_of_modules = [m for _, m in model.named_modules()] + [model] |
| for module in list_of_modules: |
| module_fqn = module_to_fqn(model, module) |
| for tensor_name, tensor in module.named_parameters(recurse=False): |
| tensor_fqn = ( # string manip to handle tensors on root |
| module_fqn + ("." if module_fqn != "" else "") + tensor_name |
| ) |
| check_tensor = fqn_to_module(model, tensor_fqn) |
| self.assertEqual(tensor, check_tensor) |
| |
| def test_get_arg_info_from_tensor_fqn(self): |
| """ |
| Tests that get_arg_info_from_tensor_fqn works for all parameters of the model. |
| Generates a tensor_fqn in the same way as test_fqn_to_module_for_tensors and |
| then compares with known (parent) module and tensor_name as well as module_fqn |
| from module_to_fqn. |
| """ |
| for model_class in model_list: |
| model = model_class() |
| list_of_modules = [m for _, m in model.named_modules()] + [model] |
| for module in list_of_modules: |
| module_fqn = module_to_fqn(model, module) |
| for tensor_name, tensor in module.named_parameters(recurse=False): |
| tensor_fqn = ( |
| module_fqn + ("." if module_fqn != "" else "") + tensor_name |
| ) |
| arg_info = get_arg_info_from_tensor_fqn(model, tensor_fqn) |
| self.assertEqual(arg_info["module"], module) |
| self.assertEqual(arg_info["module_fqn"], module_fqn) |
| self.assertEqual(arg_info["tensor_name"], tensor_name) |
| self.assertEqual(arg_info["tensor_fqn"], tensor_fqn) |
| |
| def test_get_arg_info_from_tensor_fqn_fail(self): |
| """ |
| Tests that get_arg_info_from_tensor_fqn works as expected for invalid tensor_fqn |
| inputs. The string outputs still work but the output module is expected to be None. |
| """ |
| for model_class in model_list: |
| model = model_class() |
| tensor_fqn = "foo.bar.baz" |
| arg_info = get_arg_info_from_tensor_fqn(model, tensor_fqn) |
| self.assertEqual(arg_info["module"], None) |
| self.assertEqual(arg_info["module_fqn"], "foo.bar") |
| self.assertEqual(arg_info["tensor_name"], "baz") |
| self.assertEqual(arg_info["tensor_fqn"], "foo.bar.baz") |