[MPS] And naive quantized intmm and `.gputrace` capture hooks (#125163)

- Implement a very straightforward Metal copy of CPU int4mm kernel
- Implement int8mm kernel by constructing a graph consisting of upcast, transpose and mm
- Add `isCapturing`, `isCaptureEnabled`, `startCapture` and `stopCapture` methods to `MPSProfile` which can be used to help one debug/profile Metal kernels by wrapping the calls with the following
  ```cpp
   if (getMPSProfiler().profiler.isCaptureEnabled()) {
     getMPSProfiler().startCapture(__func__, mpsStream);
   }
   ...
   if (getMPSProfiler().isCapturing()) {
     getMPSProfiler().stopCapture(mpsStream);
   }
  ```
  that, if invoked with `MTL_CAPTURE_ENABLED` environment variable set to one, will produce .gputrace files, in the current working directory, which can later be loaded and used to debug or profiler the kernel
<img width="1093" alt="image" src="https://github.com/pytorch/pytorch/assets/2453524/a2bf27e8-df8a-442c-a525-1df67b8a376a">

- Added `test_int4mm` to TestLinalgMPS, which is mostly copy-n-paste of the test from `test_linalg`

TODOs:
 - Add weight pack
 - Perf-tune both kernels
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125163
Approved by: https://github.com/mikekgfb
diff --git a/test/test_mps.py b/test/test_mps.py
index 1bc1ca8..314ac4b 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -22,8 +22,8 @@
 from torch.nn import Parameter
 from torch.testing._internal import opinfo
 from torch.testing._internal.common_utils import \
-    (gradcheck, gradgradcheck, run_tests, TestCase, download_file, IS_CI, NoTest,
-     skipIfSlowGradcheckEnv, suppress_warnings)
+    (gradcheck, gradgradcheck, parametrize, run_tests, TestCase, download_file, IS_CI,
+     NoTest, skipIfSlowGradcheckEnv, suppress_warnings)
 from torch.testing import make_tensor
 from torch.testing._internal.common_dtype import get_all_dtypes, integral_types
 import torch.backends.mps
@@ -40,6 +40,7 @@
 )
 from torch.testing._internal.common_device_type import ops, dtypes, instantiate_device_type_tests, OpDTypes
 from torch.testing._internal.common_nn import NNTestCase
+from torch.testing._internal.common_quantization import _group_quantize_tensor
 import numpy as np
 import torch
 import torch.utils._pytree as pytree
@@ -9050,6 +9051,45 @@
                         "The operator 'aten::_linalg_eigh.eigenvalues' is not currently implemented for the MPS device."):
                     raise e
 
+    @parametrize("m", [32, 64])
+    @parametrize("k", [32, 64])
+    @parametrize("n", [48, 64])
+    def test__int4_mm(self, m, k, n):
+        q_group = 32
+        inner_k_tiles = 2
+
+        torch.manual_seed(1)
+        a_f32 = torch.rand((m, k), device="mps")
+        b_f32 = torch.rand((k, n), device="mps")
+
+        def convert_weight_to_int4pack(b):
+            b_int32, b_scales_and_zeros = _group_quantize_tensor(
+                b, n_bit=4, q_group_size=q_group
+            )
+            b_int4pack = torch._convert_weight_to_int4pack(
+                b_int32.cpu(), inner_k_tiles
+            ).to(device="mps")
+
+            return b_int4pack, b_scales_and_zeros
+
+        def weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros):
+            return torch._weight_int4pack_mm(
+                a, b_int4pack, q_group, b_scales_and_zeros
+            ).to(device="mps")
+
+        b_int4pack, b_scales_and_zeros_f32 = convert_weight_to_int4pack(b_f32)
+
+        for dtype in [torch.float16, torch.float32] + ([torch.bfloat16] if product_version > 14.0 else []):
+            a = a_f32.to(dtype=dtype)
+            b = b_f32.to(dtype=dtype)
+            b_scales_and_zeros = b_scales_and_zeros_f32.to(dtype=dtype)
+            ref = torch.mm(a, b)
+            res = weight_int4pack_mm(a, b_int4pack, b_scales_and_zeros)
+
+            mean_err = ((res - ref).abs() / ref).mean()
+            self.assertTrue(mean_err < 0.05)
+
+
 
 
 
@@ -11844,6 +11884,7 @@
             cpu_tensor = ones("cpu")
             self.assertEqual(mps_tensor.cpu(), cpu_tensor)
 
+
 # TODO: Actually instantiate that test for the "mps" device to better reflect what it is doing.
 # This requires mps to be properly registered in the device generic test framework which is not the
 # case right now. We can probably use `allow_mps` introduced in https://github.com/pytorch/pytorch/pull/87342
@@ -11851,6 +11892,7 @@
 instantiate_device_type_tests(TestConsistency, globals(), only_for="cpu")
 instantiate_device_type_tests(TestErrorInputs, globals(), allow_mps=True, only_for="mps")
 instantiate_device_type_tests(TestCommon, globals(), allow_mps=True, only_for="mps")
+instantiate_device_type_tests(TestLinalgMPS, globals(), allow_mps=True, only_for="mps")
 
 if __name__ == "__main__":
     run_tests()