[metal] Add int4mm weight packing mps kernel, and improved int4mm shader (#128965)
Adds _convert_weight_to_int4pack MPS kernel
Replaces previous int4mm Metal shader, with shader authored by @kimishpatel which improves perf by ~40%
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128965
Approved by: https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index 0693a59..1f4d1a6 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -9162,8 +9162,8 @@
b, n_bit=4, q_group_size=q_group
)
b_int4pack = torch._convert_weight_to_int4pack(
- b_int32.cpu(), inner_k_tiles
- ).to(device="mps")
+ b_int32, inner_k_tiles
+ )
return b_int4pack, b_scales_and_zeros