Min and max NaN propagation fix in MPS backend (#130445)
Partial fix to issue #130295
Moves min and max ops to use the NaN propagating API in MPS to align with the pytorch convention. Adds a regression test to validate the fix achieves parity with cpu backend.
Co-authored-by: Nikita Shulga <[email protected]>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130445
Approved by: https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index 5540837..74bf819 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -8299,6 +8299,29 @@
[helper(dtype) for dtype in [torch.float32, torch.float16, torch.int32, torch.int16, torch.uint8, torch.int8, torch.bool]]
+ def test_min_max_nan_propagation(self):
+ def helper(dtype):
+ cpu_x = torch.tensor([1.0, float("nan"), 3.0], device="cpu")
+ mps_x = cpu_x.detach().clone().to('mps')
+
+ cpu_max = torch.max(cpu_x)
+ mps_max = torch.max(mps_x).to('cpu')
+
+ cpu_amax = torch.amax(cpu_x)
+ mps_amax = torch.amax(mps_x).to('cpu')
+
+ cpu_min = torch.min(cpu_x)
+ mps_min = torch.min(mps_x).to('cpu')
+
+ cpu_amin = torch.amin(cpu_x)
+ mps_amin = torch.amin(mps_x).to('cpu')
+
+ self.assertEqual(cpu_max, mps_max)
+ self.assertEqual(cpu_amax, mps_amax)
+ self.assertEqual(cpu_min, mps_min)
+ self.assertEqual(cpu_amin, mps_amin)
+ [helper(dtype) for dtype in [torch.float32, torch.float16, torch.bfloat16]]
+
def test_isin(self):
def helper(dtype):
shapes = [([2, 5], [3, 5, 2]), ([10, 3, 5], [20, 1, 3]),