[MPS] Fix softplus with f16 input (#101948)
Fixes #101946
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101948
Approved by: https://github.com/malfet
diff --git a/aten/src/ATen/native/mps/operations/Activation.mm b/aten/src/ATen/native/mps/operations/Activation.mm
index e41a184..8fefe61 100644
--- a/aten/src/ATen/native/mps/operations/Activation.mm
+++ b/aten/src/ATen/native/mps/operations/Activation.mm
@@ -1196,8 +1196,8 @@
};
MPSStream* stream = getCurrentMPSStream();
- MPSScalar beta_scalar = getMPSScalar(beta, ScalarType::Float);
- MPSScalar threshold_scalar = getMPSScalar(threshold, ScalarType::Float);
+ MPSScalar beta_scalar = getMPSScalar(beta, self.scalar_type());
+ MPSScalar threshold_scalar = getMPSScalar(threshold, self.scalar_type());
@autoreleasepool {
string key = "softplus_out_mps:" + getTensorsStringKey({self}) + ":" + std::to_string(beta.to<double>()) + ":" +
@@ -1206,9 +1206,9 @@
auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
- MPSGraphTensor* betaTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSDataType(ScalarType::Float));
+ MPSGraphTensor* betaTensor = mpsGraphScalarPlaceHolder(mpsGraph, inputTensor.dataType);
- MPSGraphTensor* thresholdTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSDataType(ScalarType::Float));
+ MPSGraphTensor* thresholdTensor = mpsGraphScalarPlaceHolder(mpsGraph, inputTensor.dataType);
MPSGraphTensor* reluTensor = [mpsGraph reLUWithTensor:inputTensor name:nil];
@@ -1258,8 +1258,8 @@
if (grad_input.numel() == 0)
return;
- MPSScalar beta_scalar = getMPSScalar(beta, ScalarType::Float);
- MPSScalar threshold_scalar = getMPSScalar(threshold, ScalarType::Float);
+ MPSScalar beta_scalar = getMPSScalar(beta, self.scalar_type());
+ MPSScalar threshold_scalar = getMPSScalar(threshold, self.scalar_type());
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
@@ -1281,9 +1281,9 @@
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, self);
- MPSGraphTensor* betaTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSScalarType(ScalarType::Float));
+ MPSGraphTensor* betaTensor = mpsGraphScalarPlaceHolder(mpsGraph, inputTensor.dataType);
- MPSGraphTensor* thresholdTensor = mpsGraphScalarPlaceHolder(mpsGraph, getMPSScalarType(ScalarType::Float));
+ MPSGraphTensor* thresholdTensor = mpsGraphScalarPlaceHolder(mpsGraph, inputTensor.dataType);
MPSGraphTensor* unitTensor = [mpsGraph constantWithScalar:1.0 shape:@[ @1 ] dataType:getMPSDataType(self)];
MPSGraphTensor* bxTensor = [mpsGraph multiplicationWithPrimaryTensor:inputTensor
diff --git a/test/test_mps.py b/test/test_mps.py
index 6f0f79d..a828487 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -6187,8 +6187,8 @@
# Test softplus
def test_softplus(self):
- def helper(shape, beta=1, threshold=20):
- cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
+ def helper(shape, beta, threshold, dtype):
+ cpu_x = torch.randn(shape, device='cpu', dtype=dtype, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()
softplus_result = torch.nn.Softplus(beta=beta, threshold=threshold)(x)
@@ -6204,10 +6204,13 @@
self.assertEqual(x.grad, cpu_x.grad)
# Test empty shape too
- for shape in [(), (2, 3), (10, 10), (2, 3, 4, 5)]:
- for beta in [0.5, 1, 2, 3, 4]:
- for threshold in [0.5, 20, 30, 40, 50]:
- helper(shape, beta, threshold)
+ for shape, beta, threshold, dtype in product(
+ [(), (2, 3), (10, 10), (2, 3, 4, 5)],
+ [0.5, 1, 2, 3, 4],
+ [0.5, 20, 30, 40, 50],
+ [torch.float16, torch.float32]
+ ):
+ helper(shape, beta, threshold, dtype)
# Test silu
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 04dd0b6..de4801d 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -17582,12 +17582,6 @@
}),
'TestUnaryUfuncs'),
),
- skips=(
- # Error: input types 'tensor<20xf16>' and 'tensor<1xf32>' are not broadcast compatible
- # See issue: https://github.com/pytorch/pytorch/issues/101946
- DecorateInfo(unittest.skip('Skipped!'), "TestConsistency", "test_output_match", dtypes=(torch.float16,),),
- DecorateInfo(unittest.skip('Skipped!'), "TestConsistency", "test_output_grad_match", dtypes=(torch.float16,),),
- ),
),
OpInfo(
"nn.functional.mse_loss",