[MPS] Add MPS implementation for constant_pad_nd() (#75) (#82366)
MPS has a native implementation of the constant pad nd. Adding that instead of going through the view ops helps improve performance in several benchmarks in torchbench.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82366
Approved by: https://github.com/malfet, https://github.com/razarmehr
diff --git a/test/test_mps.py b/test/test_mps.py
index c0737b3..c1127cb 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -3418,12 +3418,15 @@
self.assertEqual(y_cpu, y_mps.cpu())
def test_pad(self):
- def helper(shape, padding, op):
+ def helper(shape, padding, op, value=0):
inputCPU = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
inputCPU.retain_grad()
inputMPS = inputCPU.detach().clone().to('mps').requires_grad_()
- padCriteria = op(padding)
+ if (op in [nn.ConstantPad1d, nn.ConstantPad2d, nn.ConstantPad3d]):
+ padCriteria = op(padding, value)
+ else:
+ padCriteria = op(padding)
outputCPU = padCriteria(inputCPU)
outputMPS = padCriteria(inputMPS)
self.assertEqual(outputCPU, outputMPS)
@@ -3439,6 +3442,8 @@
helper((2, 4, 4), (1, 3), nn.ReflectionPad1d)
# Replication 1D
helper((2, 1, 6), 3, nn.ReplicationPad1d)
+ # Constant Pad 1D
+ helper((2, 3, 4), 2, nn.ConstantPad1d)
# 2D Padding
helper((1, 2, 3, 4), (1, 1, 2, 0), nn.ReflectionPad2d)
@@ -3448,11 +3453,15 @@
helper((2, 1, 6, 8), 2, nn.ReplicationPad2d)
# verify if a change in shape of padding would cause problems with graph caching
helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ReplicationPad2d)
+ # Constant Pad 2D
+ helper((2, 1, 6, 8), (2, 4, 3, 5), nn.ConstantPad2d)
# 3D Padding
helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReflectionPad3d)
# verify if a change in shape of padding would cause problems with graph caching
helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReplicationPad3d)
+ # Constant Pad 3D
+ helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d)
# Test stack forward
def test_stack(self):