[MPS] Add MPS implementation for constant_pad_nd() (#75) (#82366)

MPS has a native implementation of the constant pad nd. Adding that instead of going through the view ops helps improve performance in several benchmarks in torchbench.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/82366
Approved by: https://github.com/malfet, https://github.com/razarmehr
diff --git a/aten/src/ATen/native/mps/operations/Shape.mm b/aten/src/ATen/native/mps/operations/Shape.mm
index f4ffa1e..977f9f1 100644
--- a/aten/src/ATen/native/mps/operations/Shape.mm
+++ b/aten/src/ATen/native/mps/operations/Shape.mm
@@ -21,7 +21,7 @@
 // Pad operations (1D/2D/3D forward and backward)
 Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef padding,
                          const c10::optional<Tensor>& grad_output_opt,
-                         MPSGraphPaddingMode mode, const string op_name)
+                         MPSGraphPaddingMode mode, double constantValue, const string op_name)
 {
   const int padding_size = (int) padding.size();
   const int padding_dim = padding_size / 2; // either 1D, 2D, or 3D
@@ -150,7 +150,7 @@
                                                  withPaddingMode:mode
                                                      leftPadding:leftPadding
                                                     rightPadding:rightPadding
-                                                   constantValue:0
+                                                   constantValue:constantValue
                                                             name:nil];
             } else {
               newCachedGraph->gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
@@ -187,101 +187,116 @@
 TORCH_IMPL_FUNC(reflection_pad1d_out_mps)
 (const Tensor& input, IntArrayRef padding, const Tensor& output)
 {
-  mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt, MPSGraphPaddingModeReflect, "reflection_pad1d_out_mps");
+  mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt,
+                        MPSGraphPaddingModeReflect, 0.0, "reflection_pad1d_out_mps");
 }
 
 TORCH_IMPL_FUNC(reflection_pad1d_backward_out_mps)
 (const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input)
 {
   grad_input.resize_as_(input).zero_();
-  mps::pad_out_template(const_cast<Tensor&>(grad_input), input, padding, grad_output, MPSGraphPaddingModeReflect, "reflection_pad1d_backward_out_mps");
+  mps::pad_out_template(const_cast<Tensor&>(grad_input), input, padding, grad_output,
+                        MPSGraphPaddingModeReflect, 0.0, "reflection_pad1d_backward_out_mps");
 }
 
 TORCH_IMPL_FUNC(replication_pad1d_out_mps)
 (const Tensor& input, IntArrayRef padding, const Tensor& output)
 {
-  mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt, MPSGraphPaddingModeClampToEdge, "replication_pad1d_out_mps");
+  mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt,
+                        MPSGraphPaddingModeClampToEdge, 0.0, "replication_pad1d_out_mps");
 }
 
 TORCH_IMPL_FUNC(replication_pad1d_backward_out_mps)
 (const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input)
 {
   grad_input.resize_as_(input).zero_();
-  mps::pad_out_template(const_cast<Tensor&>(grad_input), input, padding, grad_output, MPSGraphPaddingModeClampToEdge, "replication_pad1d_backward_out_mps");
+  mps::pad_out_template(const_cast<Tensor&>(grad_input), input, padding, grad_output,
+                        MPSGraphPaddingModeClampToEdge, 0.0, "replication_pad1d_backward_out_mps");
 }
 
 // 2D Reflection and Replication Padding
 Tensor& reflection_pad2d_out_mps(const Tensor& input, IntArrayRef padding, Tensor& output)
 {
-  return mps::pad_out_template(output, input, padding, c10::nullopt, MPSGraphPaddingModeReflect, __func__);
+  return mps::pad_out_template(output, input, padding, c10::nullopt, MPSGraphPaddingModeReflect, 0.0, __func__);
 }
 
 Tensor reflection_pad2d_mps(const Tensor& input, IntArrayRef padding)
 {
   Tensor output = at::empty({0}, input.options());
-  return mps::pad_out_template(output, input, padding, c10::nullopt, MPSGraphPaddingModeReflect, __func__);
+  return mps::pad_out_template(output, input, padding, c10::nullopt, MPSGraphPaddingModeReflect, 0.0, __func__);
 }
 
 Tensor& reflection_pad2d_backward_out_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, Tensor& grad_input)
 {
   grad_input.resize_as_(input).zero_();
-  return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeReflect, __func__);
+  return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeReflect, 0.0, __func__);
 }
 
 Tensor reflection_pad2d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding)
 {
   auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
-  return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeReflect, __func__);
+  return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeReflect, 0.0, __func__);
 }
 
 TORCH_IMPL_FUNC(replication_pad2d_out_mps)
 (const Tensor& input, IntArrayRef padding, const Tensor& output)
 {
-  mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt, MPSGraphPaddingModeClampToEdge, "replication_pad2d_out_mps");
+  mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt,
+                        MPSGraphPaddingModeClampToEdge, 0.0, "replication_pad2d_out_mps");
 }
 
 Tensor& replication_pad2d_backward_out_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, Tensor& grad_input)
 {
   grad_input.resize_as_(input).zero_();
-  return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, __func__);
+  return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__);
 }
 
 Tensor replication_pad2d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding)
 {
   auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
-  return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, __func__);
+  return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__);
 }
 
 // 3D Reflection and Replication Padding
 TORCH_IMPL_FUNC(reflection_pad3d_out_mps)
 (const Tensor& input, IntArrayRef padding, const Tensor& output)
 {
-  mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt, MPSGraphPaddingModeReflect, "reflection_pad3d_out_mps");
+  mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt,
+                        MPSGraphPaddingModeReflect, 0.0, "reflection_pad3d_out_mps");
 }
 
 TORCH_IMPL_FUNC(reflection_pad3d_backward_out_mps)
 (const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input)
 {
   grad_input.resize_as_(input).zero_();
-  mps::pad_out_template(const_cast<Tensor&>(grad_input), input, padding, grad_output, MPSGraphPaddingModeReflect, "reflection_pad3d_backward_out_mps");
+  mps::pad_out_template(const_cast<Tensor&>(grad_input), input, padding, grad_output,
+                        MPSGraphPaddingModeReflect, 0.0, "reflection_pad3d_backward_out_mps");
 }
 
 TORCH_IMPL_FUNC(replication_pad3d_out_mps)
 (const Tensor& input, IntArrayRef padding, const Tensor& output)
 {
-  mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt, MPSGraphPaddingModeClampToEdge, "replication_pad3d_out_mps");
+  mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt,
+                        MPSGraphPaddingModeClampToEdge, 0.0, "replication_pad3d_out_mps");
 }
 
 Tensor& replication_pad3d_backward_out_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, Tensor& grad_input)
 {
   grad_input.resize_as_(input).zero_();
-  return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, __func__);
+  return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__);
 }
 
 Tensor replication_pad3d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding)
 {
   auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
-  return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, __func__);
+  return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__);
+}
+
+// backward pass is exlicitly handled in autograd by negating the "pad" argument
+Tensor constant_pad_nd_mps(const Tensor& self, IntArrayRef pad, const Scalar& value)
+{
+  Tensor output = at::empty({0}, self.options());
+  return mps::pad_out_template(output, self, pad, c10::nullopt, MPSGraphPaddingModeConstant, value.toDouble(), __func__);
 }
 
 // topk
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 77f3b5b..6f5c5cd 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -1373,6 +1373,7 @@
   variants: function
   dispatch:
     CompositeExplicitAutograd: constant_pad_nd
+    MPS: constant_pad_nd_mps
 
 - func: contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)
   variants: method
diff --git a/test/test_mps.py b/test/test_mps.py
index c0737b3..c1127cb 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -3418,12 +3418,15 @@
         self.assertEqual(y_cpu, y_mps.cpu())
 
     def test_pad(self):
-        def helper(shape, padding, op):
+        def helper(shape, padding, op, value=0):
             inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
             inputCPU.retain_grad()
             inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
 
-            padCriteria = op(padding)
+            if (op in [nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d]):
+                padCriteria = op(padding, value)
+            else:
+                padCriteria = op(padding)
             outputCPU = padCriteria(inputCPU)
             outputMPS = padCriteria(inputMPS)
             self.assertEqual(outputCPU, outputMPS)
@@ -3439,6 +3442,8 @@
         helper((2, 4, 4), (1, 3), nn.ReflectionPad1d)
         # Replication 1D
         helper((2, 1, 6), 3, nn.ReplicationPad1d)
+        # Constant Pad 1D
+        helper((2, 3, 4), 2, nn.ConstantPad1d)
 
         # 2D Padding
         helper((1, 2, 3, 4), (1, 1, 2, 0), nn.ReflectionPad2d)
@@ -3448,11 +3453,15 @@
         helper((2, 1, 6, 8), 2, nn.ReplicationPad2d)
         # verify if a change in shape of padding would cause problems with graph caching
         helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ReplicationPad2d)
+        # Constant Pad 2D
+        helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ConstantPad2d)
 
         # 3D Padding
         helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReflectionPad3d)
         # verify if a change in shape of padding would cause problems with graph caching
         helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReplicationPad3d)
+        # Constant Pad 3D
+        helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d)
 
     # Test stack forward
     def test_stack(self):