[MPS] LSTM grad_y missing fix (#96601)
Fixes #96416
Added tests that do not use LSTM output simalarly to the issue
Seems like this fix once again introduces backward incompatibility.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96601
Approved by: https://github.com/albanD, https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index 351aa76..bf0d0b0 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -9878,12 +9878,18 @@
self.assertEqual(cpu_hn, hn)
self.assertEqual(cpu_cn, cn)
- def get_backward_results(rnn, device, inp, hx, cx):
+ def get_backward_results(rnn, device, inp, hx, cx, output_grad_presented=True, states_grad_presented=True):
rnn = rnn.to(device)
inp, hx, cx = inp.to(device), hx.to(device), cx.to(device)
- output, _ = rnn(inp, (hx, cx))
- f = 3 * output.sum() + (hx * cx).sum()
+ output, (hx_out, cx_out) = rnn(inp, (hx, cx))
+ assert output_grad_presented or states_grad_presented, "At least some outputs must be used"
+
+ f = 0
+ if output_grad_presented:
+ f = f + 3 * output.sum()
+ if states_grad_presented:
+ f = f + (hx_out * cx_out).sum()
param_names, params = zip(*rnn.named_parameters())
param_grads = zip(param_names, torch.autograd.grad(f, params, retain_graph=True))
@@ -9892,18 +9898,25 @@
return output, param_grads, input_grad, hx_grad, cx_grad
if backward:
- cpu_output, cpu_weights_grad, cpu_input_grad, cpu_hx_grad, cpu_cx_grad =\
- get_backward_results(rnn, "cpu", input, hx, cx)
- mps_output, mps_weights_grad, mps_input_grad, mps_hx_grad, mps_cx_grad =\
- get_backward_results(rnn, device, input, hx, cx)
+ grad_cases = [
+ dict(output_grad_presented=True, states_grad_presented=True),
+ dict(output_grad_presented=False, states_grad_presented=True),
+ dict(output_grad_presented=True, states_grad_presented=False),
+ ]
- self.assertEqual(cpu_hx_grad, mps_hx_grad)
- self.assertEqual(cpu_cx_grad, mps_cx_grad)
- self.assertEqual(cpu_output, mps_output)
- self.assertEqual(cpu_input_grad, mps_input_grad)
- for (cpu_name, cpu_weight_grad), (mps_name, mps_weight_grad) in zip(cpu_weights_grad, mps_weights_grad):
- self.assertEqual(cpu_weight_grad, mps_weight_grad,
- f"mismatch in cpu:{cpu_name} vs mps:{mps_name}, layers: {num_layers}")
+ for grad_case in grad_cases:
+ cpu_output, cpu_weights_grad, cpu_input_grad, cpu_hx_grad, cpu_cx_grad =\
+ get_backward_results(rnn, "cpu", input, hx, cx, **grad_case)
+ mps_output, mps_weights_grad, mps_input_grad, mps_hx_grad, mps_cx_grad =\
+ get_backward_results(rnn, device, input, hx, cx, **grad_case)
+
+ self.assertEqual(cpu_hx_grad, mps_hx_grad)
+ self.assertEqual(cpu_cx_grad, mps_cx_grad)
+ self.assertEqual(cpu_output, mps_output)
+ self.assertEqual(cpu_input_grad, mps_input_grad)
+ for (cpu_name, cpu_weight_grad), (mps_name, mps_weight_grad) in zip(cpu_weights_grad, mps_weights_grad):
+ self.assertEqual(cpu_weight_grad, mps_weight_grad,
+ f"mismatch in cpu:{cpu_name} vs mps:{mps_name}, layers: {num_layers}")
LSTM_TEST_CASES = [
dict(), # default