[MPS] Gather sliced inputs to batch norm (#133610)
This PR removes the `executeGatherOp` flag from batch norm in favor of relying on the logic in https://github.com/pytorch/pytorch/blob/4aa66f68a803927ddd127ceaaa1521b8d6e90e5f/aten/src/ATen/native/mps/OperationUtils.mm#L372 to decide if gathering is necessary.
It's not the most efficient way to solve this issue, but it assures correctness for sliced inputs.
### Performance impact
#### With fix
```
python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)"
100 loops, best of 5: 282 usec per loop
python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])"
100 loops, best of 5: 448 usec per loop
python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)"
1000 loops, best of 5: 705 usec per loop
python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])"
1000 loops, best of 5: 1.11 msec per loop
python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x)"
1000 loops, best of 5: 7.16 msec per loop
python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x[5:])"
1000 loops, best of 5: 11.7 msec per loop
```
#### Without fix
```
python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)"
100 loops, best of 5: 284 usec per loop
python -m timeit -n 100 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])"
100 loops, best of 5: 265 usec per loop
python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x)"
1000 loops, best of 5: 715 usec per loop
python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(100, 100, 35, 45).to('mps')" "bn(x[5:])"
1000 loops, best of 5: 675 usec per loop
python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x)"
1000 loops, best of 5: 7.19 msec per loop
python -m timeit -n 1000 -s "import torch; import torch.nn as nn; bn = nn.BatchNorm2d(100, affine=False, device='mps');x = torch.randn(1000, 100, 35, 45).to('mps')" "bn(x[5:])"
1000 loops, best of 5: 7.13 msec per loop
```
Please feel free to push back or request changes.
Fixes #133520
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133610
Approved by: https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index 86b181e..f4aa051 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2541,6 +2541,19 @@
# This used to crash, see https://github.com/pytorch/pytorch/issues/98602
outputs.sum().backward()
+ # Regression test for https://github.com/pytorch/pytorch/issues/133520
+ def test_batch_norm_slices(self):
+ bn_cpu = nn.BatchNorm2d(100, affine=False, device='cpu')
+ bn_mps = nn.BatchNorm2d(100, affine=False, device='mps')
+
+ x_cpu = torch.randn(100, 100, 35, 45).to('cpu')
+ x_mps = x_cpu.to('mps')
+
+ res_cpu = bn_cpu(x_cpu[5:])
+ res_mps = bn_mps(x_mps[5:])
+
+ self.assertEqual(res_cpu, res_mps)
+
def test_layer_norm_backward(self):
inputs = torch.rand(4, 4, device="mps", requires_grad=True)
x = torch.nn.LayerNorm(4).to("mps")