[MPS] Introduce torch.mps.Event() APIs (#102121)
- Implement `MPSEventPool` to recycle events.
- Implement python bindings with `torch.mps.Event` class using the MPSEventPool backend. The current member functions of the Event class are `record()`, `wait()`, `synchronize()`, `query()`, and `elapsed_time()`.
- Add API to measure elapsed time between two event recordings.
- Added documentation for Event class to `mps.rst`.
- Added test case to `test_mps.py`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102121
Approved by: https://github.com/albanD, https://github.com/kulinseth
diff --git a/aten/src/ATen/detail/MPSHooksInterface.h b/aten/src/ATen/detail/MPSHooksInterface.h
index a399d71..690fee4 100644
--- a/aten/src/ATen/detail/MPSHooksInterface.h
+++ b/aten/src/ATen/detail/MPSHooksInterface.h
@@ -20,7 +20,7 @@
// this fails the implementation if MPSHooks functions are called, but
// MPS backend is not present.
#define FAIL_MPSHOOKS_FUNC(func) \
- TORCH_CHECK(false, "Cannot execute ", func ,"() without MPS backend.");
+ TORCH_CHECK(false, "Cannot execute ", func, "() without MPS backend.");
virtual ~MPSHooksInterface() = default;
@@ -64,16 +64,35 @@
virtual void setMemoryFraction(double /*ratio*/) const {
FAIL_MPSHOOKS_FUNC(__func__);
}
-
virtual void profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const {
FAIL_MPSHOOKS_FUNC(__func__);
}
-
virtual void profilerStopTrace() const {
FAIL_MPSHOOKS_FUNC(__func__);
}
+ virtual uint32_t acquireEvent(bool enable_timing) const {
+ FAIL_MPSHOOKS_FUNC(__func__);
+ }
+ virtual void releaseEvent(uint32_t event_id) const {
+ FAIL_MPSHOOKS_FUNC(__func__);
+ }
+ virtual void recordEvent(uint32_t event_id) const {
+ FAIL_MPSHOOKS_FUNC(__func__);
+ }
+ virtual void waitForEvent(uint32_t event_id) const {
+ FAIL_MPSHOOKS_FUNC(__func__);
+ }
+ virtual void synchronizeEvent(uint32_t event_id) const {
+ FAIL_MPSHOOKS_FUNC(__func__);
+ }
+ virtual bool queryEvent(uint32_t event_id) const {
+ FAIL_MPSHOOKS_FUNC(__func__);
+ }
+ virtual double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const {
+ FAIL_MPSHOOKS_FUNC(__func__);
+ }
- #undef FAIL_MPSHOOKS_FUNC
+ #undef FAIL_MPSHOOKS_FUNC
};
struct TORCH_API MPSHooksArgs {};
diff --git a/aten/src/ATen/mps/MPSEvent.h b/aten/src/ATen/mps/MPSEvent.h
new file mode 100644
index 0000000..880ff1c
--- /dev/null
+++ b/aten/src/ATen/mps/MPSEvent.h
@@ -0,0 +1,100 @@
+// Copyright © 2023 Apple Inc.
+
+#pragma once
+
+#include <ATen/mps/MPSStream.h>
+#include <ctime>
+#include <stack>
+
+namespace at::mps {
+
+// NOTE: don't create instances of this class directly.
+// Use MPSEventPool to acquire instances of MPSEvent.
+class MPSEvent {
+public:
+ explicit MPSEvent(id_t ID, MPSStream* stream, bool enable_timing);
+ ~MPSEvent();
+
+ // records an event on the stream
+ void record(bool needsLock, bool syncEvent = false);
+ // makes all future work submitted to the stream wait for this event.
+ bool wait(bool needsLock, bool syncEvent = false);
+ // schedules a notifyListener callback for the event.
+ bool notify(bool needsLock, MTLSharedEventNotificationBlock block);
+ // checks if events are already signaled.
+ bool query() const;
+ // blocks the CPU thread until all the GPU work that were scheduled
+ // prior to recording this event are completed.
+ bool synchronize();
+ // resets this event with new parameters in case it gets reused from the event pool
+ void reset(MPSStream* stream, bool enable_timing);
+ // returns the unique ID of the event instance
+ id_t getID() const { return m_id; }
+ // returns the completion timestamp of the event
+ uint64_t getCompletionTime() const { return m_completion_time; }
+ // if already recorded, waits for cpu_sync_cv to be signaled
+ void waitForCpuSync();
+
+private:
+ id_t m_id;
+ // enables measuring the completion time of the notifyListener of this event
+ bool m_enable_timing;
+ uint64_t m_signalCounter = 0;
+ MPSStream* m_stream = nullptr;
+ MTLSharedEvent_t m_event = nullptr;
+ MTLSharedEventListener* m_listener = nullptr;
+ // used to sync the events created on this Stream with CPU
+ std::mutex m_cpu_sync_mutex{};
+ std::condition_variable m_cpu_sync_cv{};
+ // CondVar predicate to sync the events created on this Stream with CPU
+ bool m_cpu_sync_completed = false;
+ // used to compute elapsed time
+ uint64_t m_completion_time = 0;
+
+ void recordLocked(bool syncEvent);
+ bool waitLocked(bool syncEvent);
+ bool notifyLocked(MTLSharedEventNotificationBlock block);
+ void notifyCpuSync();
+ static uint64_t getTime() {
+ return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW);
+ }
+};
+
+typedef std::unique_ptr<MPSEvent, std::function<void(MPSEvent*)>> MPSEventPtr;
+
+class MPSEventPool {
+public:
+ explicit MPSEventPool(MPSStream* default_stream);
+ ~MPSEventPool();
+
+ MPSEventPtr acquireEvent(bool enable_timing, MPSStream* stream);
+ void emptyCache();
+
+ // these are mainly used for MPSHooks and torch.mps.Event() bindings
+ id_t acquireEvent(bool enable_timing);
+ void releaseEvent(id_t event_id);
+ void recordEvent(id_t event_id, bool syncEvent);
+ void waitForEvent(id_t event_id, bool syncEvent);
+ void synchronizeEvent(id_t event_id);
+ bool queryEvent(id_t event_id);
+ // returns elapsed time between two recorded events in milliseconds
+ double elapsedTime(id_t start_event_id, id_t end_event_id);
+
+private:
+ MPSStream* m_default_stream = nullptr;
+ std::recursive_mutex m_mutex;
+ std::stack<std::unique_ptr<MPSEvent>> m_pool{};
+ // dictionary to associate event IDs with event objects
+ // used to retain in-use events out of the pool
+ // for torch.mps.Event() bindings.
+ std::unordered_map<id_t, MPSEventPtr> m_in_use_events{};
+ uint64_t m_event_counter = 0;
+ std::function<void(MPSEvent*)> m_default_deleter;
+
+ MPSEvent* getInUseEvent(id_t event_id, bool locked = true);
+};
+
+// shared_ptr is used to get MPSEventPool destroyed after dependent instances
+std::shared_ptr<MPSEventPool> getMPSEventPool();
+
+} // namespace at::mps
diff --git a/aten/src/ATen/mps/MPSEvent.mm b/aten/src/ATen/mps/MPSEvent.mm
new file mode 100644
index 0000000..ac46461
--- /dev/null
+++ b/aten/src/ATen/mps/MPSEvent.mm
@@ -0,0 +1,257 @@
+// Copyright © 2023 Apple Inc.
+
+#include <ATen/mps/MPSEvent.h>
+
+namespace at::mps {
+
+MPSEvent::MPSEvent(id_t ID, MPSStream* stream, bool enable_timing)
+ : m_id(ID), m_enable_timing(enable_timing), m_stream(stream), m_event([stream->device() newSharedEvent]) {}
+
+MPSEvent::~MPSEvent() {
+ if (m_event) {
+ [m_event release];
+ m_event = nil;
+ }
+ if (m_listener) {
+ [m_listener release];
+ m_listener = nil;
+ }
+}
+
+void MPSEvent::recordLocked(bool syncEvent) {
+ // active encoders must end before encoding or waiting
+ m_stream->endKernelCoalescing();
+ ++m_signalCounter;
+ if (m_enable_timing) {
+ notifyLocked(^(id<MTLSharedEvent>, uint64_t) {
+ m_completion_time = getTime();
+ notifyCpuSync();
+ });
+ }
+ id<MTLCommandBuffer> commandBuffer = m_stream->commandBuffer();
+ [commandBuffer encodeSignalEvent:m_event value:m_signalCounter];
+ if (syncEvent) {
+ m_stream->synchronize(SyncType::COMMIT);
+ }
+}
+
+bool MPSEvent::waitLocked(bool syncEvent) {
+ // check if event is not recorded yet
+ if (m_event.signaledValue >= m_signalCounter) {
+ return false;
+ }
+ // active encoders must end before encoding or waiting
+ m_stream->endKernelCoalescing();
+ id<MTLCommandBuffer> commandBuffer = m_stream->commandBuffer();
+ [commandBuffer encodeWaitForEvent:m_event value:m_signalCounter];
+ if (syncEvent) {
+ m_stream->synchronize(SyncType::COMMIT);
+ }
+ return true;
+}
+
+bool MPSEvent::notifyLocked(MTLSharedEventNotificationBlock block) {
+ // check if event is not recorded yet
+ if (m_event.signaledValue >= m_signalCounter) {
+ return false;
+ }
+ if (!m_listener) {
+ m_listener = [[MTLSharedEventListener alloc] init];
+ }
+ [m_event notifyListener:m_listener atValue:m_signalCounter block:block];
+ return true;
+}
+
+void MPSEvent::record(bool needsLock, bool syncEvent) {
+ if (!needsLock) {
+ recordLocked(syncEvent);
+ return;
+ }
+ dispatch_sync(m_stream->queue(), ^() {
+ @autoreleasepool {
+ recordLocked(syncEvent);
+ }
+ });
+}
+
+bool MPSEvent::wait(bool needsLock, bool syncEvent) {
+ __block bool waited = false;
+ if (!needsLock) {
+ return waitLocked(syncEvent);
+ }
+ dispatch_sync(m_stream->queue(), ^() {
+ @autoreleasepool {
+ waited = waitLocked(syncEvent);
+ }
+ });
+ return waited;
+}
+
+bool MPSEvent::notify(bool needsLock, MTLSharedEventNotificationBlock block) {
+ if (!needsLock) {
+ return notifyLocked(block);
+ }
+ __block bool scheduledNotify = false;
+ dispatch_sync(m_stream->queue(), ^() {
+ @autoreleasepool {
+ scheduledNotify = notifyLocked(block);
+ }
+ });
+ return scheduledNotify;
+}
+
+void MPSEvent::notifyCpuSync() {
+ std::lock_guard<std::mutex> lock(m_cpu_sync_mutex);
+ m_cpu_sync_completed = true;
+ m_cpu_sync_cv.notify_one();
+}
+
+void MPSEvent::waitForCpuSync() {
+ std::unique_lock<std::mutex> lock(m_cpu_sync_mutex);
+ m_cpu_sync_cv.wait(lock, [&] { return m_cpu_sync_completed; });
+ m_cpu_sync_completed = false;
+}
+
+bool MPSEvent::synchronize() {
+ bool scheduledNotify = notifyLocked(^(id<MTLSharedEvent>, uint64_t) {
+ m_completion_time = getTime();
+ notifyCpuSync();
+ });
+
+ if (scheduledNotify) {
+ waitForCpuSync();
+ return true;
+ }
+ return false;
+}
+
+bool MPSEvent::query() const {
+ // return false if not recorded or signaled yet
+ return m_signalCounter && (m_event.signaledValue >= m_signalCounter);
+}
+
+void MPSEvent::reset(MPSStream* stream, bool enable_timing) {
+ if (stream != m_stream) {
+ m_signalCounter = 0;
+ m_event.signaledValue = 0;
+ m_stream = stream;
+ }
+ // reset record time
+ m_completion_time = 0;
+ m_enable_timing = enable_timing;
+ m_cpu_sync_completed = false;
+};
+
+//-----------------------------------------------------------------
+// MPSEventPool
+//-----------------------------------------------------------------
+
+MPSEventPool::MPSEventPool(MPSStream* default_stream) : m_default_stream(default_stream) {
+ // default deleter to return the event back to pool after it gets released
+ m_default_deleter = [&](MPSEvent* event) {
+ std::lock_guard<std::recursive_mutex> lock(m_mutex);
+ m_pool.push(std::unique_ptr<MPSEvent>(event));
+ };
+}
+
+MPSEventPool::~MPSEventPool() {
+ emptyCache();
+}
+
+MPSEventPtr MPSEventPool::acquireEvent(bool enable_timing, MPSStream* stream) {
+ if (!stream) {
+ stream = m_default_stream;
+ }
+ {
+ std::lock_guard<std::recursive_mutex> lock(m_mutex);
+ if (!m_pool.empty()) {
+ auto event = m_pool.top().release();
+ m_pool.pop();
+ event->reset(stream, enable_timing);
+ return MPSEventPtr(event, m_default_deleter);
+ }
+ }
+ auto new_event = std::make_unique<MPSEvent>(++m_event_counter, stream, enable_timing);
+ return MPSEventPtr(new_event.release(), m_default_deleter);
+}
+
+void MPSEventPool::emptyCache() {
+ std::lock_guard<std::recursive_mutex> lock(m_mutex);
+ while (!m_pool.empty()) {
+ m_pool.pop();
+ }
+}
+
+id_t MPSEventPool::acquireEvent(bool enable_timing) {
+ std::lock_guard<std::recursive_mutex> lock(m_mutex);
+ MPSEventPtr event = acquireEvent(enable_timing, nullptr);
+ TORCH_INTERNAL_ASSERT(event);
+ id_t event_id = event->getID();
+ m_in_use_events.emplace(event_id, std::move(event));
+ return event_id;
+}
+
+void MPSEventPool::releaseEvent(id_t event_id) {
+ std::lock_guard<std::recursive_mutex> lock(m_mutex);
+ TORCH_CHECK(m_in_use_events.count(event_id) > 0, "Invalid Event ID: ", event_id);
+ // returns the event back to the MPSEventPool
+ m_in_use_events.erase(event_id);
+}
+
+void MPSEventPool::recordEvent(id_t event_id, bool syncEvent) {
+ MPSEvent* event = getInUseEvent(event_id);
+ event->record(/*needsLock*/ true, syncEvent);
+}
+
+void MPSEventPool::waitForEvent(id_t event_id, bool syncEvent) {
+ MPSEvent* event = getInUseEvent(event_id);
+ event->wait(/*needsLock*/ true, syncEvent);
+}
+
+void MPSEventPool::synchronizeEvent(id_t event_id) {
+ MPSEvent* event = getInUseEvent(event_id);
+ event->synchronize();
+}
+
+bool MPSEventPool::queryEvent(id_t event_id) {
+ MPSEvent* event = getInUseEvent(event_id);
+ return event->query();
+}
+
+double MPSEventPool::elapsedTime(id_t start_event_id, id_t end_event_id) {
+ // first make sure notifyListeners are called to capture events' completion times
+ dispatch_sync(m_default_stream->queue(), ^() {
+ m_default_stream->synchronize(SyncType::COMMIT_AND_WAIT);
+ });
+ std::lock_guard<std::recursive_mutex> lock(m_mutex);
+ MPSEvent* start_event = getInUseEvent(start_event_id, false);
+ MPSEvent* end_event = getInUseEvent(end_event_id, false);
+ // the notify is called on a separate thread, so this waits for that
+ end_event->waitForCpuSync();
+ const uint64_t start_time = start_event->getCompletionTime();
+ const uint64_t end_time = end_event->getCompletionTime();
+
+ TORCH_CHECK(start_time > 0 && end_time > 0, "Events were not created with argument 'enable_timing=True'");
+ TORCH_CHECK(
+ end_time > start_time, "End event ", end_event_id, " was not recorded after start event ", start_event_id);
+ return double(end_time - start_time) * 1e-6;
+}
+
+MPSEvent* MPSEventPool::getInUseEvent(id_t event_id, bool locked) {
+ if (locked) {
+ m_mutex.lock();
+ }
+ TORCH_CHECK(m_in_use_events.count(event_id) > 0, "Invalid Event ID: ", event_id);
+ MPSEvent* event = m_in_use_events[event_id].get();
+ if (locked) {
+ m_mutex.unlock();
+ }
+ return event;
+}
+
+std::shared_ptr<MPSEventPool> getMPSEventPool() {
+ static std::shared_ptr<MPSEventPool> event_pool = std::make_shared<MPSEventPool>(getDefaultMPSStream());
+ return event_pool;
+}
+
+} // namespace at::mps
diff --git a/aten/src/ATen/mps/MPSGuardImpl.h b/aten/src/ATen/mps/MPSGuardImpl.h
index e122bf3..dd17d0e 100644
--- a/aten/src/ATen/mps/MPSGuardImpl.h
+++ b/aten/src/ATen/mps/MPSGuardImpl.h
@@ -6,6 +6,7 @@
#include <c10/util/Exception.h>
#include <ATen/Context.h>
#include <ATen/mps/MPSStream.h>
+#include <ATen/mps/MPSEvent.h>
#ifdef __OBJC__
#include <Foundation/Foundation.h>
@@ -26,6 +27,8 @@
namespace at {
namespace mps {
+typedef MPSEvent* mpsEvent_t;
+
// TODO: Move the MPSGuardImpl to inherit from NoOpDeviceGuardImpl
// https://github.com/pytorch/pytorch/issues/77170
struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface {
diff --git a/aten/src/ATen/mps/MPSGuardImpl.mm b/aten/src/ATen/mps/MPSGuardImpl.mm
index 9c20471..0a1ba6b 100644
--- a/aten/src/ATen/mps/MPSGuardImpl.mm
+++ b/aten/src/ATen/mps/MPSGuardImpl.mm
@@ -28,19 +28,19 @@
auto mps_event = static_cast<mpsEvent_t>(*event);
MPSStream mps_stream{stream};
- mps_event->recordEvent(true);
+ mps_event->record(true);
}
void MPSGuardImpl::block(void* event, const Stream& stream) const {
auto mps_event = static_cast<mpsEvent_t>(event);
MPSStream mps_stream{stream};
- mps_event->waitForEvent(true);
+ mps_event->wait(true, false);
}
bool MPSGuardImpl::queryEvent(void* event) const {
auto mps_event = static_cast<mpsEvent_t>(event);
- return mps_event->queryEvent();
+ return mps_event->query();
}
}
diff --git a/aten/src/ATen/mps/MPSHooks.h b/aten/src/ATen/mps/MPSHooks.h
index dc4a0a3..e2a5e0c 100644
--- a/aten/src/ATen/mps/MPSHooks.h
+++ b/aten/src/ATen/mps/MPSHooks.h
@@ -4,6 +4,7 @@
#include <ATen/detail/MPSHooksInterface.h>
#include <ATen/Generator.h>
+#include <ATen/mps/MPSEvent.h>
#include <c10/util/Optional.h>
namespace at { namespace mps {
@@ -32,8 +33,19 @@
size_t getCurrentAllocatedMemory() const override;
size_t getDriverAllocatedMemory() const override;
void setMemoryFraction(double ratio) const override;
+
+ // MPSProfiler interface
void profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const override;
void profilerStopTrace() const override;
+
+ // MPSEvent interface
+ uint32_t acquireEvent(bool enable_timing) const override;
+ void releaseEvent(uint32_t event_id) const override;
+ void recordEvent(uint32_t event_id) const override;
+ void waitForEvent(uint32_t event_id) const override;
+ void synchronizeEvent(uint32_t event_id) const override;
+ bool queryEvent(uint32_t event_id) const override;
+ double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const override;
};
}} // at::mps
diff --git a/aten/src/ATen/mps/MPSHooks.mm b/aten/src/ATen/mps/MPSHooks.mm
index 62debfd..d94b985 100644
--- a/aten/src/ATen/mps/MPSHooks.mm
+++ b/aten/src/ATen/mps/MPSHooks.mm
@@ -84,6 +84,34 @@
at::mps::getMPSProfiler().StopTrace();
}
+uint32_t MPSHooks::acquireEvent(bool enable_timing) const {
+ return at::mps::getMPSEventPool()->acquireEvent(enable_timing);
+}
+
+void MPSHooks::releaseEvent(uint32_t event_id) const {
+ at::mps::getMPSEventPool()->releaseEvent(event_id);
+}
+
+void MPSHooks::recordEvent(uint32_t event_id) const {
+ at::mps::getMPSEventPool()->recordEvent(event_id, /* syncEvent*/ true);
+}
+
+void MPSHooks::waitForEvent(uint32_t event_id) const {
+ at::mps::getMPSEventPool()->waitForEvent(event_id, /* syncEvent*/ true);
+}
+
+void MPSHooks::synchronizeEvent(uint32_t event_id) const {
+ at::mps::getMPSEventPool()->synchronizeEvent(event_id);
+}
+
+bool MPSHooks::queryEvent(uint32_t event_id) const {
+ return at::mps::getMPSEventPool()->queryEvent(event_id);
+}
+
+double MPSHooks::elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const {
+ return at::mps::getMPSEventPool()->elapsedTime(start_event_id, end_event_id);
+}
+
using at::MPSHooksRegistry;
using at::RegistererMPSHooksRegistry;
diff --git a/aten/src/ATen/mps/MPSStream.h b/aten/src/ATen/mps/MPSStream.h
index 360d950..aa11e65 100644
--- a/aten/src/ATen/mps/MPSStream.h
+++ b/aten/src/ATen/mps/MPSStream.h
@@ -131,38 +131,5 @@
MPSStreamImpl();
};
-
-//-----------------------------------------------------------------
-// MPSEvent
-//-----------------------------------------------------------------
-
-struct TORCH_API MPSEvent
-{
- // for a new instance of MPSEvent, sometimes we want an empty shell and don't
- // necessarily want to create events or listeners. So we defer initialization
- // until we actually use the event (e.g., record, notify, etc.)
- MPSEvent(bool deferInitialization = true);
- ~MPSEvent();
- MTLSharedEvent_t event() const {return _event; }
-
- void recordEvent(bool syncEvent = false);
- void waitForEvent(bool syncEvent = false); // waits on the cpu
- void notifyEvent(MTLSharedEventNotificationBlock block);
- bool queryEvent() const;
- uint64_t getCurrentValue() const { return _signalCounter; }
- void setCurrentValue(uint64_t currValue) { _signalCounter = currValue; }
-private:
- bool is_initialized;
- uint64_t _signalCounter;
- MPSStream* _stream;
- MTLSharedEvent_t _event;
- MTLSharedEventListener* _listener;
-
- void initialize();
-};
-
-typedef MPSEvent* mpsEvent_t;
-
-
} // namespace mps
} // namespace at
diff --git a/aten/src/ATen/mps/MPSStream.mm b/aten/src/ATen/mps/MPSStream.mm
index f78e303..e9bb648 100644
--- a/aten/src/ATen/mps/MPSStream.mm
+++ b/aten/src/ATen/mps/MPSStream.mm
@@ -264,77 +264,5 @@
return MPSStreamImpl::getInstance();
}
-//-----------------------------------------------------------------
-// MPSEvent
-//-----------------------------------------------------------------
-
-MPSEvent::MPSEvent(bool deferInitialization)
- : is_initialized(false), _signalCounter(0), _stream(nil), _event(nil), _listener(nil) {
- if (!deferInitialization) {
- initialize();
- }
-}
-
-MPSEvent::~MPSEvent() {
- if (_event) {
- [_event release];
- _event = nil;
- }
- if (_listener) {
- [_listener release];
- _listener = nil;
- }
-}
-
-void MPSEvent::initialize() {
- _stream = getDefaultMPSStream();
- _event = [_stream->device() newSharedEvent];
- _listener = [[MTLSharedEventListener alloc] init];
- is_initialized = true;
-}
-
-void MPSEvent::recordEvent(bool syncEvent) {
- if (!is_initialized)
- initialize();
-
- dispatch_sync(_stream->queue(), ^() {
- @autoreleasepool {
- ++_signalCounter;
- id<MTLCommandBuffer> commandBuffer = _stream->commandBuffer();
- [commandBuffer encodeSignalEvent:_event value:_signalCounter];
- if (syncEvent)
- _stream->synchronize(SyncType::COMMIT);
- }
- });
-}
-
-void MPSEvent::waitForEvent(bool syncEvent) {
- TORCH_INTERNAL_ASSERT(is_initialized);
- dispatch_sync(_stream->queue(), ^() {
- @autoreleasepool {
- id<MTLCommandBuffer> commandBuffer = _stream->commandBuffer();
- [commandBuffer encodeWaitForEvent:_event value:_signalCounter];
- if (syncEvent)
- _stream->synchronize(SyncType::COMMIT);
- }
- });
-}
-
-void MPSEvent::notifyEvent(MTLSharedEventNotificationBlock block) {
- if (!is_initialized)
- initialize();
- dispatch_sync(_stream->queue(), ^() {
- @autoreleasepool {
- ++_signalCounter;
- [_event notifyListener:_listener atValue:_signalCounter block:block];
- }
- });
-}
-
-bool MPSEvent::queryEvent() const {
- // return false if not recorded or signaled yet
- return _signalCounter && (_event.signaledValue >= _signalCounter);
-}
-
} // namespace mps
} // namespace at
diff --git a/docs/source/mps.rst b/docs/source/mps.rst
index 7ed30f9..03ec57c 100644
--- a/docs/source/mps.rst
+++ b/docs/source/mps.rst
@@ -26,3 +26,11 @@
profiler.start
profiler.stop
profiler.profile
+
+MPS Event
+------------
+.. autosummary::
+ :toctree: generated
+ :nosignatures:
+
+ event.Event
diff --git a/test/test_mps.py b/test/test_mps.py
index 19e2273..63e2f5f 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -7493,6 +7493,18 @@
x = net1(x)
torch.mps.profiler.stop()
+ def test_mps_event_module(self):
+ startEvent = torch.mps.Event(enable_timing=True)
+ startEvent.record()
+ net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
+ .to(device='mps', dtype=torch.float)
+ x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
+ x = net1(x)
+ endEvent = torch.mps.Event(enable_timing=True)
+ endEvent.record()
+ elapsedTime = startEvent.elapsed_time(endEvent)
+ self.assertTrue(elapsedTime > 0.0)
+
def test_jit_save_load(self):
m = torch.nn.Module()
m.x = torch.rand(3, 3, device='mps')
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index d74d887..dd07b7f 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -1508,6 +1508,14 @@
def _mps_is_on_macos_13_or_newer(minor: _int) -> _bool: ...
def _mps_profilerStartTrace(mode: str, wait_until_completed: _bool) -> None: ...
def _mps_profilerStopTrace() -> None: ...
+def _mps_acquireEvent(enable_timing: _bool) -> _int: ...
+def _mps_releaseEvent(event_id: _int) -> None: ...
+def _mps_recordEvent(event_id: _int) -> None: ...
+def _mps_waitForEvent(event_id: _int) -> None: ...
+def _mps_synchronizeEvent(event_id: _int) -> None: ...
+def _mps_queryEvent(event_id: _int) -> _bool: ...
+def _mps_elapsedTimeOfEvents(start_event_id: _int, end_event_id: _int) -> _float: ...
+
# Defined in torch/csrc/cuda/Module.cpp
def _cuda_getCurrentStream(device: _int) -> Tuple: ...
diff --git a/torch/csrc/mps/Module.cpp b/torch/csrc/mps/Module.cpp
index a5021bb..87e9900 100644
--- a/torch/csrc/mps/Module.cpp
+++ b/torch/csrc/mps/Module.cpp
@@ -105,7 +105,7 @@
PyObject* _unused,
PyObject* noargs) {
HANDLE_TH_ERRORS
- return PyLong_FromUnsignedLongLong(
+ return THPUtils_packUInt64(
at::detail::getMPSHooks().getCurrentAllocatedMemory());
END_HANDLE_TH_ERRORS
}
@@ -114,7 +114,7 @@
PyObject* _unused,
PyObject* noargs) {
HANDLE_TH_ERRORS
- return PyLong_FromUnsignedLongLong(
+ return THPUtils_packUInt64(
at::detail::getMPSHooks().getDriverAllocatedMemory());
END_HANDLE_TH_ERRORS
}
@@ -146,6 +146,74 @@
END_HANDLE_TH_ERRORS
}
+static PyObject* MPSModule_acquireEvent(PyObject* _unused, PyObject* args) {
+ HANDLE_TH_ERRORS
+ const bool enable_timing = THPUtils_unpackBool(args);
+ return THPUtils_packUInt32(
+ at::detail::getMPSHooks().acquireEvent(enable_timing));
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject* MPSModule_releaseEvent(PyObject* _unused, PyObject* args) {
+ HANDLE_TH_ERRORS
+ const uint32_t event_id = THPUtils_unpackUInt32(args);
+ at::detail::getMPSHooks().releaseEvent(event_id);
+ Py_RETURN_NONE;
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject* MPSModule_recordEvent(PyObject* _unused, PyObject* args) {
+ HANDLE_TH_ERRORS
+ const uint32_t event_id = THPUtils_unpackUInt32(args);
+ at::detail::getMPSHooks().recordEvent(event_id);
+ Py_RETURN_NONE;
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject* MPSModule_waitForEvent(PyObject* _unused, PyObject* args) {
+ HANDLE_TH_ERRORS
+ const uint32_t event_id = THPUtils_unpackUInt32(args);
+ at::detail::getMPSHooks().waitForEvent(event_id);
+ Py_RETURN_NONE;
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject* MPSModule_synchronizeEvent(PyObject* _unused, PyObject* args) {
+ HANDLE_TH_ERRORS
+ const uint32_t event_id = THPUtils_unpackUInt32(args);
+ at::detail::getMPSHooks().synchronizeEvent(event_id);
+ Py_RETURN_NONE;
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject* MPSModule_queryEvent(PyObject* _unused, PyObject* args) {
+ HANDLE_TH_ERRORS
+ const uint32_t event_id = THPUtils_unpackUInt32(args);
+
+ if (at::detail::getMPSHooks().queryEvent(event_id)) {
+ Py_RETURN_TRUE;
+ } else {
+ Py_RETURN_FALSE;
+ }
+ END_HANDLE_TH_ERRORS
+}
+
+static PyObject* MPSModule_elapsedTimeOfEvents(
+ PyObject* _unused,
+ PyObject* args) {
+ HANDLE_TH_ERRORS
+ PyObject* start_event_o = nullptr;
+ PyObject* end_event_o = nullptr;
+ if (!PyArg_ParseTuple(args, "OO", &start_event_o, &end_event_o)) {
+ return nullptr;
+ }
+ const uint32_t start_event_id = THPUtils_unpackUInt32(start_event_o);
+ const uint32_t end_event_id = THPUtils_unpackUInt32(end_event_o);
+ return PyFloat_FromDouble(at::detail::getMPSHooks().elapsedTimeOfEvents(
+ start_event_id, end_event_id));
+ END_HANDLE_TH_ERRORS
+}
+
// NOLINTNEXTLINE(modernize-avoid-c-arrays,
// cppcoreguidelines-avoid-non-const-global-variables,
// cppcoreguidelines-avoid-c-arrays)
@@ -182,6 +250,16 @@
MPSModule_profilerStopTrace,
METH_NOARGS,
nullptr},
+ {"_mps_acquireEvent", MPSModule_acquireEvent, METH_O, nullptr},
+ {"_mps_releaseEvent", MPSModule_releaseEvent, METH_O, nullptr},
+ {"_mps_recordEvent", MPSModule_recordEvent, METH_O, nullptr},
+ {"_mps_waitForEvent", MPSModule_waitForEvent, METH_O, nullptr},
+ {"_mps_synchronizeEvent", MPSModule_synchronizeEvent, METH_O, nullptr},
+ {"_mps_queryEvent", MPSModule_queryEvent, METH_O, nullptr},
+ {"_mps_elapsedTimeOfEvents",
+ MPSModule_elapsedTimeOfEvents,
+ METH_VARARGS,
+ nullptr},
{nullptr}};
PyMethodDef* python_functions() {
diff --git a/torch/mps/__init__.py b/torch/mps/__init__.py
index 13ba91d..52cda4f 100644
--- a/torch/mps/__init__.py
+++ b/torch/mps/__init__.py
@@ -113,6 +113,7 @@
from . import profiler
+from .event import Event
__all__ = [
"get_rng_state",
@@ -124,5 +125,6 @@
"set_per_process_memory_fraction",
"current_allocated_memory",
"driver_allocated_memory",
+ "Event",
"profiler",
]
diff --git a/torch/mps/event.py b/torch/mps/event.py
new file mode 100644
index 0000000..a206b64
--- /dev/null
+++ b/torch/mps/event.py
@@ -0,0 +1,45 @@
+import torch
+
+
+class Event:
+ r"""Wrapper around an MPS event.
+
+ MPS events are synchronization markers that can be used to monitor the
+ device's progress, to accurately measure timing, and to synchronize MPS streams.
+
+ Args:
+ enable_timing (bool, optional): indicates if the event should measure time
+ (default: ``False``)
+ """
+
+ def __init__(self, enable_timing=False):
+ self.__eventId = torch._C._mps_acquireEvent(enable_timing)
+
+ def __del__(self):
+ # checks if torch._C is already destroyed
+ if hasattr(torch._C, "_mps_releaseEvent") and self.__eventId > 0:
+ torch._C._mps_releaseEvent(self.__eventId)
+
+ def record(self):
+ r"""Records the event in the default stream."""
+ torch._C._mps_recordEvent(self.__eventId)
+
+ def wait(self):
+ r"""Makes all future work submitted to the default stream wait for this event."""
+ torch._C._mps_waitForEvent(self.__eventId)
+
+ def query(self):
+ r"""Returns True if all work currently captured by event has completed."""
+ return torch._C._mps_queryEvent(self.__eventId)
+
+ def synchronize(self):
+ r"""Waits until the completion of all work currently captured in this event.
+ This prevents the CPU thread from proceeding until the event completes.
+ """
+ torch._C._mps_synchronizeEvent(self.__eventId)
+
+ def elapsed_time(self, end_event):
+ r"""Returns the time elapsed in milliseconds after the event was
+ recorded and before the end_event was recorded.
+ """
+ return torch._C._mps_elapsedTimeOfEvents(self.__eventId, end_event.__eventId)