[MPS] Fix Clamp with strided outputs/inputs (#97858)
Fixes #94396
Fixes #87348
1. If output is strided, we don't gather input tensors.
2. If output is not strided but min_t or max_t is strided, we make min_t or max_t contiguous.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97858
Approved by: https://github.com/kulinseth
diff --git a/aten/src/ATen/native/mps/operations/TensorCompare.mm b/aten/src/ATen/native/mps/operations/TensorCompare.mm
index 18a9ad9..76f4dee 100644
--- a/aten/src/ATen/native/mps/operations/TensorCompare.mm
+++ b/aten/src/ATen/native/mps/operations/TensorCompare.mm
@@ -164,17 +164,31 @@
clamp_mps_graph(newCachedGraph, input_t, min_opt_tensor, max_opt_tensor);
});
- auto inputPlaceholder = Placeholder(cachedGraph->inputTensor, input_t);
- auto outputPlaceholder = Placeholder(cachedGraph->outputTensor, output_t);
+
+ bool gatherTensorData = true;
+ if (!output_t.is_contiguous() || output_t.is_view()) {
+ gatherTensorData = false;
+ }
+
+ auto inputPlaceholder =
+ Placeholder(cachedGraph->inputTensor, input_t, /*mpsShape=*/nil, /*gatherTensorData=*/gatherTensorData);
+ auto outputPlaceholder =
+ Placeholder(cachedGraph->outputTensor, output_t, /*mpsShape=*/nil, /*gatherTensorData=*/false);
NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease];
feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
if (has_min) {
- auto minPlaceholder = Placeholder(cachedGraph->minTensor, min_opt_tensor);
+ min_opt_tensor =
+ gatherTensorData && !min_opt_tensor.is_contiguous() ? min_opt_tensor.contiguous() : min_opt_tensor;
+ auto minPlaceholder =
+ Placeholder(cachedGraph->minTensor, min_opt_tensor, /*mpsShape=*/nil, /*gatherTensorData=*/gatherTensorData);
feeds[minPlaceholder.getMPSGraphTensor()] = minPlaceholder.getMPSGraphTensorData();
}
if (has_max) {
- auto maxPlaceholder = Placeholder(cachedGraph->maxTensor, max_opt_tensor);
+ max_opt_tensor =
+ gatherTensorData && !max_opt_tensor.is_contiguous() ? max_opt_tensor.contiguous() : max_opt_tensor;
+ auto maxPlaceholder =
+ Placeholder(cachedGraph->maxTensor, max_opt_tensor, /*mpsShape=*/nil, /*gatherTensorData=*/gatherTensorData);
feeds[maxPlaceholder.getMPSGraphTensor()] = maxPlaceholder.getMPSGraphTensorData();
}
@@ -224,8 +238,15 @@
clamp_mps_graph(newCachedGraph, input_t, input_t, input_t);
});
- auto inputPlaceholder = Placeholder(cachedGraph->inputTensor, input_t);
- auto outputPlaceholder = Placeholder(cachedGraph->outputTensor, output_t);
+ bool gatherTensorData = true;
+ if (!output_t.is_contiguous() || output_t.is_view()) {
+ gatherTensorData = false;
+ }
+
+ auto inputPlaceholder =
+ Placeholder(cachedGraph->inputTensor, input_t, /*mpsShape=*/nil, /*gatherTensorData=*/gatherTensorData);
+ auto outputPlaceholder =
+ Placeholder(cachedGraph->outputTensor, output_t, /*mpsShape=*/nil, /*gatherTensorData=*/false);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
diff --git a/test/test_mps.py b/test/test_mps.py
index a672de2..19e2273 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -5613,6 +5613,29 @@
clamp_topt_result_cpu = torch.clamp(cpu_x, max=cpu_max_t)
self.assertEqual(clamp_topt_result, clamp_topt_result_cpu)
+ # test strided x
+ clamp_result = torch.clamp(x.movedim(0, -1), min=200.0, max=600.0)
+ clamp_result_cpu = torch.clamp(cpu_x.movedim(0, -1), min=200.0, max=600.0)
+ self.assertEqual(clamp_result, clamp_result_cpu)
+
+ # test strided x, min_t, max_t
+ clamp_result = torch.clamp(x.movedim(0, -1), min=min_t.movedim(0, -1), max=max_t.movedim(0, -1))
+ clamp_result_cpu = torch.clamp(cpu_x.movedim(0, -1), min=cpu_min_t.movedim(0, -1), max=cpu_max_t.movedim(0, -1))
+ self.assertEqual(clamp_result, clamp_result_cpu)
+
+ # test strided min_t, max_t
+ clamp_result = torch.clamp(
+ x.movedim(0, -1).clone(memory_format=torch.contiguous_format),
+ min=min_t.movedim(0, -1),
+ max=max_t.movedim(0, -1)
+ )
+ clamp_result_cpu = torch.clamp(
+ cpu_x.movedim(0, -1).clone(memory_format=torch.contiguous_format),
+ min=cpu_min_t.movedim(0, -1),
+ max=cpu_max_t.movedim(0, -1)
+ )
+ self.assertEqual(clamp_result, clamp_result_cpu)
+
# test inplace clamping
x.clamp_(min=200.0, max=600.0)
cpu_x.clamp_(min=200.0, max=600.0)