[MPS] Fix and unblock TestConsistency for median (#94489)
- fix num_output_dims calculation
- fix median_out_mps key
- cast tensor sent to sortWithTensor and argSortWithTensor
- note down same issue for unique
- unblock median from blocklist
- adding test_median_int16 test
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94489
Approved by: https://github.com/razarmehr
diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm
index 88df3af..6f3b8d7 100644
--- a/aten/src/ATen/native/mps/operations/ReduceOps.mm
+++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm
@@ -1751,11 +1751,21 @@
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
-
auto inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
auto reshapedTensor = [mpsGraph reshapeTensor: inputTensor
withShape: @[@-1]
name: nil];
+ MPSDataType dataType = [inputTensor dataType];
+ // #issue 104398441 sortWithTensor only supports following types, cast if necessary
+ if (dataType != MPSDataTypeInt32 &&
+ dataType != MPSDataTypeFloat32 &&
+ dataType != MPSDataTypeFloat16) {
+ dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
+ reshapedTensor = [mpsGraph castTensor:reshapedTensor
+ toType:dataType
+ name:@"castReshapedTensor"];
+ }
+
auto sortedTensor = [mpsGraph sortWithTensor: reshapedTensor
axis: ((NSUInteger) (int)0)
name: nil];
@@ -1835,7 +1845,7 @@
auto stream = at::mps::getCurrentMPSStream();
@autoreleasepool {
- string key = func_name + ":" + to_string(dim_) + ":" + getTensorsStringKey(input_t);
+ string key = func_name + ":" + to_string(dim_) + ":" + getTensorsStringKey(input_t) + ":" + getTensorsStringKey(indices_t);
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
if (!cachedGraph) {
@@ -1847,24 +1857,39 @@
auto mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
- MPSGraphTensor* inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_t);
- auto sortedTensor = [mpsGraph sortWithTensor: inputTensor
- axis: (NSUInteger)dim_
- name: nil];
- const NSUInteger midpoint = (dim_total_elements + 1) / 2 - 1;
- auto outputTensor = [mpsGraph sliceTensor:sortedTensor
- dimension:dim_
- start:midpoint
- length:1
- name:nil];
- auto argreduceOutTensor = [mpsGraph argSortWithTensor:inputTensor
- axis:(NSInteger)dim_
- name:@"argmax_out"];
- auto argOutputTensor = [mpsGraph sliceTensor:argreduceOutTensor
- dimension:dim_
- start:midpoint
- length:1
- name:nil];
+ MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(input_t.scalar_type()));
+ MPSGraphTensor* outputTensor = nil;
+ MPSGraphTensor* castInputTensor = inputTensor;
+ MPSDataType dataType = getMPSDataType(input_t.scalar_type());
+ // #issue 104398441 sortWithTensor only supports following types, cast if necessary
+ if (dataType != MPSDataTypeInt32 &&
+ dataType != MPSDataTypeFloat32 &&
+ dataType != MPSDataTypeFloat16) {
+ dataType = (dataType & MPSDataTypeFloatBit) ? MPSDataTypeFloat32 : MPSDataTypeInt32;
+ castInputTensor = [mpsGraph castTensor:inputTensor
+ toType:dataType
+ name:@"castInputTensor"];
+ }
+
+ MPSGraphTensor * sortedTensor = [mpsGraph
+ sortWithTensor:castInputTensor
+ axis:((NSUInteger) (int)dim_)
+ name:nil];
+
+ outputTensor = [mpsGraph sliceTensor:sortedTensor
+ dimension:dim_
+ start:((NSUInteger) (int)((dim_total_elements+1)/2 ) - 1)
+ length:1
+ name:nil];
+ MPSGraphTensor* argreduceOutTensor = nil;
+ argreduceOutTensor = [mpsGraph argSortWithTensor:castInputTensor
+ axis:(NSInteger)dim_
+ name:@"argmax_out"];
+ MPSGraphTensor* argOutputTensor = [mpsGraph sliceTensor:argreduceOutTensor
+ dimension:dim_
+ start:((NSUInteger) (int)((dim_total_elements+1)/2 ) - 1)
+ length:1
+ name:nil];
newCachedGraph->inputTensor_ = inputTensor;
newCachedGraph->outputTensor_ = outputTensor;
@@ -1934,7 +1959,7 @@
int64_t num_input_dims = input_shape.size();
NSMutableArray<NSNumber*> *apparent_out_shape = nil;
// Use this if keepdim is false
- int64_t num_output_dims = num_input_dims - 1;
+ int64_t num_output_dims = num_input_dims - 1 < 0 ? 0 : num_input_dims - 1;
std::vector<int64_t> vec_apparent_out_shape(num_input_dims);
std::vector<int64_t> vec_out_shape(num_output_dims);
diff --git a/aten/src/ATen/native/mps/operations/Unique.mm b/aten/src/ATen/native/mps/operations/Unique.mm
index 109244b..eac16a7 100644
--- a/aten/src/ATen/native/mps/operations/Unique.mm
+++ b/aten/src/ATen/native/mps/operations/Unique.mm
@@ -57,7 +57,7 @@
return {resultTensor, inverseIndicesTensor, countTensor, lengthTensor};
}
- // Sort only supports following types, cast if necessary
+ // #issue 104398441 sortWithTensor only supports following types, cast if necessary
if (dataType != MPSDataTypeInt32 &&
dataType != MPSDataTypeFloat32 &&
dataType != MPSDataTypeFloat16) {
diff --git a/test/test_mps.py b/test/test_mps.py
index 9002a0a..3cd98df 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2325,6 +2325,17 @@
helper(dtype, noncontiguous, dim)
+ def test_median_int16(self):
+ def helper(shape, dtype):
+ cpu_x = torch.randint(-9999, 9999, shape, device='cpu', dtype=dtype)
+ x = cpu_x.detach().clone().to('mps')
+
+ median_result = torch.median(x)
+ median_result_cpu = torch.median(cpu_x)
+ self.assertEqual(median_result, median_result_cpu)
+
+ helper((2, 8, 4, 5), torch.int16)
+
class TestLogical(TestCase):
def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):
return torch.tensor(x, device=device, dtype=dtype, requires_grad=requires_grad)