[MPS] Fix softplus with f16 input (#101948)

Fixes #101946
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101948
Approved by: https://github.com/malfet
diff --git a/aten/src/ATen/native/mps/operations/Activation.mm b/aten/src/ATen/native/mps/operations/Activation.mm
index e41a184..8fefe61 100644
--- a/aten/src/ATen/native/mps/operations/Activation.mm
+++ b/aten/src/ATen/native/mps/operations/Activation.mm
@@ -1196,8 +1196,8 @@
   };
 
   MPSStream* stream = getCurrentMPSStream();
-  MPSScalar beta_scalar = getMPSScalar(beta, ScalarType::Float);
-  MPSScalar threshold_scalar = getMPSScalar(threshold, ScalarType::Float);
+  MPSScalar beta_scalar = getMPSScalar(beta, self.scalar_type());
+  MPSScalar threshold_scalar = getMPSScalar(threshold, self.scalar_type());
 
   @autoreleasepool {
     string key = "softplus_out_mps:" + getTensorsStringKey({self}) + ":" + std::to_string(beta.to<double>()) + ":" +
@@ -1206,9 +1206,9 @@
     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
       MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
 
-      MPSGraphTensor* betaTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSDataType(ScalarType::Float));
+      MPSGraphTensor* betaTensor = mpsGraphScalarPlaceHolder(mpsGraph, inputTensor.dataType);
 
-      MPSGraphTensor* thresholdTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSDataType(ScalarType::Float));
+      MPSGraphTensor* thresholdTensor = mpsGraphScalarPlaceHolder(mpsGraph, inputTensor.dataType);
 
       MPSGraphTensor* reluTensor = [mpsGraph reLUWithTensor:inputTensor name:nil];
 
@@ -1258,8 +1258,8 @@
   if (grad_input.numel() == 0)
     return;
 
-  MPSScalar beta_scalar = getMPSScalar(beta, ScalarType::Float);
-  MPSScalar threshold_scalar = getMPSScalar(threshold, ScalarType::Float);
+  MPSScalar beta_scalar = getMPSScalar(beta, self.scalar_type());
+  MPSScalar threshold_scalar = getMPSScalar(threshold, self.scalar_type());
 
   struct CachedGraph : public MPSCachedGraph {
     CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
@@ -1281,9 +1281,9 @@
 
       MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
 
-      MPSGraphTensor* betaTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSScalarType(ScalarType::Float));
+      MPSGraphTensor* betaTensor = mpsGraphScalarPlaceHolder(mpsGraph, inputTensor.dataType);
 
-      MPSGraphTensor* thresholdTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSScalarType(ScalarType::Float));
+      MPSGraphTensor* thresholdTensor = mpsGraphScalarPlaceHolder(mpsGraph, inputTensor.dataType);
 
       MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0 shape:@[ @1 ] dataType:getMPSDataType(self)];
       MPSGraphTensor* bxTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor
diff --git a/test/test_mps.py b/test/test_mps.py
index 6f0f79d..a828487 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -6187,8 +6187,8 @@
 
     # Test softplus
     def test_softplus(self):
-        def helper(shape, beta=1, threshold=20):
-            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
+        def helper(shape, beta, threshold, dtype):
+            cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
             x = cpu_x.detach().clone().to('mps').requires_grad_()
 
             softplus_result = torch.nn.Softplus(beta=beta, threshold=threshold)(x)
@@ -6204,10 +6204,13 @@
             self.assertEqual(x.grad, cpu_x.grad)
 
         # Test empty shape too
-        for shape in [(), (2, 3), (10, 10), (2, 3, 4, 5)]:
-            for beta in [0.5, 1, 2, 3, 4]:
-                for threshold in [0.5, 20, 30, 40, 50]:
-                    helper(shape, beta, threshold)
+        for shape, beta, threshold, dtype in product(
+            [(), (2, 3), (10, 10), (2, 3, 4, 5)],
+            [0.5, 1, 2, 3, 4],
+            [0.5, 20, 30, 40, 50],
+            [torch.float16, torch.float32]
+        ):
+            helper(shape, beta, threshold, dtype)
 
     # Test silu
 
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 04dd0b6..de4801d 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -17582,12 +17582,6 @@
                 }),
                 'TestUnaryUfuncs'),
         ),
-        skips=(
-            # Error: input types 'tensor<20xf16>' and 'tensor<1xf32>' are not broadcast compatible
-            # See issue: https://github.com/pytorch/pytorch/issues/101946
-            DecorateInfo(unittest.skip('Skipped!'), "TestConsistency", "test_output_match", dtypes=(torch.float16,),),
-            DecorateInfo(unittest.skip('Skipped!'), "TestConsistency", "test_output_grad_match", dtypes=(torch.float16,),),
-        ),
     ),
     OpInfo(
         "nn.functional.mse_loss",