[MPS] Fix crash if nonzero is called concurrently (#108996)
Surrounds `stream->synchronize()` call with `dispatch_sync(stream->queue(), ^{});`, which is a noop for signle threaded program, but serializes calls to the synchronize across the threads using the same stream.
Prevent `[IOGPUMetalCommandBuffer validate]:215: failed assertion 'commit an already committed command buffer'` non-recoverable exception, which is triggered every time one is using PyCharm to inspect tensors on MPS device
Fixes https://github.com/pytorch/pytorch/issues/100285
<!--
copilot:poem
-->
### <samp>🤖 Generated by Copilot at 1662ce2</samp>
> _Sing, O Muse, of the swift and skillful coders_
> _Who fixed the dreadful deadlock of the stream_
> _That crashed the mighty tensors of the MPS_
> _When they sought out the nonzero elements._
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108996
Approved by: https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index d307ef6..59c73e2 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -12,6 +12,7 @@
import os
import copy
import gc
+import threading
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -9761,6 +9762,15 @@
nz = x.nonzero()
self.assertFalse(nz.requires_grad)
+ def test_nonzero_multi_threading(self):
+ # Test that MPS does not crash if nonzero called concurrently
+ # See https://github.com/pytorch/pytorch/issues/100285
+ x = torch.rand(3, 3, device="mps")
+ t1 = threading.Thread(target=torch.nonzero, args=(x,))
+ t2 = threading.Thread(target=torch.nonzero, args=(x,))
+ t1.start()
+ t2.start()
+
def test_masked_select(self):
x = torch.randn(3, 4)
x_mps = x.to("mps")