[metal] Add int4mm weight packing mps kernel, and improved int4mm shader (#128965) Adds _convert_weight_to_int4pack MPS kernel Replaces previous int4mm Metal shader, with shader authored by @kimishpatel which improves perf by ~40% Pull Request resolved: https://github.com/pytorch/pytorch/pull/128965 Approved by: https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py index 0693a59..1f4d1a6 100644 --- a/test/test_mps.py +++ b/test/test_mps.py
@@ -9162,8 +9162,8 @@ b, n_bit=4, q_group_size=q_group ) b_int4pack = torch._convert_weight_to_int4pack( - b_int32.cpu(), inner_k_tiles - ).to(device="mps") + b_int32, inner_k_tiles + ) return b_int4pack, b_scales_and_zeros