[MPS] Register unfold key for MPS (#91266)

Register unfold key for MPS (uses generic implementation that's already existent).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91266
Approved by: https://github.com/razarmehr
diff --git a/test/test_mps.py b/test/test_mps.py
index ad2bab4..5f85f2d 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2039,7 +2039,29 @@
         strided_mps_out = strided_mps1 - strided_mps2
         self.assertEqual(strided_cpu_out, strided_mps_out)
 
+    def test_unfold(self):
+        x = torch.arange(1., 8)
+        x_mps = torch.arange(1., 8, device="mps")
 
+        y = x.unfold(0, 2, 1)
+        y_mps = x_mps.unfold(0, 2, 1)
+
+        self.assertEqual(y, y_mps)
+
+    def test_unfold_all_devices_and_dtypes(self):
+        supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8]
+        for dt in supported_dtypes:
+            x = torch.empty((0, 1, 3, 0), dtype=dt, device="mps")
+            self.assertEqual((0, 1, 1, 0, 3), x.unfold(2, 3, 2).shape)
+
+    def test_unfold_scalars(self):
+        x = torch.tensor(0.5, device="mps")
+        # unfold on a 0-dimensional tensor should always return a 1-d dimensional
+        # tensor of shape [size] (i.e., the second parameter to unfold)
+
+        self.assertEqual(torch.empty(0, device="mps"), x.unfold(0, 0, 1))
+        self.assertEqual(torch.empty(0, device="mps"), x.unfold(0, 0, 2))
+        self.assertEqual(torch.tensor([0.5], device="mps"), x.unfold(0, 1, 1))
 
     def test_sum_backward(self):
         def helper(n, c):
@@ -5726,14 +5748,13 @@
             v[0, 1] = 0
             self.assertEqual(t[1, 0], v[0, 1])
 
-    # requires aten::unfold
-    # def test_unfold_view(self, device="mps"):
-    #     t = torch.ones(10, device=device)
-    #     v = t.unfold(0, 3, 2)
-    #     self.assertTrue(self.is_view_of(t, v))
+    def test_unfold_view(self, device="mps"):
+        t = torch.ones(10, device=device)
+        v = t.unfold(0, 3, 2)
+        self.assertTrue(self.is_view_of(t, v))
 
-    #     v[1, 0] = 0
-    #     self.assertEqual(t[2], v[1, 0])
+        v[1, 0] = 0
+        self.assertEqual(t[2], v[1, 0])
 
     def test_squeeze_view(self, device="mps"):
         t = torch.ones(5, 1, 5, device=device)