[MPS] Fix batch_norm_backwards key (#98794)

One needs different graphs for batch_norm_backwards depending whether or
not gradients are required for some of the params

Fixes https://github.com/pytorch/pytorch/issues/98602

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98794
Approved by: https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index b2cefa0..da92ece 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1950,6 +1950,16 @@
                         helper(shape, eps=3, momentum=0.67, wts=True, training=True, channels_last=channels_last,
                                track_running_stats=track_running_stats, test_module=test_module)
 
+    def test_batch_norm_backward(self):
+        inputs = torch.rand(1, 8, 4, 4, device='mps', requires_grad=True)
+        x = torch.nn.BatchNorm2d(8).to("mps")
+        y = torch.nn.BatchNorm2d(8).to("mps")
+        y.weight.requires_grad = False
+        y.bias.requires_grad = False
+        outputs = y(x(inputs))
+        # This used to crash, see https://github.com/pytorch/pytorch/issues/98602
+        outputs.sum().backward()
+
     def test_norm(self):
         a = torch.arange(9, dtype=torch.float, device="mps") - 4
         b = a.reshape((3, 3))