[metal] Parameterize group_size in int4_mm test, fix int4mm shader for group_size > 128 (#129628)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129628
Approved by: https://github.com/kimishpatel
diff --git a/aten/src/ATen/native/mps/operations/Quantized.mm b/aten/src/ATen/native/mps/operations/Quantized.mm
index 326239e..83371f9 100644
--- a/aten/src/ATen/native/mps/operations/Quantized.mm
+++ b/aten/src/ATen/native/mps/operations/Quantized.mm
@@ -165,22 +165,16 @@
   // affecting performance. This is the trick applied in MLX kernels.
   float4 act_div_scales = {1.f, 1 / 16.f, 1 / 256.f, 1 / 4096.f};
 
-  // Find specific group to which channels handled by this thread
-  // belong.
-  uint k_block_index = k / group_size;
-  // Since scales_and_zeros are packed as [num_groups, N, 2].
-  // Finding a specific's group's scales and zero points requires jump by factor
-  // of N*2
-  uint scales_group_offset = (k_block_index * N + n) * 2;
-  uint zeros_gruop_offset = scales_group_offset + 1;
-  const uint scales_jump =
-      N * 2 *
-      (k_jump /
-       group_size); /* the last term accounts for identifying the group this
-                      thread will have to process in each iteration. This mean
-                      each iteration it must jump to a different group. Thus
-                      k_jump must be > group_size */
   for (; k < K; k += k_jump) {
+    // Find specific group to which channels handled by this thread
+    // belong.
+    uint k_block_index = k / group_size;
+    // Since scales_and_zeros are packed as [num_groups, N, 2].
+    // Finding a specific's group's scales and zero points requires jump by factor
+    // of N*2
+    uint scales_group_offset = (k_block_index * N + n) * 2;
+    uint zeros_gruop_offset = scales_group_offset + 1;
+
     const T scale0 = scales_and_zeros[scales_group_offset];
     // Adding zero point results in 10% perf penalty.
     const T zero0 = scales_and_zeros[zeros_gruop_offset] - scale0 * T(8);
@@ -194,9 +188,6 @@
     const T scale3 = scales_and_zeros[scales_group_offset + 6];
     const T zero3 = scales_and_zeros[zeros_gruop_offset + 6] - scale3 * T(8);
 
-    scales_group_offset += scales_jump;
-    zeros_gruop_offset += scales_jump;
-
     const float4 zeros = float4(zero0, zero1, zero2, zero3);
 
     float4 a_val = float4(A_ptr[k / 4]);
@@ -240,9 +231,9 @@
   int4pack_mm<DTYPE, GSIZE>(                                                   \
       constant DTYPE * A [[buffer(0)]], constant uchar * B [[buffer(1)]],      \
       constant DTYPE * scales_and_zeros [[buffer(2)]],                         \
-      device DTYPE * output_data [[buffer(3)]],                                 \
+      device DTYPE * output_data [[buffer(3)]],                                \
       constant uint3 & sizes [[buffer(4)]],                                    \
-      uint3 thread_index [[thread_position_in_grid]],                           \
+      uint3 thread_index [[thread_position_in_grid]],                          \
       uint tid_in_simdgroup [[thread_index_in_simdgroup]])
 
 INSTANTIATE_INT4MV(float, 32);
diff --git a/test/test_mps.py b/test/test_mps.py
index cfa3e21..69b484b 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -9153,10 +9153,11 @@
                     raise e
 
     @parametrize("m", [1, 32, 64])
-    @parametrize("k", [32, 64])
     @parametrize("n", [48, 64])
-    def test__int4_mm(self, m, k, n):
-        q_group = 32
+    @parametrize("q_group", [32, 64, 128, 256])
+    @parametrize("num_groups", [1, 2])
+    def test__int4_mm(self, m, n, q_group, num_groups):
+        k = q_group * num_groups
         inner_k_tiles = 2
 
         torch.manual_seed(1)