blob: d3bb54ea2881c9e1ca9f3d4ebe5c66d0b3b1afd7 [file] [log] [blame] [edit]
# Owner(s): ["module: nn"]
import tempfile
from copy import deepcopy
from functools import partial
from unittest import expectedFailure
import torch
from torch import nn
from torch.nn.modules.lazy import LazyModuleMixin
from torch.nn.utils.parametrize import (
register_parametrization,
remove_parametrizations,
)
from torch.testing._internal.common_subclass import (
DiagTensorBelow,
subclass_db,
)
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
skipIfTorchDynamo,
subtest,
)
from torch.testing._internal.logging_tensor import LoggingTensor
from torch.utils._pytree import tree_map
# The current test methodology in this file is to test a variety of real use cases
# with a set of fully-fledged tensor subclasses. In the future, this may change
# to more narrowly specify toy subclasses for each of the specific invariants under
# test, avoiding the need to maintain the set of fully-fledged tensor subclasses.
# Decorator for parametrizing tests across the various tensor classes.
parametrize_tensor_cls = parametrize("tensor_cls", [
subtest(tensor_cls, name=info.name) for tensor_cls, info in subclass_db.items()])
class TestSubclass(TestCase):
def _create_tensor(self, tensor_cls):
return subclass_db[tensor_cls].create_fn(3)
@parametrize_tensor_cls
@parametrize("tensor_requires_grad", [False, True])
def test_param_invariants(self, tensor_cls, tensor_requires_grad):
x = self._create_tensor(tensor_cls).requires_grad_(tensor_requires_grad)
param = nn.Parameter(x, requires_grad=(not tensor_requires_grad))
self.assertIsInstance(param, nn.Parameter)
# Ensure requires_grad passed to Parameter's constructor takes precedence.
self.assertEqual(param.requires_grad, not tensor_requires_grad)
# Ensure original tensor is not mutated by Parameter construction.
self.assertNotIsInstance(x, nn.Parameter)
self.assertEqual(x.requires_grad, tensor_requires_grad)
class UninitializedParam(nn.Parameter):
pass
self.assertNotIsInstance(param, UninitializedParam)
@skipIfTorchDynamo()
@parametrize_tensor_cls
@parametrize("as_param", [False, True])
def test_deepcopy(self, tensor_cls, as_param):
x = self._create_tensor(tensor_cls)
if as_param:
x = nn.Parameter(x)
x_copy = deepcopy(x)
self.assertEqual(x, x_copy)
self.assertEqual(x.__class__, x_copy.__class__)
self.assertIsNot(x, x_copy)
self.assertIsInstance(x_copy, tensor_cls)
if as_param:
# Deepcopy should preserve both custom type and "parameter-ness".
self.assertIsInstance(x_copy, nn.Parameter)
@parametrize_tensor_cls
@parametrize("as_param", [False, True])
def test_serialization(self, tensor_cls, as_param):
with tempfile.TemporaryFile() as f:
x = self._create_tensor(tensor_cls)
if as_param:
x = nn.Parameter(x)
torch.save(x, f)
f.seek(0)
with torch.serialization.safe_globals([tensor_cls]):
x_loaded = torch.load(f)
self.assertEqual(x, x_loaded)
self.assertIsNot(x, x_loaded)
self.assertIsInstance(x_loaded, tensor_cls)
if as_param:
# Serialization should preserve both custom type and "parameter-ness".
self.assertIsInstance(x_loaded, nn.Parameter)
@skipIfTorchDynamo("Visible only with functorch as functorch monkeypatches tensor str")
@parametrize_tensor_cls
@parametrize("as_param", [False, True])
def test_repr(self, tensor_cls, as_param):
x = self._create_tensor(tensor_cls)
if as_param:
x = nn.Parameter(x)
str_repr = x.__repr__()
if tensor_cls is not torch.Tensor:
self.assertEqual(str_repr.count(f"{tensor_cls.__name__}("), 1)
self.assertEqual(str_repr.count("Parameter"), 1 if as_param else 0)
@parametrize_tensor_cls
@parametrize("as_param", [False, True])
def test_type_propagation(self, tensor_cls, as_param):
x = self._create_tensor(tensor_cls)
if as_param:
x = nn.Parameter(x)
# Call the add operator to produce an output tensor.
output = x + self._create_tensor(torch.Tensor)
# Custom type should be propagated across operations if closed under the op, but
# "parameter-ness" should not be.
if subclass_db[tensor_cls].closed_under_ops:
self.assertIsInstance(output, tensor_cls)
else:
self.assertIsInstance(output, torch.Tensor)
self.assertNotIsInstance(output, nn.Parameter)
@parametrize_tensor_cls
def test_module_optimization(self, tensor_cls):
create_fn = partial(self._create_tensor, tensor_cls)
class MyModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.p1 = nn.Parameter(create_fn())
self.p_list = nn.ParameterList([create_fn() for _ in range(3)])
self.p_list.append(create_fn())
self.p_dict = nn.ParameterDict({
'foo': create_fn(),
'bar': create_fn(),
})
self.p_dict['baz'] = create_fn()
with torch.no_grad():
nn.init.normal_(self.p1)
for p in self.p_list:
nn.init.uniform_(p)
for p in self.p_dict.values():
nn.init.uniform_(p)
def forward(self, x):
out = self.p1 + x
for p in self.p_list:
out = p + out
for v in self.p_dict.values():
out = v + out
return out
m = MyModule()
self.assertEqual(len(m.state_dict()), 8)
optimizer = torch.optim.SGD(m.parameters(), lr=0.1)
m(create_fn()).sum().backward(torch.tensor(1))
optimizer.step()
@parametrize_tensor_cls
@parametrize("leave_parametrized", [False, True])
def test_parametrization(self, tensor_cls, leave_parametrized):
# TODO: Either implement set_() properly for these tensor subclasses or apply a
# more general fix to avoid the need for special set_() handling. For now, skip
# testing these as they're expected to fail.
if tensor_cls in [LoggingTensor, DiagTensorBelow]:
return
create_fn = partial(self._create_tensor, tensor_cls)
class MyModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.weight = nn.Parameter(create_fn())
def forward(self, x):
return self.weight + x
class MyParametrization(nn.Module):
def forward(self, X):
return -X
m = MyModule()
self.assertEqual(len(m.state_dict()), 1)
register_parametrization(m, 'weight', MyParametrization())
self.assertIsInstance(m.weight, tensor_cls)
output = m(self._create_tensor(torch.Tensor))
self.assertIsInstance(output, tensor_cls)
remove_parametrizations(m, 'weight', leave_parametrized=leave_parametrized)
# Lazy modules with custom tensors are not supported yet.
@expectedFailure
@parametrize_tensor_cls
def test_lazy_module(self, tensor_cls):
if tensor_cls is torch.Tensor:
self.fail('dummy fail for base tensor until the test passes for subclasses')
class MyLazyModule(LazyModuleMixin, nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = nn.UninitializedParameter()
def initialize_parameters(self, input) -> None: # type: ignore[override]
if self.has_uninitialized_params():
with torch.no_grad():
self.param.materialize(input.shape)
nn.init.uniform_(self.param)
def forward(self, x):
return self.param + x
m = MyLazyModule()
self.assertTrue(m.has_uninitialized_params())
output = m(self._create_tensor(tensor_cls))
self.assertFalse(m.has_uninitialized_params())
self.assertIsInstance(m.param, tensor_cls)
def test_non_rewrapping_torch_dispatch_subclass_as_parameter_throws_for_detach(self):
# Define a subclass that does not rewrap for any function in its __torch_dispatch__ impl.
class NonRewrappingTensor(torch.Tensor):
@staticmethod
def __new__(
cls, t: torch.Tensor
):
r = super()._make_wrapper_subclass(
cls, t.shape, dtype=t.dtype, requires_grad=t.requires_grad, device=t.device)
return r
def __init__(self, t) -> None:
self.tensor: torch.Tensor = t
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(e) -> torch.Tensor:
if isinstance(e, NonRewrappingTensor):
t = e.tensor
return t
else:
return e
r = func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
# Return an unwrapped tensor no longer of original subclass type.
return r
with self.assertRaisesRegex(RuntimeError, r"requires that detach\(\) returns an instance of the same type"):
param = nn.Parameter(NonRewrappingTensor(torch.randn(3)))
def test_tensor_subclass_storage_data_accesses_throw(self):
from torch.testing._internal.logging_tensor import LoggingTensor
x = torch.ones(2)
x_log = LoggingTensor(x)
# Accessing storage on a tensor subclass is valid
storage = x_log.untyped_storage()
# This includes accessing metadata on the storage
sz = storage.size()
# But storage methods that access data will throw
with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"):
storage.data_ptr()
with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"):
storage.resize_(0)
with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"):
storage.copy_(storage)
with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"):
storage.fill_(0)
with self.assertRaisesRegex(RuntimeError, "on an invalid python storage"):
storage._write_file("file")
instantiate_parametrized_tests(TestSubclass)
if __name__ == '__main__':
run_tests()