[MPS] Add support for MPSProfiler Python bindings (#101002)

- Added torch.mps.profiler.[start() and stop()] APIs with RST documentation
- Added test case in test_mps
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101002
Approved by: https://github.com/malfet
diff --git a/aten/src/ATen/detail/MPSHooksInterface.h b/aten/src/ATen/detail/MPSHooksInterface.h
index 7d67d63..114abde 100644
--- a/aten/src/ATen/detail/MPSHooksInterface.h
+++ b/aten/src/ATen/detail/MPSHooksInterface.h
@@ -59,6 +59,14 @@
   virtual void setMemoryFraction(double /*ratio*/) const {
     AT_ERROR("Cannot execute setMemoryFraction() without MPS backend.");
   }
+
+  virtual void profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const {
+    AT_ERROR("Cannot execute profilerStartTrace() without MPS backend.");
+  }
+
+  virtual void profilerStopTrace() const {
+    AT_ERROR("Cannot execute profilerStopTrace() without MPS backend.");
+  }
 };
 
 struct TORCH_API MPSHooksArgs {};
diff --git a/aten/src/ATen/mps/MPSHooks.h b/aten/src/ATen/mps/MPSHooks.h
index 9e913b3..61ff9a8 100644
--- a/aten/src/ATen/mps/MPSHooks.h
+++ b/aten/src/ATen/mps/MPSHooks.h
@@ -21,6 +21,8 @@
   size_t getCurrentAllocatedMemory() const override;
   size_t getDriverAllocatedMemory() const override;
   void setMemoryFraction(double ratio) const override;
+  void profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const override;
+  void profilerStopTrace() const override;
 };
 
 }} // at::mps
diff --git a/aten/src/ATen/mps/MPSHooks.cpp b/aten/src/ATen/mps/MPSHooks.mm
similarity index 81%
rename from aten/src/ATen/mps/MPSHooks.cpp
rename to aten/src/ATen/mps/MPSHooks.mm
index 1186f3a..f0c5289 100644
--- a/aten/src/ATen/mps/MPSHooks.cpp
+++ b/aten/src/ATen/mps/MPSHooks.mm
@@ -1,9 +1,11 @@
 //  Copyright © 2022 Apple Inc.
 
-#include <ATen/mps/MPSHooks.h>
+#include <ATen/mps/MPSAllocatorInterface.h>
 #include <ATen/mps/MPSDevice.h>
 #include <ATen/mps/MPSGeneratorImpl.h>
-#include <ATen/mps/MPSAllocatorInterface.h>
+#include <ATen/mps/MPSHooks.h>
+#include <ATen/mps/MPSProfiler.h>
+#include <c10/util/Logging.h>
 
 namespace at {
 namespace mps {
@@ -28,7 +30,7 @@
     case 3:
       return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
     default:
-      TORCH_WARN("Can't check whether running on 13.",minor,"+ returning one for 13.3+");
+      TORCH_WARN("Can't check whether running on 13.", minor, "+ returning one for 13.3+");
       return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
   }
 }
@@ -61,6 +63,14 @@
   at::mps::getIMPSAllocator()->setHighWatermarkRatio(ratio);
 }
 
+void MPSHooks::profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const {
+  at::mps::getMPSProfiler().StartTrace(mode, waitUntilCompleted);
+}
+
+void MPSHooks::profilerStopTrace() const {
+  at::mps::getMPSProfiler().StopTrace();
+}
+
 using at::MPSHooksRegistry;
 using at::RegistererMPSHooksRegistry;
 
diff --git a/docs/source/mps.rst b/docs/source/mps.rst
index 91662aa..7ed30f9 100644
--- a/docs/source/mps.rst
+++ b/docs/source/mps.rst
@@ -15,4 +15,14 @@
     empty_cache
     set_per_process_memory_fraction
     current_allocated_memory
-    driver_allocated_memory
\ No newline at end of file
+    driver_allocated_memory
+
+MPS Profiler
+------------
+.. autosummary::
+    :toctree: generated
+    :nosignatures:
+
+    profiler.start
+    profiler.stop
+    profiler.profile
diff --git a/test/test_mps.py b/test/test_mps.py
index 60ef643..f281c7a 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -7325,6 +7325,25 @@
         self.assertTrue(current_alloc_after > current_alloc_before)
         self.assertTrue(driver_alloc_after > driver_alloc_before)
 
+    # to verify this test, run XCode Instruments "Metal System Trace" or "Logging" tool,
+    # press record, then run this python test, and press stop. Next expand
+    # the os_signposts->PyTorchMPS and check if events or intervals are logged
+    # like this example:
+    # "aten::mps_convolution_backward_input:f32[1,128,6,6]:f32[128,64,3,3]:1,128,6,6 (id=G2, run=2)"
+    def test_mps_profiler_module(self):
+        with torch.mps.profiler.profile(mode="event", wait_until_completed=False) as p:
+            # just running some ops to capture the OS Signposts traces for profiling
+            net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
+                .to(device='mps', dtype=torch.float)
+            x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
+            x = net1(x)
+
+        torch.mps.profiler.start(mode="interval", wait_until_completed=True)
+        # just running some ops to capture the OS Signposts traces for profiling
+        x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
+        x = net1(x)
+        torch.mps.profiler.stop()
+
     # Test random_, random_.to and random_.from
     def test_random(self):
         def helper(shape, low, high, dtype=torch.int32):
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index a2bc11e..0fb1976 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -1450,6 +1450,8 @@
 def _mps_driverAllocatedMemory() -> _int: ...
 def _mps_is_available() -> _bool: ...
 def _mps_is_on_macos_13_or_newer(minor: _int) -> _bool: ...
+def _mps_profilerStartTrace(mode: str, wait_until_completed: _bool) -> None: ...
+def _mps_profilerStopTrace() -> None: ...
 
 # Defined in torch/csrc/cuda/Module.cpp
 def _cuda_getCurrentStream(device: _int) -> Tuple: ...
diff --git a/torch/csrc/mps/Module.cpp b/torch/csrc/mps/Module.cpp
index 0a1c45c..7433158 100644
--- a/torch/csrc/mps/Module.cpp
+++ b/torch/csrc/mps/Module.cpp
@@ -4,6 +4,7 @@
 #include <torch/csrc/THP.h>
 #include <torch/csrc/python_headers.h>
 #include <torch/csrc/utils/python_numbers.h>
+#include <torch/csrc/utils/python_strings.h>
 
 // pthread.h is included for tracking bad forks
 #ifndef WIN32
@@ -94,8 +95,8 @@
       THPUtils_checkDouble(args), "invalid argument to setMemoryFraction()");
   double fraction = THPUtils_unpackDouble(args);
   at::detail::getMPSHooks().setMemoryFraction(fraction);
-  END_HANDLE_TH_ERRORS
   Py_RETURN_NONE;
+  END_HANDLE_TH_ERRORS
 }
 
 static PyObject* MPSModule_currentAllocatedMemory(
@@ -116,6 +117,33 @@
   END_HANDLE_TH_ERRORS
 }
 
+static PyObject* MPSModule_profilerStartTrace(
+    PyObject* _unused,
+    PyObject* args) {
+  HANDLE_TH_ERRORS
+  PyObject* mode_string_o = nullptr;
+  PyObject* wait_until_completed_string_o = nullptr;
+  if (!PyArg_ParseTuple(
+          args, "OO", &mode_string_o, &wait_until_completed_string_o)) {
+    return nullptr;
+  }
+  const std::string mode = THPUtils_unpackString(mode_string_o);
+  const bool waitUntilCompleted =
+      THPUtils_unpackBool(wait_until_completed_string_o);
+  at::detail::getMPSHooks().profilerStartTrace(mode, waitUntilCompleted);
+  Py_RETURN_NONE;
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject* MPSModule_profilerStopTrace(
+    PyObject* _unused,
+    PyObject* noargs) {
+  HANDLE_TH_ERRORS
+  at::detail::getMPSHooks().profilerStopTrace();
+  Py_RETURN_NONE;
+  END_HANDLE_TH_ERRORS
+}
+
 // NOLINTNEXTLINE(modernize-avoid-c-arrays,
 // cppcoreguidelines-avoid-non-const-global-variables,
 // cppcoreguidelines-avoid-c-arrays)
@@ -141,6 +169,14 @@
      MPSModule_driverAllocatedMemory,
      METH_NOARGS,
      nullptr},
+    {"_mps_profilerStartTrace",
+     MPSModule_profilerStartTrace,
+     METH_VARARGS,
+     nullptr},
+    {"_mps_profilerStopTrace",
+     MPSModule_profilerStopTrace,
+     METH_NOARGS,
+     nullptr},
     {nullptr}};
 
 PyMethodDef* python_functions() {
diff --git a/torch/mps/__init__.py b/torch/mps/__init__.py
index 2ab9555..beb0072 100644
--- a/torch/mps/__init__.py
+++ b/torch/mps/__init__.py
@@ -98,7 +98,9 @@
     """
     return torch._C._mps_driverAllocatedMemory()
 
+from . import profiler
+
 __all__ = [
     'get_rng_state', 'manual_seed', 'seed', 'set_rng_state', 'synchronize',
     'empty_cache', 'set_per_process_memory_fraction', 'current_allocated_memory',
-    'driver_allocated_memory']
+    'driver_allocated_memory', 'profiler']
diff --git a/torch/mps/profiler.py b/torch/mps/profiler.py
new file mode 100644
index 0000000..5ad94d0
--- /dev/null
+++ b/torch/mps/profiler.py
@@ -0,0 +1,55 @@
+import torch
+import contextlib
+
+__all__ = ["start", "stop", "profile"]
+
+def start(mode: str = "interval", wait_until_completed: bool = False) -> None:
+    r"""Start OS Signpost tracing from MPS backend.
+
+    The generated OS Signposts could be recorded and viewed in
+    XCode Instruments Logging tool.
+
+    Args:
+        mode(str): OS Signpost tracing mode could be "interval", "event",
+            or both "interval,event".
+            The interval mode traces the duration of execution of the operations,
+            whereas event mode marks the completion of executions.
+            See document `Recording Performance Data`_ for more info.
+        wait_until_completed(bool): Waits until the MPS Stream complete
+            executing each encoded GPU operation. This helps generating single
+            dispatches on the trace's timeline.
+            Note that enabling this option would affect the performance negatively.
+
+    .. _Recording Performance Data:
+       https://developer.apple.com/documentation/os/logging/recording_performance_data
+    """
+    mode_normalized = mode.lower().replace(" ", "")
+    torch._C._mps_profilerStartTrace(mode_normalized, wait_until_completed)
+
+def stop():
+    r"""Stops generating OS Signpost tracing from MPS backend."""
+    torch._C._mps_profilerStopTrace()
+
[email protected]
+def profile(mode: str = "interval", wait_until_completed: bool = False):
+    r"""Context Manager to enabling generating OS Signpost tracing from MPS backend.
+
+    Args:
+        mode(str): OS Signpost tracing mode could be "interval", "event",
+            or both "interval,event".
+            The interval mode traces the duration of execution of the operations,
+            whereas event mode marks the completion of executions.
+            See document `Recording Performance Data`_ for more info.
+        wait_until_completed(bool): Waits until the MPS Stream complete
+            executing each encoded GPU operation. This helps generating single
+            dispatches on the trace's timeline.
+            Note that enabling this option would affect the performance negatively.
+
+    .. _Recording Performance Data:
+       https://developer.apple.com/documentation/os/logging/recording_performance_data
+    """
+    try:
+        start(mode, wait_until_completed)
+        yield
+    finally:
+        stop()