[MPS] Native nonzero implementation (#125355)
Fixes https://github.com/pytorch/pytorch/issues/124850
Replace previous MPSGraph nonzero construction with native nonzero op. For older OSes, fallback to CPU (previous implementation was not reliable and was comparable to CPU in speed).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125355
Approved by: https://github.com/kulinseth
diff --git a/test/test_mps.py b/test/test_mps.py
index 38fea5b..1bc1ca8 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -227,6 +227,7 @@
'__rmul__',
'__getitem__',
'add',
+ 'argwhere',
'atleast_1d',
'atleast_2d',
'atleast_3d',
@@ -287,6 +288,7 @@
'nn.functional.padcircular',
'nn.functional.feature_alpha_dropoutwithout_train',
'nn.functional.unfold',
+ 'nonzero',
'ones',
'outer',
'permute',
@@ -340,7 +342,6 @@
'any',
'addcdiv',
'addcmul',
- 'argwhere',
'asin',
'atan',
'atanh',
@@ -408,7 +409,6 @@
'nn.functional.pixel_shuffle',
'nn.functional.pixel_unshuffle',
'nn.functional.tanhshrink',
- 'nonzero',
'prod',
'reciprocal',
'roll',