[MPS] Add support for MPSProfiler Python bindings (#101002)
- Added torch.mps.profiler.[start() and stop()] APIs with RST documentation
- Added test case in test_mps
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101002
Approved by: https://github.com/malfet
diff --git a/aten/src/ATen/detail/MPSHooksInterface.h b/aten/src/ATen/detail/MPSHooksInterface.h
index 7d67d63..114abde 100644
--- a/aten/src/ATen/detail/MPSHooksInterface.h
+++ b/aten/src/ATen/detail/MPSHooksInterface.h
@@ -59,6 +59,14 @@
virtual void setMemoryFraction(double /*ratio*/) const {
AT_ERROR("Cannot execute setMemoryFraction() without MPS backend.");
}
+
+ virtual void profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const {
+ AT_ERROR("Cannot execute profilerStartTrace() without MPS backend.");
+ }
+
+ virtual void profilerStopTrace() const {
+ AT_ERROR("Cannot execute profilerStopTrace() without MPS backend.");
+ }
};
struct TORCH_API MPSHooksArgs {};
diff --git a/aten/src/ATen/mps/MPSHooks.h b/aten/src/ATen/mps/MPSHooks.h
index 9e913b3..61ff9a8 100644
--- a/aten/src/ATen/mps/MPSHooks.h
+++ b/aten/src/ATen/mps/MPSHooks.h
@@ -21,6 +21,8 @@
size_t getCurrentAllocatedMemory() const override;
size_t getDriverAllocatedMemory() const override;
void setMemoryFraction(double ratio) const override;
+ void profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const override;
+ void profilerStopTrace() const override;
};
}} // at::mps
diff --git a/aten/src/ATen/mps/MPSHooks.cpp b/aten/src/ATen/mps/MPSHooks.mm
similarity index 81%
rename from aten/src/ATen/mps/MPSHooks.cpp
rename to aten/src/ATen/mps/MPSHooks.mm
index 1186f3a..f0c5289 100644
--- a/aten/src/ATen/mps/MPSHooks.cpp
+++ b/aten/src/ATen/mps/MPSHooks.mm
@@ -1,9 +1,11 @@
// Copyright © 2022 Apple Inc.
-#include <ATen/mps/MPSHooks.h>
+#include <ATen/mps/MPSAllocatorInterface.h>
#include <ATen/mps/MPSDevice.h>
#include <ATen/mps/MPSGeneratorImpl.h>
-#include <ATen/mps/MPSAllocatorInterface.h>
+#include <ATen/mps/MPSHooks.h>
+#include <ATen/mps/MPSProfiler.h>
+#include <c10/util/Logging.h>
namespace at {
namespace mps {
@@ -28,7 +30,7 @@
case 3:
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
default:
- TORCH_WARN("Can't check whether running on 13.",minor,"+ returning one for 13.3+");
+ TORCH_WARN("Can't check whether running on 13.", minor, "+ returning one for 13.3+");
return is_macos_13_or_newer(MacOSVersion::MACOS_VER_13_3_PLUS);
}
}
@@ -61,6 +63,14 @@
at::mps::getIMPSAllocator()->setHighWatermarkRatio(ratio);
}
+void MPSHooks::profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const {
+ at::mps::getMPSProfiler().StartTrace(mode, waitUntilCompleted);
+}
+
+void MPSHooks::profilerStopTrace() const {
+ at::mps::getMPSProfiler().StopTrace();
+}
+
using at::MPSHooksRegistry;
using at::RegistererMPSHooksRegistry;
diff --git a/docs/source/mps.rst b/docs/source/mps.rst
index 91662aa..7ed30f9 100644
--- a/docs/source/mps.rst
+++ b/docs/source/mps.rst
@@ -15,4 +15,14 @@
empty_cache
set_per_process_memory_fraction
current_allocated_memory
- driver_allocated_memory
\ No newline at end of file
+ driver_allocated_memory
+
+MPS Profiler
+------------
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
+
+ profiler.start
+ profiler.stop
+ profiler.profile
diff --git a/test/test_mps.py b/test/test_mps.py
index 60ef643..f281c7a 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -7325,6 +7325,25 @@
self.assertTrue(current_alloc_after > current_alloc_before)
self.assertTrue(driver_alloc_after > driver_alloc_before)
+ # to verify this test, run XCode Instruments "Metal System Trace" or "Logging" tool,
+ # press record, then run this python test, and press stop. Next expand
+ # the os_signposts->PyTorchMPS and check if events or intervals are logged
+ # like this example:
+ # "aten::mps_convolution_backward_input:f32[1,128,6,6]:f32[128,64,3,3]:1,128,6,6 (id=G2, run=2)"
+ def test_mps_profiler_module(self):
+ with torch.mps.profiler.profile(mode="event", wait_until_completed=False) as p:
+ # just running some ops to capture the OS Signposts traces for profiling
+ net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
+ .to(device='mps', dtype=torch.float)
+ x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
+ x = net1(x)
+
+ torch.mps.profiler.start(mode="interval", wait_until_completed=True)
+ # just running some ops to capture the OS Signposts traces for profiling
+ x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
+ x = net1(x)
+ torch.mps.profiler.stop()
+
# Test random_, random_.to and random_.from
def test_random(self):
def helper(shape, low, high, dtype=torch.int32):
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index a2bc11e..0fb1976 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -1450,6 +1450,8 @@
def _mps_driverAllocatedMemory() -> _int: ...
def _mps_is_available() -> _bool: ...
def _mps_is_on_macos_13_or_newer(minor: _int) -> _bool: ...
+def _mps_profilerStartTrace(mode: str, wait_until_completed: _bool) -> None: ...
+def _mps_profilerStopTrace() -> None: ...
# Defined in torch/csrc/cuda/Module.cpp
def _cuda_getCurrentStream(device: _int) -> Tuple: ...
diff --git a/torch/csrc/mps/Module.cpp b/torch/csrc/mps/Module.cpp
index 0a1c45c..7433158 100644
--- a/torch/csrc/mps/Module.cpp
+++ b/torch/csrc/mps/Module.cpp
@@ -4,6 +4,7 @@
#include <torch/csrc/THP.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/python_numbers.h>
+#include <torch/csrc/utils/python_strings.h>
// pthread.h is included for tracking bad forks
#ifndef WIN32
@@ -94,8 +95,8 @@
THPUtils_checkDouble(args), "invalid argument to setMemoryFraction()");
double fraction = THPUtils_unpackDouble(args);
at::detail::getMPSHooks().setMemoryFraction(fraction);
- END_HANDLE_TH_ERRORS
Py_RETURN_NONE;
+ END_HANDLE_TH_ERRORS
}
static PyObject* MPSModule_currentAllocatedMemory(
@@ -116,6 +117,33 @@
END_HANDLE_TH_ERRORS
}
+static PyObject* MPSModule_profilerStartTrace(
+ PyObject* _unused,
+ PyObject* args) {
+ HANDLE_TH_ERRORS
+ PyObject* mode_string_o = nullptr;
+ PyObject* wait_until_completed_string_o = nullptr;
+ if (!PyArg_ParseTuple(
+ args, "OO", &mode_string_o, &wait_until_completed_string_o)) {
+ return nullptr;
+ }
+ const std::string mode = THPUtils_unpackString(mode_string_o);
+ const bool waitUntilCompleted =
+ THPUtils_unpackBool(wait_until_completed_string_o);
+ at::detail::getMPSHooks().profilerStartTrace(mode, waitUntilCompleted);
+ Py_RETURN_NONE;
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject* MPSModule_profilerStopTrace(
+ PyObject* _unused,
+ PyObject* noargs) {
+ HANDLE_TH_ERRORS
+ at::detail::getMPSHooks().profilerStopTrace();
+ Py_RETURN_NONE;
+ END_HANDLE_TH_ERRORS
+}
+
// NOLINTNEXTLINE(modernize-avoid-c-arrays,
// cppcoreguidelines-avoid-non-const-global-variables,
// cppcoreguidelines-avoid-c-arrays)
@@ -141,6 +169,14 @@
MPSModule_driverAllocatedMemory,
METH_NOARGS,
nullptr},
+ {"_mps_profilerStartTrace",
+ MPSModule_profilerStartTrace,
+ METH_VARARGS,
+ nullptr},
+ {"_mps_profilerStopTrace",
+ MPSModule_profilerStopTrace,
+ METH_NOARGS,
+ nullptr},
{nullptr}};
PyMethodDef* python_functions() {
diff --git a/torch/mps/__init__.py b/torch/mps/__init__.py
index 2ab9555..beb0072 100644
--- a/torch/mps/__init__.py
+++ b/torch/mps/__init__.py
@@ -98,7 +98,9 @@
"""
return torch._C._mps_driverAllocatedMemory()
+from . import profiler
+
__all__ = [
'get_rng_state', 'manual_seed', 'seed', 'set_rng_state', 'synchronize',
'empty_cache', 'set_per_process_memory_fraction', 'current_allocated_memory',
- 'driver_allocated_memory']
+ 'driver_allocated_memory', 'profiler']
diff --git a/torch/mps/profiler.py b/torch/mps/profiler.py
new file mode 100644
index 0000000..5ad94d0
--- /dev/null
+++ b/torch/mps/profiler.py
@@ -0,0 +1,55 @@
+import torch
+import contextlib
+
+__all__ = ["start", "stop", "profile"]
+
+def start(mode: str = "interval", wait_until_completed: bool = False) -> None:
+ r"""Start OS Signpost tracing from MPS backend.
+
+ The generated OS Signposts could be recorded and viewed in
+ XCode Instruments Logging tool.
+
+ Args:
+ mode(str): OS Signpost tracing mode could be "interval", "event",
+ or both "interval,event".
+ The interval mode traces the duration of execution of the operations,
+ whereas event mode marks the completion of executions.
+ See document `Recording Performance Data`_ for more info.
+ wait_until_completed(bool): Waits until the MPS Stream complete
+ executing each encoded GPU operation. This helps generating single
+ dispatches on the trace's timeline.
+ Note that enabling this option would affect the performance negatively.
+
+ .. _Recording Performance Data:
+ https://developer.apple.com/documentation/os/logging/recording_performance_data
+ """
+ mode_normalized = mode.lower().replace(" ", "")
+ torch._C._mps_profilerStartTrace(mode_normalized, wait_until_completed)
+
+def stop():
+ r"""Stops generating OS Signpost tracing from MPS backend."""
+ torch._C._mps_profilerStopTrace()
+
[email protected]
+def profile(mode: str = "interval", wait_until_completed: bool = False):
+ r"""Context Manager to enabling generating OS Signpost tracing from MPS backend.
+
+ Args:
+ mode(str): OS Signpost tracing mode could be "interval", "event",
+ or both "interval,event".
+ The interval mode traces the duration of execution of the operations,
+ whereas event mode marks the completion of executions.
+ See document `Recording Performance Data`_ for more info.
+ wait_until_completed(bool): Waits until the MPS Stream complete
+ executing each encoded GPU operation. This helps generating single
+ dispatches on the trace's timeline.
+ Note that enabling this option would affect the performance negatively.
+
+ .. _Recording Performance Data:
+ https://developer.apple.com/documentation/os/logging/recording_performance_data
+ """
+ try:
+ start(mode, wait_until_completed)
+ yield
+ finally:
+ stop()