[MPS] Add support for MPSProfiler Python bindings (#101002) - Added torch.mps.profiler.[start() and stop()] APIs with RST documentation - Added test case in test_mps Pull Request resolved: https://github.com/pytorch/pytorch/pull/101002 Approved by: https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py index 60ef643..f281c7a 100644 --- a/test/test_mps.py +++ b/test/test_mps.py
@@ -7325,6 +7325,25 @@ self.assertTrue(current_alloc_after > current_alloc_before) self.assertTrue(driver_alloc_after > driver_alloc_before) + # to verify this test, run XCode Instruments "Metal System Trace" or "Logging" tool, + # press record, then run this python test, and press stop. Next expand + # the os_signposts->PyTorchMPS and check if events or intervals are logged + # like this example: + # "aten::mps_convolution_backward_input:f32[1,128,6,6]:f32[128,64,3,3]:1,128,6,6 (id=G2, run=2)" + def test_mps_profiler_module(self): + with torch.mps.profiler.profile(mode="event", wait_until_completed=False) as p: + # just running some ops to capture the OS Signposts traces for profiling + net1 = torch.nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)\ + .to(device='mps', dtype=torch.float) + x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True) + x = net1(x) + + torch.mps.profiler.start(mode="interval", wait_until_completed=True) + # just running some ops to capture the OS Signposts traces for profiling + x = torch.rand(1, 128, 6, 6, device='mps', dtype=torch.float, requires_grad=True) + x = net1(x) + torch.mps.profiler.stop() + # Test random_, random_.to and random_.from def test_random(self): def helper(shape, low, high, dtype=torch.int32):