[MPS] Support for median with dim (#88807)
## Summary ⚡
**Aim**: Add support for aten::median for MPS backend (Fixes #87220)
This is fresh clean PR from the previous [PR](https://github.com/pytorch/pytorch/pull/88554)
- Implementing the new median function in aten/src/ATen/native/mps/operations/ReduceOps.mm
- Adding it to aten/src/ATen/native/native_functions.yaml
- Adding it to existing test_median
### **this will works like this** 🪶
median of entire input tensor on MPS
`torch.median(mps_inputTensor)`
median of along a dim
`torch.median(mps_inputTensor, dim=[int], keepdim=[Bool])`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88807
Approved by: https://github.com/kulinseth
diff --git a/aten/src/ATen/native/mps/MPSGraphVenturaOps.h b/aten/src/ATen/native/mps/MPSGraphVenturaOps.h
index 86153b5..b77db66 100644
--- a/aten/src/ATen/native/mps/MPSGraphVenturaOps.h
+++ b/aten/src/ATen/native/mps/MPSGraphVenturaOps.h
@@ -6,4 +6,12 @@
- (MPSGraphTensor *)cumulativeSumWithTensor:(MPSGraphTensor *)tensor
axis:(NSInteger)axis
name:(NSString *)name;
+
+- (MPSGraphTensor *)sortWithTensor:(MPSGraphTensor *)tensor
+ axis:(NSInteger)axis
+ name:(NSString *)name;
+
+- (MPSGraphTensor *)argSortWithTensor:(MPSGraphTensor *)tensor
+ axis:(NSInteger)axis
+ name:(NSString *)name;
@end
diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm
index 91aa245..c99f22d 100644
--- a/aten/src/ATen/native/mps/operations/ReduceOps.mm
+++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm
@@ -9,6 +9,7 @@
#include <ATen/native/ReduceOpsUtils.h>
#include <ATen/native/Pool.h>
#include <torch/library.h>
+#include <ATen/native/mps/MPSGraphVenturaOps.h>
namespace at {
namespace native {
@@ -1638,5 +1639,319 @@
return min_max_mps(input_t, dim, keepdim, MPSReductionType::MIN, "min_mps");
}
+// Median of entire tensor into scalar result
+Tensor median_mps(const Tensor& input_t) {
+
+ if(!is_macos_13_or_newer()){
+ TORCH_WARN_ONCE("MPS: median op is supported natively starting from macOS 13.0. ",
+ "Falling back on CPU. This may have performace implications.");
+ return at::median(input_t.to("cpu"));
+ }
+
+ TORCH_INTERNAL_ASSERT(input_t.scalar_type() != ScalarType::Long, "median not supported for Long dtype on MPS");
+
+ namespace native_mps = at::native::mps;
+ using CachedGraph = native_mps::MPSUnaryCachedGraph;
+
+ native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance();
+
+ IntArrayRef input_shape = input_t.sizes();
+ int64_t num_input_dims = input_shape.size();
+
+ // calculate total no. of elements in the input tensor to reduce it to one dimension
+ NSMutableArray<NSNumber*> *apparent_input_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:1];
+ int64_t num_in_elements = 1;
+ for(int i = 0; i < num_input_dims; i++) {
+ num_in_elements *= input_shape[i];
+ }
+
+ apparent_input_shape[0] = [NSNumber numberWithInt:num_in_elements];
+
+ Tensor output_t = at::native::empty_mps({}, input_t.scalar_type(), c10::nullopt, kMPS, c10::nullopt, c10::nullopt);
+
+ if (output_t.numel() == 0 || num_in_elements == 0) {
+ return output_t;
+ }
+
+ @autoreleasepool {
+ string key = "median_mps:"+ mps::getMPSTypeString(input_t.scalar_type()) + mps::getTensorsStringKey(input_t);
+ CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
+ // Initialize once if configuration not found in cache
+ if(!cachedGraph) {
+ native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () {
+
+ CachedGraph *newCachedGraph = nil;
+
+ @autoreleasepool {
+ MPSGraph* mpsGraph = native_mps::make_mps_graph();
+ newCachedGraph = new CachedGraph(mpsGraph);
+
+ MPSGraphTensor* inputTensor = native_mps::mpsGraphRankedPlaceHolder(mpsGraph, input_t);
+
+ MPSGraphTensor* outputTensor = nil;
+
+ MPSGraphTensor * reshapedTensor = [mpsGraph reshapeTensor:inputTensor
+ withShape:@[@-1]
+ name:nil];
+ MPSGraphTensor * sortedTensor = [mpsGraph
+ sortWithTensor:reshapedTensor
+ axis:((NSUInteger) (int)0)
+ name:nil];
+
+ outputTensor = [mpsGraph sliceTensor:sortedTensor
+ dimension:0
+ start:((NSUInteger) (int)((num_in_elements+1)/2 ) - 1)
+ length:1
+ name:nil];
+
+ newCachedGraph->inputTensor_ = inputTensor;
+ newCachedGraph->outputTensor_ = outputTensor;
+ }
+ return newCachedGraph;
+ });
+ cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
+ }
+
+ auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t);
+ auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output_t, @[@1]);
+
+ NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
+ inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
+ };
+
+ NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
+ outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
+ };
+
+ native_mps::runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
+ }
+
+ return output_t;
+}
+
+
+void median_out_mps
+ (const Tensor& input_t,
+ int64_t dim,
+ bool keepdim,
+ const Tensor& output_t,
+ const Tensor& indices_t,
+ const std::string& func_name) {
+
+ namespace native_mps = at::native::mps;
+
+ if (output_t.numel() == 0) {
+ return;
+ }
+ if (input_t.numel() == 1 && input_t.dim() == 0) {
+ output_t.fill_(input_t);
+ indices_t.fill_(0);
+ return;
+ }
+
+ // Derive from MPSCachedGraph
+ struct CachedGraph : public native_mps::MPSCachedGraph
+ {
+ CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
+ MPSGraphTensor *inputTensor_ = nil;
+ MPSGraphTensor *outputTensor_ = nil;
+ MPSGraphTensor *indicesTensor_ = nil;
+ };
+
+ native_mps::MPSGraphCache* cache_ = native_mps::MPSGraphCache::getInstance();
+
+ int64_t dim_ = maybe_wrap_dim(dim, input_t.dim());
+
+ // Calculate the output shape according to keepdim=True
+ // If there is no dim argument, the input shape is flattened
+ IntArrayRef input_shape = input_t.sizes();
+ int64_t num_input_dims = input_shape.size();
+ NSMutableArray<NSNumber*> *apparent_out_shape = nil;
+
+ apparent_out_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
+ for(int i = 0; i < num_input_dims; i++) {
+ if(dim_ == i)
+ apparent_out_shape[i] = @1;
+ else
+ apparent_out_shape[i] = [NSNumber numberWithInt:input_shape[i]];
+ }
+ int dim_total_elements = input_shape[dim_];
+
+ auto stream = at::mps::getCurrentMPSStream();
+
+ @autoreleasepool {
+ string key = func_name + ":" + to_string(dim_) + ":" + native_mps::getMPSTypeString(input_t.scalar_type());
+ CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
+
+ if(!cachedGraph) {
+ native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () {
+
+ CachedGraph *newCachedGraph = nil;
+
+ @autoreleasepool {
+ MPSGraph* mpsGraph = native_mps::make_mps_graph();
+ newCachedGraph = new CachedGraph(mpsGraph);
+
+ MPSGraphTensor* inputTensor = native_mps::mpsGraphUnrankedPlaceHolder(mpsGraph, native_mps::getMPSDataType(input_t.scalar_type()));
+ MPSGraphTensor* outputTensor = nil;
+ MPSGraphTensor * sortedTensor = [mpsGraph
+ sortWithTensor:inputTensor
+ axis:((NSUInteger) (int)dim_)
+ name:nil];
+
+ outputTensor = [mpsGraph sliceTensor:sortedTensor
+ dimension:dim_
+ start:((NSUInteger) (int)((dim_total_elements+1)/2 ) - 1)
+ length:1
+ name:nil];
+ MPSGraphTensor* argreduceOutTensor = nil;
+ argreduceOutTensor = [mpsGraph argSortWithTensor:inputTensor
+ axis:(NSInteger)dim_
+ name:@"argmax_out"];
+ MPSGraphTensor* argOutputTensor = [mpsGraph sliceTensor:argreduceOutTensor
+ dimension:dim_
+ start:((NSUInteger) (int)((dim_total_elements+1)/2 ) - 1)
+ length:1
+ name:nil];
+
+ newCachedGraph->inputTensor_ = inputTensor;
+ newCachedGraph->outputTensor_ = outputTensor;
+ newCachedGraph->indicesTensor_ = argOutputTensor;
+ }
+ return newCachedGraph;
+ });
+ cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
+ }
+
+ auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, input_t);
+ auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output_t, apparent_out_shape);
+ auto indicesPlaceholder = native_mps::Placeholder(cachedGraph->indicesTensor_, indices_t, apparent_out_shape);
+
+ NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
+ inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
+ };
+
+ NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
+ outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData(),
+ indicesPlaceholder.getMPSGraphTensor() : indicesPlaceholder.getMPSGraphTensorData()
+ };
+
+ native_mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
+
+ }
+
+}
+
+// in case mps sortWithTensor do not supported on macOS
+std::tuple<Tensor&, Tensor&> median_from_cpu(
+ const Tensor& self,
+ int64_t dim,
+ bool keepdim, Tensor & valuesI, Tensor & indicesI, IntArrayRef vec_out_shape, IntArrayRef vec_apparent_out_shape) {
+ // Tensor a = at::median(self.to("cpu"));
+ Tensor values;
+ Tensor indices;
+ if (!keepdim){
+ values = at::empty({vec_out_shape}, self.options());
+ indices = at::empty({vec_out_shape}, self.options().dtype(kLong));
+
+ }
+ else{
+ values = at::empty({vec_apparent_out_shape}, self.options());
+ indices = at::empty({vec_apparent_out_shape}, self.options().dtype(kLong));
+ }
+ at::median_out(values, indices, self, dim, keepdim);
+
+ valuesI.copy_(values);
+ indicesI.copy_(indices);
+ return std::forward_as_tuple(valuesI, indicesI);
+}
+
+TORCH_API ::std::tuple<at::Tensor &,at::Tensor &> median_out_mps
+ (const at::Tensor & input_t,
+ int64_t dim,
+ bool keepdim,
+ at::Tensor & values,
+ at::Tensor & indices){
+
+ TORCH_INTERNAL_ASSERT(input_t.scalar_type() != ScalarType::Long, "median not supported for Long dtype on MPS");
+
+ namespace native_mps = at::native::mps;
+ int64_t dim_ = maybe_wrap_dim(dim, input_t.dim());
+ native::zero_numel_check_dims(input_t, dim_, "max()");
+
+ // Calculate the output shape according to keepdim=True
+ // If there is no dim argument, the input shape is flattened
+ IntArrayRef input_shape = input_t.sizes();
+ int64_t num_input_dims = input_shape.size();
+ NSMutableArray<NSNumber*> *apparent_out_shape = nil;
+ // Use this if keepdim is false
+ int64_t num_output_dims = num_input_dims - 1;
+
+ std::vector<int64_t> vec_apparent_out_shape(num_input_dims);
+ std::vector<int64_t> vec_out_shape(num_output_dims);
+
+ apparent_out_shape = [NSMutableArray<NSNumber*> arrayWithCapacity:num_input_dims];
+ // Counter for shape when keepdim is false
+ int out_i = 0;
+ for(int i = 0; i < num_input_dims; i++) {
+ if(dim_ == i) {
+ apparent_out_shape[i] = @1;
+ vec_apparent_out_shape[i] = 1;
+ }
+ else {
+ apparent_out_shape[i] = [NSNumber numberWithInt:input_shape[i]];
+ vec_apparent_out_shape[i] = input_shape[i];
+ vec_out_shape[out_i] = input_shape[i];
+ out_i++;
+ }
+ }
+
+ if(!keepdim) {
+ values = at::native::empty_mps(
+ IntArrayRef(vec_out_shape),
+ input_t.scalar_type(),
+ c10::nullopt,
+ kMPS,
+ c10::nullopt,
+ c10::nullopt);
+ indices = at::native::empty_mps(
+ IntArrayRef(vec_out_shape),
+ ScalarType::Long,
+ c10::nullopt,
+ kMPS,
+ c10::nullopt,
+ c10::nullopt);
+ } else {
+ values = at::native::empty_mps(
+ IntArrayRef(vec_apparent_out_shape),
+ input_t.scalar_type(),
+ c10::nullopt,
+ kMPS,
+ c10::nullopt,
+ c10::nullopt);
+ indices = at::native::empty_mps(
+ IntArrayRef(vec_apparent_out_shape),
+ ScalarType::Long,
+ c10::nullopt,
+ kMPS,
+ c10::nullopt,
+ c10::nullopt);
+ }
+
+ if (values.numel() == 0 || input_t.numel() == 0) {
+ return std::tuple<Tensor&, Tensor&>{values, indices};
+ }
+
+ if(!is_macos_13_or_newer()){
+ TORCH_WARN_ONCE("MPS: median op is supported natively starting from macOS 13.0.",
+ "Falling back on CPU. This may have performace implications.");
+ return median_from_cpu(input_t.to("cpu"), dim, keepdim, values, indices, IntArrayRef(vec_out_shape),IntArrayRef(vec_apparent_out_shape) );
+ }
+
+ median_out_mps(input_t, dim, keepdim, values, indices, "median_out_mps");
+
+ return std::tuple<Tensor&, Tensor&>{values, indices};
+}
+
} // native
} // at
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 8046b4f..b1d1094 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -3492,6 +3492,7 @@
dispatch:
CPU: median_cpu
CUDA: median_cuda
+ MPS: median_mps
autogen: median.out
- func: median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)
@@ -3503,6 +3504,7 @@
dispatch:
CPU: median_out_cpu
CUDA: median_out_cuda
+ MPS: median_out_mps
- func: median.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)
variants: function, method
diff --git a/test/test_mps.py b/test/test_mps.py
index 31e2e36..52d6695 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2624,6 +2624,47 @@
helper(2, 8, 4, 5, torch.int32)
# helper(2, 8, 4, 5, torch.int64)
+ def test_median(self):
+ def helper_dtype_int32(n1, n2, n3):
+ cpu_x = torch.randint(50, (n1, n2, n3), device='cpu', dtype=torch.int32)
+ mps_x = cpu_x.detach().clone().to('mps')
+
+ result_cpu = torch.median(cpu_x)
+ result_mps = torch.median(mps_x)
+
+ self.assertEqual(result_cpu, result_mps)
+
+ for dim in [0, 1, 2]:
+ for keepdim in [True, False]:
+ y, idx = torch.median(cpu_x, dim=dim, keepdim=keepdim)
+ refy, refidx = torch.median(mps_x, dim=dim, keepdim=keepdim)
+ self.assertEqual(y, refy)
+ self.assertEqual(idx, refidx)
+
+ def helper_dtype_float32(n1, n2, n3):
+ cpu_x = torch.randn(n1, n2, n3, device='cpu', dtype=torch.float32)
+ mps_x = cpu_x.detach().clone().to('mps')
+
+ result_cpu = torch.median(cpu_x)
+ result_mps = torch.median(mps_x)
+
+ self.assertEqual(result_cpu, result_mps)
+
+ for dim in [0, 1, 2]:
+ for keepdim in [True, False]:
+ y, idx = torch.median(cpu_x, dim=dim, keepdim=keepdim)
+ refy, refidx = torch.median(mps_x, dim=dim, keepdim=keepdim)
+ self.assertEqual(y, refy)
+ self.assertEqual(idx, refidx)
+
+ helper_dtype_int32(10, 10, 10) # median at even place
+ helper_dtype_int32(3, 3, 3) # median at odd place
+ helper_dtype_int32(1, 1, 1)
+ helper_dtype_int32(1, 2, 3)
+ helper_dtype_float32(10, 10, 10)
+ helper_dtype_float32(3, 3, 3)
+ helper_dtype_float32(1, 1, 1)
+
def test_any(self):
def helper(shape):
input_xs = []