| # Owner(s): ["module: nn"] |
| |
| import tempfile |
| from copy import deepcopy |
| from functools import partial |
| from unittest import expectedFailure |
| |
| import torch |
| from torch import nn |
| from torch.nn.modules.lazy import LazyModuleMixin |
| from torch.nn.utils.parametrize import ( |
| register_parametrization, |
| remove_parametrizations, |
| ) |
| from torch.testing._internal.common_subclass import ( |
| DiagTensorBelow, |
| subclass_db, |
| ) |
| from torch.testing._internal.common_utils import ( |
| TestCase, |
| instantiate_parametrized_tests, |
| parametrize, |
| run_tests, |
| skipIfTorchDynamo, |
| subtest, |
| ) |
| from torch.testing._internal.logging_tensor import LoggingTensor |
| from torch.utils._pytree import tree_map |
| |
| # The current test methodology in this file is to test a variety of real use cases |
| # with a set of fully-fledged tensor subclasses. In the future, this may change |
| # to more narrowly specify toy subclasses for each of the specific invariants under |
| # test, avoiding the need to maintain the set of fully-fledged tensor subclasses. |
| |
| |
| # Decorator for parametrizing tests across the various tensor classes. |
| parametrize_tensor_cls = parametrize("tensor_cls", [ |
| subtest(tensor_cls, name=info.name) for tensor_cls, info in subclass_db.items()]) |
| |
| |
| class TestSubclass(TestCase): |
| def _create_tensor(self, tensor_cls): |
| return subclass_db[tensor_cls].create_fn(3) |
| |
| @parametrize_tensor_cls |
| @parametrize("tensor_requires_grad", [False, True]) |
| def test_param_invariants(self, tensor_cls, tensor_requires_grad): |
| x = self._create_tensor(tensor_cls).requires_grad_(tensor_requires_grad) |
| param = nn.Parameter(x, requires_grad=(not tensor_requires_grad)) |
| |
| self.assertIsInstance(param, nn.Parameter) |
| # Ensure requires_grad passed to Parameter's constructor takes precedence. |
| self.assertEqual(param.requires_grad, not tensor_requires_grad) |
| |
| # Ensure original tensor is not mutated by Parameter construction. |
| self.assertNotIsInstance(x, nn.Parameter) |
| self.assertEqual(x.requires_grad, tensor_requires_grad) |
| |
| class UninitializedParam(nn.Parameter): |
| pass |
| |
| self.assertNotIsInstance(param, UninitializedParam) |
| |
| @skipIfTorchDynamo() |
| @parametrize_tensor_cls |
| @parametrize("as_param", [False, True]) |
| def test_deepcopy(self, tensor_cls, as_param): |
| x = self._create_tensor(tensor_cls) |
| if as_param: |
| x = nn.Parameter(x) |
| x_copy = deepcopy(x) |
| self.assertEqual(x, x_copy) |
| self.assertEqual(x.__class__, x_copy.__class__) |
| self.assertIsNot(x, x_copy) |
| self.assertIsInstance(x_copy, tensor_cls) |
| if as_param: |
| # Deepcopy should preserve both custom type and "parameter-ness". |
| self.assertIsInstance(x_copy, nn.Parameter) |
| |
| @parametrize_tensor_cls |
| @parametrize("as_param", [False, True]) |
| def test_serialization(self, tensor_cls, as_param): |
| with tempfile.TemporaryFile() as f: |
| x = self._create_tensor(tensor_cls) |
| if as_param: |
| x = nn.Parameter(x) |
| torch.save(x, f) |
| f.seek(0) |
| with torch.serialization.safe_globals([tensor_cls]): |
| x_loaded = torch.load(f) |
| |
| self.assertEqual(x, x_loaded) |
| self.assertIsNot(x, x_loaded) |
| self.assertIsInstance(x_loaded, tensor_cls) |
| if as_param: |
| # Serialization should preserve both custom type and "parameter-ness". |
| self.assertIsInstance(x_loaded, nn.Parameter) |
| |
| @skipIfTorchDynamo("Visible only with functorch as functorch monkeypatches tensor str") |
| @parametrize_tensor_cls |
| @parametrize("as_param", [False, True]) |
| def test_repr(self, tensor_cls, as_param): |
| x = self._create_tensor(tensor_cls) |
| if as_param: |
| x = nn.Parameter(x) |
| str_repr = x.__repr__() |
| if tensor_cls is not torch.Tensor: |
| self.assertEqual(str_repr.count(f"{tensor_cls.__name__}("), 1) |
| self.assertEqual(str_repr.count("Parameter"), 1 if as_param else 0) |
| |
| @parametrize_tensor_cls |
| @parametrize("as_param", [False, True]) |
| def test_type_propagation(self, tensor_cls, as_param): |
| x = self._create_tensor(tensor_cls) |
| if as_param: |
| x = nn.Parameter(x) |
| |
| # Call the add operator to produce an output tensor. |
| output = x + self._create_tensor(torch.Tensor) |
| |
| # Custom type should be propagated across operations if closed under the op, but |
| # "parameter-ness" should not be. |
| if subclass_db[tensor_cls].closed_under_ops: |
| self.assertIsInstance(output, tensor_cls) |
| else: |
| self.assertIsInstance(output, torch.Tensor) |
| self.assertNotIsInstance(output, nn.Parameter) |
| |
| @parametrize_tensor_cls |
| def test_module_optimization(self, tensor_cls): |
| create_fn = partial(self._create_tensor, tensor_cls) |
| |
| class MyModule(nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.p1 = nn.Parameter(create_fn()) |
| |
| self.p_list = nn.ParameterList([create_fn() for _ in range(3)]) |
| self.p_list.append(create_fn()) |
| |
| self.p_dict = nn.ParameterDict({ |
| 'foo': create_fn(), |
| 'bar': create_fn(), |
| }) |
| self.p_dict['baz'] = create_fn() |
| |
| with torch.no_grad(): |
| nn.init.normal_(self.p1) |
| for p in self.p_list: |
| nn.init.uniform_(p) |
| for p in self.p_dict.values(): |
| nn.init.uniform_(p) |
| |
| def forward(self, x): |
| out = self.p1 + x |
| for p in self.p_list: |
| out = p + out |
| |
| for v in self.p_dict.values(): |
| out = v + out |
| |
| return out |
| |
| m = MyModule() |
| self.assertEqual(len(m.state_dict()), 8) |
| |
| optimizer = torch.optim.SGD(m.parameters(), lr=0.1) |
| m(create_fn()).sum().backward(torch.tensor(1)) |
| optimizer.step() |
| |
| @parametrize_tensor_cls |
| @parametrize("leave_parametrized", [False, True]) |
| def test_parametrization(self, tensor_cls, leave_parametrized): |
| # TODO: Either implement set_() properly for these tensor subclasses or apply a |
| # more general fix to avoid the need for special set_() handling. For now, skip |
| # testing these as they're expected to fail. |
| if tensor_cls in [LoggingTensor, DiagTensorBelow]: |
| return |
| |
| create_fn = partial(self._create_tensor, tensor_cls) |
| |
| class MyModule(nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.weight = nn.Parameter(create_fn()) |
| |
| def forward(self, x): |
| return self.weight + x |
| |
| class MyParametrization(nn.Module): |
| def forward(self, X): |
| return -X |
| |
| m = MyModule() |
| self.assertEqual(len(m.state_dict()), 1) |
| register_parametrization(m, 'weight', MyParametrization()) |
| self.assertIsInstance(m.weight, tensor_cls) |
| output = m(self._create_tensor(torch.Tensor)) |
| self.assertIsInstance(output, tensor_cls) |
| remove_parametrizations(m, 'weight', leave_parametrized=leave_parametrized) |
| |
| # Lazy modules with custom tensors are not supported yet. |
| @expectedFailure |
| @parametrize_tensor_cls |
| def test_lazy_module(self, tensor_cls): |
| if tensor_cls is torch.Tensor: |
| self.fail('dummy fail for base tensor until the test passes for subclasses') |
| |
| class MyLazyModule(LazyModuleMixin, nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.param = nn.UninitializedParameter() |
| |
| def initialize_parameters(self, input) -> None: # type: ignore[override] |
| if self.has_uninitialized_params(): |
| with torch.no_grad(): |
| self.param.materialize(input.shape) |
| nn.init.uniform_(self.param) |
| |
| def forward(self, x): |
| return self.param + x |
| |
| m = MyLazyModule() |
| self.assertTrue(m.has_uninitialized_params()) |
| output = m(self._create_tensor(tensor_cls)) |
| self.assertFalse(m.has_uninitialized_params()) |
| self.assertIsInstance(m.param, tensor_cls) |
| |
| def test_non_rewrapping_torch_dispatch_subclass_as_parameter_throws_for_detach(self): |
| |
| # Define a subclass that does not rewrap for any function in its __torch_dispatch__ impl. |
| class NonRewrappingTensor(torch.Tensor): |
| @staticmethod |
| def __new__( |
| cls, t: torch.Tensor |
| ): |
| r = super()._make_wrapper_subclass( |
| cls, t.shape, dtype=t.dtype, requires_grad=t.requires_grad, device=t.device) |
| return r |
| |
| def __init__(self, t) -> None: |
| self.tensor: torch.Tensor = t |
| |
| @classmethod |
| def __torch_dispatch__(cls, func, types, args=(), kwargs=None): |
| |
| def unwrap(e) -> torch.Tensor: |
| if isinstance(e, NonRewrappingTensor): |
| t = e.tensor |
| return t |
| else: |
| return e |
| |
| r = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) |
| # Return an unwrapped tensor no longer of original subclass type. |
| return r |
| |
| with self.assertRaisesRegex(RuntimeError, r"requires that detach\(\) returns an instance of the same type"): |
| param = nn.Parameter(NonRewrappingTensor(torch.randn(3))) |
| |
| def test_tensor_subclass_storage_data_accesses_throw(self): |
| from torch.testing._internal.logging_tensor import LoggingTensor |
| x = torch.ones(2) |
| x_log = LoggingTensor(x) |
| # Accessing storage on a tensor subclass is valid |
| storage = x_log.untyped_storage() |
| # This includes accessing metadata on the storage |
| sz = storage.size() |
| # But storage methods that access data will throw |
| with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"): |
| storage.data_ptr() |
| with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"): |
| storage.resize_(0) |
| with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"): |
| storage.copy_(storage) |
| with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"): |
| storage.fill_(0) |
| with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"): |
| storage._write_file("file") |
| |
| |
| instantiate_parametrized_tests(TestSubclass) |
| |
| if __name__ == '__main__': |
| run_tests() |