[MPS] Introduce torch.mps.Event() APIs (#102121)

- Implement `MPSEventPool` to recycle events.
- Implement python bindings with `torch.mps.Event` class using the MPSEventPool backend. The current member functions of the Event class are `record()`, `wait()`, `synchronize()`, `query()`, and `elapsed_time()`.
- Add API to measure elapsed time between two event recordings.
- Added documentation for Event class to `mps.rst`.
- Added test case to `test_mps.py`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102121
Approved by: https://github.com/albanD, https://github.com/kulinseth
diff --git a/aten/src/ATen/detail/MPSHooksInterface.h b/aten/src/ATen/detail/MPSHooksInterface.h
index a399d71..690fee4 100644
--- a/aten/src/ATen/detail/MPSHooksInterface.h
+++ b/aten/src/ATen/detail/MPSHooksInterface.h
@@ -20,7 +20,7 @@
   // this fails the implementation if MPSHooks functions are called, but
   // MPS backend is not present.
   #define FAIL_MPSHOOKS_FUNC(func) \
-    TORCH_CHECK(false, "Cannot execute ", func ,"() without MPS backend.");
+    TORCH_CHECK(false, "Cannot execute ", func, "() without MPS backend.");
 
   virtual ~MPSHooksInterface() = default;
 
@@ -64,16 +64,35 @@
   virtual void setMemoryFraction(double /*ratio*/) const {
     FAIL_MPSHOOKS_FUNC(__func__);
   }
-
   virtual void profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const {
     FAIL_MPSHOOKS_FUNC(__func__);
   }
-
   virtual void profilerStopTrace() const {
     FAIL_MPSHOOKS_FUNC(__func__);
   }
+  virtual uint32_t acquireEvent(bool enable_timing) const {
+    FAIL_MPSHOOKS_FUNC(__func__);
+  }
+  virtual void releaseEvent(uint32_t event_id) const {
+    FAIL_MPSHOOKS_FUNC(__func__);
+  }
+  virtual void recordEvent(uint32_t event_id) const {
+    FAIL_MPSHOOKS_FUNC(__func__);
+  }
+  virtual void waitForEvent(uint32_t event_id) const {
+    FAIL_MPSHOOKS_FUNC(__func__);
+  }
+  virtual void synchronizeEvent(uint32_t event_id) const {
+    FAIL_MPSHOOKS_FUNC(__func__);
+  }
+  virtual bool queryEvent(uint32_t event_id) const {
+    FAIL_MPSHOOKS_FUNC(__func__);
+  }
+  virtual double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const {
+    FAIL_MPSHOOKS_FUNC(__func__);
+  }
 
-    #undef FAIL_MPSHOOKS_FUNC
+  #undef FAIL_MPSHOOKS_FUNC
 };
 
 struct TORCH_API MPSHooksArgs {};
diff --git a/aten/src/ATen/mps/MPSEvent.h b/aten/src/ATen/mps/MPSEvent.h
new file mode 100644
index 0000000..880ff1c
--- /dev/null
+++ b/aten/src/ATen/mps/MPSEvent.h
@@ -0,0 +1,100 @@
+//  Copyright © 2023 Apple Inc.
+
+#pragma once
+
+#include <ATen/mps/MPSStream.h>
+#include <ctime>
+#include <stack>
+
+namespace at::mps {
+
+// NOTE: don't create instances of this class directly.
+// Use MPSEventPool to acquire instances of MPSEvent.
+class MPSEvent {
+public:
+  explicit MPSEvent(id_t ID, MPSStream* stream, bool enable_timing);
+  ~MPSEvent();
+
+  // records an event on the stream
+  void record(bool needsLock, bool syncEvent = false);
+  // makes all future work submitted to the stream wait for this event.
+  bool wait(bool needsLock, bool syncEvent = false);
+  // schedules a notifyListener callback for the event.
+  bool notify(bool needsLock, MTLSharedEventNotificationBlock block);
+  // checks if events are already signaled.
+  bool query() const;
+  // blocks the CPU thread until all the GPU work that were scheduled
+  // prior to recording this event are completed.
+  bool synchronize();
+  // resets this event with new parameters in case it gets reused from the event pool
+  void reset(MPSStream* stream, bool enable_timing);
+  // returns the unique ID of the event instance
+  id_t getID() const { return m_id; }
+  // returns the completion timestamp of the event
+  uint64_t getCompletionTime() const { return m_completion_time; }
+  // if already recorded, waits for cpu_sync_cv to be signaled
+  void waitForCpuSync();
+
+private:
+  id_t m_id;
+  // enables measuring the completion time of the notifyListener of this event
+  bool m_enable_timing;
+  uint64_t m_signalCounter = 0;
+  MPSStream* m_stream = nullptr;
+  MTLSharedEvent_t m_event = nullptr;
+  MTLSharedEventListener* m_listener = nullptr;
+  // used to sync the events created on this Stream with CPU
+  std::mutex m_cpu_sync_mutex{};
+  std::condition_variable m_cpu_sync_cv{};
+  // CondVar predicate to sync the events created on this Stream with CPU
+  bool m_cpu_sync_completed = false;
+  // used to compute elapsed time
+  uint64_t m_completion_time = 0;
+
+  void recordLocked(bool syncEvent);
+  bool waitLocked(bool syncEvent);
+  bool notifyLocked(MTLSharedEventNotificationBlock block);
+  void notifyCpuSync();
+  static uint64_t getTime() {
+    return clock_gettime_nsec_np(CLOCK_MONOTONIC_RAW);
+  }
+};
+
+typedef std::unique_ptr<MPSEvent, std::function<void(MPSEvent*)>> MPSEventPtr;
+
+class MPSEventPool {
+public:
+  explicit MPSEventPool(MPSStream* default_stream);
+  ~MPSEventPool();
+
+  MPSEventPtr acquireEvent(bool enable_timing, MPSStream* stream);
+  void emptyCache();
+
+  // these are mainly used for MPSHooks and torch.mps.Event() bindings
+  id_t acquireEvent(bool enable_timing);
+  void releaseEvent(id_t event_id);
+  void recordEvent(id_t event_id, bool syncEvent);
+  void waitForEvent(id_t event_id, bool syncEvent);
+  void synchronizeEvent(id_t event_id);
+  bool queryEvent(id_t event_id);
+  // returns elapsed time between two recorded events in milliseconds
+  double elapsedTime(id_t start_event_id, id_t end_event_id);
+
+private:
+  MPSStream* m_default_stream = nullptr;
+  std::recursive_mutex m_mutex;
+  std::stack<std::unique_ptr<MPSEvent>> m_pool{};
+  // dictionary to associate event IDs with event objects
+  // used to retain in-use events out of the pool
+  // for torch.mps.Event() bindings.
+  std::unordered_map<id_t, MPSEventPtr> m_in_use_events{};
+  uint64_t m_event_counter = 0;
+  std::function<void(MPSEvent*)> m_default_deleter;
+
+  MPSEvent* getInUseEvent(id_t event_id, bool locked = true);
+};
+
+// shared_ptr is used to get MPSEventPool destroyed after dependent instances
+std::shared_ptr<MPSEventPool> getMPSEventPool();
+
+} // namespace at::mps
diff --git a/aten/src/ATen/mps/MPSEvent.mm b/aten/src/ATen/mps/MPSEvent.mm
new file mode 100644
index 0000000..ac46461
--- /dev/null
+++ b/aten/src/ATen/mps/MPSEvent.mm
@@ -0,0 +1,257 @@
+//  Copyright © 2023 Apple Inc.
+
+#include <ATen/mps/MPSEvent.h>
+
+namespace at::mps {
+
+MPSEvent::MPSEvent(id_t ID, MPSStream* stream, bool enable_timing)
+    : m_id(ID), m_enable_timing(enable_timing), m_stream(stream), m_event([stream->device() newSharedEvent]) {}
+
+MPSEvent::~MPSEvent() {
+  if (m_event) {
+    [m_event release];
+    m_event = nil;
+  }
+  if (m_listener) {
+    [m_listener release];
+    m_listener = nil;
+  }
+}
+
+void MPSEvent::recordLocked(bool syncEvent) {
+  // active encoders must end before encoding or waiting
+  m_stream->endKernelCoalescing();
+  ++m_signalCounter;
+  if (m_enable_timing) {
+    notifyLocked(^(id<MTLSharedEvent>, uint64_t) {
+      m_completion_time = getTime();
+      notifyCpuSync();
+    });
+  }
+  id<MTLCommandBuffer> commandBuffer = m_stream->commandBuffer();
+  [commandBuffer encodeSignalEvent:m_event value:m_signalCounter];
+  if (syncEvent) {
+    m_stream->synchronize(SyncType::COMMIT);
+  }
+}
+
+bool MPSEvent::waitLocked(bool syncEvent) {
+  // check if event is not recorded yet
+  if (m_event.signaledValue >= m_signalCounter) {
+    return false;
+  }
+  // active encoders must end before encoding or waiting
+  m_stream->endKernelCoalescing();
+  id<MTLCommandBuffer> commandBuffer = m_stream->commandBuffer();
+  [commandBuffer encodeWaitForEvent:m_event value:m_signalCounter];
+  if (syncEvent) {
+    m_stream->synchronize(SyncType::COMMIT);
+  }
+  return true;
+}
+
+bool MPSEvent::notifyLocked(MTLSharedEventNotificationBlock block) {
+  // check if event is not recorded yet
+  if (m_event.signaledValue >= m_signalCounter) {
+    return false;
+  }
+  if (!m_listener) {
+    m_listener = [[MTLSharedEventListener alloc] init];
+  }
+  [m_event notifyListener:m_listener atValue:m_signalCounter block:block];
+  return true;
+}
+
+void MPSEvent::record(bool needsLock, bool syncEvent) {
+  if (!needsLock) {
+    recordLocked(syncEvent);
+    return;
+  }
+  dispatch_sync(m_stream->queue(), ^() {
+    @autoreleasepool {
+      recordLocked(syncEvent);
+    }
+  });
+}
+
+bool MPSEvent::wait(bool needsLock, bool syncEvent) {
+  __block bool waited = false;
+  if (!needsLock) {
+    return waitLocked(syncEvent);
+  }
+  dispatch_sync(m_stream->queue(), ^() {
+    @autoreleasepool {
+      waited = waitLocked(syncEvent);
+    }
+  });
+  return waited;
+}
+
+bool MPSEvent::notify(bool needsLock, MTLSharedEventNotificationBlock block) {
+  if (!needsLock) {
+    return notifyLocked(block);
+  }
+  __block bool scheduledNotify = false;
+  dispatch_sync(m_stream->queue(), ^() {
+    @autoreleasepool {
+      scheduledNotify = notifyLocked(block);
+    }
+  });
+  return scheduledNotify;
+}
+
+void MPSEvent::notifyCpuSync() {
+  std::lock_guard<std::mutex> lock(m_cpu_sync_mutex);
+  m_cpu_sync_completed = true;
+  m_cpu_sync_cv.notify_one();
+}
+
+void MPSEvent::waitForCpuSync() {
+  std::unique_lock<std::mutex> lock(m_cpu_sync_mutex);
+  m_cpu_sync_cv.wait(lock, [&] { return m_cpu_sync_completed; });
+  m_cpu_sync_completed = false;
+}
+
+bool MPSEvent::synchronize() {
+  bool scheduledNotify = notifyLocked(^(id<MTLSharedEvent>, uint64_t) {
+    m_completion_time = getTime();
+    notifyCpuSync();
+  });
+
+  if (scheduledNotify) {
+    waitForCpuSync();
+    return true;
+  }
+  return false;
+}
+
+bool MPSEvent::query() const {
+  // return false if not recorded or signaled yet
+  return m_signalCounter && (m_event.signaledValue >= m_signalCounter);
+}
+
+void MPSEvent::reset(MPSStream* stream, bool enable_timing) {
+  if (stream != m_stream) {
+    m_signalCounter = 0;
+    m_event.signaledValue = 0;
+    m_stream = stream;
+  }
+  // reset record time
+  m_completion_time = 0;
+  m_enable_timing = enable_timing;
+  m_cpu_sync_completed = false;
+};
+
+//-----------------------------------------------------------------
+//  MPSEventPool
+//-----------------------------------------------------------------
+
+MPSEventPool::MPSEventPool(MPSStream* default_stream) : m_default_stream(default_stream) {
+  // default deleter to return the event back to pool after it gets released
+  m_default_deleter = [&](MPSEvent* event) {
+    std::lock_guard<std::recursive_mutex> lock(m_mutex);
+    m_pool.push(std::unique_ptr<MPSEvent>(event));
+  };
+}
+
+MPSEventPool::~MPSEventPool() {
+  emptyCache();
+}
+
+MPSEventPtr MPSEventPool::acquireEvent(bool enable_timing, MPSStream* stream) {
+  if (!stream) {
+    stream = m_default_stream;
+  }
+  {
+    std::lock_guard<std::recursive_mutex> lock(m_mutex);
+    if (!m_pool.empty()) {
+      auto event = m_pool.top().release();
+      m_pool.pop();
+      event->reset(stream, enable_timing);
+      return MPSEventPtr(event, m_default_deleter);
+    }
+  }
+  auto new_event = std::make_unique<MPSEvent>(++m_event_counter, stream, enable_timing);
+  return MPSEventPtr(new_event.release(), m_default_deleter);
+}
+
+void MPSEventPool::emptyCache() {
+  std::lock_guard<std::recursive_mutex> lock(m_mutex);
+  while (!m_pool.empty()) {
+    m_pool.pop();
+  }
+}
+
+id_t MPSEventPool::acquireEvent(bool enable_timing) {
+  std::lock_guard<std::recursive_mutex> lock(m_mutex);
+  MPSEventPtr event = acquireEvent(enable_timing, nullptr);
+  TORCH_INTERNAL_ASSERT(event);
+  id_t event_id = event->getID();
+  m_in_use_events.emplace(event_id, std::move(event));
+  return event_id;
+}
+
+void MPSEventPool::releaseEvent(id_t event_id) {
+  std::lock_guard<std::recursive_mutex> lock(m_mutex);
+  TORCH_CHECK(m_in_use_events.count(event_id) > 0, "Invalid Event ID: ", event_id);
+  // returns the event back to the MPSEventPool
+  m_in_use_events.erase(event_id);
+}
+
+void MPSEventPool::recordEvent(id_t event_id, bool syncEvent) {
+  MPSEvent* event = getInUseEvent(event_id);
+  event->record(/*needsLock*/ true, syncEvent);
+}
+
+void MPSEventPool::waitForEvent(id_t event_id, bool syncEvent) {
+  MPSEvent* event = getInUseEvent(event_id);
+  event->wait(/*needsLock*/ true, syncEvent);
+}
+
+void MPSEventPool::synchronizeEvent(id_t event_id) {
+  MPSEvent* event = getInUseEvent(event_id);
+  event->synchronize();
+}
+
+bool MPSEventPool::queryEvent(id_t event_id) {
+  MPSEvent* event = getInUseEvent(event_id);
+  return event->query();
+}
+
+double MPSEventPool::elapsedTime(id_t start_event_id, id_t end_event_id) {
+  // first make sure notifyListeners are called to capture events' completion times
+  dispatch_sync(m_default_stream->queue(), ^() {
+    m_default_stream->synchronize(SyncType::COMMIT_AND_WAIT);
+  });
+  std::lock_guard<std::recursive_mutex> lock(m_mutex);
+  MPSEvent* start_event = getInUseEvent(start_event_id, false);
+  MPSEvent* end_event = getInUseEvent(end_event_id, false);
+  // the notify is called on a separate thread, so this waits for that
+  end_event->waitForCpuSync();
+  const uint64_t start_time = start_event->getCompletionTime();
+  const uint64_t end_time = end_event->getCompletionTime();
+
+  TORCH_CHECK(start_time > 0 && end_time > 0, "Events were not created with argument 'enable_timing=True'");
+  TORCH_CHECK(
+      end_time > start_time, "End event ", end_event_id, " was not recorded after start event ", start_event_id);
+  return double(end_time - start_time) * 1e-6;
+}
+
+MPSEvent* MPSEventPool::getInUseEvent(id_t event_id, bool locked) {
+  if (locked) {
+    m_mutex.lock();
+  }
+  TORCH_CHECK(m_in_use_events.count(event_id) > 0, "Invalid Event ID: ", event_id);
+  MPSEvent* event = m_in_use_events[event_id].get();
+  if (locked) {
+    m_mutex.unlock();
+  }
+  return event;
+}
+
+std::shared_ptr<MPSEventPool> getMPSEventPool() {
+  static std::shared_ptr<MPSEventPool> event_pool = std::make_shared<MPSEventPool>(getDefaultMPSStream());
+  return event_pool;
+}
+
+} // namespace at::mps
diff --git a/aten/src/ATen/mps/MPSGuardImpl.h b/aten/src/ATen/mps/MPSGuardImpl.h
index e122bf3..dd17d0e 100644
--- a/aten/src/ATen/mps/MPSGuardImpl.h
+++ b/aten/src/ATen/mps/MPSGuardImpl.h
@@ -6,6 +6,7 @@
 #include <c10/util/Exception.h>
 #include <ATen/Context.h>
 #include <ATen/mps/MPSStream.h>
+#include <ATen/mps/MPSEvent.h>
 
 #ifdef __OBJC__
 #include <Foundation/Foundation.h>
@@ -26,6 +27,8 @@
 namespace at {
 namespace mps {
 
+typedef MPSEvent* mpsEvent_t;
+
 // TODO: Move the MPSGuardImpl to inherit from NoOpDeviceGuardImpl
 // https://github.com/pytorch/pytorch/issues/77170
 struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface {
diff --git a/aten/src/ATen/mps/MPSGuardImpl.mm b/aten/src/ATen/mps/MPSGuardImpl.mm
index 9c20471..0a1ba6b 100644
--- a/aten/src/ATen/mps/MPSGuardImpl.mm
+++ b/aten/src/ATen/mps/MPSGuardImpl.mm
@@ -28,19 +28,19 @@
 
   auto mps_event = static_cast<mpsEvent_t>(*event);
   MPSStream mps_stream{stream};
-  mps_event->recordEvent(true);
+  mps_event->record(true);
 }
 
 void MPSGuardImpl::block(void* event, const Stream& stream) const {
   auto mps_event = static_cast<mpsEvent_t>(event);
   MPSStream mps_stream{stream};
 
-  mps_event->waitForEvent(true);
+  mps_event->wait(true, false);
 }
 
 bool MPSGuardImpl::queryEvent(void* event) const {
   auto mps_event = static_cast<mpsEvent_t>(event);
-  return mps_event->queryEvent();
+  return mps_event->query();
 }
 
 }
diff --git a/aten/src/ATen/mps/MPSHooks.h b/aten/src/ATen/mps/MPSHooks.h
index dc4a0a3..e2a5e0c 100644
--- a/aten/src/ATen/mps/MPSHooks.h
+++ b/aten/src/ATen/mps/MPSHooks.h
@@ -4,6 +4,7 @@
 
 #include <ATen/detail/MPSHooksInterface.h>
 #include <ATen/Generator.h>
+#include <ATen/mps/MPSEvent.h>
 #include <c10/util/Optional.h>
 
 namespace at { namespace mps {
@@ -32,8 +33,19 @@
   size_t getCurrentAllocatedMemory() const override;
   size_t getDriverAllocatedMemory() const override;
   void setMemoryFraction(double ratio) const override;
+
+  // MPSProfiler interface
   void profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const override;
   void profilerStopTrace() const override;
+
+  // MPSEvent interface
+  uint32_t acquireEvent(bool enable_timing) const override;
+  void releaseEvent(uint32_t event_id) const override;
+  void recordEvent(uint32_t event_id) const override;
+  void waitForEvent(uint32_t event_id) const override;
+  void synchronizeEvent(uint32_t event_id) const override;
+  bool queryEvent(uint32_t event_id) const override;
+  double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const override;
 };
 
 }} // at::mps
diff --git a/aten/src/ATen/mps/MPSHooks.mm b/aten/src/ATen/mps/MPSHooks.mm
index 62debfd..d94b985 100644
--- a/aten/src/ATen/mps/MPSHooks.mm
+++ b/aten/src/ATen/mps/MPSHooks.mm
@@ -84,6 +84,34 @@
   at::mps::getMPSProfiler().StopTrace();
 }
 
+uint32_t MPSHooks::acquireEvent(bool enable_timing) const {
+  return at::mps::getMPSEventPool()->acquireEvent(enable_timing);
+}
+
+void MPSHooks::releaseEvent(uint32_t event_id) const {
+  at::mps::getMPSEventPool()->releaseEvent(event_id);
+}
+
+void MPSHooks::recordEvent(uint32_t event_id) const {
+  at::mps::getMPSEventPool()->recordEvent(event_id, /* syncEvent*/ true);
+}
+
+void MPSHooks::waitForEvent(uint32_t event_id) const {
+  at::mps::getMPSEventPool()->waitForEvent(event_id, /* syncEvent*/ true);
+}
+
+void MPSHooks::synchronizeEvent(uint32_t event_id) const {
+  at::mps::getMPSEventPool()->synchronizeEvent(event_id);
+}
+
+bool MPSHooks::queryEvent(uint32_t event_id) const {
+  return at::mps::getMPSEventPool()->queryEvent(event_id);
+}
+
+double MPSHooks::elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const {
+  return at::mps::getMPSEventPool()->elapsedTime(start_event_id, end_event_id);
+}
+
 using at::MPSHooksRegistry;
 using at::RegistererMPSHooksRegistry;
 
diff --git a/aten/src/ATen/mps/MPSStream.h b/aten/src/ATen/mps/MPSStream.h
index 360d950..aa11e65 100644
--- a/aten/src/ATen/mps/MPSStream.h
+++ b/aten/src/ATen/mps/MPSStream.h
@@ -131,38 +131,5 @@
   MPSStreamImpl();
 };
 
-
-//-----------------------------------------------------------------
-//  MPSEvent
-//-----------------------------------------------------------------
-
-struct TORCH_API MPSEvent
-{
-  // for a new instance of MPSEvent, sometimes we want an empty shell and don't
-  // necessarily want to create events or listeners. So we defer initialization
-  // until we actually use the event (e.g., record, notify, etc.)
-  MPSEvent(bool deferInitialization = true);
-  ~MPSEvent();
-  MTLSharedEvent_t event() const {return _event; }
-
-  void recordEvent(bool syncEvent = false);
-  void waitForEvent(bool syncEvent = false); // waits on the cpu
-  void notifyEvent(MTLSharedEventNotificationBlock block);
-  bool queryEvent() const;
-  uint64_t getCurrentValue() const { return _signalCounter; }
-  void setCurrentValue(uint64_t currValue) { _signalCounter = currValue; }
-private:
-  bool is_initialized;
-  uint64_t _signalCounter;
-  MPSStream* _stream;
-  MTLSharedEvent_t _event;
-  MTLSharedEventListener* _listener;
-
-  void initialize();
-};
-
-typedef MPSEvent* mpsEvent_t;
-
-
 } // namespace mps
 } // namespace at
diff --git a/aten/src/ATen/mps/MPSStream.mm b/aten/src/ATen/mps/MPSStream.mm
index f78e303..e9bb648 100644
--- a/aten/src/ATen/mps/MPSStream.mm
+++ b/aten/src/ATen/mps/MPSStream.mm
@@ -264,77 +264,5 @@
   return MPSStreamImpl::getInstance();
 }
 
-//-----------------------------------------------------------------
-//  MPSEvent
-//-----------------------------------------------------------------
-
-MPSEvent::MPSEvent(bool deferInitialization)
-    : is_initialized(false), _signalCounter(0), _stream(nil), _event(nil), _listener(nil) {
-  if (!deferInitialization) {
-    initialize();
-  }
-}
-
-MPSEvent::~MPSEvent() {
-  if (_event) {
-    [_event release];
-    _event = nil;
-  }
-  if (_listener) {
-    [_listener release];
-    _listener = nil;
-  }
-}
-
-void MPSEvent::initialize() {
-  _stream = getDefaultMPSStream();
-  _event = [_stream->device() newSharedEvent];
-  _listener = [[MTLSharedEventListener alloc] init];
-  is_initialized = true;
-}
-
-void MPSEvent::recordEvent(bool syncEvent) {
-  if (!is_initialized)
-    initialize();
-
-  dispatch_sync(_stream->queue(), ^() {
-    @autoreleasepool {
-      ++_signalCounter;
-      id<MTLCommandBuffer> commandBuffer = _stream->commandBuffer();
-      [commandBuffer encodeSignalEvent:_event value:_signalCounter];
-      if (syncEvent)
-        _stream->synchronize(SyncType::COMMIT);
-    }
-  });
-}
-
-void MPSEvent::waitForEvent(bool syncEvent) {
-  TORCH_INTERNAL_ASSERT(is_initialized);
-  dispatch_sync(_stream->queue(), ^() {
-    @autoreleasepool {
-      id<MTLCommandBuffer> commandBuffer = _stream->commandBuffer();
-      [commandBuffer encodeWaitForEvent:_event value:_signalCounter];
-      if (syncEvent)
-        _stream->synchronize(SyncType::COMMIT);
-    }
-  });
-}
-
-void MPSEvent::notifyEvent(MTLSharedEventNotificationBlock block) {
-  if (!is_initialized)
-    initialize();
-  dispatch_sync(_stream->queue(), ^() {
-    @autoreleasepool {
-      ++_signalCounter;
-      [_event notifyListener:_listener atValue:_signalCounter block:block];
-    }
-  });
-}
-
-bool MPSEvent::queryEvent() const {
-  // return false if not recorded or signaled yet
-  return _signalCounter && (_event.signaledValue >= _signalCounter);
-}
-
 } // namespace mps
 } // namespace at
diff --git a/docs/source/mps.rst b/docs/source/mps.rst
index 7ed30f9..03ec57c 100644
--- a/docs/source/mps.rst
+++ b/docs/source/mps.rst
@@ -26,3 +26,11 @@
     profiler.start
     profiler.stop
     profiler.profile
+
+MPS Event
+------------
+.. autosummary::
+    :toctree: generated
+    :nosignatures:
+
+    event.Event
diff --git a/test/test_mps.py b/test/test_mps.py
index 19e2273..63e2f5f 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -7493,6 +7493,18 @@
         x = net1(x)
         torch.mps.profiler.stop()
 
+    def test_mps_event_module(self):
+        startEvent = torch.mps.Event(enable_timing=True)
+        startEvent.record()
+        net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
+            .to(device='mps', dtype=torch.float)
+        x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
+        x = net1(x)
+        endEvent = torch.mps.Event(enable_timing=True)
+        endEvent.record()
+        elapsedTime = startEvent.elapsed_time(endEvent)
+        self.assertTrue(elapsedTime > 0.0)
+
     def test_jit_save_load(self):
         m = torch.nn.Module()
         m.x = torch.rand(3, 3, device='mps')
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index d74d887..dd07b7f 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -1508,6 +1508,14 @@
 def _mps_is_on_macos_13_or_newer(minor: _int) -> _bool: ...
 def _mps_profilerStartTrace(mode: str, wait_until_completed: _bool) -> None: ...
 def _mps_profilerStopTrace() -> None: ...
+def _mps_acquireEvent(enable_timing: _bool) -> _int: ...
+def _mps_releaseEvent(event_id: _int) -> None: ...
+def _mps_recordEvent(event_id: _int) -> None: ...
+def _mps_waitForEvent(event_id: _int) -> None: ...
+def _mps_synchronizeEvent(event_id: _int) -> None: ...
+def _mps_queryEvent(event_id: _int) -> _bool: ...
+def _mps_elapsedTimeOfEvents(start_event_id: _int, end_event_id: _int) -> _float: ...
+
 
 # Defined in torch/csrc/cuda/Module.cpp
 def _cuda_getCurrentStream(device: _int) -> Tuple: ...
diff --git a/torch/csrc/mps/Module.cpp b/torch/csrc/mps/Module.cpp
index a5021bb..87e9900 100644
--- a/torch/csrc/mps/Module.cpp
+++ b/torch/csrc/mps/Module.cpp
@@ -105,7 +105,7 @@
     PyObject* _unused,
     PyObject* noargs) {
   HANDLE_TH_ERRORS
-  return PyLong_FromUnsignedLongLong(
+  return THPUtils_packUInt64(
       at::detail::getMPSHooks().getCurrentAllocatedMemory());
   END_HANDLE_TH_ERRORS
 }
@@ -114,7 +114,7 @@
     PyObject* _unused,
     PyObject* noargs) {
   HANDLE_TH_ERRORS
-  return PyLong_FromUnsignedLongLong(
+  return THPUtils_packUInt64(
       at::detail::getMPSHooks().getDriverAllocatedMemory());
   END_HANDLE_TH_ERRORS
 }
@@ -146,6 +146,74 @@
   END_HANDLE_TH_ERRORS
 }
 
+static PyObject* MPSModule_acquireEvent(PyObject* _unused, PyObject* args) {
+  HANDLE_TH_ERRORS
+  const bool enable_timing = THPUtils_unpackBool(args);
+  return THPUtils_packUInt32(
+      at::detail::getMPSHooks().acquireEvent(enable_timing));
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject* MPSModule_releaseEvent(PyObject* _unused, PyObject* args) {
+  HANDLE_TH_ERRORS
+  const uint32_t event_id = THPUtils_unpackUInt32(args);
+  at::detail::getMPSHooks().releaseEvent(event_id);
+  Py_RETURN_NONE;
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject* MPSModule_recordEvent(PyObject* _unused, PyObject* args) {
+  HANDLE_TH_ERRORS
+  const uint32_t event_id = THPUtils_unpackUInt32(args);
+  at::detail::getMPSHooks().recordEvent(event_id);
+  Py_RETURN_NONE;
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject* MPSModule_waitForEvent(PyObject* _unused, PyObject* args) {
+  HANDLE_TH_ERRORS
+  const uint32_t event_id = THPUtils_unpackUInt32(args);
+  at::detail::getMPSHooks().waitForEvent(event_id);
+  Py_RETURN_NONE;
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject* MPSModule_synchronizeEvent(PyObject* _unused, PyObject* args) {
+  HANDLE_TH_ERRORS
+  const uint32_t event_id = THPUtils_unpackUInt32(args);
+  at::detail::getMPSHooks().synchronizeEvent(event_id);
+  Py_RETURN_NONE;
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject* MPSModule_queryEvent(PyObject* _unused, PyObject* args) {
+  HANDLE_TH_ERRORS
+  const uint32_t event_id = THPUtils_unpackUInt32(args);
+
+  if (at::detail::getMPSHooks().queryEvent(event_id)) {
+    Py_RETURN_TRUE;
+  } else {
+    Py_RETURN_FALSE;
+  }
+  END_HANDLE_TH_ERRORS
+}
+
+static PyObject* MPSModule_elapsedTimeOfEvents(
+    PyObject* _unused,
+    PyObject* args) {
+  HANDLE_TH_ERRORS
+  PyObject* start_event_o = nullptr;
+  PyObject* end_event_o = nullptr;
+  if (!PyArg_ParseTuple(args, "OO", &start_event_o, &end_event_o)) {
+    return nullptr;
+  }
+  const uint32_t start_event_id = THPUtils_unpackUInt32(start_event_o);
+  const uint32_t end_event_id = THPUtils_unpackUInt32(end_event_o);
+  return PyFloat_FromDouble(at::detail::getMPSHooks().elapsedTimeOfEvents(
+      start_event_id, end_event_id));
+  END_HANDLE_TH_ERRORS
+}
+
 // NOLINTNEXTLINE(modernize-avoid-c-arrays,
 // cppcoreguidelines-avoid-non-const-global-variables,
 // cppcoreguidelines-avoid-c-arrays)
@@ -182,6 +250,16 @@
      MPSModule_profilerStopTrace,
      METH_NOARGS,
      nullptr},
+    {"_mps_acquireEvent", MPSModule_acquireEvent, METH_O, nullptr},
+    {"_mps_releaseEvent", MPSModule_releaseEvent, METH_O, nullptr},
+    {"_mps_recordEvent", MPSModule_recordEvent, METH_O, nullptr},
+    {"_mps_waitForEvent", MPSModule_waitForEvent, METH_O, nullptr},
+    {"_mps_synchronizeEvent", MPSModule_synchronizeEvent, METH_O, nullptr},
+    {"_mps_queryEvent", MPSModule_queryEvent, METH_O, nullptr},
+    {"_mps_elapsedTimeOfEvents",
+     MPSModule_elapsedTimeOfEvents,
+     METH_VARARGS,
+     nullptr},
     {nullptr}};
 
 PyMethodDef* python_functions() {
diff --git a/torch/mps/__init__.py b/torch/mps/__init__.py
index 13ba91d..52cda4f 100644
--- a/torch/mps/__init__.py
+++ b/torch/mps/__init__.py
@@ -113,6 +113,7 @@
 
 
 from . import profiler
+from .event import Event
 
 __all__ = [
     "get_rng_state",
@@ -124,5 +125,6 @@
     "set_per_process_memory_fraction",
     "current_allocated_memory",
     "driver_allocated_memory",
+    "Event",
     "profiler",
 ]
diff --git a/torch/mps/event.py b/torch/mps/event.py
new file mode 100644
index 0000000..a206b64
--- /dev/null
+++ b/torch/mps/event.py
@@ -0,0 +1,45 @@
+import torch
+
+
+class Event:
+    r"""Wrapper around an MPS event.
+
+    MPS events are synchronization markers that can be used to monitor the
+    device's progress, to accurately measure timing, and to synchronize MPS streams.
+
+    Args:
+        enable_timing (bool, optional): indicates if the event should measure time
+            (default: ``False``)
+    """
+
+    def __init__(self, enable_timing=False):
+        self.__eventId = torch._C._mps_acquireEvent(enable_timing)
+
+    def __del__(self):
+        # checks if torch._C is already destroyed
+        if hasattr(torch._C, "_mps_releaseEvent") and self.__eventId > 0:
+            torch._C._mps_releaseEvent(self.__eventId)
+
+    def record(self):
+        r"""Records the event in the default stream."""
+        torch._C._mps_recordEvent(self.__eventId)
+
+    def wait(self):
+        r"""Makes all future work submitted to the default stream wait for this event."""
+        torch._C._mps_waitForEvent(self.__eventId)
+
+    def query(self):
+        r"""Returns True if all work currently captured by event has completed."""
+        return torch._C._mps_queryEvent(self.__eventId)
+
+    def synchronize(self):
+        r"""Waits until the completion of all work currently captured in this event.
+        This prevents the CPU thread from proceeding until the event completes.
+        """
+        torch._C._mps_synchronizeEvent(self.__eventId)
+
+    def elapsed_time(self, end_event):
+        r"""Returns the time elapsed in milliseconds after the event was
+        recorded and before the end_event was recorded.
+        """
+        return torch._C._mps_elapsedTimeOfEvents(self.__eventId, end_event.__eventId)