Enable batch matmul for result sizes > 2**32 the tensor can be split along batch axis (#133430) Fixes #131865. Addresses the issue seen when running llama v3.1 8B parameter model on MPS backend where the batch matmul output size can go over the 32-bit indexing limit of MPS tensors, causing an assert. Test case to reproduce the issue with the dimensions encountered in llama v3.1 and verify this fix works around it: ``` import torch device='mps' a = torch.randn([32, 20064, 128], dtype=torch.float32,device=device) b = torch.randn([32, 128, 20064], dtype=torch.float32, device=device) res = torch.bmm(a, b) ``` Notably the current change only works as long as the individual output matrix in the bmm does not exceed the number of elements 2**32. This lets us split up the computation along the batch axis to avoid going over the limit. Added a TORCH_CHECK to raise an error if the individual matrix dimensions are too large to handle for this op until a more general workaround tiling the matmuls is available. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133430 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <[email protected]>
diff --git a/test/test_mps.py b/test/test_mps.py index 281c68b..f7f36e5 100644 --- a/test/test_mps.py +++ b/test/test_mps.py
@@ -1914,6 +1914,20 @@ self.assertEqual(output_cpu, output_mps) self.assertEqual(output_cpu.size(), output_mps.size()) + @xfailIf(product_version < 15.0) + @parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_large_bmm(self, dtype): + batch1 = torch.randn(11, 20064, 128, dtype=dtype, device='mps') + batch2 = torch.randn(11, 128, 20064, dtype=dtype, device='mps') + output_cpu = torch.bmm(batch1.cpu(), batch2.cpu()) + output_mps = torch.bmm(batch1, batch2) + + # Using the low precision comparison for FP16 + tol = 1e-2 if dtype == torch.float16 else None + self.assertEqual(output_cpu, output_mps, atol=tol, rtol=tol) + self.assertEqual(output_cpu.size(), output_mps.size()) + + def test_addr(self): A = torch.ones(5, 10).to("mps") B = torch.ones(5).to("mps")