[MPS] Enable conditional indexing tests (#97871)
The tests seem to be working now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97871
Approved by: https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index 67f3c78..9f6a00a 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -9237,21 +9237,20 @@
# FIXME: use supported_dtypes once uint8 is fixed
[helper(dtype) for dtype in [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16]]
- # FIXME: conditional indexing not working
- # def test_boolean_array_indexing_1(self):
- # def helper(dtype):
- # x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype)
- # x_mps = x_cpu.detach().clone().to("mps")
+ def test_boolean_array_indexing(self):
+ def helper(dtype):
+ x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype)
+ x_mps = x_cpu.detach().clone().to("mps")
- # res_cpu = x_cpu[x_cpu > 5]
- # res_mps = x_mps[x_mps > 5]
+ res_cpu = x_cpu[x_cpu > 5]
+ res_mps = x_mps[x_mps > 5]
- # print(res_cpu)
- # print(res_mps)
-
- # self.assertEqual(res_cpu, res_mps, str(dtype))
- # [helper(dtype) for dtype in self.supported_dtypes]
-
+ self.assertEqual(res_cpu, res_mps, str(dtype))
+ for dtype in self.supported_dtypes:
+ # MPS support binary op with uint8 natively starting from macOS 13.0
+ if product_version < 13.0 and dtype == torch.uint8:
+ continue
+ helper(dtype)
def test_advanced_indexing_3D_get(self):
def helper(x_cpu):
@@ -9566,24 +9565,23 @@
r = v[c > 0]
self.assertEqual(r.shape, (num_ones, 3))
- # FIXME: conditional indexing not working
- # def test_jit_indexing(self, device="mps"):
- # def fn1(x):
- # x[x < 50] = 1.0
- # return x
+ def test_jit_indexing(self, device="mps"):
+ def fn1(x):
+ x[x < 50] = 1.0
+ return x
- # def fn2(x):
- # x[0:50] = 1.0
- # return x
+ def fn2(x):
+ x[0:50] = 1.0
+ return x
- # scripted_fn1 = torch.jit.script(fn1)
- # scripted_fn2 = torch.jit.script(fn2)
- # data = torch.arange(100, device=device, dtype=torch.float)
- # out = scripted_fn1(data.detach().clone())
- # ref = torch.tensor(np.concatenate((np.ones(50), np.arange(50, 100))), device=device, dtype=torch.float)
- # self.assertEqual(out, ref)
- # out = scripted_fn2(data.detach().clone())
- # self.assertEqual(out, ref)
+ scripted_fn1 = torch.jit.script(fn1)
+ scripted_fn2 = torch.jit.script(fn2)
+ data = torch.arange(100, device=device, dtype=torch.float)
+ out = scripted_fn1(data.detach().clone())
+ ref = torch.tensor(np.concatenate((np.ones(50), np.arange(50, 100))), device=device, dtype=torch.float)
+ self.assertEqual(out, ref)
+ out = scripted_fn2(data.detach().clone())
+ self.assertEqual(out, ref)
def test_int_indices(self, device="mps"):
v = torch.randn(5, 7, 3, device=device)