[MPS] Add support for MPSProfiler Python bindings (#101002)
- Added torch.mps.profiler.[start() and stop()] APIs with RST documentation
- Added test case in test_mps
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101002
Approved by: https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index 60ef643..f281c7a 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -7325,6 +7325,25 @@
self.assertTrue(current_alloc_after > current_alloc_before)
self.assertTrue(driver_alloc_after > driver_alloc_before)
+ # to verify this test, run XCode Instruments "Metal System Trace" or "Logging" tool,
+ # press record, then run this python test, and press stop. Next expand
+ # the os_signposts->PyTorchMPS and check if events or intervals are logged
+ # like this example:
+ # "aten::mps_convolution_backward_input:f32[1,128,6,6]:f32[128,64,3,3]:1,128,6,6 (id=G2, run=2)"
+ def test_mps_profiler_module(self):
+ with torch.mps.profiler.profile(mode="event", wait_until_completed=False) as p:
+ # just running some ops to capture the OS Signposts traces for profiling
+ net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\
+ .to(device='mps', dtype=torch.float)
+ x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
+ x = net1(x)
+
+ torch.mps.profiler.start(mode="interval", wait_until_completed=True)
+ # just running some ops to capture the OS Signposts traces for profiling
+ x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True)
+ x = net1(x)
+ torch.mps.profiler.stop()
+
# Test random_, random_.to and random_.from
def test_random(self):
def helper(shape, low, high, dtype=torch.int32):