[MPS] Gather sliced inputs to batch norm (#133610)

This PR removes the `executeGatherOp` flag from batch norm in favor of relying on the logic in https://github.com/pytorch/pytorch/blob/4aa66f68a803927ddd127ceaaa1521b8d6e90e5f/aten/src/ATen/native/mps/OperationUtils.mm#L372 to decide if gathering is necessary.

It's not the most efficient way to solve this issue, but it assures correctness for sliced inputs.

### Performance impact

#### With fix

```
python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)"
100 loops, best of 5: 282 usec per loop

python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])"
100 loops, best of 5: 448 usec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)"
1000 loops, best of 5: 705 usec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])"
1000 loops, best of 5: 1.11 msec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x)"
1000 loops, best of 5: 7.16 msec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x[5:])"
1000 loops, best of 5: 11.7 msec per loop
```

#### Without fix

```
python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)"
100 loops, best of 5: 284 usec per loop

python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])"
100 loops, best of 5: 265 usec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)"
1000 loops, best of 5: 715 usec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])"
1000 loops, best of 5: 675 usec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x)"
1000 loops, best of 5: 7.19 msec per loop

python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x[5:])"
1000 loops, best of 5: 7.13 msec per loop
```

Please feel free to push back or request changes.

Fixes #133520
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133610
Approved by: https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index 86b181e..f4aa051 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2541,6 +2541,19 @@
         # This used to crash, see https://github.com/pytorch/pytorch/issues/98602
         outputs.sum().backward()
 
+    # Regression test for https://github.com/pytorch/pytorch/issues/133520
+    def test_batch_norm_slices(self):
+        bn_cpu = nn.BatchNorm2d(100, affine=False, device='cpu')
+        bn_mps = nn.BatchNorm2d(100, affine=False, device='mps')
+
+        x_cpu = torch.randn(100, 100, 35, 45).to('cpu')
+        x_mps = x_cpu.to('mps')
+
+        res_cpu = bn_cpu(x_cpu[5:])
+        res_mps = bn_mps(x_mps[5:])
+
+        self.assertEqual(res_cpu, res_mps)
+
     def test_layer_norm_backward(self):
         inputs = torch.rand(4, 4, device="mps", requires_grad=True)
         x = torch.nn.LayerNorm(4).to("mps")