[metal] Add int4mm weight packing mps kernel, and improved int4mm shader (#128965)

Adds _convert_weight_to_int4pack MPS kernel
Replaces previous int4mm Metal shader, with shader authored by @kimishpatel  which improves perf by ~40%

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128965
Approved by: https://github.com/malfet
diff --git a/aten/src/ATen/native/mps/operations/Quantized.mm b/aten/src/ATen/native/mps/operations/Quantized.mm
index 860d53b..326239e 100644
--- a/aten/src/ATen/native/mps/operations/Quantized.mm
+++ b/aten/src/ATen/native/mps/operations/Quantized.mm
@@ -4,6 +4,7 @@
 #include <ATen/Functions.h>
 #include <ATen/NativeFunctions.h>
 #else
+#include <ATen/ops/_convert_weight_to_int4pack_native.h>
 #include <ATen/ops/_weight_int4pack_mm_native.h>
 #include <ATen/ops/_weight_int8pack_mm_native.h>
 #include <ATen/ops/empty.h>
@@ -54,96 +55,209 @@
 };
 #endif
 
-template<typename T, unsigned groupSize>
-kernel void int4pack_mm(
-    constant T                 * A              [[buffer(0)]],
-    constant uchar             * B              [[buffer(1)]],
-    constant T                 * scalesAndZeros [[buffer(2)]],
-    device   T                 * outputData     [[buffer(3)]],
-    constant uint3             & sizes          [[buffer(4)]], // M, K, N
-    uint3 group_index [[threadgroup_position_in_grid]],
-    uint3 threadgroup_index [[thread_position_in_threadgroup]]) {
-
-    const uint K = sizes.y;
-    const uint N = sizes.z;
-    const uint nb = group_index.x; // 0..N/32-1
-    const uint n2 = 16 * nb + threadgroup_index.x; // 0..N/2-1
-    const uint m = group_index.z;
-    const uint ldb = min(32U,  N - nb * 32);
-    const uint32_t k_block = (K + groupSize - 1) / groupSize;
-
-    using vec2T = typename Vec2Type<T>::type;
-    using vec4T = typename Vec4Type<T>::type;
-
-    constant vec4T *A_ptr = reinterpret_cast<constant vec4T *>(A + m * K);
-    constant uchar *B_ptr = B + (nb * 16 * K);
-
-    float2 rc = 0.0;
-    uint k = threadgroup_index.y * 4;
-    for (uint32_t kb = 0; kb < k_block ; kb ++) {
-      float2 scales, zeros;
-      for (int i = 0; i < 2; ++i) {
-        scales[i] = scalesAndZeros[(kb * N + 2*n2 + i) * 2 + 0];
-        zeros[i] = scalesAndZeros[(kb * N + 2*n2 + i) * 2 + 1] - scales[i] * T(8);
-      }
-
-      for(uint idx = k % groupSize; idx < groupSize && k < K; idx += 16, k += 16) {
-        threadgroup_barrier(mem_flags::mem_none);
-
-        const auto a_vec = float4(A_ptr[k/4]);
-        uchar4 b_byte;
-        for (int i = 0; i < 4; i++) {
-          b_byte[i] = B_ptr[((k + i) * ldb + (2*n2 % 32))/2];
-        }
-
-        float4x2 b_mat;
-
-        for (int i = 0; i < 4; i++) {
-          b_mat[i] = scales * float2(
-            float(b_byte[i] & 0x0f),
-            float(b_byte[i] >> 4)) + zeros;
-        }
-
-        rc += b_mat * a_vec;
-      }
-    }
-
-    threadgroup float2 tgp_memory[16][4];
-    tgp_memory[threadgroup_index.x][threadgroup_index.y] = rc;
-    threadgroup_barrier(mem_flags::mem_threadgroup);
-    if (threadgroup_index.y == 0) {
-      for (unsigned i = 1; i < 4; i++) {
-        rc += tgp_memory[threadgroup_index.x][i];
-      }
-      reinterpret_cast<device vec2T*>(outputData + m * N)[n2] = vec2T(rc);
-    }
+kernel void weight_to_int4pack(constant int *W [[buffer(0)]],
+                               device uchar *outputData [[buffer(1)]],
+                               constant uint2 &sizes [[buffer(2)]],
+                               uint2 thread_index [[thread_position_in_grid]]) {
+  const uint N = sizes.x;
+  const uint K = sizes.y;
+  const uint n = thread_index.x; // 0..N-1
+  const uint k2 = thread_index.y; // 0..K/2-1
+  int32_t src_val0 = W[n * K + 2 * k2];
+  int32_t src_val1 = W[n * K + 2 * k2 + 1];
+  outputData[n * (K / 2) + k2] = (uint8_t(src_val1) << 4) | uint8_t(src_val0);
 }
 
-#define INSTANTIATE_INT4MM(DTYPE, GSIZE)                                 \
-template                                                                 \
-[[host_name("int4pack_mm_" #GSIZE "_" #DTYPE)]]                          \
-kernel void int4pack_mm<DTYPE, GSIZE>(                                   \
-    constant DTYPE             * A              [[buffer(0)]],           \
-    constant uchar             * B              [[buffer(1)]],           \
-    constant DTYPE             * scalesAndZeros [[buffer(2)]],           \
-    device   DTYPE             * outputData     [[buffer(3)]],           \
-    constant uint3             & sizes          [[buffer(4)]],           \
-    uint3 group_index [[threadgroup_position_in_grid]], \
-    uint3 threadgroup_index [[thread_position_in_threadgroup]])
+/*
+   This code takes heavy inspiration from MLX qvm kernel here:
+   https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/quantized.metal#L381
+   Specifically:
+     - Multiplying activation by inverse scaling factor to reduce compute
+   boundedness
+     - Handling zero point by accumulating act in separate sum term. Needed with
+   optimization done above. MLX MIT License:
+   https://github.com/ml-explore/mlx/blob/main/LICENSE
+*/
 
-INSTANTIATE_INT4MM(float, 32);
-INSTANTIATE_INT4MM(half, 32);
-INSTANTIATE_INT4MM(float, 64);
-INSTANTIATE_INT4MM(half, 64);
-INSTANTIATE_INT4MM(float, 128);
-INSTANTIATE_INT4MM(half, 128);
-INSTANTIATE_INT4MM(float, 256);
-INSTANTIATE_INT4MM(half, 256);
+/*
+   A matrix is [M x K] (right now this kernel does not support M > 1 but this is
+   a very easy fix that will follow right after) B matrix is [N x K]. For 4 bit
+   2 of the k values are packed in one byte so you can think of B as [N x K/2]
+   matrix from layout perspective.
+
+   Since this kernel is optimizing for gemv case, we split work, along reduction
+   dim k, among the threads of same simdgroup. Ex: if K = 4096 and simdgroup
+   size is 32 (current algorithm should work as long as simdgroup size is > 32).
+   Then each thread will accumulate 4096/32 = 128 k values. However these 128
+   values, handled by each thread are not laid out contiguously. Each thread
+   handles 4 contiguous k values and then jumps 128 elements, k_jump =
+   thread_per_channel (32) * ks_per_thread (4). Take a simpler example where
+   simdgroup is of size 4. In this case threads_per_channel = 4. Assume K = 32
+      k                thread
+   [0, 1, 2, 3,          0
+    4, 5, 6, 7,          1
+    8, 9, 10, 11,        2
+    12, 13, 14, 15,      3
+    16, 17, 18, 19,      0
+    20, 21, 22, 23,      1
+    24, 25, 26, 27,      2
+    28, 29, 30, 31]      3
+   thread id in simd group that handle corresponding
+   ks
+   Thread 0 here is handling (0, 1, 2, 3) and then (16, 17, 18, 19). They are
+   apart by k_jump = 4 * 4 = 16 This is done to improve memory access locality
+   amonng threads that are working co-operatively. Once each thread has their
+   partial sums accumulated, we use tree reduction (Metal offers simd_sum but
+   not used so that we support simdgroup size = 64). In the
+   example above we will have 4 partial sums.
+
+   Each thread also handles 4 different output rows. Thus each simdgroup will be
+   responsible for (1x4) tile of the output. We haven't evaluated whether a
+   different tile size is better or not. We probably will do some auto-tuning
+   once initial work is done.
+
+*/
+
+/*
+   @brief This shader implements 4-bit matrix-vector multiplication where A
+   matrix is fp16, bfloat or float and B matrix is a 4-bit groupwise-quantized weight
+   matrix.
+   @param [in] A is activation matrix of size M x K.
+   @param [in] B is weight matrix of size M x K. Each byte contains 2 4-bit
+   values, along K dim, packed together.
+   @param [in] scales_and_zeros is scales and zero points corresponding each
+   output channel x groups. These are packed as [num_groups = ceil(K / group_size), N, 2]. N = output
+   @param [out] output_data is output matrix of size M x N.
+   @param [in] sizes array contains values of M, N and K.
+   @param [in] thread_index is global thread id.
+   @param [in] tid_in_simdgruop is thread id in simdgroup. e.g. in simdgroup of size 32 it can be in [0-31].
+*/
+template <typename T, unsigned group_size>
+kernel void int4pack_mm(constant T *A [[buffer(0)]],
+                        constant uchar *B [[buffer(1)]],
+                        constant T *scales_and_zeros [[buffer(2)]],
+                        device T *output_data [[buffer(3)]],
+                        constant uint3 &sizes [[buffer(4)]], // M, K, N
+                        uint3 thread_index [[thread_position_in_grid]],
+                        uint tid_in_simdgroup [[thread_index_in_simdgroup]]) {
+  constexpr uint threads_per_channel = 32;
+  constexpr uint ks_per_thread = 4;
+  constexpr uint k_pack_factor = 2;
+  const uint K = sizes.y;
+  const uint N = sizes.z;
+  uint n = thread_index.x; // 0..N/4-1
+  uint m = thread_index.z; // 0..M
+  n = n / threads_per_channel;
+  n = n * 4;
+  // This is starting k for each thread. In the example above, for thread 1 this
+  // value will be 4.
+  uint k = (tid_in_simdgroup % threads_per_channel) * ks_per_thread;
+  constexpr int k_jump = threads_per_channel * ks_per_thread;
+
+  using vecT = typename Vec4Type<T>::type;
+  constant vecT *A_ptr = reinterpret_cast<constant vecT *>(A + m * K);
+  constant uchar *B_ptr = B + ((n * K) / k_pack_factor);
+
+  thread float4 result = float4(0.0);
+  // We multipy group of 4 channels with these scales.
+  // Because corresponding values from weight matrix are effectively left
+  // shifted. This is to avoid doing right shift on those values which ends up
+  // affecting performance. This is the trick applied in MLX kernels.
+  float4 act_div_scales = {1.f, 1 / 16.f, 1 / 256.f, 1 / 4096.f};
+
+  // Find specific group to which channels handled by this thread
+  // belong.
+  uint k_block_index = k / group_size;
+  // Since scales_and_zeros are packed as [num_groups, N, 2].
+  // Finding a specific's group's scales and zero points requires jump by factor
+  // of N*2
+  uint scales_group_offset = (k_block_index * N + n) * 2;
+  uint zeros_gruop_offset = scales_group_offset + 1;
+  const uint scales_jump =
+      N * 2 *
+      (k_jump /
+       group_size); /* the last term accounts for identifying the group this
+                      thread will have to process in each iteration. This mean
+                      each iteration it must jump to a different group. Thus
+                      k_jump must be > group_size */
+  for (; k < K; k += k_jump) {
+    const T scale0 = scales_and_zeros[scales_group_offset];
+    // Adding zero point results in 10% perf penalty.
+    const T zero0 = scales_and_zeros[zeros_gruop_offset] - scale0 * T(8);
+
+    const T scale1 = scales_and_zeros[scales_group_offset + 2];
+    const T zero1 = scales_and_zeros[zeros_gruop_offset + 2] - scale1 * T(8);
+
+    const T scale2 = scales_and_zeros[scales_group_offset + 4];
+    const T zero2 = scales_and_zeros[zeros_gruop_offset + 4] - scale2 * T(8);
+
+    const T scale3 = scales_and_zeros[scales_group_offset + 6];
+    const T zero3 = scales_and_zeros[zeros_gruop_offset + 6] - scale3 * T(8);
+
+    scales_group_offset += scales_jump;
+    zeros_gruop_offset += scales_jump;
+
+    const float4 zeros = float4(zero0, zero1, zero2, zero3);
+
+    float4 a_val = float4(A_ptr[k / 4]);
+    // We are gonna skip right-shifts of the weights and hence divide by corresponding factor.
+    float4 a_vec = a_val * act_div_scales;
+    float a_val_sum = a_val[0] + a_val[1] + a_val[2] + a_val[3];
+
+    float4x4 b_mat;
+    ushort b_val0 = (reinterpret_cast<constant ushort *>(
+        B_ptr + (k + 0 * K) / k_pack_factor))[0];
+    ushort b_val1 = (reinterpret_cast<constant ushort *>(
+        B_ptr + (k + 1 * K) / k_pack_factor))[0];
+    ushort b_val2 = (reinterpret_cast<constant ushort *>(
+        B_ptr + (k + 2 * K) / k_pack_factor))[0];
+    ushort b_val3 = (reinterpret_cast<constant ushort *>(
+        B_ptr + (k + 3 * K) / k_pack_factor))[0];
+    b_mat[0] = scale0 * float4(float(b_val0 & 0x000f), float(b_val0 & 0x00f0),
+                               float(b_val0 & 0x0f00), float(b_val0 & 0xf000));
+    b_mat[1] = scale1 * float4(float(b_val1 & 0x000f), float(b_val1 & 0x00f0),
+                               float(b_val1 & 0x0f00), float(b_val1 & 0xf000));
+    b_mat[2] = scale2 * float4(float(b_val2 & 0x000f), float(b_val2 & 0x00f0),
+                               float(b_val2 & 0x0f00), float(b_val2 & 0xf000));
+    b_mat[3] = scale3 * float4(float(b_val3 & 0x000f), float(b_val3 & 0x00f0),
+                               float(b_val3 & 0x0f00), float(b_val3 & 0xf000));
+
+    result += a_vec * b_mat;
+    result += a_val_sum * zeros;
+  }
+  result += simd_shuffle_down(result, 1);
+  result += simd_shuffle_down(result, 2);
+  result += simd_shuffle_down(result, 4);
+  result += simd_shuffle_down(result, 8);
+  result += simd_shuffle_down(result, 16);
+  if (tid_in_simdgroup % threads_per_channel == 0) {
+    reinterpret_cast<device vecT *>(output_data + m * N)[n / 4] = vecT(result);
+  }
+}
+
+#define INSTANTIATE_INT4MV(DTYPE, GSIZE)                                       \
+  template [[host_name("int4pack_mm_" #GSIZE "_" #DTYPE)]] kernel void         \
+  int4pack_mm<DTYPE, GSIZE>(                                                   \
+      constant DTYPE * A [[buffer(0)]], constant uchar * B [[buffer(1)]],      \
+      constant DTYPE * scales_and_zeros [[buffer(2)]],                         \
+      device DTYPE * output_data [[buffer(3)]],                                 \
+      constant uint3 & sizes [[buffer(4)]],                                    \
+      uint3 thread_index [[thread_position_in_grid]],                           \
+      uint tid_in_simdgroup [[thread_index_in_simdgroup]])
+
+INSTANTIATE_INT4MV(float, 32);
+INSTANTIATE_INT4MV(half, 32);
+INSTANTIATE_INT4MV(float, 64);
+INSTANTIATE_INT4MV(half, 64);
+INSTANTIATE_INT4MV(float, 128);
+INSTANTIATE_INT4MV(half, 128);
+INSTANTIATE_INT4MV(float, 256);
+INSTANTIATE_INT4MV(half, 256);
 #if __METAL_VERSION__ >= 310
-INSTANTIATE_INT4MM(bfloat, 32);
-INSTANTIATE_INT4MM(bfloat, 64);
-INSTANTIATE_INT4MM(bfloat, 128);
-INSTANTIATE_INT4MM(bfloat, 256);
+INSTANTIATE_INT4MV(bfloat, 32);
+INSTANTIATE_INT4MV(bfloat, 64);
+INSTANTIATE_INT4MV(bfloat, 128);
+INSTANTIATE_INT4MV(bfloat, 256);
 #endif
 
 // ------------------------------ int8 MM For M >= 12 ------------------------------------
@@ -601,6 +715,52 @@
 
 )METAL_QUANTIZED");
 
+Tensor _convert_weight_to_int4pack_mps(const Tensor& in, int64_t innerKTiles) {
+  TORCH_CHECK(in.dim() == 2, __func__, " : expect weight to be 2D tensor.");
+  TORCH_CHECK(in.dtype() == at::kInt, __func__, " : expect weight to be kInt.");
+  TORCH_CHECK(innerKTiles == 2 || innerKTiles == 4 || innerKTiles == 8,
+              __func__,
+              " : innerKTiles need to be 2, 4, or 8, got ",
+              innerKTiles);
+
+  auto weight = in.contiguous();
+  auto N = weight.size(0);
+  auto K = weight.size(1);
+
+  // Create fake shapes for cpu. The meta registration in dynamo requires
+  // operator has the same output shape for each device. So creating a fake
+  // shape {N / 8, K / (16 * innerKTiles), 32, innerKTiles / 2}
+  auto weight_packed = at::empty({N / 8, K / (16 * innerKTiles), 32, innerKTiles / 2},
+                                 at::TensorOptions().dtype(at::kInt).device(at::kMPS));
+
+  MPSStream* mpsStream = getCurrentMPSStream();
+  std::array<uint32_t, 4> sizes = {static_cast<uint32_t>(N), static_cast<uint32_t>(K), 0, 0};
+  dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
+    @autoreleasepool {
+#if _CAPTURE_KERNEL
+      if (getMPSProfiler().isCaptureEnabled()) {
+        getMPSProfiler().startCapture(fmt::format("weight_to_int4pack_{}x{}", N, K), mpsStream);
+      }
+#endif
+      id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
+      const std::string kernel = fmt::format("weight_to_int4pack");
+      id<MTLComputePipelineState> quantizedPSO = lib.getPipelineStateForFunc(kernel);
+      const auto maxThreadsPerGroup = [quantizedPSO maxTotalThreadsPerThreadgroup];
+      [computeEncoder setComputePipelineState:quantizedPSO];
+      mtl_setBuffer(computeEncoder, weight, 0);
+      mtl_setBuffer(computeEncoder, weight_packed, 1);
+      [computeEncoder setBytes:sizes.data() length:sizeof(uint32_t) * sizes.size() atIndex:2];
+      [computeEncoder dispatchThreads:MTLSizeMake(N, K / 2, 1) threadsPerThreadgroup:MTLSizeMake(64, 1, 1)];
+#if _CAPTURE_KERNEL
+      if (getMPSProfiler().isCapturing()) {
+        getMPSProfiler().stopCapture(mpsStream);
+      }
+#endif
+    }
+  });
+  return weight_packed;
+}
+
 Tensor _weight_int4pack_mm_mps(const Tensor& A, const Tensor& B, int64_t qGroupSize, const Tensor& qScaleAndZeros) {
   constexpr int64_t kNTileSize = 8;
 
@@ -649,7 +809,7 @@
       mtl_setBuffer(computeEncoder, qScaleAndZeros, 2);
       mtl_setBuffer(computeEncoder, C, 3);
       [computeEncoder setBytes:sizes.data() length:sizeof(uint32_t) * sizes.size() atIndex:4];
-      [computeEncoder dispatchThreads:MTLSizeMake(N / 2, 4, M) threadsPerThreadgroup:MTLSizeMake(16, 4, 1)];
+      [computeEncoder dispatchThreads:MTLSizeMake(N / 4 * 32, 1, M) threadsPerThreadgroup:MTLSizeMake(64, 1, 1)];
 #if _CAPTURE_KERNEL
       if (getMPSProfiler().isCapturing()) {
         getMPSProfiler().stopCapture(mpsStream);
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index b4cde40..ade3820 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -4121,6 +4121,7 @@
   dispatch:
     CPU: _convert_weight_to_int4pack_cpu
     CUDA: _convert_weight_to_int4pack_cuda
+    MPS: _convert_weight_to_int4pack_mps
 
 - func: _weight_int4pack_mm(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros) -> Tensor
   dispatch:
diff --git a/test/test_mps.py b/test/test_mps.py
index 0693a59..1f4d1a6 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -9162,8 +9162,8 @@
                 b, n_bit=4, q_group_size=q_group
             )
             b_int4pack = torch._convert_weight_to_int4pack(
-                b_int32.cpu(), inner_k_tiles
-            ).to(device="mps")
+                b_int32, inner_k_tiles
+            )
 
             return b_int4pack, b_scales_and_zeros