[MPS] Register unfold key for MPS (#91266)
Register unfold key for MPS (uses generic implementation that's already existent).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91266
Approved by: https://github.com/razarmehr
diff --git a/test/test_mps.py b/test/test_mps.py
index ad2bab4..5f85f2d 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2039,7 +2039,29 @@
strided_mps_out = strided_mps1 - strided_mps2
self.assertEqual(strided_cpu_out, strided_mps_out)
+ def test_unfold(self):
+ x = torch.arange(1., 8)
+ x_mps = torch.arange(1., 8, device="mps")
+ y = x.unfold(0, 2, 1)
+ y_mps = x_mps.unfold(0, 2, 1)
+
+ self.assertEqual(y, y_mps)
+
+ def test_unfold_all_devices_and_dtypes(self):
+ supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8]
+ for dt in supported_dtypes:
+ x = torch.empty((0, 1, 3, 0), dtype=dt, device="mps")
+ self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape)
+
+ def test_unfold_scalars(self):
+ x = torch.tensor(0.5, device="mps")
+ # unfold on a 0-dimensional tensor should always return a 1-d dimensional
+ # tensor of shape [size] (i.e., the second parameter to unfold)
+
+ self.assertEqual(torch.empty(0, device="mps"), x.unfold(0, 0, 1))
+ self.assertEqual(torch.empty(0, device="mps"), x.unfold(0, 0, 2))
+ self.assertEqual(torch.tensor([0.5], device="mps"), x.unfold(0, 1, 1))
def test_sum_backward(self):
def helper(n, c):
@@ -5726,14 +5748,13 @@
v[0, 1] = 0
self.assertEqual(t[1, 0], v[0, 1])
- # requires aten::unfold
- # def test_unfold_view(self, device="mps"):
- # t = torch.ones(10, device=device)
- # v = t.unfold(0, 3, 2)
- # self.assertTrue(self.is_view_of(t, v))
+ def test_unfold_view(self, device="mps"):
+ t = torch.ones(10, device=device)
+ v = t.unfold(0, 3, 2)
+ self.assertTrue(self.is_view_of(t, v))
- # v[1, 0] = 0
- # self.assertEqual(t[2], v[1, 0])
+ v[1, 0] = 0
+ self.assertEqual(t[2], v[1, 0])
def test_squeeze_view(self, device="mps"):
t = torch.ones(5, 1, 5, device=device)