Fixes #128429: NaN in triu op on MPS (#128575)

Fixes triu op when k > 0 and the lower triangle of the input tensor contains inf leading to NaNs in the computation through complement. Fixed by using select API instead.

Fixes #128429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128575
Approved by: https://github.com/kulinseth
diff --git a/aten/src/ATen/native/mps/operations/TriangularOps.mm b/aten/src/ATen/native/mps/operations/TriangularOps.mm
index 5fa0b22..dcea978 100644
--- a/aten/src/ATen/native/mps/operations/TriangularOps.mm
+++ b/aten/src/ATen/native/mps/operations/TriangularOps.mm
@@ -35,11 +35,16 @@
 
       if (k > 0) {
         MPSGraphTensor* diagMinusOneTensor = [mpsGraph constantWithScalar:(k - 1) dataType:MPSDataTypeInt32];
-        MPSGraphTensor* complementTensor = [mpsGraph bandPartWithTensor:inputTensor
-                                                         numLowerTensor:minusOneTensor
-                                                         numUpperTensor:diagMinusOneTensor
-                                                                   name:nil];
-        outputTensor = [mpsGraph subtractionWithPrimaryTensor:inputTensor secondaryTensor:complementTensor name:nil];
+        MPSGraphTensor* onesTensor = [mpsGraph constantWithScalar:1 dataType:MPSDataTypeInt32];
+        onesTensor = [mpsGraph broadcastTensor:onesTensor toShape:inputTensor.shape name:nil];
+        MPSGraphTensor* maskTensor = [mpsGraph bandPartWithTensor:onesTensor
+                                                   numLowerTensor:minusOneTensor
+                                                   numUpperTensor:diagMinusOneTensor
+                                                             name:nil];
+        outputTensor = [mpsGraph selectWithPredicateTensor:maskTensor
+                                       truePredicateTensor:[mpsGraph constantWithScalar:0 dataType:inputTensor.dataType]
+                                      falsePredicateTensor:inputTensor
+                                                      name:nil];
       } else {
         MPSGraphTensor* minusDiagTensor = [mpsGraph constantWithScalar:(-k) dataType:MPSDataTypeInt32];
         outputTensor = [mpsGraph bandPartWithTensor:inputTensor
diff --git a/test/test_mps.py b/test/test_mps.py
index 275013f..311cf82 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1617,6 +1617,14 @@
             a = torch.tensor(v, dtype=dtype, device="mps") * b
             self.compare_with_numpy(torch.exp, np.exp, a)
 
+    def test_triu_inf(self, device="mps", dtype=torch.float):
+        for diag in [-1, 0, 1]:
+            mask = torch.full((3, 6, 6), float("-inf"))
+            mask_mps = mask.clone().detach().to('mps')
+            cpu_ref = torch.triu(mask, diagonal=diag)
+            mps_out = torch.triu(mask_mps, diagonal=diag)
+            self.assertEqual(cpu_ref, mps_out)
+
     def test_exp1(self, device="mps", dtype=torch.float):
         input = torch.tensor([-0.1, 1.0, -0.9, 0.1], device=device, dtype=dtype)
         output = torch.exp(input)