[metal] Parameterize group_size in int4_mm test, fix int4mm shader for group_size > 128 (#129628)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129628
Approved by: https://github.com/kimishpatel
diff --git a/aten/src/ATen/native/mps/operations/Quantized.mm b/aten/src/ATen/native/mps/operations/Quantized.mm
index 326239e..83371f9 100644
--- a/aten/src/ATen/native/mps/operations/Quantized.mm
+++ b/aten/src/ATen/native/mps/operations/Quantized.mm
@@ -165,22 +165,16 @@
// affecting performance. This is the trick applied in MLX kernels.
float4 act_div_scales = {1.f, 1 / 16.f, 1 / 256.f, 1 / 4096.f};
- // Find specific group to which channels handled by this thread
- // belong.
- uint k_block_index = k / group_size;
- // Since scales_and_zeros are packed as [num_groups, N, 2].
- // Finding a specific's group's scales and zero points requires jump by factor
- // of N*2
- uint scales_group_offset = (k_block_index * N + n) * 2;
- uint zeros_gruop_offset = scales_group_offset + 1;
- const uint scales_jump =
- N * 2 *
- (k_jump /
- group_size); /* the last term accounts for identifying the group this
- thread will have to process in each iteration. This mean
- each iteration it must jump to a different group. Thus
- k_jump must be > group_size */
for (; k < K; k += k_jump) {
+ // Find specific group to which channels handled by this thread
+ // belong.
+ uint k_block_index = k / group_size;
+ // Since scales_and_zeros are packed as [num_groups, N, 2].
+ // Finding a specific's group's scales and zero points requires jump by factor
+ // of N*2
+ uint scales_group_offset = (k_block_index * N + n) * 2;
+ uint zeros_gruop_offset = scales_group_offset + 1;
+
const T scale0 = scales_and_zeros[scales_group_offset];
// Adding zero point results in 10% perf penalty.
const T zero0 = scales_and_zeros[zeros_gruop_offset] - scale0 * T(8);
@@ -194,9 +188,6 @@
const T scale3 = scales_and_zeros[scales_group_offset + 6];
const T zero3 = scales_and_zeros[zeros_gruop_offset + 6] - scale3 * T(8);
- scales_group_offset += scales_jump;
- zeros_gruop_offset += scales_jump;
-
const float4 zeros = float4(zero0, zero1, zero2, zero3);
float4 a_val = float4(A_ptr[k / 4]);
@@ -240,9 +231,9 @@
int4pack_mm<DTYPE, GSIZE>( \
constant DTYPE * A [[buffer(0)]], constant uchar * B [[buffer(1)]], \
constant DTYPE * scales_and_zeros [[buffer(2)]], \
- device DTYPE * output_data [[buffer(3)]], \
+ device DTYPE * output_data [[buffer(3)]], \
constant uint3 & sizes [[buffer(4)]], \
- uint3 thread_index [[thread_position_in_grid]], \
+ uint3 thread_index [[thread_position_in_grid]], \
uint tid_in_simdgroup [[thread_index_in_simdgroup]])
INSTANTIATE_INT4MV(float, 32);
diff --git a/test/test_mps.py b/test/test_mps.py
index cfa3e21..69b484b 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -9153,10 +9153,11 @@
raise e
@parametrize("m", [1, 32, 64])
- @parametrize("k", [32, 64])
@parametrize("n", [48, 64])
- def test__int4_mm(self, m, k, n):
- q_group = 32
+ @parametrize("q_group", [32, 64, 128, 256])
+ @parametrize("num_groups", [1, 2])
+ def test__int4_mm(self, m, n, q_group, num_groups):
+ k = q_group * num_groups
inner_k_tiles = 2
torch.manual_seed(1)