[MPS] Fix Clamp with strided outputs/inputs (#97858)
Fixes #94396
Fixes #87348
1. If output is strided, we don't gather input tensors.
2. If output is not strided but min_t or max_t is strided, we make min_t or max_t contiguous.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97858
Approved by: https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index a672de2..19e2273 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -5613,6 +5613,29 @@
clamp_topt_result_cpu = torch.clamp(cpu_x, max=cpu_max_t)
self.assertEqual(clamp_topt_result, clamp_topt_result_cpu)
+ # test strided x
+ clamp_result = torch.clamp(x.movedim(0, -1), min=200.0, max=600.0)
+ clamp_result_cpu = torch.clamp(cpu_x.movedim(0, -1), min=200.0, max=600.0)
+ self.assertEqual(clamp_result, clamp_result_cpu)
+
+ # test strided x, min_t, max_t
+ clamp_result = torch.clamp(x.movedim(0, -1), min=min_t.movedim(0, -1), max=max_t.movedim(0, -1))
+ clamp_result_cpu = torch.clamp(cpu_x.movedim(0, -1), min=cpu_min_t.movedim(0, -1), max=cpu_max_t.movedim(0, -1))
+ self.assertEqual(clamp_result, clamp_result_cpu)
+
+ # test strided min_t, max_t
+ clamp_result = torch.clamp(
+ x.movedim(0, -1).clone(memory_format=torch.contiguous_format),
+ min=min_t.movedim(0, -1),
+ max=max_t.movedim(0, -1)
+ )
+ clamp_result_cpu = torch.clamp(
+ cpu_x.movedim(0, -1).clone(memory_format=torch.contiguous_format),
+ min=cpu_min_t.movedim(0, -1),
+ max=cpu_max_t.movedim(0, -1)
+ )
+ self.assertEqual(clamp_result, clamp_result_cpu)
+
# test inplace clamping
x.clamp_(min=200.0, max=600.0)
cpu_x.clamp_(min=200.0, max=600.0)