[MPS] Fix the failure with ReplicatePad3D (#96988)
- Only ReflectPad needs the torch checks for input arguments and not the ReplicatePad
- Added a test case
- The failure was originally found in test_modules with test `test_forward_nn_ReplicationPad3d_mps_float32`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96988
Approved by: https://github.com/DenisVieriu97
diff --git a/test/test_mps.py b/test/test_mps.py
index bf0d0b0..127e53a 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -5723,6 +5723,8 @@
helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReflectionPad3d)
# verify if a change in shape of padding would cause problems with graph caching
helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ReplicationPad3d)
+ # case where input_d == pad_front/back for ReplicationPad3d
+ helper((3, 4, 5, 6, 7), (1, 2, 3, 4, 5, 6), nn.ReplicationPad3d)
# Constant Pad 3D
helper((2, 4, 6, 8, 4), (1, 3, 3, 5, 3, 4), nn.ConstantPad3d)
# input size < pad size