[metal] Add int4mm weight packing mps kernel, and improved int4mm shader (#128965)
Adds _convert_weight_to_int4pack MPS kernel
Replaces previous int4mm Metal shader, with shader authored by @kimishpatel which improves perf by ~40%
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128965
Approved by: https://github.com/malfet
diff --git a/aten/src/ATen/native/mps/operations/Quantized.mm b/aten/src/ATen/native/mps/operations/Quantized.mm
index 860d53b..326239e 100644
--- a/aten/src/ATen/native/mps/operations/Quantized.mm
+++ b/aten/src/ATen/native/mps/operations/Quantized.mm
@@ -4,6 +4,7 @@
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
+#include <ATen/ops/_convert_weight_to_int4pack_native.h>
#include <ATen/ops/_weight_int4pack_mm_native.h>
#include <ATen/ops/_weight_int8pack_mm_native.h>
#include <ATen/ops/empty.h>
@@ -54,96 +55,209 @@
};
#endif
-template<typename T, unsigned groupSize>
-kernel void int4pack_mm(
- constant T * A [[buffer(0)]],
- constant uchar * B [[buffer(1)]],
- constant T * scalesAndZeros [[buffer(2)]],
- device T * outputData [[buffer(3)]],
- constant uint3 & sizes [[buffer(4)]], // M, K, N
- uint3 group_index [[threadgroup_position_in_grid]],
- uint3 threadgroup_index [[thread_position_in_threadgroup]]) {
-
- const uint K = sizes.y;
- const uint N = sizes.z;
- const uint nb = group_index.x; // 0..N/32-1
- const uint n2 = 16 * nb + threadgroup_index.x; // 0..N/2-1
- const uint m = group_index.z;
- const uint ldb = min(32U, N - nb * 32);
- const uint32_t k_block = (K + groupSize - 1) / groupSize;
-
- using vec2T = typename Vec2Type<T>::type;
- using vec4T = typename Vec4Type<T>::type;
-
- constant vec4T *A_ptr = reinterpret_cast<constant vec4T *>(A + m * K);
- constant uchar *B_ptr = B + (nb * 16 * K);
-
- float2 rc = 0.0;
- uint k = threadgroup_index.y * 4;
- for (uint32_t kb = 0; kb < k_block ; kb ++) {
- float2 scales, zeros;
- for (int i = 0; i < 2; ++i) {
- scales[i] = scalesAndZeros[(kb * N + 2*n2 + i) * 2 + 0];
- zeros[i] = scalesAndZeros[(kb * N + 2*n2 + i) * 2 + 1] - scales[i] * T(8);
- }
-
- for(uint idx = k % groupSize; idx < groupSize && k < K; idx += 16, k += 16) {
- threadgroup_barrier(mem_flags::mem_none);
-
- const auto a_vec = float4(A_ptr[k/4]);
- uchar4 b_byte;
- for (int i = 0; i < 4; i++) {
- b_byte[i] = B_ptr[((k + i) * ldb + (2*n2 % 32))/2];
- }
-
- float4x2 b_mat;
-
- for (int i = 0; i < 4; i++) {
- b_mat[i] = scales * float2(
- float(b_byte[i] & 0x0f),
- float(b_byte[i] >> 4)) + zeros;
- }
-
- rc += b_mat * a_vec;
- }
- }
-
- threadgroup float2 tgp_memory[16][4];
- tgp_memory[threadgroup_index.x][threadgroup_index.y] = rc;
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (threadgroup_index.y == 0) {
- for (unsigned i = 1; i < 4; i++) {
- rc += tgp_memory[threadgroup_index.x][i];
- }
- reinterpret_cast<device vec2T*>(outputData + m * N)[n2] = vec2T(rc);
- }
+kernel void weight_to_int4pack(constant int *W [[buffer(0)]],
+ device uchar *outputData [[buffer(1)]],
+ constant uint2 &sizes [[buffer(2)]],
+ uint2 thread_index [[thread_position_in_grid]]) {
+ const uint N = sizes.x;
+ const uint K = sizes.y;
+ const uint n = thread_index.x; // 0..N-1
+ const uint k2 = thread_index.y; // 0..K/2-1
+ int32_t src_val0 = W[n * K + 2 * k2];
+ int32_t src_val1 = W[n * K + 2 * k2 + 1];
+ outputData[n * (K / 2) + k2] = (uint8_t(src_val1) << 4) | uint8_t(src_val0);
}
-#define INSTANTIATE_INT4MM(DTYPE, GSIZE) \
-template \
-[[host_name("int4pack_mm_" #GSIZE "_" #DTYPE)]] \
-kernel void int4pack_mm<DTYPE, GSIZE>( \
- constant DTYPE * A [[buffer(0)]], \
- constant uchar * B [[buffer(1)]], \
- constant DTYPE * scalesAndZeros [[buffer(2)]], \
- device DTYPE * outputData [[buffer(3)]], \
- constant uint3 & sizes [[buffer(4)]], \
- uint3 group_index [[threadgroup_position_in_grid]], \
- uint3 threadgroup_index [[thread_position_in_threadgroup]])
+/*
+ This code takes heavy inspiration from MLX qvm kernel here:
+ https://github.com/ml-explore/mlx/blob/main/mlx/backend/metal/kernels/quantized.metal#L381
+ Specifically:
+ - Multiplying activation by inverse scaling factor to reduce compute
+ boundedness
+ - Handling zero point by accumulating act in separate sum term. Needed with
+ optimization done above. MLX MIT License:
+ https://github.com/ml-explore/mlx/blob/main/LICENSE
+*/
-INSTANTIATE_INT4MM(float, 32);
-INSTANTIATE_INT4MM(half, 32);
-INSTANTIATE_INT4MM(float, 64);
-INSTANTIATE_INT4MM(half, 64);
-INSTANTIATE_INT4MM(float, 128);
-INSTANTIATE_INT4MM(half, 128);
-INSTANTIATE_INT4MM(float, 256);
-INSTANTIATE_INT4MM(half, 256);
+/*
+ A matrix is [M x K] (right now this kernel does not support M > 1 but this is
+ a very easy fix that will follow right after) B matrix is [N x K]. For 4 bit
+ 2 of the k values are packed in one byte so you can think of B as [N x K/2]
+ matrix from layout perspective.
+
+ Since this kernel is optimizing for gemv case, we split work, along reduction
+ dim k, among the threads of same simdgroup. Ex: if K = 4096 and simdgroup
+ size is 32 (current algorithm should work as long as simdgroup size is > 32).
+ Then each thread will accumulate 4096/32 = 128 k values. However these 128
+ values, handled by each thread are not laid out contiguously. Each thread
+ handles 4 contiguous k values and then jumps 128 elements, k_jump =
+ thread_per_channel (32) * ks_per_thread (4). Take a simpler example where
+ simdgroup is of size 4. In this case threads_per_channel = 4. Assume K = 32
+ k thread
+ [0, 1, 2, 3, 0
+ 4, 5, 6, 7, 1
+ 8, 9, 10, 11, 2
+ 12, 13, 14, 15, 3
+ 16, 17, 18, 19, 0
+ 20, 21, 22, 23, 1
+ 24, 25, 26, 27, 2
+ 28, 29, 30, 31] 3
+ thread id in simd group that handle corresponding
+ ks
+ Thread 0 here is handling (0, 1, 2, 3) and then (16, 17, 18, 19). They are
+ apart by k_jump = 4 * 4 = 16 This is done to improve memory access locality
+ amonng threads that are working co-operatively. Once each thread has their
+ partial sums accumulated, we use tree reduction (Metal offers simd_sum but
+ not used so that we support simdgroup size = 64). In the
+ example above we will have 4 partial sums.
+
+ Each thread also handles 4 different output rows. Thus each simdgroup will be
+ responsible for (1x4) tile of the output. We haven't evaluated whether a
+ different tile size is better or not. We probably will do some auto-tuning
+ once initial work is done.
+
+*/
+
+/*
+ @brief This shader implements 4-bit matrix-vector multiplication where A
+ matrix is fp16, bfloat or float and B matrix is a 4-bit groupwise-quantized weight
+ matrix.
+ @param [in] A is activation matrix of size M x K.
+ @param [in] B is weight matrix of size M x K. Each byte contains 2 4-bit
+ values, along K dim, packed together.
+ @param [in] scales_and_zeros is scales and zero points corresponding each
+ output channel x groups. These are packed as [num_groups = ceil(K / group_size), N, 2]. N = output
+ @param [out] output_data is output matrix of size M x N.
+ @param [in] sizes array contains values of M, N and K.
+ @param [in] thread_index is global thread id.
+ @param [in] tid_in_simdgruop is thread id in simdgroup. e.g. in simdgroup of size 32 it can be in [0-31].
+*/
+template <typename T, unsigned group_size>
+kernel void int4pack_mm(constant T *A [[buffer(0)]],
+ constant uchar *B [[buffer(1)]],
+ constant T *scales_and_zeros [[buffer(2)]],
+ device T *output_data [[buffer(3)]],
+ constant uint3 &sizes [[buffer(4)]], // M, K, N
+ uint3 thread_index [[thread_position_in_grid]],
+ uint tid_in_simdgroup [[thread_index_in_simdgroup]]) {
+ constexpr uint threads_per_channel = 32;
+ constexpr uint ks_per_thread = 4;
+ constexpr uint k_pack_factor = 2;
+ const uint K = sizes.y;
+ const uint N = sizes.z;
+ uint n = thread_index.x; // 0..N/4-1
+ uint m = thread_index.z; // 0..M
+ n = n / threads_per_channel;
+ n = n * 4;
+ // This is starting k for each thread. In the example above, for thread 1 this
+ // value will be 4.
+ uint k = (tid_in_simdgroup % threads_per_channel) * ks_per_thread;
+ constexpr int k_jump = threads_per_channel * ks_per_thread;
+
+ using vecT = typename Vec4Type<T>::type;
+ constant vecT *A_ptr = reinterpret_cast<constant vecT *>(A + m * K);
+ constant uchar *B_ptr = B + ((n * K) / k_pack_factor);
+
+ thread float4 result = float4(0.0);
+ // We multipy group of 4 channels with these scales.
+ // Because corresponding values from weight matrix are effectively left
+ // shifted. This is to avoid doing right shift on those values which ends up
+ // affecting performance. This is the trick applied in MLX kernels.
+ float4 act_div_scales = {1.f, 1 / 16.f, 1 / 256.f, 1 / 4096.f};
+
+ // Find specific group to which channels handled by this thread
+ // belong.
+ uint k_block_index = k / group_size;
+ // Since scales_and_zeros are packed as [num_groups, N, 2].
+ // Finding a specific's group's scales and zero points requires jump by factor
+ // of N*2
+ uint scales_group_offset = (k_block_index * N + n) * 2;
+ uint zeros_gruop_offset = scales_group_offset + 1;
+ const uint scales_jump =
+ N * 2 *
+ (k_jump /
+ group_size); /* the last term accounts for identifying the group this
+ thread will have to process in each iteration. This mean
+ each iteration it must jump to a different group. Thus
+ k_jump must be > group_size */
+ for (; k < K; k += k_jump) {
+ const T scale0 = scales_and_zeros[scales_group_offset];
+ // Adding zero point results in 10% perf penalty.
+ const T zero0 = scales_and_zeros[zeros_gruop_offset] - scale0 * T(8);
+
+ const T scale1 = scales_and_zeros[scales_group_offset + 2];
+ const T zero1 = scales_and_zeros[zeros_gruop_offset + 2] - scale1 * T(8);
+
+ const T scale2 = scales_and_zeros[scales_group_offset + 4];
+ const T zero2 = scales_and_zeros[zeros_gruop_offset + 4] - scale2 * T(8);
+
+ const T scale3 = scales_and_zeros[scales_group_offset + 6];
+ const T zero3 = scales_and_zeros[zeros_gruop_offset + 6] - scale3 * T(8);
+
+ scales_group_offset += scales_jump;
+ zeros_gruop_offset += scales_jump;
+
+ const float4 zeros = float4(zero0, zero1, zero2, zero3);
+
+ float4 a_val = float4(A_ptr[k / 4]);
+ // We are gonna skip right-shifts of the weights and hence divide by corresponding factor.
+ float4 a_vec = a_val * act_div_scales;
+ float a_val_sum = a_val[0] + a_val[1] + a_val[2] + a_val[3];
+
+ float4x4 b_mat;
+ ushort b_val0 = (reinterpret_cast<constant ushort *>(
+ B_ptr + (k + 0 * K) / k_pack_factor))[0];
+ ushort b_val1 = (reinterpret_cast<constant ushort *>(
+ B_ptr + (k + 1 * K) / k_pack_factor))[0];
+ ushort b_val2 = (reinterpret_cast<constant ushort *>(
+ B_ptr + (k + 2 * K) / k_pack_factor))[0];
+ ushort b_val3 = (reinterpret_cast<constant ushort *>(
+ B_ptr + (k + 3 * K) / k_pack_factor))[0];
+ b_mat[0] = scale0 * float4(float(b_val0 & 0x000f), float(b_val0 & 0x00f0),
+ float(b_val0 & 0x0f00), float(b_val0 & 0xf000));
+ b_mat[1] = scale1 * float4(float(b_val1 & 0x000f), float(b_val1 & 0x00f0),
+ float(b_val1 & 0x0f00), float(b_val1 & 0xf000));
+ b_mat[2] = scale2 * float4(float(b_val2 & 0x000f), float(b_val2 & 0x00f0),
+ float(b_val2 & 0x0f00), float(b_val2 & 0xf000));
+ b_mat[3] = scale3 * float4(float(b_val3 & 0x000f), float(b_val3 & 0x00f0),
+ float(b_val3 & 0x0f00), float(b_val3 & 0xf000));
+
+ result += a_vec * b_mat;
+ result += a_val_sum * zeros;
+ }
+ result += simd_shuffle_down(result, 1);
+ result += simd_shuffle_down(result, 2);
+ result += simd_shuffle_down(result, 4);
+ result += simd_shuffle_down(result, 8);
+ result += simd_shuffle_down(result, 16);
+ if (tid_in_simdgroup % threads_per_channel == 0) {
+ reinterpret_cast<device vecT *>(output_data + m * N)[n / 4] = vecT(result);
+ }
+}
+
+#define INSTANTIATE_INT4MV(DTYPE, GSIZE) \
+ template [[host_name("int4pack_mm_" #GSIZE "_" #DTYPE)]] kernel void \
+ int4pack_mm<DTYPE, GSIZE>( \
+ constant DTYPE * A [[buffer(0)]], constant uchar * B [[buffer(1)]], \
+ constant DTYPE * scales_and_zeros [[buffer(2)]], \
+ device DTYPE * output_data [[buffer(3)]], \
+ constant uint3 & sizes [[buffer(4)]], \
+ uint3 thread_index [[thread_position_in_grid]], \
+ uint tid_in_simdgroup [[thread_index_in_simdgroup]])
+
+INSTANTIATE_INT4MV(float, 32);
+INSTANTIATE_INT4MV(half, 32);
+INSTANTIATE_INT4MV(float, 64);
+INSTANTIATE_INT4MV(half, 64);
+INSTANTIATE_INT4MV(float, 128);
+INSTANTIATE_INT4MV(half, 128);
+INSTANTIATE_INT4MV(float, 256);
+INSTANTIATE_INT4MV(half, 256);
#if __METAL_VERSION__ >= 310
-INSTANTIATE_INT4MM(bfloat, 32);
-INSTANTIATE_INT4MM(bfloat, 64);
-INSTANTIATE_INT4MM(bfloat, 128);
-INSTANTIATE_INT4MM(bfloat, 256);
+INSTANTIATE_INT4MV(bfloat, 32);
+INSTANTIATE_INT4MV(bfloat, 64);
+INSTANTIATE_INT4MV(bfloat, 128);
+INSTANTIATE_INT4MV(bfloat, 256);
#endif
// ------------------------------ int8 MM For M >= 12 ------------------------------------
@@ -601,6 +715,52 @@
)METAL_QUANTIZED");
+Tensor _convert_weight_to_int4pack_mps(const Tensor& in, int64_t innerKTiles) {
+ TORCH_CHECK(in.dim() == 2, __func__, " : expect weight to be 2D tensor.");
+ TORCH_CHECK(in.dtype() == at::kInt, __func__, " : expect weight to be kInt.");
+ TORCH_CHECK(innerKTiles == 2 || innerKTiles == 4 || innerKTiles == 8,
+ __func__,
+ " : innerKTiles need to be 2, 4, or 8, got ",
+ innerKTiles);
+
+ auto weight = in.contiguous();
+ auto N = weight.size(0);
+ auto K = weight.size(1);
+
+ // Create fake shapes for cpu. The meta registration in dynamo requires
+ // operator has the same output shape for each device. So creating a fake
+ // shape {N / 8, K / (16 * innerKTiles), 32, innerKTiles / 2}
+ auto weight_packed = at::empty({N / 8, K / (16 * innerKTiles), 32, innerKTiles / 2},
+ at::TensorOptions().dtype(at::kInt).device(at::kMPS));
+
+ MPSStream* mpsStream = getCurrentMPSStream();
+ std::array<uint32_t, 4> sizes = {static_cast<uint32_t>(N), static_cast<uint32_t>(K), 0, 0};
+ dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
+ @autoreleasepool {
+#if _CAPTURE_KERNEL
+ if (getMPSProfiler().isCaptureEnabled()) {
+ getMPSProfiler().startCapture(fmt::format("weight_to_int4pack_{}x{}", N, K), mpsStream);
+ }
+#endif
+ id<MTLComputeCommandEncoder> computeEncoder = mpsStream->commandEncoder();
+ const std::string kernel = fmt::format("weight_to_int4pack");
+ id<MTLComputePipelineState> quantizedPSO = lib.getPipelineStateForFunc(kernel);
+ const auto maxThreadsPerGroup = [quantizedPSO maxTotalThreadsPerThreadgroup];
+ [computeEncoder setComputePipelineState:quantizedPSO];
+ mtl_setBuffer(computeEncoder, weight, 0);
+ mtl_setBuffer(computeEncoder, weight_packed, 1);
+ [computeEncoder setBytes:sizes.data() length:sizeof(uint32_t) * sizes.size() atIndex:2];
+ [computeEncoder dispatchThreads:MTLSizeMake(N, K / 2, 1) threadsPerThreadgroup:MTLSizeMake(64, 1, 1)];
+#if _CAPTURE_KERNEL
+ if (getMPSProfiler().isCapturing()) {
+ getMPSProfiler().stopCapture(mpsStream);
+ }
+#endif
+ }
+ });
+ return weight_packed;
+}
+
Tensor _weight_int4pack_mm_mps(const Tensor& A, const Tensor& B, int64_t qGroupSize, const Tensor& qScaleAndZeros) {
constexpr int64_t kNTileSize = 8;
@@ -649,7 +809,7 @@
mtl_setBuffer(computeEncoder, qScaleAndZeros, 2);
mtl_setBuffer(computeEncoder, C, 3);
[computeEncoder setBytes:sizes.data() length:sizeof(uint32_t) * sizes.size() atIndex:4];
- [computeEncoder dispatchThreads:MTLSizeMake(N / 2, 4, M) threadsPerThreadgroup:MTLSizeMake(16, 4, 1)];
+ [computeEncoder dispatchThreads:MTLSizeMake(N / 4 * 32, 1, M) threadsPerThreadgroup:MTLSizeMake(64, 1, 1)];
#if _CAPTURE_KERNEL
if (getMPSProfiler().isCapturing()) {
getMPSProfiler().stopCapture(mpsStream);
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index b4cde40..ade3820 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -4121,6 +4121,7 @@
dispatch:
CPU: _convert_weight_to_int4pack_cpu
CUDA: _convert_weight_to_int4pack_cuda
+ MPS: _convert_weight_to_int4pack_mps
- func: _weight_int4pack_mm(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros) -> Tensor
dispatch:
diff --git a/test/test_mps.py b/test/test_mps.py
index 0693a59..1f4d1a6 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -9162,8 +9162,8 @@
b, n_bit=4, q_group_size=q_group
)
b_int4pack = torch._convert_weight_to_int4pack(
- b_int32.cpu(), inner_k_tiles
- ).to(device="mps")
+ b_int32, inner_k_tiles
+ )
return b_int4pack, b_scales_and_zeros