[MPS] Native nonzero implementation (#125355)

Fixes https://github.com/pytorch/pytorch/issues/124850

Replace previous MPSGraph nonzero construction with native nonzero op. For older OSes, fallback to CPU (previous implementation was not reliable and was comparable to CPU in speed).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125355
Approved by: https://github.com/kulinseth
diff --git a/aten/src/ATen/native/mps/operations/Indexing.mm b/aten/src/ATen/native/mps/operations/Indexing.mm
index d86f57c..38c3212 100644
--- a/aten/src/ATen/native/mps/operations/Indexing.mm
+++ b/aten/src/ATen/native/mps/operations/Indexing.mm
@@ -241,14 +241,20 @@
 } // namespace mps
 
 static Tensor nonzero_fallback(const Tensor& self) {
-  TORCH_WARN_ONCE("MPS: nonzero op is supported natively starting from macOS 13.0. ",
-                  "Falling back on CPU. This may have performance implications.");
-
   return at::nonzero(self.to("cpu")).clone().to("mps");
 }
 
 Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_) {
-  if (!is_macos_13_or_newer()) {
+  if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) {
+    TORCH_WARN_ONCE("MPS: nonzero op is supported natively starting from macOS 13.0. ",
+                    "Falling back on CPU. This may have performance implications.");
+    Tensor out_fallback = nonzero_fallback(self);
+    at::native::resize_output(out_, out_fallback.sizes());
+    out_.copy_(out_fallback.to("mps"));
+    return out_;
+  } else if (self.is_complex()) {
+    TORCH_WARN_ONCE("MPS: nonzero op is not supported for complex datatypes. ",
+                    "Falling back on CPU. This may have performance implications.");
     Tensor out_fallback = nonzero_fallback(self);
     at::native::resize_output(out_, out_fallback.sizes());
     out_.copy_(out_fallback.to("mps"));
@@ -282,7 +288,6 @@
     CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
     MPSGraphTensor* inputTensor_ = nil;
     MPSGraphTensor* outputTensor_ = nil;
-    MPSGraphTensor* scatterDataTensor_ = nil;
   };
 
   dispatch_sync(stream->queue(), ^() {
@@ -300,93 +305,20 @@
     out = at::empty(out_.sizes(), out_.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt);
   }
 
-  int64_t _apparentInputShape = 1;
-  for (auto dim : self.sizes()) {
-    _apparentInputShape *= dim;
-  }
-  MPSShape* apparentOutputShape = @[ @(total_nonzero * nDim) ];
-  MPSShape* apparentInputShape = @[ @(_apparentInputShape) ];
-
-  // Pseudocode:
-  //
-  // inputTensor     = [1,  0,  0,  3]
-  // inputNonZero    = [1,  0,  0,  1]
-  // indices         = [1,  1,  1,  2]
-  // maskedIndices   = [0, -1, -1,  1]
-  // coordinates     = [0,  1,  2,  3]
-  // scatterResult   = [0,  3]
-
   @autoreleasepool {
     string key = "nonzero_out_mps" + getTensorsStringKey(self);
     auto cachedGraph = LookUpOrCreateCachedGraph<CachedGraph>(key, [&](auto mpsGraph, auto newCachedGraph) {
-      MPSDataType inputDataType = getMPSDataType(self);
-      MPSShape* inputShape = getMPSShape(self);
+      MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(self), getMPSShape(self));
 
-      MPSGraphTensor* inputTensor =
-          mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(self.scalar_type()), apparentInputShape);
-      MPSGraphTensor* scatterDataTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSScalarType(out.scalar_type()));
-      MPSGraphTensor* zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputDataType];
-      MPSGraphTensor* oneTensor = [mpsGraph constantWithScalar:1.0 dataType:MPSDataTypeInt32];
-      MPSGraphTensor* minusMaxDimTensor = [mpsGraph constantWithScalar:-maxDimensions dataType:MPSDataTypeInt32];
-      MPSGraphTensor* inputNotEqualToZeroTensor = [mpsGraph notEqualWithPrimaryTensor:inputTensor
-                                                                      secondaryTensor:zeroTensor
-                                                                                 name:nil];
-      MPSGraphTensor* maskTensor = [mpsGraph castTensor:inputNotEqualToZeroTensor
-                                                 toType:MPSDataTypeInt32
-                                                   name:@"castToInt32"];
-      MPSGraphTensor* indicesTensor = [mpsGraph cumulativeSumWithTensor:maskTensor axis:0 name:nil];
-      MPSGraphTensor* indicesMinusOneTensor = [mpsGraph subtractionWithPrimaryTensor:indicesTensor
-                                                                     secondaryTensor:oneTensor
-                                                                                name:nil];
-      MPSGraphTensor* maskedIndicesTensor = [mpsGraph selectWithPredicateTensor:inputNotEqualToZeroTensor
-                                                            truePredicateTensor:indicesMinusOneTensor
-                                                           falsePredicateTensor:minusMaxDimTensor
-                                                                           name:nil];
-      MPSGraphTensor* coordinatesTensor = [mpsGraph reshapeTensor:[mpsGraph coordinateAlongAxis:0
-                                                                                      withShape:inputShape
-                                                                                           name:nil]
-                                                        withShape:@[ @-1 ]
-                                                             name:nil];
-      if (nDim > 1) {
-        NSMutableArray<MPSGraphTensor*>* maskedIndicesTensorArray = [NSMutableArray arrayWithCapacity:nDim];
-        NSMutableArray<MPSGraphTensor*>* coordinatesTensorArray = [NSMutableArray arrayWithCapacity:nDim];
-
-        MPSGraphTensor* constantRankTensor = [mpsGraph constantWithScalar:nDim dataType:MPSDataTypeInt32];
-        maskedIndicesTensorArray[0] = [mpsGraph multiplicationWithPrimaryTensor:maskedIndicesTensor
-                                                                secondaryTensor:constantRankTensor
-                                                                           name:nil];
-        coordinatesTensorArray[0] = coordinatesTensor;
-        for (int i = 1; i < nDim; i++) {
-          maskedIndicesTensorArray[i] = [mpsGraph additionWithPrimaryTensor:maskedIndicesTensorArray[i - 1]
-                                                            secondaryTensor:oneTensor
-                                                                       name:nil];
-          coordinatesTensorArray[i] = [mpsGraph reshapeTensor:[mpsGraph coordinateAlongAxis:i
-                                                                                  withShape:inputShape
-                                                                                       name:nil]
-                                                    withShape:@[ @-1 ]
-                                                         name:nil];
-        }
-        maskedIndicesTensor = [mpsGraph concatTensors:maskedIndicesTensorArray dimension:0 interleave:YES name:nil];
-        coordinatesTensor = [mpsGraph concatTensors:coordinatesTensorArray dimension:0 interleave:YES name:nil];
-      }
-
-      MPSGraphTensor* outputTensor = [mpsGraph scatterWithDataTensor:scatterDataTensor
-                                                       updatesTensor:coordinatesTensor
-                                                       indicesTensor:maskedIndicesTensor
-                                                                axis:0
-                                                                mode:MPSGraphScatterModeSet
-                                                                name:nil];
+      MPSGraphTensor* outputTensor = [mpsGraph nonZeroIndicesOfTensor:inputTensor name:nil];
 
       newCachedGraph->inputTensor_ = inputTensor;
-      newCachedGraph->scatterDataTensor_ = scatterDataTensor;
       newCachedGraph->outputTensor_ = outputTensor;
     });
 
-    Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, apparentInputShape);
-    Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out, apparentOutputShape);
-    Placeholder scatterPlaceholder = Placeholder(cachedGraph->scatterDataTensor_, out, apparentOutputShape);
-
-    auto feeds = dictionaryFromPlaceholders(selfPlaceholder, scatterPlaceholder);
+    Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
+    Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out);
+    auto feeds = dictionaryFromPlaceholders(selfPlaceholder);
     runMPSGraph(stream, cachedGraph->graph(), feeds, outputPlaceholder);
   }
 
@@ -398,7 +330,13 @@
 }
 
 Tensor nonzero_mps(const Tensor& self) {
-  if (!is_macos_13_or_newer()) {
+  if (!is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS)) {
+    TORCH_WARN_ONCE("MPS: nonzero op is supported natively starting from macOS 13.0. ",
+                    "Falling back on CPU. This may have performance implications.");
+    return nonzero_fallback(self);
+  } else if (self.is_complex()) {
+    TORCH_WARN_ONCE("MPS: nonzero op is not supported for complex datatypes ",
+                    "Falling back on CPU. This may have performance implications.");
     return nonzero_fallback(self);
   }
 
diff --git a/test/test_mps.py b/test/test_mps.py
index 38fea5b..1bc1ca8 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -227,6 +227,7 @@
         '__rmul__',
         '__getitem__',
         'add',
+        'argwhere',
         'atleast_1d',
         'atleast_2d',
         'atleast_3d',
@@ -287,6 +288,7 @@
         'nn.functional.padcircular',
         'nn.functional.feature_alpha_dropoutwithout_train',
         'nn.functional.unfold',
+        'nonzero',
         'ones',
         'outer',
         'permute',
@@ -340,7 +342,6 @@
         'any',
         'addcdiv',
         'addcmul',
-        'argwhere',
         'asin',
         'atan',
         'atanh',
@@ -408,7 +409,6 @@
         'nn.functional.pixel_shuffle',
         'nn.functional.pixel_unshuffle',
         'nn.functional.tanhshrink',
-        'nonzero',
         'prod',
         'reciprocal',
         'roll',