[MPS] Enable conditional indexing tests (#97871)

The tests seem to be working now.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97871
Approved by: https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index 67f3c78..9f6a00a 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -9237,21 +9237,20 @@
         # FIXME: use supported_dtypes once uint8 is fixed
         [helper(dtype) for dtype in [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16]]
 
-    # FIXME: conditional indexing not working
-    # def test_boolean_array_indexing_1(self):
-    #     def helper(dtype):
-    #         x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype)
-    #         x_mps = x_cpu.detach().clone().to("mps")
+    def test_boolean_array_indexing(self):
+        def helper(dtype):
+            x_cpu = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11]], dtype=dtype)
+            x_mps = x_cpu.detach().clone().to("mps")
 
-    #         res_cpu = x_cpu[x_cpu > 5]
-    #         res_mps = x_mps[x_mps > 5]
+            res_cpu = x_cpu[x_cpu > 5]
+            res_mps = x_mps[x_mps > 5]
 
-    #         print(res_cpu)
-    #         print(res_mps)
-
-    #         self.assertEqual(res_cpu, res_mps, str(dtype))
-    #     [helper(dtype) for dtype in self.supported_dtypes]
-
+            self.assertEqual(res_cpu, res_mps, str(dtype))
+        for dtype in self.supported_dtypes:
+            # MPS support binary op with uint8 natively starting from macOS 13.0
+            if product_version < 13.0 and dtype == torch.uint8:
+                continue
+            helper(dtype)
 
     def test_advanced_indexing_3D_get(self):
         def helper(x_cpu):
@@ -9566,24 +9565,23 @@
         r = v[c > 0]
         self.assertEqual(r.shape, (num_ones, 3))
 
-    # FIXME: conditional indexing not working
-    # def test_jit_indexing(self, device="mps"):
-    #     def fn1(x):
-    #         x[x < 50] = 1.0
-    #         return x
+    def test_jit_indexing(self, device="mps"):
+        def fn1(x):
+            x[x < 50] = 1.0
+            return x
 
-    #     def fn2(x):
-    #         x[0:50] = 1.0
-    #         return x
+        def fn2(x):
+            x[0:50] = 1.0
+            return x
 
-    #     scripted_fn1 = torch.jit.script(fn1)
-    #     scripted_fn2 = torch.jit.script(fn2)
-    #     data = torch.arange(100, device=device, dtype=torch.float)
-    #     out = scripted_fn1(data.detach().clone())
-    #     ref = torch.tensor(np.concatenate((np.ones(50), np.arange(50, 100))), device=device, dtype=torch.float)
-    #     self.assertEqual(out, ref)
-    #     out = scripted_fn2(data.detach().clone())
-    #     self.assertEqual(out, ref)
+        scripted_fn1 = torch.jit.script(fn1)
+        scripted_fn2 = torch.jit.script(fn2)
+        data = torch.arange(100, device=device, dtype=torch.float)
+        out = scripted_fn1(data.detach().clone())
+        ref = torch.tensor(np.concatenate((np.ones(50), np.arange(50, 100))), device=device, dtype=torch.float)
+        self.assertEqual(out, ref)
+        out = scripted_fn2(data.detach().clone())
+        self.assertEqual(out, ref)
 
     def test_int_indices(self, device="mps"):
         v = torch.randn(5, 7, 3, device=device)