[MPS] Fix batch_norm_backwards key (#98794)
One needs different graphs for batch_norm_backwards depending whether or
not gradients are required for some of the params
Fixes https://github.com/pytorch/pytorch/issues/98602
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98794
Approved by: https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index b2cefa0..da92ece 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1950,6 +1950,16 @@
helper(shape, eps=3, momentum=0.67, wts=True, training=True, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
+ def test_batch_norm_backward(self):
+ inputs = torch.rand(1, 8, 4, 4, device='mps', requires_grad=True)
+ x = torch.nn.BatchNorm2d(8).to("mps")
+ y = torch.nn.BatchNorm2d(8).to("mps")
+ y.weight.requires_grad = False
+ y.bias.requires_grad = False
+ outputs = y(x(inputs))
+ # This used to crash, see https://github.com/pytorch/pytorch/issues/98602
+ outputs.sum().backward()
+
def test_norm(self):
a = torch.arange(9, dtype=torch.float, device="mps") - 4
b = a.reshape((3, 3))