| # Owner(s): ["module: nn"] |
| import unittest |
| from dataclasses import dataclass |
| from functools import partial |
| from itertools import chain, product |
| |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.nn import CrossEntropyLoss |
| from torch.nn.utils._expanded_weights import ExpandedWeight |
| from torch.nn.utils._expanded_weights.expanded_weights_utils import ( |
| forward_helper, |
| set_grad_sample_if_exists, |
| standard_kwargs, |
| sum_over_all_but_batch_and_last_n, |
| unpack_expanded_weight_or_tensor, |
| ) |
| from torch.nn.utils._per_sample_grad import call_for_per_sample_grads |
| from torch.testing._internal.common_cuda import TEST_CUDA, tf32_off |
| from torch.testing._internal.common_device_type import ( |
| instantiate_device_type_tests, |
| OpDTypes, |
| ops, |
| ) |
| from torch.testing._internal.common_methods_invocations import op_db, SampleInput |
| from torch.testing._internal.common_modules import module_db, modules |
| from torch.testing._internal.common_nn import module_tests, new_module_tests, TestBase |
| from torch.testing._internal.common_utils import ( |
| freeze_rng_state, |
| make_tensor, |
| parametrize, |
| run_tests, |
| skipIfTorchDynamo, |
| TestCase, |
| ) |
| from torch.utils._pytree import tree_map_only |
| |
| |
| class TestContext: |
| pass |
| |
| |
| class TestExpandedWeightHelperFunction(TestCase): |
| def test_forward_helper(self, device): |
| input = torch.randn(3, 4, device=device) |
| weight = torch.randn(5, 4, device=device) |
| bias = torch.randn(5, device=device) |
| for weight_batched, bias_batched in product([True, False], [True, False]): |
| maybe_batched_weight = weight |
| maybe_batched_bias = bias |
| if weight_batched: |
| maybe_batched_weight = ExpandedWeight( |
| weight.clone().requires_grad_(), 3, loss_reduction="sum" |
| ) |
| if bias_batched: |
| maybe_batched_bias = ExpandedWeight( |
| bias.clone().requires_grad_(), 3, loss_reduction="sum" |
| ) |
| args = (input, maybe_batched_weight, maybe_batched_bias) |
| expanded_args, expanded_kwargs = standard_kwargs(("bias",), args) |
| res = forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) |
| expected = nn.functional.linear(input, weight, bias) |
| self.assertEqual(res, expected) |
| |
| self.assertEqual(len(expanded_args), 2) |
| assert expanded_args[0] is args[0] # avoids property checks in assertEquals |
| assert expanded_args[1] is args[1] # avoids property checks in assertEquals |
| self.assertEqual(len(expanded_kwargs), 1) |
| assert ( |
| expanded_kwargs["bias"] is args[2] |
| ) # avoids property checks in assertEquals |
| |
| def test_forward_helper_failure_args(self, device): |
| weight = torch.randn(5, 4, device=device) |
| bias = torch.randn(5, device=device) |
| with self.assertRaisesRegex( |
| RuntimeError, r"do not support inputs that are also ExpandedWeights." |
| ): |
| input = ExpandedWeight( |
| torch.randn(3, 4, requires_grad=True), 3, loss_reduction="sum" |
| ) |
| expanded_args, expanded_kwargs = standard_kwargs( |
| ("bias",), (input, weight, bias) |
| ) |
| forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) |
| with self.assertRaisesRegex( |
| RuntimeError, r"requires a Tensor as the first input" |
| ): |
| expanded_args, expanded_kwargs = standard_kwargs( |
| ("bias",), (3, weight, bias) |
| ) |
| forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) |
| with self.assertRaisesRegex( |
| RuntimeError, r"requires a batch dimension but got an input of size 0" |
| ): |
| expanded_args, expanded_kwargs = standard_kwargs( |
| ("bias",), (torch.tensor(3), weight, bias) |
| ) |
| forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) |
| with self.assertRaisesRegex( |
| RuntimeError, r"0 is not a valid batch size for Expanded Weights" |
| ): |
| expanded_args, expanded_kwargs = standard_kwargs( |
| ("bias",), (torch.randn(0, 1, 2), weight, bias) |
| ) |
| forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) |
| input = torch.randn(3, 4) |
| for weight_batched, bias_batched in product([True, False], [True, False]): |
| if not weight_batched and not bias_batched: |
| continue |
| maybe_batched_weight = weight |
| maybe_batched_bias = bias |
| if weight_batched: |
| maybe_batched_weight = ExpandedWeight( |
| weight.clone().requires_grad_(), 4, loss_reduction="sum" |
| ) |
| if bias_batched: |
| maybe_batched_bias = ExpandedWeight( |
| bias.clone().requires_grad_(), 4, loss_reduction="sum" |
| ) |
| with self.assertRaisesRegex( |
| RuntimeError, |
| r"Expected ExpandedWeights to have batch size matching input", |
| ): |
| expanded_args, expanded_kwargs = standard_kwargs( |
| ("bias",), (input, maybe_batched_weight, maybe_batched_bias) |
| ) |
| forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) |
| |
| def test_set_grad_sample_if_exists(self, device): |
| def test_fn(a): |
| return grad_sample |
| |
| orig_weight = torch.randn(4, device=device, requires_grad=True) |
| expanded_weight = ExpandedWeight(orig_weight, 3, loss_reduction="sum") |
| grad_sample = torch.randn(3) |
| set_grad_sample_if_exists(expanded_weight, test_fn) |
| self.assertTrue(hasattr(orig_weight, "grad_sample")) |
| self.assertEqual(orig_weight.grad_sample, grad_sample) |
| |
| basic_tensor = torch.randn(4, device=device) |
| set_grad_sample_if_exists(basic_tensor, test_fn) |
| self.assertFalse(hasattr(basic_tensor, "grad_sample")) |
| |
| non_tensor = 3 |
| set_grad_sample_if_exists(non_tensor, test_fn) |
| self.assertFalse(hasattr(non_tensor, "grad_sample")) |
| |
| def test_set_grad_sample_if_exists_failure(self, device): |
| def test_fn(a): |
| return True |
| |
| grad_tensor = torch.randn(4, requires_grad=True, device=device) |
| with self.assertRaisesRegex( |
| RuntimeError, |
| r"does not support a mixture of ExpandedWeight parameters and normal Parameters", |
| ): |
| set_grad_sample_if_exists(grad_tensor, test_fn) |
| |
| def test_unpack_expanded_weight_or_tensor(self, device): |
| input = torch.randn(3, requires_grad=True, device=device) |
| self.assertEqual( |
| input, |
| unpack_expanded_weight_or_tensor( |
| ExpandedWeight(input, 3, loss_reduction="sum") |
| ), |
| ) |
| |
| input.requires_grad_(False) |
| self.assertEqual(input, unpack_expanded_weight_or_tensor(input)) |
| self.assertTrue(unpack_expanded_weight_or_tensor(4) is None) |
| |
| def test_unpack_expanded_weight_or_tensor_with_custom_function(self, device): |
| input = torch.randn(3, requires_grad=True, device=device) |
| self.assertTrue( |
| unpack_expanded_weight_or_tensor( |
| ExpandedWeight(input, 3, loss_reduction="sum"), lambda x: x is input |
| ) |
| ) |
| |
| input.requires_grad_(False) |
| self.assertTrue(unpack_expanded_weight_or_tensor(input, lambda x: x is input)) |
| self.assertTrue( |
| unpack_expanded_weight_or_tensor(4, lambda x: x is input) is None |
| ) |
| |
| def test_unpack_expanded_weight_or_tensor_failure(self, device): |
| input = torch.randn(3, requires_grad=True, device=device) |
| with self.assertRaisesRegex( |
| RuntimeError, |
| r"does not support a mixture of ExpandedWeight parameters and normal Parameters", |
| ): |
| unpack_expanded_weight_or_tensor(input) |
| |
| with self.assertRaisesRegex( |
| RuntimeError, |
| r"does not support a mixture of ExpandedWeight parameters and normal Parameters", |
| ): |
| unpack_expanded_weight_or_tensor(input, lambda x: x is input) |
| |
| def test_sum_over_all_but_batch_and_last_n(self, device): |
| input = torch.randn(1, 2, 3, 4, 5, device=device) |
| res = sum_over_all_but_batch_and_last_n(input, 2) |
| expected = input.sum((1, 2)) |
| self.assertEqual(res, expected) |
| |
| res = sum_over_all_but_batch_and_last_n(input, 0) |
| expected = input.sum((1, 2, 3, 4)) |
| self.assertEqual(res, expected) |
| |
| res = sum_over_all_but_batch_and_last_n(input, 4) |
| self.assertEqual(res, input) |
| |
| |
| class TestExpandedWeightFunctional(TestCase): |
| def _compare_ew_and_for_loop_per_sample_grads(self, op, sample_input, reduction): |
| input = sample_input.input |
| args = sample_input.args |
| kwargs = sample_input.kwargs |
| batch_size = input.shape[0] if len(input.shape) > 1 else 1 |
| |
| # get per sample grads with ExpandedWeights objects |
| loss_reduction = "sum" if reduction == torch.sum else "mean" |
| (ew_input, ew_args, ew_kwargs) = make_expanded_weight( |
| sample_input, batch_size, loss_reduction |
| ) |
| diff_input_list = (ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values()) |
| diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)] |
| diff_input_list = [ |
| i.orig_weight if isinstance(i, ExpandedWeight) else i |
| for i in diff_input_list |
| ] |
| if not diff_input_list: |
| return |
| result = run_op(op, ew_input, *ew_args, **ew_kwargs) |
| reduction( |
| result |
| ).backward() # grad doesn't work with ExpandedWeight because it calls __torch_function__ |
| expanded_weight_grad = tuple( |
| i.grad_sample if hasattr(i, "grad_sample") else i.grad |
| for i in diff_input_list |
| ) |
| |
| # get per sample grads with for loop |
| func = partial(run_op, op) |
| |
| per_sample_grad = for_loop_per_sample_grad( |
| batch_size, reduction, input, func, *args, **kwargs |
| ) |
| |
| # check equality |
| self.assertEqual(len(per_sample_grad), len(expanded_weight_grad)) |
| if loss_reduction == "mean": |
| # don't check equality of `input.grad`s since these vanilla tensors won't be scaled |
| expanded_weight_grad = expanded_weight_grad[1:] |
| per_sample_grad = per_sample_grad[1:] |
| for result_grad, expected_grad in zip(expanded_weight_grad, per_sample_grad): |
| self.assertEqual(result_grad, expected_grad) |
| |
| @ops( |
| filter(lambda op: op.supports_expanded_weight, op_db), |
| dtypes=OpDTypes.supported, |
| allowed_dtypes=(torch.double,), |
| ) |
| def test_expanded_weight_per_sample_grad_sum(self, device, dtype, op): |
| sample_inputs = op.sample_inputs(device, dtype, requires_grad=True) |
| for sample_input in supported_inputs(op, sample_inputs): |
| if ( |
| op.name == "nn.functional.embedding" |
| ): # embedding flips its argument order for autograd tests |
| sample_input = SampleInput( |
| sample_input.args[0], |
| args=(sample_input.input,), |
| kwargs=sample_input.kwargs, |
| ) |
| |
| self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.sum) |
| |
| @ops( |
| filter(lambda op: op.supports_expanded_weight, op_db), |
| dtypes=OpDTypes.supported, |
| allowed_dtypes=(torch.double,), |
| ) |
| def test_expanded_weight_per_sample_grad_mean(self, device, dtype, op): |
| sample_inputs = op.sample_inputs(device, dtype, requires_grad=True) |
| for sample_input in supported_inputs(op, sample_inputs): |
| if ( |
| op.name == "nn.functional.embedding" |
| ): # embedding flips its argument order for autograd tests |
| sample_input = SampleInput( |
| sample_input.args[0], |
| args=(sample_input.input,), |
| kwargs=sample_input.kwargs, |
| ) |
| |
| self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.mean) |
| |
| @ops( |
| filter(lambda op: op.supports_expanded_weight, op_db), |
| dtypes=OpDTypes.supported, |
| allowed_dtypes=(torch.double,), |
| ) |
| def test_expanded_weights_per_sample_grad_input_no_grad(self, device, dtype, op): |
| sample_inputs = op.sample_inputs(device, dtype, requires_grad=True) |
| for sample_input in supported_inputs(op, sample_inputs): |
| if ( |
| op.name == "nn.functional.embedding" |
| ): # embedding flips its argument order for autograd tests |
| sample_input = SampleInput( |
| sample_input.args[0], |
| args=(sample_input.input,), |
| kwargs=sample_input.kwargs, |
| ) |
| sample_input.input.requires_grad_(False) |
| |
| self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.mean) |
| |
| @skipIfTorchDynamo("Checking error message doesn't work with dynamo") |
| @ops( |
| filter(lambda op: op.supports_expanded_weight, op_db), |
| dtypes=OpDTypes.supported, |
| allowed_dtypes=(torch.double,), |
| ) |
| def test_unsupported_expand_weights(self, device, dtype, op): |
| sample_inputs = op.sample_inputs(device, dtype, requires_grad=True) |
| unsupported_inputs = supported_inputs(op, sample_inputs, supported_inputs=False) |
| for sample_input in unsupported_inputs: |
| with self.assertRaisesRegex(RuntimeError, r"Expanded Weights"): |
| if ( |
| op.name == "nn.functional.embedding" |
| ): # embedding flips its argument order for autograd tests |
| sample_input = SampleInput( |
| sample_input.args[0], |
| args=(sample_input.input,), |
| kwargs=sample_input.kwargs, |
| ) |
| input = sample_input.input |
| |
| batch_size = input.shape[0] if len(input.shape) > 1 else 1 |
| |
| # get per sample grads with ExpandedWeights objects |
| (ew_input, ew_args, ew_kwargs) = make_expanded_weight( |
| sample_input, batch_size |
| ) |
| result = run_op(op, ew_input, *ew_args, **ew_kwargs) |
| diff_input_list = ( |
| (ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values()) |
| ) |
| diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)] |
| diff_input_list = [ |
| i.orig_weight if isinstance(i, ExpandedWeight) else i |
| for i in diff_input_list |
| ] |
| result.sum().backward() # grad doesn't work with ExpandedWeight because it calls __torch_function__ |
| |
| @ops( |
| filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported |
| ) |
| def test_expanded_weight_forward(self, device, dtype, op): |
| sample_inputs = op.sample_inputs(device, dtype) |
| for sample_input in supported_inputs(op, sample_inputs): |
| if ( |
| op.name == "nn.functional.embedding" |
| ): # embedding flips its argument order for autograd tests |
| sample_input = SampleInput( |
| sample_input.args[0].clone(), |
| args=(sample_input.input.clone(),), |
| kwargs=sample_input.kwargs, |
| ) |
| if ( |
| "cuda" in device |
| and "max_norm" in sample_input.kwargs |
| and "padding_idx" in sample_input.kwargs |
| ): |
| self.skipTest( |
| "embedding is non-determinstic in this case, see issue #74679" |
| ) |
| batch_size = ( |
| sample_input.input.shape[0] if len(sample_input.input.shape) > 1 else 1 |
| ) |
| for loss_reduction in ["sum", "mean"]: |
| (ew_input, ew_args, ew_kwargs) = make_expanded_weight( |
| sample_input, batch_size, loss_reduction |
| ) |
| expanded_weight_result = run_op(op, ew_input, *ew_args, **ew_kwargs) |
| normal_result = run_op( |
| op, sample_input.input, *sample_input.args, **sample_input.kwargs |
| ) |
| self.assertEqual(expanded_weight_result, normal_result) |
| |
| def test_expanded_weight_error(self, device): |
| batch_size = 3 |
| sample_input = make_tensor( |
| (batch_size, 4), dtype=torch.float32, device=device, requires_grad=True |
| ) |
| sample_weight = make_tensor( |
| (4), dtype=torch.float32, device=device, requires_grad=True |
| ) |
| with self.assertRaisesRegex( |
| RuntimeError, r"Expanded Weights encountered but cannot handle function" |
| ): |
| torch.add( |
| sample_input, |
| ExpandedWeight(sample_weight, batch_size, loss_reduction="sum"), |
| ) |
| |
| def _test_embedding_model(self, model, num_embedding, device): |
| batch_size = 32 |
| input = torch.randint(0, num_embedding, (batch_size, 5, 5), device=device) |
| return self._test_model( |
| partial(model, num_embedding=num_embedding), batch_size, input, device |
| ) |
| |
| def _test_conv_model( |
| self, |
| model, |
| input_size, |
| num_dim, |
| device, |
| loss_reduction="sum", |
| atol=1e-4, |
| rtol=5e-5, |
| ): |
| batch_size = 32 |
| input_ending = [input_size] * num_dim |
| input = torch.randn([batch_size, 3] + input_ending, device=device) |
| return self._test_model( |
| partial(model, num_dim=num_dim), |
| batch_size, |
| input, |
| device, |
| loss_reduction, |
| atol, |
| rtol, |
| ) |
| |
| def _test_model( |
| self, |
| model, |
| batch_size, |
| input, |
| device, |
| loss_reduction="sum", |
| atol=1e-4, |
| rtol=5e-5, |
| ): |
| model = model(10).to(device) |
| targets = torch.randint(0, 10, (batch_size,), device=device) |
| criterion = CrossEntropyLoss(reduction=loss_reduction) |
| result = call_for_per_sample_grads(model, loss_reduction=loss_reduction)(input) |
| loss = criterion(result, targets) |
| loss.backward() |
| result = [] |
| for weight in model.parameters(): |
| result.append(weight.grad_sample) |
| del weight.grad_sample |
| |
| expected = [] |
| for i in range(batch_size): |
| loss = criterion(model(input[i].unsqueeze(0)), targets[i].unsqueeze(0)) |
| expected.append( |
| torch.autograd.grad(loss, model.parameters(), torch.ones_like(loss)) |
| ) |
| |
| expected = [torch.stack(grad) for grad in zip(*expected)] |
| for res, exp in zip(result, expected): |
| self.assertEqual(res, exp, atol=atol, rtol=rtol) |
| |
| def _compute_tolerances(self, device): |
| is_cuda_sm86 = device.startswith("cuda") and torch.cuda.get_device_capability( |
| 0 |
| ) == (8, 6) |
| return (9e-3, 5e-5) if is_cuda_sm86 else (1e-4, 5e-5) |
| |
| @tf32_off() |
| def test_cnn_model_sum(self, device): |
| def convnet(num_classes, num_dim): |
| return nn.Sequential( |
| nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), |
| nn.ReLU(), |
| nn.AvgPool2d(kernel_size=2, stride=2), |
| nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), |
| nn.ReLU(), |
| nn.AvgPool2d(kernel_size=2, stride=2), |
| nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), |
| nn.ReLU(), |
| nn.AvgPool2d(kernel_size=2, stride=2), |
| nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), |
| nn.ReLU(), |
| nn.AdaptiveAvgPool2d((1, 1)), |
| nn.Flatten(start_dim=1, end_dim=-1), |
| nn.Linear(128, num_classes, bias=True), |
| ) |
| |
| atol, rtol = self._compute_tolerances(device) |
| return self._test_conv_model(convnet, 28, 2, device, atol=atol, rtol=rtol) |
| |
| @tf32_off() |
| def test_cnn_model_mean(self, device): |
| def convnet(num_classes, num_dim): |
| return nn.Sequential( |
| nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), |
| nn.ReLU(), |
| nn.AvgPool2d(kernel_size=2, stride=2), |
| nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), |
| nn.ReLU(), |
| nn.AvgPool2d(kernel_size=2, stride=2), |
| nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), |
| nn.ReLU(), |
| nn.AvgPool2d(kernel_size=2, stride=2), |
| nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), |
| nn.ReLU(), |
| nn.AdaptiveAvgPool2d((1, 1)), |
| nn.Flatten(start_dim=1, end_dim=-1), |
| nn.Linear(128, num_classes, bias=True), |
| ) |
| |
| atol, rtol = self._compute_tolerances(device) |
| return self._test_conv_model( |
| convnet, 28, 2, device, loss_reduction="mean", atol=atol, rtol=rtol |
| ) |
| |
| @parametrize("num_dim", [1, 2, 3]) |
| @tf32_off() |
| def test_instance_norm_model(self, num_dim, device): |
| def instance_norm_model(num_classes, num_dim): |
| conv_layer = ( |
| nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d |
| ) |
| norm_layer = ( |
| nn.InstanceNorm1d |
| if num_dim == 1 |
| else nn.InstanceNorm2d |
| if num_dim == 2 |
| else nn.InstanceNorm3d |
| ) |
| return nn.Sequential( |
| conv_layer(3, 32, kernel_size=3, stride=1, padding=1), |
| norm_layer(32, affine=True), |
| nn.Flatten(start_dim=1, end_dim=-1), |
| nn.Linear(32 * (7**num_dim), num_classes, bias=True), |
| ) |
| |
| atol, rtol = self._compute_tolerances(device) |
| return self._test_conv_model( |
| instance_norm_model, 7, num_dim, device, atol=atol, rtol=rtol |
| ) |
| |
| @parametrize("num_dim", [1, 2, 3]) |
| @tf32_off() |
| def test_group_norm_model(self, num_dim, device): |
| def group_norm_model(num_classes, num_dim): |
| conv_layer = ( |
| nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d |
| ) |
| return nn.Sequential( |
| conv_layer(3, 32, kernel_size=3, stride=1, padding=1), |
| nn.GroupNorm(8, 32, affine=True), |
| nn.Flatten(start_dim=1, end_dim=-1), |
| nn.Linear(32 * (7**num_dim), num_classes, bias=True), |
| ) |
| |
| atol, rtol = self._compute_tolerances(device) |
| return self._test_conv_model( |
| group_norm_model, 7, num_dim, device, atol=atol, rtol=rtol |
| ) |
| |
| @parametrize("num_dim", [1, 2, 3]) |
| @tf32_off() |
| def test_layer_norm_model(self, num_dim, device): |
| def layer_norm_model(num_classes, num_dim): |
| conv_layer = ( |
| nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d |
| ) |
| normalized_shape = [7] * num_dim |
| return nn.Sequential( |
| conv_layer(3, 32, kernel_size=3, stride=1, padding=1), |
| nn.LayerNorm(normalized_shape, elementwise_affine=True), |
| nn.Flatten(start_dim=1, end_dim=-1), |
| nn.Linear(32 * (7**num_dim), num_classes, bias=True), |
| ) |
| |
| atol, rtol = self._compute_tolerances(device) |
| return self._test_conv_model( |
| layer_norm_model, 7, num_dim, device, atol=atol, rtol=rtol |
| ) |
| |
| def test_embedding_model(self, device): |
| def embedding_model(num_classes, num_embedding): |
| return nn.Sequential( |
| nn.Embedding(num_embedding, 15), |
| nn.Flatten(start_dim=1, end_dim=-1), |
| nn.Linear(375, num_classes, bias=True), |
| ) |
| |
| return self._test_embedding_model(embedding_model, 16, device) |
| |
| def test_group_norm_error(self, device): |
| # group norm has to call native_group_norm. This checks that it hits the same errors |
| # that normal group norm would |
| |
| N = 3 |
| C = 5 |
| inp = torch.randn(N, C) |
| with self.assertRaisesRegex( |
| RuntimeError, r"Expected number of channels in input to be divisible" |
| ): |
| F.group_norm(inp, 2) # 5 is not divisible by 2 |
| |
| |
| class TestExpandedWeightModule(TestCase): |
| def _do_test( |
| self, |
| module, |
| input, |
| args=None, |
| kwargs=None, |
| batch_first=True, |
| atol=None, |
| rtol=None, |
| ): |
| args = args or () |
| kwargs = kwargs or {} |
| |
| batch_dim = 0 if batch_first else 1 |
| batch_size = input.shape[batch_dim] |
| diff_input = input.dtype == torch.float or input.dtype == torch.double |
| if diff_input: |
| input.requires_grad_() |
| |
| with freeze_rng_state(): |
| # get per sample grads with ExpandedWeights context manager |
| actual_res = call_for_per_sample_grads( |
| module, |
| batch_size=batch_size, |
| loss_reduction="sum", |
| batch_first=batch_first, |
| )(input, *args, **kwargs).sum() |
| actual_res.backward() |
| actual_grads = [] |
| for param in module.parameters(): |
| actual_grads.append(param.grad_sample) |
| del param.grad_sample |
| if diff_input: |
| actual_grads.append(input.grad.clone()) |
| input.grad = torch.zeros_like(input.grad) |
| |
| # get per sample grads with a for loop |
| expected_res = torch.tensor( |
| 0.0, device=input.device, dtype=actual_res.dtype |
| ) |
| expected_grads = [] |
| for i in range(batch_size): |
| input_slice = input.narrow(batch_dim, i, 1) |
| input_slice = input_slice.squeeze(batch_dim) |
| |
| # h's batch dim is always the first dim. Must be contiguous for CUDA |
| sliced_args = tree_map_only( |
| torch.Tensor, lambda t: t.narrow(1, i, 1).contiguous(), args |
| ) |
| diff_params = module.parameters() |
| if diff_input: |
| diff_params = chain(diff_params, (input_slice,)) |
| res = module( |
| input_slice.unsqueeze(batch_dim).contiguous(), |
| *sliced_args, |
| **kwargs, |
| ).sum() |
| out_grads = torch.autograd.grad( |
| res, diff_params, torch.ones_like(res), allow_unused=True |
| ) |
| expected_grads.append(out_grads) |
| expected_res += res |
| expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)] |
| if not batch_first: |
| expected_grads[-1] = expected_grads[-1].transpose(0, 1) |
| self.assertEqual(actual_res, expected_res) |
| [ |
| self.assertEqual(actual, expected, atol=atol, rtol=rtol) |
| for (actual, expected) in zip(actual_grads, expected_grads) |
| ] |
| |
| def _do_test_multi_input(self, module, input): |
| class TestModule(nn.Module): |
| def __init__(self, module): |
| super().__init__() |
| self.module = module |
| |
| def forward(self, input): |
| return self.module(input) + self.module(input) |
| |
| batch_size = input.shape[0] |
| diff_input = input.dtype == torch.float or input.dtype == torch.double |
| if diff_input: |
| input.requires_grad_() |
| with freeze_rng_state(): |
| # get per sample grads with ExpandedWeights context manager, calling .backward() twice |
| test_module = TestModule(module) |
| actual_res = call_for_per_sample_grads(test_module, loss_reduction="sum")( |
| input |
| ).sum() |
| actual_res.backward() |
| actual_grads = [] |
| for param in module.parameters(): |
| actual_grads.append(param.grad_sample) |
| del param.grad_sample |
| if diff_input: |
| actual_grads.append(input.grad.clone()) |
| input.grad = torch.zeros_like(input.grad) |
| |
| # get per sample grads with a for loop, running over the input twice |
| expected_grads = [] |
| for i in range(batch_size): |
| input_slice = input[i] |
| diff_params = module.parameters() |
| if diff_input: |
| diff_params = chain(diff_params, (input_slice,)) |
| res = module(input_slice.unsqueeze(0)).sum() |
| out_grads = torch.autograd.grad( |
| res, diff_params, torch.ones_like(res), allow_unused=True |
| ) |
| expected_grads.append(out_grads) |
| expected_grads = tuple(torch.stack(grad) for grad in zip(*expected_grads)) |
| expected_grads = tuple( |
| expected_grad |
| for expected_grad in expected_grads |
| if expected_grad is not None |
| ) |
| assert [ |
| self.assertEqual(actual, 2 * expected) |
| for (actual, expected) in zip(actual_grads, expected_grads) |
| ] |
| |
| def _do_test_rnn_packed_sequence( |
| self, module, input, args=None, kwargs=None, atol=None, rtol=None |
| ): |
| args = args if args is not None else () |
| kwargs = kwargs if kwargs is not None else {} |
| |
| batch_size = max(tuple(input.batch_sizes)).item() |
| |
| with freeze_rng_state(): |
| # get per sample grads with ExpandedWeights context manager |
| actual_res = call_for_per_sample_grads( |
| module, batch_size=batch_size, loss_reduction="sum" |
| )(input, *args, **kwargs).data.sum() |
| actual_res.backward() |
| actual_grads = [] |
| for param in module.parameters(): |
| self.assertEqual(param.grad_sample.shape[0], batch_size) |
| actual_grads.append(param.grad_sample) |
| del param.grad_sample |
| |
| input.data.grad = torch.zeros_like(input.data) |
| |
| # compute the per sample grads with a for loop |
| expected_res = torch.zeros_like(actual_res) |
| expected_grads = [] |
| padded_input, seq_sizes = torch.nn.utils.rnn.pad_packed_sequence( |
| input, batch_first=True |
| ) |
| for i in range(len(seq_sizes)): |
| input_slice = padded_input[i].narrow(0, 0, seq_sizes[i]) |
| diff_params = module.parameters() |
| batch_dim = 0 if module.m.batch_first else 1 |
| res = module(input_slice.unsqueeze(batch_dim), *args, **kwargs).sum() |
| expected_res += res |
| out_grads = torch.autograd.grad( |
| res, diff_params, torch.ones_like(res), allow_unused=True |
| ) |
| expected_grads.append(out_grads) |
| |
| expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)] |
| self.assertEqual(actual_res, expected_res) |
| [ |
| self.assertEqual(actual, expected, atol=atol, rtol=rtol) |
| for (actual, expected) in zip(actual_grads, expected_grads) |
| ] |
| |
| @modules( |
| filter( |
| lambda m_info: m_info.module_cls |
| in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU), |
| module_db, |
| ) |
| ) |
| @tf32_off() |
| def test_module(self, device, dtype, module_info, training): |
| class RNNWrapper(torch.nn.Module): |
| def __init__(self, m_cons, args, kwargs): |
| super().__init__() |
| self.m = m_cons(*args, **kwargs) |
| |
| def forward(self, *inps): |
| ret = self.m(*inps) |
| assert isinstance(ret, tuple) |
| return ret[0] |
| |
| def batch_hidden(h): |
| new_h_shape = [1] * (len(h.shape) + 1) |
| new_h_shape[1] = 2 |
| return h.unsqueeze(1).repeat(new_h_shape) |
| |
| module_cls = module_info.module_cls |
| atol, rtol = ( |
| (1e-4, 1e-5) |
| if module_cls == torch.nn.GRU and dtype == torch.float32 |
| else (None, None) |
| ) |
| module_inputs = module_info.module_inputs_func( |
| module_info, |
| device=device, |
| dtype=dtype, |
| requires_grad=True, |
| training=training, |
| with_packed_sequence=True, |
| ) |
| for module_input in module_inputs: |
| if module_input.forward_input is None: |
| continue |
| args, kwargs = ( |
| module_input.constructor_input.args, |
| module_input.constructor_input.kwargs, |
| ) |
| m = RNNWrapper(module_cls, args, kwargs) |
| batch_first = m.m.batch_first |
| m.to(device).to(dtype) |
| |
| args, kwargs = ( |
| module_input.forward_input.args, |
| module_input.forward_input.kwargs, |
| ) |
| |
| # if the RNN tests use unbatched inputs--batch the inputs |
| input = args[0] |
| if isinstance(input, torch.Tensor) and input.dim() == 2: |
| input = input.detach() |
| new_input_shape = [1] * (len(input.shape) + 1) |
| if batch_first: |
| new_input_shape[0] = 2 |
| input = input.repeat(new_input_shape) |
| else: |
| new_input_shape[1] = 2 |
| input = input.unsqueeze(1).repeat(new_input_shape) |
| |
| h = args[1] if len(args) > 1 else None |
| if h is not None: |
| h = ( |
| batch_hidden(h) |
| if isinstance(h, torch.Tensor) |
| else tuple(batch_hidden(hx) for hx in h) |
| ) |
| args = list(args) |
| args[1] = h |
| |
| if isinstance(input, torch.nn.utils.rnn.PackedSequence): |
| self._do_test_rnn_packed_sequence( |
| m, input, args[1:], kwargs, atol=atol, rtol=rtol |
| ) |
| else: |
| self._do_test( |
| m, |
| input, |
| args[1:], |
| kwargs, |
| batch_first=batch_first, |
| atol=atol, |
| rtol=rtol, |
| ) |
| |
| def test_per_sample_api_failing(self): |
| module = nn.Linear(10, 10) |
| input = torch.randn(64, 10) |
| with self.assertRaisesRegex(RuntimeError, r"Module passed must be nn.Module"): |
| call_for_per_sample_grads("fail")(input) |
| with self.assertRaisesRegex( |
| RuntimeError, r"Batch size passed must be None or an integer" |
| ): |
| call_for_per_sample_grads(module, batch_size=6.4)(input) |
| with self.assertRaisesRegex(RuntimeError, r"Batch size must be positive"): |
| call_for_per_sample_grads(module, batch_size=-64)(input) |
| with self.assertRaisesRegex(RuntimeError, r"incorrect for multiple calls"): |
| loss = call_for_per_sample_grads(module)(input).sum() |
| loss.backward() # populate grad_sample fields |
| call_for_per_sample_grads(module)(input) |
| |
| module = nn.Linear(10, 10) # reset to not have grad_sample fields |
| with self.assertRaisesRegex( |
| RuntimeError, r"Expected loss_reduction argument to be sum or mean" |
| ): |
| call_for_per_sample_grads(module, loss_reduction="")(input) |
| |
| def test_per_sample_api_compute_batch_size(self): |
| class CustomModule(nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.linear = nn.Linear(5, 5) |
| |
| def forward(self, input1, input2): |
| return self.linear(input1) + self.linear(input2) |
| |
| module = CustomModule() |
| input1 = torch.randn(4, 5) |
| input2 = torch.randn(5, 5) |
| |
| with self.assertRaisesRegex( |
| RuntimeError, |
| "found at least one input with batch size 4 and one with batch size 5", |
| ): |
| call_for_per_sample_grads(module)(input1, input2) |
| |
| input2 = torch.randn(4, 5) |
| call_for_per_sample_grads(module)(input1, input2) |
| |
| module = CustomModule() |
| call_for_per_sample_grads(module)(input1, input2=input2) |
| |
| module = CustomModule() |
| call_for_per_sample_grads(module)(input1=input1, input2=input2) |
| |
| def test_per_sample_api_compute_batch_size_not_pytreeable(self): |
| @dataclass |
| class NonPytreeableTuple: |
| elem1: torch.Tensor |
| elem2: torch.Tensor |
| |
| class CustomModule(nn.Module): |
| def __init__(self) -> None: |
| super().__init__() |
| self.linear = nn.Linear(5, 5) |
| |
| def forward(self, input1, input2): |
| return self.linear(input1.elem1) + self.linear(input1.elem2) |
| |
| input = NonPytreeableTuple(torch.randn(4, 5), torch.randn(4, 5)) |
| model = CustomModule() |
| with self.assertRaisesRegex( |
| RuntimeError, |
| "ExpandedWeights cannot compute the batch size from the inputs", |
| ): |
| call_for_per_sample_grads(model)(input, "") |
| |
| # would prefer for it to error because input is not pytree-able but that's hard to detect |
| with self.assertRaisesRegex( |
| RuntimeError, "Expected ExpandedWeights to have batch size matching input" |
| ): |
| call_for_per_sample_grads(model)(input, torch.randn(5)) |
| |
| model = CustomModule() # TODO: functional call bug, sam will fix |
| call_for_per_sample_grads(model)(input, torch.randn(4, 5)) |
| model = CustomModule() |
| call_for_per_sample_grads(model, batch_size=4)(input, torch.randn(5)) |
| |
| |
| class ContextManagerTests(TestBase): |
| def __init__(self, *args, **kwargs): |
| self.test_cpu = kwargs.get("test_cpu", True) |
| self.test_cuda = kwargs.get("test_cuda", True) |
| super().__init__(*args, **kwargs) |
| |
| @property |
| def constructor_args(self): |
| return self._get_arg("constructor_args", False) |
| |
| def test_context_manager(self, test_case, device): |
| kwargs = {"device": device, "dtype": torch.double} |
| module = self.constructor(*self.constructor_args).to(**kwargs) |
| if "Embedding" in self.get_name(): |
| kwargs["dtype"] = torch.long |
| input = self._get_input().to(**kwargs) |
| if len(input.shape) == 0 or input.shape[0] == 0: |
| raise unittest.SkipTest( |
| "Can't get per sample gradients when no batch dim or batch dim is 0" |
| ) |
| if self.constructor == torch.nn.Linear and len(input.shape) == 1: |
| raise unittest.SkipTest( |
| "Can't get per sample gradients for input of rank 1" |
| ) |
| test_case._do_test(module, input) |
| |
| def test_context_manager_multiple_inputs(self, test_case, device): |
| module = self.constructor(*self.constructor_args).to(device) |
| input = self._get_input() |
| if len(input.shape) == 0 or input.shape[0] == 0: |
| raise unittest.SkipTest( |
| "Can't get per sample gradients when no batch dim or batch dim is 0" |
| ) |
| if self.constructor == torch.nn.Linear and len(input.shape) == 1: |
| raise unittest.SkipTest( |
| "Can't get per sample gradients for input of rank 1" |
| ) |
| test_case._do_test_multi_input(module, input) |
| |
| |
| def filter_supported_tests(t): |
| supported_modules = [ |
| "Linear", |
| "Conv1d", |
| "Conv2d", |
| "Conv3d", |
| "Embedding", |
| "LayerNorm", |
| "GroupNorm", |
| "InstanceNorm", |
| ] |
| if "module_name" in t and t["module_name"] in supported_modules: |
| return True |
| |
| |
| # TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests |
| # These currently use the legacy nn tests |
| supported_tests = [ |
| t for t in module_tests + new_module_tests if filter_supported_tests(t) |
| ] |
| for test_param in supported_tests: |
| if "constructor" not in test_param: |
| name = test_param.pop("module_name") |
| test_param["constructor"] = getattr(nn, name) |
| decorator = test_param.pop("decorator", lambda test: test) |
| test = ContextManagerTests(**test_param) |
| test_name = test.get_name() |
| if hasattr(TestExpandedWeightModule, test_name): |
| raise RuntimeError("Found two tests with the same name: " + test_name) |
| test_name_multi_input = test.get_name() + "_multiple_inputs" |
| if hasattr(TestExpandedWeightModule, test_name_multi_input): |
| raise RuntimeError("Found two tests with the same name: " + test_name) |
| if test.test_cpu: |
| setattr( |
| TestExpandedWeightModule, |
| test_name, |
| decorator(lambda self, test=test: test.test_context_manager(self, "cpu")), |
| ) |
| setattr( |
| TestExpandedWeightModule, |
| test_name_multi_input, |
| decorator( |
| lambda self, test=test: test.test_context_manager_multiple_inputs( |
| self, "cpu" |
| ) |
| ), |
| ) |
| if TEST_CUDA and test.test_cuda: |
| # since this checks derivatives, only use double for precision |
| setattr( |
| TestExpandedWeightModule, |
| test_name + "_cuda_double", |
| decorator(lambda self, test=test: test.test_context_manager(self, "cuda")), |
| ) |
| |
| # ------------- HELPER FUNCTIONS ----------------- |
| |
| |
| def run_op(op, input, *args, **kwargs): |
| r""" |
| OpInfo for Embedding switches the input and weight so autograd tests will only check the derivative |
| of the weight, not the input, which can't be differentiable since its dtype is int. Calls op, |
| using the special ordering that Embedding's OpInfo expects for that case. |
| """ |
| if op.name == "nn.functional.embedding": |
| return op(args[0], input, **kwargs) |
| else: |
| return op(input, *args, **kwargs) |
| |
| |
| def make_expanded_weight(sample_input, batch_size, loss_reduction="sum"): |
| def expanded_weight_or_clone(arg): |
| if is_diff_tensor(arg): |
| return ExpandedWeight(torch.clone(arg), batch_size, loss_reduction) |
| return clone_if_tensor(arg) |
| |
| ew_input = clone_if_tensor(sample_input.input) |
| ew_args = tuple(expanded_weight_or_clone(arg) for arg in sample_input.args) |
| ew_kwargs = { |
| name: expanded_weight_or_clone(arg) |
| for (name, arg) in sample_input.kwargs.items() |
| } |
| return ew_input, ew_args, ew_kwargs |
| |
| |
| def supported_inputs(op, sample_inputs, supported_inputs=True): |
| r""" |
| ExpandedWeights currently does not support some use cases when there's no batch dimension or |
| operations that would cause inter-batch operations. Removes all of the cases it cannot deal with |
| """ |
| |
| def filter_fn(input): |
| convolutions = [ |
| "nn.functional.conv1d", |
| "nn.functional.conv2d", |
| "nn.functional.conv3d", |
| ] |
| batched_input_size = dict(zip(convolutions, [3, 4, 5])) |
| if op.name == "nn.functional.linear": |
| is_supported_input = ( |
| input.input.dim() > 1 |
| ) # input of rank 1 means no batch dim |
| elif op.name == "nn.functional.layer_norm": |
| normalized_shape = input.args[0] |
| is_supported_input = ( |
| input.input.shape != normalized_shape |
| ) # would cause inter-batch operations |
| elif op.name in convolutions: |
| # currently can't deal with padding computation on Python level |
| is_supported_input = input.input.dim() == batched_input_size[op.name] |
| elif op.name == "nn.functional.embedding": |
| idx = input.args[0] |
| is_supported_input = len(idx.shape) > 1 # there's no batch size |
| else: |
| is_supported_input = True |
| is_supported_input = ( |
| is_supported_input and input.input.shape[0] > 0 |
| ) # 0 is not a valid batch size |
| return is_supported_input if supported_inputs else not is_supported_input |
| |
| return [input for input in sample_inputs if filter_fn(input)] |
| |
| |
| def for_loop_per_sample_grad(batch_size, reduction, input, func, *args, **kwargs): |
| # get per sample grads by getting derivative for each input in a for loop |
| per_sample_grad = [] |
| for i in range(batch_size): |
| per_sample_input = input[i] |
| result = reduction(func(per_sample_input.unsqueeze(0), *args, **kwargs)) |
| diff_input_list = (per_sample_input,) + tuple(args) + tuple(kwargs.values()) |
| diff_input_list = [ |
| i |
| for i in diff_input_list |
| if isinstance(i, torch.Tensor) and i.requires_grad |
| ] |
| per_sample_grad.append( |
| torch.autograd.grad( |
| result, diff_input_list, torch.ones_like(result), allow_unused=True |
| ) |
| ) |
| if len(per_sample_grad) == batch_size: |
| per_sample_grad = tuple(torch.stack(grad) for grad in zip(*per_sample_grad)) |
| return per_sample_grad |
| |
| |
| def is_diff_tensor(t): |
| return isinstance(t, ExpandedWeight) or ( |
| isinstance(t, torch.Tensor) and t.requires_grad |
| ) |
| |
| |
| def clone_if_tensor(t): |
| if isinstance(t, torch.Tensor): |
| res = torch.clone(t).detach() |
| res.requires_grad_(t.requires_grad) |
| return res |
| else: |
| return t |
| |
| |
| instantiate_device_type_tests(TestExpandedWeightHelperFunction, globals()) |
| instantiate_device_type_tests(TestExpandedWeightFunctional, globals()) |
| instantiate_device_type_tests(TestExpandedWeightModule, globals()) |
| if __name__ == "__main__": |
| run_tests() |