[MPS] Fix `layer_norm_backward_mps` key (#100295)

Followup after https://github.com/pytorch/pytorch/pull/98794
See report in https://github.com/pytorch/pytorch/issues/98602#issuecomment-1527312211 and reproducer in https://github.com/pytorch/pytorch/issues/98602#issuecomment-1528214175

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100295
Approved by: https://github.com/kit1980, https://github.com/izaitsevfb
diff --git a/test/test_mps.py b/test/test_mps.py
index 21fd749..14ef58f 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2012,7 +2012,7 @@
                                track_running_stats=track_running_stats, test_module=test_module)
 
     def test_batch_norm_backward(self):
-        inputs = torch.rand(1, 8, 4, 4, device='mps', requires_grad=True)
+        inputs = torch.rand(1, 8, 4, 4, device="mps", requires_grad=True)
         x = torch.nn.BatchNorm2d(8).to("mps")
         y = torch.nn.BatchNorm2d(8).to("mps")
         y.weight.requires_grad = False
@@ -2021,6 +2021,16 @@
         # This used to crash, see https://github.com/pytorch/pytorch/issues/98602
         outputs.sum().backward()
 
+    def test_layer_norm_backward(self):
+        inputs = torch.rand(4, 4, device="mps", requires_grad=True)
+        x = torch.nn.LayerNorm(4).to("mps")
+        y = torch.nn.LayerNorm(4).to("mps")
+        y.weight.requires_grad = False
+        y.bias.requires_grad = False
+        outputs = y(x(inputs))
+        # This used to crash, see https://github.com/pytorch/pytorch/issues/98602
+        outputs.sum().backward()
+
     def test_norm(self):
         a = torch.arange(9, dtype=torch.float, device="mps") - 4
         b = a.reshape((3, 3))