[MPS] LSTM grad_y missing fix (#96601)

Fixes #96416
Added tests that do not use LSTM output simalarly to the issue

Seems like this fix once again introduces backward incompatibility.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96601
Approved by: https://github.com/albanD, https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index 351aa76..bf0d0b0 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -9878,12 +9878,18 @@
         self.assertEqual(cpu_hn, hn)
         self.assertEqual(cpu_cn, cn)
 
-        def get_backward_results(rnn, device, inp, hx, cx):
+        def get_backward_results(rnn, device, inp, hx, cx, output_grad_presented=True, states_grad_presented=True):
             rnn = rnn.to(device)
             inp, hx, cx = inp.to(device), hx.to(device), cx.to(device)
 
-            output, _ = rnn(inp, (hx, cx))
-            f = 3 * output.sum() + (hx * cx).sum()
+            output, (hx_out, cx_out) = rnn(inp, (hx, cx))
+            assert output_grad_presented or states_grad_presented, "At least some outputs must be used"
+
+            f = 0
+            if output_grad_presented:
+                f = f + 3 * output.sum()
+            if states_grad_presented:
+                f = f + (hx_out * cx_out).sum()
 
             param_names, params = zip(*rnn.named_parameters())
             param_grads = zip(param_names, torch.autograd.grad(f, params, retain_graph=True))
@@ -9892,18 +9898,25 @@
             return output, param_grads, input_grad, hx_grad, cx_grad
 
         if backward:
-            cpu_output, cpu_weights_grad, cpu_input_grad, cpu_hx_grad, cpu_cx_grad =\
-                get_backward_results(rnn, "cpu", input, hx, cx)
-            mps_output, mps_weights_grad, mps_input_grad, mps_hx_grad, mps_cx_grad =\
-                get_backward_results(rnn, device, input, hx, cx)
+            grad_cases = [
+                dict(output_grad_presented=True, states_grad_presented=True),
+                dict(output_grad_presented=False, states_grad_presented=True),
+                dict(output_grad_presented=True, states_grad_presented=False),
+            ]
 
-            self.assertEqual(cpu_hx_grad, mps_hx_grad)
-            self.assertEqual(cpu_cx_grad, mps_cx_grad)
-            self.assertEqual(cpu_output, mps_output)
-            self.assertEqual(cpu_input_grad, mps_input_grad)
-            for (cpu_name, cpu_weight_grad), (mps_name, mps_weight_grad) in zip(cpu_weights_grad, mps_weights_grad):
-                self.assertEqual(cpu_weight_grad, mps_weight_grad,
-                                 f"mismatch in cpu:{cpu_name} vs mps:{mps_name}, layers: {num_layers}")
+            for grad_case in grad_cases:
+                cpu_output, cpu_weights_grad, cpu_input_grad, cpu_hx_grad, cpu_cx_grad =\
+                    get_backward_results(rnn, "cpu", input, hx, cx, **grad_case)
+                mps_output, mps_weights_grad, mps_input_grad, mps_hx_grad, mps_cx_grad =\
+                    get_backward_results(rnn, device, input, hx, cx, **grad_case)
+
+                self.assertEqual(cpu_hx_grad, mps_hx_grad)
+                self.assertEqual(cpu_cx_grad, mps_cx_grad)
+                self.assertEqual(cpu_output, mps_output)
+                self.assertEqual(cpu_input_grad, mps_input_grad)
+                for (cpu_name, cpu_weight_grad), (mps_name, mps_weight_grad) in zip(cpu_weights_grad, mps_weights_grad):
+                    self.assertEqual(cpu_weight_grad, mps_weight_grad,
+                                     f"mismatch in cpu:{cpu_name} vs mps:{mps_name}, layers: {num_layers}")
 
     LSTM_TEST_CASES = [
         dict(),  # default