[MPS] Add nonzero mps support (#91616)
Adds nonzero support for mps:
**Pseudocode**:
```
//
// inputTensor = [1, 0, 0, 3]
// inputNonZero = [1, 0, 0, 1] (input != 0)
// scan = [1, 1, 1, 2] (prefix sum)
// maskedIndices = [0, -1, -1, 1] (select)
// coordinates = [0, 1, 2, 3] (coordinateAlongAxis)
// scatterResult = [0, 3] (scatter)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91616
Approved by: https://github.com/razarmehr
diff --git a/aten/src/ATen/mps/MPSFallback.mm b/aten/src/ATen/mps/MPSFallback.mm
index f1c0dbb..e5dfde1 100644
--- a/aten/src/ATen/mps/MPSFallback.mm
+++ b/aten/src/ATen/mps/MPSFallback.mm
@@ -61,7 +61,6 @@
m.impl("_fft_r2c", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
m.impl("linalg_vector_norm", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
m.impl("sgn.out", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
- m.impl("nonzero", torch::CppFunction::makeFromBoxedFunction<&mps_fallback>());
m.impl("_slow_conv2d_forward", slow_conv2d_forward_mps);
}
diff --git a/aten/src/ATen/native/mps/operations/Indexing.mm b/aten/src/ATen/native/mps/operations/Indexing.mm
index 331114d..a97f96b 100644
--- a/aten/src/ATen/native/mps/operations/Indexing.mm
+++ b/aten/src/ATen/native/mps/operations/Indexing.mm
@@ -12,6 +12,7 @@
#include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/native/mps/OperationUtils.h>
#include <ATen/native/mps/operations/Indexing.h>
+#include <ATen/native/mps/MPSGraphVenturaOps.h>
#include <ATen/native/Resize.h>
#include <ATen/AccumulateType.h>
#include <torch/library.h>
@@ -211,6 +212,185 @@
return result;
}
+static
+Tensor nonzero_fallback(const Tensor& self) {
+ TORCH_WARN_ONCE("MPS: nonzero op is supported natively starting from macOS 13.0. ",
+ "Falling back on CPU. This may have performance implications.");
+
+ return at::nonzero(self.to("cpu")).clone().to("mps");
+}
+
+Tensor& nonzero_out_mps(const Tensor& self, Tensor& out_){
+ if (!is_macos_13_or_newer()) {
+ Tensor out_fallback = nonzero_fallback(self);
+ at::native::resize_output(out_, out_fallback.sizes());
+ out_.copy_(out_fallback.to("mps"));
+ return out_;
+ }
+
+ using namespace mps;
+ const uint32_t maxDimensions = 16;
+
+ TORCH_CHECK(self.numel() < std::numeric_limits<int>::max(), "nonzero is not supported for tensors with more than INT_MAX elements, \
+ file a support request");
+ TORCH_CHECK(out_.dtype() == at::kLong, "Expected object of scalar type ", at::kLong, " as out, but got ", out_.dtype());
+ TORCH_CHECK(self.device() == out_.device(), "expected self and out to be on the same device, but got out on ",
+ out_.device(), " and self on ", self.device());
+ TORCH_CHECK(self.dim() <= maxDimensions, "nonzero is not supported for tensor with more than ", 16, " dimensions");
+ TORCH_CHECK(out_.is_mps());
+
+ MPSStream *stream = getCurrentMPSStream();
+ struct CachedGraph : public MPSCachedGraph
+ {
+ CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
+ MPSGraphTensor* inputTensor_ = nil;
+ MPSGraphTensor* outputTensor_ = nil;
+ MPSGraphTensor* scatterDataTensor_ = nil;
+ };
+
+ int64_t total_nonzero = at::count_nonzero(self).item<int64_t>();
+ int64_t nDim = self.dim();
+ at::native::resize_output(out_, {total_nonzero, nDim});
+ if (out_.numel() == 0) {
+ return out_;
+ }
+
+ bool contiguous_output = (out_.is_contiguous() && !out_.is_view());
+ Tensor out = out_;
+ if (!contiguous_output) {
+ out = at::native::empty_mps(
+ out_.sizes(),
+ out_.scalar_type(),
+ c10::nullopt,
+ kMPS,
+ c10::nullopt,
+ c10::nullopt);
+ }
+
+ int64_t _apparentInputShape = 1;
+ for (auto dim : self.sizes()) {
+ _apparentInputShape *= dim;
+ }
+ MPSShape *apparentOutputShape = @[@(total_nonzero * nDim)];
+ MPSShape *apparentInputShape = @[@(_apparentInputShape)];
+
+ // Pseudocode:
+ //
+ // inputTensor = [1, 0, 0, 3]
+ // inputNonZero = [1, 0, 0, 1]
+ // indices = [1, 1, 1, 2]
+ // maskedIndices = [0, -1, -1, 1]
+ // coordinates = [0, 1, 2, 3]
+ // scatterResult = [0, 3]
+
+ MPSGraphCache* cache_ = MPSGraphCache::getInstance();
+ @autoreleasepool {
+ string key = "nonzero_out_mps" + getTensorsStringKey(self);
+ CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
+
+ if(!cachedGraph) {
+ MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
+ CachedGraph *newCachedGraph = nil;
+ @autoreleasepool {
+ MPSDataType inputDataType = getMPSDataType(self.scalar_type());
+ MPSShape* inputShape = getMPSShape(self);
+ MPSGraph* mpsGraph = make_mps_graph();
+ newCachedGraph = new CachedGraph(mpsGraph);
+
+ MPSGraphTensor *inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, getMPSScalarType(self.scalar_type()), apparentInputShape);
+ MPSGraphTensor *scatterDataTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSScalarType(out.scalar_type()));
+ MPSGraphTensor *zeroTensor = [mpsGraph constantWithScalar:0.0 dataType:inputDataType];
+ MPSGraphTensor *oneTensor = [mpsGraph constantWithScalar:1.0 dataType:MPSDataTypeInt32];
+ MPSGraphTensor *minusMaxDimTensor = [mpsGraph constantWithScalar:-maxDimensions dataType:MPSDataTypeInt32];
+ MPSGraphTensor *inputNotEqualToZeroTensor = [mpsGraph notEqualWithPrimaryTensor:inputTensor
+ secondaryTensor:zeroTensor
+ name:nil];
+ MPSGraphTensor *maskTensor = [mpsGraph castTensor:inputNotEqualToZeroTensor
+ toType:MPSDataTypeInt32
+ name:@"castToInt32"];
+ MPSGraphTensor *indicesTensor = [mpsGraph cumulativeSumWithTensor:maskTensor
+ axis:0
+ name:nil];
+ MPSGraphTensor *indicesMinusOneTensor = [mpsGraph subtractionWithPrimaryTensor:indicesTensor
+ secondaryTensor:oneTensor
+ name:nil];
+ MPSGraphTensor *maskedIndicesTensor = [mpsGraph selectWithPredicateTensor:inputNotEqualToZeroTensor
+ truePredicateTensor:indicesMinusOneTensor
+ falsePredicateTensor:minusMaxDimTensor
+ name:nil];
+ MPSGraphTensor *coordinatesTensor = [mpsGraph reshapeTensor:[mpsGraph coordinateAlongAxis:0 withShape:inputShape name:nil]
+ withShape:@[@-1]
+ name:nil];
+ if (nDim > 1) {
+ NSMutableArray<MPSGraphTensor*> *maskedIndicesTensorArray = [NSMutableArray arrayWithCapacity:nDim];
+ NSMutableArray<MPSGraphTensor*> *coordinatesTensorArray = [NSMutableArray arrayWithCapacity:nDim];
+
+ MPSGraphTensor *constantRankTensor = [mpsGraph constantWithScalar:nDim
+ dataType:MPSDataTypeInt32];
+ maskedIndicesTensorArray[0] = [mpsGraph multiplicationWithPrimaryTensor:maskedIndicesTensor
+ secondaryTensor:constantRankTensor
+ name:nil];
+ coordinatesTensorArray[0] = coordinatesTensor;
+ for (int i = 1; i < nDim; i++){
+ maskedIndicesTensorArray[i] = [mpsGraph additionWithPrimaryTensor:maskedIndicesTensorArray[i - 1]
+ secondaryTensor:oneTensor
+ name:nil];
+ coordinatesTensorArray[i] = [mpsGraph reshapeTensor:[mpsGraph coordinateAlongAxis:i withShape:inputShape name:nil]
+ withShape:@[@-1]
+ name:nil];
+ }
+ maskedIndicesTensor = [mpsGraph concatTensors:maskedIndicesTensorArray dimension:0 interleave:YES name:nil];
+ coordinatesTensor = [mpsGraph concatTensors:coordinatesTensorArray dimension:0 interleave:YES name:nil];
+ }
+
+ MPSGraphTensor *outputTensor = [mpsGraph scatterWithDataTensor:scatterDataTensor
+ updatesTensor:coordinatesTensor
+ indicesTensor:maskedIndicesTensor
+ axis:0
+ mode:MPSGraphScatterModeSet
+ name:nil];
+
+ newCachedGraph->inputTensor_ = inputTensor;
+ newCachedGraph->scatterDataTensor_ = scatterDataTensor;
+ newCachedGraph->outputTensor_ = outputTensor;
+ }
+ return newCachedGraph;
+ });
+ cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
+ }
+
+ Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, apparentInputShape);
+ Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, contiguous_output ? out_ : out, apparentOutputShape);
+ Placeholder scatterPlaceholder = Placeholder(cachedGraph->scatterDataTensor_, contiguous_output ? out_ : out, apparentOutputShape);
+
+ // Create dictionary of inputs and outputs
+ NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
+ selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData(),
+ scatterPlaceholder.getMPSGraphTensor() : scatterPlaceholder.getMPSGraphTensorData()
+ };
+
+ NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
+ outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
+ };
+
+ runMPSGraph(stream, cachedGraph->graph(), feeds, results);
+ if (!contiguous_output) {
+ out_.copy_(out);
+ }
+ }
+
+ return out_;
+}
+
+Tensor nonzero_mps(const Tensor& self){
+ if (!is_macos_13_or_newer()) {
+ return nonzero_fallback(self);
+ }
+
+ Tensor out = at::empty({0}, self.options().dtype(kLong));
+ return nonzero_out_mps(self, out);
+}
+
Tensor masked_select_mps(const Tensor & self, const Tensor & mask) {
namedinference::compute_broadcast_outnames(self, mask);
Tensor result = at::empty({0}, self.options());
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 3e485d6..9b42c97 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -8521,6 +8521,7 @@
dispatch:
CPU: nonzero_out_cpu
CUDA: nonzero_out_cuda
+ MPS: nonzero_out_mps
tags: dynamic_output_shape
- func: nonzero(Tensor self) -> Tensor
@@ -8528,6 +8529,7 @@
dispatch:
CPU: nonzero_cpu
CUDA: nonzero_cuda
+ MPS: nonzero_mps
tags: [dynamic_output_shape, canonical]
- func: nonzero_numpy(Tensor self) -> Tensor[]
diff --git a/test/test_mps.py b/test/test_mps.py
index d3308d9..00dadbe 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -6687,6 +6687,116 @@
supported_dtypes = [torch.float32, torch.float16, torch.int64, torch.int32, torch.int16, torch.uint8]
supported_np_dtypes = [np.float32, np.float16, np.int64, np.int32, np.int16, np.uint8]
+ def test_nonzero_no_warning(self):
+ device = "mps"
+ t = torch.randn((2, 2), device=device)
+ with warnings.catch_warnings(record=True) as w:
+ warnings.simplefilter("always")
+ torch.nonzero(t)
+ t.nonzero()
+ self.assertEqual(len(w), 0)
+
+ def test_nonzero(self):
+ def helper(dtype):
+ device = "mps"
+ shapes = [
+ torch.Size((12,)),
+ torch.Size((12, 1)),
+ torch.Size((1, 12)),
+ torch.Size((6, 2)),
+ torch.Size((3, 2, 2)),
+ torch.Size((5, 5, 5)),
+ ]
+
+ def gen_nontrivial_input(shape, dtype, device):
+ if dtype != torch.bfloat16:
+ return torch.randint(2, shape, device=device, dtype=dtype)
+ else:
+ # windows does not work for bfloat16 randing
+ return torch.randint(2, shape, device=device, dtype=torch.float).to(dtype)
+
+ for shape in shapes:
+ tensor = gen_nontrivial_input(shape, dtype, device)
+ dst1 = torch.nonzero(tensor, as_tuple=False)
+ dst2 = tensor.nonzero(as_tuple=False)
+ dst3 = torch.empty([], dtype=torch.long, device=device)
+ dst3 = dst3.resize_(0)
+ torch.nonzero(tensor, out=dst3)
+ np_array = tensor.cpu().numpy() if dtype != torch.bfloat16 else tensor.float().cpu().numpy()
+ np_result = torch.from_numpy(np.stack(np_array.nonzero())).t()
+ self.assertEqual(dst1.cpu(), np_result, atol=0, rtol=0)
+ self.assertEqual(dst2.cpu(), np_result, atol=0, rtol=0)
+ self.assertEqual(dst3.cpu(), np_result, atol=0, rtol=0)
+ tup1 = torch.nonzero(tensor, as_tuple=True)
+ tup2 = tensor.nonzero(as_tuple=True)
+ tup1 = torch.stack(tup1).t().cpu()
+ tup2 = torch.stack(tup2).t().cpu()
+ self.assertEqual(tup1, np_result, atol=0, rtol=0)
+ self.assertEqual(tup2, np_result, atol=0, rtol=0)
+ [helper(dtype) for dtype in self.supported_dtypes]
+
+ def test_nonzero_astuple_out(self):
+ device = "mps"
+ t = torch.randn((3, 3, 3), device=device)
+ out = torch.empty([], dtype=torch.long, device=device)
+ out = out.resize_(0)
+
+ with self.assertRaises(RuntimeError):
+ torch.nonzero(t, as_tuple=True, out=out)
+
+ self.assertEqual(torch.nonzero(t, as_tuple=False, out=out), torch.nonzero(t, out=out))
+
+ # Verifies that JIT script cannot handle the as_tuple kwarg
+ # See Issue https://github.com/pytorch/pytorch/issues/45499.
+ def _foo(t):
+ tuple_result = torch.nonzero(t, as_tuple=True)
+ nontuple_result = torch.nonzero(t, as_tuple=False)
+ out = torch.empty_like(nontuple_result)
+ torch.nonzero(t, as_tuple=False, out=out)
+ return tuple_result, nontuple_result, out
+
+ with self.assertRaises(RuntimeError):
+ scripted_foo = torch.jit.script(_foo)
+
+ # Verifies that JIT tracing works fine
+ traced_foo = torch.jit.trace(_foo, t)
+ traced_tuple, traced_nontuple, traced_out = traced_foo(t)
+ expected_tuple = torch.nonzero(t, as_tuple=True)
+ expected_nontuple = torch.nonzero(t)
+
+ self.assertEqual(traced_tuple, expected_tuple)
+ self.assertEqual(traced_nontuple, expected_nontuple)
+ self.assertEqual(traced_out, expected_nontuple)
+
+ def test_nonzero_discontiguous(self):
+ device = "mps"
+ shape = (4, 4)
+ tensor = torch.randint(2, shape, device=device)
+ tensor_nc = torch.empty(shape[0], shape[1] * 2, device=device)[:, ::2].copy_(tensor)
+ dst1 = tensor.nonzero(as_tuple=False)
+ dst2 = tensor_nc.nonzero(as_tuple=False)
+ self.assertEqual(dst1, dst2, atol=0, rtol=0)
+ dst3 = torch.empty_like(dst1)
+ data_ptr = dst3.data_ptr()
+ # expect dst3 storage to be reused
+ torch.nonzero(tensor, out=dst3)
+ self.assertEqual(data_ptr, dst3.data_ptr())
+ self.assertEqual(dst1, dst3, atol=0, rtol=0)
+ # discontiguous out
+ dst4 = torch.empty(dst1.size(0), dst1.size(1) * 2, dtype=torch.long, device=device)[:, ::2]
+ data_ptr = dst4.data_ptr()
+ strides = dst4.stride()
+ torch.nonzero(tensor, out=dst4)
+ self.assertEqual(data_ptr, dst4.data_ptr())
+ self.assertEqual(dst1, dst4, atol=0, rtol=0)
+ self.assertEqual(strides, dst4.stride())
+
+ def test_nonzero_non_diff(self):
+ device = "mps"
+ x = torch.randn(10, requires_grad=True)
+ nz = x.nonzero()
+ self.assertFalse(nz.requires_grad)
+
def test_masked_select(self):
x = torch.randn(3, 4)
x_mps = x.to("mps")
@@ -7841,7 +7951,8 @@
'vsplit': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'vstack': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'zero_': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
- 'where': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8']
+ 'where': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
+ 'nonzero': ['f32', 'i16', 'i32', 'i64']
}
@@ -8066,6 +8177,8 @@
'slice_scatter': [torch.uint8],
'square': [torch.bool, torch.int16, torch.int32, torch.int64, torch.uint8], # moved from section below
+ # count_nonzero returns wrong results for these dtypes
+ 'nonzero': [torch.uint8, torch.float16],
# ALLOW_LIST doesn't know about variants
'nn.functional.padconstant': None,
@@ -8141,7 +8254,6 @@
'eq': None,
'mul': None,
'cartesian_prod': None,
- 'nonzero': None,
'bool': None,
'inner': None,
'dstack': None,