Fix torch.clamp in MPS to handle NaN correctly (#121381) Fixes #120899 So this is interesting. There are methods that specifically propagate NaN instead of clamping to real numbers. https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph/3857573-maximumwithnanpropagationwithpri Co-authored-by: Nikita Shulga <[email protected]> Pull Request resolved: https://github.com/pytorch/pytorch/pull/121381 Approved by: https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py index d1f4b82..553ebc3 100644 --- a/test/test_mps.py +++ b/test/test_mps.py
@@ -5768,6 +5768,25 @@ self.assertEqual(clamp_result_mps, clamp_result_cpu) + def test_clamp_nan(self): + t_mps = torch.tensor([torch.nan, 1, 2], device="mps") + t_cpu = torch.tensor([torch.nan, 1, 2], device="cpu") + + clamp_min_max_mps = torch.clamp(t_mps, min=-100, max=100) + clamp_min_max_cpu = torch.clamp(t_cpu, min=-100, max=100) + + self.assertEqual(clamp_min_max_mps, clamp_min_max_cpu) + + clamp_min_mps = torch.clamp(t_mps, min=-100) + clamp_min_cpu = torch.clamp(t_cpu, min=-100) + + self.assertEqual(clamp_min_mps, clamp_min_cpu) + + clamp_max_mps = torch.clamp(t_mps, max=100) + clamp_max_cpu = torch.clamp(t_cpu, max=100) + + self.assertEqual(clamp_max_mps, clamp_max_cpu) + # Test clamp_min def test_clamp_min(self): def helper(n, c, h, w):