| # Owner(s): ["module: nestedtensor"] |
| |
| import io |
| import itertools |
| import sys |
| from typing import Optional, Tuple |
| import unittest |
| from functools import partial |
| |
| import numpy as np |
| import torch |
| import torch.nn |
| import torch.nn.functional as F |
| from torch.testing._internal.common_cuda import SM80OrLater |
| from torch.testing._internal.common_device_type import ( |
| dtypes, |
| dtypesIfCUDA, |
| instantiate_device_type_tests, |
| onlyCPU, |
| onlyCUDA, |
| skipMeta, |
| PYTORCH_CUDA_MEMCHECK, |
| ) |
| from torch.testing._internal.common_dtype import floating_types_and_half |
| from torch.testing._internal.common_utils import ( |
| decorateIf, |
| freeze_rng_state, |
| gradcheck, |
| instantiate_parametrized_tests, |
| IS_FBCODE, |
| parametrize, |
| run_tests, |
| skipIfSlowGradcheckEnv, |
| markDynamoStrictTest, |
| xfailIfTorchDynamo, |
| subtest, |
| TEST_WITH_ROCM, |
| TestCase, |
| ) |
| |
| from torch.nested._internal.nested_tensor import ( |
| buffer_from_jagged, |
| jagged_from_list, |
| NestedTensor, |
| ) |
| |
| # Tests are ported from pytorch/nestedtensor. |
| # This makes porting as_nested_tensor easier in the future. |
| |
| |
| def _iter_constructors(): |
| # yield as_nested_tensor |
| yield torch.nested.nested_tensor |
| |
| # Helper function to generate a pair of random nested tensors |
| # one is contiguous, the other is not, but they appear to have same entries |
| # an output nested tensor consists of |
| # * `len(ragged_sizes)` matrices |
| # * matrices[i].shape == (20, ragged_sizes[i]) |
| |
| |
| def random_nt_noncontiguous_pair(ragged_sizes, device="cpu", dtype=torch.float16): |
| xs = [] |
| for size in ragged_sizes: |
| xs.append(torch.randn((size, 20), device=device, dtype=dtype)) |
| # contiguous nested tensor |
| ys = [] |
| for x in xs: |
| ys.append(x.transpose(-1, -2)) |
| nt_contiguous = torch.nested.nested_tensor(ys) |
| # noncontiguous nested tensor |
| n = len(ragged_sizes) |
| nt_noncontiguous = torch.nested.nested_tensor(xs).transpose(-1, -2) |
| return nt_contiguous, nt_noncontiguous |
| |
| # Helper functions to pad a noncontiguous nested tensor |
| # can be replaced once to_padded_tensor supports noncontiguous memory |
| |
| |
| def noncontiguous_to_padded_tensor(input, shape=None): |
| tensors = input.unbind() |
| ntensors = len(tensors) |
| assert ntensors > 0 |
| if shape is None: |
| shape = [] |
| for size in tensors[0].shape: |
| shape.append(size) |
| for i in range(1, ntensors): |
| new_shape = tensors[i].shape |
| for j in range(len(shape)): |
| shape[j] = max(shape[j], new_shape[j]) |
| shape = [ntensors] + shape |
| result = tensors[0].new_zeros(shape) |
| for itensor in range(ntensors): |
| tensor = tensors[itensor] |
| view = result[itensor] |
| for idim in range(tensor.dim()): |
| view = view.narrow(idim, 0, tensor.size(idim)) |
| view.copy_(tensor) |
| return result |
| |
| # Helper function to generate a random nested tensor |
| |
| |
| def random_nt(device, dtype, num_tensors, max_dims, min_dims=None, layout=torch.strided, require_non_empty=True): |
| if min_dims is None: |
| min_dims = tuple([0] * len(max_dims)) |
| |
| assert len(max_dims) == len(min_dims) |
| for min_dim, max_dim in zip(min_dims, max_dims): |
| assert max_dim > min_dim, "random_nt: max_dim must be greater than min_dim" |
| assert min_dim >= 0, "random_nt: min_dim must be non-negative" |
| if require_non_empty: |
| assert not (min_dim == 0 and max_dim == 1), ( |
| "random_nt: zero cannot be the only possible value if require_non_empty is True" |
| ) |
| |
| if require_non_empty: |
| # Select a random idx that will be required to be non-empty |
| non_zero_idx = torch.randint(low=0, high=num_tensors, size=(1,)).item() |
| |
| ts1 = [] |
| for i, _ in enumerate(range(num_tensors)): |
| tensor_dims = [] |
| for min_dim, max_dim in zip(min_dims, max_dims): |
| new_min_dim = min_dim |
| if require_non_empty and i == non_zero_idx and min_dim == 0: |
| new_min_dim = 1 |
| tensor_dims.append(torch.randint(low=new_min_dim, high=max_dim, size=(1,)).item()) |
| t1 = torch.randn(tensor_dims, device=device, dtype=dtype) |
| ts1.append(t1) |
| |
| return torch.nested.nested_tensor(ts1, device=device, dtype=dtype, layout=layout) |
| |
| |
| # Alternate approach to generating a random NT. |
| # dims should be something like [5, None, 10], with None indicating that a |
| # random ragged structure should be used |
| def random_nt_from_dims(dims, device=None, dtype=None, layout=torch.strided, requires_grad=False): |
| sizes = [ |
| [d if d is not None else torch.randint(2, 10, size=(1,)).item() for d in dims[1:]] |
| for d in range(dims[0]) |
| ] |
| return torch.nested.nested_tensor([ |
| torch.randn(*size) for size in sizes |
| ], device=device, dtype=dtype, layout=layout, requires_grad=requires_grad) |
| |
| |
| # Creates an NT matching another NT's number of components and |
| # shape / ragged structure for all dims specified to be -1. |
| def random_nt_from_similar(other, dims=None): |
| if dims is None: |
| return torch.randn_like(other) |
| assert len(dims) == other.dim() |
| assert dims[0] == -1 or dims[0] == other.size(0) |
| |
| ret_sizes = [] |
| for t in other.unbind(): |
| other_size = t.shape |
| ret_size = [] |
| for i, d in enumerate(dims[1:]): |
| if d == -1: |
| ret_size.append(other_size[i]) |
| else: |
| ret_size.append(d) |
| ret_sizes.append(ret_size) |
| |
| return torch.nested.nested_tensor([ |
| torch.randn(*size) for size in ret_sizes |
| ], device=other.device) |
| |
| |
| # makes naming nice for tests that parametrize over layout. |
| def layout_name(layout): |
| # e.g. "torch.jagged" -> "jagged" |
| return layout.__repr__().split(".")[-1] |
| |
| |
| @markDynamoStrictTest |
| class TestNestedTensor(TestCase): |
| @parametrize("batch_size", [2, 4]) |
| @parametrize("max_seq_len", [3, 5]) |
| @parametrize("vocab_size", [10, 20]) |
| def test_2d_nested_tensor(self, batch_size, max_seq_len, vocab_size): |
| data = [] |
| nested_tensor_ref_list = [] |
| for _ in range(batch_size): |
| if max_seq_len == 0: |
| length = 0 |
| else: |
| length = np.random.randint(low=1, high=max_seq_len) |
| row = list(np.random.randint(low=0, high=vocab_size, size=(length,))) |
| data.append(row) |
| nested_tensor_ref_list.append(torch.Tensor(row)) |
| nested_tensor = torch.nested.nested_tensor(data, dtype=torch.int64) |
| nested_tensor_list = nested_tensor.unbind() |
| for id in range(batch_size): |
| self.assertEqual( |
| nested_tensor_list[id], |
| nested_tensor_ref_list[id].type(torch.int64) |
| ) |
| |
| @parametrize("batch_size", [2, 4]) |
| @parametrize("max_seq_len", [3, 5]) |
| @parametrize("vocab_size", [10, 20]) |
| def test_3d_nested_tensor(self, batch_size, max_seq_len, vocab_size): |
| data = [] |
| nested_tensor_ref_list = [] |
| for _ in range(batch_size): |
| if max_seq_len == 0: |
| length = 0 |
| else: |
| length = np.random.randint(low=1, high=max_seq_len) |
| row = list(np.random.randint(low=0, high=vocab_size, size=(length,))) |
| row = [list(item * np.arange(max_seq_len)) for item in row] |
| data.append(row) |
| nested_tensor_ref_list.append(torch.Tensor(row)) |
| nested_tensor = torch.nested.nested_tensor(data, dtype=torch.int64) |
| nested_tensor_list = nested_tensor.unbind() |
| for id in range(batch_size): |
| self.assertEqual( |
| nested_tensor_list[id], |
| nested_tensor_ref_list[id].type(torch.int64) |
| ) |
| |
| @parametrize("batch_size", [2, 4]) |
| @parametrize("max_seq_len", [3, 5]) |
| @parametrize("vocab_size", [10, 20]) |
| def test_3d_nested_tensor_float(self, batch_size, max_seq_len, vocab_size): |
| data = [] |
| nested_tensor_ref_list = [] |
| for _ in range(batch_size): |
| if max_seq_len == 0: |
| length = 0 |
| else: |
| length = np.random.randint(low=1, high=max_seq_len) |
| row = list( |
| np.random.randint(low=0, high=vocab_size, size=(length,)).astype(float) |
| ) |
| row = [list(item * np.arange(max_seq_len)) for item in row] |
| data.append(row) |
| nested_tensor_ref_list.append(torch.Tensor(row)) |
| nested_tensor = torch.nested.nested_tensor(data, dtype=torch.float) |
| nested_tensor_list = nested_tensor.unbind() |
| for id in range(batch_size): |
| self.assertEqual( |
| nested_tensor_list[id], |
| nested_tensor_ref_list[id].type(torch.float) |
| ) |
| |
| |
| @torch.inference_mode() |
| def _test_unbind_case(self, a, b): |
| nt = torch.nested.nested_tensor([a, b]) |
| a1, b1 = nt.unbind() |
| self.assertTrue(a is not a1) |
| self.assertTrue(b is not b1) |
| |
| nt = torch.nested.nested_tensor([a, b], dtype=a.dtype) |
| a1, b1 = nt.unbind(0) |
| self.assertEqual(a, a1) |
| self.assertEqual(b, b1) |
| |
| a = torch.randn((2, 3)).add_(1) |
| nt = torch.nested.nested_tensor([a]) |
| self.assertEqual(a, nt.unbind(0)[0]) |
| |
| @torch.inference_mode() |
| def test_unbind_0(self): |
| self._test_unbind_case( |
| torch.tensor([1, 2]), torch.tensor([7, 8]), |
| ) |
| |
| @torch.inference_mode() |
| def test_unbind_1(self): |
| self._test_unbind_case( |
| torch.tensor([1]), torch.tensor([7]), |
| ) |
| |
| @torch.inference_mode() |
| def test_unbind_3(self): |
| self._test_unbind_case( |
| torch.tensor([1.0]), torch.tensor([]), |
| ) |
| |
| @torch.inference_mode() |
| def test_unbind_4(self): |
| self._test_unbind_case( |
| torch.tensor([]), torch.tensor([]), |
| ) |
| |
| @torch.inference_mode() |
| def test_unbind_dim(self): |
| def _test_fn(unbind_fn): |
| a = torch.rand(3, 2) |
| b = torch.rand(2, 3) |
| nt = torch.nested.nested_tensor([a, b]) |
| self.assertRaises(RuntimeError, lambda: unbind_fn(nt, 1)) |
| |
| # Both of these tests are necessary, because we're using |
| # torch_function. |
| _test_fn(lambda x, dim: x.unbind(dim)) |
| # TODO: Re-enable this once using torch_dispatch |
| # _test_fn(lambda x, dim: torch.unbind(x, dim)) |
| |
| @torch.inference_mode() |
| def test_nested_tensor(self): |
| self.assertRaises(TypeError, lambda: torch.nested.nested_tensor(torch.tensor([3.0]))) |
| self.assertRaises(TypeError, lambda: torch.nested.nested_tensor(4.0)) |
| |
| @torch.inference_mode() |
| def test_nested_tensor_matching_dim(self): |
| self.assertRaisesRegex( |
| RuntimeError, |
| "Found dimension 1 for Tensor at index 1 and dimension 0 for Tensor at index 0.", |
| lambda: torch.nested.nested_tensor([torch.tensor(1.0), torch.tensor([])]), |
| ) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "Found dimension 1 for Tensor at index 2 and dimension 0 for Tensor at index 1.", |
| lambda: torch.nested.nested_tensor( |
| [torch.tensor(1.0), torch.tensor(2.0), torch.tensor([])] |
| ), |
| ) |
| |
| @torch.inference_mode() |
| def test_default_nested_tensor(self): |
| self.assertRaises(TypeError, lambda: torch.nested.nested_tensor()) |
| default_nested_tensor = torch.nested.nested_tensor([]) |
| default_tensor = torch.tensor([]) |
| # self.assertEqual(default_nested_tensor.nested_dim(), 1) |
| # self.assertEqual(default_nested_tensor.nested_size(), ()) |
| self.assertEqual(default_nested_tensor.dim(), default_tensor.dim()) |
| self.assertEqual(default_nested_tensor.layout, default_tensor.layout) |
| self.assertEqual(default_nested_tensor.device, default_tensor.device) |
| self.assertEqual(default_nested_tensor.dtype, default_tensor.dtype) |
| self.assertEqual( |
| default_nested_tensor.requires_grad, default_tensor.requires_grad |
| ) |
| self.assertIsNone(default_tensor.grad) |
| # TODO: Re-enable once we have a performance driven |
| # use case and implementation. |
| # self.assertEqual(default_nested_tensor.is_pinned(), |
| # default_tensor.is_pinned()) |
| |
| @torch.inference_mode() |
| def test_dim(self): |
| for constructor in _iter_constructors(): |
| a1 = constructor([]) |
| self.assertEqual(a1.dim(), 1) |
| a1 = constructor([torch.tensor(3.0)]) |
| self.assertEqual(a1.dim(), 1) |
| a1 = constructor([torch.tensor([1, 2, 3, 4])]) |
| self.assertEqual(a1.dim(), 2) |
| |
| @unittest.skipIf(IS_FBCODE, "numel is not virtual in fbcode.") |
| @torch.inference_mode() |
| def test_numel(self): |
| for constructor in _iter_constructors(): |
| a1 = constructor([]) |
| self.assertEqual(a1.numel(), 0) |
| a1 = constructor([torch.tensor(3.0), torch.tensor(4.0)]) |
| self.assertEqual(a1.numel(), 2) |
| a1 = constructor([torch.randn(2, 2, 2)]) |
| self.assertEqual(a1.numel(), 8) |
| a1 = constructor([torch.randn([1, 2, 3]), torch.randn(3, 2, 1)]) |
| self.assertEqual(a1.numel(), 12) |
| a1 = constructor([torch.randn([1, 1, 3]), torch.randn(3, 2, 4)]) |
| self.assertEqual(a1.numel(), 27) |
| a1 = constructor([torch.randn([5, 5, 5]), torch.randn(6, 6, 6)]) |
| self.assertEqual(a1.numel(), 341) |
| |
| # Interesting edge case |
| a1 = constructor([torch.randn([1, 2, 3]), torch.randn(1, 2, 0)]) |
| self.assertEqual(a1.numel(), 6) |
| |
| @torch.inference_mode() |
| def test_size(self): |
| for constructor in _iter_constructors(): |
| a1 = constructor([]) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "NestedTensorImpl doesn't support sizes", |
| lambda: a1.size(), |
| ) |
| |
| def test_size_dim(self): |
| a = torch.nested.nested_tensor([]) |
| self.assertEqual(a.size(0), 0) |
| |
| a = torch.nested.nested_tensor([torch.tensor(1)]) |
| self.assertEqual(a.size(0), 1) |
| |
| a = torch.nested.nested_tensor([torch.tensor(1), torch.tensor(2)]) |
| self.assertEqual(a.size(0), 2) |
| |
| a = torch.nested.nested_tensor([torch.rand(1, 2), |
| torch.rand(1, 8)]) |
| self.assertEqual(a.size(0), 2) |
| self.assertEqual(a.size(1), 1) |
| self.assertRaisesRegex( |
| RuntimeError, "Given dimension 2 is irregular and does not have a size", lambda: a.size(2)) |
| |
| a = torch.nested.nested_tensor([torch.rand(3, 4), |
| torch.rand(5, 4)]) |
| self.assertEqual(a.size(0), 2) |
| self.assertRaisesRegex( |
| RuntimeError, "Given dimension 1 is irregular and does not have a size", lambda: a.size(1)) |
| self.assertEqual(a.size(2), 4) |
| |
| @unittest.skipIf(IS_FBCODE, "stride is not virtual in fbcode.") |
| @torch.inference_mode() |
| def test_stride(self): |
| for constructor in _iter_constructors(): |
| a1 = constructor([]) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "NestedTensorImpl doesn't support strides", |
| lambda: a1.stride(), |
| ) |
| |
| @unittest.skipIf(IS_FBCODE, "is_contiguous is not virtual in fbcode.") |
| @torch.inference_mode() |
| def test_is_contiguous(self): |
| # Test empty case |
| nt_empty = torch.nested.nested_tensor([]) |
| assert nt_empty.is_contiguous() |
| self.assertEqual(nt_empty, nt_empty.contiguous()) |
| |
| nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7)) |
| |
| # Test contiguous case |
| assert nt_contiguous.is_contiguous() |
| self.assertEqual(nt_contiguous, nt_contiguous.contiguous()) |
| |
| # Test non_contiguous case |
| assert not nt_noncontiguous.is_contiguous() |
| self.assertEqual(nt_contiguous, nt_noncontiguous.contiguous()) |
| |
| # Test querying by memory_format |
| self.assertTrue(nt_contiguous.is_contiguous(memory_format=torch.contiguous_format)) |
| self.assertTrue(not nt_noncontiguous.is_contiguous(memory_format=torch.contiguous_format)) |
| |
| @torch.inference_mode() |
| def test_repr_string(self): |
| a = torch.nested.nested_tensor([]) |
| expected = "nested_tensor([\n\n])" |
| self.assertEqual(str(a), expected) |
| self.assertEqual(repr(a), expected) |
| |
| a = torch.nested.nested_tensor([torch.tensor(1.0)]) |
| expected = "nested_tensor([\n tensor(1.)\n])" |
| self.assertEqual(str(a), expected) |
| self.assertEqual(repr(a), expected) |
| |
| a = torch.nested.nested_tensor([torch.tensor([[1, 2]]), torch.tensor([[4, 5]])]) |
| expected = "nested_tensor([\n tensor([[1, 2]]),\n tensor([[4, 5]])\n])" |
| self.assertEqual(str(a), expected) |
| self.assertEqual(repr(a), expected) |
| |
| def test_to_padded_tensor_on_empty_tensor(self): |
| |
| nt = torch.nested.nested_tensor([]) |
| empty = torch.nested.to_padded_tensor(nt, 4) |
| self.assertEqual(empty, torch.tensor([])) |
| |
| def test_nested_namespace(self): |
| nt = torch.nested.nested_tensor([torch.randn(2, 3), torch.randn(4, 5)]) |
| result = nt.to_padded_tensor(4) |
| nested_namespace_result = torch.nested.to_padded_tensor(nt, 4) |
| self.assertEqual(result, nested_namespace_result) |
| |
| def test_to(self): |
| ntensors = 4 |
| nt = random_nt(torch.device('cpu'), torch.float32, ntensors, (4, 4)) |
| |
| def test_copy_behavior(t, non_blocking=False): |
| self.assertIs(t, t.to(t, non_blocking=non_blocking)) |
| self.assertIs(t, t.to(t.dtype, non_blocking=non_blocking)) |
| self.assertIs(t, t.to(torch.empty_like(t), non_blocking=non_blocking)) |
| self.assertIsNot(t, t.to(t, non_blocking=non_blocking, copy=True)) |
| self.assertIsNot(t, t.to(t.dtype, non_blocking=non_blocking, copy=True)) |
| self.assertIsNot(t, t.to(torch.empty_like(t), non_blocking=non_blocking, copy=True)) |
| |
| devices = [t.device] |
| if t.device.type == 'cuda': |
| if t.device.index == -1: |
| devices.append(f'cuda:{torch.cuda.current_device()}') |
| elif t.device.index == torch.cuda.current_device(): |
| devices.append('cuda') |
| for device in devices: |
| self.assertIs(t, t.to(device, non_blocking=non_blocking)) |
| self.assertIs(t, t.to(device, t.dtype, non_blocking=non_blocking)) |
| self.assertIsNot(t, t.to(device, non_blocking=non_blocking, copy=True)) |
| self.assertIsNot(t, t.to(device, t.dtype, non_blocking=non_blocking, copy=True)) |
| |
| test_copy_behavior(nt) |
| self.assertEqual(nt.device, nt.to('cpu').device) |
| self.assertEqual(nt.device, nt.to('cpu', dtype=torch.float32).device) |
| self.assertIs(torch.float32, nt.to('cpu', dtype=torch.float32).dtype) |
| self.assertEqual(nt.device, nt.to(torch.float32).device) |
| self.assertIs(torch.float32, nt.to(dtype=torch.float32).dtype) |
| |
| def test_data_ptr(getter): |
| self.assertEqual(getter(nt), getter(nt.to('cpu'))) |
| self.assertEqual(getter(nt), getter(nt.to(dtype=nt.dtype, device=nt.device, copy=False))) |
| self.assertEqual(getter(nt), getter(nt.to('cpu', copy=False))) |
| self.assertNotEqual(getter(nt), getter(nt.to('cpu', copy=True))) |
| |
| test_data_ptr(lambda nt: nt.data_ptr()) |
| |
| if torch.cuda.is_available(): |
| for non_blocking in [True, False]: |
| for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']: |
| nt2 = random_nt(cuda, torch.float32, ntensors, (4, 4)) |
| test_copy_behavior(nt2, non_blocking) |
| self.assertEqual(nt2.device, nt2.to(cuda, non_blocking=non_blocking).device) |
| self.assertEqual(nt.device, nt2.to('cpu', non_blocking=non_blocking).device) |
| self.assertEqual(nt2.device, nt.to(cuda, non_blocking=non_blocking).device) |
| self.assertIs(torch.int32, nt2.to('cpu', dtype=torch.int32, non_blocking=non_blocking).dtype) |
| self.assertEqual(nt.device, nt2.to('cpu', dtype=torch.int32, non_blocking=non_blocking).device) |
| self.assertIs(torch.int32, nt2.to(dtype=torch.int32).dtype) |
| self.assertEqual(nt2.device, nt2.to(dtype=torch.int32).device) |
| |
| def test_copy_(self): |
| ntensors = 4 |
| nt = random_nt(torch.device('cpu'), torch.float32, ntensors, (4, 4)) |
| nt_copy = torch.empty_like(nt) |
| nt_copy.copy_(nt) |
| |
| for (nt_ub, nt_copy_ub) in zip(nt.unbind(), nt_copy): |
| self.assertEqual(nt_ub, nt_copy_ub) |
| |
| nt_error = torch.nested.nested_tensor([torch.tensor([0, 0])]) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "copy_ only supports tensors that are the same size for Nested implementations", |
| lambda: nt_error.copy_(nt) |
| ) |
| |
| if torch.cuda.is_available(): |
| nt = random_nt(torch.device('cuda'), torch.float32, ntensors, (4, 4)) |
| nt_copy = torch.empty_like(nt, device=torch.device('cpu')) |
| nt_copy.copy_(nt, non_blocking=True) |
| torch.cuda.current_stream(torch.cuda.current_device()).synchronize() |
| for (nt_ub, nt_copy_ub) in zip(nt.unbind(), nt_copy): |
| self.assertEqual(nt_ub, nt_copy_ub) |
| |
| nt_copy = torch.empty_like(nt, device=torch.device('cpu')) |
| nt_copy.copy_(nt, non_blocking=False) |
| for (nt_ub, nt_copy_ub) in zip(nt.unbind(), nt_copy): |
| self.assertEqual(nt_ub, nt_copy_ub) |
| |
| def test_fill_(self): |
| ntensors = 4 |
| nt = random_nt(torch.device('cpu'), torch.float32, ntensors, (4, 4)) |
| nt.fill_(10.) |
| for nt_ub in nt.unbind(): |
| t = torch.empty_like(nt_ub) |
| t.fill_(10.) |
| self.assertEqual(nt_ub, t) |
| |
| fill_tensor = torch.tensor([11.]) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "fill_ only supports 0-dimension value tensor", |
| lambda: nt.fill_(fill_tensor) |
| ) |
| |
| nt.fill_(fill_tensor[0]) |
| for nt_ub in nt.unbind(): |
| t = torch.empty_like(nt_ub) |
| t.fill_(11.) |
| self.assertEqual(nt_ub, t) |
| |
| def test_zero_(self): |
| ntensors = 4 |
| nt = random_nt(torch.device('cpu'), torch.float32, ntensors, (4, 4)) |
| nt.zero_() |
| for nt_ub in nt.unbind(): |
| t = torch.empty_like(nt_ub) |
| t.fill_(0.) |
| self.assertEqual(nt_ub, t) |
| |
| @parametrize("func", [torch.ones_like, torch.zeros_like, torch.randn_like], |
| name_fn=lambda f: f.__name__) |
| def test_like_functions(self, func): |
| ntensors = 4 |
| nt = random_nt(torch.device('cpu'), torch.float32, ntensors, (4, 4)) |
| torch.manual_seed(1) |
| nt_like = func(nt) |
| |
| torch.manual_seed(1) |
| for nt_ub in nt_like.unbind(): |
| t_like = func(nt_ub) |
| self.assertEqual(nt_ub, t_like) |
| |
| def test_cat(self): |
| # dim=0 success case |
| # No constraints on ragged structures matching. |
| x = random_nt_from_dims([5, None, 10]) |
| y = random_nt_from_dims([3, 4, None]) |
| output = torch.cat([x, y], dim=0) |
| for out_component, xy_component in zip( |
| output.unbind(), itertools.chain(x.unbind(), y.unbind())): |
| self.assertEqual(out_component, xy_component) |
| |
| # dim=-1 success case |
| # shape (B, *, D) |
| x = random_nt_from_dims([5, None, 10]) |
| # shape (B, *, D'); same structure as x but dim=-1 differs |
| y = random_nt_from_similar(x, dims=[-1, -1, 8]) |
| # should be shape (B, *, D + D') when supported |
| output = torch.cat([x, y], dim=-1) |
| for out_component, x_component, y_component in zip(output.unbind(), x.unbind(), y.unbind()): |
| self.assertEqual(out_component, torch.cat([x_component, y_component], dim=-1)) |
| |
| # dim between 0 and -1 success case |
| x = random_nt_from_dims([5, None, 2, 3]) |
| # same structure as x but dim=2 differs |
| y = random_nt_from_similar(x, dims=[-1, -1, 4, -1]) |
| output = torch.cat([x, y], dim=2) |
| for out_component, x_component, y_component in zip(output.unbind(), x.unbind(), y.unbind()): |
| self.assertEqual(out_component, torch.cat([x_component, y_component], dim=1)) |
| |
| # error case: mixed NT / dense inputs |
| x = random_nt_from_dims([5, None, 2]) |
| y = torch.randn(5, 3, 2) |
| with self.assertRaisesRegex( |
| RuntimeError, "expected each tensor in given list to be nested"): |
| torch.cat([x, y], dim=-1) |
| |
| # error case: NTs with different dims |
| x = random_nt_from_dims([5, None, 2]) |
| y = random_nt_from_dims([5, None, 2, 3]) |
| with self.assertRaisesRegex( |
| RuntimeError, "expected all nested tensors to have matching ragged structures outside of the concatenated dim"): |
| torch.cat([x, y], dim=-1) |
| |
| # error case: non-contiguous NT |
| x, y = random_nt_noncontiguous_pair((2, 3, 4), dtype=torch.float32) |
| # transpose to put ragged dim next to batch dim |
| x, y = x.transpose(-2, -1), y.transpose(-2, -1) |
| with self.assertRaisesRegex( |
| RuntimeError, "only contiguous nested tensors are supported"): |
| torch.cat([x, y], dim=-1) |
| |
| # error case: multiple ragged dims in inputs |
| x = random_nt_from_dims([5, None, None, 2]) |
| y = random_nt_from_similar(x) |
| with self.assertRaisesRegex( |
| RuntimeError, "only nested tensors with a single ragged dim next to the batch dim are supported"): |
| torch.cat([x, y], dim=-1) |
| |
| # error case: ragged dim not next to batch dim |
| x = random_nt_from_dims([5, 2, None]) |
| y = random_nt_from_similar(x) |
| with self.assertRaisesRegex( |
| RuntimeError, "only nested tensors with a single ragged dim next to the batch dim are supported"): |
| torch.cat([x, y], dim=1) |
| |
| # error case: NTs with different batch sizes |
| x = random_nt_from_dims([5, None, 2]) |
| y = random_nt_from_dims([3, None, 2]) |
| with self.assertRaisesRegex( |
| RuntimeError, "expected all nested tensors to have matching ragged structures outside of the concatenated dim"): |
| torch.cat([x, y], dim=-1) |
| |
| # error case: NTs with different ragged structures |
| x = torch.nested.nested_tensor([ |
| torch.randn(2, 6), |
| torch.randn(4, 6), |
| torch.randn(5, 6), |
| ]) |
| y = torch.nested.nested_tensor([ |
| torch.randn(5, 6), |
| torch.randn(4, 6), |
| torch.randn(2, 6), |
| ]) |
| with self.assertRaisesRegex( |
| RuntimeError, "expected all nested tensors to have matching ragged structures outside of the concatenated dim"): |
| torch.cat([x, y], dim=-1) |
| |
| |
| @markDynamoStrictTest |
| class TestNestedTensorDeviceType(TestCase): |
| # Helper function to generate a pair of random nested tensors |
| # the 2 nested tensors have same shapes |
| def random_nt_pair(self, device, dtype, num_tensors, max_dims): |
| ts1 = [] |
| ts2 = [] |
| for _ in range(num_tensors): |
| tensor_dims = tuple([torch.randint(low=0, high=max_dim, size=(1,)).item() for max_dim in max_dims]) |
| t1 = torch.randn(tensor_dims, device=device, dtype=dtype) |
| t2 = torch.randn(tensor_dims, device=device, dtype=dtype) |
| ts1.append(t1) |
| ts2.append(t2) |
| return (torch.nested.nested_tensor(ts1, device=device, dtype=dtype), |
| torch.nested.nested_tensor(ts2, device=device, dtype=dtype)) |
| |
| @dtypes(*floating_types_and_half()) |
| def test_detach(self, device, dtype): |
| a = torch.randn(2, 4, device=device, dtype=dtype, requires_grad=False) |
| b = torch.randn(5, 4, device=device, dtype=dtype, requires_grad=False) |
| x = torch.nested.nested_tensor([a, b], requires_grad=True) |
| |
| x_detach = x.detach() |
| |
| z = x_detach * 4 |
| self.assertFalse(x_detach.requires_grad) |
| self.assertFalse(z.requires_grad) |
| |
| a = torch.randn(2, 4, device=device, dtype=dtype, requires_grad=True) |
| b = torch.randn(5, 4, device=device, dtype=dtype, requires_grad=True) |
| x = torch.nested.as_nested_tensor([a, b]) |
| |
| y = x * 2 |
| y = y.detach() |
| self.assertFalse(y.requires_grad) |
| self.assertIsNone(y.grad_fn) |
| |
| z = x + y |
| torch.nested.to_padded_tensor(z, 0).sum().backward() |
| # This is an incorrect gradient, but we assume that's what the user |
| # wanted. detach() is an advanced option. |
| self.assertEqual(a.grad, torch.ones(2, 4, device=device, dtype=dtype)) |
| self.assertEqual(b.grad, torch.ones(5, 4, device=device, dtype=dtype)) |
| |
| @dtypes(torch.float, torch.float16, torch.double) |
| def test_unbind_noncontiguous(self, device, dtype): |
| nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype) |
| ub_contiguous = nt_contiguous.unbind() |
| ub_noncontiguous = nt_noncontiguous.unbind() |
| self.assertEqual(len(ub_contiguous), len(ub_noncontiguous)) |
| n = len(ub_contiguous) |
| for i in range(n): |
| self.assertEqual(ub_contiguous[i], ub_noncontiguous[i]) |
| |
| @dtypes(torch.float) |
| @skipMeta |
| def test_to_then_from_padded_tensor_no_transform0213(self, device, dtype): |
| t = torch.randn(4, 4, 4, device=device, dtype=dtype) |
| ts = list(torch.unbind(t)) |
| ts[0] = ts[0][:-1] |
| nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) |
| padded = torch.nested.to_padded_tensor(nt, 0) |
| |
| nt_to = torch._nested_from_padded_and_nested_example(padded, nt) |
| |
| for (t1, t2) in zip(nt.unbind(), nt_to.unbind()): |
| self.assertEqual(t1, t2) |
| self.assertEqual(nt.device, nt_to.device) |
| |
| @dtypes(torch.float) |
| @dtypesIfCUDA(torch.float, torch.half) |
| @skipMeta |
| @torch.inference_mode() |
| def test_layer_norm(self, device, dtype): |
| def _test(size): |
| # Simple shapes test |
| t0 = torch.randn(2, size, device=device, dtype=dtype, requires_grad=False) |
| t1 = torch.randn(2, size, device=device, dtype=dtype, requires_grad=False) |
| ts = [t0, t1, t0, t1] |
| nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) |
| layer_norm = torch.nn.LayerNorm(size, device=device, dtype=dtype) |
| nt_result = layer_norm(nt) |
| for (nt_subresult, t) in zip(nt_result.unbind(), ts): |
| t_result = layer_norm(t.reshape(1, -1, size).squeeze(0)) |
| self.assertEqual(nt_subresult, t_result) |
| |
| # More complex nt test with different lengths for each tensor |
| t0 = torch.randn(4, size, device=device, dtype=dtype, requires_grad=False) |
| t1 = torch.randn(10, size, device=device, dtype=dtype, requires_grad=False) |
| t2 = torch.randn(7, size, device=device, dtype=dtype, requires_grad=False) |
| ts = [t0, t1, t2, t0, t2] |
| nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) |
| layer_norm = torch.nn.LayerNorm(size, device=device, dtype=dtype) |
| nt_result = layer_norm(nt) |
| for (nt_subresult, t) in zip(nt_result.unbind(), ts): |
| t_result = layer_norm(t.reshape(1, -1, size).squeeze(0)) |
| self.assertEqual(nt_subresult, t_result) |
| |
| if size <= 128: |
| # Test with multidimensional tensors after irregular dim |
| # (run only with smaller dimensions to ensure fast execution) |
| t0 = torch.randn(4, size, size, 4, device=device, dtype=dtype, requires_grad=False) |
| t1 = torch.randn(10, size, size, 4, device=device, dtype=dtype, requires_grad=False) |
| t2 = torch.randn(7, size, size, 4, device=device, dtype=dtype, requires_grad=False) |
| ts = [t0, t1, t2, t0, t2] |
| nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) |
| layer_norm = torch.nn.LayerNorm((size, size, 4), device=device, dtype=dtype) |
| nt_result = layer_norm(nt) |
| for (nt_subresult, t) in zip(nt_result.unbind(), ts): |
| t_result = layer_norm(t.reshape(1, -1, size, size, 4).squeeze(0)) |
| self.assertEqual(nt_subresult, t_result) |
| |
| # Test where the normalizing dimensions are not all |
| layer_norm = torch.nn.LayerNorm((size, 4), device=device, dtype=dtype) |
| nt_result = layer_norm(nt) |
| for (nt_subresult, t) in zip(nt_result.unbind(), ts): |
| t_result = layer_norm(t.reshape(1, -1, size, size, 4).squeeze(0)) |
| self.assertEqual(nt_subresult, t_result) |
| |
| for size in (1024, 1023, 513, 512, 256, 128, 2, 4, 32): |
| _test(size) |
| |
| @dtypes(torch.float) |
| @dtypesIfCUDA(torch.float, torch.half) |
| @skipMeta |
| @torch.inference_mode() |
| def test_layer_norm_breaking(self, device, dtype): |
| size = 128 |
| t0 = torch.randn(4, size, size, 4, device=device, dtype=dtype, requires_grad=False) |
| t1 = torch.randn(10, size, size, 4, device=device, dtype=dtype, requires_grad=False) |
| t2 = torch.randn(7, size, size, 4, device=device, dtype=dtype, requires_grad=False) |
| ts = [t0, t1, t2, t0, t2] |
| nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) |
| layer_norm = torch.nn.LayerNorm((4, size, size, 4), device=device, dtype=dtype) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "normalized_shape extends into irregular dimensions for the nested tensor", |
| lambda: layer_norm(nt), |
| ) |
| layer_norm = torch.nn.LayerNorm((size + 1, size, 4), device=device, dtype=dtype) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "The shape at dimension 0", |
| lambda: layer_norm(nt), |
| ) |
| |
| @decorateIf( |
| xfailIfTorchDynamo, |
| # only fails in python 3.11. TODO: Ensure this is fixed once views work! |
| lambda params: params["layout"] == torch.jagged and sys.version_info >= (3, 11) |
| ) |
| @parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name) |
| def test_embedding(self, device, layout): |
| inputs = [ |
| torch.randint(100, (L,), device=device, dtype=torch.int64) |
| for L in torch.randint(5, 50, (8,)) |
| ] |
| x = torch.nested.nested_tensor(inputs, device=device, dtype=torch.int64, layout=layout) |
| emb = torch.nn.Embedding(100, 8, device=device) |
| y = emb(x) |
| ys = y.unbind() |
| for i, inp in enumerate(inputs): |
| self.assertEqual(emb(inp), ys[i]) |
| |
| |
| @skipMeta |
| @torch.inference_mode() |
| @dtypes(*floating_types_and_half()) |
| def test_masked_fill(self, device, dtype): |
| # nested tensor * nested tensor |
| (nt, mask) = self.random_nt_pair(device, dtype, 4, (4, 4)) |
| mask = torch.nested.nested_tensor([m < 0 for m in mask.unbind()]) |
| ref = torch.nested.nested_tensor([t.masked_fill(m, 0) for (t, m) in zip(nt.unbind(), mask.unbind())]) |
| out = nt.masked_fill(mask, 0) |
| self.assertEqual(ref, out) |
| |
| |
| @dtypes(torch.float, torch.float16) |
| def test_to_padded_tensor_simple(self, device, dtype): |
| t = torch.randn(4, 4, 4, device=device, dtype=dtype) |
| ts = list(torch.unbind(t)) |
| ts[0] = ts[0][:-1] |
| nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) |
| for padding_value in (0, 1): |
| padded = torch.nested.to_padded_tensor(nt, padding_value) |
| |
| correct_output = t.clone() |
| if padding_value == 0: |
| correct_output[0][-1] = torch.zeros_like(correct_output[0][-1]) |
| else: |
| correct_output[0][-1] = torch.ones_like(correct_output[0][-1]) |
| |
| self.assertEqual(padded, correct_output) |
| self.assertEqual(padded.device, torch.device(device)) |
| self.assertEqual(padded.dtype, dtype) |
| |
| @dtypes(torch.float, torch.float16) |
| def test_to_padded_tensor_output_size(self, device, dtype): |
| t = torch.randn(4, 4, 4, device=device, dtype=dtype) |
| output_size = (4, 6, 5) |
| ts = list(torch.unbind(t)) |
| ts[0] = ts[0][:-1] |
| nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) |
| for padding_value in (0, 1): |
| padded = torch.nested.to_padded_tensor(nt, padding_value, output_size=output_size) |
| correct_output = torch.ones(output_size, device=device, dtype=dtype) * padding_value |
| correct_output[:4:, :4, :4] = t.clone() |
| if padding_value == 0: |
| correct_output[0][3] = torch.zeros_like(correct_output[0][3]) |
| else: |
| correct_output[0][3] = torch.ones_like(correct_output[0][3]) |
| |
| self.assertEqual(padded, correct_output) |
| self.assertEqual(padded.device, torch.device(device)) |
| self.assertEqual(padded.dtype, dtype) |
| |
| @dtypes(torch.float, torch.float16, torch.double) |
| def test_to_padded_tensor_dim2(self, device, dtype): |
| ts = [ |
| torch.randn(160, device=device, dtype=dtype), |
| torch.randn(1240, device=device, dtype=dtype), |
| torch.randn(2400, device=device, dtype=dtype), |
| ] |
| nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) |
| pad = 42 |
| correct_output = [] |
| for t in ts: |
| next_output = torch.ones_like(ts[2]) * pad |
| correct_output.append(next_output) |
| next_output[:t.size(0)].copy_(t) |
| correct_output = torch.stack(correct_output) |
| padded = torch.nested.to_padded_tensor(nt, pad) |
| self.assertEqual(padded, correct_output) |
| |
| @dtypes(torch.float, torch.float16, torch.double) |
| def test_to_padded_tensor_dim3(self, device, dtype): |
| ts = [ |
| torch.randn(16, 21, device=device, dtype=dtype), |
| torch.randn(24, 32, device=device, dtype=dtype), |
| torch.randn(40, 53, device=device, dtype=dtype), |
| ] |
| nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) |
| pad = 42 |
| correct_output = [] |
| for t in ts: |
| next_output = torch.ones_like(ts[2]) * pad |
| correct_output.append(next_output) |
| next_output[:t.size(0), :t.size(1)].copy_(t) |
| correct_output = torch.stack(correct_output) |
| padded = torch.nested.to_padded_tensor(nt, pad) |
| self.assertEqual(padded, correct_output) |
| |
| @dtypes(torch.float, torch.float16, torch.double) |
| def test_to_padded_tensor_dim4(self, device, dtype): |
| ts = [ |
| torch.randn(16, 21, 13, device=device, dtype=dtype), |
| torch.randn(24, 32, 14, device=device, dtype=dtype), |
| torch.randn(40, 53, 16, device=device, dtype=dtype), |
| ] |
| nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) |
| pad = 42 |
| correct_output = [] |
| for t in ts: |
| next_output = torch.ones_like(ts[2]) * pad |
| correct_output.append(next_output) |
| next_output[:t.size(0), :t.size(1), :t.size(2)].copy_(t) |
| correct_output = torch.stack(correct_output) |
| padded = torch.nested.to_padded_tensor(nt, pad) |
| self.assertEqual(padded, correct_output) |
| |
| # TODO: test noncontiguous to_padded_tensor |
| # For now this tests the functionality of noncontiguous_to_padded_tensor |
| # and the error message of to_padded_tensor |
| # since to_padded_tensor does not support noncontiguous buffer yet |
| @dtypes(torch.float, torch.float16, torch.double) |
| @torch.inference_mode() |
| def test_to_padded_tensor_noncontiguous(self, device, dtype): |
| nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype) |
| # test noncontiguous_to_padded_tensor functionality |
| self.assertEqual( |
| torch.nested.to_padded_tensor(nt_contiguous, 0.0), |
| noncontiguous_to_padded_tensor(nt_noncontiguous)) |
| # test to_padded_tensor error message |
| self.assertRaisesRegex( |
| RuntimeError, |
| r"for now to_padded_tensor only supports contiguous nested tensor", |
| lambda: torch.nested.to_padded_tensor(nt_noncontiguous, 0.0) |
| ) |
| |
| @skipMeta |
| def test_device_checks(self, device): |
| nt = torch.nested.nested_tensor([], device=device) |
| is_cuda = 'cuda' in str(device) |
| self.assertEqual(nt.is_cuda, is_cuda) |
| |
| @dtypes(torch.float, torch.float16, torch.double) |
| def test_nested_tensor_indexing(self, device, dtype): |
| # edge case: empty nested tensor |
| nt0 = torch.nested.nested_tensor([]) |
| self.assertRaises(IndexError, lambda: nt0[0]) |
| # normal case |
| x0 = torch.randn((2, 5), device=device, dtype=dtype) |
| x1 = torch.randn((3, 4), device=device, dtype=dtype) |
| nt = torch.nested.nested_tensor([x0, x1]) |
| # single index: only support integer in the batch dimension |
| self.assertEqual(nt[0], x0) |
| self.assertEqual(nt[-1], x1) |
| self.assertRaises(IndexError, lambda: nt[2]) |
| self.assertRaises(IndexError, lambda: nt[-3]) |
| self.assertRaises(NotImplementedError, lambda: nt[:]) |
| self.assertRaises(NotImplementedError, lambda: nt[...]) |
| # tuple of indices: only support integer in the batch dimension |
| # + all possible indexing in the original tensor dimensions |
| self.assertEqual(nt[0, 0, 0], x0[0, 0]) |
| self.assertEqual(nt[0, 1, :], x0[1, :]) |
| self.assertEqual(nt[1, ...], x1) |
| self.assertRaises(IndexError, lambda: nt[1, 4, 2]) |
| self.assertRaises(NotImplementedError, lambda: nt[:, 1, 1]) |
| # test select on non-batch dimensions |
| self.assertEqual(nt.select(1, 0)[0], x0.select(0, 0)) |
| self.assertEqual(nt.select(1, 0)[1], x1.select(0, 0)) |
| self.assertRaises(IndexError, lambda: nt.select(1, 3)) |
| self.assertEqual(nt.select(2, 0)[0], x0.select(1, 0)) |
| self.assertEqual(nt.select(2, 0)[1], x1.select(1, 0)) |
| self.assertRaises(IndexError, lambda: nt.select(2, 5)) |
| # make sure indexing returns a view |
| nt[0].fill_(100.0) |
| answer = torch.tensor(100.0, device=device, dtype=dtype).expand((2, 5)) |
| self.assertEqual(nt[0], answer) |
| nt[1, 1, :].fill_(200.0) |
| answer = torch.tensor(200.0, device=device, dtype=dtype).expand(4) |
| self.assertEqual(nt[1, 1, :], answer) |
| |
| # Test that indexing works when requires_grad_(True) |
| # previously this was failing because the backward kernel for select.int uses .sizes() |
| nt = torch.nested.nested_tensor([x0, x1]).requires_grad_(True) |
| self.assertEqual(nt[0], x0) |
| self.assertEqual(nt[-1], x1) |
| grad_x0 = torch.randn((2, 5), device=device, dtype=dtype) |
| nt[0].backward(grad_x0) |
| expected_grad = torch.nested.nested_tensor([grad_x0, torch.zeros((3, 4), device=device, dtype=dtype)]) |
| self.assertEqual(nt.grad, expected_grad) |
| |
| @parametrize("func", [subtest(torch.nn.functional.relu, name='relu'), |
| subtest(torch.nn.functional.relu_, name='relu_'), |
| subtest(torch.nn.functional.gelu, name='gelu'), |
| subtest(torch._C._nn.gelu_, name='gelu_'), |
| subtest(torch.tanh, name='tanh'), |
| subtest(torch.tanh_, name='tanh_'), |
| subtest(torch.neg, name='neg'), |
| subtest(torch.nn.functional.silu, name='silu'), |
| subtest(partial(torch.nn.functional.silu, inplace=True), name='silu_'), |
| subtest(torch.abs, name="abs"), |
| subtest(torch.abs_, name="abs_"), |
| subtest(torch.sgn, name="sgn"), |
| subtest(torch.logical_not, name='logical_not'), |
| subtest(torch.sin, name='sin'), |
| subtest(torch.cos, name='cos')]) |
| def test_activations(self, device, func): |
| nt, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device=device, dtype=torch.float32) |
| nested_result = func(nt) |
| self.assertTrue(nested_result.is_nested) |
| for t, t_res in zip(nt.unbind(), nested_result.unbind()): |
| self.assertEqual(func(t), t_res) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "NestedTensor must be contiguous to get buffer.", |
| lambda: func(nt_noncontiguous)) |
| |
| @parametrize("func", [subtest(torch.ge, name='ge'), |
| subtest(torch.eq, name='eq')]) |
| def test_binary_ops_with_scalar(self, device, func): |
| nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair( |
| (2, 3, 6, 7), device=device, dtype=torch.float32) |
| scalar = 0.0 |
| |
| # should work regardless of contiguity |
| for nt in (nt_contiguous, nt_noncontiguous): |
| nested_result = func(nt, scalar) |
| self.assertTrue(nested_result.is_nested) |
| for t, t_res in zip(nt.unbind(), nested_result.unbind()): |
| self.assertEqual(func(t, scalar), t_res) |
| |
| @dtypes(*floating_types_and_half()) |
| def test_nested_tensor_chunk(self, device, dtype): |
| # Transformer use case |
| a = torch.randn(3, 3 * 4, device=device, dtype=dtype) |
| b = torch.randn(2, 3 * 4, device=device, dtype=dtype) |
| c = torch.randn(1, 3 * 4, device=device, dtype=dtype) |
| a_chunks = a.chunk(3, dim=-1) |
| b_chunks = b.chunk(3, dim=-1) |
| c_chunks = c.chunk(3, dim=-1) |
| |
| a_nt = [a_chunks[0], b_chunks[0], c_chunks[0]] |
| b_nt = [a_chunks[1], b_chunks[1], c_chunks[1]] |
| c_nt = [a_chunks[2], b_chunks[2], c_chunks[2]] |
| |
| nt = torch.nested.nested_tensor([a, b, c]) |
| chunked = nt.chunk(3, dim=-1) |
| |
| self.assertEqual(chunked[0], torch.nested.nested_tensor(a_nt)) |
| self.assertEqual(chunked[1], torch.nested.nested_tensor(b_nt)) |
| self.assertEqual(chunked[2], torch.nested.nested_tensor(c_nt)) |
| |
| for chunk in chunked: |
| self.assertFalse(chunk.is_contiguous()) |
| |
| # Failure chunking on ragged dimensions |
| self.assertRaisesRegex( |
| RuntimeError, "Chunk for nested tensors is currently only supported for the last dimension.", |
| lambda: torch.chunk(nt, 5, dim=1)) |
| self.assertRaisesRegex( |
| RuntimeError, "Chunk for nested tensors is currently only supported for the last dimension.", |
| lambda: torch.chunk(nt, 5, dim=0)) |
| |
| # Failure on non-contiguous nt |
| _, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype) |
| self.assertRaisesRegex( |
| RuntimeError, "chunk expects `self` to be contiguous.", lambda: torch.chunk(nt_noncontiguous, 5, dim=-1)) |
| |
| # Failure when calling non divisible n_chunks |
| self.assertRaisesRegex( |
| RuntimeError, "Chunk for nested tensors is only supported for " |
| "nested tensors with trailing dimension divisible by chunks.", |
| lambda: torch.chunk(nt, 5, dim=-1)) |
| |
| # Failure when calling backward on a chunk |
| a = torch.randn(3, 3 * 4, device=device, dtype=dtype, requires_grad=True) |
| b = torch.randn(2, 3 * 4, device=device, dtype=dtype, requires_grad=True) |
| nt_grad = torch.nested.as_nested_tensor([a, b]) |
| chunked = torch.chunk(nt_grad, 2, dim=-1) |
| self.assertRaisesRegex(RuntimeError, "derivative for aten::chunk is not implemented", |
| lambda: chunked[0].backward(chunked[0].clone())) |
| |
| @dtypes(*floating_types_and_half()) |
| def test_nested_tensor_split_with_sizes(self, device, dtype): |
| a = torch.randn(3, 20, device=device, dtype=dtype) |
| b = torch.randn(2, 20, device=device, dtype=dtype) |
| c = torch.randn(1, 20, device=device, dtype=dtype) |
| |
| split_sizes = [4, 6, 10] |
| a_splits = a.split_with_sizes(split_sizes, dim=-1) |
| b_splits = b.split_with_sizes(split_sizes, dim=-1) |
| c_splits = c.split_with_sizes(split_sizes, dim=-1) |
| |
| nt = torch.nested.nested_tensor([a, b, c]) |
| nt_splits = nt.split_with_sizes(split_sizes, dim=-1) |
| |
| for i, nt_split in enumerate(nt_splits): |
| self.assertEqual(nt_split, torch.nested.nested_tensor( |
| [a_splits[i], b_splits[i], c_splits[i]])) |
| dense_strides = torch.stack([ |
| torch.tensor(a_splits[i].stride()), |
| torch.tensor(b_splits[i].stride()), |
| torch.tensor(c_splits[i].stride()) |
| ]) |
| self.assertEqual(nt_split._nested_tensor_strides(), dense_strides) |
| self.assertFalse(nt_split.is_contiguous()) |
| |
| # Failure calling on ragged dimensions |
| self.assertRaisesRegex( |
| RuntimeError, "split_with_sizes for nested tensors is currently only supported for the last dimension.", |
| lambda: torch.split_with_sizes(nt, split_sizes, dim=1)) |
| |
| # Failure calling on non-last dimension |
| self.assertRaisesRegex( |
| RuntimeError, "split_with_sizes for nested tensors is currently only supported for the last dimension.", |
| lambda: torch.split_with_sizes(nt, split_sizes, dim=0)) |
| |
| # Failure on non-contiguous nt |
| _, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype) |
| self.assertRaisesRegex( |
| RuntimeError, "split_with_sizes expects `self` to be contiguous.", |
| lambda: torch.split_with_sizes(nt_noncontiguous, split_sizes, dim=-1)) |
| |
| # Failure when calling with split_sizes that don't cover the full dim size |
| bad_split_sizes = [4, 6, 9] # don't add up to 20 |
| self.assertRaisesRegex( |
| RuntimeError, "split_with_sizes expects split_sizes to sum exactly to 20", |
| lambda: torch.split_with_sizes(nt, bad_split_sizes, dim=-1)) |
| |
| @dtypes(torch.float, torch.float16, torch.double) |
| @torch.inference_mode() |
| def test_nested_tensor_indexing_noncontiguous(self, device, dtype): |
| nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype) |
| self.assertEqual(nt_contiguous.size(0), nt_noncontiguous.size(0)) |
| n = nt_contiguous.size(0) |
| for i in range(n): |
| self.assertEqual(nt_contiguous[i], nt_noncontiguous[i]) |
| |
| @dtypes(torch.float, torch.float16) |
| @skipMeta |
| @torch.inference_mode() |
| @parametrize("transpose", [True, False]) |
| def test_nested_tensor_add(self, device, dtype, transpose): |
| if transpose: |
| a = torch.randn(2, 2, 2, device=device, dtype=dtype) |
| b = torch.rand(2, 2, 2, device=device, dtype=dtype) |
| c = a.transpose(-1, -2).contiguous() |
| d = b.transpose(-1, -2).contiguous() |
| nt1 = torch.nested.nested_tensor([a, b, a, b]) |
| nt2 = torch.nested.nested_tensor([c, d, c, d]).transpose(-1, -2) |
| else: |
| (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) |
| ref = torch.nested.nested_tensor([t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]) |
| out = nt1 + nt2 |
| self.assertEqual(ref, out) |
| |
| @dtypes(torch.float, torch.float16) |
| @skipMeta |
| @torch.inference_mode() |
| @parametrize("transpose", [True, False]) |
| def test_nested_tensor_sub(self, device, dtype, transpose): |
| if transpose: |
| a = torch.randn(2, 2, 2, device=device, dtype=dtype) |
| b = torch.rand(2, 2, 2, device=device, dtype=dtype) |
| c = a.transpose(-1, -2).contiguous() |
| d = b.transpose(-1, -2).contiguous() |
| nt1 = torch.nested.nested_tensor([a, b, a, b]) |
| nt2 = torch.nested.nested_tensor([c, d, c, d]).transpose(-1, -2) |
| else: |
| (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) |
| ref = torch.nested.nested_tensor([t1 - t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]) |
| out = nt1 - nt2 |
| self.assertEqual(ref, out) |
| |
| @onlyCUDA |
| @dtypes(torch.float, torch.float16) |
| @torch.inference_mode() |
| @parametrize("embedding_dim", [8, 128, 256, 384]) |
| def test_nested_tensor_dense_elementwise(self, device, dtype, embedding_dim): |
| def _test_add_mul(nt, t): |
| ref_add = torch.nested.nested_tensor( |
| [t1 + t2 for (t1, t2) in zip(nt.unbind(), t.unbind())]) |
| ref_mul = torch.nested.nested_tensor( |
| [t1 * t2 for (t1, t2) in zip(nt.unbind(), t.unbind())]) |
| self.assertEqual(nt.add(t), ref_add) |
| self.assertEqual(nt.mul(t), ref_mul) |
| |
| batch_size = 32 |
| seq_lens = torch.randint(low=0, high=10, size=(batch_size,)) |
| |
| # [B, *, D], [B, 1, D] case |
| ts = [torch.randn((seq_len, embedding_dim)) for seq_len in seq_lens] |
| nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) |
| t = torch.randn((batch_size, 1, embedding_dim), device=device, dtype=dtype) |
| _test_add_mul(nt, t) |
| |
| # [B, *], [B, 1] case |
| ts = [torch.randn(seq_len) for seq_len in seq_lens] |
| nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype) |
| t = torch.randn((batch_size, 1), device=device, dtype=dtype) |
| _test_add_mul(nt, t) |
| |
| @dtypes(torch.float, torch.float16) |
| @skipMeta |
| @torch.inference_mode() |
| def test_nested_tensor_mul(self, device, dtype): |
| # nested tensor * nested tensor |
| (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) |
| ref = torch.nested.nested_tensor([t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]) |
| out = nt1 * nt2 |
| self.assertEqual(ref, out) |
| # nested tensor * scalar |
| number = 10.0 |
| scalar = torch.tensor(number).to(dtype).to(device) |
| ref = torch.nested.nested_tensor([t * number for t in nt1.unbind()]) |
| out_number0 = nt1 * number |
| out_number1 = number * nt1 |
| out_scalar0 = nt1 * scalar |
| out_scalar1 = scalar * nt1 |
| self.assertEqual(out_number0, ref) |
| self.assertEqual(out_number1, ref) |
| self.assertEqual(out_scalar0, ref) |
| self.assertEqual(out_scalar1, ref) |
| # error case: numel == 1 but dim > 0 |
| vector = torch.tensor([number]).to(dtype).to(device) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "Expected both self and other to be nested, but got a nested self and non-nested other", |
| lambda: nt1.mul(vector) |
| ) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "Expected both self and other to be nested, but got a non-nested self and nested other", |
| lambda: vector.mul(nt1) |
| ) |
| |
| @dtypes(torch.float, torch.float16) |
| @skipMeta |
| @torch.inference_mode() |
| def test_nested_tensor_div(self, device, dtype): |
| nt, nt2 = self.random_nt_pair(device, dtype, 4, (4, 4)) |
| scale = 4.0 |
| ref = torch.nested.nested_tensor([t / scale for t in nt.unbind()]) |
| out = nt / 4.0 |
| self.assertEqual(ref, out) |
| ref_transposed = ref.transpose(1, 2) |
| out = nt.transpose(1, 2) / 4.0 |
| self.assertEqual(ref_transposed, out) |
| |
| ref = torch.nested.nested_tensor([t / t2 for (t, t2) in zip(nt.unbind(), nt2.unbind())]) |
| out = nt / nt2 |
| self.assertEqual(ref, out) |
| |
| out = nt.transpose(1, 2) / nt2.transpose(1, 2) |
| self.assertEqual(ref.transpose(1, 2), out) |
| |
| nt_transpose_copy = torch.nested.nested_tensor([t.transpose(0, 1) for t in nt.unbind()]) |
| |
| self.assertRaisesRegex( |
| RuntimeError, "div requires strides to match when given NestedTensors", |
| lambda: nt_transpose_copy.transpose(1, 2) / nt2) |
| |
| nt = torch.nested.nested_tensor([torch.randn(i, 4) for i in [3, 4, 5]], device=device, dtype=dtype) |
| nt_chunks = nt.chunk(2, -1) |
| self.assertRaisesRegex( |
| RuntimeError, "div requires offsets to match when given NestedTensors", |
| lambda: nt_chunks[0] / nt_chunks[1]) |
| |
| @dtypes(torch.float, torch.float16) |
| @skipMeta |
| @torch.inference_mode() |
| def test_nested_tensor_add_in_place(self, device, dtype): |
| (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) |
| ref = torch.nested.nested_tensor([t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]) |
| nt1 += nt2 |
| self.assertEqual(ref, nt1) |
| |
| @dtypes(torch.float, torch.float16) |
| @skipMeta |
| @torch.inference_mode() |
| def test_nested_tensor_mul_in_place(self, device, dtype): |
| # nested tensor * nested tensor |
| (nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4)) |
| ref = torch.nested.nested_tensor([t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]) |
| nt1 *= nt2 |
| self.assertEqual(ref, nt1) |
| # nested tensor * scalar |
| number = 10.0 |
| scalar = torch.tensor(number).to(dtype).to(device) |
| ref = torch.nested.nested_tensor([t * number for t in nt1.unbind()]) |
| out_number = nt1.clone() |
| out_number *= number |
| out_scalar = nt1.clone() |
| out_scalar *= scalar |
| self.assertEqual(out_number, ref) |
| self.assertEqual(out_scalar, ref) |
| self.assertRaisesRegex( |
| RuntimeError, |
| r"output with shape \[.*\] doesn't match the broadcast shape \[.*\]", |
| lambda: scalar.mul_(nt1) |
| ) |
| # error case: numel == 1 but dim > 0 |
| vector = torch.tensor([number]).to(dtype).to(device) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "Expected both self and other to be nested, but got a nested self and non-nested other", |
| lambda: nt1.mul_(vector) |
| ) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "Expected both self and other to be nested, but got a non-nested self and nested other", |
| lambda: vector.mul_(nt1) |
| ) |
| |
| @onlyCPU |
| @skipMeta |
| @dtypes(torch.float) |
| def test_nested_tensor_sum_dim(self, device, dtype): |
| params = ((2, (1, 1)), ((4), (4, 4)), (10, (3, 5, 7))) |
| |
| def test_sum(device, dtype, ntensors, max_sizes, dim, keepdim=True): |
| nt = random_nt(device, dtype, ntensors, max_sizes, require_non_empty=False) |
| nt2 = nt.clone() |
| ub2 = nt2.unbind() |
| nt.requires_grad_(True) |
| [t.requires_grad_(True) for t in ub2] |
| nt_sum = nt.sum(dim=dim, keepdim=keepdim) |
| ub2_sum = [t.sum(-1, keepdim=keepdim) for t in ub2] |
| self.assertEqual(nt_sum, torch.nested.nested_tensor(ub2_sum)) |
| |
| # test backward |
| # generate gradient tensor that has the same size as the output |
| size = nt_sum._nested_tensor_size() |
| gt2 = [] |
| for i in range(ntensors): |
| gt2.append(torch.randn(size[i].tolist(), device=device, dtype=dtype)) |
| gt = torch.nested.nested_tensor(gt2).clone() |
| nt_sum.backward(gt) |
| for t2, g2 in zip(ub2_sum, gt2): |
| t2.backward(g2) |
| self.assertEqual(nt.grad, torch.nested.nested_tensor([t.grad for t in ub2])) |
| return |
| |
| for ntensors, max_sizes in params: |
| test_sum(device, dtype, ntensors, max_sizes, len(max_sizes)) |
| |
| # Test error inputs |
| with self.assertRaisesRegex(RuntimeError, "NestedTensor can only be reduced across the last"): |
| torch.nested.nested_tensor([torch.tensor([3, 4, 5]), torch.tensor([1, 2])]).sum(0, keepdim=True) |
| |
| with self.assertRaisesRegex(RuntimeError, "NestedTensor only allows reduction of a single"): |
| torch.nested.nested_tensor([torch.tensor([[3, 4, 5]]), torch.tensor([[1, 2]])]).sum([0, 1], keepdim=True) |
| |
| with self.assertRaisesRegex(RuntimeError, "NestedTensor always requires keepdim=True for now."): |
| torch.nested.nested_tensor([torch.tensor([3, 4, 5]), torch.tensor([1, 2])]).sum(-1) |
| |
| @dtypes(torch.float, torch.float16) |
| def test_contiguous(self, device, dtype): |
| # Since we don't have access to the buffer in python this is harder to show what |
| # we are testing for. When we call chunk on a consistent dim of a NT |
| # for chunk_size > 1 the resulting tensors are views of the original NT |
| # whose numels is now less than the size of the buffer. Clone was |
| # previously creating a new NT with a buffer that was the same size as the |
| # original. |
| nt_contiguous = torch.nested.nested_tensor([torch.randn(2, 20, device=device, dtype=dtype), |
| torch.randn(4, 20, device=device, dtype=dtype)]) |
| # Split up the last dimension which has a consistent size of 20 into 5 chunks |
| chunks = nt_contiguous.chunk(5, dim=-1) |
| |
| # # Check chunks are contiguous after calling contiguous |
| for chunk in chunks: |
| self.assertFalse(chunk.is_contiguous()) |
| self.assertTrue(chunk.contiguous().is_contiguous()) |
| |
| @dtypes(torch.float, torch.float16) |
| @skipMeta |
| def test_clone(self, device, dtype): |
| nt1 = random_nt(device, dtype, 4, (4, 4), (1, 1)) |
| nt2 = nt1.clone() |
| # Verify the values match |
| self.assertEqual(nt1, nt2) |
| # Verify modifying nt2 doesn't affect nt1 |
| nt2.mul_(nt1) |
| ub1 = nt1.unbind() |
| ub2 = nt2.unbind() |
| for i in range(len(ub1)): |
| self.assertNotEqual(ub1[i], ub2[i]) |
| |
| nt1.clone(memory_format=torch.preserve_format) |
| msg = "Nested tensor clone supports Preserve and Contiguous memory formats, called clone with memory format: ChannelsLast" |
| with self.assertRaisesRegex(RuntimeError, msg): |
| nt1.clone(memory_format=torch.channels_last) |
| |
| # cannot test torch.float16 because: RuntimeError: "bernoulli_scalar_cpu_" not implemented for 'Half' |
| @decorateIf(xfailIfTorchDynamo, lambda params: params["layout"] == torch.jagged) |
| @dtypes(torch.float, torch.double) |
| @parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name) |
| def test_dropout(self, device, dtype, layout): |
| # edge case: empty nested tensor |
| # TODO: support empty NT in jagged layout |
| if layout == torch.strided: |
| nt0 = torch.nested.nested_tensor([], layout=layout) |
| y = torch.nn.functional.dropout(nt0, 0.5) |
| self.assertEqual(nt0, y) |
| # normal nested tensor |
| ntensors = 4 |
| if layout == torch.jagged: |
| nt = random_nt(device, dtype, ntensors, (4, 4), (0, 3), layout=layout) |
| else: |
| nt = random_nt(device, dtype, ntensors, (4, 4), layout=layout) |
| # edge case: invalid dropout |
| self.assertRaises(ValueError, lambda: torch.nn.Dropout(-0.1)) |
| self.assertRaises(ValueError, lambda: torch.nn.Dropout(1.1)) |
| self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, -0.1)) |
| self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, 1.1)) |
| # edge case: no dropout |
| dropouter = torch.nn.Dropout(0.0) |
| y0 = dropouter(nt) |
| y1 = torch.nn.functional.dropout(nt, 0.0) |
| self.assertEqual(nt, y0) |
| self.assertEqual(nt, y1) |
| # edge case: all dropout |
| dropouter = torch.nn.Dropout(1.0) |
| y0 = dropouter(nt) |
| y1 = torch.nn.functional.dropout(nt, 1.0) |
| nt0 = torch.zeros_like(nt) |
| self.assertEqual(nt0, y0) |
| self.assertEqual(nt0, y1) |
| # normal case: normal dropout |
| p = 0.2 |
| y = torch.nn.functional.dropout(nt, p) |
| expect = nt.clone() |
| if layout == torch.jagged: |
| expect = torch.where(y == 0.0, y, nt) |
| expect /= 1.0 - p |
| self.assertEqual(y, expect) |
| else: |
| expect = nt.clone() |
| for i in range(ntensors): |
| actual_tensor = y[i].view(-1) |
| expect_tensor = expect[i].view(-1) |
| for j in range(actual_tensor.shape[0]): |
| if actual_tensor[j].item() == 0.0: |
| expect_tensor[j] = 0.0 |
| else: |
| expect_tensor[j] /= 1.0 - p |
| self.assertEqual(y, expect) |
| with freeze_rng_state(): |
| dropouter = torch.nn.Dropout(p) |
| y0 = dropouter(nt) |
| with freeze_rng_state(): |
| y1 = torch.nn.functional.dropout(nt, p) |
| self.assertEqual(y0, y1) |
| |
| @dtypes(torch.float, torch.double) |
| def test_dropout_noncontiguous(self, device, dtype): |
| ntensors = 4 |
| nt0 = random_nt(device, dtype, ntensors, (4, 4)) |
| nt1 = nt0.transpose(-1, -2) |
| p = 0.3 |
| with freeze_rng_state(): |
| dropouter = torch.nn.Dropout(p) |
| y0 = dropouter(nt0) |
| with freeze_rng_state(): |
| y1 = torch.nn.functional.dropout(nt1, p).transpose(-1, -2) |
| self.assertEqual(y0, y1) |
| |
| # cannot test torch.float16 because: RuntimeError: "softmax_kernel_impl" not implemented for 'Half' |
| @dtypes(torch.float, torch.double) |
| def test_softmax(self, device, dtype): |
| # normal nested tensor |
| ntensors = 4 |
| nt = random_nt(device, dtype, ntensors, (4, 4)) |
| # error case: softmax across nested dimension |
| self.assertRaisesRegex( |
| RuntimeError, |
| "Cannot apply softmax across nested dimension 0", |
| lambda: torch.nn.functional.softmax(nt, 0) |
| ) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "Cannot apply softmax across nested dimension 0", |
| lambda: torch.nn.functional.softmax(nt, -3) |
| ) |
| # error case: dimension out of range |
| self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt, 3)) |
| self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt, -4)) |
| # normal case: should equal to padding -inf |
| softmaxer = torch.nn.Softmax(1) |
| y0 = softmaxer(nt) |
| y1 = torch.nn.functional.softmax(nt, 1) |
| self.assertEqual(y0, y1) |
| pt = torch.nested.to_padded_tensor(nt, float("-inf")) |
| # if an entire slice is padded, then softmax will return 0.0 / 0.0 = nan |
| # however, physically speaking that should be 0.0 |
| expect = torch.nn.functional.softmax(pt, 1).nan_to_num_(0.0) |
| self.assertEqual(torch.nested.to_padded_tensor(y0, 0.0), expect) |
| # edge case: empty nested tensor |
| nt0 = torch.nested.nested_tensor([]) |
| y = torch.nn.functional.softmax(nt0, 1) |
| self.assertEqual(nt0, y) |
| # edge case: nesting scalars |
| nt1 = torch.nested.nested_tensor([torch.tensor(0.0), torch.tensor(1.0)]) |
| self.assertRaises(RuntimeError, lambda: torch.nn.functional.softmax(nt1, 0)) |
| self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt1, 1)) |
| |
| @dtypes(torch.float, torch.double) |
| @torch.inference_mode() |
| def test_softmax_noncontiguous(self, device, dtype): |
| nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype) |
| self.assertEqual( |
| torch.nn.functional.softmax(nt_contiguous, -1), |
| torch.nn.functional.softmax(nt_noncontiguous, -1)) |
| |
| def _test_bmm(self, device, dtype): |
| # error case: one is nested but the other is not |
| nt = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)], device=device, dtype=dtype) |
| t = torch.randn(4, device=device, dtype=dtype) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "Expected both to be nested, but got a nested self and non-nested other", |
| lambda: nt.bmm(t) |
| ) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "Expected both to be nested, but got a non-nested self and nested other", |
| lambda: t.bmm(nt) |
| ) |
| # error case: not 3D tensors |
| nt0 = torch.nested.nested_tensor([], device=device, dtype=dtype) |
| nt1 = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)], device=device, dtype=dtype) |
| nt2 = torch.nested.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "batch1 must be a 3D tensor", |
| lambda: nt0.bmm(nt0) |
| ) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "batch1 must be a 3D tensor", |
| lambda: nt0.bmm(nt1) |
| ) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "batch1 must be a 3D tensor", |
| lambda: nt0.bmm(nt2) |
| ) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "batch1 must be a 3D tensor", |
| lambda: nt1.bmm(nt0) |
| ) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "batch1 must be a 3D tensor", |
| lambda: nt1.bmm(nt1) |
| ) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "batch1 must be a 3D tensor", |
| lambda: nt1.bmm(nt2) |
| ) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "batch2 must be a 3D tensor", |
| lambda: nt2.bmm(nt0) |
| ) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "batch2 must be a 3D tensor", |
| lambda: nt2.bmm(nt1) |
| ) |
| # error case: incompatible batch size |
| nt0 = torch.nested.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype) |
| nt1 = torch.nested.nested_tensor([torch.randn((4, 6)), |
| torch.randn((4, 5)), |
| torch.randn((4, 7))], |
| device=device, dtype=dtype) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "Expected size for the 1st dimension of batch2 tensor to be: 2 but got: 3.", |
| lambda: nt0.bmm(nt1) |
| ) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "Expected size for the 1st dimension of batch2 tensor to be: 3 but got: 2.", |
| lambda: nt1.bmm(nt0) |
| ) |
| # error case: underlying matrices cannot be multiplied |
| nt0 = torch.nested.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype) |
| self.assertRaisesRegex( |
| RuntimeError, |
| r"0-th nested matrices in batch cannot be multiplied \(2x4 and 2x4\)", |
| lambda: nt0.bmm(nt0) |
| ) |
| # normal nested tensor |
| nt0 = torch.nested.nested_tensor([torch.randn((2, 4)), torch.randn((3, 7))], device=device, dtype=dtype) |
| nt1 = torch.nested.nested_tensor([torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype) |
| actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0) |
| expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(torch.nested.to_padded_tensor(nt1, 0.0)) |
| if dtype == torch.float16: |
| self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3) |
| else: |
| self.assertEqual(actual, expect) |
| |
| # test tensorcore path |
| nt0 = torch.nested.nested_tensor([torch.randn((2, 8)), torch.randn((3, 16))], device=device, dtype=dtype) |
| nt1 = torch.nested.nested_tensor([torch.randn((8, 8)), torch.randn((16, 8))], device=device, dtype=dtype) |
| actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0) |
| expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(torch.nested.to_padded_tensor(nt1, 0.0)) |
| if dtype == torch.float16: |
| self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3) |
| else: |
| self.assertEqual(actual, expect) |
| |
| @onlyCUDA |
| @dtypes(torch.float, torch.double, torch.float16) |
| def test_bmm_cuda(self, device, dtype): |
| self._test_bmm(device, dtype) |
| |
| @onlyCPU |
| # cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half' |
| @dtypes(torch.float, torch.double) |
| def test_bmm_cpu(self, device, dtype): |
| self._test_bmm(device, dtype) |
| |
| # cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half' |
| @dtypes(torch.float, torch.double) |
| def test_bmm_noncontiguous(self, device, dtype): |
| nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype) |
| nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair((6, 7), device, dtype) |
| self.assertEqual( |
| nt0_contiguous.transpose(-1, -2).bmm(nt1_contiguous), |
| nt0_noncontiguous.transpose(-1, -2).bmm(nt1_noncontiguous)) |
| |
| @dtypes(torch.float, torch.double) |
| def test_matmul_with_bmm_path(self, device, dtype): |
| def unbind_rebind_matmul(nt1, nt2): |
| t1s = nt1.unbind() |
| t2s = nt2.unbind() |
| out_ts = [t1.matmul(t2) for t1, t2 in zip(t1s, t2s)] |
| return torch.nested.nested_tensor(out_ts) |
| |
| # [N, n_head, *, head_dim], [N, n_head, head_dim, *] |
| Ns = [1, 2, 5] |
| n_heads = np.random.randint(2, 5) |
| head_dim = 3 |
| t1s = [] |
| t2s = [] |
| for N in Ns: |
| for _ in range(N): |
| seq_len1 = np.random.randint(2, 5) |
| seq_len2 = np.random.randint(2, 5) |
| t1s.append(torch.randn(n_heads, seq_len1, head_dim)) |
| t2s.append(torch.randn(n_heads, head_dim, seq_len2)) |
| nt1 = torch.nested.nested_tensor(t1s, device=device, dtype=dtype) |
| nt2 = torch.nested.nested_tensor(t2s, device=device, dtype=dtype) |
| self.assertEqual(torch.matmul(nt1, nt2), unbind_rebind_matmul(nt1, nt2)) |
| |
| # test with noncontiguous |
| t3s = [] |
| t4s = [] |
| for _ in range(N): |
| seq_len = np.random.randint(2, 5) |
| t3s.append(torch.randn(seq_len, n_heads, head_dim)) |
| t4s.append(torch.randn(seq_len, n_heads, head_dim)) |
| nt3 = torch.nested.nested_tensor(t3s, device=device, dtype=dtype).transpose(1, 2) |
| nt4 = torch.nested.nested_tensor(t4s, device=device, dtype=dtype).transpose(1, 2).transpose(2, 3) |
| self.assertEqual(torch.matmul(nt3, nt4), unbind_rebind_matmul(nt3, nt4)) |
| |
| # cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half' |
| @dtypes(torch.float, torch.double) |
| def test_matmul(self, device, dtype): |
| # error case: one is nested but the other is not |
| nt = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)], device=device, dtype=dtype) |
| t = torch.randn(4, device=device, dtype=dtype) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "Expected both to be nested, but got a nested self and non-nested other", |
| lambda: torch.matmul(nt, t) |
| ) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "Expected both to be nested, but got a non-nested self and nested other", |
| lambda: torch.matmul(t, nt) |
| ) |
| # error case: not 3+D tensors |
| nt0 = torch.nested.nested_tensor([], device=device, dtype=dtype) |
| nt1 = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)], device=device, dtype=dtype) |
| nt2 = torch.nested.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype) |
| self.assertRaisesRegex( |
| RuntimeError, |
| r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", |
| lambda: torch.matmul(nt0, nt0) |
| ) |
| self.assertRaisesRegex( |
| RuntimeError, |
| r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", |
| lambda: torch.matmul(nt0, nt1) |
| ) |
| self.assertRaisesRegex( |
| RuntimeError, |
| r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", |
| lambda: torch.matmul(nt0, nt2) |
| ) |
| self.assertRaisesRegex( |
| RuntimeError, |
| r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", |
| lambda: torch.matmul(nt1, nt0) |
| ) |
| self.assertRaisesRegex( |
| RuntimeError, |
| r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", |
| lambda: torch.matmul(nt1, nt1) |
| ) |
| self.assertRaisesRegex( |
| RuntimeError, |
| r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+", |
| lambda: torch.matmul(nt1, nt2) |
| ) |
| self.assertRaisesRegex( |
| RuntimeError, |
| r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: [0-9]+", |
| lambda: torch.matmul(nt2, nt0) |
| ) |
| self.assertRaisesRegex( |
| RuntimeError, |
| r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: [0-9]+", |
| lambda: torch.matmul(nt2, nt1) |
| ) |
| # error case: incompatible batch size |
| nt0 = torch.nested.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype) |
| nt1 = torch.nested.nested_tensor([torch.randn((4, 6)), |
| torch.randn((4, 5)), |
| torch.randn((4, 7))], |
| device=device, dtype=dtype) |
| self.assertRaisesRegex( |
| RuntimeError, |
| r"matmul: Expected size for the 1st dimension of 2nd input tensor to be: [0-9]+ but got: [0-9]+.", |
| lambda: torch.matmul(nt0, nt1) |
| ) |
| self.assertRaisesRegex( |
| RuntimeError, |
| r"matmul: Expected size for the 1st dimension of 2nd input tensor to be: [0-9]+ but got: [0-9]+.", |
| lambda: torch.matmul(nt1, nt0) |
| ) |
| # error case: incompatible (wrong) batch sizes that shouldn't even broadcast? |
| nt0 = torch.nested.nested_tensor([torch.randn((2, 2, 4)), |
| torch.randn((2, 3, 4))], |
| device=device, dtype=dtype) |
| nt1 = torch.nested.nested_tensor([torch.randn((3, 4, 6)), |
| torch.randn((3, 4, 5))], |
| device=device, dtype=dtype) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "matmul(): For nested tensors, batch dimensions must have the same sizes,", |
| lambda: torch.matmul(nt0, nt1) |
| ) |
| # error case: incompatible batch sizes that should technically broadcast |
| nt0 = torch.nested.nested_tensor([torch.randn((2, 2, 4)), |
| torch.randn((1, 3, 4))], |
| device=device, dtype=dtype) |
| nt1 = torch.nested.nested_tensor([torch.randn((1, 4, 6)), |
| torch.randn((3, 4, 5))], |
| device=device, dtype=dtype) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "matmul(): For nested tensors, batch dimensions must have the same sizes,", |
| lambda: torch.matmul(nt0, nt1) |
| ) |
| # error case: underlying matrices cannot be multiplied |
| nt0 = torch.nested.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "matmul(): Nested tensors cannot be matrix multiplied", |
| lambda: torch.matmul(nt0, nt0) |
| ) |
| # normal nested tensor: 3D |
| nt0 = torch.nested.nested_tensor([torch.randn((2, 4)), torch.randn((3, 7))], device=device, dtype=dtype) |
| nt1 = torch.nested.nested_tensor([torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype) |
| actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0) |
| expect = torch.matmul(torch.nested.to_padded_tensor(nt0, 0.0), torch.nested.to_padded_tensor(nt1, 0.0)) |
| self.assertEqual(actual, expect) |
| # normal nested tensor: 4D (with testing for batch_size=1) |
| nt0 = torch.nested.nested_tensor([torch.randn((1, 2, 4)), |
| torch.randn((8, 3, 7))], |
| device=device, dtype=dtype) |
| nt1 = torch.nested.nested_tensor([torch.randn((1, 4, 6)), |
| torch.randn((8, 7, 5))], |
| device=device, dtype=dtype) |
| actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0) |
| expect = torch.matmul(torch.nested.to_padded_tensor(nt0, 0.0), torch.nested.to_padded_tensor(nt1, 0.0)) |
| self.assertEqual(actual, expect) |
| # normal nested tensor: 5D |
| nt0 = torch.nested.nested_tensor([torch.randn((8, 9, 2, 4)), |
| torch.randn((8, 9, 3, 7))], |
| device=device, dtype=dtype) |
| nt1 = torch.nested.nested_tensor([torch.randn((8, 9, 4, 6)), |
| torch.randn((8, 9, 7, 5))], |
| device=device, dtype=dtype) |
| actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0) |
| expect = torch.matmul(torch.nested.to_padded_tensor(nt0, 0.0), torch.nested.to_padded_tensor(nt1, 0.0)) |
| self.assertEqual(actual, expect) |
| |
| # only supported on CUDA for now |
| @dtypes(torch.float, torch.double) |
| def test_matmul_nt_with_broadcasted_t(self, device, dtype): |
| # NT (B, *, C, D) with T (D, E) broadcasting case |
| nt = random_nt_from_dims([3, None, 4, 5], device=device, dtype=dtype) |
| t = torch.randn(5, 6, device=device, dtype=dtype) |
| output = torch.matmul(nt, t) |
| |
| # should be equivalent to matmul-ing each component with the dense tensor |
| self.assertEqual(nt.size(0), output.size(0)) |
| for component, out_component in zip(nt, output): |
| self.assertEqual(out_component, torch.matmul(component, t)) |
| |
| # cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half' |
| @dtypes(torch.float, torch.double) |
| def test_matmul_noncontiguous(self, device, dtype): |
| nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype) |
| nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair((6, 7), device, dtype) |
| self.assertEqual( |
| torch.matmul(nt0_contiguous.transpose(-1, -2), nt1_contiguous), |
| torch.matmul(nt0_noncontiguous.transpose(-1, -2), nt1_noncontiguous)) |
| |
| @dtypes(torch.float, torch.double) |
| def test_linear(self, device, dtype): |
| a = torch.randn(1, 2, device=device, dtype=dtype) |
| b = torch.randn(2, 2, device=device, dtype=dtype) |
| c = torch.randn(3, 2, device=device, dtype=dtype) |
| nt = torch.nested.nested_tensor([a, b, c]) |
| |
| weight = torch.randn(2, 2, device=device, dtype=dtype) |
| bias = torch.randn(2, device=device, dtype=dtype) |
| # success case |
| torch.functional.F.linear(nt, weight, bias) |
| |
| # invalid nested tensor dimension |
| msg = r'Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 2. Dense tensor dim: 2' |
| nt1 = torch.nested.nested_tensor([torch.randn(1, device=device, dtype=dtype), |
| torch.randn(2, device=device, dtype=dtype)]) |
| with self.assertRaisesRegex(RuntimeError, msg): |
| torch.functional.F.linear(nt1, weight, bias) |
| |
| # invalid weight shape |
| msg = r'Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 3. Dense tensor dim: 3' |
| weight1 = torch.randn(2, 2, 3, device=device, dtype=dtype) |
| with self.assertRaisesRegex(RuntimeError, msg): |
| torch.functional.F.linear(nt, weight1, bias) |
| |
| # inconsistent last dim of nested tensor |
| msg = r"Expected all tensors in nested tensor to have the same trailing dimension, instead last dimension equals:" |
| nt2 = torch.nested.nested_tensor([torch.randn(1, 2, device=device, dtype=dtype), |
| torch.randn(2, 3, device=device, dtype=dtype)]) |
| with self.assertRaisesRegex(RuntimeError, msg): |
| torch.functional.F.linear(nt2, weight, bias) |
| |
| # Mismatch of nested tensor last dim and weight dimension |
| weight2 = torch.randn(2, 4, device=device, dtype=dtype) |
| msg = r"Shape mismatch for NestedTensor Linear: Expected input's \(a nested tensor\) 'last_dim'" \ |
| r" to equal 'weight.size\(1\), but got: last_dim = 2, and weight.size\(1\) = 4" |
| with self.assertRaisesRegex(RuntimeError, msg): |
| torch.functional.F.linear(nt, weight2, bias) |
| |
| # Nested tensor input and nested weight |
| nt_weight = nt.clone() |
| msg = r"Linear does not support nested weight when input is a nested tensor." |
| with self.assertRaisesRegex(RuntimeError, msg): |
| torch.functional.F.linear(nt, nt_weight, bias) |
| |
| # TODO: test noncontiguous linear |
| # For now this tests the error message of linear |
| # since linear does not support noncontiguous buffer yet |
| @dtypes(torch.float, torch.double) |
| def test_linear_noncontiguous(self, device, dtype): |
| nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype) |
| weight = torch.randn((8, 5), device=device, dtype=dtype) |
| self.assertRaisesRegex( |
| RuntimeError, |
| r"for now linear only supports contiguous nested tensor", |
| lambda: torch.nn.functional.linear(nt_noncontiguous, weight) |
| ) |
| |
| @dtypes(torch.float, torch.float16, torch.double) |
| def test_to_padded_tensor_zero_numel_errors(self, device, dtype): |
| ts = [torch.ones(1, 0), torch.ones(0, 0)] |
| nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype, layout=torch.strided) |
| self.assertRaisesRegex( |
| RuntimeError, |
| r"at least one constituent tensor should have non-zero numel", |
| lambda: torch.nested.to_padded_tensor(nt, 0.0) |
| ) |
| |
| @dtypes(torch.float, torch.float16, torch.double) |
| def test_transpose(self, device, dtype): |
| nt = random_nt(device, dtype, 4, (4, 4)) |
| # error case: transpose nested dimension |
| self.assertRaisesRegex( |
| RuntimeError, |
| "Nested tensor dimension 0 cannot be transposed", |
| lambda: nt.transpose(0, 1) |
| ) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "Nested tensor dimension 0 cannot be transposed", |
| lambda: nt.transpose(1, -3) |
| ) |
| # error case: dimension out of range |
| self.assertRaises(IndexError, lambda: nt.transpose(1, 3)) |
| self.assertRaises(IndexError, lambda: nt.transpose(-4, -1)) |
| # normal case |
| ntT = nt.transpose(-1, -2) |
| ptT_from_ntT = noncontiguous_to_padded_tensor(ntT) |
| pt = torch.nested.to_padded_tensor(nt, 0.0) |
| ptT = pt.transpose(-1, -2) |
| self.assertEqual(ptT, ptT_from_ntT) |
| |
| @dtypes(torch.float, torch.float16, torch.double) |
| def test_squeeze_unsqueeze(self, device, dtype): |
| a = torch.arange(6).reshape(2, 3) |
| b = torch.arange(15).reshape(5, 3) |
| nt = torch.nested.nested_tensor([a, b], device=device, dtype=dtype) |
| # error case: squeeze no dimension |
| self.assertRaisesRegex( |
| RuntimeError, |
| "For nested tensors, squeeze without the dim argument", |
| lambda: nt.squeeze() |
| ) |
| # error case: squeeze nested dimension |
| self.assertRaisesRegex( |
| RuntimeError, |
| "For nested tensors, squeezing dimension 0", |
| lambda: nt.squeeze(0) |
| ) |
| # error case: dimension out of range |
| self.assertRaises(IndexError, lambda: nt.squeeze(3)) |
| # error case: squeeze nested tensor of singleton tensors |
| c = torch.ones(1) |
| nt_singleton = torch.nested.nested_tensor([c, c], device=device, dtype=dtype) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "For nested tensors, squeezing a nested tensor of singleton", |
| lambda: nt_singleton.squeeze(1) |
| ) |
| |
| # squeezing a dim which does not have size 1 should be a no-op |
| nt2 = nt.squeeze(-1) |
| self.assertEqual(nt, nt2) |
| |
| # test cases that should work |
| nt_sizes = nt._nested_tensor_size() |
| nt_strides = nt._nested_tensor_strides() |
| for i in range(-2, 4): |
| if (i == 0): |
| # cannot unsqueeze batch dim |
| continue |
| nt_unsqueezed = nt.unsqueeze(i) |
| # negative dim will correspond to unsqueeze() applied at dim = dim + nt.dim() + 1 |
| wrapped_i = i + nt.dim() + 1 if i < 0 else i |
| # col_index into nt size tensor is requires subtraction of 1 to ignore batch dim |
| size_idx = wrapped_i - 1 |
| self.assertEqual(nt_unsqueezed._nested_tensor_size()[:, size_idx], torch.ones(2, dtype=torch.long)) |
| unsqueezed_stride = nt_unsqueezed._nested_tensor_strides()[:, size_idx] |
| if (i == nt.ndim or i == -1): |
| self.assertEqual(unsqueezed_stride, torch.ones(2, dtype=torch.long)) |
| else: |
| stride_col_after = nt_strides[:, size_idx] |
| size_col_after = nt_sizes[:, size_idx] |
| self.assertEqual(unsqueezed_stride, stride_col_after * size_col_after) |
| nt_squeezed = nt_unsqueezed.squeeze(i) |
| self.assertEqual(nt_squeezed, nt) |
| self.assertEqual(nt_squeezed._nested_tensor_size(), nt_sizes) |
| self.assertEqual(nt_squeezed._nested_tensor_strides(), nt_strides) |
| |
| @dtypes(torch.float, torch.float16, torch.double) |
| def test_transpose_inference_mode_interaction(self, device, dtype): |
| nt = random_nt(device, dtype, 4, (4, 4)) |
| # Construct in default mode and transpose while in inference mode |
| with torch.inference_mode(): |
| ntT = nt.transpose(-1, -2) |
| ptT_from_ntT = noncontiguous_to_padded_tensor(ntT) |
| pt = torch.nested.to_padded_tensor(nt, 0.0) |
| ptT = pt.transpose(-1, -2) |
| self.assertEqual(ptT, ptT_from_ntT) |
| |
| # Construct and transpose while in inference mode |
| with torch.inference_mode(): |
| nt = random_nt(device, dtype, 4, (4, 4)) |
| ntT = nt.transpose(-1, -2) |
| ptT_from_ntT = noncontiguous_to_padded_tensor(ntT) |
| pt = torch.nested.to_padded_tensor(nt, 0.0) |
| ptT = pt.transpose(-1, -2) |
| self.assertEqual(ptT, ptT_from_ntT) |
| |
| @dtypes(torch.float, torch.float16, torch.double) |
| def test_view(self, device, dtype): |
| nt = random_nt(device, dtype, 4, (4, 4)) |
| # error case: empty shape |
| self.assertRaisesRegex( |
| RuntimeError, |
| r"shape '\[\]' is invalid for a nested tensor", |
| lambda: nt.view(()) |
| ) |
| # error case: empty nested tensor |
| nt_empty = torch.nested.nested_tensor([]) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "empty nested tensor cannot be reshaped", |
| lambda: nt_empty.view(-1) |
| ) |
| # error case: -1 for batch size |
| self.assertRaisesRegex( |
| RuntimeError, |
| r"view: For now nested view cannot change or infer the implicit batch dimension", |
| lambda: nt.view(-1, 2, 3) |
| ) |
| self.assertRaisesRegex( |
| RuntimeError, |
| r"shape '\[.*\]' is invalid for input of size [0-9]+", |
| lambda: nt.view(4, 2, 3) |
| ) |
| # normal case |
| x0 = torch.randn((2, 20), device=device, dtype=dtype) |
| x1 = torch.randn((3, 20), device=device, dtype=dtype) |
| nt = torch.nested.nested_tensor([x0, x1]) |
| pt = torch.nested.to_padded_tensor(nt, 0.0) |
| # error case, trying to reshape batch dim to a legit shape |
| self.assertRaisesRegex( |
| RuntimeError, |
| r"For now nested view cannot change or infer the implicit batch dimension", |
| lambda: nt.transpose(-1, -2).view(40, -1) |
| ) |
| # inherit only the ragged dimension |
| # (2, 20) -> (2, 5, 4) |
| # (3, 20) -> (3, 5, 4) |
| nt1 = nt.view(2, -1, 5, 4) |
| # (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4) |
| pt1 = pt.view(2, -1, 5, 4) |
| self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1) |
| |
| # more than one -1 (even for "old" dims), should fail |
| # this attempts to do # (2, (2, 3), 5, 4) -> (2, (2, 3), 5, 2, 2) |
| # but we ban "inherit old behavior" for >1 dimension |
| self.assertRaisesRegex( |
| RuntimeError, |
| r"only one dimension can be inferred", |
| lambda: nt1.view(2, -1, -1, 2, 2) |
| ) |
| |
| @dtypes(torch.float, torch.float16, torch.double) |
| def test_view_inference_mode_interaction(self, device, dtype): |
| # Construct in default mode and view while in inference mode |
| nt = torch.nested.nested_tensor([torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype) |
| with torch.inference_mode(): |
| ntT = nt.view(2, -1, 4, 5) |
| ptT_from_ntT = noncontiguous_to_padded_tensor(ntT) |
| pt = torch.nested.to_padded_tensor(nt, 0.0) |
| ptT = pt.view(2, -1, 4, 5) |
| self.assertEqual(ptT, ptT_from_ntT) |
| # Construct and view while in inference mode |
| with torch.inference_mode(): |
| nt = torch.nested.nested_tensor([torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype) |
| ntT = nt.view(2, -1, 4, 5) |
| ptT_from_ntT = noncontiguous_to_padded_tensor(ntT) |
| pt = torch.nested.to_padded_tensor(nt, 0.0) |
| ptT = pt.view(2, -1, 4, 5) |
| self.assertEqual(ptT, ptT_from_ntT) |
| |
| @dtypes(torch.float, torch.float16, torch.double) |
| def test_reshape(self, device, dtype): |
| nt = random_nt(device, dtype, 4, (4, 4)) |
| # error case: empty shape |
| self.assertRaisesRegex( |
| RuntimeError, |
| r"shape '\[\]' is invalid for a nested tensor", |
| lambda: nt.reshape(()) |
| ) |
| # error case: empty nested tensor |
| nt_empty = torch.nested.nested_tensor([]) |
| self.assertRaisesRegex( |
| RuntimeError, |
| "empty nested tensor cannot be reshaped", |
| lambda: nt_empty.reshape(-1) |
| ) |
| # error case: -1 for batch size |
| self.assertRaisesRegex( |
| RuntimeError, |
| r"reshape: For now nested reshape cannot change or infer the implicit batch dimension", |
| lambda: nt.reshape(-1, 2, 3) |
| ) |
| self.assertRaisesRegex( |
| RuntimeError, |
| r"shape '\[.*\]' is invalid for input of size [0-9]+", |
| lambda: nt.reshape(4, 2, 3) |
| ) |
| # normal case |
| x0 = torch.randn((2, 20), device=device, dtype=dtype) |
| x1 = torch.randn((3, 20), device=device, dtype=dtype) |
| nt = torch.nested.nested_tensor([x0, x1]) # (2, (2, 3), 20) |
| pt = torch.nested.to_padded_tensor(nt, 0.0) |
| # error case, trying to reshape batch dim to a legit shape |
| self.assertRaisesRegex( |
| RuntimeError, |
| r"reshape: For now nested reshape cannot change or infer the implicit batch dimension", |
| lambda: nt.transpose(-1, -2).reshape(40, -1) |
| ) |
| # inherit only the ragged dimension |
| # (2, 20) -> (2, 5, 4) |
| # (3, 20) -> (3, 5, 4) |
| nt1 = nt.reshape(2, -1, 5, 4) |
| # (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4) |
| pt1 = pt.reshape(2, -1, 5, 4) |
| self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1) |
| |
| # more than one -1 (even for "old" dims), should fail |
| # this attempts to do # (2, (2, 3), 5, 4) -> (2, (2, 3), 5, 2, 2) |
| # but we ban "inherit old behavior" for >1 dimension |
| self.assertRaisesRegex( |
| RuntimeError, |
| r"only one dimension can be inferred", |
| lambda: nt1.reshape(2, -1, -1, 2, 2) |
| ) |
| |
| @dtypes(torch.float, torch.float16, torch.double) |
| def test_narrow(self, device, dtype): |
| nt = random_nt_from_dims([5, None, None, None], device=device, dtype=dtype) |
| |
| # narrow on dim=0 from start to end |
| bounds = [(0, 5), (0, 3), (1, 2), (1, 5), (2, 4)] |
| for start, end in bounds: |
| length = end - start |
| narrowed = nt.narrow(dim=0, start=start, length=length) |
| # ensure output is a view |
| self.assertTrue(narrowed._base is nt) |
| for nc, c in zip(narrowed.unbind(), nt.unbind()[start:end]): |
| self.assertEqual(nc, c) |
| |
| # dim != 0 is not supported |
| for dim in range(1, nt.dim()): |
| with self.assertRaisesRegex(RuntimeError, "only dim=0 supported for nested tensors"): |
| nt.narrow(dim=dim, start=0, length=1) |
| |
| # error case: non-contiguous NT |
| _, nt_noncont = random_nt_noncontiguous_pair((2, 3, 4)) |
| with self.assertRaisesRegex(RuntimeError, "only contiguous nested tensors supported"): |
| nt_noncont.narrow(dim=0, start=0, length=1) |
| |
| @parametrize("input_dim", [3, 4]) |
| def test_scaled_dot_product_attention(self, device, input_dim): |
| |
| def rand_tensor(*shape): |
| return torch.randn(shape, device=device) |
| |
| E = 8 |
| if input_dim == 3: |
| # Shape: (N, L, E); ragged L |
| query = torch.nested.nested_tensor([rand_tensor(2, E), rand_tensor(3, E), rand_tensor(4, E)]) |
| |
| # Shape: (N, S, E); ragged S |
| key = torch.nested.nested_tensor([rand_tensor(3, E), rand_tensor(4, E), rand_tensor(5, E)]) |
| value = torch.nested.nested_tensor([rand_tensor(3, E), rand_tensor(4, E), rand_tensor(5, E)]) |
| elif input_dim == 4: |
| # In the 4D case the L and S is ragged |
| # Shape: (N, N', L, E); ragged N' and L |
| query = torch.nested.nested_tensor([rand_tensor(2, 2, E), rand_tensor(3, 3, E), rand_tensor(4, 4, E)]) |
| # Shape: (N, N', S, E); ragged N' and S |
| key = torch.nested.nested_tensor([rand_tensor(2, 3, E), rand_tensor(3, 4, E), rand_tensor(4, 5, E)]) |
| value = torch.nested.nested_tensor([rand_tensor(2, 3, E), rand_tensor(3, 4, E), rand_tensor(4, 5, E)]) |
| else: |
| self.fail(f"Invalid input_dim {input_dim} encountered in SDP test") |
| |
| def rand_mask(size): |
| return torch.randint(0, 2, size=size, dtype=torch.bool, device=device) |
| |
| # Shape: (N, L, S); ragged L and S matching above |
| attn_mask = torch.nested.nested_tensor([rand_mask((2, 3)), rand_mask((3, 4)), rand_mask((4, 5))]) |
| |
| dropout_p = 0.0 # no dropout for reproducibility |
| |
| # Success case: no attn_mask set and is_causal=False. |
| actual = torch.nn.functional.scaled_dot_product_attention( |
| query, key, value, attn_mask=None, is_causal=False, dropout_p=dropout_p) |
| |
| expected_outputs = [] |
| for q, k, v in zip(query.unbind(), key.unbind(), value.unbind()): |
| output = torch.nn.functional.scaled_dot_product_attention( |
| q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), attn_mask=None, dropout_p=dropout_p) |
| expected_outputs.append(output.squeeze(0)) |
| expected_output_nested = torch.nested.nested_tensor(expected_outputs) |
| self.assertEqual(actual, expected_output_nested) |
| |
| # Error case: explicit attn_mask set. |
| with self.assertRaisesRegex(RuntimeError, "not supported when an explicit attn_mask is set"): |
| torch.nn.functional.scaled_dot_product_attention( |
| query, key, value, attn_mask=attn_mask, dropout_p=dropout_p) |
| |
| # Error case: is_causal=True. |
| with self.assertRaisesRegex(RuntimeError, "not supported when is_causal=True"): |
| torch.nn.functional.scaled_dot_product_attention( |
| query, key, value, dropout_p=dropout_p, is_causal=True) |
| |
| @dtypes(torch.float, torch.float16, torch.double) |
| def test_empty_like(self, device, dtype): |
| ntensors = 4 |
| nt = random_nt(device, dtype, ntensors, (4, 4)) |
| |
| # Create empty on same device as original nested tensor |
| nt_empty = torch.empty_like(nt) |
| assert nt.is_same_size(nt_empty) |
| self.assertEqual(nt.dtype, nt_empty.dtype) |
| self.assertEqual(nt.device, nt_empty.device) |
| self.assertEqual(nt.layout, nt_empty.layout) |
| |
| if torch.cuda.is_available(): |
| if device == "cpu": |
| nt_cuda = torch.empty_like(nt, device='cuda') |
| self.assertEqual(torch.device("cuda").type, nt_cuda.device.type) |
| else: |
| nt_cpu = torch.empty_like(nt, device='cpu') |
| self.assertEqual(torch.device("cpu").type, nt_cpu.device.type) |
| |
| # Check changing dtype of empty_like nested tensor output |
| dtype_set = {torch.float, torch.float16, torch.double} |
| for other_dtype in dtype_set - {dtype}: |
| nt_empty_other_dtype = torch.empty_like(nt, dtype=other_dtype) |
| self.assertEqual(nt.dtype, dtype) |
| self.assertEqual(nt_empty_other_dtype.dtype, other_dtype) |
| self.assertEqual(nt.device, nt_empty.device) |
| self.assertEqual(nt.layout, nt_empty.layout) |
| |
| # Create tensor for autograd |
| nt_empty_req_grad = torch.empty_like(nt, requires_grad=True) |
| self.assertEqual(nt_empty_req_grad.requires_grad, True) |
| |
| # Test noncontiguous tensor does not fail to copy |
| nt_cont, nt_noncont = random_nt_noncontiguous_pair((2, 3, 6, 7)) |
| nt_empty = torch.empty_like(nt_cont) |
| assert nt_cont.is_same_size(nt_empty) |
| nt_empty_non_contig = torch.empty_like(nt_noncont) |
| assert nt_noncont.is_same_size(nt_empty_non_contig) |
| |
| # Test the contiguous memory format option |
| nt_empty_contig = torch.empty_like(nt_cont, memory_format=torch.contiguous_format) |
| assert nt_cont.is_same_size(nt_empty_contig) |
| assert nt_empty_contig.is_contiguous() |
| |
| nt_empty_non_contig = torch.empty_like(nt_noncont, memory_format=torch.contiguous_format) |
| assert nt_noncont.is_same_size(nt_empty_non_contig) |
| assert nt_empty_non_contig.is_contiguous() |
| |
| # Test other memory formats fail |
| self.assertRaises(RuntimeError, lambda: torch.empty_like(nt_cont, memory_format=torch.channels_last)) |
| self.assertRaises(RuntimeError, lambda: torch.empty_like(nt_noncont, memory_format=torch.channels_last)) |
| self.assertRaises(RuntimeError, lambda: torch.empty_like(nt_cont, memory_format=torch.channels_last_3d)) |
| self.assertRaises(RuntimeError, lambda: torch.empty_like(nt_noncont, memory_format=torch.channels_last_3d)) |
| |
| @markDynamoStrictTest |
| class TestNestedTensorAutograd(TestCase): |
| # Note [Gradcheck args check_batched_grad=False] the common_utils testing version of gradcheck |
| # includes the default parameters used for testing ops with gradcheck. However nested tensor |
| # does not support the stack op therefore we turn it off for these tests |
| def _create_leaf_nested_tensor_from_list(self, tensor_device, requires_grad=False): |
| return torch.nested.nested_tensor([torch.randn(1, 2,), |
| torch.randn(7, 8)], requires_grad=requires_grad, device=tensor_device) |
| |
| def _create_nested_tensor_from_list(self, tensor_device, requires_grad=False): |
| return torch.nested.as_nested_tensor([torch.randn(1, 2, requires_grad=requires_grad), |
| torch.randn(7, 8, requires_grad=requires_grad)], device=tensor_device) |
| |
| def _create_nested_tensor_from_mask(self, tensor_device, requires_grad=False): |
| data = torch.randn(2, 3, 4, requires_grad=requires_grad, device=tensor_device) |
| mask = torch.ones_like(data[:, :, 0]).bool() |
| return torch._nested_tensor_from_mask(data, mask) |
| |
| def test_as_nested_tensor_propagates_gradients(self, device): |
| a = torch.arange(3, dtype=torch.float, device=device) |
| b = torch.arange(5, dtype=torch.float, device=device) |
| nt = torch.nested.as_nested_tensor([a, b]) |
| # tensors with requires_grad=False are leaves |
| self.assertTrue(nt.is_leaf) |
| self.assertTrue(not nt.requires_grad) |
| |
| a = torch.arange(3, dtype=torch.float, requires_grad=True, device=device) |
| b = torch.arange(5, dtype=torch.float, requires_grad=True, device=device) |
| nt2 = torch.nested.as_nested_tensor([a, b]) |
| fake_grad = torch.nested.nested_tensor([torch.ones_like(a), torch.zeros_like(b)], device=device) |
| nt2.backward(fake_grad) |
| self.assertEqual(a.grad, fake_grad[0]) |
| self.assertEqual(b.grad, fake_grad[1]) |
| |
| def test_nested_tensor_generates_leaf(self, device): |
| a = torch.arange(3, dtype=torch.float, requires_grad=True, device=device) |
| b = torch.arange(5, dtype=torch.float, requires_grad=True, device=device) |
| |
| nt = torch.nested.nested_tensor([a, b], requires_grad=False) |
| self.assertTrue(nt.is_leaf) |
| self.assertTrue(not nt.requires_grad) |
| |
| nt2 = torch.nested.nested_tensor([a, b], requires_grad=True) |
| self.assertTrue(nt2.is_leaf) |
| self.assertTrue(nt2.requires_grad) |
| |
| fake_grad = torch.nested.nested_tensor([torch.ones_like(a), torch.zeros_like(b)], device=device) |
| nt2.backward(fake_grad) |
| self.assertEqual(nt2.grad, fake_grad) |
| self.assertEqual(a.grad, None) |
| self.assertEqual(b.grad, None) |
| |
| def test_set_requires_grad_from_list(self, device): |
| nt = self._create_nested_tensor_from_list(device) |
| nt.requires_grad_() |
| assert nt.requires_grad |
| |
| def test_set_requires_grad_from_mask(self, device): |
| nt = self._create_nested_tensor_from_mask(device) |
| nt.requires_grad_() |
| assert nt.requires_grad |
| |
| def test_backward_for_add_op(self, device): |
| nt_1 = self._create_nested_tensor_from_mask(device) |
| nt_2 = self._create_nested_tensor_from_mask(device) |
| |
| nt_1.requires_grad_() |
| c = nt_1 + nt_2 |
| |
| assert nt_1.requires_grad |
| assert c.requires_grad |
| grad_output = self._create_nested_tensor_from_mask(device) |
| c.backward(grad_output) |
| |
| # Grad check doesn't work with nested yet. |
| # d/dnt_1 (nt + nt_1) = 1*grad_output |
| self.assertEqual(nt_1.grad, grad_output) |
| |
| def test_backward_for_sub_op(self, device): |
| nt_1 = self._create_nested_tensor_from_mask(device) |
| nt_2 = self._create_nested_tensor_from_mask(device) |
| |
| nt_1.requires_grad_() |
| nt_2.requires_grad_() |
| c = nt_1 - nt_2 |
| |
| assert nt_1.requires_grad |
| assert nt_2.requires_grad |
| assert c.requires_grad |
| grad_output = self._create_nested_tensor_from_mask(device) |
| c.backward(grad_output) |
| |
| self.assertEqual(nt_1.grad, grad_output) |
| self.assertEqual(nt_2.grad, -1 * grad_output) |
| |
| def test_backward_sub_strided(self, device): |
| a = torch.nested.nested_tensor([torch.randn(9, 2, 4), torch.randn(12, 2, 4)], requires_grad=True, device=device) |
| b = torch.nested.nested_tensor([torch.randn(9, 4, 2), torch.randn(12, 4, 2)], requires_grad=True, device=device) |
| c = a - b.transpose(-1, -2) |
| grad_output = c.clone() |
| c.backward(grad_output) |
| self.assertEqual(a.grad, grad_output) |
| self.assertEqual(b.grad, -1 * grad_output.transpose(-1, -2)) |
| |
| def test_backward_add_strided(self, device): |
| a = torch.nested.nested_tensor([torch.randn(9, 2, 4), torch.randn(12, 2, 4)], requires_grad=True, device=device) |
| b = torch.nested.nested_tensor([torch.randn(9, 4, 2), torch.randn(12, 4, 2)], requires_grad=True, device=device) |
| c = a + b.transpose(-1, -2) |
| grad_output = c.clone() |
| c.backward(grad_output) |
| self.assertEqual(a.grad, grad_output) |
| self.assertEqual(b.grad, grad_output.transpose(-1, -2)) |
| |
| # Test Factory Functions |
| def test_nested_tensor_to_padded_tensor(self, device): |
| for padding_val in [0, 1]: |
| nt = self._create_leaf_nested_tensor_from_list(tensor_device=device, requires_grad=True) |
| |
| out = torch.nested.to_padded_tensor(nt, padding_val) |
| grad_output = torch.ones(out.shape, device=device) |
| out.backward(grad_output) |
| |
| self.assertEqual(nt.grad, torch.nested.nested_tensor([torch.ones(1, 2), torch.ones(7, 8)], device=device)) |
| |
| def test_nested_tensor_from_mask_and_to_padded(self, device): |
| N, L, D = 2, 4, 4 |
| mask = torch.ones(N, L, device=device) |
| for i in range(1, N): |
| end = torch.randint(1, L - 1, (1,), device=device) |
| mask[i, end:] = 0 |
| |
| mask[0, :] = 1 |
| mask = mask.bool() |
| |
| data = torch.randn(N, L, D, requires_grad=True, dtype=torch.float64, device=device) |
| |
| def grad_test_func(inpt): |
| nt = torch._nested_tensor_from_mask(inpt, mask) |
| # This implicitly tests to_padded_tensor grads |
| return torch.nested.to_padded_tensor(nt, 0) |
| assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) |
| |
| def test_nested_tensor_from_padded(self, device): |
| nested_size = torch.tensor([[1, 2], [2, 2]]) |
| padded_tensor = torch.randn(2, 2, 2, dtype=torch.float64, device=device) |
| padded_tensor[0, 1, :] = 0 |
| padded_tensor.requires_grad_() |
| |
| def grad_test_func(tensor, nested_size): |
| nt = torch._nested_from_padded(tensor, nested_size, fuse_transform_0213=False) |
| # This implicitly tests to_padded_tensor grads |
| return torch.nested.to_padded_tensor(nt, 0) |
| |
| data = (padded_tensor, nested_size) |
| assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) |
| |
| def test_nested_tensor_from_padded_fused(self, device): |
| nested_size = torch.tensor([[1, 8], [2, 8]]) |
| padded_tensor = torch.randn(2, 2, 2, 4, dtype=torch.float64, device=device) |
| padded_tensor[0, 1, :] = 0 |
| padded_tensor.requires_grad_() |
| |
| def grad_test_func(tensor, nested_size): |
| nt = torch._nested_from_padded(tensor, nested_size, fuse_transform_0213=True) |
| # This implicitly tests to_padded_tensor grads |
| return torch.nested.to_padded_tensor(nt, 0) |
| data = (padded_tensor, nested_size) |
| assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) |
| |
| def test_nested_tensor_from_list(self, device): |
| |
| a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device) |
| b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device) |
| c = torch.randn(10, 2, requires_grad=True, dtype=torch.float64, device=device) |
| |
| def grad_test_func(a, b, c): |
| c = torch.nested.as_nested_tensor([a, b, c]) |
| # This implictily tests to_padded_tensor grads |
| return torch.nested.to_padded_tensor(c, 0) |
| data = (a, b, c) |
| assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) |
| |
| @decorateIf( |
| xfailIfTorchDynamo, |
| # only fails in python 3.11. TODO: Debug this! |
| lambda params: params["layout"] == torch.jagged and sys.version_info >= (3, 11) |
| ) |
| @parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name) |
| def test_dropout_backward(self, layout): |
| if layout == torch.jagged: |
| nt = torch.nested.nested_tensor([torch.randn((2, 5)), torch.randn((3, 5))], requires_grad=True, layout=layout) |
| else: |
| nt = torch.nested.nested_tensor([torch.randn((2, 5)), torch.randn((3, 4))], requires_grad=True, layout=layout) |
| p = 0.2 |
| y = torch.nn.functional.dropout(nt, p) |
| y.backward(nt.clone().detach()) |
| self.assertEqual(nt.grad, y) |
| |
| def test_nested_tensor_bmm_gradcheck(self, device): |
| a = torch.randn(2, 6, requires_grad=True, dtype=torch.float64, device=device) |
| b = torch.randn(3, 6, requires_grad=True, dtype=torch.float64, device=device) |
| c = torch.randn(6, 4, requires_grad=True, dtype=torch.float64, device=device) |
| d = torch.randn(6, 5, requires_grad=True, dtype=torch.float64, device=device) |
| |
| def grad_test_func(a, b, c, d): |
| nt0 = torch.nested.as_nested_tensor([a, b]) |
| nt1 = torch.nested.as_nested_tensor([c, d]) |
| result = nt0.bmm(nt1) |
| return torch.nested.to_padded_tensor(result, 0.0) |
| |
| data = (a, b, c, d) |
| assert torch.autograd.gradcheck(grad_test_func, inputs=data) |
| |
| def test_nested_tensor_bmm_backward(self, device): |
| nt0 = torch.nested.nested_tensor([torch.randn((2, 6)), torch.randn((3, 6))], requires_grad=True, device=device) |
| nt1 = torch.nested.nested_tensor([torch.randn((6, 4)), torch.randn((6, 5))], requires_grad=True, device=device) |
| with torch.no_grad(): |
| pt0 = torch.nested.to_padded_tensor(nt0, 0.0).requires_grad_(True) |
| pt1 = torch.nested.to_padded_tensor(nt1, 0.0).requires_grad_(True) |
| |
| ynt = nt0.bmm(nt1) |
| ypt = pt0.bmm(pt1) |
| ynt.backward(ynt.clone()) |
| ypt.backward(ypt.clone()) |
| |
| self.assertEqual(torch.nested.to_padded_tensor(nt0.grad, 0.0), pt0.grad) |
| self.assertEqual(torch.nested.to_padded_tensor(nt1.grad, 0.0), pt1.grad) |
| |
| def test_nested_tensor_matmul_gradcheck(self, device): |
| a = torch.randn(2, 6, requires_grad=True, dtype=torch.float64, device=device) |
| b = torch.randn(3, 6, requires_grad=True, dtype=torch.float64, device=device) |
| c = torch.randn(6, 4, requires_grad=True, dtype=torch.float64, device=device) |
| d = torch.randn(6, 5, requires_grad=True, dtype=torch.float64, device=device) |
| |
| def grad_test_func(a, b, c, d): |
| nt0 = torch.nested.as_nested_tensor([a, b]) |
| nt1 = torch.nested.as_nested_tensor([c, d]) |
| result = torch.matmul(nt0, nt1) |
| return torch.nested.to_padded_tensor(result, 0.0) |
| |
| data = (a, b, c, d) |
| assert torch.autograd.gradcheck(grad_test_func, inputs=data) |
| |
| def test_nested_tensor_matmul_backward(self, device): |
| nt0 = torch.nested.nested_tensor([torch.randn((7, 2, 6)), torch.randn((7, 3, 6))], requires_grad=True, device=device) |
| nt1 = torch.nested.nested_tensor([torch.randn((7, 6, 4)), torch.randn((7, 6, 5))], requires_grad=True, device=device) |
| with torch.no_grad(): |
| pt0 = torch.nested.to_padded_tensor(nt0, 0.0).requires_grad_(True) |
| pt1 = torch.nested.to_padded_tensor(nt1, 0.0).requires_grad_(True) |
| |
| ynt = torch.matmul(nt0, nt1) |
| ypt = torch.matmul(pt0, pt1) |
| ynt.backward(ynt.clone()) |
| ypt.backward(ypt.clone()) |
| |
| self.assertEqual(torch.nested.to_padded_tensor(nt0.grad, 0.0), pt0.grad) |
| self.assertEqual(torch.nested.to_padded_tensor(nt1.grad, 0.0), pt1.grad) |
| |
| def test_nested_tensor_transpose_gradcheck(self, device): |
| a = torch.randn(2, 5, requires_grad=True, device=device) |
| b = torch.randn(3, 4, requires_grad=True, device=device) |
| |
| def grad_test_func(a, b): |
| nt = torch.nested.as_nested_tensor([a, b]) |
| result = nt.transpose(-2, -1).transpose(-2, -1) |
| return torch.nested.to_padded_tensor(result, 0.0) |
| |
| data = (a, b) |
| assert torch.autograd.gradcheck(grad_test_func, inputs=data, eps=1e-3) |
| |
| def test_nested_tensor_transpose_backward(self, device): |
| nt = torch.nested.nested_tensor([torch.randn((2, 5)), torch.randn((3, 4))], requires_grad=True, device=device) |
| with torch.no_grad(): |
| pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True) |
| |
| ynt = nt.transpose(-2, -1) |
| ypt = pt.transpose(-2, -1) |
| ynt.backward(ynt.clone()) |
| ypt.backward(ypt.clone()) |
| |
| self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad) |
| |
| def test_nested_tensor_reshape_gradcheck(self, device): |
| a = torch.randn(2, 6, requires_grad=True, device=device) |
| b = torch.randn(3, 6, requires_grad=True, device=device) |
| |
| def grad_test_func(a, b): |
| nt = torch.nested.as_nested_tensor([a, b]) |
| result = nt.reshape(2, -1, 2, 3) |
| return torch.nested.to_padded_tensor(result, 0.0) |
| |
| data = (a, b) |
| assert torch.autograd.gradcheck(grad_test_func, inputs=data, eps=1e-3) |
| |
| def test_nested_tensor_reshape_backward(self): |
| nt = torch.nested.nested_tensor([torch.randn((2, 6)), torch.randn((3, 6))], requires_grad=True) |
| with torch.no_grad(): |
| pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True) |
| |
| ynt = nt.reshape(2, -1, 2, 3) |
| ypt = pt.reshape(2, -1, 2, 3) |
| ynt.backward(ynt.clone()) |
| ypt.backward(ypt.clone()) |
| |
| self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad) |
| |
| def test_nested_tensor_squeeze_backward(self, device): |
| nt = torch.nested.nested_tensor([torch.randn((2, 6, 1)), torch.randn((3, 6, 1))], requires_grad=True, device=device) |
| with torch.no_grad(): |
| pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True) |
| |
| ynt = nt.squeeze(-1) |
| ypt = pt.squeeze(-1) |
| ynt.backward(ynt.clone()) |
| ypt.backward(ypt.clone()) |
| |
| self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad) |
| |
| def test_nested_tensor_squeeze_gradcheck(self, device): |
| a = torch.randn((2, 6, 1), dtype=torch.float64, requires_grad=True, device=device) |
| b = torch.randn((3, 6, 1), dtype=torch.float64, requires_grad=True, device=device) |
| |
| def grad_test_func(a, b): |
| nt = torch.nested.as_nested_tensor([a, b]) |
| result = nt.squeeze(-1) |
| return torch.nested.to_padded_tensor(result, 0.0) |
| |
| assert torch.autograd.gradcheck(grad_test_func, inputs=(a, b), eps=1e-3) |
| |
| def test_nested_tensor_unsqueeze_backward(self, device): |
| nt = torch.nested.nested_tensor([torch.randn((2, 6)), torch.randn((3, 6))], requires_grad=True, device=device) |
| with torch.no_grad(): |
| pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True) |
| |
| ynt = nt.unsqueeze(2) |
| ypt = pt.unsqueeze(2) |
| ynt.backward(ynt.clone()) |
| ypt.backward(ypt.clone()) |
| |
| self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad) |
| |
| def test_nested_tensor_unsqueeze_gradcheck(self, device): |
| a = torch.randn((2, 6), dtype=torch.float64, requires_grad=True, device=device) |
| b = torch.randn((3, 6), dtype=torch.float64, requires_grad=True, device=device) |
| |
| def grad_test_func(a, b): |
| nt = torch.nested.as_nested_tensor([a, b]) |
| result = nt.unsqueeze(-1) |
| return torch.nested.to_padded_tensor(result, 0.0) |
| |
| assert torch.autograd.gradcheck(grad_test_func, inputs=(a, b), eps=1e-3) |
| |
| def test_nested_tensor_linear(self, device): |
| |
| a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device) |
| b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device) |
| c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device) |
| |
| weight = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device) |
| bias = torch.randn(2, requires_grad=True, dtype=torch.float64, device=device) |
| |
| def grad_test_func(a, b, c, weight, bias=None): |
| nt = torch.nested.as_nested_tensor([a, b, c]) |
| # This implicitly tests to_padded_tensor grads |
| d = torch.functional.F.linear(nt, weight, bias) |
| return torch.nested.to_padded_tensor(d, 0) |
| data = (a, b, c, weight, bias) |
| assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) |
| |
| # Test linear with no bias added |
| data = (a, b, c, weight) |
| assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) |
| |
| def test_nested_tensor_linear_plus_transpose(self, device): |
| a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device) |
| b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device) |
| c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device) |
| |
| weight = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device) |
| bias = torch.randn(2, requires_grad=True, dtype=torch.float64, device=device) |
| |
| def grad_test_func(a, b, c, weight, bias=None): |
| nt = torch.nested.as_nested_tensor([a, b, c]) |
| # This implicitly tests to_padded_tensor grads |
| d = torch.functional.F.linear(nt, weight, bias) |
| d = d.transpose(-1, -2).contiguous() |
| return torch.nested.to_padded_tensor(d, 0) |
| data = (a, b, c, weight, bias) |
| assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) |
| |
| # Test linear with no bias added |
| data = (a, b, c, weight) |
| assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) |
| |
| def test_nested_tensor_softmax(self, device): |
| a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device) |
| b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device) |
| c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device) |
| |
| def grad_test_func(a, b, c, dim): |
| nt = torch.nested.as_nested_tensor([a, b, c]) |
| # This implicitly tests to_padded_tensor grads |
| d = torch.functional.F.softmax(nt, dim=dim) |
| return torch.nested.to_padded_tensor(d, 0) |
| |
| # softmax over last dim |
| data = (a, b, c, -1) |
| assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) |
| |
| def test_nested_tensor_linear_backward(self, device): |
| a = torch.randn(1, 2, requires_grad=False, device=device) |
| b = torch.randn(2, 2, requires_grad=False, device=device) |
| c = torch.randn(3, 2, requires_grad=False, device=device) |
| |
| weight = torch.randn(2, 2, requires_grad=True, device=device) |
| bias = torch.randn(2, requires_grad=True, device=device) |
| nt = torch.nested.as_nested_tensor([a, b, c], device=device) |
| |
| out = torch.functional.F.linear(nt, weight, bias) |
| |
| out.backward(out.clone()) |
| |
| assert weight.grad is not None |
| assert bias.grad is not None |
| |
| assert a.grad is None |
| assert b.grad is None |
| assert c.grad is None |
| |
| def test_values_grad_with_broadcast(self, device): |
| a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device) |
| b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device) |
| c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device) |
| |
| def grad_test_func(a, b, c): |
| nt = torch.nested.as_nested_tensor([a, b, c]) |
| buffer = nt.values() |
| return buffer.sum() |
| |
| data = (a, b, c) |
| assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) |
| |
| def test_to_buffer_series_ops_grad_with_broadcast(self, device): |
| a = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device) |
| b = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device) |
| c = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device) |
| |
| def grad_test_func(a, b, c): |
| nt = torch.nested.as_nested_tensor([a, b, c]) |
| buffer = nt.values() |
| buffer = buffer * 2 |
| return buffer.exp() |
| |
| data = (a, b, c) |
| assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) |
| |
| def test_unbind_flow_through(self, device): |
| a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device) |
| b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device) |
| c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device) |
| |
| def grad_test_func(a, b, c): |
| nt = torch.nested.as_nested_tensor([a, b, c]) |
| ntT = nt.transpose(-1, -2) |
| unbound = ntT.unbind() |
| d = unbound[0] |
| d = torch.pow(d, 2) |
| return d |
| |
| data = (a, b, c) |
| assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) |
| |
| def test_split_with_sizes_flow_through(self, device): |
| a = torch.randn(2, 5, requires_grad=True, dtype=torch.float64, device=device) |
| b = torch.randn(3, 5, requires_grad=True, dtype=torch.float64, device=device) |
| c = torch.randn(4, 5, requires_grad=True, dtype=torch.float64, device=device) |
| |
| def grad_test_func(a, b, c): |
| nt = torch.nested.as_nested_tensor([a, b, c]) |
| splits = nt.split_with_sizes([2, 3], dim=-1) |
| unbound = splits[1].unbind() |
| d = unbound[0] |
| d = torch.pow(d, 2) |
| return d |
| |
| data = (a, b, c) |
| assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) |
| |
| def test_indexing_backward(self, device): |
| x0 = torch.randn((2, 5)) |
| x1 = torch.randn((3, 4)) |
| nt = torch.nested.nested_tensor([x0, x1], device=device, requires_grad=True) |
| self.assertEqual(nt[0], x0) |
| self.assertEqual(nt[-1], x1) |
| grad_x0 = torch.randn((2, 5), device=device) |
| nt[0].backward(grad_x0) |
| expected_grad = torch.nested.nested_tensor([grad_x0, torch.zeros((3, 4), device=device)]) |
| self.assertEqual(nt.grad, expected_grad) |
| |
| def test_masked_fill_backward(self, device): |
| a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device) |
| b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device) |
| c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device) |
| |
| def grad_test_func(a, b, c): |
| nt = torch.nested.as_nested_tensor([a, b, c]) |
| mask = nt.detach().clone().to(bool) |
| out = nt.masked_fill(mask, 0) |
| out = torch.nested.to_padded_tensor(out, 0) |
| return out |
| data = (a, b, c) |
| assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) |
| |
| def test_gelu_backward(self, device): |
| a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device) |
| b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device) |
| c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device) |
| |
| def grad_test_func(a, b, c): |
| nt = torch.nested.as_nested_tensor([a, b, c]) |
| nt_gelu = torch.nn.functional.gelu(nt) |
| return torch.nested.to_padded_tensor(nt_gelu, 0) |
| |
| data = (a, b, c) |
| assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) |
| |
| def test_relu_backward(self, device): |
| a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device) |
| b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device) |
| c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device) |
| |
| def grad_test_func(a, b, c): |
| nt = torch.nested.as_nested_tensor([a, b, c]) |
| nt_relu = torch.nn.functional.relu(nt) |
| return torch.nested.to_padded_tensor(nt_relu, 0) |
| |
| data = (a, b, c) |
| assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) |
| |
| def test_selu_backward(self, device): |
| a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device) |
| b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device) |
| c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device) |
| |
| def grad_test_func(a, b, c): |
| nt = torch.nested.as_nested_tensor([a, b, c]) |
| nt_relu = torch.nn.functional.silu(nt) |
| return torch.nested.to_padded_tensor(nt_relu, 0) |
| |
| data = (a, b, c) |
| assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) |
| |
| def test_abs_backward(self, device): |
| a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device) |
| b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device) |
| c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device) |
| |
| def grad_test_func(a, b, c): |
| nt = torch.nested.as_nested_tensor([a, b, c]) |
| nt_abs = torch.abs(nt) |
| return torch.nested.to_padded_tensor(nt_abs, 0) |
| |
| data = (a, b, c) |
| assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) |
| |
| # Previously would error when input NT doesn't require grad |
| # NotImplementedError: Cannot access storage of UndefinedTensorImpl |
| def test_layer_norm_backward_edge_case(self, device): |
| size = 4 |
| a = torch.randn(1, 2, size, requires_grad=False, dtype=torch.float64, device=device) |
| nt = torch.nested.nested_tensor([a]) |
| nt_layer_norm = torch.nn.LayerNorm(nt.size(-1), device=device, dtype=torch.float64) |
| out = nt_layer_norm(nt) |
| out.backward(out.clone()) |
| |
| def test_accumulate_grad_different_strides(self, device): |
| a = torch.rand(1, 4, 2, requires_grad=True, dtype=torch.float64, device=device) |
| b = torch.rand(1, 8, 2, requires_grad=True, dtype=torch.float64, device=device) |
| |
| def grad_test_func(a, b): |
| nt_1 = torch.nested.as_nested_tensor([a, b]) |
| nt_2 = nt_1.clone() |
| out = torch.nn.functional.scaled_dot_product_attention(nt_1, nt_2, nt_2) |
| return torch.nested.to_padded_tensor(out, 0) |
| |
| data = (a, b) |
| assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) |
| |
| # https://github.com/pytorch/pytorch/issues/95562 |
| @skipIfSlowGradcheckEnv |
| @parametrize("size", [1024, 1023, 513, 512, 256, 128, 32, 4, 2]) |
| def test_layer_norm_backward(self, device, size): |
| a = torch.randn(1, 2, size, requires_grad=True, dtype=torch.float64, device=device) |
| b = torch.randn(2, 2, size, requires_grad=True, dtype=torch.float64, device=device) |
| c = torch.randn(3, 2, size, requires_grad=True, dtype=torch.float64, device=device) |
| |
| def grad_test_func(a, b, c): |
| nt = torch.nested.as_nested_tensor([a, b, c]) |
| layer_norm = torch.nn.LayerNorm(nt.size(-1), device=device, dtype=torch.float64) |
| nt_layer_norm = layer_norm(nt) |
| return torch.nested.to_padded_tensor(nt_layer_norm, 0) |
| |
| data = (a, b, c) |
| assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) |
| |
| # https://github.com/pytorch/pytorch/issues/95562 |
| @skipIfSlowGradcheckEnv |
| # Could either mark slow or reduce size |
| @parametrize("size", [128, 32, 4, 2]) |
| def test_layer_norm_backward_5d(self, device, size): |
| a = torch.randn(4, size, size, 4, requires_grad=True, dtype=torch.float64, device=device) |
| b = torch.randn(7, size, size, 4, requires_grad=True, dtype=torch.float64, device=device) |
| c = torch.randn(10, size, size, 4, requires_grad=True, dtype=torch.float64, device=device) |
| |
| def grad_test_func(a, b, c): |
| nt = torch.nested.as_nested_tensor([a, b, c]) |
| layer_norm = torch.nn.LayerNorm((size, size, nt.size(-1)), device=device, dtype=torch.float64) |
| nt_layer_norm = layer_norm(nt) |
| return torch.nested.to_padded_tensor(nt_layer_norm, 0) |
| |
| data = (a, b, c) |
| assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False) |
| |
| # Found in torch/testing/_comparison.py |
| default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float32: 1e-5} |
| default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float32: 1.3e-6} |
| |
| def get_rtol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float: |
| deviation = true_value - computed_value |
| deviation = torch.abs(deviation / true_value) |
| # Fill in the nans with the default rtol |
| torch.nan_to_num_(deviation, nan=default_rtol[computed_value.dtype]) |
| return deviation.max().item() |
| |
| |
| def get_atol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float: |
| deviation = true_value - computed_value |
| atol = torch.abs(deviation).max().item() |
| return atol |
| |
| |
| def get_tolerances( |
| true_value: torch.Tensor, |
| computed_value: torch.Tensor, |
| fudge_factor: Optional[float] = None, |
| ) -> Tuple[float, float]: |
| """Returns the absolute and relative tolerances for comparing two tensors.""" |
| fudge_factor = fudge_factor if fudge_factor is not None else 1.0 |
| atol = get_atol(true_value, computed_value) |
| rtol = get_rtol(true_value, computed_value) |
| |
| atol = fudge_factor * max(atol, default_atol[computed_value.dtype]) |
| rtol = fudge_factor * max(rtol, default_rtol[computed_value.dtype]) |
| # torch.isclose() has weird behavior around see: |
| # https://github.com/pytorch/pytorch/issues/102400 |
| if rtol > 1e30: |
| rtol = default_rtol[computed_value.dtype] |
| return atol, rtol |
| |
| # We can probably parametrizing existing tests instead of having a separate |
| # test class as we begin to support more ops. Also maybe rewrite with OpInfos. |
| @markDynamoStrictTest |
| class TestNestedTensorSubclass(TestCase): |
| # TODO: consolidate with the below |
| def _get_list_for_jagged_tensor(self, nested_size, device, requires_grad=True): |
| Ds = nested_size[1:] |
| out = [] |
| for s in nested_size[0]: |
| out.append( |
| torch.randn(s, *Ds, requires_grad=requires_grad, device=device, dtype=torch.float64) |
| ) |
| return out |
| |
| def _get_example_tensor_lists(self, include_list_of_lists=True, include_requires_grad=True): |
| |
| def _make_tensor(*shape, include_requires_grad=include_requires_grad, requires_grad=True): |
| return torch.randn( |
| *shape, |
| requires_grad=(requires_grad if include_requires_grad else False) |
| ) |
| |
| # Purposefully introduce mixed requires_grad settings for the components |
| # when include_requires_grad=True. |
| example_lists = [ |
| # (B, *, D) with B=4 |
| [ |
| _make_tensor(2, 5), |
| _make_tensor(3, 5, requires_grad=False), |
| _make_tensor(4, 5, requires_grad=False), |
| _make_tensor(6, 5) |
| ], |
| # (B, *, D_0, D_1) with B=5 |
| [ |
| _make_tensor(2, 5, 6), |
| _make_tensor(3, 5, 6), |
| _make_tensor(4, 5, 6, requires_grad=False), |
| _make_tensor(5, 5, 6), |
| _make_tensor(6, 5, 6), |
| ], |
| ] |
| |
| if include_list_of_lists: |
| example_lists.append( |
| # (B, *, D) with B=3 in list form |
| [ |
| _make_tensor(2, 5, requires_grad=False).tolist(), |
| _make_tensor(3, 5).tolist(), |
| _make_tensor(4, 5).tolist(), |
| ]) |
| |
| return example_lists |
| |
| def test_tensor_attributes(self, device): |
| a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device) |
| b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device) |
| c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device) |
| nt, _offsets = jagged_from_list([a, b, c], None) |
| |
| for op in ( |
| torch.ops.aten.is_non_overlapping_and_dense.default, |
| torch.ops.aten.sym_size.default, |
| torch.ops.aten.dim.default, |
| torch.ops.aten.sym_numel.default, |
| torch.ops.aten.sym_stride.default, |
| torch.ops.aten.sym_storage_offset.default, |
| ): |
| op(nt) |
| |
| with self.assertRaisesRegex(RuntimeError, |
| "directly calling torch.ops.aten.size"): |
| torch.ops.aten.size.default(nt) |
| |
| singleton_int = torch.nested._internal.nested_tensor.get_tensor_symint(_offsets, coeff=1) |
| self.assertEqual(nt.size(), (3, singleton_int, 3)) |
| self.assertEqual(nt.shape, (3, singleton_int, 3)) |
| self.assertEqual(nt.dim(), 3) |
| self.assertEqual(nt.numel(), 27) |
| |
| def test_linear(self, device): |
| a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device) |
| b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device) |
| c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device) |
| weight = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device) |
| |
| def grad_test_func(a, b, c, weight): |
| nt, _ = jagged_from_list([a, b, c], None) |
| out = torch.nn.functional.linear(nt, weight) |
| return buffer_from_jagged(out) |
| |
| gradcheck(grad_test_func, inputs=(a, b, c, weight), check_batched_grad=False) |
| |
| def test_unary_pointwise(self, device): |
| a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device) |
| b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device) |
| c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device) |
| |
| def grad_test_func(a, b, c): |
| nt, _ = jagged_from_list([a, b, c], None) |
| out = torch.nn.functional.silu(nt.sin().cos()) |
| return buffer_from_jagged(out) |
| |
| gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False) |
| |
| def test_unary_pointwise_transposed_inputs(self, device): |
| a, b, c = ( |
| torch.randn(i + 2, 5, requires_grad=True, dtype=torch.float64, device=device) for i in range(3) |
| ) |
| |
| nt, _ = jagged_from_list([a.detach(), b.detach(), c.detach()], None) |
| nt_t = nt.transpose(1, 2) |
| self.assertFalse(nt_t.is_contiguous()) |
| out = torch.nn.functional.silu(nt_t.sin().cos()) |
| self.assertEqual(out.is_contiguous(), torch.nn.functional.silu(b.transpose(-1, -2).sin().cos()).is_contiguous()) |
| |
| self.assertEqual(nt_t.shape, out.shape) |
| |
| a, b, c = ( |
| torch.randn(i + 2, 5, requires_grad=True, dtype=torch.float64, device=device) for i in range(3) |
| ) |
| |
| def grad_test_func(a, b, c): |
| nt, _ = jagged_from_list([a, b, c], None) |
| nt_t = nt.transpose(1, 2) |
| out = torch.nn.functional.silu(nt_t.sin().cos()) |
| return buffer_from_jagged(out) |
| |
| gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False) |
| |
| |
| def test_binary_pointwise(self, device): |
| a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device) |
| b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device) |
| c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device) |
| |
| # Incorrect usage: shape check will fail if the offsets tensor are not |
| # the same exact tensor object |
| nt1, _ = jagged_from_list([a, b, c], None) |
| nt2, _ = jagged_from_list([a, b, c], None) |
| |
| self.assertRaisesRegex( |
| RuntimeError, |
| "cannot call binary pointwise function .* with inputs of shapes", |
| lambda: nt1 * nt2) |
| |
| # Correct usage: chain the calls using the same offsets tensor object |
| def grad_test_func(a, b, c): |
| nt1, offsets = jagged_from_list([a, b, c], None) |
| nt2, offsets = jagged_from_list([a, b, c], offsets) |
| out = nt1 * nt2 |
| return buffer_from_jagged(out) |
| |
| gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False) |
| |
| def test_binary_pointwise_transposed(self, device): |
| a, b, c = ( |
| torch.randn(i + 2, 5, dtype=torch.float64, device=device) for i in range(3) |
| ) |
| |
| nt1, offsets = jagged_from_list([a, b, c], None) |
| nt2, offsets = jagged_from_list([a, b, c], offsets) |
| |
| nt1_t = nt1.transpose(1, 2) |
| nt2_t = nt2.transpose(1, 2) |
| |
| out = nt1_t * nt2_t |
| self.assertFalse(nt1_t.is_contiguous()) |
| self.assertEqual(out.is_contiguous(), (b.transpose(-1, -2) * b.transpose(-1, -2)).is_contiguous()) |
| self.assertEqual(out.shape, nt1_t.shape) |
| |
| self.assertRaisesRegex( |
| RuntimeError, |
| "cannot call binary pointwise function mul.Tensor with inputs of shapes", |
| lambda: nt1 * nt2_t, |
| ) |
| |
| a, b, c = ( |
| torch.randn(i + 2, 5, requires_grad=True, dtype=torch.float64, device=device) for i in range(3) |
| ) |
| |
| # Correct usage: chain the calls using the same offsets tensor object |
| def grad_test_func(a, b, c): |
| nt1, offsets = jagged_from_list([a, b, c], None) |
| nt2, offsets = jagged_from_list([a, b, c], offsets) |
| nt1_t = nt1.transpose(1, 2) |
| nt2_t = nt2.transpose(1, 2) |
| out = nt1_t * nt2_t |
| return buffer_from_jagged(out) |
| |
| gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False) |
| |
| def test_split(self, device): |
| a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device) |
| b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device) |
| c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device) |
| |
| nt, _ = jagged_from_list([a, b, c], None) |
| out = torch.split(nt, 2, -1) |
| self.assertEqual(len(out), 2) |
| self.assertEqual( |
| out[0], jagged_from_list([a[:, 0:2], b[:, 0:2], c[:, 0:2]], None)[0] |
| ) |
| self.assertEqual( |
| out[1], jagged_from_list([a[:, 2:], b[:, 2:], c[:, 2:]], None)[0] |
| ) |
| |
| with self.assertRaisesRegex( |
| RuntimeError, |
| r"split\(\): not supported for NestedTensor on dim=0 or dim=1", |
| ): |
| torch.split(nt, 2, 1) |
| |
| def test_split_with_sizes(self, device): |
| a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device) |
| b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device) |
| c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device) |
| |
| nt, _ = jagged_from_list([a, b, c], None) |
| out = torch.split(nt, [1, 2], -1) |
| self.assertEqual(len(out), 2) |
| self.assertEqual( |
| out[0], jagged_from_list([a[:, 0:1], b[:, 0:1], c[:, 0:1]], None)[0] |
| ) |
| self.assertEqual( |
| out[1], jagged_from_list([a[:, 1:], b[:, 1:], c[:, 1:]], None)[0] |
| ) |
| with self.assertRaisesRegex( |
| RuntimeError, |
| r"split_with_sizes\(\): not supported for NestedTensor on dim=0 or dim=1", |
| ): |
| torch.split(nt, [1, 2], 1) |
| |
| def test_views_inherit_ragged_dim(self, device): |
| # view |
| nt = random_nt_from_dims( |
| [4, None, 8, 10], device=device, dtype=torch.float32, layout=torch.jagged) |
| # inherit ragged dim via -1 |
| view = nt.view(4, -1, 80) |
| self.assertEqual(nt.shape[1], view.shape[1]) |
| # inherit batch and ragged dims via -1 |
| view2 = nt.view(-1, -1, 80) |
| self.assertEqual(nt.shape[:2], view2.shape[:2]) |
| |
| # expand |
| nt = random_nt_from_dims( |
| [3, None, 1], device=device, dtype=torch.float32, layout=torch.jagged) |
| # inherit batch and ragged dims via -1 |
| view = nt.expand(-1, -1, 5) |
| self.assertEqual(nt.shape[:2], view.shape[:2]) |
| |
| @xfailIfTorchDynamo |
| def test_view_ragged_idx_not_one(self, device): |
| nt = random_nt_from_dims([2, None, 20], device=device, dtype=torch.float32, layout=torch.jagged) |
| |
| view_transposed = nt.transpose(1, 2).view(2, 20, nt.size(1)) |
| self.assertEqual((2, 20, nt.size(1)), (view_transposed.size())) |
| self.assertEqual(view_transposed._base, nt) |
| |
| def test_unsafe_view(self, device): |
| nt = random_nt_from_dims([4, None, 8, 10], device=device, dtype=torch.float32, layout=torch.jagged) |
| # basic view |
| view1 = torch.ops.aten._unsafe_view(nt, (4, -1, 80)) |
| self.assertEqual((4, nt.size(1), 80), tuple(view1.size())) |
| # _unsafe_view differs from view in that the view information is not tracked |
| self.assertTrue(view1._base is None) |
| |
| # test an unsafe_view when ragged_idx != 1, currently only supports identity view |
| nt_t = nt.transpose(1, 2) |
| view2 = torch.ops.aten._unsafe_view(nt_t, (4, 8, nt.size(1), 10)) |
| self.assertEqual((4, 8, nt.size(1), 10), tuple(view2.size())) |
| self.assertTrue(view2._base is None) |
| |
| @xfailIfTorchDynamo |
| @parametrize("requires_grad", [False, True]) |
| def test_reshape_decomp(self, device, requires_grad): |
| # contiguous NT should result in view |
| nt = random_nt_from_dims( |
| [3, None, 10], |
| device=device, |
| dtype=torch.float32, |
| layout=torch.jagged, |
| requires_grad=requires_grad |
| ) |
| view = nt.reshape(-1, -1, 5, 2) |
| self.assertEqual(view.shape[:2], nt.shape[:2]) |
| self.assertTrue(view._is_view() and view._base is nt) |
| # make sure gradients flow back |
| if requires_grad: |
| view.backward(torch.ones_like(view)) |
| self.assertEqual(nt.grad, torch.ones_like(nt)) |
| |
| # non-contiguous NT should result in contiguous copy |
| nt = random_nt_from_dims( |
| [3, None, 5, 2], |
| device=device, |
| dtype=torch.float32, |
| layout=torch.jagged, |
| requires_grad=requires_grad |
| ) |
| nt_noncontig = nt.transpose(-1, -2) |
| self.assertFalse(nt_noncontig.is_contiguous()) |
| copy = nt_noncontig.reshape(-1, -1, 10) |
| self.assertTrue(copy.is_contiguous()) |
| self.assertEqual(copy.shape[:2], nt.shape[:2]) |
| # make sure gradients flow back |
| if requires_grad: |
| copy.backward(torch.ones_like(copy)) |
| self.assertEqual(nt.grad, torch.ones_like(nt)) |
| |
| def test_flatten_decomp(self, device): |
| nt = random_nt_from_dims( |
| [3, None, 5, 2], device=device, dtype=torch.float32, layout=torch.jagged) |
| flattened = nt.flatten(-2, -1) |
| self.assertEqual(flattened.shape, nt.view(3, -1, 10).shape) |
| |
| nt = random_nt_from_dims( |
| [3, None, 5, 2, 6], device=device, dtype=torch.float32, layout=torch.jagged) |
| flattened = nt.flatten(-3, -2) |
| self.assertEqual(flattened.shape, nt.view(3, -1, 10, 6).shape) |
| |
| def test_chunk(self, device): |
| # normal case |
| D = 30 |
| nt = random_nt_from_dims( |
| [4, None, D], device=device, dtype=torch.float32, layout=torch.jagged) |
| NUM_CHUNKS = 3 |
| chunks = nt.chunk(NUM_CHUNKS, dim=-1) |
| self.assertEqual(len(chunks), NUM_CHUNKS) |
| for i in range(NUM_CHUNKS): |
| self.assertEqual(chunks[i].shape[-1], D // NUM_CHUNKS) |
| |
| # chunk on batch dim not supported |
| with self.assertRaisesRegex( |
| RuntimeError, "chunk.* not supported for NestedTensor on dim=0 or dim=1"): |
| nt.chunk(2, dim=0) |
| |
| # chunk on ragged dim not supported |
| with self.assertRaisesRegex( |
| RuntimeError, "chunk.* not supported for NestedTensor on dim=0 or dim=1"): |
| nt.chunk(2, dim=1) |
| |
| def test_squeeze(self, device): |
| B = 4 |
| D = 6 |
| # squeeze middle dim |
| nt = random_nt_from_dims( |
| [B, None, 1, D], device=device, dtype=torch.float32, layout=torch.jagged) |
| j0 = nt.shape[1] |
| |
| for dim_arg in [-2, 2]: |
| out = nt.squeeze(dim_arg) |
| self.assertEqual(out.shape, (B, j0, D)) |
| self.assertEqual(out.unsqueeze(-2), nt) |
| |
| # squeeze last dim |
| nt = random_nt_from_dims( |
| [B, None, 1], device=device, dtype=torch.float32, layout=torch.jagged) |
| j1 = nt.shape[1] |
| |
| for dim_arg in [-1, 2]: |
| out = nt.squeeze(dim_arg) |
| self.assertEqual(out.shape, (B, j1)) |
| self.assertEqual(out.unsqueeze(-1), nt) |
| |
| # squeeze on batch dim not supported |
| with self.assertRaisesRegex( |
| RuntimeError, "squeeze.* not supported for NestedTensor on dim=0 or dim=1"): |
| nt.squeeze(0) |
| |
| # squeeze on ragged dim not supported |
| with self.assertRaisesRegex( |
| RuntimeError, "squeeze.* not supported for NestedTensor on dim=0 or dim=1"): |
| nt.squeeze(1) |
| |
| def test_binary_pointwise_broadcasting(self, device): |
| # (B, j0, 3, 4) |
| ts = self._get_list_for_jagged_tensor(((2, 3, 4), 3, 4), device, requires_grad=True) |
| # (B, j0, ?, ?) + (?) -> (B, j0, ?, ?) |
| # (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?) |
| # (B, j0, ?, ?) + (1, ?, ?) -> (B, j0, ?, ?) |
| # Unsupported: (B, j0, ?, ?) + (1, 1, 1, ?, ?) -> (1, B, j0, ?, ?) |
| t_sizes = ( |
| (4,), |
| (1, 4), |
| (3, 1), |
| (1, 3, 1), |
| (1, 1, 1, 4), |
| # (1, 1, 1, 1, 4), (unsupported today) |
| ) |
| |
| def grad_test_func(t, *ts): |
| nt, _ = jagged_from_list(ts, None) |
| out = nt + t |
| return buffer_from_jagged(out) |
| |
| for t_size in t_sizes: |
| t = torch.rand(t_size, requires_grad=True, device=device, dtype=torch.float64) |
| gradcheck(grad_test_func, inputs=(t, *ts), check_batched_grad=False) |
| |
| def test_threshold_backward(self, device): |
| ts1 = self._get_list_for_jagged_tensor(((2, 3, 4), 16), device=device, requires_grad=False) |
| ts2 = self._get_list_for_jagged_tensor(((2, 3, 4), 16), device=device, requires_grad=False) |
| |
| nt1, offsets = jagged_from_list(ts1, None) |
| nt2, offsets = jagged_from_list(ts2, offsets) |
| buf1 = buffer_from_jagged(nt1).detach().clone() |
| buf2 = buffer_from_jagged(nt2).detach().clone() |
| |
| res_nt = torch.ops.aten.threshold_backward(nt1, nt2, 0.0) |
| res_dense = torch.ops.aten.threshold_backward(buf1, buf2, 0.0) |
| |
| self.assertEqual(res_dense, buffer_from_jagged(res_nt)) |
| |
| |
| @parametrize("keepdim", [False, True]) |
| def test_sum_int_DimList(self, device, keepdim): |
| # (B, j0, 3, 4) |
| ts = self._get_list_for_jagged_tensor(((2, 3, 4), 3, 4), device=device, requires_grad=True) |
| |
| # Check shape correctness |
| reduce_dims = ( |
| # dims, expected shape, expected keepdim shape |
| # j0 is represented as None |
| ((0, 1), (3, 4), (1, 1, 3, 4)), |
| ((1, 2), None, None), |
| ((2, 3), (3, None), (3, None, 1, 1)), |
| ((0, 1, 3), (3,), (1, 1, 3, 1)), |
| ((0, 1, 2), (4,), (1, 1, 1, 4)), |
| ((0, 1, 2, 3), tuple(), (1, 1, 1, 1)), |
| ) |
| for rd, ref_shape_no_keepdim, ref_shape_keepdim in reduce_dims: |
| if (0 in rd) ^ (1 in rd): |
| with self.assertRaisesRegex( |
| RuntimeError, |
| "applying over the ragged dimension, but not the batch dimension"): |
| nt, _ = jagged_from_list(ts, None) |
| out = torch.sum(nt, dim=rd, keepdim=keepdim) |
| continue |
| |
| nt, _ = jagged_from_list(ts, None) |
| out = torch.sum(nt, dim=rd, keepdim=keepdim) |
| ref_shape = ref_shape_keepdim if keepdim else ref_shape_no_keepdim |
| self.assertEqual(len(out.shape), len(ref_shape)) |
| for o, r in zip(out.shape, ref_shape): |
| if r is not None: |
| self.assertEqual(o, r) |
| else: |
| self.assertTrue(isinstance(o, torch.SymInt)) |
| |
| # Check values correctness |
| # raggedness not reduced |
| nt, _ = jagged_from_list(ts, None) |
| out = torch.sum(nt, dim=(2, 3), keepdim=keepdim) |
| out_ref = torch.sum(nt.values(), dim=(1, 2)) |
| self.assertIsInstance(out, NestedTensor) |
| # flatten to avoid having to replicate unsqueeze logic depending on keepdim |
| self.assertTrue(torch.allclose(out.values().view(-1), out_ref.view(-1))) |
| |
| # raggedness reduced away |
| nt, _ = jagged_from_list(ts, None) |
| out = torch.sum(nt, dim=(0, 1), keepdim=keepdim) |
| out_ref = torch.sum(nt.values(), dim=(0,)) |
| self.assertNotIsInstance(out, NestedTensor) |
| self.assertTrue(torch.allclose(out, out_ref)) |
| |
| |
| |
| @dtypes(torch.float, torch.double, torch.half) |
| @parametrize("requires_grad", [False, True]) |
| @parametrize("weights_only", [False, True]) |
| def test_serialization(self, device, dtype, requires_grad, weights_only): |
| |
| def compare_metadata(nt1, nt2): |
| self.assertEqual(nt1._nested_tensor_size(), nt2._nested_tensor_size()) |
| self.assertEqual(nt1._nested_tensor_strides(), nt2._nested_tensor_strides()) |
| self.assertEqual(nt1._nested_tensor_storage_offsets(), |
| nt2._nested_tensor_storage_offsets()) |
| |
| nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7)) |
| for a in [nt_contiguous, nt_noncontiguous]: |
| buffer = io.BytesIO() |
| serialized = torch.save(a, buffer) |
| buffer.seek(0) |
| b = torch.load(buffer, weights_only=weights_only) |
| # should be both conceptually equal and metadata equivalent |
| self.assertEqual(a, b) |
| compare_metadata(a, b) |
| # should be conceptually equal but not necessarily metadata equivalent |
| self.assertEqual(b, nt_contiguous) |
| self.assertEqual(b, nt_noncontiguous) |
| |
| @unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property") |
| @onlyCUDA |
| def test_pin_memory(self, device): |
| nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7)) |
| for nt in [nt_contiguous, nt_noncontiguous]: |
| self.assertFalse(nt.is_pinned()) |
| pinned = nt.pin_memory(device) |
| self.assertTrue(pinned.is_pinned()) |
| self.assertEqual(nt, pinned) |
| self.assertNotEqual(nt.data_ptr(), pinned.data_ptr()) |
| # test that pin_memory on already pinned tensor has no effect |
| self.assertIs(pinned, pinned.pin_memory()) |
| self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr()) |
| |
| def _validate_nt(self, nt, tensor_list, device, dtype, requires_grad): |
| # Validate a bunch of properties after NT construction. |
| device = torch.device(device) |
| first_t = torch.as_tensor(tensor_list[0]) |
| expected_dim = first_t.dim() + 1 |
| batch_size = len(tensor_list) |
| self.assertEqual(nt.dim(), expected_dim) |
| self.assertEqual(nt.device, device) |
| self.assertEqual(nt.dtype, dtype) |
| self.assertEqual(nt.layout, torch.jagged) |
| self.assertEqual(nt.requires_grad, requires_grad) |
| self.assertEqual(nt.values().device, device) |
| self.assertEqual(nt.offsets().device, device) |
| self.assertEqual(nt.shape[0], batch_size) |
| self.assertTrue(isinstance(nt.shape[1], torch.SymInt)) |
| self.assertEqual(nt.shape[2:], first_t.shape[1:]) |
| |
| @xfailIfTorchDynamo |
| @dtypes(torch.float, torch.double, torch.half) |
| @parametrize("requires_grad", [False, True]) |
| @parametrize("components_require_grad", [False, True]) |
| def test_jagged_layout_construction_nested_tensor( |
| self, device, dtype, requires_grad, components_require_grad): |
| for tensor_list in self._get_example_tensor_lists( |
| include_list_of_lists=True, include_requires_grad=components_require_grad): |
| nt = torch.nested.nested_tensor( |
| tensor_list, |
| device=device, |
| dtype=dtype, |
| layout=torch.jagged, |
| requires_grad=requires_grad) |
| self._validate_nt(nt, tensor_list, device, dtype, requires_grad) |
| |
| # Make sure grads -don't- flow back into original tensors for nested_tensor() |
| if requires_grad: |
| (nt * 2).backward(torch.ones_like(nt)) |
| for t in tensor_list: |
| t = t if isinstance(t, torch.Tensor) else torch.as_tensor(t) |
| self.assertTrue(t.grad is None) |
| |
| @xfailIfTorchDynamo |
| @dtypes(torch.float, torch.double, torch.half) |
| @parametrize("components_require_grad", [False, True]) |
| def test_jagged_layout_construction_as_nested_tensor( |
| self, device, dtype, components_require_grad): |
| # NB: as_nested_tensor(tensor_list) doesn't support lists of lists for tensor_list |
| for tensor_list in self._get_example_tensor_lists( |
| include_list_of_lists=False, include_requires_grad=components_require_grad): |
| nt = torch.nested.as_nested_tensor( |
| tensor_list, |
| device=device, |
| dtype=dtype, |
| layout=torch.jagged) |
| |
| # nt.requires_grad=True should be set if at least one component requires grad |
| self._validate_nt(nt, tensor_list, device, dtype, components_require_grad) |
| |
| # Make sure grads flow back into original tensors for as_nested_tensor() |
| if components_require_grad: |
| (nt * 2).backward(torch.ones_like(nt)) |
| for t in tensor_list: |
| if t.requires_grad: |
| self.assertEqual(t.grad, torch.ones_like(t) * 2) |
| else: |
| self.assertTrue(t.grad is None) |
| |
| @xfailIfTorchDynamo |
| @unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property") |
| @onlyCUDA |
| def test_jagged_layout_construction_with_pinned_memory(self, device): |
| for tensor_list in self._get_example_tensor_lists(): |
| nt = torch.nested.nested_tensor( |
| tensor_list, |
| layout=torch.jagged, |
| device="cpu", |
| pin_memory=True) |
| |
| self._validate_nt(nt, tensor_list, "cpu", torch.float32, requires_grad=False) |
| self.assertTrue(nt.is_pinned()) |
| |
| @dtypes(torch.double, torch.half) |
| @onlyCUDA |
| def test_device_dtype_transfer_maintains_offsets(self, device, dtype): |
| for tensor_list in self._get_example_tensor_lists(): |
| orig_device = torch.device("cpu") |
| orig_dtype = torch.float32 |
| nt = torch.nested.nested_tensor( |
| tensor_list, |
| layout=torch.jagged, |
| device=orig_device, |
| dtype=orig_dtype) |
| |
| self.assertEqual(torch.int64, nt.offsets().dtype) |
| nt = nt.to(device=device).to(dtype=dtype) |
| |
| # offsets should still be int64 on the original device |
| self.assertEqual(orig_device, nt.offsets().device) |
| self.assertEqual(torch.int64, nt.offsets().dtype) |
| |
| def test_unbind(self, device): |
| for tensor_list in self._get_example_tensor_lists(): |
| nt = torch.nested.nested_tensor( |
| tensor_list, |
| layout=torch.jagged, |
| device=device) |
| out = nt.unbind() |
| self.assertEqual(len(out), len(tensor_list)) |
| for i, t in enumerate(out): |
| self.assertEqual(t, tensor_list[i]) |
| |
| @xfailIfTorchDynamo |
| def test_layer_norm_2(self, device): |
| test_tensor_list = self._get_list_for_jagged_tensor( |
| ((2, 3, 4), 3), device=device, requires_grad=True |
| ) |
| bias = torch.randn(3, requires_grad=False, dtype=torch.float64, device=device) |
| |
| def grad_test_func(a, b, c, bias): |
| nt, _ = jagged_from_list([a, b, c], None) |
| out = torch.nn.functional.layer_norm(nt, (nt.shape[-1],), bias=bias) |
| return buffer_from_jagged(out) |
| |
| gradcheck( |
| grad_test_func, inputs=(*test_tensor_list, bias), check_batched_grad=False |
| ) |
| |
| with self.assertRaisesRegex( |
| RuntimeError, |
| r"layer_norm\(\): normalizing over ragged dim not supported for nested tensors", |
| ): |
| nt, _ = jagged_from_list(test_tensor_list, None) |
| _ = torch.nn.functional.layer_norm(nt, (nt.shape[-2], nt.shape[-1])) |
| |
| def test_narrow(self, device): |
| starts = torch.tensor([0, 1, 2, 3, 4], device=device, dtype=torch.int64) |
| lengths = torch.tensor([3, 2, 2, 1, 5], device=device, dtype=torch.int64) |
| nt = torch.nested.narrow( |
| torch.arange(0, 10, device=device, dtype=torch.int64).unsqueeze(0).expand(5, -1).clone().detach(), |
| 1, |
| starts, |
| lengths, |
| layout=torch.jagged |
| ) |
| |
| # TODO: Use this approach when unbind is functional |
| # unbinded_nt = nt.unbind() |
| # for i in range(starts.shape[0]): |
| # self.assertEqual(torch.arange(starts[i], starts[i] + lengths[i], device=device, dtype=torch.int64), unbinded_nt[i]) |
| for i in range(starts.shape[0]): |
| self.assertEqual( |
| torch.arange(starts[i], starts[i] + lengths[i], device=device, dtype=torch.int64), |
| nt.values()[nt.offsets()[i]:(nt.offsets()[i] + nt.lengths()[i])] |
| ) |
| |
| def test_is_contiguous(self, device): |
| a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device) |
| b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device) |
| c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device) |
| nt_contiguous, _ = jagged_from_list([a, b, c], None) |
| |
| starts_nc = torch.tensor([0, 1, 2, 3, 4], device=device, dtype=torch.int64) |
| lengths_nc = torch.tensor([3, 2, 2, 1, 5], device=device, dtype=torch.int64) |
| narrow_base = torch.arange(0, 10, device=device, dtype=torch.int64).unsqueeze(0).expand(5, -1).clone() |
| nt_noncontiguous = torch.nested.narrow( |
| narrow_base, |
| 1, |
| starts_nc, |
| lengths_nc, |
| layout=torch.jagged |
| ) |
| |
| starts_c = torch.tensor([1, 0, 0, 0, 0], device=device, dtype=torch.int64) |
| lengths_c = torch.tensor([9, 10, 10, 10, 8], device=device, dtype=torch.int64) |
| nt_contiguous_narrow = torch.nested.narrow( |
| narrow_base, |
| 1, |
| starts_c, |
| lengths_c, |
| layout=torch.jagged |
| ) |
| |
| # Test contiguous case |
| assert nt_contiguous.is_contiguous() |
| |
| # Test narrow case |
| assert not nt_noncontiguous.is_contiguous() |
| assert nt_contiguous_narrow.is_contiguous() |
| |
| # Test querying by memory_format |
| self.assertTrue(nt_contiguous.is_contiguous(memory_format=torch.contiguous_format)) |
| self.assertTrue(not nt_noncontiguous.is_contiguous(memory_format=torch.contiguous_format)) |
| self.assertTrue(nt_contiguous_narrow.is_contiguous(memory_format=torch.contiguous_format)) |
| |
| def test_noncontiguous_pointwise(self, device): |
| a = torch.randn(2, 3, 4, requires_grad=True, dtype=torch.float64, device=device) |
| b = torch.randn(3, 3, 4, requires_grad=True, dtype=torch.float64, device=device) |
| c = torch.randn(4, 3, 4, requires_grad=True, dtype=torch.float64, device=device) |
| nt, _ = jagged_from_list([a, b, c], None) |
| # transpose ragged dim |
| transposed = nt.transpose(1, 2) |
| self.assertFalse(transposed.is_contiguous()) |
| clone = transposed.clone() |
| |
| def check_nt_equality(x, y): |
| self.assertEqual(x.values(), y.values()) |
| self.assertEqual(x.offsets(), y.offsets()) |
| self.assertEqual(x._ragged_idx, y._ragged_idx) |
| self.assertEqual(x.shape, y.shape) |
| |
| self.assertFalse(clone.is_contiguous()) |
| check_nt_equality(clone, transposed) |
| |
| clone_contig = transposed.clone(memory_format=torch.contiguous_format) |
| self.assertTrue(clone_contig.is_contiguous()) |
| check_nt_equality(clone_contig, transposed) |
| |
| detached = transposed.detach() |
| self.assertFalse(clone.is_contiguous()) |
| check_nt_equality(detached, transposed) |
| |
| def test_to_copy(self, device): |
| nt, _ = jagged_from_list( |
| [torch.randn(i + 2, 3, 4, requires_grad=True, dtype=torch.float64, device=device) for i in range(3)], None |
| ) |
| |
| nt_copy_dtype = torch.ops.aten._to_copy(nt, dtype=torch.float16) |
| self.assertEqual(torch.float16, nt_copy_dtype.dtype) |
| |
| nt_t = nt.transpose(1, 2) |
| nt_t_copy_dtype = torch.ops.aten._to_copy(nt_t, dtype=torch.float16) |
| self.assertEqual(torch.float16, nt_t_copy_dtype.dtype) |
| |
| def test_is_same_size(self, device): |
| def get_3_tensors(): |
| return [torch.randn(i + 2, 3, 4, requires_grad=True, dtype=torch.float64, device=device) for i in range(3)] |
| |
| nt1, offsets1 = jagged_from_list(get_3_tensors(), None) |
| nt2, offsets1 = jagged_from_list(get_3_tensors(), offsets1) |
| |
| nt3, offsets2 = jagged_from_list(get_3_tensors(), None) |
| nt4, offsets2 = jagged_from_list(get_3_tensors(), offsets2) |
| |
| def check_size(nt1, nt2, nt3, nt4): |
| self.assertTrue(torch.ops.aten.is_same_size(nt1, nt2)) |
| self.assertTrue(torch.ops.aten.is_same_size(nt3, nt4)) |
| self.assertFalse(torch.ops.aten.is_same_size(nt1, nt3)) |
| |
| check_size(nt1, nt2, nt3, nt4) |
| |
| nt1_t, nt2_t, nt3_t, nt4_t = (x.transpose(1, 2) for x in (nt1, nt2, nt3, nt4)) |
| check_size(nt1_t, nt2_t, nt3_t, nt4_t) |
| |
| # Note 1: Math fallback doesn't work with bfloat16 on CUDA |
| # Note 2: ROCm doesn't support flash attention or mem_efficient attention for NT |
| @xfailIfTorchDynamo |
| @unittest.skipIf( |
| TEST_WITH_ROCM, |
| "ROCm doesn't support flash attention or mem_efficient attention for NT", |
| ) |
| @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if |
| SM80OrLater else [torch.float16, torch.float32]) |
| def test_sdpa(self, device, dtype): |
| batch_size = 1 |
| emb_dims = 128 |
| n_heads = 8 |
| head_dims = emb_dims // n_heads |
| |
| sen1 = torch.randn(11, emb_dims, dtype=dtype, device=device) |
| sen2 = torch.randn(13, emb_dims, dtype=dtype, device=device) |
| |
| query = torch.nn.Linear(emb_dims, emb_dims, bias=False, device=device, dtype=dtype) |
| key = torch.nn.Linear(emb_dims, emb_dims, bias=False, device=device, dtype=dtype) |
| value = torch.nn.Linear(emb_dims, emb_dims, bias=False, device=device, dtype=dtype) |
| |
| # Simplest case: 1 sentence, no batching |
| x_d1 = sen1.unsqueeze(0) |
| x_nt = torch.nested.as_nested_tensor([sen1], layout=torch.jagged) |
| |
| # See note below for why we detach here. |
| q_d1 = query(x_d1).view(batch_size, -1, n_heads, head_dims).detach().requires_grad_(True) |
| q_d1_t = q_d1.transpose(1, 2) |
| k_d1 = key(x_d1).view(batch_size, -1, n_heads, head_dims).detach().requires_grad_(True) |
| k_d1_t = k_d1.transpose(1, 2) |
| v_d1 = value(x_d1).view(batch_size, -1, n_heads, head_dims).detach().requires_grad_(True) |
| v_d1_t = v_d1.transpose(1, 2) |
| |
| q_nt = query(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().requires_grad_(True) |
| q_nt_t = q_nt.transpose(1, 2) |
| k_nt = key(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().requires_grad_(True) |
| k_nt_t = k_nt.transpose(1, 2) |
| v_nt = value(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().requires_grad_(True) |
| v_nt_t = v_nt.transpose(1, 2) |
| |
| # High Precision Math Reference |
| q_d1_f32 = q_d1.to(torch.float32) |
| k_d1_f32 = k_d1.to(torch.float32) |
| v_d1_f32 = v_d1.to(torch.float32) |
| q_d1_f32_t = q_d1_f32.transpose(1, 2) |
| k_d1_f32_t = k_d1_f32.transpose(1, 2) |
| v_d1_f32_t = v_d1_f32.transpose(1, 2) |
| out_ref = torch.ops.aten._scaled_dot_product_attention_math(q_d1_f32_t, k_d1_f32_t, v_d1_f32_t)[0] |
| grads_ref = torch.autograd.grad(out_ref.sum(), (q_d1_f32, k_d1_f32, v_d1_f32)) |
| |
| # Low Precision Math Reference |
| out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(q_d1_t, k_d1_t, v_d1_t)[0] |
| grads_lp_ref = torch.autograd.grad(out_lp_ref.sum(), (q_d1, k_d1, v_d1)) |
| |
| # Compute tolerances |
| output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref) |
| grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(grads_ref[0], grads_lp_ref[0]) |
| grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(grads_ref[1], grads_lp_ref[1]) |
| grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(grads_ref[2], grads_lp_ref[2]) |
| grad_atols = [grad_q_ref_atol, grad_k_ref_atol, grad_v_ref_atol] |
| grad_rtols = [grad_q_ref_rtol, grad_k_ref_rtol, grad_v_ref_rtol] |
| |
| attn_d1 = torch.nn.functional.scaled_dot_product_attention(q_d1_t, k_d1_t, v_d1_t).transpose(1, 2) |
| attn_nt = torch.nn.functional.scaled_dot_product_attention(q_nt_t, k_nt_t, v_nt_t).transpose(1, 2) |
| |
| self.assertEqual(attn_d1, attn_nt.unbind()[0].unsqueeze(0), atol=output_ref_atol, rtol=output_ref_rtol) |
| |
| # Simple case: 2 sentences, no extra params |
| x_d2 = sen2.unsqueeze(0) |
| x_nt = torch.nested.as_nested_tensor([sen1, sen2], layout=torch.jagged) |
| |
| # NB: we make sure the leaf tensor we compute gradients for is the view-ed tensor before |
| # it is transposed. This is because today we cannot backward through view or unbind a |
| # transposed tensor. |
| q_d2 = query(x_d2).view(batch_size, -1, n_heads, head_dims).detach().requires_grad_(True) |
| q_d2_t = q_d2.transpose(1, 2) |
| k_d2 = key(x_d2).view(batch_size, -1, n_heads, head_dims).detach().requires_grad_(True) |
| k_d2_t = k_d2.transpose(1, 2) |
| v_d2 = value(x_d2).view(batch_size, -1, n_heads, head_dims).detach().requires_grad_(True) |
| v_d2_t = v_d2.transpose(1, 2) |
| |
| q_nt = query(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().requires_grad_(True) |
| q_nt_t = q_nt.transpose(1, 2) |
| k_nt = key(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().requires_grad_(True) |
| k_nt_t = k_nt.transpose(1, 2) |
| v_nt = value(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().requires_grad_(True) |
| v_nt_t = v_nt.transpose(1, 2) |
| |
| attn_d2 = torch.nn.functional.scaled_dot_product_attention(q_d2_t, k_d2_t, v_d2_t).transpose(1, 2) |
| d1_grads = torch.autograd.grad(attn_d1.sum(), (q_d1, k_d1, v_d1)) |
| d2_grads = torch.autograd.grad(attn_d2.sum(), (q_d2, k_d2, v_d2)) |
| |
| def check_forward_backward(): |
| attn_nt = torch.nn.functional.scaled_dot_product_attention(q_nt_t, k_nt_t, v_nt_t).transpose(1, 2) |
| |
| attn_nts = attn_nt.unbind() |
| self.assertEqual(attn_d1, attn_nts[0].unsqueeze(0), atol=output_ref_atol, rtol=output_ref_rtol) |
| self.assertEqual(attn_d2, attn_nts[1].unsqueeze(0), atol=output_ref_atol, rtol=output_ref_rtol) |
| |
| nt_grads = torch.autograd.grad(buffer_from_jagged(attn_nt).sum(), (q_nt, k_nt, v_nt)) |
| for nt_grad, d1_grad, d2_grad, grad_atol, grad_rtol in zip(nt_grads, d1_grads, d2_grads, grad_atols, grad_rtols): |
| unbound_nt_grads = nt_grad.unbind() |
| self.assertEqual(d1_grad, unbound_nt_grads[0].unsqueeze(0), atol=grad_atol, rtol=grad_rtol) |
| self.assertEqual(d2_grad, unbound_nt_grads[1].unsqueeze(0), atol=grad_atol, rtol=grad_rtol) |
| |
| # Default |
| check_forward_backward() |
| |
| # Test dispatcher works by calling only mem-effn and math (as they are safe for all devices) |
| with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=True, enable_math=True): |
| check_forward_backward() |
| |
| # Test math fallback |
| with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): |
| # Math fallback doesn't work with bfloat16 on CUDA because |
| # "group_gemm_dispatch" not implemented for 'BFloat16' |
| if not (str(device).startswith("cuda") and dtype == torch.bfloat16): |
| check_forward_backward() |
| |
| # This requires NT -> NT views to work in inductor, which is a TODO |
| @unittest.expectedFailure # noqa: E301 |
| @onlyCUDA |
| @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if |
| SM80OrLater else [torch.float16, torch.float32]) |
| def test_sdpa_compile(self, device, dtype): |
| batch_size = 1 |
| emb_dims = 1024 |
| n_heads = 8 |
| head_dims = emb_dims // n_heads |
| |
| sen1 = torch.randn(11, emb_dims, dtype=dtype, device=device) |
| sen2 = torch.randn(13, emb_dims, dtype=dtype, device=device) |
| |
| query = torch.nn.Linear(emb_dims, emb_dims, bias=False, device=device, dtype=dtype) |
| key = torch.nn.Linear(emb_dims, emb_dims, bias=False, device=device, dtype=dtype) |
| value = torch.nn.Linear(emb_dims, emb_dims, bias=False, device=device, dtype=dtype) |
| |
| # Simplest case: 1 sentence, no batching |
| x_d1 = sen1.unsqueeze(0) |
| x_d2 = sen2.unsqueeze(0) |
| x_nt = torch.nested.as_nested_tensor([sen1, sen2], layout=torch.jagged) |
| |
| q_d1 = query(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) |
| k_d1 = key(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) |
| v_d1 = value(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) |
| q_d2 = query(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) |
| k_d2 = key(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) |
| v_d2 = value(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2) |
| |
| q_nt = query(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).transpose(1, 2) |
| k_nt = key(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).transpose(1, 2) |
| v_nt = value(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).transpose(1, 2) |
| |
| # High Precision Math Reference |
| q_d1_f32 = q_d1.to(torch.float32) |
| k_d1_f32 = k_d1.to(torch.float32) |
| v_d1_f32 = v_d1.to(torch.float32) |
| out_ref = torch.ops.aten._scaled_dot_product_attention_math(q_d1_f32, k_d1_f32, v_d1_f32)[0] |
| # Low Precision Math Reference |
| out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(q_d1, k_d1, v_d1)[0] |
| output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref) |
| |
| attn_d1 = torch.nn.functional.scaled_dot_product_attention(q_d1, k_d1, v_d1).transpose(1, 2) |
| attn_d2 = torch.nn.functional.scaled_dot_product_attention(q_d2, k_d2, v_d2).transpose(1, 2) |
| |
| compiled_sdpa = torch.compile(torch.nn.functional.scaled_dot_product_attention) |
| attn_nt = compiled_sdpa(q_nt, k_nt, v_nt).transpose(1, 2) |
| |
| attn_nts = attn_nt.unbind() |
| self.assertEqual(attn_d1, attn_nts[0].unsqueeze(0), atol=output_ref_atol, rtol=output_ref_rtol) |
| self.assertEqual(attn_d2, attn_nts[1].unsqueeze(0), atol=output_ref_atol, rtol=output_ref_rtol) |
| |
| @dtypes(torch.float32, torch.double, torch.half) |
| def test_sdpa_with_constant_sequence_length(self, device, dtype): |
| # shape (B, P*, S, D) |
| # B: batch size |
| # P*: ragged number of prompts |
| # S: (constant) sequence length |
| # D: embedding size |
| query = random_nt_from_dims( |
| [4, None, 8, 10], device=device, dtype=dtype, layout=torch.jagged) |
| key = random_nt_from_similar(query) |
| value = random_nt_from_similar(query) |
| output = F.scaled_dot_product_attention(query, key, value) |
| self.assertTrue(isinstance(output, NestedTensor)) |
| |
| # should be equivalent to just running the buffers through |
| output_dense = F.scaled_dot_product_attention(query._values, key._values, value._values) |
| self.assertEqual(output._values, output_dense) |
| |
| |
| instantiate_parametrized_tests(TestNestedTensor) |
| instantiate_device_type_tests(TestNestedTensorDeviceType, globals()) |
| instantiate_device_type_tests(TestNestedTensorAutograd, globals()) |
| instantiate_device_type_tests(TestNestedTensorSubclass, globals()) |
| |
| if __name__ == '__main__': |
| run_tests() |