[metal] Parameterize group_size in int4_mm test, fix int4mm shader for group_size > 128 (#129628)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129628
Approved by: https://github.com/kimishpatel
diff --git a/test/test_mps.py b/test/test_mps.py
index cfa3e21..69b484b 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -9153,10 +9153,11 @@
raise e
@parametrize("m", [1, 32, 64])
- @parametrize("k", [32, 64])
@parametrize("n", [48, 64])
- def test__int4_mm(self, m, k, n):
- q_group = 32
+ @parametrize("q_group", [32, 64, 128, 256])
+ @parametrize("num_groups", [1, 2])
+ def test__int4_mm(self, m, n, q_group, num_groups):
+ k = q_group * num_groups
inner_k_tiles = 2
torch.manual_seed(1)