Fix mps constant pad (#89864)
Support arbitrary dimensions for constant padding on MPS
Fixes #89624
Fixes #87277
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89864
Approved by: https://github.com/kulinseth, https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index 75b3c35..520e6cc 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -3635,6 +3635,15 @@
r_mps = m(input_mps)
self.assertEqual(r_cpu, r_mps.to("cpu"))
+ # Arbitrary input dimensions
+ pad = (1, 1, 0, 0, 0, 0)
+ value = 3.5
+ input_cpu = torch.randn((1, 1, 3, 3, 3, 3, 3, 3, 3, 3))
+ input_mps = input_cpu.detach().clone().to("mps")
+ r_cpu = F.pad(input_cpu, pad=pad, value=value)
+ r_mps = F.pad(input_mps, pad=pad, value=value)
+ self.assertEqual(r_cpu, r_mps.to("cpu"))
+
def test_circular_pad(self):
# https://github.com/pytorch/pytorch/issues/80856
k_cpu = torch.ones(3, 3, 9, 9)
@@ -3700,6 +3709,14 @@
helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ConstantPad2d)
# input size < pad size
helper((1, 2, 3), (0, 0, 0, 1), nn.ConstantPad2d)
+ # pad dims < input dims
+ helper((50, 9, 300), (0, 0, 0, 31), nn.ConstantPad2d)
+ # pad dims == input dims
+ helper((1, 3), (0, 2, 0, 1), nn.ConstantPad2d)
+ # input.numel() == 0 but output.numel() > 0
+ helper((0, 3, 3), (1, 1, 1, 1, 1, 1), nn.ConstantPad2d)
+ # pad dims < input dims - 2
+ helper((1, 2, 3, 4), (1, 2), nn.ConstantPad2d)
# 3D Padding
helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReflectionPad3d)
@@ -3707,6 +3724,8 @@
helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReplicationPad3d)
# Constant Pad 3D
helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d)
+ # input size < pad size
+ helper((2, 4, 6), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d)
# Test stack forward
def test_stack(self):