Enable batch matmul for result sizes > 2**32 the tensor can be split along batch axis (#133430)

Fixes #131865. Addresses the issue seen when running llama v3.1 8B parameter model on MPS backend where the batch matmul output size can go over the 32-bit indexing limit of MPS tensors, causing an assert.

Test case to reproduce the issue with the dimensions encountered in llama v3.1 and verify this fix works around it:

```
import torch
device='mps'
a = torch.randn([32, 20064, 128], dtype=torch.float32,device=device)
b = torch.randn([32, 128, 20064], dtype=torch.float32, device=device)
res = torch.bmm(a, b)
```

Notably the current change only works as long as the individual output matrix in the bmm does not exceed the number of elements 2**32. This lets us split up the computation along the batch axis to avoid going over the limit.

Added a TORCH_CHECK to raise an error if the individual matrix dimensions are too large to handle for this op until a more general workaround tiling the matmuls is available.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133430
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <[email protected]>
diff --git a/test/test_mps.py b/test/test_mps.py
index 281c68b..f7f36e5 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -1914,6 +1914,20 @@
         self.assertEqual(output_cpu, output_mps)
         self.assertEqual(output_cpu.size(), output_mps.size())
 
+    @xfailIf(product_version < 15.0)
+    @parametrize("dtype", [torch.float16, torch.bfloat16])
+    def test_large_bmm(self, dtype):
+        batch1 = torch.randn(11, 20064, 128, dtype=dtype, device='mps')
+        batch2 = torch.randn(11, 128, 20064, dtype=dtype, device='mps')
+        output_cpu = torch.bmm(batch1.cpu(), batch2.cpu())
+        output_mps = torch.bmm(batch1, batch2)
+
+        # Using the low precision comparison for FP16
+        tol = 1e-2 if dtype == torch.float16 else None
+        self.assertEqual(output_cpu, output_mps, atol=tol, rtol=tol)
+        self.assertEqual(output_cpu.size(), output_mps.size())
+
+
     def test_addr(self):
         A = torch.ones(5, 10).to("mps")
         B = torch.ones(5).to("mps")