[MPS] Fix Clamp with strided outputs/inputs (#97858)

Fixes #94396
Fixes #87348

1. If output is strided, we don't gather input tensors.
2. If output is not strided but min_t or max_t is strided, we make min_t or max_t contiguous.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97858
Approved by: https://github.com/kulinseth
diff --git a/aten/src/ATen/native/mps/operations/TensorCompare.mm b/aten/src/ATen/native/mps/operations/TensorCompare.mm
index 18a9ad9..76f4dee 100644
--- a/aten/src/ATen/native/mps/operations/TensorCompare.mm
+++ b/aten/src/ATen/native/mps/operations/TensorCompare.mm
@@ -164,17 +164,31 @@
 
       clamp_mps_graph(newCachedGraph, input_t, min_opt_tensor, max_opt_tensor);
     });
-    auto inputPlaceholder = Placeholder(cachedGraph->inputTensor, input_t);
-    auto outputPlaceholder = Placeholder(cachedGraph->outputTensor, output_t);
+
+    bool gatherTensorData = true;
+    if (!output_t.is_contiguous() || output_t.is_view()) {
+      gatherTensorData = false;
+    }
+
+    auto inputPlaceholder =
+        Placeholder(cachedGraph->inputTensor, input_t, /*mpsShape=*/nil, /*gatherTensorData=*/gatherTensorData);
+    auto outputPlaceholder =
+        Placeholder(cachedGraph->outputTensor, output_t, /*mpsShape=*/nil, /*gatherTensorData=*/false);
 
     NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
     feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
     if (has_min) {
-      auto minPlaceholder = Placeholder(cachedGraph->minTensor, min_opt_tensor);
+      min_opt_tensor =
+          gatherTensorData && !min_opt_tensor.is_contiguous() ? min_opt_tensor.contiguous() : min_opt_tensor;
+      auto minPlaceholder =
+          Placeholder(cachedGraph->minTensor, min_opt_tensor, /*mpsShape=*/nil, /*gatherTensorData=*/gatherTensorData);
       feeds[minPlaceholder.getMPSGraphTensor()] = minPlaceholder.getMPSGraphTensorData();
     }
     if (has_max) {
-      auto maxPlaceholder = Placeholder(cachedGraph->maxTensor, max_opt_tensor);
+      max_opt_tensor =
+          gatherTensorData && !max_opt_tensor.is_contiguous() ? max_opt_tensor.contiguous() : max_opt_tensor;
+      auto maxPlaceholder =
+          Placeholder(cachedGraph->maxTensor, max_opt_tensor, /*mpsShape=*/nil, /*gatherTensorData=*/gatherTensorData);
       feeds[maxPlaceholder.getMPSGraphTensor()] = maxPlaceholder.getMPSGraphTensorData();
     }
 
@@ -224,8 +238,15 @@
       clamp_mps_graph(newCachedGraph, input_t, input_t, input_t);
     });
 
-    auto inputPlaceholder = Placeholder(cachedGraph->inputTensor, input_t);
-    auto outputPlaceholder = Placeholder(cachedGraph->outputTensor, output_t);
+    bool gatherTensorData = true;
+    if (!output_t.is_contiguous() || output_t.is_view()) {
+      gatherTensorData = false;
+    }
+
+    auto inputPlaceholder =
+        Placeholder(cachedGraph->inputTensor, input_t, /*mpsShape=*/nil, /*gatherTensorData=*/gatherTensorData);
+    auto outputPlaceholder =
+        Placeholder(cachedGraph->outputTensor, output_t, /*mpsShape=*/nil, /*gatherTensorData=*/false);
 
     NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
       inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
diff --git a/test/test_mps.py b/test/test_mps.py
index a672de2..19e2273 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -5613,6 +5613,29 @@
             clamp_topt_result_cpu = torch.clamp(cpu_x, max=cpu_max_t)
             self.assertEqual(clamp_topt_result, clamp_topt_result_cpu)
 
+            # test strided x
+            clamp_result = torch.clamp(x.movedim(0, -1), min=200.0, max=600.0)
+            clamp_result_cpu = torch.clamp(cpu_x.movedim(0, -1), min=200.0, max=600.0)
+            self.assertEqual(clamp_result, clamp_result_cpu)
+
+            # test strided x, min_t, max_t
+            clamp_result = torch.clamp(x.movedim(0, -1), min=min_t.movedim(0, -1), max=max_t.movedim(0, -1))
+            clamp_result_cpu = torch.clamp(cpu_x.movedim(0, -1), min=cpu_min_t.movedim(0, -1), max=cpu_max_t.movedim(0, -1))
+            self.assertEqual(clamp_result, clamp_result_cpu)
+
+            # test strided min_t, max_t
+            clamp_result = torch.clamp(
+                x.movedim(0, -1).clone(memory_format=torch.contiguous_format),
+                min=min_t.movedim(0, -1),
+                max=max_t.movedim(0, -1)
+            )
+            clamp_result_cpu = torch.clamp(
+                cpu_x.movedim(0, -1).clone(memory_format=torch.contiguous_format),
+                min=cpu_min_t.movedim(0, -1),
+                max=cpu_max_t.movedim(0, -1)
+            )
+            self.assertEqual(clamp_result, clamp_result_cpu)
+
             # test inplace clamping
             x.clamp_(min=200.0, max=600.0)
             cpu_x.clamp_(min=200.0, max=600.0)