Fix torch.clamp in MPS to handle NaN correctly (#121381)

Fixes #120899

So this is interesting. There are methods that specifically propagate NaN instead of clamping to real numbers.
https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph/3857573-maximumwithnanpropagationwithpri

Co-authored-by: Nikita Shulga <[email protected]>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121381
Approved by: https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index d1f4b82..553ebc3 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -5768,6 +5768,25 @@
 
         self.assertEqual(clamp_result_mps, clamp_result_cpu)
 
+    def test_clamp_nan(self):
+        t_mps = torch.tensor([torch.nan, 1, 2], device="mps")
+        t_cpu = torch.tensor([torch.nan, 1, 2], device="cpu")
+
+        clamp_min_max_mps = torch.clamp(t_mps, min=-100, max=100)
+        clamp_min_max_cpu = torch.clamp(t_cpu, min=-100, max=100)
+
+        self.assertEqual(clamp_min_max_mps, clamp_min_max_cpu)
+
+        clamp_min_mps = torch.clamp(t_mps, min=-100)
+        clamp_min_cpu = torch.clamp(t_cpu, min=-100)
+
+        self.assertEqual(clamp_min_mps, clamp_min_cpu)
+
+        clamp_max_mps = torch.clamp(t_mps, max=100)
+        clamp_max_cpu = torch.clamp(t_cpu, max=100)
+
+        self.assertEqual(clamp_max_mps, clamp_max_cpu)
+
     # Test clamp_min
     def test_clamp_min(self):
         def helper(n, c, h, w):