[MPS] Add support for torch.linalg.cross (#91642)
* Add support for torch.linalg.cross
* Make use of `metal::cross` for float and half. For the other dtypes implement cross manually
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91642
Approved by: https://github.com/razarmehr, https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index c070c75..b226b9a 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -441,6 +441,56 @@
D = torch.mm(B, C).cpu()
torch.testing.assert_close(D, torch.full((5, 5), 6.0))
+ def test_linalg_cross(self):
+ def helper(dtype):
+ device = "mps"
+ if dtype is torch.int32 or dtype is torch.int64:
+ x = torch.randint(0, 99999, (100, 3, 100), dtype=dtype, device=device)
+ y = torch.randint(0, 99999, (100, 3, 100), dtype=dtype, device=device)
+ else:
+ x = torch.rand(100, 3, 100, dtype=dtype, device=device)
+ y = torch.rand(100, 3, 100, dtype=dtype, device=device)
+ x_cpu = x.to("cpu")
+ y_cpu = y.to("cpu")
+ res1 = torch.linalg.cross(x, y, dim=1)
+ res2 = torch.tensor((), dtype=dtype, device=device)
+ res1_cpu = torch.linalg.cross(x_cpu, y_cpu, dim=1)
+ res2_cpu = torch.tensor((), dtype=dtype, device="cpu")
+ torch.linalg.cross(x, y, dim=1, out=res2)
+ torch.linalg.cross(x_cpu, y_cpu, dim=1, out=res2_cpu)
+ self.assertEqual(res1, res2)
+ self.assertEqual(res1, res1_cpu)
+ self.assertEqual(res2, res2_cpu)
+
+ # test for broadcastable inputs
+ if dtype is torch.int32 or dtype is torch.int64:
+ x = torch.randint(0, 99999, (1, 3, 2), dtype=dtype, device=device)
+ y = torch.randint(0, 99999, (4, 3, 1), dtype=dtype, device=device)
+ else:
+ x = torch.rand(1, 3, 2, dtype=dtype, device=device)
+ y = torch.rand(4, 3, 1, dtype=dtype, device=device)
+ x_cpu = x.to("cpu")
+ y_cpu = y.to("cpu")
+ res1 = torch.linalg.cross(x, y, dim=1)
+ res2 = torch.tensor((), dtype=dtype, device=device)
+ res1_cpu = torch.linalg.cross(x_cpu, y_cpu, dim=1)
+ res2_cpu = torch.tensor((), dtype=dtype, device="cpu")
+ torch.linalg.cross(x, y, dim=1, out=res2)
+ torch.linalg.cross(x_cpu, y_cpu, dim=1, out=res2_cpu)
+ self.assertEqual(res1, res2)
+ self.assertEqual(res1, res1_cpu)
+ self.assertEqual(res2, res2_cpu)
+ [helper(dtype) for dtype in [torch.int32, torch.int64, torch.float32]]
+
+ def test_cross(self):
+ a = torch.randn(4, 3, device="mps")
+ b = torch.randn(4, 3, device="mps")
+ a_cpu = a.to("cpu")
+ b_cpu = b.to("cpu")
+ res = torch.cross(a, b, dim=1)
+ res_cpu = torch.cross(a_cpu, b_cpu, dim=1)
+ self.assertEqual(res, res_cpu)
+
def test_addmm(self):
A = torch.ones(5, 5).to("mps")
B = torch.ones(5, 6).to("mps")
@@ -8314,6 +8364,8 @@
'where': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'nonzero': ['f32', 'i16', 'i32', 'i64'],
'unique_consecutive': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
+ 'cross': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
+ 'linalg.cross': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
}