| # Owner(s): ["module: nn"] |
| |
| from itertools import chain, product |
| from inspect import signature, isgenerator |
| from copy import deepcopy |
| import tempfile |
| from operator import methodcaller |
| |
| import torch |
| |
| from torch._subclasses.meta_utils import assert_metadata_eq |
| from torch.testing._internal.common_cuda import with_tf32_off |
| from torch.testing._internal.common_device_type import ( |
| instantiate_device_type_tests, onlyCPU, onlyCUDA, toleranceOverride, tol, skipMeta) |
| from torch.testing._internal.common_modules import module_db, modules, ModuleErrorEnum, TrainEvalMode |
| from torch.testing._internal.common_utils import ( |
| TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck, |
| gradgradcheck, parametrize, wrapSwapTensorsTest) |
| from unittest.mock import patch, call |
| |
| |
| class TestModule(TestCase): |
| _do_cuda_memory_leak_check = True |
| _do_cuda_non_default_stream = True |
| precision = 1e-5 |
| rel_tol = 1e-5 |
| |
| def _assert_module_parameters_and_buffer_are(self, module, device, dtype): |
| # Check device placement and dtype for created parameters and buffers. |
| # Only verify floating point dtypes since that's what the kwarg or methods |
| # such as `float()` applies to. |
| if not isinstance(device, torch.device): |
| device = torch.device(device) |
| |
| def _check_module(items, name, device=device, dtype=dtype): |
| for item_name, item in items: |
| self.assertEqual( |
| item.device, device, |
| f'{name} {item_name} is on device {item.device} instead of the expected device {device}') |
| if item.dtype.is_floating_point: |
| self.assertEqual( |
| item.dtype, dtype, |
| f'{name} {item_name} is of dtype {item.dtype} instead of the expected dtype {dtype}') |
| _check_module(module.named_parameters(), "Parameter") |
| _check_module(module.named_buffers(), "Buffer") |
| |
| @modules(module_db) |
| def test_forward(self, device, dtype, module_info, training): |
| module_cls = module_info.module_cls |
| module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, |
| requires_grad=False, training=training) |
| dtype_to_method_caller = { |
| torch.float32: methodcaller("float"), |
| torch.float64: methodcaller("double"), |
| } |
| for module_input in module_inputs: |
| if module_input.forward_input is None: |
| continue |
| |
| with freeze_rng_state(): |
| # === Instantiate the module. === |
| args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs |
| m = module_cls(*args, **kwargs) |
| m.to(device).to(dtype) |
| m.train(training) |
| |
| # === Do forward pass. === |
| args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs |
| outputs = m(*args, **kwargs) |
| |
| # === Compare outputs to a reference if one is specified. === |
| # TODO: Handle precision |
| reference_fn = module_input.reference_fn |
| if reference_fn is not None: |
| ref_outputs = reference_fn(m, *args, **kwargs) |
| self.assertEqual(outputs, ref_outputs) |
| |
| # === Use the method call and verify the parameters and buffers === |
| if dtype in dtype_to_method_caller: |
| dtype_to_method_caller[dtype](m) |
| m(*args, **kwargs) |
| self._assert_module_parameters_and_buffer_are(m, device, dtype) |
| |
| # Tests passing factory kwargs (e.g. device / dtype) during module instantiation. |
| # They should be applied to any created parameters and buffers. |
| @modules(module_db) |
| def test_factory_kwargs(self, device, dtype, module_info, training): |
| module_cls = module_info.module_cls |
| module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, |
| requires_grad=False, training=training) |
| for module_input in module_inputs: |
| args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs |
| |
| # Check if this module creates parameters or registers buffers. |
| # The mock magic here passes through to the real Parameter / register_buffer |
| # logic and is only used to check call inputs. |
| module_creates_params_or_buffers = False |
| parameter_new = mock_wrapper(torch.nn.Parameter.__new__) |
| with patch.object(torch.nn.Parameter, '__new__', parameter_new): |
| register_buffer = mock_wrapper(torch.nn.Module.register_buffer) |
| with patch.object(torch.nn.Module, 'register_buffer', register_buffer): |
| m = module_cls(*args, **kwargs) |
| m.train(training) |
| |
| # Check if a parameter or buffer was created with a tensor not passed to the constructor. |
| constructor_tensors = get_tensors_from(args, kwargs) |
| for mock in [parameter_new.mock, register_buffer.mock]: |
| for call_args, call_kwargs in mock.call_args_list: |
| call_tensors = get_tensors_from(call_args, call_kwargs) |
| if len(call_tensors) > 0 and not constructor_tensors.intersection(call_tensors): |
| module_creates_params_or_buffers = True |
| break |
| |
| if not module_creates_params_or_buffers: |
| continue |
| |
| # Instantiate module with the factory kwargs. |
| kwargs.update({ |
| 'device': device, |
| 'dtype': dtype, |
| }) |
| |
| if issubclass(module_info.module_cls, torch.nn.modules.lazy.LazyModuleMixin): |
| # Ensure device and dtype are passed to all UninitializedParameters and UninitializedBuffers. |
| uninit_param_new = mock_wrapper(torch.nn.UninitializedParameter.__new__) |
| with patch.object(torch.nn.UninitializedParameter, '__new__', uninit_param_new): |
| uninit_buffer_new = mock_wrapper(torch.nn.UninitializedBuffer.__new__) |
| with patch.object(torch.nn.UninitializedBuffer, '__new__', uninit_buffer_new): |
| m = module_cls(*args, **kwargs) |
| m.train(training) |
| uninit_param_new.mock.assert_has_calls( |
| [call(device=device, dtype=dtype) for _ in uninit_param_new.mock.mock_calls]) |
| uninit_buffer_new.mock.assert_has_calls( |
| [call(device=device, dtype=dtype) for _ in uninit_buffer_new.mock.mock_calls]) |
| else: |
| # Check device placement and dtype for created parameters and buffers. |
| # Only verify floating point dtypes since that's what the kwarg applies to. |
| m = module_cls(*args, **kwargs) |
| m.train(training) |
| self._assert_module_parameters_and_buffer_are(m, device, dtype) |
| |
| @onlyCUDA |
| @modules(module_db) |
| def test_multiple_device_transfer(self, device, dtype, module_info, training): |
| module_cls = module_info.module_cls |
| module_inputs_device = module_info.module_inputs_func(module_info, device=device, dtype=dtype, |
| requires_grad=False, training=training) |
| module_inputs_cpu = module_info.module_inputs_func(module_info, device="cpu", dtype=dtype, |
| requires_grad=False, training=training) |
| for module_input_device, module_input_cpu in zip(module_inputs_device, module_inputs_cpu): |
| if module_input_device.forward_input is None: |
| continue |
| |
| with freeze_rng_state(): |
| # === Instantiate the module. === |
| args, kwargs = module_input_device.constructor_input.args, module_input_device.constructor_input.kwargs |
| m = module_cls(*args, **kwargs) |
| m.to(device).to(dtype) |
| m.train(training) |
| |
| # === Do forward pass on GPU === |
| input_device_args = module_input_device.forward_input.args |
| input_device_kwargs = module_input_device.forward_input.kwargs |
| m(*input_device_args, **input_device_kwargs) |
| self._assert_module_parameters_and_buffer_are(m, device, dtype) |
| |
| # === Move to CPU === |
| input_cpu_args = module_input_cpu.forward_input.args |
| input_cpu_kwargs = module_input_cpu.forward_input.kwargs |
| m.cpu() |
| m(*input_cpu_args, **input_cpu_kwargs) |
| self._assert_module_parameters_and_buffer_are(m, "cpu", dtype) |
| |
| # === Move back to GPU and forward pass === |
| m.cuda() |
| m(*input_device_args, **input_device_kwargs) |
| self._assert_module_parameters_and_buffer_are(m, device, dtype) |
| |
| if torch.cuda.device_count() >= 2: |
| # === test cross-GPU transfer works |
| def _to_device1(objs): |
| if isinstance(objs, (tuple, list)): |
| return type(objs)(_to_device1(item) for item in objs) |
| elif isinstance(objs, dict): |
| return {name: _to_device1(item) for name, item in objs.items()} |
| elif isinstance(objs, torch.Tensor): |
| return objs.cuda(1) |
| else: |
| return objs |
| input_device_1_args = _to_device1(input_device_args) |
| input_device_1_kwargs = _to_device1(input_device_kwargs) |
| |
| m.cuda(1) |
| with torch.cuda.device(1): |
| m(*input_device_1_args, **input_device_1_kwargs) |
| self._assert_module_parameters_and_buffer_are(m, torch.device("cuda:1"), dtype) |
| |
| @modules(module_db) |
| def test_repr(self, device, dtype, module_info, training): |
| # Test module can be represented with repr and str without errors. |
| module_cls = module_info.module_cls |
| module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, |
| requires_grad=False, training=training) |
| for module_input in module_inputs: |
| args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs |
| m = module_cls(*args, **kwargs) |
| m.to(device).to(dtype) |
| m.train(training) |
| |
| # Check that these methods do not raise errors |
| m.__repr__() |
| str(m) |
| |
| @modules(module_db) |
| def test_save_load(self, device, dtype, module_info, training): |
| # Test that module can be pickled and unpickled. |
| module_cls = module_info.module_cls |
| module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, |
| requires_grad=False, training=training) |
| for module_input in module_inputs: |
| if module_input.forward_input is None: |
| continue |
| |
| args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs |
| |
| with freeze_rng_state(): |
| # === Instantiate the module. === |
| args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs |
| m = module_cls(*args, **kwargs) |
| m.to(device).to(dtype) |
| m.train(training) |
| sd = m.state_dict() |
| |
| # === Do forward pass. === |
| args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs |
| output = m(*args, **kwargs) |
| |
| # === Check saved/loaded module gives the same output. === |
| with tempfile.TemporaryFile() as f: |
| torch.save(m, f) |
| f.seek(0) |
| # weights_only=False as this is legacy code that saves the model |
| m_copy = torch.load(f, weights_only=False) |
| output_from_copy = m_copy(*args, **kwargs) |
| self.assertEqual(output, output_from_copy) |
| |
| # === Check saved/loaded state_dict are the same (including weights_only load). === |
| with tempfile.TemporaryFile() as f: |
| torch.save(sd, f) |
| f.seek(0) |
| sd_copy = torch.load(f) |
| self.assertEqual(sd_copy, sd) |
| del sd_copy |
| f.seek(0) |
| sd_copy_wo = torch.load(f, weights_only=True) |
| self.assertEqual(sd_copy_wo, sd) |
| |
| @skipMeta |
| @modules([module_info for module_info in module_db |
| if 'inplace' in signature(module_info.module_cls).parameters]) |
| def test_check_inplace(self, device, dtype, module_info, training): |
| # Check if the inplace variant of the module gives the same result as the out of place |
| # variant. |
| module_cls = module_info.module_cls |
| module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, |
| requires_grad=True, training=training) |
| for module_input in module_inputs: |
| if module_input.forward_input is None: |
| continue |
| |
| # === Instantiate the module. === |
| args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs |
| m_op = module_cls(*args, **kwargs, inplace=False) |
| m_op.to(device).to(dtype) |
| m_op.train(training) |
| m_inplace = module_cls(*args, **kwargs, inplace=True) |
| m_inplace.to(device).to(dtype) |
| m_inplace.train(training) |
| |
| # === Inplace modules only supports inplace operations on the first argument === |
| input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs |
| |
| # === Do not allow the first input to be in input_kwargs === |
| forward_sig = signature(m_op).parameters |
| self.assertGreaterEqual(len(forward_sig), 1) |
| first_param_name = next(iter(forward_sig.items())) |
| self.assertNotIn(first_param_name, input_kwargs) |
| |
| # === Out of place operation does not write to original tensor === |
| self.assertGreaterEqual(len(input_args), 1) |
| input_version = input_args[0]._version |
| with freeze_rng_state(): |
| output_op = m_op(*input_args, **input_kwargs) |
| self.assertEqual(input_args[0]._version, input_version) |
| |
| # === Check that the inplace operation gives the same result === |
| input_arg_copy = deepcopy(input_args) |
| input_arg_clone = tuple(i.clone() for i in input_arg_copy) |
| input_clone_version = input_arg_clone[0]._version |
| with freeze_rng_state(): |
| output_ip = m_inplace(*input_arg_clone, **input_kwargs) |
| self.assertGreater(input_arg_clone[0]._version, input_clone_version) |
| self.assertEqual(output_op, output_ip) |
| |
| # === Check that the gradients are the same === |
| grad = output_op.data.clone().normal_() |
| output_op.backward(grad) |
| output_ip.backward(grad) |
| self.assertEqual(input_args[0].grad, input_arg_copy[0].grad) |
| |
| def _traverse_obj(self, obj, func): |
| if isinstance(obj, (tuple, list)): |
| return type(obj)(self._traverse_obj(o, func) for o in obj) |
| elif isgenerator(obj): |
| return tuple(self._traverse_obj(o, func) for o in obj) |
| elif isinstance(obj, dict): |
| return {name: self._traverse_obj(o, func) for name, o in obj.items()} |
| elif isinstance(obj, (torch.Tensor, torch.nn.Parameter)): |
| return func(obj) |
| else: |
| return obj |
| |
| def _retain_grad(self, obj): |
| # gradients needs to be retained to check for grad. This is useful when |
| # non-leafs are present in the graph. |
| def inner_retain_grad(obj): |
| if obj.requires_grad: |
| obj.retain_grad() |
| self._traverse_obj(obj, inner_retain_grad) |
| |
| def _get_grads(self, obj): |
| def inner_get_grad(obj): |
| if obj.requires_grad: |
| return obj.grad |
| return self._traverse_obj(obj, inner_get_grad) |
| |
| def _zero_grad(self, obj): |
| def inner_zero_grad(obj): |
| if obj.grad is not None: |
| obj.grad = None |
| self._traverse_obj(obj, inner_zero_grad) |
| |
| @modules(module_db) |
| def test_non_contiguous_tensors(self, device, dtype, module_info, training): |
| # Check modules work with non-contiguous tensors |
| |
| module_cls = module_info.module_cls |
| module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, |
| requires_grad=True, training=training) |
| |
| def _make_non_contiguous(obj): |
| def inner_make_non_contiguous(obj): |
| # Scalar tensors can not be made non-contiguous |
| if not isinstance(obj, torch.Tensor) or obj.dim() == 0: |
| return obj |
| |
| out = torch.repeat_interleave(obj, 2, dim=-1) |
| out = out[..., ::2].detach() |
| out.requires_grad = obj.requires_grad |
| return out |
| return self._traverse_obj(obj, inner_make_non_contiguous) |
| |
| def _can_be_noncontiguous(obj): |
| if isinstance(obj, (tuple, list)): |
| return any(_can_be_noncontiguous(o) for o in obj) |
| elif isinstance(obj, dict): |
| return any(_can_be_noncontiguous(o) for o in obj.values()) |
| # scalar tensors can not be non-contiguous |
| return isinstance(obj, torch.Tensor) and obj.dim() != 0 |
| |
| for module_input in module_inputs: |
| if module_input.forward_input is None: |
| continue |
| |
| input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs |
| if not (_can_be_noncontiguous(input_args) or _can_be_noncontiguous(input_kwargs)): |
| continue |
| |
| # === Instantiate the module. === |
| args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs |
| m = module_cls(*args, **kwargs) |
| m.to(device).to(dtype) |
| m.train(training) |
| |
| self._retain_grad((input_args, input_kwargs)) |
| |
| # === Forward with default input |
| with freeze_rng_state(): |
| default_output = m(*input_args, **input_kwargs) |
| if isinstance(default_output, torch.Tensor): |
| grad_output = default_output.clone().detach_().normal_() |
| default_output.backward(grad_output, retain_graph=True) |
| else: |
| grad_output = tuple(self._traverse_obj(o, lambda o: o.clone().detach_().normal_() if o.requires_grad else None) |
| for o in default_output) |
| flattened_default_output = torch.utils._pytree.tree_leaves(default_output) |
| flattened_grad_output = torch.utils._pytree.tree_leaves(grad_output) |
| for o, g_o in zip(flattened_default_output, flattened_grad_output): |
| if (o.requires_grad): |
| o.backward(g_o, retain_graph=True) |
| |
| default_input_args_grad, default_input_kwargs_grad = deepcopy(self._get_grads((input_args, input_kwargs))) |
| default_param_grad = deepcopy([p.grad for p in m.parameters()]) |
| |
| # === Construct non-contiguous tensors === |
| nc_input_args, nc_input_kwargs = _make_non_contiguous((input_args, input_kwargs)) |
| nc_grad_output = _make_non_contiguous(grad_output) |
| |
| # === Compare results with non-contiguous and contiguous tensors === |
| inputs = [(input_args, input_kwargs), (nc_input_args, nc_input_kwargs)] |
| grads = [grad_output, nc_grad_output] |
| |
| for (in_args, in_kwargs), g_out in product(inputs, grads): |
| g_out_copy = deepcopy(g_out) |
| self._zero_grad((in_args, in_kwargs)) |
| self._zero_grad(m.parameters()) |
| |
| with freeze_rng_state(): |
| out = m(*in_args, **in_kwargs) |
| if isinstance(out, torch.Tensor): |
| out.backward(g_out_copy, retain_graph=True) |
| else: |
| flattened_out = torch.utils._pytree.tree_leaves(out) |
| flattened_g_out_copy = torch.utils._pytree.tree_leaves(g_out_copy) |
| for o, g_o in zip(flattened_out, flattened_g_out_copy): |
| if o.requires_grad: |
| o.backward(g_o, retain_graph=True) |
| |
| input_args_grad, input_kwargs_grad = self._get_grads((in_args, in_kwargs)) |
| self.assertEqual(out, default_output) |
| self.assertEqual(input_args_grad, default_input_args_grad, atol=1e-4, rtol=0) |
| self.assertEqual(input_kwargs_grad, default_input_kwargs_grad, atol=1e-4, rtol=0) |
| |
| param_grad = [p.grad for p in m.parameters()] |
| self.assertEqual(param_grad, default_param_grad) |
| |
| def _test_gradients_helper(self, device, dtype, module_info, training, check): |
| # Check gradients |
| module_cls = module_info.module_cls |
| module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, |
| requires_grad=True, training=training) |
| # === Set nondet tol for gradcheck to user-defined value if on CUDA and cudNN is enabled |
| gradcheck_nondet_tol = 0.0 |
| if (torch.device(device).type == 'cuda' and torch.backends.cudnn.enabled): |
| gradcheck_nondet_tol = module_info.gradcheck_nondet_tol |
| |
| for module_input in module_inputs: |
| if module_input.forward_input is None: |
| continue |
| |
| # === Instantiate the module. === |
| args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs |
| m = module_cls(*args, **kwargs) |
| m.to(device).to(dtype) |
| m.train(training) |
| |
| params = tuple(m.parameters()) |
| |
| # === Lazy modules need to see an input to initialize params before gradcheck is run. === |
| input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs |
| if issubclass(module_info.module_cls, torch.nn.modules.lazy.LazyModuleMixin): |
| with torch.no_grad(): |
| m(*input_args, **input_kwargs) |
| |
| # === Perform gradient check on the input_args === |
| other_kwargs = {} |
| kwarg_tensors = [] |
| for name, obj in input_kwargs.items(): |
| if isinstance(obj, torch.Tensor): |
| kwarg_tensors.append((name, obj)) |
| else: |
| other_kwargs[name] = obj |
| |
| def fn_to_gradcheck(*flat_input_and_params): |
| input_and_params = torch.utils._pytree.tree_unflatten(flat_input_and_params, flat_spec) |
| new_input_args = input_and_params[:len(input_args)] |
| kwarg_args = input_and_params[-len(kwarg_tensors):] |
| new_kwargs = {name: obj for (name, _), obj in zip(kwarg_tensors, kwarg_args)} |
| |
| with freeze_rng_state(): |
| output = m(*new_input_args, **new_kwargs, **other_kwargs) |
| output_flattened = torch.utils._pytree.tree_leaves(output) |
| return output_flattened |
| |
| # check total derivative |
| grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors) |
| flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input) |
| |
| self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol)) |
| |
| # check partial derivatives |
| old_params_requires_grad = [p.requires_grad for p in params] |
| for p in params: |
| p.requires_grad = False |
| |
| old_kwargs_requires_grad = [obj.requires_grad for (_, obj) in kwarg_tensors] |
| for (_, obj) in kwarg_tensors: |
| obj.requires_grad = False |
| |
| for p, old in zip(params, old_params_requires_grad): |
| p.requires_grad = old |
| grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors) |
| flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input) |
| self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol)) |
| p.requires_grad = False |
| |
| for (_, obj), old in zip(kwarg_tensors, old_kwargs_requires_grad): |
| obj.requires_grad = old |
| grad_input = input_args + params + tuple(obj for (_, obj) in kwarg_tensors) |
| flat_input, flat_spec = torch.utils._pytree.tree_flatten(grad_input) |
| self.assertTrue(check(fn_to_gradcheck, flat_input, nondet_tol=gradcheck_nondet_tol)) |
| obj.requires_grad = False |
| |
| @modules(module_db, allowed_dtypes=[torch.double]) |
| def test_grad(self, device, dtype, module_info, training): |
| self._test_gradients_helper(device, dtype, module_info, training, gradcheck) |
| |
| @modules([m for m in module_db if m.supports_gradgrad], |
| allowed_dtypes=[torch.double]) |
| def test_gradgrad(self, device, dtype, module_info, training): |
| self._test_gradients_helper(device, dtype, module_info, training, gradgradcheck) |
| |
| @onlyCUDA |
| @with_tf32_off # Turn off TF32 to compute at full precision https://github.com/pytorch/pytorch/issues/86798 |
| @toleranceOverride({torch.float32: tol(5e-2, 0), |
| torch.float64: tol(4e-4, 0)}) |
| @modules(module_db) |
| def test_cpu_gpu_parity(self, device, dtype, module_info, training): |
| # TODO: RNN / GRU / LSTM don't support backwards on eval mode for cuDNN; skip this in a |
| # nicer way for eval mode only. |
| # See https://github.com/pytorch/pytorch/issues/79161 |
| rnn_modules = {torch.nn.RNN, torch.nn.GRU, torch.nn.LSTM} |
| if (module_info.module_cls in rnn_modules |
| and not training |
| and 'cuda' in device |
| and torch.backends.cudnn.enabled): |
| return |
| |
| # Test cpu and gpu results are the same |
| module_cls = module_info.module_cls |
| module_inputs_cpu = module_info.module_inputs_func(module_info, device="cpu", dtype=dtype, |
| requires_grad=True, training=training) |
| |
| def _to_device(obj): |
| if isinstance(obj, torch.Tensor): |
| res = obj.detach().to(device=device) |
| res.requires_grad = obj.requires_grad |
| return res |
| elif isinstance(obj, tuple): |
| return tuple(_to_device(o) for o in obj) |
| elif isinstance(obj, dict): |
| return {key: _to_device(o) for key, o in obj.items()} |
| else: |
| return deepcopy(obj) |
| |
| for module_input in module_inputs_cpu: |
| # === Move input from cpu to device === |
| cpu_forward_args = module_input.forward_input.args |
| cpu_forward_kwargs = module_input.forward_input.kwargs |
| |
| gpu_forward_args, gpu_forward_kwargs = _to_device((cpu_forward_args, cpu_forward_kwargs)) |
| |
| self._retain_grad((cpu_forward_args, cpu_forward_kwargs, gpu_forward_args, gpu_forward_kwargs)) |
| |
| # === Construct module on cpu and gpu === |
| args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs |
| |
| cpu_module = module_cls(*args, **kwargs).to(dtype).to("cpu") |
| cpu_module.train(training) |
| gpu_module = module_cls(*args, **kwargs).to(dtype).to(device) |
| gpu_module.train(training) |
| |
| # === Lazy modules need to see an input to initialize params === |
| if issubclass(module_cls, torch.nn.modules.lazy.LazyModuleMixin): |
| with torch.no_grad(): |
| cpu_module(*cpu_forward_args, **cpu_forward_kwargs) |
| gpu_module(*gpu_forward_args, **gpu_forward_kwargs) |
| |
| for cpu_p, gpu_p in zip(cpu_module.parameters(), gpu_module.parameters()): |
| gpu_p.data.copy_(cpu_p) |
| |
| # === Compare forward output between cpu and gpu === |
| cpu_outputs = cpu_module(*cpu_forward_args, **cpu_forward_kwargs) |
| gpu_outputs = gpu_module(*gpu_forward_args, **gpu_forward_kwargs) |
| |
| self.assertEqual(cpu_outputs, gpu_outputs) |
| |
| # === Run backwards on CPU and GPU and compare results === |
| def check_backward(cpu_output, gpu_output): |
| cpu_grad_output = cpu_output.clone().normal_() |
| gpu_grad_output = cpu_grad_output.type_as(gpu_output) |
| |
| cpu_output.backward(cpu_grad_output, retain_graph=True) |
| gpu_output.backward(gpu_grad_output, retain_graph=True) |
| |
| cpu_grad_input = self._get_grads(cpu_forward_args) |
| gpu_grad_input = self._get_grads(gpu_forward_args) |
| self.assertEqual(cpu_grad_input, gpu_grad_input) |
| |
| for cpu_p, gpu_p in zip(cpu_module.parameters(), gpu_module.parameters()): |
| self.assertEqual(cpu_p.grad, gpu_p.grad) |
| |
| cpu_grad_kwarg_input = self._get_grads(cpu_forward_kwargs) |
| gpu_grad_kwarg_input = self._get_grads(gpu_forward_kwargs) |
| self.assertEqual(cpu_grad_kwarg_input, gpu_grad_kwarg_input) |
| |
| for _ in range(5): |
| if isinstance(cpu_outputs, torch.Tensor): |
| check_backward(cpu_outputs, gpu_outputs) |
| else: |
| flatten_cpu_outputs = torch.utils._pytree.tree_leaves(cpu_outputs) |
| flatten_gpu_outputs = torch.utils._pytree.tree_leaves(gpu_outputs) |
| for cpu_output, gpu_output in zip(flatten_cpu_outputs, flatten_gpu_outputs): |
| if cpu_output.requires_grad: |
| check_backward(cpu_output, gpu_output) |
| |
| @with_tf32_off |
| @modules(module_db) |
| def test_memory_format(self, device, dtype, module_info, training): |
| is_sm86or80 = device.startswith("cuda") and (torch.cuda.get_device_capability(0) == (8, 6) |
| or torch.cuda.get_device_capability(0) == (8, 0)) |
| # TODO tighten it to a specific module |
| atol, rtol = (3e-3, 7e-3) if is_sm86or80 else (None, None) |
| module_cls = module_info.module_cls |
| module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, |
| requires_grad=True, training=training) |
| module_memformat_affects_out = module_info.module_memformat_affects_out |
| |
| def _get_mem_formats(channels_last=False, channels_last_3d=False): |
| if channels_last: |
| return ([torch.contiguous_format, torch.channels_last], |
| [torch.preserve_format, torch.contiguous_format, torch.channels_last]) |
| elif channels_last_3d: |
| return ([torch.contiguous_format, torch.channels_last_3d], |
| [torch.preserve_format, torch.contiguous_format, torch.channels_last_3d]) |
| else: |
| return ([torch.contiguous_format], |
| [torch.preserve_format, torch.contiguous_format]) |
| |
| # Check that at least one Tensor input has dim == n |
| def _check_dims(obj, n): |
| if isinstance(obj, torch.Tensor): |
| return obj.dim() == n |
| elif isinstance(obj, (tuple, list)): |
| return any(_check_dims(o, n) for o in obj) |
| else: |
| return False |
| |
| # Called after _check_dims, when we know that >= 1 tensor can be converted to mem_format |
| def _to_mem_format(mem_format, obj): |
| def inner_to_mem_format(obj): |
| d = obj.dim() |
| if ((mem_format == torch.channels_last and d != 4) |
| or (mem_format == torch.channels_last_3d and d != 5)): |
| return obj.clone().detach().requires_grad_(obj.requires_grad) |
| return obj.clone().to(memory_format=mem_format).detach().requires_grad_(obj.requires_grad) |
| |
| return self._traverse_obj(obj, inner_to_mem_format) |
| |
| def _check_out_mem_format(output, input_mem_format, module_mem_format): |
| def inner_check_out_mem_format(output): |
| d = output.dim() |
| if (d == 4 and ((input_mem_format == torch.channels_last) |
| or (module_mem_format == torch.channels_last and module_memformat_affects_out))): |
| self.assertTrue(output.numel() == 0 or output.is_contiguous(memory_format=torch.channels_last)) |
| elif (d == 5 and ((input_mem_format == torch.channels_last_3d) |
| or (module_mem_format == torch.channels_last_3d and module_memformat_affects_out))): |
| self.assertTrue(output.numel() == 0 or output.is_contiguous(memory_format=torch.channels_last_3d)) |
| else: |
| self.assertTrue(output.is_contiguous()) |
| return self._traverse_obj(output, inner_check_out_mem_format) |
| |
| def _req_grad(t): |
| return isinstance(t, torch.Tensor) and t.requires_grad |
| |
| for module_input in module_inputs: |
| if module_input.forward_input is None: |
| continue |
| |
| supports_channels_last = _check_dims(module_input.forward_input.args, 4) |
| supports_channels_last_3d = _check_dims(module_input.forward_input.args, 5) |
| input_mem_formats, module_mem_formats = _get_mem_formats(supports_channels_last, supports_channels_last_3d) |
| |
| with freeze_rng_state(): |
| # === Instantiate the module. === |
| args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs |
| |
| m = module_cls(*args, **kwargs) |
| m.to(device).to(dtype) |
| m.train(training) |
| |
| # === Get output in (contiguous, contiguous) configuration. === |
| args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs |
| desired_outputs = m(*args, **kwargs) |
| # === Do backward pass. === |
| ref_diff_outputs = tuple(t for t in torch.utils._pytree.tree_leaves(desired_outputs) if _req_grad(t)) |
| if training and len(ref_diff_outputs) > 0: |
| params = tuple(p for p in m.parameters()) |
| ref_diff_inputs = tuple( |
| t |
| for t in torch.utils._pytree.tree_leaves((args, kwargs, params)) |
| if _req_grad(t) |
| ) |
| ref_grad_outputs = tuple( |
| torch.rand_like(t) |
| for t in ref_diff_outputs |
| ) |
| ref_grad_inputs = torch.autograd.grad( |
| ref_diff_outputs, |
| ref_diff_inputs, |
| grad_outputs=ref_grad_outputs, |
| ) |
| |
| for input_mem_format in input_mem_formats: |
| # === Change memformat of input. === |
| d_args = _to_mem_format(input_mem_format, module_input.forward_input.args) |
| d_kwargs = _to_mem_format(input_mem_format, module_input.forward_input.kwargs) |
| |
| # See https://github.com/pytorch/pytorch/issues/107861 |
| # When inductor tests are turned on, the setting of requires_grad will be lost |
| for t1, t2 in zip( |
| torch.utils._pytree.tree_leaves(d_args), |
| torch.utils._pytree.tree_leaves(module_input.forward_input.args), |
| ): |
| t1.requires_grad_(t2.requires_grad) |
| for t1, t2 in zip( |
| torch.utils._pytree.tree_leaves(d_kwargs), |
| torch.utils._pytree.tree_leaves(module_input.forward_input.kwargs), |
| ): |
| t1.requires_grad_(t2.requires_grad) |
| |
| module_input.forward_input.args = d_args |
| module_input.forward_input.kwargs = d_kwargs |
| |
| for module_mem_format in module_mem_formats: |
| # === Change memformat of module === |
| m.to(memory_format=module_mem_format) |
| |
| # === Do forward pass. === |
| args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs |
| outputs = m(*args, **kwargs) |
| |
| # === Compare outputs to (contiguous, contiguous) output. === |
| if input_mem_format != torch.contiguous_format or module_mem_format != torch.contiguous_format: |
| self.assertEqual(outputs, desired_outputs, rtol=rtol, atol=atol) |
| |
| # === Check mem format of output. === |
| _check_out_mem_format(outputs, input_mem_format, module_mem_format) |
| |
| # === Do backward pass. === |
| diff_outputs = tuple(t for t in torch.utils._pytree.tree_leaves(outputs) if _req_grad(t)) |
| if training and len(diff_outputs) > 0: |
| params = tuple(p for p in m.parameters()) |
| diff_inputs = tuple( |
| t |
| for t in torch.utils._pytree.tree_leaves((args, kwargs, params)) |
| if _req_grad(t) |
| ) |
| grad_outputs = tuple( |
| torch.empty_like(t1).copy_(t2) |
| for (t1, t2) in zip(diff_outputs, ref_grad_outputs) |
| ) |
| |
| grad_inputs = torch.autograd.grad( |
| diff_outputs, |
| diff_inputs, |
| grad_outputs=grad_outputs, |
| ) |
| |
| if ( |
| input_mem_format != torch.contiguous_format |
| or module_mem_format != torch.contiguous_format |
| ): |
| self.assertEqual( |
| grad_inputs, ref_grad_inputs, rtol=rtol, atol=atol |
| ) |
| |
| # === Check mem format of grad_inputs. === |
| _check_out_mem_format(grad_inputs, input_mem_format, module_mem_format) |
| |
| # Test whether train and eval modes differ for each module. Use to verify |
| # that the ModuleInfo entry flag is correct. |
| @modules(module_db, train_eval_mode=TrainEvalMode.train_only) |
| def test_if_train_and_eval_modes_differ(self, device, dtype, module_info, training): |
| module_cls = module_info.module_cls |
| module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, |
| requires_grad=False, training=training) |
| |
| # Run forward inputs through to see if the training flag is accessed during forward. |
| for module_input in module_inputs: |
| if module_input.forward_input is None: |
| continue |
| |
| # === Instantiate the module. === |
| args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs |
| m = module_cls(*args, **kwargs) |
| m.to(device).to(dtype) |
| m.train(training) |
| |
| # Remove training attribute and see if forward still works. |
| delattr(m, 'training') |
| |
| # === Do forward pass. === |
| try: |
| args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs |
| m(*args, **kwargs) |
| except AttributeError as e: |
| if "'training'" in str(e): |
| self.assertTrue(module_info.train_and_eval_differ, |
| f"The ModuleInfo entry for {module_info.name} has " |
| "train_and_eval_differ=False, but the training mode was found to " |
| "affect the forward pass. Consider setting train_and_eval_differ=True " |
| "for this ModuleInfo entry.") |
| else: |
| raise e |
| |
| |
| @onlyCPU |
| @modules(module_db) |
| def test_device_ctx_init(self, device, dtype, module_info, training): |
| module_cls = module_info.module_cls |
| module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, |
| requires_grad=False, training=training) |
| with torch.device('meta'): |
| module_inputs_meta = module_info.module_inputs_func(module_info, device=None, dtype=dtype, |
| requires_grad=False, training=training) |
| |
| for module_input, module_input_meta in zip(module_inputs, module_inputs_meta): |
| c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs |
| |
| c_args_meta, c_kwargs_meta = module_input_meta.constructor_input.args, module_input_meta.constructor_input.kwargs |
| |
| m_cpu = module_cls(*c_args, **c_kwargs) |
| |
| with torch.device('meta'): |
| m = module_cls(*c_args_meta, **c_kwargs_meta) |
| |
| for (p_meta, p_cpu) in chain(zip(m.parameters(), m_cpu.parameters()), |
| zip(m.buffers(), m_cpu.buffers())): |
| if torch.nn.parameter.is_lazy(p_meta): |
| continue |
| self.assertTrue(p_meta.is_meta) |
| assert_metadata_eq(self.assertEqual, p_meta, p_cpu) |
| |
| |
| @modules([module for module in module_db if module.module_error_inputs_func is not None]) |
| def test_errors(self, device, dtype, module_info, training): |
| module_cls = module_info.module_cls |
| error_inputs = module_info.module_error_inputs_func(module_info, device=device, dtype=dtype, |
| requires_grad=False, training=training) |
| for error_input in error_inputs: |
| module_input = error_input.module_error_input |
| c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs |
| if error_input.error_on == ModuleErrorEnum.CONSTRUCTION_ERROR: |
| with self.assertRaisesRegex(error_input.error_type, error_input.error_regex): |
| m = module_cls(*c_args, **c_kwargs) |
| elif error_input.error_on == ModuleErrorEnum.FORWARD_ERROR: |
| m = module_cls(*c_args, **c_kwargs) |
| fw_args, fw_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs |
| with self.assertRaisesRegex(error_input.error_type, error_input.error_regex): |
| m(*fw_args, **fw_kwargs) |
| else: |
| raise NotImplementedError(f"Unknown error type {error_input.error_on}") |
| |
| # Only run this test for float32 because the test loops over all the dtypes |
| @modules([module for module in module_db if not module.is_lazy], allowed_dtypes=[torch.float32]) |
| @parametrize('swap', [True, False]) |
| @parametrize('set_grad', [True, False]) |
| @wrapSwapTensorsTest() |
| def test_to(self, device, dtype, module_info, training, swap, set_grad): |
| module_cls = module_info.module_cls |
| devices = ['cpu'] |
| if torch.cuda.is_available(): |
| devices += ['cuda'] |
| dtypes = module_info.dtypes |
| module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, |
| requires_grad=False, training=training) |
| torch.__future__.set_swap_module_params_on_conversion(swap) |
| |
| for module_input in module_inputs: |
| c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs |
| args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs |
| |
| m = module_cls(*c_args, **c_kwargs) |
| |
| # Avoid using `module.to()` when constructing module since that is the method we are testing |
| def _to(m, set_grad=False): |
| for c in m.children(): |
| _to(c, set_grad=set_grad) |
| for n, p in m.named_parameters(recurse=False): |
| new_p = torch.nn.Parameter(p.detach().clone().to(device, dtype)) |
| setattr(m, n, new_p) |
| if set_grad: |
| new_p.grad = torch.randn_like(new_p) |
| for n, b in m.named_buffers(recurse=False): |
| new_b = b.detach().clone().to(device, dtype) |
| setattr(m, n, new_b) |
| _to(m, set_grad=set_grad) |
| |
| # Check .to() can be run after forward and backward with swap |
| has_params = len(list(m.parameters())) > 0 |
| if swap and not set_grad and has_params: |
| out = m(*args, **kwargs) |
| if isinstance(out, tuple): |
| out = out[0] |
| out.sum().backward() |
| m.to(dtype=torch.half) |
| # reset |
| m.to(dtype=torch.float32) |
| |
| prev_device, prev_dtype = device, dtype |
| for device_, dtype_ in product(devices, dtypes): |
| # if device/dtype do not change, grad.to(device, dtype) is a no-op so |
| # swapping will not change ._cdata |
| # parameters will be wrapped in an nn.Parameter before swapping |
| # which will cause the ._cdata to change |
| g_no_swap = device_ == prev_device and dtype_ == prev_dtype |
| prev_prev_device, prev_prev_dtype = prev_device, prev_dtype |
| prev_device, prev_dtype = device_, dtype_ |
| |
| p_ids_before = [id(p) for p in m.parameters()] |
| p_cdatas_before = [p._cdata for p in m.parameters()] |
| if set_grad: |
| g_ids_before = [id(p.grad) for p in m.parameters()] |
| g_cdatas_before = [p.grad._cdata for p in m.parameters()] |
| |
| m.to(device=device_, dtype=dtype_) |
| |
| self.assertTrue(all(isinstance(p, torch.nn.Parameter) for p in m.parameters())) |
| self.assertTrue(all(p.device.type == device_ for p in m.parameters())) |
| self.assertTrue(all(p.dtype == dtype_ for p in m.parameters())) |
| p_ids_after = [id(p) for p in m.parameters()] |
| p_cdatas_after = [p._cdata for p in m.parameters()] |
| |
| if set_grad: |
| self.assertTrue(all(p.grad.device.type == device_ for p in m.parameters())) |
| self.assertTrue(all(p.grad.dtype == dtype_ for p in m.parameters())) |
| g_ids_after = [id(p.grad) for p in m.parameters()] |
| g_cdatas_after = [p.grad._cdata for p in m.parameters()] |
| |
| if swap: |
| # id same, ._cdata differs --> swapped cdata of THPVariable |
| self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after))) |
| self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after))) |
| if set_grad: |
| self.assertTrue( |
| all(a == b if g_no_swap else a != b for a, b in zip(g_cdatas_before, g_cdatas_after))) |
| else: |
| # id and _cdata remain the same --> .data setting |
| self.assertTrue(all(a == b for a, b in zip(p_cdatas_before, p_cdatas_after))) |
| self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after))) |
| if set_grad: |
| self.assertTrue(all(a == b for a, b in zip(g_cdatas_before, g_cdatas_after))) |
| self.assertTrue(all(a == b for a, b in zip(g_ids_before, g_ids_after))) |
| |
| @modules([module for module in module_db if not module.is_lazy], allowed_dtypes=[torch.float32]) |
| @parametrize('swap', [True, False]) |
| @wrapSwapTensorsTest() |
| def test_to_empty(self, device, dtype, module_info, swap, training): |
| module_cls = module_info.module_cls |
| |
| with torch.device("meta"): |
| module_inputs = module_info.module_inputs_func(module_info, device=None, dtype=dtype, |
| requires_grad=False, training=training) |
| |
| torch.__future__.set_swap_module_params_on_conversion(swap) |
| device_ = torch.device(device) |
| |
| for module_input in module_inputs: |
| c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs |
| |
| with torch.device("meta"): |
| m = module_cls(*c_args, **c_kwargs) |
| |
| p_ids_before = [id(p) for p in m.parameters()] |
| p_cdatas_before = [p._cdata for p in m.parameters()] |
| m.to_empty(device=device_) |
| |
| self.assertTrue(all(isinstance(p, torch.nn.Parameter) for p in m.parameters())) |
| self.assertTrue(all(p.device == device_ for p in m.parameters())) |
| self.assertTrue(all(p.dtype == dtype for p in m.parameters())) |
| p_ids_after = [id(p) for p in m.parameters()] |
| p_cdatas_after = [p._cdata for p in m.parameters()] |
| |
| if swap: |
| # id same, ._cdata differs --> swapped cdata of THPVariable |
| self.assertTrue(all(a == b for a, b in zip(p_ids_before, p_ids_after))) |
| self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after))) |
| else: |
| # id and ._cdata differ |
| # meta and device have different shallow copy types, so this will create a new |
| # parameter and assign it to the module |
| self.assertTrue(all(a != b for a, b in zip(p_ids_before, p_ids_after))) |
| self.assertTrue(all(a != b for a, b in zip(p_cdatas_before, p_cdatas_after))) |
| |
| |
| instantiate_device_type_tests(TestModule, globals(), allow_mps=True) |
| |
| if __name__ == '__main__': |
| run_tests() |