[MPS] Add nonzero mps support (#91616)

Adds nonzero support for mps:

  **Pseudocode**:
  ```
  //
  // inputTensor   = [1,  0,  0,  3]
  // inputNonZero  = [1,  0,  0,  1] (input != 0)
  // scan          = [1,  1,  1,  2] (prefix sum)
  // maskedIndices = [0, -1, -1,  1] (select)
  // coordinates   = [0,  1,  2,  3] (coordinateAlongAxis)
  // scatterResult = [0,  3]         (scatter)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91616
Approved by: https://github.com/razarmehr
diff --git a/aten/src/ATen/mps/MPSFallback.mm b/aten/src/ATen/mps/MPSFallback.mm
index f1c0dbb..e5dfde1 100644
--- a/aten/src/ATen/mps/MPSFallback.mm
+++ b/aten/src/ATen/mps/MPSFallback.mm
@@ -61,7 +61,6 @@
   m.impl("_fft_r2c", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
   m.impl("linalg_vector_norm", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
   m.impl("sgn.out", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
-  m.impl("nonzero", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
   m.impl("_slow_conv2d_forward", slow_conv2d_forward_mps);
 }
 
diff --git a/aten/src/ATen/native/mps/operations/Indexing.mm b/aten/src/ATen/native/mps/operations/Indexing.mm
index 331114d..a97f96b 100644
--- a/aten/src/ATen/native/mps/operations/Indexing.mm
+++ b/aten/src/ATen/native/mps/operations/Indexing.mm
@@ -12,6 +12,7 @@
 #include <ATen/native/LinearAlgebraUtils.h>
 #include <ATen/native/mps/OperationUtils.h>
 #include <ATen/native/mps/operations/Indexing.h>
+#include <ATen/native/mps/MPSGraphVenturaOps.h>
 #include <ATen/native/Resize.h>
 #include <ATen/AccumulateType.h>
 #include <torch/library.h>
@@ -211,6 +212,185 @@
   return result;
 }
 
+static
+Tensor nonzero_fallback(const Tensor& self) {
+  TORCH_WARN_ONCE("MPS: nonzero op is supported natively starting from macOS 13.0. ",
+                  "Falling back on CPU. This may have performance implications.");
+
+  return at::nonzero(self.to("cpu")).clone().to("mps");
+}
+
+Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_){
+  if (!is_macos_13_or_newer()) {
+      Tensor out_fallback = nonzero_fallback(self);
+      at::native::resize_output(out_, out_fallback.sizes());
+      out_.copy_(out_fallback.to("mps"));
+      return out_;
+  }
+
+  using namespace mps;
+  const uint32_t maxDimensions = 16;
+
+  TORCH_CHECK(self.numel() < std::numeric_limits<int>::max(), "nonzero is not supported for tensors with more than INT_MAX elements, \
+  file a support request");
+  TORCH_CHECK(out_.dtype() == at::kLong, "Expected object of scalar type ", at::kLong, " as out, but got ", out_.dtype());
+  TORCH_CHECK(self.device() == out_.device(), "expected self and out to be on the same device, but got out on ",
+  out_.device(), " and self on ", self.device());
+  TORCH_CHECK(self.dim() <= maxDimensions, "nonzero is not supported for tensor with more than ", 16, " dimensions");
+  TORCH_CHECK(out_.is_mps());
+
+  MPSStream *stream = getCurrentMPSStream();
+  struct CachedGraph : public MPSCachedGraph
+  {
+    CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
+    MPSGraphTensor* inputTensor_ = nil;
+    MPSGraphTensor* outputTensor_ = nil;
+    MPSGraphTensor* scatterDataTensor_ = nil;
+  };
+
+  int64_t total_nonzero = at::count_nonzero(self).item<int64_t>();
+  int64_t nDim = self.dim();
+  at::native::resize_output(out_, {total_nonzero, nDim});
+  if (out_.numel() ==  0) {
+    return out_;
+  }
+
+  bool contiguous_output = (out_.is_contiguous() && !out_.is_view());
+  Tensor out = out_;
+  if (!contiguous_output) {
+    out = at::native::empty_mps(
+           out_.sizes(),
+           out_.scalar_type(),
+           c10::nullopt,
+           kMPS,
+           c10::nullopt,
+           c10::nullopt);
+  }
+
+  int64_t _apparentInputShape = 1;
+  for (auto dim : self.sizes()) {
+    _apparentInputShape *= dim;
+  }
+  MPSShape *apparentOutputShape = @[@(total_nonzero * nDim)];
+  MPSShape *apparentInputShape = @[@(_apparentInputShape)];
+
+  // Pseudocode:
+  //
+  // inputTensor     = [1,  0,  0,  3]
+  // inputNonZero    = [1,  0,  0,  1]
+  // indices         = [1,  1,  1,  2]
+  // maskedIndices   = [0, -1, -1,  1]
+  // coordinates     = [0,  1,  2,  3]
+  // scatterResult   = [0,  3]
+
+  MPSGraphCache* cache_ = MPSGraphCache::getInstance();
+  @autoreleasepool {
+    string key = "nonzero_out_mps" + getTensorsStringKey(self);
+    CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
+
+    if(!cachedGraph) {
+      MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
+        CachedGraph *newCachedGraph = nil;
+        @autoreleasepool {
+          MPSDataType inputDataType = getMPSDataType(self.scalar_type());
+          MPSShape* inputShape = getMPSShape(self);
+          MPSGraph* mpsGraph = make_mps_graph();
+          newCachedGraph = new CachedGraph(mpsGraph);
+
+          MPSGraphTensor *inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(self.scalar_type()), apparentInputShape);
+          MPSGraphTensor *scatterDataTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSScalarType(out.scalar_type()));
+          MPSGraphTensor *zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputDataType];
+          MPSGraphTensor *oneTensor = [mpsGraph constantWithScalar:1.0 dataType:MPSDataTypeInt32];
+          MPSGraphTensor *minusMaxDimTensor = [mpsGraph constantWithScalar:-maxDimensions dataType:MPSDataTypeInt32];
+          MPSGraphTensor *inputNotEqualToZeroTensor = [mpsGraph notEqualWithPrimaryTensor:inputTensor
+                                                                          secondaryTensor:zeroTensor
+                                                                                     name:nil];
+          MPSGraphTensor *maskTensor = [mpsGraph castTensor:inputNotEqualToZeroTensor
+                                                     toType:MPSDataTypeInt32
+                                                       name:@"castToInt32"];
+          MPSGraphTensor *indicesTensor = [mpsGraph cumulativeSumWithTensor:maskTensor
+                                                                       axis:0
+                                                                       name:nil];
+          MPSGraphTensor *indicesMinusOneTensor = [mpsGraph subtractionWithPrimaryTensor:indicesTensor
+                                                                        secondaryTensor:oneTensor
+                                                                                   name:nil];
+          MPSGraphTensor *maskedIndicesTensor = [mpsGraph selectWithPredicateTensor:inputNotEqualToZeroTensor
+                                                                truePredicateTensor:indicesMinusOneTensor
+                                                               falsePredicateTensor:minusMaxDimTensor
+                                                                               name:nil];
+          MPSGraphTensor *coordinatesTensor = [mpsGraph reshapeTensor:[mpsGraph coordinateAlongAxis:0 withShape:inputShape name:nil]
+                                                            withShape:@[@-1]
+                                                                name:nil];
+          if (nDim > 1) {
+            NSMutableArray<MPSGraphTensor*> *maskedIndicesTensorArray = [NSMutableArray arrayWithCapacity:nDim];
+            NSMutableArray<MPSGraphTensor*> *coordinatesTensorArray = [NSMutableArray arrayWithCapacity:nDim];
+
+            MPSGraphTensor *constantRankTensor = [mpsGraph constantWithScalar:nDim
+                                                                     dataType:MPSDataTypeInt32];
+            maskedIndicesTensorArray[0] = [mpsGraph multiplicationWithPrimaryTensor:maskedIndicesTensor
+                                                                    secondaryTensor:constantRankTensor
+                                                                               name:nil];
+            coordinatesTensorArray[0] = coordinatesTensor;
+            for (int i = 1; i < nDim; i++){
+              maskedIndicesTensorArray[i] = [mpsGraph additionWithPrimaryTensor:maskedIndicesTensorArray[i - 1]
+                                                                secondaryTensor:oneTensor
+                                                                           name:nil];
+              coordinatesTensorArray[i] = [mpsGraph reshapeTensor:[mpsGraph coordinateAlongAxis:i withShape:inputShape name:nil]
+                                                        withShape:@[@-1]
+                                                             name:nil];
+            }
+            maskedIndicesTensor = [mpsGraph concatTensors:maskedIndicesTensorArray dimension:0 interleave:YES name:nil];
+            coordinatesTensor = [mpsGraph concatTensors:coordinatesTensorArray dimension:0 interleave:YES name:nil];
+          }
+
+          MPSGraphTensor *outputTensor = [mpsGraph scatterWithDataTensor:scatterDataTensor
+                                                           updatesTensor:coordinatesTensor
+                                                           indicesTensor:maskedIndicesTensor
+                                                                    axis:0
+                                                                    mode:MPSGraphScatterModeSet
+                                                                    name:nil];
+
+          newCachedGraph->inputTensor_ = inputTensor;
+          newCachedGraph->scatterDataTensor_ = scatterDataTensor;
+          newCachedGraph->outputTensor_ = outputTensor;
+        }
+        return newCachedGraph;
+      });
+      cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
+    }
+
+    Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, apparentInputShape);
+    Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, contiguous_output ? out_ : out, apparentOutputShape);
+    Placeholder scatterPlaceholder = Placeholder(cachedGraph->scatterDataTensor_, contiguous_output ? out_ : out, apparentOutputShape);
+
+    // Create dictionary of inputs and outputs
+    NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
+      selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(),
+      scatterPlaceholder.getMPSGraphTensor() : scatterPlaceholder.getMPSGraphTensorData()
+    };
+
+    NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
+      outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
+    };
+
+    runMPSGraph(stream, cachedGraph->graph(), feeds, results);
+    if (!contiguous_output) {
+      out_.copy_(out);
+    }
+  }
+
+  return out_;
+}
+
+Tensor nonzero_mps(const Tensor& self){
+  if (!is_macos_13_or_newer()) {
+    return nonzero_fallback(self);
+  }
+
+  Tensor out = at::empty({0}, self.options().dtype(kLong));
+  return nonzero_out_mps(self, out);
+}
+
 Tensor masked_select_mps(const Tensor & self, const Tensor & mask) {
   namedinference::compute_broadcast_outnames(self, mask);
   Tensor result = at::empty({0}, self.options());
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 3e485d6..9b42c97 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -8521,6 +8521,7 @@
   dispatch:
     CPU: nonzero_out_cpu
     CUDA: nonzero_out_cuda
+    MPS: nonzero_out_mps
   tags: dynamic_output_shape
 
 - func: nonzero(Tensor self) -> Tensor
@@ -8528,6 +8529,7 @@
   dispatch:
     CPU: nonzero_cpu
     CUDA: nonzero_cuda
+    MPS: nonzero_mps
   tags: [dynamic_output_shape, canonical]
 
 - func: nonzero_numpy(Tensor self) -> Tensor[]
diff --git a/test/test_mps.py b/test/test_mps.py
index d3308d9..00dadbe 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -6687,6 +6687,116 @@
     supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8]
     supported_np_dtypes = [np.float32, np.float16, np.int64, np.int32, np.int16, np.uint8]
 
+    def test_nonzero_no_warning(self):
+        device = "mps"
+        t = torch.randn((2, 2), device=device)
+        with warnings.catch_warnings(record=True) as w:
+            warnings.simplefilter("always")
+            torch.nonzero(t)
+            t.nonzero()
+            self.assertEqual(len(w), 0)
+
+    def test_nonzero(self):
+        def helper(dtype):
+            device = "mps"
+            shapes = [
+                torch.Size((12,)),
+                torch.Size((12, 1)),
+                torch.Size((1, 12)),
+                torch.Size((6, 2)),
+                torch.Size((3, 2, 2)),
+                torch.Size((5, 5, 5)),
+            ]
+
+            def gen_nontrivial_input(shape, dtype, device):
+                if dtype != torch.bfloat16:
+                    return torch.randint(2, shape, device=device, dtype=dtype)
+                else:
+                    # windows does not work for bfloat16 randing
+                    return torch.randint(2, shape, device=device, dtype=torch.float).to(dtype)
+
+            for shape in shapes:
+                tensor = gen_nontrivial_input(shape, dtype, device)
+                dst1 = torch.nonzero(tensor, as_tuple=False)
+                dst2 = tensor.nonzero(as_tuple=False)
+                dst3 = torch.empty([], dtype=torch.long, device=device)
+                dst3 = dst3.resize_(0)
+                torch.nonzero(tensor, out=dst3)
+                np_array = tensor.cpu().numpy() if dtype != torch.bfloat16 else tensor.float().cpu().numpy()
+                np_result = torch.from_numpy(np.stack(np_array.nonzero())).t()
+                self.assertEqual(dst1.cpu(), np_result, atol=0, rtol=0)
+                self.assertEqual(dst2.cpu(), np_result, atol=0, rtol=0)
+                self.assertEqual(dst3.cpu(), np_result, atol=0, rtol=0)
+                tup1 = torch.nonzero(tensor, as_tuple=True)
+                tup2 = tensor.nonzero(as_tuple=True)
+                tup1 = torch.stack(tup1).t().cpu()
+                tup2 = torch.stack(tup2).t().cpu()
+                self.assertEqual(tup1, np_result, atol=0, rtol=0)
+                self.assertEqual(tup2, np_result, atol=0, rtol=0)
+        [helper(dtype) for dtype in self.supported_dtypes]
+
+    def test_nonzero_astuple_out(self):
+        device = "mps"
+        t = torch.randn((3, 3, 3), device=device)
+        out = torch.empty([], dtype=torch.long, device=device)
+        out = out.resize_(0)
+
+        with self.assertRaises(RuntimeError):
+            torch.nonzero(t, as_tuple=True, out=out)
+
+        self.assertEqual(torch.nonzero(t, as_tuple=False, out=out), torch.nonzero(t, out=out))
+
+        # Verifies that JIT script cannot handle the as_tuple kwarg
+        # See Issue https://github.com/pytorch/pytorch/issues/45499.
+        def _foo(t):
+            tuple_result = torch.nonzero(t, as_tuple=True)
+            nontuple_result = torch.nonzero(t, as_tuple=False)
+            out = torch.empty_like(nontuple_result)
+            torch.nonzero(t, as_tuple=False, out=out)
+            return tuple_result, nontuple_result, out
+
+        with self.assertRaises(RuntimeError):
+            scripted_foo = torch.jit.script(_foo)
+
+        # Verifies that JIT tracing works fine
+        traced_foo = torch.jit.trace(_foo, t)
+        traced_tuple, traced_nontuple, traced_out = traced_foo(t)
+        expected_tuple = torch.nonzero(t, as_tuple=True)
+        expected_nontuple = torch.nonzero(t)
+
+        self.assertEqual(traced_tuple, expected_tuple)
+        self.assertEqual(traced_nontuple, expected_nontuple)
+        self.assertEqual(traced_out, expected_nontuple)
+
+    def test_nonzero_discontiguous(self):
+        device = "mps"
+        shape = (4, 4)
+        tensor = torch.randint(2, shape, device=device)
+        tensor_nc = torch.empty(shape[0], shape[1] * 2, device=device)[:, ::2].copy_(tensor)
+        dst1 = tensor.nonzero(as_tuple=False)
+        dst2 = tensor_nc.nonzero(as_tuple=False)
+        self.assertEqual(dst1, dst2, atol=0, rtol=0)
+        dst3 = torch.empty_like(dst1)
+        data_ptr = dst3.data_ptr()
+        # expect dst3 storage to be reused
+        torch.nonzero(tensor, out=dst3)
+        self.assertEqual(data_ptr, dst3.data_ptr())
+        self.assertEqual(dst1, dst3, atol=0, rtol=0)
+        # discontiguous out
+        dst4 = torch.empty(dst1.size(0), dst1.size(1) * 2, dtype=torch.long, device=device)[:, ::2]
+        data_ptr = dst4.data_ptr()
+        strides = dst4.stride()
+        torch.nonzero(tensor, out=dst4)
+        self.assertEqual(data_ptr, dst4.data_ptr())
+        self.assertEqual(dst1, dst4, atol=0, rtol=0)
+        self.assertEqual(strides, dst4.stride())
+
+    def test_nonzero_non_diff(self):
+        device = "mps"
+        x = torch.randn(10, requires_grad=True)
+        nz = x.nonzero()
+        self.assertFalse(nz.requires_grad)
+
     def test_masked_select(self):
         x = torch.randn(3, 4)
         x_mps = x.to("mps")
@@ -7841,7 +7951,8 @@
         'vsplit': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
         'vstack': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
         'zero_': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
-        'where': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8']
+        'where': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
+        'nonzero': ['f32', 'i16', 'i32', 'i64']
     }
 
 
@@ -8066,6 +8177,8 @@
         'slice_scatter': [torch.uint8],
         'square': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8],  # moved from section below
 
+        # count_nonzero returns wrong results for these dtypes
+        'nonzero': [torch.uint8, torch.float16],
         # ALLOW_LIST doesn't know about variants
         'nn.functional.padconstant': None,
 
@@ -8141,7 +8254,6 @@
         'eq': None,
         'mul': None,
         'cartesian_prod': None,
-        'nonzero': None,
         'bool': None,
         'inner': None,
         'dstack': None,