Fix mps constant pad (#89864)

Support arbitrary dimensions for constant padding on MPS

Fixes #89624
Fixes #87277

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89864
Approved by: https://github.com/kulinseth, https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index 75b3c35..520e6cc 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -3635,6 +3635,15 @@
         r_mps = m(input_mps)
         self.assertEqual(r_cpu, r_mps.to("cpu"))
 
+        # Arbitrary input dimensions
+        pad = (1, 1, 0, 0, 0, 0)
+        value = 3.5
+        input_cpu = torch.randn((1, 1, 3, 3, 3, 3, 3, 3, 3, 3))
+        input_mps = input_cpu.detach().clone().to("mps")
+        r_cpu = F.pad(input_cpu, pad=pad, value=value)
+        r_mps = F.pad(input_mps, pad=pad, value=value)
+        self.assertEqual(r_cpu, r_mps.to("cpu"))
+
     def test_circular_pad(self):
         # https://github.com/pytorch/pytorch/issues/80856
         k_cpu = torch.ones(3, 3, 9, 9)
@@ -3700,6 +3709,14 @@
         helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ConstantPad2d)
         # input size < pad size
         helper((1, 2, 3), (0, 0, 0, 1), nn.ConstantPad2d)
+        # pad dims < input dims
+        helper((50, 9, 300), (0, 0, 0, 31), nn.ConstantPad2d)
+        # pad dims == input dims
+        helper((1, 3), (0, 2, 0, 1), nn.ConstantPad2d)
+        # input.numel() == 0 but output.numel() > 0
+        helper((0, 3, 3), (1, 1, 1, 1, 1, 1), nn.ConstantPad2d)
+        # pad dims < input dims - 2
+        helper((1, 2, 3, 4), (1, 2), nn.ConstantPad2d)
 
         # 3D Padding
         helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReflectionPad3d)
@@ -3707,6 +3724,8 @@
         helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReplicationPad3d)
         # Constant Pad 3D
         helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d)
+        # input size < pad size
+        helper((2, 4, 6), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d)
 
     # Test stack forward
     def test_stack(self):