[MPS] Introduce torch.mps.Event() APIs (#102121)
- Implement `MPSEventPool` to recycle events.
- Implement python bindings with `torch.mps.Event` class using the MPSEventPool backend. The current member functions of the Event class are `record()`, `wait()`, `synchronize()`, `query()`, and `elapsed_time()`.
- Add API to measure elapsed time between two event recordings.
- Added documentation for Event class to `mps.rst`.
- Added test case to `test_mps.py`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102121
Approved by: https://github.com/albanD, https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index 19e2273..63e2f5f 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -7493,6 +7493,18 @@
x = net1(x)
torch.mps.profiler.stop()
+ def test_mps_event_module(self):
+ startEvent = torch.mps.Event(enable_timing=True)
+ startEvent.record()
+ net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
+ .to(device='mps', dtype=torch.float)
+ x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
+ x = net1(x)
+ endEvent = torch.mps.Event(enable_timing=True)
+ endEvent.record()
+ elapsedTime = startEvent.elapsed_time(endEvent)
+ self.assertTrue(elapsedTime > 0.0)
+
def test_jit_save_load(self):
m = torch.nn.Module()
m.x = torch.rand(3, 3, device='mps')