[MPS] Fix `layer_norm_backward_mps` key (#100295)
Followup after https://github.com/pytorch/pytorch/pull/98794
See report in https://github.com/pytorch/pytorch/issues/98602#issuecomment-1527312211 and reproducer in https://github.com/pytorch/pytorch/issues/98602#issuecomment-1528214175
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100295
Approved by: https://github.com/kit1980, https://github.com/izaitsevfb
diff --git a/test/test_mps.py b/test/test_mps.py
index 21fd749..14ef58f 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2012,7 +2012,7 @@
track_running_stats=track_running_stats, test_module=test_module)
def test_batch_norm_backward(self):
- inputs = torch.rand(1, 8, 4, 4, device='mps', requires_grad=True)
+ inputs = torch.rand(1, 8, 4, 4, device="mps", requires_grad=True)
x = torch.nn.BatchNorm2d(8).to("mps")
y = torch.nn.BatchNorm2d(8).to("mps")
y.weight.requires_grad = False
@@ -2021,6 +2021,16 @@
# This used to crash, see https://github.com/pytorch/pytorch/issues/98602
outputs.sum().backward()
+ def test_layer_norm_backward(self):
+ inputs = torch.rand(4, 4, device="mps", requires_grad=True)
+ x = torch.nn.LayerNorm(4).to("mps")
+ y = torch.nn.LayerNorm(4).to("mps")
+ y.weight.requires_grad = False
+ y.bias.requires_grad = False
+ outputs = y(x(inputs))
+ # This used to crash, see https://github.com/pytorch/pytorch/issues/98602
+ outputs.sum().backward()
+
def test_norm(self):
a = torch.arange(9, dtype=torch.float, device="mps") - 4
b = a.reshape((3, 3))