Add scatter support for view operations (#79939)
* Add scatter support for view operations; #78074, #78886, #79672
* Update test_slicing_replace_column to properly test different sizes
* Handle in-place changes for binary ops; add new testcase
* Add new view ops testing scatter; add MPSDebugConfig.h config file for debugging purposes
* Merge gatherViewTensor and scatterViewTensor into a generic function
* Add scatter on demand in scatterViewOperation instead of caching it into a generic graph
* Create separate graphs for scatter and gather;
* Create scatter graph at scatter time
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79939
Approved by: https://github.com/razarmehr
diff --git a/test/test_mps.py b/test/test_mps.py
index 0bb7ffc..6264541 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -15,10 +15,12 @@
import itertools
from torch._six import inf
from torch.nn import Parameter
-from torch.testing._internal.common_utils import run_tests, TestCase, download_file, TEST_WITH_UBSAN
+from torch.testing._internal.common_utils import run_tests, TestCase, download_file, TEST_WITH_UBSAN, gradcheck, gradgradcheck
+from torch.testing import make_tensor
from torch.testing._comparison import TensorLikePair
import torch.backends.mps
from torch.distributions import Uniform, Exponential
+from functools import partial
from torch.testing._internal.common_nn import NNTestCase
import numpy as np
@@ -4782,6 +4784,841 @@
m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
self._test_addmm_addmv(torch.addmm, M, m1, m2, transpose_out=t4)
+class TestGatherScatter(TestCase):
+ def test_slicing_with_step(self):
+ # Slicing with step
+ # https://github.com/pytorch/pytorch/issues/78886
+ x_mps = torch.zeros(10, dtype=torch.float32, device="mps")
+ x_mps[::2] = 1.0
+
+ x_cpu = torch.zeros(10, dtype=torch.float32, device="mps")
+ x_cpu[::2] = 1.0
+
+ self.assertEqual(x_cpu, x_mps)
+
+ def test_slicing_replace_column(self):
+ # https://github.com/pytorch/pytorch/issues/78074
+ def _helper(tensor_data):
+ x_cpu = torch.tensor(tensor_data)
+ x_mps = x_cpu.to('mps')
+
+ x_cpu[:, 0] = 7
+ x_mps[:, 0] = 7
+
+ self.assertEqual(x_cpu, x_mps)
+
+ _helper([[1, 2, 3], [4, 5, 6]])
+ _helper([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
+ _helper([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]])
+
+ def test_inplace_scatter(self):
+ # https://github.com/pytorch/pytorch/issues/79672
+ a_mps = torch.ones((2, 2),).to(torch.device("mps"))
+ b_mps = torch.ones((2, 2),).to(torch.device("mps"))
+
+ a_cpu = torch.ones((2, 2),).to(torch.device("cpu"))
+ b_cpu = torch.ones((2, 2),).to(torch.device("cpu"))
+
+ a_mps[:, 0] += b_mps[:, 0]
+ a_cpu[:, 0] += b_cpu[:, 0]
+ self.assertEqual(a_cpu, a_mps)
+
+ a_mps[:, 0] = a_mps[:, 0] + b_mps[:, 0]
+ a_cpu[:, 0] = a_cpu[:, 0] + b_cpu[:, 0]
+ self.assertEqual(a_cpu, a_mps)
+
+class TestViewOpsMPS(TestCase):
+ exact_dtype = True
+
+ def is_view_of(self, base, other):
+ if (not other._is_view() or
+ other is base or
+ other._base is not base or
+ base.device != other.device):
+ return False
+ # Note: only validates storage on native device types
+ # because some accelerators, like XLA, do not expose storage
+ if base.device.type == 'cpu' or base.device.type == 'cuda':
+ if base.storage().data_ptr() != other.storage().data_ptr():
+ return False
+
+ return True
+
+ # Returns true if v1 and v2 are views of the same base
+ def is_view_of_same_base(self, v1, v2):
+ if (not v1._is_view() or v1 is v2):
+ return False
+ return self.is_view_of(v1._base, v2)
+
+ # Performs transpose if contiguous=True, else returns the input tensor as is
+ def _do_transpose(self, x, contiguous=False, dim0=0, dim1=1):
+ if contiguous:
+ return x
+ else:
+ return x.transpose(dim0, dim1)
+
+ def test_diagonal_view(self, device="mps"):
+ t = torch.ones((5, 5), device=device)
+ v = torch.diagonal(t)
+ self.assertTrue(self.is_view_of(t, v))
+
+ v[0] = 0
+ self.assertEqual(t[0, 0], v[0])
+
+ t = torch.ones((3, 3, 3), device="mps")
+ v = torch.diagonal(t, offset=1, dim1=1, dim2=2)
+ self.assertTrue(self.is_view_of(t, v))
+
+ v[0, 0] = 0
+ self.assertEqual(t[0, 0, 1], v[0, 0])
+
+ def test_select_view(self, device="mps") -> None:
+ t = torch.ones((5, 5), device=device)
+ v = t.select(0, 2)
+ self.assertTrue(self.is_view_of(t, v))
+
+ v[0] = 0
+ self.assertEqual(t[2, 0], v[0])
+
+ def test_unbind_view(self, device="mps") -> None:
+ t = torch.zeros((5, 5), device=device)
+ tup = torch.unbind(t)
+
+ for idx, v in enumerate(tup):
+ self.assertTrue(self.is_view_of(t, v))
+
+ v[0] = idx + 1
+ self.assertEqual(t[idx, 0], v[0])
+
+ def test_expand_view(self, device="mps") -> None:
+ t = torch.ones((5, 1), device=device)
+ v = t.expand(5, 5)
+ self.assertTrue(self.is_view_of(t, v))
+
+ v[2, 2] = 0
+ self.assertEqual(t[2, 0], v[2, 2])
+
+ def test_expand_as_view(self, device="mps"):
+ t = torch.ones((5, 1), device=device)
+ e = torch.empty((5, 5), device=device)
+ v = t.expand_as(e)
+ self.assertTrue(self.is_view_of(t, v))
+
+ v[2, 2] = 0
+ self.assertEqual(t[2, 0], v[2, 2])
+
+ def test_narrow_view(self, device="mps"):
+ t = torch.ones((5, 5), device=device)
+ v = torch.narrow(t, 1, 2, 2)
+ self.assertTrue(self.is_view_of(t, v))
+
+ v[0, 0] = 0
+ self.assertEqual(t[0, 2], v[0, 0])
+
+ def test_permute_view(self, device="mps") -> None:
+ t = torch.ones((5, 5), device=device)
+ v = t.permute(1, 0)
+ self.assertTrue(self.is_view_of(t, v))
+
+ v[0, 1] = 0
+ self.assertEqual(t[1, 0], v[0, 1])
+
+ def test_transpose_view(self, device="mps"):
+ for fn in (torch.swapdims, torch.swapaxes, torch.transpose):
+ t = torch.ones((5, 5), device=device)
+ v = fn(t, 0, 1)
+ self.assertTrue(self.is_view_of(t, v))
+
+ v[0, 1] = 0
+ self.assertEqual(t[1, 0], v[0, 1])
+
+ def test_transpose_inplace_view(self, device="mps"):
+ t = torch.ones(5, 5, device=device)
+ v = t.view_as(t)
+ v = v.swapdims_(0, 1)
+ self.assertTrue(self.is_view_of(t, v))
+ v[0, 1] = 0
+ self.assertEqual(t[1, 0], v[0, 1])
+
+ t = torch.ones(5, 5, device=device)
+ v = t.view_as(t)
+ v = v.swapaxes_(0, 1)
+ self.assertTrue(self.is_view_of(t, v))
+ v[0, 1] = 0
+ self.assertEqual(t[1, 0], v[0, 1])
+
+ t = torch.ones(5, 5, device=device)
+ v = t.view_as(t)
+ v = v.transpose_(0, 1)
+ self.assertTrue(self.is_view_of(t, v))
+ v[0, 1] = 0
+ self.assertEqual(t[1, 0], v[0, 1])
+
+ def test_t_view(self, device="mps"):
+ t = torch.ones((5, 5), device=device)
+ v = t.t()
+ self.assertTrue(self.is_view_of(t, v))
+
+ v[0, 1] = 0
+ self.assertEqual(t[1, 0], v[0, 1])
+
+ def test_t_inplace_view(self, device="mps"):
+ t = torch.ones(5, 5, device=device)
+ v = t.view_as(t)
+ v = v.t_()
+ self.assertTrue(self.is_view_of(t, v))
+ v[0, 1] = 0
+ self.assertEqual(t[1, 0], v[0, 1])
+
+ def test_T_view(self, device="mps"):
+ for op in ("T", "H", "mT", "mH"):
+ t = torch.ones((5, 5), device=device)
+ v = getattr(t, op)
+ self.assertTrue(self.is_view_of(t, v))
+
+ v[0, 1] = 0
+ self.assertEqual(t[1, 0], v[0, 1])
+
+ # requires aten::unfold
+ # def test_unfold_view(self, device="mps"):
+ # t = torch.ones(10, device=device)
+ # v = t.unfold(0, 3, 2)
+ # self.assertTrue(self.is_view_of(t, v))
+
+ # v[1, 0] = 0
+ # self.assertEqual(t[2], v[1, 0])
+
+ def test_squeeze_view(self, device="mps"):
+ t = torch.ones(5, 1, 5, device=device)
+ v = torch.squeeze(t)
+ self.assertTrue(self.is_view_of(t, v))
+ v[0, 1] = 0
+ self.assertEqual(t, v._base)
+
+ def test_squeeze_inplace_view(self, device="mps"):
+ t = torch.ones(5, 5, device=device)
+ v = t.view_as(t)
+ v = v.squeeze_()
+ self.assertTrue(self.is_view_of(t, v))
+ v[0, 1] = 0
+ self.assertEqual(t, v._base)
+
+ def test_unsqueeze_view(self, device="mps"):
+ t = torch.ones(5, 5, device=device)
+ v = torch.unsqueeze(t, 1)
+ self.assertTrue(self.is_view_of(t, v))
+
+ v[0, 0, 1] = 0
+ self.assertEqual(t[0, 1], v[0, 0, 1])
+
+ def test_unsqueeze_inplace_view(self, device="mps"):
+ t = torch.ones(5, 5, device=device)
+ v = t.view_as(t)
+ v = v.unsqueeze_(1)
+ self.assertTrue(self.is_view_of(t, v))
+ v[0, 0, 1] = 0
+ self.assertEqual(t[0, 1], v[0, 0, 1])
+
+ def test_as_strided_view(self, device="mps"):
+ t = torch.ones(5, 5, device=device)
+ v = torch.as_strided(t, (25,), (1,))
+ self.assertTrue(self.is_view_of(t, v))
+
+ v[6] = 0
+ self.assertEqual(t[1, 1], v[6])
+
+ def test_as_strided_inplace_view(self, device="mps"):
+ t = torch.ones(5, 5, device=device)
+ v = t.view_as(t)
+ v = v.as_strided_((25,), (1,))
+ self.assertTrue(self.is_view_of(t, v))
+ v[6] = 0
+ self.assertEqual(t[1, 1], v[6])
+
+ def test_view_view(self, device="mps"):
+ t = torch.ones(5, 5, device=device)
+ v = t.view(25)
+ self.assertTrue(self.is_view_of(t, v))
+
+ v[6] = 0
+ self.assertEqual(t[1, 1], v[6])
+
+ def test_view_as_view(self, device="mps"):
+ t = torch.ones(5, 5, device=device)
+ e = torch.empty((25,))
+ v = t.view_as(e)
+ self.assertTrue(self.is_view_of(t, v))
+
+ v[6] = 0
+ self.assertEqual(t[1, 1], v[6])
+
+ def test_contiguous_self(self, device="mps"):
+ t = torch.ones(5, 5, device=device)
+ s = t.contiguous()
+ self.assertTrue(s is t)
+
+ def test_contiguous_nonview(self, device="mps"):
+ t = torch.ones(5, 5, device=device)
+ nv = t.t().contiguous()
+ self.assertTrue(not self.is_view_of(t, nv))
+
+ nv[0, 0] = 0
+ self.assertNotEqual(t[0, 0], nv[0, 0])
+
+ def test_reshape_view(self, device="mps"):
+ t = torch.ones(5, 5, device=device)
+ v = torch.reshape(t, (25,))
+ self.assertTrue(self.is_view_of(t, v))
+
+ v[6] = 0
+ self.assertEqual(t[1, 1], v[6])
+
+ def test_reshape_as_view(self, device="mps"):
+ t = torch.ones(5, 5, device=device)
+ e = torch.empty((25,), device=device)
+ v = t.reshape_as(e)
+ self.assertTrue(self.is_view_of(t, v))
+
+ v[6] = 0
+ self.assertEqual(t[1, 1], v[6])
+
+ def test_reshape_nonview(self, device="mps"):
+ t = torch.ones(5, 5, device=device)
+ nv = torch.reshape(t.t(), (25,))
+ self.assertTrue(not self.is_view_of(t, nv))
+
+ nv[6] = 0
+ self.assertNotEqual(t[1, 1], nv[6])
+
+ def test_flatten_view(self, device="mps"):
+ def test_writes_propagate(t, v):
+ idx_t = (0,) * t.ndim
+ idx_v = (0,) * v.ndim
+ v[idx_v] = 0
+ self.assertEqual(t[idx_t], v[idx_v])
+
+ t = torch.ones(1, 2, 3, 4, device=device)
+ v = t.flatten()
+ self.assertTrue(self.is_view_of(t, v))
+ test_writes_propagate(t, v)
+
+ # zero-dimensional tensor
+ t = torch.tensor(1, device=device)
+ v = t.flatten()
+ test_writes_propagate(t, v)
+ self.assertTrue(self.is_view_of(t, v))
+
+ t = torch.ones(1, 2, 3, 4, device=device).transpose(2, 3)
+ v = t.flatten(0, 1)
+ test_writes_propagate(t, v)
+ self.assertTrue(self.is_view_of_same_base(t, v))
+
+ # stride[i] = stride[i + 1] * size[i + 1] is satisfied for 3 groups:
+ t = torch.ones(720, device=device) \
+ .as_strided((2, 3, 2, 3, 5, 4), (6, 2, 15, 5, 1, 0))
+ # [--1--|---2---|-3-] [--1--|----2---|-3-]
+ v1 = t.flatten(0, 1)
+ v2 = v1.flatten(1, 3)
+ v3 = v2.flatten(2, 2)
+ test_writes_propagate(t, v1)
+ self.assertTrue(self.is_view_of_same_base(t, v1))
+ test_writes_propagate(t, v2)
+ self.assertTrue(self.is_view_of_same_base(t, v2))
+ test_writes_propagate(t, v3)
+ self.assertTrue(self.is_view_of_same_base(t, v3))
+
+ def test_flatten_nonview(self, device="mps"):
+ def assert_is_nonview(t, nv):
+ idx_t = (0,) * t.ndim
+ idx_nv = (0,) * nv.ndim
+ self.assertTrue(not nv._is_view())
+ nv[idx_nv] = 0
+ self.assertNotEqual(t[idx_t], nv[idx_nv])
+ t = torch.ones(2, 3, 2, 3, device=device).transpose(2, 3)
+ nv = t.flatten(1, 3)
+ assert_is_nonview(t, nv)
+
+ t = torch.ones(2, 2, device=device).T
+ nv = t.flatten()
+ assert_is_nonview(t, nv)
+
+ # flatten returns the original object if start_dim=end_dim
+ t = t = torch.ones(2, 2, device=device)
+ nv = t.flatten(1, 1)
+ self.assertTrue(t is nv)
+
+ def test_basic_indexing_slice_view(self, device="mps"):
+ t = torch.ones(5, 5, device=device)
+ v = t[:2, :3]
+ self.assertTrue(self.is_view_of(t, v))
+
+ v[0, 0] = 0
+ self.assertEqual(t[0, 0], v[0, 0])
+
+ def test_basic_indexing_ellipses_view(self, device="mps"):
+ t = torch.ones(5, 5, device=device)
+ v = t[..., :2]
+ self.assertTrue(self.is_view_of(t, v))
+
+ v[0, 0] = 0
+ self.assertEqual(t[0, 0], v[0, 0])
+
+ def test_basic_indexing_newaxis_view(self, device="mps"):
+ t = torch.ones(5, 5, device=device)
+ v = t[None, :2, 3]
+ self.assertTrue(self.is_view_of(t, v))
+
+ v[0, 0] = 0
+ self.assertEqual(t[0, 3], v[0, 0])
+
+ def test_chunk_view(self, device="mps"):
+ t = torch.zeros(3, 3, device=device)
+ l = torch.chunk(t, 3)
+
+ for idx, v in enumerate(l):
+ self.assertTrue(self.is_view_of(t, v))
+
+ v[0, 0] = idx + 1
+ self.assertEqual(t[idx, 0], v[0, 0])
+
+ def test_split_view(self, device="mps"):
+ t = torch.zeros(3, 3, device=device)
+ l = torch.split(t, [1, 1, 1])
+
+ for idx, v in enumerate(l):
+ self.assertTrue(self.is_view_of(t, v))
+
+ v[0, 0] = idx + 1
+ self.assertEqual(t[idx, 0], v[0, 0])
+
+ def test_movedim_view(self, device="mps"):
+ def run_test(device, op):
+ t = torch.zeros(3, 3, device=device)
+ out = op(t)
+
+ self.assertTrue(self.is_view_of(t, out))
+
+ # Randomly change values in output
+ # and verify that original is changed
+ # as well.
+ for _ in range(3):
+ idx_1, idx_2 = random.randint(0, 2), random.randint(0, 2)
+ out[idx_1, idx_2] = random.random()
+ self.assertEqual(t[idx_2, idx_1], out[idx_1, idx_2])
+
+ for fn in [torch.movedim, torch.moveaxis]:
+ op = partial(fn, source=(0, 1), destination=(1, 0))
+ run_test(device, op)
+
+ op = partial(fn, source=0, destination=1)
+ run_test(device, op)
+
+ # Testing that the generated view_copy kernel and its derivative are implemented correctly
+ def test_view_copy(self, device="mps"):
+ a = torch.randn(4, device=device, requires_grad=True)
+ a_ref = a.clone().detach().requires_grad_()
+ a_view = a_ref.view(2, 2)
+ a_view_copy = torch.view_copy(a, (2, 2))
+
+ # view_copy ops don't preserve view relationship
+ self.assertTrue(self.is_view_of(a_ref, a_view))
+ self.assertFalse(self.is_view_of(a, a_view_copy))
+
+ a_view_copy.sum().backward()
+ a_view.sum().backward()
+
+ # forward and backward give the same shape + result
+ self.assertEqual(a_view_copy, a_view)
+ self.assertEqual(a.grad, a_ref.grad)
+
+ def test_view_copy_out(self, device="mps"):
+ a = torch.randn(2, 2, device=device)
+ out = torch.empty(2, device=device)
+
+ torch.diagonal_copy(a, out=out)
+ expected = torch.diagonal_copy(a)
+
+ self.assertEqual(expected, out)
+
+ a = torch.randn(4, device=device)
+ out1 = torch.empty(2, device=device)
+ out2 = torch.empty(2, device=device)
+
+ torch.split_copy(a, 2, out=(out1, out2))
+ expected1, expected2 = torch.split_copy(a, 2)
+
+ self.assertEqual(expected1, out1)
+ self.assertEqual(expected2, out2)
+
+ def test_empty_reshape(self, device="mps"):
+ x = torch.randn(0, 6, device=device)
+ self.assertEqual((1, 0, 6, 1, 1), x.reshape(1, 0, 6, 1, 1).shape)
+ # should be viewable -- i.e. data_ptr is the same.
+ self.assertEqual(x.data_ptr(), x.reshape(1, 0, 6, 1, 1).data_ptr())
+
+ # match NumPy semantics -- don't infer the size of dimension with a degree of freedom
+ self.assertRaises(RuntimeError, lambda: x.reshape(0, -1))
+
+ def test_expand(self, device="mps"):
+ tensor = torch.rand(1, 8, 1, device=device)
+ tensor2 = torch.rand(5, device=device)
+ template = torch.rand(4, 8, 5, device=device)
+ target = template.size()
+ self.assertEqual(tensor.expand_as(template).size(), target)
+ self.assertEqual(tensor.expand(4, 8, 5).size(), target)
+ self.assertEqual(tensor.expand(target).size(), target)
+ self.assertEqual(tensor2.expand_as(template).size(), target)
+ self.assertEqual(tensor2.expand(4, 8, 5).size(), target)
+ self.assertEqual(tensor2.expand(target).size(), target)
+
+ # test double expand
+ self.assertEqual(tensor2.expand(1, 5).expand(2, 2, 5), tensor2.repeat(2, 2, 1))
+
+ # test non-contiguous
+ noncontig = torch.randn(5, 2, 1, 3, device=device)[:, 0]
+ self.assertFalse(noncontig.is_contiguous())
+ self.assertEqual(noncontig.expand(2, 5, 4, 3), noncontig.contiguous().repeat(2, 1, 4, 1))
+
+ # make sure it's compatible with unsqueeze
+ expanded = tensor2.expand(1, 1, 5)
+ unsqueezed = tensor2.unsqueeze(0).unsqueeze(1)
+ self.assertEqual(expanded, unsqueezed)
+ self.assertEqual(expanded.stride(), unsqueezed.stride())
+
+ # test -1 as target size
+ self.assertEqual(tensor.expand(4, -1, 5), tensor.expand(4, 8, 5))
+ self.assertRaises(RuntimeError, lambda: tensor2.expand(-1, -1))
+
+ # test expanding empty to empty
+ self.assertEqual(torch.zeros(0, device=device).expand((0,)), torch.zeros(0, device=device))
+
+ def test_view_empty(self, device="mps"):
+ x = torch.randn(0, 6, device=device)
+ self.assertEqual((1, 0, 6, 1, 1), x.view(1, 0, 6, 1, 1).shape)
+
+ def test_reshape(self, device="mps"):
+ x = torch.randn(3, 3, device=device)
+ self.assertEqual(x.data_ptr(), x.reshape(-1).data_ptr())
+ self.assertEqual(x.data_ptr(), x.reshape(1, 9, 1).data_ptr())
+ self.assertEqual(torch.reshape(x, (9,)), x.reshape(9))
+ self.assertRaises(RuntimeError, lambda: x.reshape(-1, -1))
+
+ y = torch.randn(4, 4, 4, device=device)[:, 0, :]
+ # .data_ptr() on meta tensors is always 0 so they are equal regardless of the reshape
+ if device != "meta":
+ self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr())
+ self.assertEqual(y.contiguous().view(-1), y.reshape(-1))
+ self.assertEqual(y.reshape(2, 2, 4).data_ptr(), y.data_ptr())
+
+ s = torch.randn((), device=device)
+ self.assertEqual(s.data_ptr(), s.reshape(()).data_ptr())
+ self.assertEqual(s.reshape(-1).shape, (1,))
+ self.assertRaises(RuntimeError, lambda: s.reshape(2))
+
+ empty = torch.tensor([], device=device)
+ self.assertEqual(empty, empty.reshape(-1))
+ self.assertEqual(empty, empty.reshape([0]))
+ # TODO: fix these once we have multi-dimensional empty tensors
+ self.assertEqual(empty.reshape([0, 1]).shape, (0, 1))
+ self.assertEqual(empty.reshape([1, -1]).shape, (1, 0))
+ self.assertRaises(RuntimeError, lambda: empty.reshape(1))
+
+ x = torch.randn(3, 3, device=device)
+ self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(9)).data_ptr())
+ self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(1, 9, 1)).data_ptr())
+ self.assertRaises(RuntimeError, lambda: x.reshape_as(torch.rand(10, device=device)))
+
+ def test_narrow(self, device="mps"):
+ x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
+ self.assertEqual(x.narrow(0, 0, 1), torch.tensor([[0, 1, 2]]))
+ self.assertEqual(x.narrow(0, 0, 2), torch.tensor([[0, 1, 2], [3, 4, 5]]))
+ self.assertEqual(x.narrow(0, 1, 1), torch.tensor([[3, 4, 5]]))
+ self.assertEqual(x.narrow(0, -1, 1), torch.tensor([[6, 7, 8]]))
+ self.assertEqual(x.narrow(0, -2, 2), torch.tensor([[3, 4, 5], [6, 7, 8]]))
+ self.assertEqual(x.narrow(0, -3, 3), torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]))
+ self.assertEqual(x.narrow(-1, -1, 1), torch.tensor([[2], [5], [8]]))
+ self.assertEqual(x.narrow(-2, -1, 1), torch.tensor([[6, 7, 8]]))
+
+ def test_narrow_tensor(self, device="mps"):
+ x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
+ self.assertEqual(x.narrow(0, torch.tensor(0), 1), torch.tensor([[0, 1, 2]]))
+ with self.assertRaises(Exception):
+ x.narrow(0, torch.tensor(0.), 1)
+ with self.assertRaises(Exception):
+ x.narrow(0, torch.tensor([0]), 1)
+ with self.assertRaises(Exception):
+ x.narrow(0, torch.tensor([0, 1]), 1)
+
+ def test_t(self, device="mps"):
+ # Test 0D tensors
+ x = torch.randn(())
+ self.assertEqual(x, x.t())
+ x = x.to_sparse()
+ self.assertEqual(x, x.t())
+
+ # Test 1D tensors
+ x = torch.arange(4)
+ self.assertEqual(x, x.t())
+ x = x.to_sparse()
+ self.assertEqual(x, x.t())
+
+ # Test 2D tensors
+ x = torch.rand((2, 2))
+ self.assertEqual(x.t(), x.transpose(0, 1))
+ x = x.to_sparse()
+ self.assertEqual(x.t(), x.transpose(0, 1))
+
+ # Test 3D tensor
+ x = torch.rand((2, 2, 2))
+ with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 dimensions, but self is 3D'):
+ x.t()
+ x = x.to_sparse()
+ with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 sparse and 0 dense dimensions'):
+ x.t()
+
+ def test_split(self, device="mps"):
+ tensor = torch.rand(7, 4)
+ split_size = 3
+ dim = 0
+ target_sizes = ([3, 4], [3, 4], [1, 4])
+ splits = tensor.split(split_size, dim)
+ start = 0
+ for target_size, split in zip(target_sizes, splits):
+ self.assertEqual(split.size(), target_size)
+ self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0)
+ start = start + target_size[dim]
+
+ # Variable sections split
+ tensor = torch.randn(20, 10)
+ dim = 0
+ split_sizes = [5, 5, 10]
+ target_sizes = ([[5, 10], [5, 10], [10, 10]])
+ splits = tensor.split(split_sizes, dim)
+ start = 0
+ for target_size, split in zip(target_sizes, splits):
+ self.assertEqual(split.size(), target_size)
+ self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0)
+ start = start + target_size[dim]
+
+ split_sizes = [2, 2, 6]
+ target_sizes = ([20, 2], [20, 2], [20, 6])
+ dim = 1
+ splits = tensor.split(split_sizes, dim)
+ start = 0
+ for target_size, split in zip(target_sizes, splits):
+ self.assertEqual(split.size(), target_size)
+ self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0)
+ start = start + target_size[dim]
+
+ def test_chunk(self, device="mps"):
+ tensor = torch.rand(4, 7)
+ num_chunks = 3
+ dim = 1
+ target_sizes = ([4, 3], [4, 3], [4, 1])
+ splits = tensor.chunk(num_chunks, dim)
+ start = 0
+ for target_size, split in zip(target_sizes, splits):
+ self.assertEqual(split.size(), target_size)
+ self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split,
+ atol=0, rtol=0)
+ start = start + target_size[dim]
+
+ # Invalid chunk sizes
+ error_regex = 'chunk expects.*greater than 0'
+ with self.assertRaisesRegex(RuntimeError, error_regex):
+ tensor.chunk(0)
+ with self.assertRaisesRegex(RuntimeError, error_regex):
+ tensor.chunk(-2)
+
+ def test_unsqueeze(self, device="mps") -> None:
+ x = torch.randn(2, 3, 4)
+ y = x.unsqueeze(1)
+ self.assertEqual(y, x.view(2, 1, 3, 4))
+ y = x.clone().unsqueeze_(2)
+ self.assertEqual(y, x.view(2, 3, 1, 4))
+
+ x = x[:, 1]
+ self.assertFalse(x.is_contiguous())
+ y = x.unsqueeze(1)
+ self.assertEqual(y, x.contiguous().view(2, 1, 4))
+ y = x.clone().unsqueeze_(2)
+ self.assertEqual(y, x.contiguous().view(2, 4, 1))
+
+ # unit test for special case transposed copy (see ATen/native/Copy.cpp for details)
+ def test_big_transpose(self, device="mps"):
+ t = torch.rand(456, 789, device=device)
+ t1 = t.t().contiguous()
+ t2 = torch.from_numpy(t.cpu().numpy().transpose())
+ self.assertEqual(t1, t2)
+
+ def test_T(self, device="mps"):
+ a = torch.randn(2, 3, 4, device=device)
+ t1 = a.T
+ t2 = a.permute(2, 1, 0)
+ self.assertEqual(t2, t1)
+ b = torch.randn(10, device=device)
+ self.assertEqual(b, b.T)
+ scalar = torch.tensor(5, device=device)
+ self.assertEqual(scalar, scalar.T)
+
+ def test_transposes(self, device="mps", dtype=torch.float32):
+ for op in ("T", "H", "mT", "mH", "adjoint"):
+ shapes = ((), (2, 3), (2, 3, 4)) if op[0] == "m" or op == "adjoint" else ((), (2, 3),)
+ for shape in shapes:
+ a = make_tensor(shape, device=device, dtype=dtype)
+ t1 = getattr(a, op)
+ if op == "adjoint":
+ t1 = t1()
+ t2 = a
+ if a.ndim != 0:
+ t2 = t2.transpose(-2, -1)
+ if op[-1] == "H" or op == "adjoint":
+ t2 = t2.conj()
+ self.assertEqual(t2, t1)
+
+ def test_transposes_errors(self, device="mps", dtype=torch.float32):
+ for op in ("H", "mT", "mH", "adjoint"):
+ shapes = ((2,), (2, 3, 4)) if op == "H" else ((2,),)
+ for shape in shapes:
+ a = make_tensor(shape, device=device, dtype=dtype)
+ with self.assertRaisesRegex(RuntimeError, "only supported on matrices"):
+ t1 = getattr(a, op)
+ if op == "adjoint":
+ t1 = t1()
+
+ def test_python_types(self, device="mps"):
+ a1 = torch.randn((1, 2), device=device, dtype=torch.float32)
+ a2 = torch.randn((1, 2), device=device, dtype=torch.float32)
+ self.assertEqual(a1.dtype, a2.dtype)
+
+ b1 = torch.arange(10, 20, dtype=torch.int64, device=device)
+ b2 = torch.arange(10, 20, dtype=int, device=device)
+ self.assertEqual(b1.dtype, b2.dtype)
+
+ c1 = torch.tensor([True, False], dtype=torch.bool, device=device)
+ c2 = torch.tensor([True, False], dtype=bool, device=device)
+ self.assertEqual(c1.dtype, c2.dtype)
+
+ # TODO: is resize best put in test_view_ops?
+ def test_resize_as_preserves_strides(self, device="mps"):
+ x = torch.empty(2, 3).t()
+ old_strides = x.stride()
+ x.resize_as_(x)
+ self.assertEqual(x.stride(), old_strides)
+
+ def test_memory_format_resize_as(self, device="mps"):
+ def test_helper(shape, memory_format, device="mps"):
+ xc = torch.randn(shape, device=device).contiguous(memory_format=memory_format)
+ flat = torch.randn(xc.numel(), device=device)
+ flat.resize_as_(xc, memory_format=torch.preserve_format)
+ self.assertTrue(flat.is_contiguous(memory_format=memory_format))
+
+ test_helper((10, 3, 32, 32), torch.channels_last, device="mps")
+ test_helper((3, 10, 3, 32, 32), torch.channels_last_3d, device="mps")
+
+ def test_memory_format_resize_(self, device="mps"):
+ def test_helper(shape, numel, memory_format, device="mps"):
+ flat = torch.randn(numel, device=device)
+ flat.resize_(shape, memory_format=memory_format)
+ self.assertTrue(flat.is_contiguous(memory_format=memory_format))
+
+ test_helper((10, 3, 32, 32), 10 * 3 * 32 * 32, torch.channels_last, device="mps")
+ test_helper((3, 10, 3, 32, 32), 3 * 10 * 3 * 32 * 32, torch.channels_last_3d, device="mps")
+
+ # TODO: OpInfo this
+ def _test_atleast(self, device, torch_fn):
+ # 0-dim
+ s = torch.tensor(0.5, dtype=torch.double, requires_grad=True)
+
+ gradcheck(lambda x: torch_fn(x), s)
+ gradgradcheck(lambda x: torch_fn(x), s)
+
+ # 1-dim
+ a = torch.rand(4, dtype=torch.double, requires_grad=True)
+
+ gradcheck(lambda x: torch_fn(x), a)
+ gradgradcheck(lambda x: torch_fn(x), a)
+
+ # 2,3,4-dim
+ b = torch.rand(4, 3, dtype=torch.double, requires_grad=True)
+ c = torch.rand(4, 3, 2, dtype=torch.double, requires_grad=True)
+ d = torch.rand(4, 3, 2, 1, dtype=torch.double, requires_grad=True)
+
+ input_tuple = (s, a, b, c, d)
+ gradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple)
+ gradgradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple)
+
+ def test_atleast_gradient(self, device="mps"):
+ self._test_atleast(device, torch.atleast_1d)
+ self._test_atleast(device, torch.atleast_2d)
+ self._test_atleast(device, torch.atleast_3d)
+
+ def test_view(self, device="mps"):
+ tensor = torch.rand(15, device=device)
+ template = torch.rand(3, 5, device=device)
+ empty = torch.empty(0, device=device)
+ target = template.size()
+ self.assertEqual(tensor.view_as(template).size(), target)
+ self.assertEqual(tensor.view(3, 5).size(), target)
+ self.assertEqual(tensor.view(torch.Size([3, 5])).size(), target)
+ self.assertEqual(tensor.view(-1, 5).size(), target)
+ self.assertEqual(tensor.view(3, -1).size(), target)
+ tensor_view = tensor.view(5, 3)
+ tensor_view.fill_(random.uniform(0, 1))
+ self.assertEqual(empty.view_as(empty), empty)
+ self.assertEqual(empty.view(0), empty)
+ self.assertEqual(empty.view(0, 3, 0, 1).size(), torch.Size([0, 3, 0, 1]))
+ self.assertEqual(empty.view(0, 3, 0, 1).view(0), empty)
+
+ # test size inference with empty tensors
+ self.assertEqual(empty.view(-1).size(), torch.Size([0]))
+ self.assertEqual(empty.view(10, 3, -1).size(), torch.Size([10, 3, 0]))
+
+ with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"):
+ empty.view(-1, 0)
+
+ with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"):
+ empty.view(3, 0, -1, 0)
+
+ self.assertRaises(RuntimeError, lambda: tensor.view(15, 0))
+ self.assertRaises(RuntimeError, lambda: tensor.view(7, -1))
+ self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1))
+
+ # RuntimeError: Invalid device for storage: mps
+ # def test_contiguous(self, device="mps"):
+ # x = torch.randn(1, 16, 5, 5, device=device)
+ # self.assertTrue(x.is_contiguous())
+ # stride = list(x.stride())
+ # stride[0] = 20
+ # # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1
+ # x.set_(x.storage(), 0, x.size(), stride)
+ # self.assertTrue(x.is_contiguous())
+
+ def test_resize_all_dtypes_and_devices(self, device="mps"):
+ shape = (2, 2)
+ for dt in (torch.half, torch.bfloat16, torch.bool):
+ x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
+ x.resize_(shape)
+ self.assertEqual(shape, x.shape)
+
+ def test_resize_as_all_dtypes_and_devices(self, device="mps"):
+ for dt in (torch.half, torch.bfloat16, torch.bool):
+ x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
+ y = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dt, device=device)
+ x.resize_as_(y)
+ self.assertEqual(y.shape, x.shape)
+
+ def test_resize_overflow(self, device="mps"):
+ x = torch.empty((), dtype=torch.float64)
+ with self.assertRaisesRegex(RuntimeError, 'Storage size calculation overflowed'):
+ x.resize_([2, 4, 2**29, 2**29])
+ with self.assertRaisesRegex(RuntimeError, 'overflow'):
+ x.resize_([8, 8, 2**29, 2**29])
+
+ def test_view_all_dtypes_and_devices(self, device="mps"):
+ for dt in (torch.float, torch.bool):
+ x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device)
+ self.assertEqual(x.view(6).shape, [6])
class TestRNNMPS(TestCase):
def test_lstm_1(self, device="mps", dtype=torch.float32):