[MPS] Add softplus backward (#79873)
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79873
Approved by: https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index 1bf01c6..8e06c58 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -3598,7 +3598,14 @@
softplus_result = torch.nn.Softplus(beta=0.5, threshold=0.5)(x)
softplus_result_cpu = torch.nn.Softplus(beta=0.5, threshold=0.5)(cpu_x)
+ cpu_grad = torch.randn(softplus_result.shape)
+ grad = cpu_grad.to('mps')
+
+ softplus_result.backward(gradient=grad)
+ softplus_result_cpu.backward(gradient=cpu_grad)
+
self.assertEqual(softplus_result, softplus_result_cpu)
+ self.assertEqual(x.grad, cpu_x.grad)
# Test empty shape too
for shape in [(), (2, 3), (10, 10), (2, 3, 4, 5)]: