[MPS] Fix gelu forward and backward ops (#94529)
Forward pass:
```
fix gelu_out_mps key
add calculation for gelu with tanh
remove gelu from blocklist
```
Backward pass:
```
fix gelu_backward_out_mps key
uniform format
add caculation for tanh approximate backward pass
unblock grad test from blocklist
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94529
Approved by: https://github.com/razarmehr, https://github.com/kulinseth
diff --git a/aten/src/ATen/native/mps/operations/Activation.mm b/aten/src/ATen/native/mps/operations/Activation.mm
index ee1c3ee..a5dae09 100644
--- a/aten/src/ATen/native/mps/operations/Activation.mm
+++ b/aten/src/ATen/native/mps/operations/Activation.mm
@@ -753,6 +753,50 @@
return erfTensor;
}
+MPSGraphTensor* tanh (MPSGraph* mpsGraph, MPSGraphTensor *inputTensor) {
+ // 0.5 * x * (1 + text{Tanh}(sqrt(2 / pi) * (x + 0.044715 * x^3)))
+ auto dataType = [inputTensor dataType];
+ constexpr float kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
+ constexpr float kKappa = 0.044715f;
+ MPSGraphTensor *betaf = [mpsGraph constantWithScalar: kBeta
+ shape: @[@1]
+ dataType: dataType];
+ MPSGraphTensor *kappaf = [mpsGraph constantWithScalar: kKappa
+ shape: @[@1]
+ dataType: dataType];
+ MPSGraphTensor *onef = [mpsGraph constantWithScalar: 1.0f
+ shape: @[@1]
+ dataType: dataType];
+ MPSGraphTensor *halff = [mpsGraph constantWithScalar: 0.5f
+ shape: @[@1]
+ dataType: dataType];
+ MPSGraphTensor *erfTensor = [mpsGraph multiplicationWithPrimaryTensor: inputTensor
+ secondaryTensor: inputTensor
+ name : nil];
+ erfTensor = [mpsGraph multiplicationWithPrimaryTensor: erfTensor
+ secondaryTensor: inputTensor
+ name : nil];
+ erfTensor = [mpsGraph multiplicationWithPrimaryTensor: erfTensor
+ secondaryTensor: kappaf
+ name : nil];
+ erfTensor = [mpsGraph additionWithPrimaryTensor: erfTensor
+ secondaryTensor: inputTensor
+ name : nil];
+ erfTensor = [mpsGraph multiplicationWithPrimaryTensor: erfTensor
+ secondaryTensor: betaf
+ name : nil];
+ erfTensor = [mpsGraph tanhWithTensor: erfTensor
+ name : nil];
+ erfTensor = [mpsGraph additionWithPrimaryTensor: erfTensor
+ secondaryTensor: onef
+ name : nil];
+ erfTensor = [mpsGraph multiplicationWithPrimaryTensor: erfTensor
+ secondaryTensor: halff
+ name : nil];
+
+ return erfTensor;
+}
+
TORCH_IMPL_FUNC(gelu_out_mps) (
const Tensor& self, c10::string_view approximate, const Tensor& output
) {
@@ -776,7 +820,7 @@
MPSStream* stream = getCurrentMPSStream();
@autoreleasepool {
- string key = "gelu_out_mps" + getTensorsStringKey({self});
+ string key = "gelu_out_mps" + getTensorsStringKey({self}) + ":" + c10::str(approximate);
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
@@ -791,7 +835,12 @@
getMPSDataType(self.scalar_type()),
getMPSShape(self));
- MPSGraphTensor* outputTensor = normcdf(mpsGraph, inputTensor);
+ MPSGraphTensor* outputTensor = nil;
+ if(approximate == "tanh") {
+ outputTensor = tanh(mpsGraph, inputTensor);
+ } else {
+ outputTensor = normcdf(mpsGraph, inputTensor);
+ }
outputTensor = [mpsGraph multiplicationWithPrimaryTensor:outputTensor
secondaryTensor:inputTensor
name:nil];
@@ -824,7 +873,6 @@
const Tensor& grad, const Tensor& self, c10::string_view approximate, const Tensor& grad_input
) {
using namespace mps;
- constexpr float kBeta = M_2_SQRTPI * M_SQRT1_2 * (0.5);
// Empty output
if(grad_input.numel() == 0)
@@ -843,7 +891,7 @@
MPSStream* stream = getCurrentMPSStream();
@autoreleasepool {
- string key = "gelu_backward_out_mps" + getTensorsStringKey({self, grad});
+ string key = "gelu_backward_out_mps" + getTensorsStringKey({self, grad}) + ":" + c10::str(approximate);
CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {
MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
@@ -861,32 +909,110 @@
MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph,
dataType,
getMPSShape(self));
- MPSGraphTensor* cdf = normcdf(mpsGraph, inputTensor);
- MPSGraphTensor *halff = [mpsGraph constantWithScalar: -0.5f
- shape: @[@1]
- dataType: dataType];
- MPSGraphTensor *betaf = [mpsGraph constantWithScalar :kBeta
- shape :@[@1]
- dataType:dataType];
- MPSGraphTensor *pdfMul = [mpsGraph squareWithTensor : inputTensor
- name : nil];
- pdfMul = [mpsGraph multiplicationWithPrimaryTensor : pdfMul
- secondaryTensor : halff
- name : nil];
- pdfMul = [mpsGraph exponentWithTensor : pdfMul
- name : nil];
- MPSGraphTensor* pdf = [mpsGraph multiplicationWithPrimaryTensor : pdfMul
- secondaryTensor : betaf
- name : nil];
- pdf = [mpsGraph multiplicationWithPrimaryTensor : inputTensor
- secondaryTensor : pdf
- name : nil];
- pdf = [mpsGraph additionWithPrimaryTensor : pdf
- secondaryTensor : cdf
- name : nil];
- MPSGraphTensor* outputTensor = [mpsGraph multiplicationWithPrimaryTensor : gradTensor
- secondaryTensor : pdf
- name : nil];
+ MPSGraphTensor* outputTensor = nil;
+ if(approximate == "tanh") {
+ constexpr float kBeta = M_SQRT2 * M_2_SQRTPI * (0.5f);
+ constexpr float kKappa = 0.044715f;
+ MPSGraphTensor *betaf = [mpsGraph constantWithScalar: kBeta
+ shape: @[@1]
+ dataType: dataType];
+ MPSGraphTensor *kappaf = [mpsGraph constantWithScalar: kKappa
+ shape: @[@1]
+ dataType: dataType];
+ MPSGraphTensor *halff = [mpsGraph constantWithScalar: 0.5f
+ shape: @[@1]
+ dataType: dataType];
+ MPSGraphTensor *onef = [mpsGraph constantWithScalar: 1.0f
+ shape: @[@1]
+ dataType: dataType];
+ MPSGraphTensor *threef = [mpsGraph constantWithScalar: 3.0f
+ shape: @[@1]
+ dataType: dataType];
+ MPSGraphTensor* x_sq = [mpsGraph multiplicationWithPrimaryTensor: inputTensor
+ secondaryTensor: inputTensor
+ name: nil];
+ MPSGraphTensor *x_cube = [mpsGraph multiplicationWithPrimaryTensor: x_sq
+ secondaryTensor: inputTensor
+ name: nil];
+ MPSGraphTensor *inner = [mpsGraph multiplicationWithPrimaryTensor: kappaf
+ secondaryTensor: x_cube
+ name: nil];
+ inner = [mpsGraph additionWithPrimaryTensor: inner
+ secondaryTensor: inputTensor
+ name: nil];
+ inner = [mpsGraph multiplicationWithPrimaryTensor: betaf
+ secondaryTensor: inner
+ name: nil];
+ MPSGraphTensor *tanhInner = [mpsGraph tanhWithTensor: inner
+ name: nil];
+ MPSGraphTensor *left = [mpsGraph multiplicationWithPrimaryTensor: halff
+ secondaryTensor: inputTensor
+ name: nil];
+ MPSGraphTensor *right = [mpsGraph additionWithPrimaryTensor: onef
+ secondaryTensor: tanhInner
+ name: nil];
+ MPSGraphTensor *left_derivative = [mpsGraph multiplicationWithPrimaryTensor: halff
+ secondaryTensor: right
+ name: nil];
+ MPSGraphTensor *tanh_derivative = [mpsGraph multiplicationWithPrimaryTensor: tanhInner
+ secondaryTensor: tanhInner
+ name: nil];
+ tanh_derivative = [mpsGraph subtractionWithPrimaryTensor: onef
+ secondaryTensor: tanh_derivative
+ name: nil];
+ MPSGraphTensor *inner_derivative = [mpsGraph multiplicationWithPrimaryTensor: threef
+ secondaryTensor: kappaf
+ name: nil];
+ inner_derivative = [mpsGraph multiplicationWithPrimaryTensor: inner_derivative
+ secondaryTensor: x_sq
+ name: nil];
+ inner_derivative = [mpsGraph additionWithPrimaryTensor: inner_derivative
+ secondaryTensor: onef
+ name: nil];
+ inner_derivative = [mpsGraph multiplicationWithPrimaryTensor: betaf
+ secondaryTensor: inner_derivative
+ name: nil];
+ MPSGraphTensor *right_derivative = [mpsGraph multiplicationWithPrimaryTensor: left
+ secondaryTensor: tanh_derivative
+ name: nil];
+ right_derivative = [mpsGraph multiplicationWithPrimaryTensor: right_derivative
+ secondaryTensor: inner_derivative
+ name: nil];
+ outputTensor = [mpsGraph additionWithPrimaryTensor: left_derivative
+ secondaryTensor: right_derivative
+ name: nil];
+ outputTensor = [mpsGraph multiplicationWithPrimaryTensor: gradTensor
+ secondaryTensor: outputTensor
+ name: nil];
+ } else {
+ constexpr float kBeta = M_2_SQRTPI * M_SQRT1_2 * (0.5);
+ MPSGraphTensor *halff = [mpsGraph constantWithScalar: -0.5f
+ shape: @[@1]
+ dataType: dataType];
+ MPSGraphTensor *betaf = [mpsGraph constantWithScalar: kBeta
+ shape: @[@1]
+ dataType: dataType];
+ MPSGraphTensor* cdf = normcdf(mpsGraph, inputTensor);
+ MPSGraphTensor *pdfMul = [mpsGraph squareWithTensor: inputTensor
+ name: nil];
+ pdfMul = [mpsGraph multiplicationWithPrimaryTensor: pdfMul
+ secondaryTensor: halff
+ name: nil];
+ pdfMul = [mpsGraph exponentWithTensor: pdfMul
+ name: nil];
+ MPSGraphTensor* pdf = [mpsGraph multiplicationWithPrimaryTensor: pdfMul
+ secondaryTensor: betaf
+ name: nil];
+ pdf = [mpsGraph multiplicationWithPrimaryTensor: inputTensor
+ secondaryTensor: pdf
+ name: nil];
+ pdf = [mpsGraph additionWithPrimaryTensor: pdf
+ secondaryTensor: cdf
+ name: nil];
+ outputTensor = [mpsGraph multiplicationWithPrimaryTensor: gradTensor
+ secondaryTensor: pdf
+ name: nil];
+ }
newCachedGraph->gradTensor_ = gradTensor;
newCachedGraph->inputTensor_ = inputTensor;
diff --git a/test/test_mps.py b/test/test_mps.py
index e3329a4..cd40f44 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -5005,6 +5005,17 @@
finally:
torch.set_num_threads(num_threads)
+ def test_gelu_tanh(self):
+ def helper(shape):
+ cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
+ x = cpu_x.detach().clone().to('mps')
+
+ gelu_tanh_result = torch.nn.functional.gelu(x, approximate='tanh')
+ gelu_tanh_result_cpu = torch.nn.functional.gelu(cpu_x, approximate='tanh')
+ self.assertEqual(gelu_tanh_result, gelu_tanh_result_cpu)
+
+ helper((2, 8, 4, 5))
+
# Test hardtanh
def test_hardtanh(self):
def helper(shape, min_val, max_val, inplace=False):
@@ -9175,6 +9186,7 @@
'_native_batch_norm_legit': ['f32'],
'native_batch_norm': ['f32'],
'native_layer_norm': ['f32'],
+ 'nn.functional.gelu': ['f32'],
}
# These ops that are problematic. So never run them even when