[MPS] Add scalar params to the softplus key. (#94256)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94256
Approved by: https://github.com/razarmehr, https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index f0f507b..cc74589 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -4651,7 +4651,7 @@
# Test softplus
def test_softplus(self):
- def helper(shape, beta=0.5, threshold=0.5):
+ def helper(shape, beta=1, threshold=20):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
@@ -4669,9 +4669,9 @@
# Test empty shape too
for shape in [(), (2, 3), (10, 10), (2, 3, 4, 5)]:
- helper(shape)
- helper(shape, beta=0.6, threshold=0.6) # relu path
- helper(shape, beta=1, threshold=20) # softplus path
+ for beta in [0.5, 1, 2, 3, 4]:
+ for threshold in [0.5, 20, 30, 40, 50]:
+ helper(shape, beta, threshold)
# Test silu