[MPS] Fix softplus with f16 input (#101948)
Fixes #101946
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101948
Approved by: https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index 6f0f79d..a828487 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -6187,8 +6187,8 @@
# Test softplus
def test_softplus(self):
- def helper(shape, beta=1, threshold=20):
- cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
+ def helper(shape, beta, threshold, dtype):
+ cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
softplus_result = torch.nn.Softplus(beta=beta, threshold=threshold)(x)
@@ -6204,10 +6204,13 @@
self.assertEqual(x.grad, cpu_x.grad)
# Test empty shape too
- for shape in [(), (2, 3), (10, 10), (2, 3, 4, 5)]:
- for beta in [0.5, 1, 2, 3, 4]:
- for threshold in [0.5, 20, 30, 40, 50]:
- helper(shape, beta, threshold)
+ for shape, beta, threshold, dtype in product(
+ [(), (2, 3), (10, 10), (2, 3, 4, 5)],
+ [0.5, 1, 2, 3, 4],
+ [0.5, 20, 30, 40, 50],
+ [torch.float16, torch.float32]
+ ):
+ helper(shape, beta, threshold, dtype)
# Test silu