[MPS] Fix bidirectional LSTM & small one-direction LSTM fix (#95563)

Fixes #94754

With this PR I hope to finish my breathtaking journey of fixing MPS LSTM.

Here, I enable `bidirectional` on MPS. Also, I've noticed that cache key did not account for all parameters, so there could have been problems with one-directional LSTM when created without bias or dropout and then with one of them.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95563
Approved by: https://github.com/jhavukainen, https://github.com/kulinseth, https://github.com/malfet
diff --git a/aten/src/ATen/native/RNN.cpp b/aten/src/ATen/native/RNN.cpp
index 6b2b985..c2dbb99 100644
--- a/aten/src/ATen/native/RNN.cpp
+++ b/aten/src/ATen/native/RNN.cpp
@@ -1422,7 +1422,7 @@
     return std::make_tuple(std::move(output), std::move(hy), std::move(cy));
   }
 #ifdef USE_MPS
-  if (_input.is_mps() && !bidirectional) {
+  if (_input.is_mps()) {
     std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> output = at::_lstm_mps(_input, hx, _params, has_biases,
             num_layers, dropout_p, train, bidirectional, batch_first);
     std::tuple<Tensor, Tensor, Tensor> return_values = std::make_tuple(std::get<0>(output), std::get<1>(output), std::get<2>(output));
diff --git a/aten/src/ATen/native/mps/operations/RnnOps.mm b/aten/src/ATen/native/mps/operations/RnnOps.mm
index 9e59a6c..fbf0a99 100644
--- a/aten/src/ATen/native/mps/operations/RnnOps.mm
+++ b/aten/src/ATen/native/mps/operations/RnnOps.mm
@@ -23,6 +23,85 @@
     return output_dimensions;
 }
 
+/**
+ * Accepts tensors in Pytorch API format and returns tensors in MPS API format
+ * @return tuple of tensors to use with MPS API in order:
+ * stateTensor, cellStateTensor, recurrentWeight, inputWeight, biasTensor
+ */
+static std::tuple<MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*, MPSGraphTensor*>
+    getMPSTensorsFromPytorchTensors(MPSGraph* mpsGraph,
+                                MPSGraphTensor* stateTensor, MPSGraphTensor* cellStateTensor,
+                                NSMutableArray<MPSGraphTensor*> *recurrentKernelWeightsList,
+                                NSMutableArray<MPSGraphTensor*> *kernelWeightsList,
+                                NSMutableArray<MPSGraphTensor*> *kernelBiasList,
+                                NSMutableArray<MPSGraphTensor*> *recurrentBiasList,
+                                bool has_biases, bool bidirectional, size_t layer_no) {
+    MPSGraphTensor* biasTensor_ = nil;
+    MPSGraphTensor* stateTensor_ = nil, *cellStateTensor_ = nil;
+    MPSGraphTensor* recurrentWeight_ = nil, *inputWeight_ = nil;
+
+    if (bidirectional) {
+        stateTensor_ = [mpsGraph sliceTensor:stateTensor
+                                   dimension:0
+                                       start:layer_no * 2
+                                      length:2
+                                        name:nil];
+        // [2, N, H] -> [N, 2, H]
+        stateTensor_ = [mpsGraph transposeTensor:stateTensor_ dimension: 0 withDimension: 1 name:nil];
+        // [N, 2, H] -> [N, 2 * H]
+        stateTensor_ = [mpsGraph flatten2DTensor:stateTensor_ axis:1 name:nil];
+        cellStateTensor_ = [mpsGraph sliceTensor:cellStateTensor
+                                       dimension:0
+                                           start:layer_no * 2
+                                          length:2
+                                            name:nil];
+        cellStateTensor_ = [mpsGraph transposeTensor:cellStateTensor_ dimension: 0 withDimension: 1 name:nil];
+        cellStateTensor_ = [mpsGraph flatten2DTensor:cellStateTensor_ axis:1 name:nil];
+
+        recurrentWeight_ = [mpsGraph
+            concatTensor: [mpsGraph expandDimsOfTensor: recurrentKernelWeightsList[layer_no * 2] axis: 0 name: nil]
+              withTensor: [mpsGraph expandDimsOfTensor: recurrentKernelWeightsList[layer_no * 2 + 1] axis: 0 name: nil]
+               dimension: 0
+                    name: nil
+        ];
+        inputWeight_ = [mpsGraph
+            concatTensor: kernelWeightsList[layer_no * 2]
+              withTensor: kernelWeightsList[layer_no * 2 + 1]
+               dimension: 0
+                    name: nil
+        ];
+        if (has_biases) {
+          auto biasTensorFwd_ = [mpsGraph additionWithPrimaryTensor:kernelBiasList[layer_no * 2]
+                                                    secondaryTensor:recurrentBiasList[layer_no * 2]
+                                                               name:nil];
+          auto biasTensorBack_ = [mpsGraph additionWithPrimaryTensor:kernelBiasList[layer_no * 2 + 1]
+                                                     secondaryTensor:recurrentBiasList[layer_no * 2 + 1]
+                                                                name:nil];
+
+          biasTensor_ = [mpsGraph concatTensor:biasTensorFwd_ withTensor:biasTensorBack_ dimension:0 name:nil];
+        }
+    } else {
+        stateTensor_ = [mpsGraph sliceTensor:stateTensor
+                                   dimension:0
+                                       start:layer_no
+                                      length:1
+                                        name:nil];
+        cellStateTensor_ = [mpsGraph sliceTensor:cellStateTensor
+                                       dimension:0
+                                           start:layer_no
+                                          length:1
+                                            name:nil];
+        recurrentWeight_ = recurrentKernelWeightsList[layer_no];
+        inputWeight_ = kernelWeightsList[layer_no];
+        if (has_biases) {
+          biasTensor_ = [mpsGraph additionWithPrimaryTensor:kernelBiasList[layer_no]
+                                            secondaryTensor:recurrentBiasList[layer_no]
+                                                       name:nil];
+        }
+    }
+    return std::make_tuple(stateTensor_, cellStateTensor_, recurrentWeight_, inputWeight_, biasTensor_);
+}
+
 std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> _lstm_mps(const Tensor& input, TensorList hx, TensorList params, bool has_biases, int64_t num_layers, double dropout_p, bool train, bool bidirectional, bool batch_first) {
     using namespace mps;
 
@@ -38,15 +117,17 @@
     std::vector<Tensor> recurrent_kernel_weights;
     std::vector<Tensor> biases;
     std::vector<Tensor> recurrent_biases;
-    for (size_t i = 0; i < num_layers; i+=1) {
+
+    const int64_t total_layers = num_layers * (bidirectional ? 2 : 1);
+
+    for (const auto i : c10::irange(total_layers)) {
+        const int stride = (has_biases ? 4 : 2);
+        kernel_weights.push_back(params[i*stride]);
+        recurrent_kernel_weights.push_back(params[i*stride+1]);
+
         if (has_biases) {
-            kernel_weights.push_back(params[i*4]);
-            recurrent_kernel_weights.push_back(params[i*4+1]);
-            biases.push_back(params[i*4+2]);
-            recurrent_biases.push_back(params[i*4+3]);
-        } else {
-            kernel_weights.push_back(params[i*2]);
-            recurrent_kernel_weights.push_back(params[i*2+1]);
+          biases.push_back(params[i*stride+2]);
+          recurrent_biases.push_back(params[i*stride+3]);
         }
     }
 
@@ -65,7 +146,7 @@
     MPSStream* stream = getCurrentMPSStream();
 
     @autoreleasepool {
-      string key = "lstm_" + getTensorsStringKey({input, hx[0], hx[1]}) + getMPSTypeString(input.scalar_type()) + "_num_layers_" + std::to_string(num_layers);
+      string key = "lstm_" + getTensorsStringKey({input, hx[0], hx[1]}) + getMPSTypeString(input.scalar_type()) + "_num_layers_" + std::to_string(num_layers) + "_bidirectional_" + std::to_string(bidirectional) + "_has_biases_" + std::to_string(has_biases) + "_dropout_" + std::to_string(dropout_p) + "_batch_first_" + std::to_string(batch_first);
       CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
       if(!cachedGraph) {
         MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
@@ -81,7 +162,7 @@
             NSMutableArray<MPSGraphTensor*> *recurrentBiasList = [[NSMutableArray alloc] initWithCapacity:params.size()];
             NSMutableArray<MPSGraphTensor*> *layersOutputsList = [[NSMutableArray alloc] initWithCapacity:num_layers];
 
-            for (size_t i = 0; i < num_layers; i += 1) {
+            for (const auto i : c10::irange(total_layers)) {
                 [kernelWeightsList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()), getMPSShape(kernel_weights[i]))];
                 [recurrentKernelWeightsList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()),getMPSShape(recurrent_kernel_weights[i]))];
                 if(has_biases) {
@@ -100,7 +181,7 @@
             MPSGraphTensor* cellStateTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()), getMPSShape(hx[1]));
             std::vector<MPSGraphTensor*> inputTensors = {inputTensor, stateTensor, cellStateTensor,};
 
-            if(batch_first) {
+            if (batch_first) {
                 inputTensor = [mpsGraph transposeTensor:inputTensor
                                                 dimension:0
                                                 withDimension:1
@@ -113,49 +194,61 @@
             NSMutableArray<MPSGraphTensor*>* outputCellStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
             NSMutableArray<MPSGraphTensor*>* outputZStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
             NSMutableArray<MPSGraphTensor*>* outputCellStateFwdArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
-            for(int i = 0; i < num_layers; i++) {
-                MPSGraphTensor* biasTensor = nil;
-                if(has_biases) {
-                    biasTensor = [mpsGraph additionWithPrimaryTensor:kernelBiasList[i]
-                                                     secondaryTensor:recurrentBiasList[i]
-                                                                name:nil];
-                }
-                MPSGraphTensor* stateTensor_ = [mpsGraph sliceTensor:stateTensor
-                                                           dimension:0
-                                                               start:i
-                                                              length:1
-                                                                name:nil];
-                MPSGraphTensor* cellStateTensor_ = [mpsGraph sliceTensor:cellStateTensor
-                                                               dimension:0
-                                                                   start:i
-                                                                  length:1
-                                                                    name:nil];
+            for (int i = 0; i < num_layers; i++) {
+                auto tensorsData = getMPSTensorsFromPytorchTensors(mpsGraph, stateTensor, cellStateTensor,
+                                                                   recurrentKernelWeightsList, kernelWeightsList,
+                                                                   kernelBiasList, recurrentBiasList, has_biases,
+                                                                   bidirectional, i);
+                MPSGraphTensor* stateTensor_ = std::get<0>(tensorsData), *cellStateTensor_ = std::get<1>(tensorsData);
+                MPSGraphTensor* recurrentWeight_ = std::get<2>(tensorsData), *inputWeight_ = std::get<3>(tensorsData);
+                MPSGraphTensor* biasTensor_ = std::get<4>(tensorsData);
+
+
                 outputs = [mpsGraph LSTMWithSourceTensor:inputTensor_
-                                        recurrentWeight:recurrentKernelWeightsList[i]
-                                            inputWeight:kernelWeightsList[i]
-                                                   bias:biasTensor
+                                        recurrentWeight:recurrentWeight_
+                                            inputWeight:inputWeight_
+                                                   bias:biasTensor_
                                               initState:stateTensor_
                                                initCell:cellStateTensor_
                                              descriptor:opDesc
                                                    name:nil];
 
                 inputTensor_ = [outputs objectAtIndex:0];
-                // no need to keep a final layer output copy as it is
+                // no need to keep the final layer output copy as it is
                 // returned anyway and not used in backprop
-                if(i != num_layers - 1) {
+                if (i != num_layers - 1) {
                     [layersOutputsList addObject:[mpsGraph expandDimsOfTensor:inputTensor_
                                                                          axis:0
                                                                          name:nil]];
                 }
-                if(dropout_p>0.0 && train && (i!=num_layers-1)) {
+                if (dropout_p>0.0 && train && (i!=num_layers-1)) {
                     inputTensor_ = [mpsGraph dropoutTensor:inputTensor_
                                                       rate:dropout_p
                                                       name:nil];
 
                 }
 
-                [outputStateArray addObject:[mpsGraph sliceTensor:[outputs objectAtIndex:0] dimension:0 start:-1 length:1 name:nil]];
-                [outputCellStateArray addObject:[mpsGraph sliceTensor:[outputs objectAtIndex:1] dimension:0 start:-1 length:1 name:nil]];
+                if (bidirectional) {
+                    // [1, N, 2 * H]
+                    auto stateLastT = [mpsGraph sliceTensor:[outputs objectAtIndex:0] dimension:0 start:-1 length:1 name:nil];
+                    auto stateFirstT = [mpsGraph sliceTensor:[outputs objectAtIndex:0] dimension:0 start:0 length:1 name:nil];
+                    // [1, N, H] ([1, N, 0:H])
+                    auto stateForward = [mpsGraph sliceTensor:stateLastT dimension: -1 start:0 length:hx[0].sizes()[2] name:nil];
+                    // [1, N, H] ([1, N, H:2H])
+                    auto stateBack = [mpsGraph sliceTensor:stateFirstT dimension: -1 start:hx[0].sizes()[2] length:hx[0].sizes()[2] name:nil];
+                    [outputStateArray addObject:stateForward];
+                    [outputStateArray addObject:stateBack];
+
+                    auto cellStateLastT = [mpsGraph sliceTensor:[outputs objectAtIndex:1] dimension:0 start:-1 length:1 name:nil];
+                    auto cellStateFirstT = [mpsGraph sliceTensor:[outputs objectAtIndex:1] dimension:0 start:0 length:1 name:nil];
+                    auto cellStateForward = [mpsGraph sliceTensor:cellStateLastT dimension: -1 start:0 length:hx[1].sizes()[2] name:nil];
+                    auto cellStateBack = [mpsGraph sliceTensor:cellStateFirstT dimension: -1 start:hx[1].sizes()[2] length:hx[1].sizes()[2] name:nil];
+                    [outputCellStateArray addObject:cellStateForward];
+                    [outputCellStateArray addObject:cellStateBack];
+                } else {
+                    [outputStateArray addObject:[mpsGraph sliceTensor:[outputs objectAtIndex:0] dimension:0 start:-1 length:1 name:nil]];
+                    [outputCellStateArray addObject:[mpsGraph sliceTensor:[outputs objectAtIndex:1] dimension:0 start:-1 length:1 name:nil]];
+                }
                 [outputCellStateFwdArray addObject: [mpsGraph expandDimsOfTensor:[outputs objectAtIndex:1]
                                                                             axis:0
                                                                             name:nil]];
@@ -205,21 +298,18 @@
       NSMutableArray<MPSGraphTensor*> *biasList = cachedGraph->biasList_;
       NSMutableArray<MPSGraphTensor*> *recurrentBiasList = cachedGraph->recurrentBiasList_;
 
-      Placeholder kernelWeight, recurrentKernelWeight, bias, recurrentBias;
-
       NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*> *feeds = [[[NSMutableDictionary alloc] init] autorelease];
-      for (size_t i = 0; i < num_layers; i+=1) {
-          kernelWeight = Placeholder([kernelWeightsList objectAtIndex:i], kernel_weights[i]);
-          recurrentKernelWeight = Placeholder([recurrentKernelWeightsList objectAtIndex:i], recurrent_kernel_weights[i]);
+      for (const auto i : c10::irange(total_layers)) {
+          Placeholder kernelWeight = Placeholder([kernelWeightsList objectAtIndex:i], kernel_weights[i]);
+          Placeholder recurrentKernelWeight = Placeholder([recurrentKernelWeightsList objectAtIndex:i], recurrent_kernel_weights[i]);
           [feeds setObject:kernelWeight.getMPSGraphTensorData() forKey:kernelWeight.getMPSGraphTensor()];
           [feeds setObject:recurrentKernelWeight.getMPSGraphTensorData() forKey:recurrentKernelWeight.getMPSGraphTensor()];
-          if(has_biases) {
-            bias = Placeholder([biasList objectAtIndex:i], biases[i]);
-            recurrentBias = Placeholder([recurrentBiasList objectAtIndex:i], recurrent_biases[i]);
+          if (has_biases) {
+            Placeholder bias = Placeholder([biasList objectAtIndex:i], biases[i]);
+            Placeholder recurrentBias = Placeholder([recurrentBiasList objectAtIndex:i], recurrent_biases[i]);
             [feeds setObject:bias.getMPSGraphTensorData() forKey:bias.getMPSGraphTensor()];
             [feeds setObject:recurrentBias.getMPSGraphTensorData() forKey:recurrentBias.getMPSGraphTensor()];
           }
-
       }
       Placeholder selfPlaceholder   = Placeholder(cachedGraph->inputTensors_[0], input);
       Placeholder selfState   = Placeholder(cachedGraph->inputTensors_[1], hx[0]);
@@ -274,22 +364,22 @@
     std::vector<Tensor> recurrent_kernel_weights;
     std::vector<Tensor> biases;
     std::vector<Tensor> recurrent_biases;
-    for (size_t i = 0; i < num_layers; i+=1) {
-        if(has_biases) {
-            kernel_weights.push_back(params[i*4]);
-            recurrent_kernel_weights.push_back(params[i*4+1]);
-            biases.push_back(params[i*4+2]);
-            recurrent_biases.push_back(params[i*4+3]);
-        } else {
-            kernel_weights.push_back(params[i*2]);
-            recurrent_kernel_weights.push_back(params[i*2+1]);
-        }
+
+    const int64_t total_layers = num_layers * (bidirectional ? 2 : 1);
+
+    for (const auto i : c10::irange(total_layers)) {
+      const int stride = (has_biases ? 4 : 2);
+      kernel_weights.push_back(params[i*stride]);
+      recurrent_kernel_weights.push_back(params[i*stride+1]);
+      if(has_biases) {
+          biases.push_back(params[i*stride + 2]);
+          recurrent_biases.push_back(params[i*stride + 3]);
+      }
     }
 
     struct CachedGraph : public MPSCachedGraph {
       CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
       std::vector<MPSGraphTensor*> inputTensors_;
-      std::vector<MPSGraphTensor*> outputTensors_;
       NSMutableArray<MPSGraphTensor*> *kernelWeightsList_ = nil;
       NSMutableArray<MPSGraphTensor*> *recurrentKernelWeightsList_ = nil;
       NSMutableArray<MPSGraphTensor*> *biasList_ = nil;
@@ -308,9 +398,9 @@
     MPSStream* stream = getCurrentMPSStream();
     @autoreleasepool {
 
-        string key = "lstm_backward_" + getTensorsStringKey({input, z_state, cell_state_fwd, grad_y, grad_cy, grad_hy})+ getMPSTypeString(input.scalar_type()) + "_num_layers_" + std::to_string(num_layers);
+        string key = "lstm_backward_" + getTensorsStringKey({input, z_state, cell_state_fwd, grad_y, grad_cy, grad_hy})+ getMPSTypeString(input.scalar_type()) + "_num_layers_" + std::to_string(num_layers) + "_bidirectional_" + std::to_string(bidirectional) + "_has_biases_" + std::to_string(has_biases) + "_batch_first_" + std::to_string(batch_first);
         CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
-        if(!cachedGraph) {
+        if (!cachedGraph) {
             MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
 
                 CachedGraph *newCachedGraph = nil;
@@ -323,10 +413,10 @@
                     NSMutableArray<MPSGraphTensor*> *kernelBiasList = [[NSMutableArray alloc] initWithCapacity:params.size()];
                     NSMutableArray<MPSGraphTensor*> *recurrentBiasList = [[NSMutableArray alloc] initWithCapacity:params.size()];
 
-                    for (size_t i = 0; i < num_layers; i += 1) {
+                    for (const auto i : c10::irange(total_layers)) {
                         [kernelWeightsList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()), getMPSShape(kernel_weights[i]))];
                         [recurrentKernelWeightsList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()),getMPSShape(recurrent_kernel_weights[i]))];
-                        if(has_biases) {
+                        if (has_biases) {
                             [kernelBiasList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()),getMPSShape(biases[i]))];
                             [recurrentBiasList addObject:mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(input.scalar_type()),getMPSShape(recurrent_biases[i]))];
                         }
@@ -377,6 +467,8 @@
                     NSMutableArray<MPSGraphTensor*>* gradStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
                     NSMutableArray<MPSGraphTensor*>* gradCellStateArray = [[NSMutableArray alloc] initWithCapacity:num_layers];
 
+                    auto hidden_size = hx[0].sizes()[2];
+
                     for (int i = num_layers - 1; i >= 0; i--) {
                         MPSGraphTensor* zState = [mpsGraph sliceTensor:zStateTensor
                                                                 dimension:0
@@ -394,46 +486,57 @@
                         cellStateFwd = [mpsGraph squeezeTensor:cellStateFwd
                                                     axis:0
                                                     name:nil];
-                        MPSGraphTensor* biasTensor = nil;
-                        if(has_biases) {
-                            biasTensor = [mpsGraph additionWithPrimaryTensor:kernelBiasList[i]
-                                                            secondaryTensor:recurrentBiasList[i]
-                                                            name:nil];
+                        auto tensorsData = getMPSTensorsFromPytorchTensors(mpsGraph, stateTensor, cellStateTensor,
+                                                                           recurrentKernelWeightsList, kernelWeightsList,
+                                                                           kernelBiasList, recurrentBiasList, has_biases,
+                                                                           bidirectional, i);
+                        MPSGraphTensor* stateTensor_ = std::get<0>(tensorsData), *cellStateTensor_ = std::get<1>(tensorsData);
+                        MPSGraphTensor* recurrentWeight_ = std::get<2>(tensorsData), *inputWeight_ = std::get<3>(tensorsData);
+                        MPSGraphTensor* biasTensor_ = std::get<4>(tensorsData);
+
+                        MPSGraphTensor* gradientHyTensor_ = nil, *gradientCyTensor_ = nil;
+                        if (bidirectional) {
+                            gradientHyTensor_ = [mpsGraph sliceTensor:gradientHyTensor
+                                                            dimension:0
+                                                                start:i * 2
+                                                               length:2
+                                                                 name:nil];
+                            // [2, N, H] -> [N, 2, H]
+                            gradientHyTensor_ = [mpsGraph transposeTensor:gradientHyTensor_ dimension: 0 withDimension: 1 name:nil];
+                            // [N, 2, H] -> [N, 2 * H]
+                            gradientHyTensor_ = [mpsGraph flatten2DTensor:gradientHyTensor_ axis:1 name:nil];
+
+
+                            gradientCyTensor_ = [mpsGraph sliceTensor:gradientCyTensor
+                                                            dimension:0
+                                                                start:i * 2
+                                                               length:2
+                                                                 name:nil];
+                            gradientCyTensor_ = [mpsGraph transposeTensor:gradientCyTensor_ dimension: 0 withDimension: 1 name:nil];
+                            gradientCyTensor_ = [mpsGraph flatten2DTensor:gradientCyTensor_ axis:1 name:nil];
                         } else {
-                            biasTensor = [mpsGraph constantWithScalar:0.0
-                                                            dataType:inputTensor.dataType];
+                            gradientHyTensor_ = [mpsGraph sliceTensor:gradientHyTensor
+                                                            dimension:0
+                                                                start:i
+                                                               length:1
+                                                                 name:nil];
+
+                            gradientCyTensor_ = [mpsGraph sliceTensor:gradientCyTensor
+                                                            dimension:0
+                                                                start:i
+                                                               length:1
+                                                                 name:nil];
                         }
 
-                        MPSGraphTensor* stateTensor_ = [mpsGraph sliceTensor:stateTensor
-                                                                    dimension:0
-                                                                    start:i
-                                                                    length:1
-                                                                    name:nil];
-                        MPSGraphTensor* cellStateTensor_ = [mpsGraph sliceTensor:cellStateTensor
-                                                                            dimension:0
-                                                                            start:i
-                                                                            length:1
-                                                                            name:nil];
-                        MPSGraphTensor* gradientHyTensor_ = [mpsGraph sliceTensor:gradientHyTensor
-                                                                    dimension:0
-                                                                    start:i
-                                                                    length:1
-                                                                    name:nil];
-
-                        MPSGraphTensor* gradientCyTensor_ = [mpsGraph sliceTensor:gradientCyTensor
-                                                                            dimension:0
-                                                                            start:i
-                                                                            length:1
-                                                                            name:nil];
-
                         MPSGraphTensor* iterationInputTensor_ = nil;
                         if (i == 0) {
                             iterationInputTensor_ = inputTensor;
                         } else {
                             iterationInputTensor_ = [mpsGraph sliceTensor:layersOutputsTensor
                                                                 dimension: 0
-                                                                    // last element in layersOutputsTensor contains
-                                                                    // **inputs** for the last layer
+                                                                    // the last element in layersOutputsTensor
+                                                                    // contains **inputs** for the **last** layer
+                                                                    // and so on
                                                                     start: i - num_layers
                                                                    length: 1
                                                                      name: nil];
@@ -443,14 +546,14 @@
                         }
 
                         outputs = [mpsGraph LSTMGradientsWithSourceTensor: iterationInputTensor_
-                                             recurrentWeight: recurrentKernelWeightsList[i]
+                                             recurrentWeight: recurrentWeight_
                                               sourceGradient: gradientTensor_
                                                       zState: zState
                                                cellOutputFwd: cellStateFwd
                                                stateGradient: gradientHyTensor_
                                                 cellGradient: gradientCyTensor_
-                                                 inputWeight: kernelWeightsList[i]
-                                                        bias: biasTensor
+                                                 inputWeight: inputWeight_
+                                                        bias: biasTensor_
                                                    initState: stateTensor_
                                                     initCell: cellStateTensor_
                                                         mask: nil
@@ -459,14 +562,103 @@
                                                         name: nil];
 
                         gradientTensor_ = [outputs objectAtIndex:0];
-                        [gradRecWeightsArray insertObject:[outputs objectAtIndex:1] atIndex:0];
-                        [gradWeightsArray insertObject:[outputs objectAtIndex:2] atIndex:0];
-                        [gradBiasArray insertObject: [outputs objectAtIndex:3] atIndex:0];
-                        [gradStateArray insertObject: [mpsGraph expandDimsOfTensor:[outputs objectAtIndex:4] axis:0 name:nil]  atIndex:0];
-                        [gradCellStateArray insertObject: [mpsGraph expandDimsOfTensor:[outputs objectAtIndex:5] axis:0 name:nil] atIndex:0];
-                    }
-                    std::vector<MPSGraphTensor*> outputTensors = {[outputs objectAtIndex:0],[outputs objectAtIndex:1],[outputs objectAtIndex:2],[outputs objectAtIndex:3], [outputs objectAtIndex:4], [outputs objectAtIndex:5]};
+                        if (bidirectional) {
+                            int outputIter = 1;
+                            auto gradRecWeightsBidirectional = [outputs objectAtIndex:outputIter++];
+                            auto gradRecWeightFwd = [mpsGraph sliceTensor:gradRecWeightsBidirectional
+                                                                dimension: 0
+                                                                    start: 0
+                                                                   length: 1
+                                                                     name: nil];
+                            gradRecWeightFwd = [mpsGraph squeezeTensor:gradRecWeightFwd axis:0 name: nil];
+                            auto gradRecWeightBack = [mpsGraph sliceTensor:gradRecWeightsBidirectional
+                                                                dimension: 0
+                                                                    start: 1
+                                                                   length: 1
+                                                                     name: nil];
+                            gradRecWeightBack = [mpsGraph squeezeTensor:gradRecWeightBack axis:0 name: nil];
 
+                            // inverse order
+                            [gradRecWeightsArray insertObject:gradRecWeightBack atIndex:0];
+                            [gradRecWeightsArray insertObject:gradRecWeightFwd atIndex:0];
+
+                            auto gradWeightsBidirectional = [outputs objectAtIndex:outputIter++];
+                            auto gradWeightFwd = [mpsGraph sliceTensor:gradWeightsBidirectional
+                                                                dimension: 0
+                                                                    start: 0
+                                                                   length: hidden_size * 4
+                                                                     name: nil];
+                            auto gradWeightBack = [mpsGraph sliceTensor:gradWeightsBidirectional
+                                                             dimension: 0
+                                                                 start: hidden_size * 4
+                                                                length: hidden_size * 4
+                                                                  name: nil];
+
+                            [gradWeightsArray insertObject:gradWeightBack atIndex:0];
+                            [gradWeightsArray insertObject:gradWeightFwd atIndex:0];
+
+                            if (has_biases) {
+                              // has shape [1, 1, 8H] vs [8H] as should be
+                              // so, squeeze these two first dimensions
+                              auto gradBiasBidirectional = [outputs objectAtIndex:outputIter++];
+                              gradBiasBidirectional = [mpsGraph squeezeTensor: gradBiasBidirectional
+                                                                         axes: @[@0, @1]
+                                                                         name: nil];
+                              auto gradBiasFwd = [mpsGraph sliceTensor:gradBiasBidirectional
+                                                             dimension: 0
+                                                                 start: 0
+                                                                length: hidden_size * 4
+                                                                  name: nil];
+                              auto gradBiasBack = [mpsGraph sliceTensor:gradBiasBidirectional
+                                                             dimension: 0
+                                                                 start: hidden_size * 4
+                                                                length: hidden_size * 4
+                                                                  name: nil];
+
+                              [gradBiasArray insertObject: gradBiasBack atIndex:0];
+                              [gradBiasArray insertObject: gradBiasFwd atIndex:0];
+                            }
+
+                            auto gradStateBidirectional = [outputs objectAtIndex:outputIter++];
+                            auto gradStateFwd = [mpsGraph sliceTensor:gradStateBidirectional
+                                                            dimension: 1
+                                                                start: 0
+                                                               length: hidden_size
+                                                                 name: nil];
+                            auto gradStateBack = [mpsGraph sliceTensor:gradStateBidirectional
+                                                            dimension: 1
+                                                                start: hidden_size
+                                                               length: hidden_size
+                                                                 name: nil];
+
+                            [gradStateArray insertObject: [mpsGraph expandDimsOfTensor:gradStateBack axis:0 name:nil]  atIndex:0];
+                            [gradStateArray insertObject: [mpsGraph expandDimsOfTensor:gradStateFwd axis:0 name:nil]  atIndex:0];
+
+                            auto gradCellStateBidirectional = [outputs objectAtIndex:outputIter++];
+                            auto gradCellStateFwd = [mpsGraph sliceTensor:gradCellStateBidirectional
+                                                                dimension: 1
+                                                                    start: 0
+                                                                   length: hidden_size
+                                                                     name: nil];
+                            auto gradCellStateBack = [mpsGraph sliceTensor:gradCellStateBidirectional
+                                                                 dimension: 1
+                                                                     start: hidden_size
+                                                                    length: hidden_size
+                                                                      name: nil];
+
+                            [gradCellStateArray insertObject: [mpsGraph expandDimsOfTensor:gradCellStateBack axis:0 name:nil]  atIndex:0];
+                            [gradCellStateArray insertObject: [mpsGraph expandDimsOfTensor:gradCellStateFwd axis:0 name:nil]  atIndex:0];
+                        } else {
+                            int outputIter = 1;
+                            [gradRecWeightsArray insertObject:[outputs objectAtIndex:outputIter++] atIndex:0];
+                            [gradWeightsArray insertObject:[outputs objectAtIndex:outputIter++] atIndex:0];
+                            if (has_biases) {
+                              [gradBiasArray insertObject: [outputs objectAtIndex:outputIter++] atIndex:0];
+                            }
+                            [gradStateArray insertObject: [mpsGraph expandDimsOfTensor:[outputs objectAtIndex:outputIter++] axis:0 name:nil]  atIndex:0];
+                            [gradCellStateArray insertObject: [mpsGraph expandDimsOfTensor:[outputs objectAtIndex:outputIter++] axis:0 name:nil] atIndex:0];
+                        }
+                    }
                     if (batch_first) {
                         MPSGraphTensor* gradientTensorTransposed = [mpsGraph transposeTensor:gradientTensor_
                                                                                    dimension: 0
@@ -477,7 +669,6 @@
                         newCachedGraph->gradOutput_ = gradientTensor_;
                     }
 
-                    newCachedGraph->outputTensors_ = outputTensors;
                     newCachedGraph->gradRecWeights_ = gradRecWeightsArray;
                     newCachedGraph->gradWeights_ = gradWeightsArray;
                     newCachedGraph->gradBias_ = gradBiasArray;
@@ -514,18 +705,15 @@
         NSMutableArray<MPSGraphTensor*> *recurrentKernelWeightsList = cachedGraph->recurrentKernelWeightsList_;
         NSMutableArray<MPSGraphTensor*> *biasList = cachedGraph->biasList_;
         NSMutableArray<MPSGraphTensor*> *recurrentBiasList = cachedGraph->recurrentBiasList_;
-        Placeholder kernelWeight;
-        Placeholder recurrentKernelWeight;
-        Placeholder bias;
-        Placeholder recurrentBias;
-        for (size_t i = 0; i < num_layers; i+=1) {
-            kernelWeight = Placeholder([kernelWeightsList objectAtIndex:i], kernel_weights[i]);
-            recurrentKernelWeight = Placeholder([recurrentKernelWeightsList objectAtIndex:i], recurrent_kernel_weights[i]);
+
+        for (const auto i : c10::irange(total_layers)) {
+            Placeholder kernelWeight = Placeholder([kernelWeightsList objectAtIndex:i], kernel_weights[i]);
+            Placeholder recurrentKernelWeight = Placeholder([recurrentKernelWeightsList objectAtIndex:i], recurrent_kernel_weights[i]);
             [feeds setObject:kernelWeight.getMPSGraphTensorData() forKey:kernelWeight.getMPSGraphTensor()];
             [feeds setObject:recurrentKernelWeight.getMPSGraphTensorData() forKey:recurrentKernelWeight.getMPSGraphTensor()];
-            if(has_biases) {
-                bias = Placeholder([biasList objectAtIndex:i], biases[i]);
-                recurrentBias = Placeholder([recurrentBiasList objectAtIndex:i], recurrent_biases[i]);
+            if (has_biases) {
+              Placeholder bias = Placeholder([biasList objectAtIndex:i], biases[i]);
+              Placeholder recurrentBias = Placeholder([recurrentBiasList objectAtIndex:i], recurrent_biases[i]);
                 [feeds setObject:bias.getMPSGraphTensorData() forKey:bias.getMPSGraphTensor()];
                 [feeds setObject:recurrentBias.getMPSGraphTensorData() forKey:recurrentBias.getMPSGraphTensor()];
             }
@@ -556,25 +744,32 @@
         Placeholder gradRecWeightsPlaceholder, gradWeightsPlaceholder, gradBiasPlaceholder;
 
         std::vector<Tensor> weights;
-        for (int i = 0; i < num_layers; i++) {
+        for (const auto i : c10::irange(total_layers)) {
             Tensor grad_rec_weights = at::empty_like(recurrent_kernel_weights[i]);
             Tensor grad_weights = at::empty_like(kernel_weights[i]);
-            Tensor grad_bias = at::empty((kernel_weights[i].size(0)), kernel_weights[i].options());
+
             weights.push_back(grad_weights);
             weights.push_back(grad_rec_weights);
 
-            if(has_biases) {
-                weights.push_back(grad_bias);
-                weights.push_back(grad_bias);
-            }
-
             gradRecWeightsPlaceholder = Placeholder([gradRecWeightsArray objectAtIndex: i], grad_rec_weights);
             gradWeightsPlaceholder = Placeholder([gradWeightsArray objectAtIndex: i], grad_weights);
-            gradBiasPlaceholder = Placeholder([gradBiasArray objectAtIndex: i], grad_bias);
 
-            [results setObject:gradBiasPlaceholder.getMPSGraphTensorData() forKey:gradBiasPlaceholder.getMPSGraphTensor()];
             [results setObject:gradRecWeightsPlaceholder.getMPSGraphTensorData() forKey:gradRecWeightsPlaceholder.getMPSGraphTensor()];
             [results setObject:gradWeightsPlaceholder.getMPSGraphTensorData() forKey:gradWeightsPlaceholder.getMPSGraphTensor()];
+
+            if (has_biases) {
+                Tensor grad_bias = at::empty((kernel_weights[i].size(0)), kernel_weights[i].options());
+
+                // In PyTorch LSTM API there are two biases. The second bias is included for CuDNN compatibility.
+                // In this implementation these two biases are added together and used further.
+                // Therefore, they have equal gradient, and it is pushed
+                // twice for each of two bias vectors.
+                weights.push_back(grad_bias);
+                weights.push_back(grad_bias);
+
+                gradBiasPlaceholder = Placeholder([gradBiasArray objectAtIndex: i], grad_bias);
+                [results setObject:gradBiasPlaceholder.getMPSGraphTensorData() forKey:gradBiasPlaceholder.getMPSGraphTensor()];
+            }
         }
 
         runMPSGraph(stream, cachedGraph->graph(), feeds, results);
diff --git a/test/test_mps.py b/test/test_mps.py
index 062a513..bce494c 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -9255,91 +9255,91 @@
         self.assertEqual(out, torch.zeros(2, device=device), atol=0, rtol=0)
 
 class TestRNNMPS(TestCaseMPS):
-    def test_lstm_1(self, device="mps", dtype=torch.float32):
-        for layers in [1] if product_version < 13.0 else [1, 2, 5]:
-            torch.random.manual_seed(42)
-            rnn = nn.LSTM(7, 4, layers, device="cpu")
-            input = torch.randn(2, 3, 7, device="cpu")
-            hx = torch.randn(layers, 3, 4, device="cpu")
-            cx = torch.randn(layers, 3, 4, device="cpu")
+    def _lstm_helper(self, num_layers, dtype, device, bidirectional=False, bias=True, batch_first=False,
+                     seq_len=3, batch_size=5, hidden_size=7, input_size=11, backward=False):
+        rnn = nn.LSTM(
+            input_size=input_size,
+            hidden_size=hidden_size,
+            num_layers=num_layers,
+            bias=bias,
+            bidirectional=bidirectional,
+            batch_first=batch_first,
+            device="cpu"
+        )
+        bidirectional_mul = 2 if bidirectional else 1
 
-            cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx))
+        if batch_first:
+            input = torch.randn(batch_size, seq_len, input_size, device="cpu", dtype=dtype, requires_grad=backward)
+            hx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
+                             requires_grad=backward)
+            cx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
+                             requires_grad=backward)
+        else:
+            input = torch.randn(seq_len, batch_size, input_size, device="cpu", dtype=dtype, requires_grad=backward)
+            hx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
+                             requires_grad=backward)
+            cx = torch.randn(num_layers * bidirectional_mul, batch_size, hidden_size, device="cpu", dtype=dtype,
+                             requires_grad=backward)
 
+        cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx))
+
+        rnn = rnn.to(device)
+        input = input.to(device)
+        hx = hx.to(device)
+        cx = cx.to(device)
+        output, (hn, cn) = rnn(input, (hx, cx))
+
+        self.assertEqual(cpu_output, output)
+        self.assertEqual(cpu_hn, hn)
+        self.assertEqual(cpu_cn, cn)
+
+        def get_backward_results(rnn, device, inp, hx, cx):
             rnn = rnn.to(device)
-            input = input.to(device)
-            hx = hx.to(device)
-            cx = cx.to(device)
-            output, (hn, cn) = rnn(input, (hx, cx))
+            inp, hx, cx = inp.to(device), hx.to(device), cx.to(device)
 
-            self.assertEqual(cpu_output, output)
-            self.assertEqual(cpu_hn, hn)
-            self.assertEqual(cpu_cn, cn)
+            output, _ = rnn(inp, (hx, cx))
+            f = 3 * output.sum() + (hx * cx).sum()
 
-            # test batch_first
-            rnn = nn.LSTM(7, 4, layers, device="cpu", batch_first=True)
-            input = torch.randn(3, 2, 7, device="cpu")
-            hx = torch.randn(layers, 3, 4, device="cpu")
-            cx = torch.randn(layers, 3, 4, device="cpu")
-            cpu_output, (cpu_hn, cpu_cn) = rnn(input, (hx, cx))
+            param_names, params = zip(*rnn.named_parameters())
+            param_grads = zip(param_names, torch.autograd.grad(f, params, retain_graph=True))
 
-            rnn = rnn.to(device)
-            input = input.to(device)
-            hx = hx.to(device)
-            cx = cx.to(device)
-            output, (hn, cn) = rnn(input, (hx, cx))
+            input_grad, hx_grad, cx_grad = torch.autograd.grad(f, [inp, hx, cx])
+            return output, param_grads, input_grad, hx_grad, cx_grad
 
-            self.assertEqual(cpu_output, output)
-            self.assertEqual(cpu_hn, hn)
-            self.assertEqual(cpu_cn, cn)
+        if backward:
+            cpu_output, cpu_weights_grad, cpu_input_grad, cpu_hx_grad, cpu_cx_grad =\
+                get_backward_results(rnn, "cpu", input, hx, cx)
+            mps_output, mps_weights_grad, mps_input_grad, mps_hx_grad, mps_cx_grad =\
+                get_backward_results(rnn, device, input, hx, cx)
+
+            self.assertEqual(cpu_hx_grad, mps_hx_grad)
+            self.assertEqual(cpu_cx_grad, mps_cx_grad)
+            self.assertEqual(cpu_output, mps_output)
+            self.assertEqual(cpu_input_grad, mps_input_grad)
+            for (cpu_name, cpu_weight_grad), (mps_name, mps_weight_grad) in zip(cpu_weights_grad, mps_weights_grad):
+                self.assertEqual(cpu_weight_grad, mps_weight_grad,
+                                 f"mismatch in cpu:{cpu_name} vs mps:{mps_name}, layers: {num_layers}")
+
+    LSTM_TEST_CASES = [
+        dict(),  # default
+        dict(batch_first=True),
+        dict(bias=False),
+        dict(bidirectional=True),
+        dict(batch_first=True, bias=False),
+        dict(bidirectional=True, bias=False),
+        dict(bidirectional=True, batch_first=True),
+        dict(bidirectional=True, batch_first=True, bias=False)
+    ]
+
+    def test_lstm_forward(self, device="mps", dtype=torch.float32):
+        for num_layers in [1] if product_version < 13.0 else [1, 2, 5]:
+            for test_options in self.LSTM_TEST_CASES:
+                self._lstm_helper(num_layers=num_layers, dtype=dtype, device=device, **test_options)
 
     def test_lstm_backward(self, device="mps", dtype=torch.float32):
-        for layers in [1] if product_version < 13.0 else [1, 2, 5]:
-            lstm = nn.LSTM(2, 4, layers)  # initialized globally for consistent parameters init
-            lstm.train()
-
-            def get_results(device, inp, hx, cx):
-                rnn = lstm.to(device)
-                inp, hx, cx = inp.to(device), hx.to(device), cx.to(device)
-
-                output, _ = rnn(inp, (hx, cx))
-                f = output.sum()
-
-                param_names, params = zip(*rnn.named_parameters())
-                param_grads = zip(param_names, torch.autograd.grad(f, params, retain_graph=True))
-
-                input_grad, hx_grad, cx_grad = torch.autograd.grad(f, [inp, hx, cx])
-                return output, param_grads, input_grad, hx_grad, cx_grad
-
-            inp = torch.randn((5, 3, 2), requires_grad=True, dtype=dtype, device=device)
-            hx = torch.randn((layers, 3, 4), requires_grad=True, dtype=dtype, device=device)
-            cx = torch.randn((layers, 3, 4), requires_grad=True, dtype=dtype, device=device)
-
-            cpu_output, cpu_weights_grad, cpu_input_grad, cpu_hx_grad, cpu_cx_grad = get_results("cpu", inp, hx, cx)
-            mps_output, mps_weights_grad, mps_input_grad, mps_hx_grad, mps_cx_grad = get_results(device, inp, hx, cx)
-
-            self.assertEqual(cpu_hx_grad, mps_hx_grad)
-            self.assertEqual(cpu_cx_grad, mps_cx_grad)
-            self.assertEqual(cpu_output, mps_output)
-            self.assertEqual(cpu_input_grad, mps_input_grad)
-            for (cpu_name, cpu_weight_grad), (mps_name, mps_weight_grad) in zip(cpu_weights_grad, mps_weights_grad):
-                self.assertEqual(cpu_weight_grad, mps_weight_grad, f"mismatch in cpu:{cpu_name} vs mps:{mps_name}")
-
-            # test batch_first backward
-            lstm = nn.LSTM(2, 4, layers, batch_first=True)
-            lstm.train()
-
-            hx = torch.randn((layers, 5, 4), requires_grad=True, dtype=dtype, device=device)
-            cx = torch.randn((layers, 5, 4), requires_grad=True, dtype=dtype, device=device)
-
-            cpu_output, cpu_weights_grad, cpu_input_grad, cpu_hx_grad, cpu_cx_grad = get_results("cpu", inp, hx, cx)
-            mps_output, mps_weights_grad, mps_input_grad, mps_hx_grad, mps_cx_grad = get_results(device, inp, hx, cx)
-
-            self.assertEqual(cpu_hx_grad, mps_hx_grad)
-            self.assertEqual(cpu_cx_grad, mps_cx_grad)
-            self.assertEqual(cpu_output, mps_output)
-            self.assertEqual(cpu_input_grad, mps_input_grad)
-            for (cpu_name, cpu_weight_grad), (mps_name, mps_weight_grad) in zip(cpu_weights_grad, mps_weights_grad):
-                self.assertEqual(cpu_weight_grad, mps_weight_grad, f"mismatch in cpu:{cpu_name} vs mps:{mps_name}")
+        for num_layers in [1] if product_version < 13.0 else [1, 2, 5]:
+            for test_options in self.LSTM_TEST_CASES:
+                self._lstm_helper(num_layers=num_layers, dtype=dtype, device=device, backward=True, **test_options)
 
 
     def test_RNN_cell_no_broadcasting(self):