[MPS] Add support for torch.linalg.cross (#91642)

* Add support for torch.linalg.cross
* Make use of `metal::cross` for float and half. For the other dtypes implement cross manually

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91642
Approved by: https://github.com/razarmehr, https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index c070c75..b226b9a 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -441,6 +441,56 @@
         D = torch.mm(B, C).cpu()
         torch.testing.assert_close(D, torch.full((5, 5), 6.0))
 
+    def test_linalg_cross(self):
+        def helper(dtype):
+            device = "mps"
+            if dtype is torch.int32 or dtype is torch.int64:
+                x = torch.randint(0, 99999, (100, 3, 100), dtype=dtype, device=device)
+                y = torch.randint(0, 99999, (100, 3, 100), dtype=dtype, device=device)
+            else:
+                x = torch.rand(100, 3, 100, dtype=dtype, device=device)
+                y = torch.rand(100, 3, 100, dtype=dtype, device=device)
+            x_cpu = x.to("cpu")
+            y_cpu = y.to("cpu")
+            res1 = torch.linalg.cross(x, y, dim=1)
+            res2 = torch.tensor((), dtype=dtype, device=device)
+            res1_cpu = torch.linalg.cross(x_cpu, y_cpu, dim=1)
+            res2_cpu = torch.tensor((), dtype=dtype, device="cpu")
+            torch.linalg.cross(x, y, dim=1, out=res2)
+            torch.linalg.cross(x_cpu, y_cpu, dim=1, out=res2_cpu)
+            self.assertEqual(res1, res2)
+            self.assertEqual(res1, res1_cpu)
+            self.assertEqual(res2, res2_cpu)
+
+            # test for broadcastable inputs
+            if dtype is torch.int32 or dtype is torch.int64:
+                x = torch.randint(0, 99999, (1, 3, 2), dtype=dtype, device=device)
+                y = torch.randint(0, 99999, (4, 3, 1), dtype=dtype, device=device)
+            else:
+                x = torch.rand(1, 3, 2, dtype=dtype, device=device)
+                y = torch.rand(4, 3, 1, dtype=dtype, device=device)
+            x_cpu = x.to("cpu")
+            y_cpu = y.to("cpu")
+            res1 = torch.linalg.cross(x, y, dim=1)
+            res2 = torch.tensor((), dtype=dtype, device=device)
+            res1_cpu = torch.linalg.cross(x_cpu, y_cpu, dim=1)
+            res2_cpu = torch.tensor((), dtype=dtype, device="cpu")
+            torch.linalg.cross(x, y, dim=1, out=res2)
+            torch.linalg.cross(x_cpu, y_cpu, dim=1, out=res2_cpu)
+            self.assertEqual(res1, res2)
+            self.assertEqual(res1, res1_cpu)
+            self.assertEqual(res2, res2_cpu)
+        [helper(dtype) for dtype in [torch.int32, torch.int64, torch.float32]]
+
+    def test_cross(self):
+        a = torch.randn(4, 3, device="mps")
+        b = torch.randn(4, 3, device="mps")
+        a_cpu = a.to("cpu")
+        b_cpu = b.to("cpu")
+        res = torch.cross(a, b, dim=1)
+        res_cpu = torch.cross(a_cpu, b_cpu, dim=1)
+        self.assertEqual(res, res_cpu)
+
     def test_addmm(self):
         A = torch.ones(5, 5).to("mps")
         B = torch.ones(5, 6).to("mps")
@@ -8314,6 +8364,8 @@
         'where': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
         'nonzero': ['f32', 'i16', 'i32', 'i64'],
         'unique_consecutive': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
+        'cross': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
+        'linalg.cross': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
     }