blob: a8069b05f45c95cdb5928fab5ab4be5bf2d67b6f [file] [log] [blame]
# Owner(s): ["module: nestedtensor"]
import torch
import torch.nn
import unittest
from torch.testing._internal.common_device_type import (
dtypes,
dtypesIfCUDA,
instantiate_device_type_tests,
skipMeta,
)
from torch.testing._internal.common_utils import TestCase, IS_FBCODE, run_tests, freeze_rng_state
from torch import nested_tensor
# Tests are ported from pytorch/nestedtensor.
# This makes porting as_nested_tensor easier in the future.
def _iter_constructors():
# yield as_nested_tensor
yield nested_tensor
class TestNestedTensor(TestCase):
@torch.inference_mode()
def _test_unbind_case(self, a, b):
nt = nested_tensor([a, b])
a1, b1 = nt.unbind()
self.assertTrue(a is not a1)
self.assertTrue(b is not b1)
nt = nested_tensor([a, b], dtype=a.dtype)
a1, b1 = nt.unbind(0)
self.assertEqual(a, a1)
self.assertEqual(b, b1)
a = torch.randn((2, 3)).add_(1)
nt = nested_tensor([a])
self.assertEqual(a, nt.unbind(0)[0])
@torch.inference_mode()
def test_unbind_0(self):
self._test_unbind_case(
torch.tensor([1, 2]), torch.tensor([7, 8]),
)
@torch.inference_mode()
def test_unbind_1(self):
self._test_unbind_case(
torch.tensor([1]), torch.tensor([7]),
)
# @torch.inference_mode()
# def test_unbind_2(self):
# self._test_unbind_case(
# torch.tensor(1), torch.tensor(7),
# )
@torch.inference_mode()
def test_unbind_3(self):
self._test_unbind_case(
torch.tensor([1.0]), torch.tensor([]),
)
@torch.inference_mode()
def test_unbind_4(self):
self._test_unbind_case(
torch.tensor([]), torch.tensor([]),
)
@torch.inference_mode()
def test_unbind_dim(self):
def _test_fn(unbind_fn):
a = torch.rand(3, 2)
b = torch.rand(2, 3)
nt = nested_tensor([a, b])
self.assertRaises(RuntimeError, lambda: unbind_fn(nt, 1))
# Both of these tests are necessary, because we're using
# torch_function.
_test_fn(lambda x, dim: x.unbind(dim))
# TODO: Re-enable this once using torch_dispatch
# _test_fn(lambda x, dim: torch.unbind(x, dim))
@torch.inference_mode()
def test_nested_tensor(self):
self.assertRaises(TypeError, lambda: nested_tensor([3.0]))
self.assertRaises(TypeError, lambda: nested_tensor(torch.tensor([3.0])))
self.assertRaises(TypeError, lambda: nested_tensor(4.0))
@torch.inference_mode()
def test_nested_tensor_matching_dim(self):
self.assertRaisesRegex(
RuntimeError,
"Found dimension 1 for Tensor at index 1 and dimension 0 for Tensor at index 0.",
lambda: nested_tensor([torch.tensor(1.0), torch.tensor([])]),
)
self.assertRaisesRegex(
RuntimeError,
"Found dimension 1 for Tensor at index 2 and dimension 0 for Tensor at index 1.",
lambda: nested_tensor(
[torch.tensor(1.0), torch.tensor(2.0), torch.tensor([])]
),
)
@torch.inference_mode()
def test_default_nested_tensor(self):
self.assertRaises(TypeError, lambda: nested_tensor())
default_nested_tensor = nested_tensor([])
default_tensor = torch.tensor([])
# self.assertEqual(default_nested_tensor.nested_dim(), 1)
# self.assertEqual(default_nested_tensor.nested_size(), ())
self.assertEqual(default_nested_tensor.dim(), default_tensor.dim())
self.assertEqual(default_nested_tensor.layout, default_tensor.layout)
self.assertEqual(default_nested_tensor.device, default_tensor.device)
self.assertEqual(default_nested_tensor.dtype, default_tensor.dtype)
self.assertEqual(
default_nested_tensor.requires_grad, default_tensor.requires_grad
)
self.assertIsNone(default_tensor.grad)
# TODO: Re-enable once we have a performance driven
# use case and implementation.
# self.assertEqual(default_nested_tensor.is_pinned(),
# default_tensor.is_pinned())
@torch.inference_mode()
def test_dim(self):
for constructor in _iter_constructors():
a1 = constructor([])
self.assertEqual(a1.dim(), 1)
a1 = constructor([torch.tensor(3.0)])
self.assertEqual(a1.dim(), 1)
a1 = constructor([torch.tensor([1, 2, 3, 4])])
self.assertEqual(a1.dim(), 2)
@unittest.skipIf(IS_FBCODE, "numel is not virtual in fbcode.")
@torch.inference_mode()
def test_numel(self):
for constructor in _iter_constructors():
a1 = constructor([])
self.assertEqual(a1.numel(), 0)
a1 = constructor([torch.tensor(3.0), torch.tensor(4.0)])
self.assertEqual(a1.numel(), 2)
a1 = constructor([torch.randn(2, 2, 2)])
self.assertEqual(a1.numel(), 8)
a1 = constructor([torch.randn([1, 2, 3]), torch.randn(3, 2, 1)])
self.assertEqual(a1.numel(), 12)
a1 = constructor([torch.randn([1, 1, 3]), torch.randn(3, 2, 4)])
self.assertEqual(a1.numel(), 27)
a1 = constructor([torch.randn([5, 5, 5]), torch.randn(6, 6, 6)])
self.assertEqual(a1.numel(), 341)
# Interesting edge case
a1 = constructor([torch.randn([1, 2, 3]), torch.randn(1, 2, 0)])
self.assertEqual(a1.numel(), 6)
@torch.inference_mode()
def test_size(self):
for constructor in _iter_constructors():
a1 = constructor([])
self.assertRaisesRegex(
RuntimeError,
"Tensors of type NestedTensorImpl do not have sym sizes"
if IS_FBCODE
else "NestedTensorImpl doesn't support sizes",
lambda: a1.size(),
)
@unittest.skipIf(IS_FBCODE, "stride is not virtual in fbcode.")
@torch.inference_mode()
def test_stride(self):
for constructor in _iter_constructors():
a1 = constructor([])
self.assertRaisesRegex(
RuntimeError,
"NestedTensorImpl doesn't support strides",
lambda: a1.stride(),
)
@unittest.skipIf(IS_FBCODE, "is_contiguous is not virtual in fbcode.")
@torch.inference_mode()
def test_is_contiguous(self):
for constructor in _iter_constructors():
a1 = constructor([])
self.assertRaisesRegex(
RuntimeError, "is_contiguous is disabled", lambda: a1.is_contiguous()
)
@torch.inference_mode()
def test_repr_string(self):
a = nested_tensor([])
expected = "nested_tensor([" "\n\n])"
self.assertEqual(str(a), expected)
self.assertEqual(repr(a), expected)
a = nested_tensor([torch.tensor(1.0)])
expected = "nested_tensor([" "\n tensor(1.)" "\n])"
self.assertEqual(str(a), expected)
self.assertEqual(repr(a), expected)
a = nested_tensor([torch.tensor([[1, 2]]), torch.tensor([[4, 5]])])
expected = (
"nested_tensor([" "\n tensor([[1, 2]])" "," "\n tensor([[4, 5]])" "\n])"
)
self.assertEqual(str(a), expected)
self.assertEqual(repr(a), expected)
@torch.inference_mode()
def test_activations(self):
for func in (torch.nn.functional.relu, torch.nn.functional.relu_, torch.nn.functional.gelu, torch._C._nn.gelu_):
t = torch.tensor([-1, 0, 1], dtype=torch.float)
nt = nested_tensor([t])
nested_result = func(nt)
self.assertTrue(nested_result.is_nested)
self.assertEqual(func(t), nested_result.unbind()[0])
def test_to_padded_tensor_on_empty_tensor(self):
nt = torch.nested_tensor([])
empty = nt.to_padded_tensor(4)
self.assertEqual(empty, torch.tensor([]))
class TestNestedTensorDeviceType(TestCase):
@dtypes(torch.float)
@skipMeta
def test_to_then_from_padded_tensor_no_transform0213(self, device, dtype):
t = torch.randn(4, 4, 4, device=device, dtype=dtype)
ts = list(torch.unbind(t))
ts[0] = ts[0][:-1]
nt = torch.nested_tensor(ts, device=device, dtype=dtype)
padded = nt.to_padded_tensor(0)
nt_to = torch._nested_from_padded_and_nested_example(padded, nt)
for (t1, t2) in zip(nt.unbind(), nt_to.unbind()):
self.assertEqual(t1, t2)
self.assertEqual(nt.device, nt_to.device)
@dtypes(torch.float)
@dtypesIfCUDA(torch.float, torch.half)
@skipMeta
@torch.inference_mode()
def test_layer_norm(self, device, dtype):
def _test(size):
t0 = torch.randn(2, size, device=device, dtype=dtype, requires_grad=False)
t1 = torch.randn(2, size, device=device, dtype=dtype, requires_grad=False)
ts = [t0, t1, t0, t1]
nt = torch.nested_tensor(ts, device=device, dtype=dtype)
layer_norm = torch.nn.LayerNorm(size, device=device, dtype=dtype)
nt_result = nt._nested_tensor_layer_norm(
layer_norm.weight, layer_norm.bias, 1e-5
)
for (nt_subresult, t) in zip(nt_result.unbind(), ts):
t_result = layer_norm(t.reshape(1, -1, size).squeeze(0))
self.assertEqual(nt_subresult, t_result)
for size in (1024, 1023, 513, 512, 256, 128, 2, 4, 32):
_test(size)
@skipMeta
@torch.inference_mode()
def test_embedding(self, device):
inputs = [
torch.randint(100, (L,), device=device, dtype=torch.int64)
for L in torch.randint(5, 50, (8,))
]
x = torch.nested_tensor(inputs, device=device, dtype=torch.int64)
emb = torch.nn.Embedding(100, 8, device=device)
y = emb(x)
ys = y.unbind()
for i, inp in enumerate(inputs):
self.assertEqual(emb(inp), ys[i])
@dtypes(torch.float, torch.float16)
def test_to_padded_tensor_simple(self, device, dtype):
t = torch.randn(4, 4, 4, device=device, dtype=dtype)
ts = list(torch.unbind(t))
ts[0] = ts[0][:-1]
nt = torch.nested_tensor(ts, device=device, dtype=dtype)
for padding_value in (0, 1):
padded = nt.to_padded_tensor(padding_value)
correct_output = t.clone()
if padding_value == 0:
correct_output[0][-1] = torch.zeros_like(correct_output[0][-1])
else:
correct_output[0][-1] = torch.ones_like(correct_output[0][-1])
self.assertEqual(padded, correct_output)
self.assertEqual(padded.device, torch.device(device))
self.assertEqual(padded.dtype, dtype)
@dtypes(torch.float, torch.float16)
def test_to_padded_tensor_output_size(self, device, dtype):
t = torch.randn(4, 4, 4, device=device, dtype=dtype)
output_size = (4, 6, 5)
ts = list(torch.unbind(t))
ts[0] = ts[0][:-1]
nt = torch.nested_tensor(ts, device=device, dtype=dtype)
for padding_value in (0, 1):
padded = nt.to_padded_tensor(padding_value, output_size=output_size)
correct_output = torch.ones(output_size, device=device, dtype=dtype) * padding_value
correct_output[:4:, :4, :4] = t.clone()
if padding_value == 0:
correct_output[0][3] = torch.zeros_like(correct_output[0][3])
else:
correct_output[0][3] = torch.ones_like(correct_output[0][3])
self.assertEqual(padded, correct_output)
self.assertEqual(padded.device, torch.device(device))
self.assertEqual(padded.dtype, dtype)
@dtypes(torch.float, torch.float16, torch.double)
def test_to_padded_tensor_dim2(self, device, dtype):
ts = [
torch.randn(160, device=device, dtype=dtype),
torch.randn(1240, device=device, dtype=dtype),
torch.randn(2400, device=device, dtype=dtype),
]
nt = torch.nested_tensor(ts, device=device, dtype=dtype)
pad = 42
correct_output = []
for t in ts:
next_output = torch.ones_like(ts[2]) * pad
correct_output.append(next_output)
next_output[:t.size(0)].copy_(t)
correct_output = torch.stack(correct_output)
padded = nt.to_padded_tensor(pad)
self.assertEqual(padded, correct_output)
@dtypes(torch.float, torch.float16, torch.double)
def test_to_padded_tensor_dim3(self, device, dtype):
ts = [
torch.randn(16, 21, device=device, dtype=dtype),
torch.randn(24, 32, device=device, dtype=dtype),
torch.randn(40, 53, device=device, dtype=dtype),
]
nt = torch.nested_tensor(ts, device=device, dtype=dtype)
pad = 42
correct_output = []
for t in ts:
next_output = torch.ones_like(ts[2]) * pad
correct_output.append(next_output)
next_output[:t.size(0), :t.size(1)].copy_(t)
correct_output = torch.stack(correct_output)
padded = nt.to_padded_tensor(pad)
self.assertEqual(padded, correct_output)
@dtypes(torch.float, torch.float16, torch.double)
def test_to_padded_tensor_dim4(self, device, dtype):
ts = [
torch.randn(16, 21, 13, device=device, dtype=dtype),
torch.randn(24, 32, 14, device=device, dtype=dtype),
torch.randn(40, 53, 16, device=device, dtype=dtype),
]
nt = torch.nested_tensor(ts, device=device, dtype=dtype)
pad = 42
correct_output = []
for t in ts:
next_output = torch.ones_like(ts[2]) * pad
correct_output.append(next_output)
next_output[:t.size(0), :t.size(1), :t.size(2)].copy_(t)
correct_output = torch.stack(correct_output)
padded = nt.to_padded_tensor(pad)
self.assertEqual(padded, correct_output)
@skipMeta
def test_device_checks(self, device):
nt = torch.nested_tensor([], device=device)
is_cuda = 'cuda' in str(device)
self.assertEqual(nt.is_cuda, is_cuda)
@dtypes(torch.float, torch.float16, torch.double)
def test_nested_tensor_indexing(self, device, dtype):
# edge case: empty nested tensor
nt0 = torch.nested_tensor([])
self.assertRaises(IndexError, lambda: nt0[0])
# normal case
x0 = torch.randn((2, 5), device=device, dtype=dtype)
x1 = torch.randn((3, 4), device=device, dtype=dtype)
nt = torch.nested_tensor([x0, x1])
# single index: only support integer in the batch dimension
self.assertEqual(nt[0], x0)
self.assertEqual(nt[-1], x1)
self.assertRaises(IndexError, lambda: nt[2])
self.assertRaises(IndexError, lambda: nt[-3])
self.assertRaises(NotImplementedError, lambda: nt[:])
self.assertRaises(NotImplementedError, lambda: nt[None])
self.assertRaises(NotImplementedError, lambda: nt[...])
# tuple of indices: only support integer in the batch dimension
# + all possible indexing in the original tensor dimensions
self.assertEqual(nt[0, 0, 0], x0[0, 0])
self.assertEqual(nt[0, 1, :], x0[1, :])
self.assertEqual(nt[1, ...], x1)
self.assertRaises(IndexError, lambda: nt[1, 4, 2])
self.assertRaises(NotImplementedError, lambda: nt[:, 1, 1])
# make sure indexing returns a view
nt[0].fill_(100.0)
answer = torch.tensor(100.0, device=device, dtype=dtype).expand((2, 5))
self.assertEqual(nt[0], answer)
nt[1, 1, :].fill_(200.0)
answer = torch.tensor(200.0, device=device, dtype=dtype).expand(4)
self.assertEqual(nt[1, 1, :], answer)
# Helper functions for testing elementwise ops
def random_nt(self, device, dtype, num_tensors, max_dims, min_dims=None):
if min_dims is None:
min_dims = tuple([0] * len(max_dims))
ts1 = []
for _ in range(num_tensors):
tensor_dims = tuple([torch.randint(low=min_dim, high=max_dim, size=(1,)).item()
for (min_dim, max_dim) in zip(min_dims, max_dims)])
t1 = torch.randn(tensor_dims, device=device, dtype=dtype)
ts1.append(t1)
return torch.nested_tensor(ts1, device=device, dtype=dtype)
# Helper functions for testing elementwise ops
def random_nt_pair(self, device, dtype, num_tensors, max_dims):
ts1 = []
ts2 = []
for _ in range(num_tensors):
tensor_dims = tuple([torch.randint(low=0, high=max_dim, size=(1,)).item() for max_dim in max_dims])
t1 = torch.randn(tensor_dims, device=device, dtype=dtype)
t2 = torch.randn(tensor_dims, device=device, dtype=dtype)
ts1.append(t1)
ts2.append(t2)
return (torch.nested_tensor(ts1, device=device, dtype=dtype),
torch.nested_tensor(ts2, device=device, dtype=dtype))
def nt_equal(self, nt1, nt2):
self.assertEqual(nt1.dtype, nt2.dtype)
self.assertEqual(nt1.device, nt2.device)
ub1 = nt1.unbind()
ub2 = nt2.unbind()
self.assertEqual(len(ub1), len(ub2))
n = len(ub1)
for i in range(n):
self.assertEqual(ub1[i], ub2[i])
@dtypes(torch.float, torch.float16)
@skipMeta
@torch.inference_mode()
def test_nested_tensor_add(self, device, dtype):
(nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
ref = torch.nested_tensor([t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())])
out = nt1 + nt2
self.nt_equal(ref, out)
@dtypes(torch.float, torch.float16)
@skipMeta
@torch.inference_mode()
def test_nested_tensor_mul(self, device, dtype):
# nested tensor * nested tensor
(nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
ref = torch.nested_tensor([t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())])
out = nt1 * nt2
self.nt_equal(ref, out)
# nested tensor * scalar
number = 10.0
scalar = torch.tensor(number).to(dtype).to(device)
ref = torch.nested_tensor([t * number for t in nt1.unbind()])
out_number0 = nt1 * number
out_number1 = number * nt1
out_scalar0 = nt1 * scalar
out_scalar1 = scalar * nt1
self.nt_equal(out_number0, ref)
self.nt_equal(out_number1, ref)
self.nt_equal(out_scalar0, ref)
self.nt_equal(out_scalar1, ref)
# error case: numel == 1 but dim > 0
vector = torch.tensor([number]).to(dtype).to(device)
self.assertRaisesRegex(
RuntimeError,
"Expected both self and other to be nested, but got a nested self and non-nested other",
lambda: nt1.mul(vector)
)
self.assertRaisesRegex(
RuntimeError,
"Expected both self and other to be nested, but got a non-nested self and nested other",
lambda: vector.mul(nt1)
)
@dtypes(torch.float, torch.float16)
@skipMeta
@torch.inference_mode()
def test_nested_tensor_add_in_place(self, device, dtype):
(nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
ref = torch.nested_tensor([t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())])
nt1 += nt2
self.nt_equal(ref, nt1)
@dtypes(torch.float, torch.float16)
@skipMeta
@torch.inference_mode()
def test_nested_tensor_mul_in_place(self, device, dtype):
# nested tensor * nested tensor
(nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
ref = torch.nested_tensor([t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())])
nt1 *= nt2
self.nt_equal(ref, nt1)
# nested tensor * scalar
number = 10.0
scalar = torch.tensor(number).to(dtype).to(device)
ref = torch.nested_tensor([t * number for t in nt1.unbind()])
out_number = nt1.clone()
out_number *= number
out_scalar = nt1.clone()
out_scalar *= scalar
self.nt_equal(out_number, ref)
self.nt_equal(out_scalar, ref)
self.assertRaisesRegex(
RuntimeError,
r"output with shape \[.*\] doesn't match the broadcast shape \[.*\]",
lambda: scalar.mul_(nt1)
)
# error case: numel == 1 but dim > 0
vector = torch.tensor([number]).to(dtype).to(device)
self.assertRaisesRegex(
RuntimeError,
"Expected both self and other to be nested, but got a nested self and non-nested other",
lambda: nt1.mul_(vector)
)
self.assertRaisesRegex(
RuntimeError,
"Expected both self and other to be nested, but got a non-nested self and nested other",
lambda: vector.mul_(nt1)
)
@dtypes(torch.float, torch.float16)
@skipMeta
@torch.inference_mode()
def test_clone(self, device, dtype):
nt1 = self.random_nt(device, dtype, 4, (4, 4), (1, 1))
nt2 = nt1.clone()
# Verify the values match
self.nt_equal(nt1, nt2)
# Verify modifying nt2 doesn't affect nt1
nt2.mul_(nt1)
ub1 = nt1.unbind()
ub2 = nt2.unbind()
for i in range(len(ub1)):
self.assertNotEqual(ub1[i], ub2[i])
nt1.clone(memory_format=torch.preserve_format)
msg = "clone_nested only supports memory format Preserve, but got ChannelsLast instead."
with self.assertRaisesRegex(RuntimeError, msg):
nt1.clone(memory_format=torch.channels_last)
# cannot test torch.float16 because: RuntimeError: "bernoulli_scalar_cpu_" not implemented for 'Half'
@dtypes(torch.float, torch.double)
@torch.inference_mode()
def test_dropout(self, device, dtype):
# edge case: empty nested tensor
nt0 = torch.nested_tensor([])
y = torch.nn.functional.dropout(nt0, 0.5)
self.nt_equal(nt0, y)
# normal nested tensor
ntensors = 4
nt = self.random_nt(device, dtype, ntensors, (4, 4))
# edge case: invalid dropout
self.assertRaises(ValueError, lambda: torch.nn.Dropout(-0.1))
self.assertRaises(ValueError, lambda: torch.nn.Dropout(1.1))
self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, -0.1))
self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, 1.1))
# edge case: no dropout
dropouter = torch.nn.Dropout(0.0)
y0 = dropouter(nt)
y1 = torch.nn.functional.dropout(nt, 0.0)
self.nt_equal(nt, y0)
self.nt_equal(nt, y1)
# edge case: all dropout
dropouter = torch.nn.Dropout(1.0)
y0 = dropouter(nt)
y1 = torch.nn.functional.dropout(nt, 1.0)
nt0 = nt.clone()
for i in range(ntensors):
nt0[i].fill_(0.0)
self.nt_equal(nt0, y0)
self.nt_equal(nt0, y1)
# normal case: normal dropout
p = 0.2
y = torch.nn.functional.dropout(nt, p)
expect = nt.clone()
for i in range(ntensors):
actual_tensor = y[i].view(-1)
expect_tensor = expect[i].view(-1)
for j in range(actual_tensor.shape[0]):
if actual_tensor[j].item() == 0.0:
expect_tensor[j] = 0.0
else:
expect_tensor[j] /= 1.0 - p
self.nt_equal(y, expect)
with freeze_rng_state():
dropouter = torch.nn.Dropout(p)
y0 = dropouter(nt)
with freeze_rng_state():
y1 = torch.nn.functional.dropout(nt, p)
self.nt_equal(y0, y1)
# inplace
# in principle, since we have established the correctness of functional, we could simply compare inplace vs functional
# in practice, cuda functional has its own implementation to skip `bernoulli_`
# so cuda functional will differ from cuda inplace causing test failure
# in `test_dropout_cuda_float64 (__main__.TestNestedTensorDeviceTypeCUDA)`
# on `linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 2, 4, linux.4xlarge.nvidia.gpu)`
expect = nt.clone()
torch.nn.functional.dropout(nt, p, inplace=True)
for i in range(ntensors):
actual_tensor = nt[i].view(-1)
expect_tensor = expect[i].view(-1)
for j in range(actual_tensor.shape[0]):
if actual_tensor[j].item() == 0.0:
expect_tensor[j] = 0.0
else:
expect_tensor[j] /= 1.0 - p
self.nt_equal(nt, expect)
# cannot test torch.float16 because: RuntimeError: "softmax_kernel_impl" not implemented for 'Half'
@dtypes(torch.float, torch.double)
@torch.inference_mode()
def test_softmax(self, device, dtype):
# normal nested tensor
ntensors = 4
nt = self.random_nt(device, dtype, ntensors, (4, 4))
# error case: softmax across nested dimension
self.assertRaises(RuntimeError, lambda: torch.nn.functional.softmax(nt, 0))
self.assertRaises(RuntimeError, lambda: torch.nn.functional.softmax(nt, -3))
# error case: dimension out of range
self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt, 3))
self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt, -4))
# normal case: should equal to padding -inf
softmaxer = torch.nn.Softmax(1)
y0 = softmaxer(nt)
y1 = torch.nn.functional.softmax(nt, 1)
self.nt_equal(y0, y1)
pt = nt.to_padded_tensor(float("-inf"))
# if an entire slice is padded, then softmax will return 0.0 / 0.0 = nan
# however, physically speaking that should be 0.0
expect = torch.nn.functional.softmax(pt, 1).nan_to_num_(0.0)
self.assertEqual(y0.to_padded_tensor(0.0), expect)
# edge case: empty nested tensor
nt0 = torch.nested_tensor([])
y = torch.nn.functional.softmax(nt0, 1)
self.nt_equal(nt0, y)
# edge case: nesting scalars
nt1 = torch.nested_tensor([torch.tensor(0.0), torch.tensor(1.0)])
self.assertRaises(RuntimeError, lambda: torch.nn.functional.softmax(nt1, 0))
self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt1, 1))
@dtypes(torch.float, torch.float16, torch.double)
@torch.inference_mode()
def test_bmm(self, device, dtype):
# error case: not 3D tensors
nt0 = torch.nested_tensor([])
nt1 = torch.nested_tensor([torch.randn(2), torch.randn(3)])
nt2 = torch.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))])
self.assertRaisesRegex(RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt0))
self.assertRaisesRegex(RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt1))
self.assertRaisesRegex(RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt2))
self.assertRaisesRegex(RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt0))
self.assertRaisesRegex(RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt1))
self.assertRaisesRegex(RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt2))
self.assertRaisesRegex(RuntimeError, "batch2 must be a 3D tensor", lambda: nt2.bmm(nt0))
self.assertRaisesRegex(RuntimeError, "batch2 must be a 3D tensor", lambda: nt2.bmm(nt1))
# error case: incompatible batch size
nt0 = torch.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))])
nt1 = torch.nested_tensor([torch.randn((4, 6)), torch.randn((4, 5)), torch.randn((4, 7))])
self.assertRaisesRegex(
RuntimeError,
"Expected size for the 1st dimension of batch2 tensor to be: 2 but got: 3.",
lambda: nt0.bmm(nt1)
)
self.assertRaisesRegex(
RuntimeError,
"Expected size for the 1st dimension of batch2 tensor to be: 3 but got: 2.",
lambda: nt1.bmm(nt0)
)
# error case: underlying matrices cannot be multiplied
nt0 = torch.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))])
self.assertRaisesRegex(
RuntimeError,
r"0-th nested matrices in batch cannot be multiplied \(2x4 and 2x4\)",
lambda: nt0.bmm(nt0)
)
# normal nested tensor
nt0 = torch.nested_tensor([torch.randn((2, 4)), torch.randn((3, 7))])
nt1 = torch.nested_tensor([torch.randn((4, 6)), torch.randn((7, 5))])
actual = nt0.bmm(nt1)
expect = nt0.to_padded_tensor(0.0).bmm(nt1.to_padded_tensor(0.0))
self.assertEqual(actual.to_padded_tensor(0.0), expect)
@dtypes(torch.float, torch.double)
def test_linear(self, device, dtype):
a = torch.randn(1, 2, device=device, dtype=dtype)
b = torch.randn(2, 2, device=device, dtype=dtype)
c = torch.randn(3, 2, device=device, dtype=dtype)
nt = torch.nested_tensor([a, b, c])
weight = torch.randn(2, 2, device=device, dtype=dtype)
bias = torch.randn(2, device=device, dtype=dtype)
# success case
torch.functional.F.linear(nt, weight, bias)
# invalid nested tensor dimension
msg = r'Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 2. Dense tensor dim: 2'
nt1 = torch.nested_tensor([torch.randn(1, device=device, dtype=dtype),
torch.randn(2, device=device, dtype=dtype)])
with self.assertRaisesRegex(RuntimeError, msg):
torch.functional.F.linear(nt1, weight, bias)
# invalid weight shape
msg = r'Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 3. Dense tensor dim: 3'
weight1 = torch.randn(2, 2, 3, device=device, dtype=dtype)
with self.assertRaisesRegex(RuntimeError, msg):
torch.functional.F.linear(nt, weight1, bias)
# inconsistent last dim of nested tensor
msg = r"all tensors in NestedTensor must have the same trailing dim"
nt2 = torch.nested_tensor([torch.randn(1, 2, device=device, dtype=dtype),
torch.randn(2, 3, device=device, dtype=dtype)])
with self.assertRaisesRegex(RuntimeError, msg):
torch.functional.F.linear(nt2, weight, bias)
# Mismatch of nested tensor last dim and weight dimension
weight2 = torch.randn(2, 4, device=device, dtype=dtype)
msg = r"Shape mismatch for NestedTensor Linear: Expected input's \(a nested tensor\) 'last_dim'" \
r" to equal 'weight.size\(1\), but got: last_dim = 2, and weight.size\(1\) = 4"
with self.assertRaisesRegex(RuntimeError, msg):
torch.functional.F.linear(nt, weight2, bias)
# Nested tensor input and nested weight
nt_weight = nt.clone()
msg = r"Linear does not support nested weight when input is a nested tensor."
with self.assertRaisesRegex(RuntimeError, msg):
torch.functional.F.linear(nt, nt_weight, bias)
class TestNestedTensorAutograd(TestCase):
def nt_equal(self, nt1, nt2):
self.assertEqual(nt1.dtype, nt2.dtype)
self.assertEqual(nt1.device, nt2.device)
ub1 = nt1.unbind()
ub2 = nt2.unbind()
self.assertEqual(len(ub1), len(ub2))
n = len(ub1)
for i in range(n):
self.assertEqual(ub1[i], ub2[i])
def _create_nested_tensor_from_list(self, requires_grad=False):
return torch.nested_tensor([torch.randn(1, 2, requires_grad=requires_grad),
torch.randn(7, 8, requires_grad=requires_grad)])
def _create_nested_tensor_from_mask(self, requires_grad=False):
data = torch.randn(2, 3, 4, requires_grad=requires_grad)
mask = torch.ones_like(data[:, :, 0]).bool()
return torch._nested_tensor_from_mask(data, mask)
def test_set_requires_grad_from_list(self):
nt = self._create_nested_tensor_from_list()
nt.requires_grad_()
assert nt.requires_grad
def test_set_requires_grad_from_mask(self):
nt = self._create_nested_tensor_from_mask()
nt.requires_grad_()
assert nt.requires_grad
def test_backward_for_add_op(self):
nt_1 = self._create_nested_tensor_from_mask()
nt_2 = self._create_nested_tensor_from_mask()
nt_1.requires_grad_()
c = nt_1 + nt_2
assert nt_1.requires_grad
assert c.requires_grad
grad_output = self._create_nested_tensor_from_mask()
c.backward(grad_output)
# Grad check doesn't work with nested yet.
# d/dnt_1 (nt + nt_1) = 1*grad_output
self.nt_equal(nt_1.grad, grad_output)
# Test Factory Functions
def test_nested_tensor_to_padded_tensor(self):
for padding_val in [0, 1]:
nt = torch.nested_tensor([torch.randn(1, 2), torch.randn(7, 8)])
nt.requires_grad_()
out = nt.to_padded_tensor(padding_val)
grad_output = torch.ones(out.shape)
out.backward(grad_output)
self.nt_equal(nt.grad, torch.nested_tensor([torch.ones(1, 2), torch.ones(7, 8)]))
def test_nested_tensor_from_mask_and_to_padded(self):
N, L, D = 2, 4, 4
mask = torch.ones(N, L)
for i in range(1, N):
end = torch.randint(1, L - 1, (1,))
mask[i, end:] = 0
mask[0, :] = 1
mask = mask.bool()
data = torch.randn(N, L, D, requires_grad=True, dtype=torch.float64)
def grad_test_func(inpt):
nt = torch._nested_tensor_from_mask(inpt, mask)
# This implicitly tests to_padded_tensor grads
return nt.to_padded_tensor(0)
assert torch.autograd.gradcheck(grad_test_func, inputs=data)
def test_nested_tensor_from_padded(self):
nested_size = torch.tensor([[1, 2], [2, 2]])
padded_tensor = torch.randn(2, 2, 2, dtype=torch.float64)
padded_tensor[0, 1, :] = 0
padded_tensor.requires_grad_()
def grad_test_func(tensor, nested_size):
nt = torch._nested_from_padded(tensor, nested_size, fuse_transform_0213=False)
# This implicitly tests to_padded_tensor grads
return nt.to_padded_tensor(0)
data = (padded_tensor, nested_size)
assert torch.autograd.gradcheck(grad_test_func, inputs=data)
def test_nested_tensor_from_padded_fused(self):
nested_size = torch.tensor([[1, 8], [2, 8]])
padded_tensor = torch.randn(2, 2, 2, 4, dtype=torch.float64)
padded_tensor[0, 1, :] = 0
padded_tensor.requires_grad_()
def grad_test_func(tensor, nested_size):
nt = torch._nested_from_padded(tensor, nested_size, fuse_transform_0213=True)
# This implicitly tests to_padded_tensor grads
return nt.to_padded_tensor(0)
data = (padded_tensor, nested_size)
assert torch.autograd.gradcheck(grad_test_func, inputs=data)
def test_nested_tensor_from_list(self):
a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64)
b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64)
c = torch.randn(10, 2, requires_grad=True, dtype=torch.float64)
def grad_test_func(a, b, c):
c = torch.nested_tensor([a, b, c])
# This implictily tests to_padded_tensor grads
return c.to_padded_tensor(0)
data = (a, b, c)
assert torch.autograd.gradcheck(grad_test_func, inputs=data)
def test_size_dim(self):
a = torch.nested_tensor([])
self.assertEqual(a.size(0), 0)
a = torch.nested_tensor([torch.tensor(1)])
self.assertEqual(a.size(0), 1)
a = torch.nested_tensor([torch.tensor(1), torch.tensor(2)])
self.assertEqual(a.size(0), 2)
a = torch.nested_tensor([torch.rand(1, 2),
torch.rand(1, 8)])
self.assertEqual(a.size(0), 2)
self.assertEqual(a.size(1), 1)
self.assertRaisesRegex(
RuntimeError, "Given dimension 2 is irregular and does not have a size", lambda: a.size(2))
a = torch.nested_tensor([torch.rand(3, 4),
torch.rand(5, 4)])
self.assertEqual(a.size(0), 2)
self.assertRaisesRegex(
RuntimeError, "Given dimension 1 is irregular and does not have a size", lambda: a.size(1))
self.assertEqual(a.size(2), 4)
def test_nested_tensor_linear(self):
a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64)
b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64)
c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64)
weight = torch.randn(2, 2, requires_grad=True, dtype=torch.float64)
bias = torch.randn(2, requires_grad=True, dtype=torch.float64)
def grad_test_func(a, b, c, weight, bias=None):
nt = torch.nested_tensor([a, b, c])
# This implicitly tests to_padded_tensor grads
d = torch.functional.F.linear(nt, weight, bias)
return d.to_padded_tensor(0)
data = (a, b, c, weight, bias)
assert torch.autograd.gradcheck(grad_test_func, inputs=data)
# Test linear with no bias added
data = (a, b, c, weight)
assert torch.autograd.gradcheck(grad_test_func, inputs=data)
def test_nested_tensor_linear_backward(self):
a = torch.randn(1, 2, requires_grad=False)
b = torch.randn(2, 2, requires_grad=False)
c = torch.randn(3, 2, requires_grad=False)
weight = torch.randn(2, 2, requires_grad=True)
bias = torch.randn(2, requires_grad=True)
nt = torch.nested_tensor([a, b, c])
out = torch.functional.F.linear(nt, weight, bias)
out.backward(out.clone())
assert weight.grad is not None
assert bias.grad is not None
assert a.grad is None
assert b.grad is None
assert c.grad is None
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
if __name__ == '__main__':
run_tests()