[metal] Parameterize group_size in int4_mm test, fix int4mm shader for group_size > 128 (#129628)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129628
Approved by: https://github.com/kimishpatel
diff --git a/test/test_mps.py b/test/test_mps.py
index cfa3e21..69b484b 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -9153,10 +9153,11 @@
                     raise e
 
     @parametrize("m", [1, 32, 64])
-    @parametrize("k", [32, 64])
     @parametrize("n", [48, 64])
-    def test__int4_mm(self, m, k, n):
-        q_group = 32
+    @parametrize("q_group", [32, 64, 128, 256])
+    @parametrize("num_groups", [1, 2])
+    def test__int4_mm(self, m, n, q_group, num_groups):
+        k = q_group * num_groups
         inner_k_tiles = 2
 
         torch.manual_seed(1)