[MPS] And naive quantized intmm and `.gputrace` capture hooks (#125163)
- Implement a very straightforward Metal copy of CPU int4mm kernel
- Implement int8mm kernel by constructing a graph consisting of upcast, transpose and mm
- Add `isCapturing`, `isCaptureEnabled`, `startCapture` and `stopCapture` methods to `MPSProfile` which can be used to help one debug/profile Metal kernels by wrapping the calls with the following
```cpp
if (getMPSProfiler().profiler.isCaptureEnabled()) {
getMPSProfiler().startCapture(__func__, mpsStream);
}
...
if (getMPSProfiler().isCapturing()) {
getMPSProfiler().stopCapture(mpsStream);
}
```
that, if invoked with `MTL_CAPTURE_ENABLED` environment variable set to one, will produce .gputrace files, in the current working directory, which can later be loaded and used to debug or profiler the kernel
<img width="1093" alt="image" src="https://github.com/pytorch/pytorch/assets/2453524/a2bf27e8-df8a-442c-a525-1df67b8a376a">
- Added `test_int4mm` to TestLinalgMPS, which is mostly copy-n-paste of the test from `test_linalg`
TODOs:
- Add weight pack
- Perf-tune both kernels
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125163
Approved by: https://github.com/mikekgfb
diff --git a/test/test_mps.py b/test/test_mps.py
index 1bc1ca8..314ac4b 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -22,8 +22,8 @@
from torch.nn import Parameter
from torch.testing._internal import opinfo
from torch.testing._internal.common_utils import \
- (gradcheck, gradgradcheck, run_tests, TestCase, download_file, IS_CI, NoTest,
- skipIfSlowGradcheckEnv, suppress_warnings)
+ (gradcheck, gradgradcheck, parametrize, run_tests, TestCase, download_file, IS_CI,
+ NoTest, skipIfSlowGradcheckEnv, suppress_warnings)
from torch.testing import make_tensor
from torch.testing._internal.common_dtype import get_all_dtypes, integral_types
import torch.backends.mps
@@ -40,6 +40,7 @@
)
from torch.testing._internal.common_device_type import ops, dtypes, instantiate_device_type_tests, OpDTypes
from torch.testing._internal.common_nn import NNTestCase
+from torch.testing._internal.common_quantization import _group_quantize_tensor
import numpy as np
import torch
import torch.utils._pytree as pytree
@@ -9050,6 +9051,45 @@
"The operator 'aten::_linalg_eigh.eigenvalues' is not currently implemented for the MPS device."):
raise e
+ @parametrize("m", [32, 64])
+ @parametrize("k", [32, 64])
+ @parametrize("n", [48, 64])
+ def test__int4_mm(self, m, k, n):
+ q_group = 32
+ inner_k_tiles = 2
+
+ torch.manual_seed(1)
+ a_f32 = torch.rand((m, k), device="mps")
+ b_f32 = torch.rand((k, n), device="mps")
+
+ def convert_weight_to_int4pack(b):
+ b_int32, b_scales_and_zeros = _group_quantize_tensor(
+ b, n_bit=4, q_group_size=q_group
+ )
+ b_int4pack = torch._convert_weight_to_int4pack(
+ b_int32.cpu(), inner_k_tiles
+ ).to(device="mps")
+
+ return b_int4pack, b_scales_and_zeros
+
+ def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros):
+ return torch._weight_int4pack_mm(
+ a, b_int4pack, q_group, b_scales_and_zeros
+ ).to(device="mps")
+
+ b_int4pack, b_scales_and_zeros_f32 = convert_weight_to_int4pack(b_f32)
+
+ for dtype in [torch.float16, torch.float32] + ([torch.bfloat16] if product_version > 14.0 else []):
+ a = a_f32.to(dtype=dtype)
+ b = b_f32.to(dtype=dtype)
+ b_scales_and_zeros = b_scales_and_zeros_f32.to(dtype=dtype)
+ ref = torch.mm(a, b)
+ res = weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros)
+
+ mean_err = ((res - ref).abs() / ref).mean()
+ self.assertTrue(mean_err < 0.05)
+
+
@@ -11844,6 +11884,7 @@
cpu_tensor = ones("cpu")
self.assertEqual(mps_tensor.cpu(), cpu_tensor)
+
# TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing.
# This requires mps to be properly registered in the device generic test framework which is not the
# case right now. We can probably use `allow_mps` introduced in https://github.com/pytorch/pytorch/pull/87342
@@ -11851,6 +11892,7 @@
instantiate_device_type_tests(TestConsistency, globals(), only_for="cpu")
instantiate_device_type_tests(TestErrorInputs, globals(), allow_mps=True, only_for="mps")
instantiate_device_type_tests(TestCommon, globals(), allow_mps=True, only_for="mps")
+instantiate_device_type_tests(TestLinalgMPS, globals(), allow_mps=True, only_for="mps")
if __name__ == "__main__":
run_tests()