[MPS] Add MPS implementation for constant_pad_nd() (#75) (#82366)
MPS has a native implementation of the constant pad nd. Adding that instead of going through the view ops helps improve performance in several benchmarks in torchbench.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82366
Approved by: https://github.com/malfet, https://github.com/razarmehr
diff --git a/aten/src/ATen/native/mps/operations/Shape.mm b/aten/src/ATen/native/mps/operations/Shape.mm
index f4ffa1e..977f9f1 100644
--- a/aten/src/ATen/native/mps/operations/Shape.mm
+++ b/aten/src/ATen/native/mps/operations/Shape.mm
@@ -21,7 +21,7 @@
// Pad operations (1D/2D/3D forward and backward)
Tensor& pad_out_template(Tensor &output, const Tensor &input_, IntArrayRef padding,
const c10::optional<Tensor>& grad_output_opt,
- MPSGraphPaddingMode mode, const string op_name)
+ MPSGraphPaddingMode mode, double constantValue, const string op_name)
{
const int padding_size = (int) padding.size();
const int padding_dim = padding_size / 2; // either 1D, 2D, or 3D
@@ -150,7 +150,7 @@
withPaddingMode:mode
leftPadding:leftPadding
rightPadding:rightPadding
- constantValue:0
+ constantValue:constantValue
name:nil];
} else {
newCachedGraph->gradOutputTensor = mpsGraphRankedPlaceHolder(mpsGraph, grad_output);
@@ -187,101 +187,116 @@
TORCH_IMPL_FUNC(reflection_pad1d_out_mps)
(const Tensor& input, IntArrayRef padding, const Tensor& output)
{
- mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt, MPSGraphPaddingModeReflect, "reflection_pad1d_out_mps");
+ mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt,
+ MPSGraphPaddingModeReflect, 0.0, "reflection_pad1d_out_mps");
}
TORCH_IMPL_FUNC(reflection_pad1d_backward_out_mps)
(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input)
{
grad_input.resize_as_(input).zero_();
- mps::pad_out_template(const_cast<Tensor&>(grad_input), input, padding, grad_output, MPSGraphPaddingModeReflect, "reflection_pad1d_backward_out_mps");
+ mps::pad_out_template(const_cast<Tensor&>(grad_input), input, padding, grad_output,
+ MPSGraphPaddingModeReflect, 0.0, "reflection_pad1d_backward_out_mps");
}
TORCH_IMPL_FUNC(replication_pad1d_out_mps)
(const Tensor& input, IntArrayRef padding, const Tensor& output)
{
- mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt, MPSGraphPaddingModeClampToEdge, "replication_pad1d_out_mps");
+ mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt,
+ MPSGraphPaddingModeClampToEdge, 0.0, "replication_pad1d_out_mps");
}
TORCH_IMPL_FUNC(replication_pad1d_backward_out_mps)
(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input)
{
grad_input.resize_as_(input).zero_();
- mps::pad_out_template(const_cast<Tensor&>(grad_input), input, padding, grad_output, MPSGraphPaddingModeClampToEdge, "replication_pad1d_backward_out_mps");
+ mps::pad_out_template(const_cast<Tensor&>(grad_input), input, padding, grad_output,
+ MPSGraphPaddingModeClampToEdge, 0.0, "replication_pad1d_backward_out_mps");
}
// 2D Reflection and Replication Padding
Tensor& reflection_pad2d_out_mps(const Tensor& input, IntArrayRef padding, Tensor& output)
{
- return mps::pad_out_template(output, input, padding, c10::nullopt, MPSGraphPaddingModeReflect, __func__);
+ return mps::pad_out_template(output, input, padding, c10::nullopt, MPSGraphPaddingModeReflect, 0.0, __func__);
}
Tensor reflection_pad2d_mps(const Tensor& input, IntArrayRef padding)
{
Tensor output = at::empty({0}, input.options());
- return mps::pad_out_template(output, input, padding, c10::nullopt, MPSGraphPaddingModeReflect, __func__);
+ return mps::pad_out_template(output, input, padding, c10::nullopt, MPSGraphPaddingModeReflect, 0.0, __func__);
}
Tensor& reflection_pad2d_backward_out_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, Tensor& grad_input)
{
grad_input.resize_as_(input).zero_();
- return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeReflect, __func__);
+ return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeReflect, 0.0, __func__);
}
Tensor reflection_pad2d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding)
{
auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
- return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeReflect, __func__);
+ return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeReflect, 0.0, __func__);
}
TORCH_IMPL_FUNC(replication_pad2d_out_mps)
(const Tensor& input, IntArrayRef padding, const Tensor& output)
{
- mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt, MPSGraphPaddingModeClampToEdge, "replication_pad2d_out_mps");
+ mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt,
+ MPSGraphPaddingModeClampToEdge, 0.0, "replication_pad2d_out_mps");
}
Tensor& replication_pad2d_backward_out_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, Tensor& grad_input)
{
grad_input.resize_as_(input).zero_();
- return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, __func__);
+ return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__);
}
Tensor replication_pad2d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding)
{
auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
- return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, __func__);
+ return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__);
}
// 3D Reflection and Replication Padding
TORCH_IMPL_FUNC(reflection_pad3d_out_mps)
(const Tensor& input, IntArrayRef padding, const Tensor& output)
{
- mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt, MPSGraphPaddingModeReflect, "reflection_pad3d_out_mps");
+ mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt,
+ MPSGraphPaddingModeReflect, 0.0, "reflection_pad3d_out_mps");
}
TORCH_IMPL_FUNC(reflection_pad3d_backward_out_mps)
(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, const Tensor& grad_input)
{
grad_input.resize_as_(input).zero_();
- mps::pad_out_template(const_cast<Tensor&>(grad_input), input, padding, grad_output, MPSGraphPaddingModeReflect, "reflection_pad3d_backward_out_mps");
+ mps::pad_out_template(const_cast<Tensor&>(grad_input), input, padding, grad_output,
+ MPSGraphPaddingModeReflect, 0.0, "reflection_pad3d_backward_out_mps");
}
TORCH_IMPL_FUNC(replication_pad3d_out_mps)
(const Tensor& input, IntArrayRef padding, const Tensor& output)
{
- mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt, MPSGraphPaddingModeClampToEdge, "replication_pad3d_out_mps");
+ mps::pad_out_template(const_cast<Tensor&>(output), input, padding, c10::nullopt,
+ MPSGraphPaddingModeClampToEdge, 0.0, "replication_pad3d_out_mps");
}
Tensor& replication_pad3d_backward_out_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding, Tensor& grad_input)
{
grad_input.resize_as_(input).zero_();
- return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, __func__);
+ return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__);
}
Tensor replication_pad3d_backward_mps(const Tensor& grad_output, const Tensor& input, IntArrayRef padding)
{
auto grad_input = at::zeros_like(input, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
- return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, __func__);
+ return mps::pad_out_template(grad_input, input, padding, grad_output, MPSGraphPaddingModeClampToEdge, 0.0, __func__);
+}
+
+// backward pass is exlicitly handled in autograd by negating the "pad" argument
+Tensor constant_pad_nd_mps(const Tensor& self, IntArrayRef pad, const Scalar& value)
+{
+ Tensor output = at::empty({0}, self.options());
+ return mps::pad_out_template(output, self, pad, c10::nullopt, MPSGraphPaddingModeConstant, value.toDouble(), __func__);
}
// topk
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 77f3b5b..6f5c5cd 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -1373,6 +1373,7 @@
variants: function
dispatch:
CompositeExplicitAutograd: constant_pad_nd
+ MPS: constant_pad_nd_mps
- func: contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)
variants: method
diff --git a/test/test_mps.py b/test/test_mps.py
index c0737b3..c1127cb 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -3418,12 +3418,15 @@
self.assertEqual(y_cpu, y_mps.cpu())
def test_pad(self):
- def helper(shape, padding, op):
+ def helper(shape, padding, op, value=0):
inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
inputCPU.retain_grad()
inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
- padCriteria = op(padding)
+ if (op in [nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d]):
+ padCriteria = op(padding, value)
+ else:
+ padCriteria = op(padding)
outputCPU = padCriteria(inputCPU)
outputMPS = padCriteria(inputMPS)
self.assertEqual(outputCPU, outputMPS)
@@ -3439,6 +3442,8 @@
helper((2, 4, 4), (1, 3), nn.ReflectionPad1d)
# Replication 1D
helper((2, 1, 6), 3, nn.ReplicationPad1d)
+ # Constant Pad 1D
+ helper((2, 3, 4), 2, nn.ConstantPad1d)
# 2D Padding
helper((1, 2, 3, 4), (1, 1, 2, 0), nn.ReflectionPad2d)
@@ -3448,11 +3453,15 @@
helper((2, 1, 6, 8), 2, nn.ReplicationPad2d)
# verify if a change in shape of padding would cause problems with graph caching
helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ReplicationPad2d)
+ # Constant Pad 2D
+ helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ConstantPad2d)
# 3D Padding
helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReflectionPad3d)
# verify if a change in shape of padding would cause problems with graph caching
helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReplicationPad3d)
+ # Constant Pad 3D
+ helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d)
# Test stack forward
def test_stack(self):