[MPS] Fix the ChannelsLast memory format in cat_out_mps() (#91786)
- Fixed the memory leak with the `malloc()`
- Introduced shortened data type strings (optional) to avoid getting extra long cached graph string keys with ops such as cat_out()
- Fixed data type issues in Monterey
- Removed the unused `use_scalar_value` argument from `getTensorsStringKey()`
- Clean up and refactoring
Fixes #89353
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91786
Approved by: https://github.com/kulinseth
diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h
index 5cba91d..e6a6ef4 100644
--- a/aten/src/ATen/native/mps/OperationUtils.h
+++ b/aten/src/ATen/native/mps/OperationUtils.h
@@ -42,12 +42,12 @@
MPSDataType getMPSDataType(ScalarType scalar_type);
MPSDataType getMPSScalarType(ScalarType scalar_type);
MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type);
-std::string getMPSTypeString(ScalarType scalar_type);
+std::string getMPSTypeString(ScalarType scalar_type, bool short_name = false);
std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type);
NSArray<NSNumber*>* getTensorAxes(const Tensor& t);
NSArray<NSNumber*>* getTensorAxes(const Tensor& t, at::OptionalIntArrayRef dim);
std::string getMPSShapeString(MPSShape* shape);
-std::string getTensorsStringKey(const TensorList& tensors, bool use_scalar_value = false);
+std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = false);
std::string getArrayRefString(const IntArrayRef s);
// use has_storage() on the returned tensor to determine if src actually is a view
Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst);
diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm
index ebcdacd..1fe3f29 100644
--- a/aten/src/ATen/native/mps/OperationUtils.mm
+++ b/aten/src/ATen/native/mps/OperationUtils.mm
@@ -63,25 +63,26 @@
}
}
-std::string getMPSTypeString(ScalarType scalar_type) {
+// use short_name to avoid getting extra long cached graph keys with ops such as cat_out(), etc.
+std::string getMPSTypeString(ScalarType scalar_type, bool short_name) {
switch (scalar_type) {
case ScalarType::Double:
case ScalarType::Float:
- return "Float32";
+ return short_name ? "f32" : "Float32";
case ScalarType::Half:
- return "Float16";
+ return short_name ? "f16" : "Float16";
case ScalarType::Int:
- return "Int32";
+ return short_name ? "i32" : "Int32";
case ScalarType::Long:
- return "Int64";
+ return short_name ? "i64" : "Int64";
case ScalarType::Short:
- return "Int16";
+ return short_name ? "i16" : "Int16";
case ScalarType::Char:
- return "Int8";
+ return short_name ? "i8" : "Int8";
case ScalarType::Byte:
- return "UInt8";
+ return short_name ? "u8" : "UInt8";
case ScalarType::Bool:
- return "Bool";
+ return short_name ? "b8" : "Bool";
default:
return "Undefined";
}
@@ -149,16 +150,16 @@
return ss.str();
}
-std::string getTensorsStringKey(const TensorList& tensors, bool use_scalar_value) {
+std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype) {
std::string str;
// The key format per tensor would look like ":Float32[1,1,1,10]:"
for (const Tensor& tensor: tensors) {
str += ":";
if (tensor.defined()) {
- str += getMPSTypeString(tensor.scalar_type()) + "[";
+ str += getMPSTypeString(tensor.scalar_type(), short_dtype) + "[";
// if tensor is a scalar
if (tensor.dim() == 0) {
- str += (use_scalar_value ? std::to_string(tensor.item().to<double>()) : "Scalar");
+ str += "Scalar";
} else {
const NSString* ns_shape_key = [[getMPSShape(tensor) valueForKey:@"description"] componentsJoinedByString:@","];
str += std::string(ns_shape_key.UTF8String);
diff --git a/aten/src/ATen/native/mps/operations/PointwiseOps.mm b/aten/src/ATen/native/mps/operations/PointwiseOps.mm
index 8da6b94..c10b656 100644
--- a/aten/src/ATen/native/mps/operations/PointwiseOps.mm
+++ b/aten/src/ATen/native/mps/operations/PointwiseOps.mm
@@ -34,7 +34,7 @@
MPSGraphCache* cache_ = MPSGraphCache::getInstance();
@autoreleasepool {
- string key = op_name + getTensorsStringKey({self, tensor1, tensor2}, false);
+ string key = op_name + getTensorsStringKey({self, tensor1, tensor2});
CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
diff --git a/aten/src/ATen/native/mps/operations/Shape.mm b/aten/src/ATen/native/mps/operations/Shape.mm
index 3190991..37a65c7 100644
--- a/aten/src/ATen/native/mps/operations/Shape.mm
+++ b/aten/src/ATen/native/mps/operations/Shape.mm
@@ -182,32 +182,6 @@
}
}
-inline c10::MemoryFormat compute_output_memory_format(const TensorList &inputs) {
- c10::optional<c10::MemoryFormat> format = c10::nullopt;
- for (auto &t : inputs) {
- auto f = t.suggest_memory_format();
- if (!format.has_value()) {
- format = f;
- continue;
- }
- if (format.value() == f) {
- continue;
- }
- bool contiguous = (format.value() == c10::MemoryFormat::Contiguous || f == c10::MemoryFormat::Contiguous || format.value() != f);
- if (contiguous) {
- return c10::MemoryFormat::Contiguous;
- }
- }
- return format.value();
-}
-
-//Tensor cat_mps(TensorList inputs, int64_t dimension) {
- //ScalarType high_type = result_type(inputs);
- //Tensor out = at::empty({0}, inputs.front().options().dtype(high_type));
- //at::native::cat_out_mps(inputs, dimension, out);
- //return out;
-//}
-
TORCH_IMPL_FUNC(cat_out_mps)
(const ITensorListRef& inputs,
int64_t dimension,
@@ -217,21 +191,30 @@
bool all_same_sizes_and_stride,
MemoryFormat memory_format,
const Tensor& out) {
+
using namespace mps;
if (out.numel() == 0) {
return;
}
-
auto materialized_inputs = inputs.materialize();
+ auto out_dtype = at::native::result_type(inputs);
int idx = 0;
- for(const Tensor& t : materialized_inputs) {
- TORCH_CHECK(t.dim() > 0,
- "zero-dimensional tensor (at position ", idx, ") cannot be concatenated");
+ for (const Tensor& t : materialized_inputs) {
+ TORCH_CHECK(t.dim() > 0, "zero-dimensional tensor (at position ", idx, ") cannot be concatenated");
+ auto lap = at::get_overlap_status(out, t);
+ TORCH_CHECK(lap != at::MemOverlapStatus::Partial && lap != at::MemOverlapStatus::Full,
+ "torch.cat(): unsupported operation: the input tensors cannot refer to any "
+ "of the output memory locations. Found overlap in input tensor ", idx);
idx++;
}
+ // Check for type promotion
+ TORCH_CHECK(canCast(out_dtype, out.scalar_type()),
+ "torch.cat(): input types can't be cast to the desired output type ", out.scalar_type());
+ TORCH_CHECK(inputs.size() > 0,"torch.cat(): invalid number of inputs ", inputs.size());
dimension = legacy_cat_wrap_dim(dimension, materialized_inputs);
+ TORCH_CHECK(dimension >= 0, "torch.cat(): invalid dimension ", dimension);
// previously, size [0] tensors were the only possible empty tensors; thus, it
// wasn't possible to cat empty tensors unless all the other tensors were
@@ -242,214 +225,154 @@
auto should_skip = [](const Tensor& t) {
return t.dim() == 1 && at::native::size(t, 0) == 0;
};
-
- const Tensor* notSkippedTensor = NULL; // non-owning reference
-
- // Check for type promotion
- TORCH_CHECK(
- canCast(result_type(inputs), out.scalar_type()),
- "torch.cat(): input types ",
- " can't be cast to the desired output type ",
- out.scalar_type());
-
- // Inputs cannot alias the output tensor
- idx = 0;
- for(const Tensor& t : materialized_inputs) {
- auto lap = at::get_overlap_status(out, t);
- TORCH_CHECK(
- lap != at::MemOverlapStatus::Partial &&
- lap != at::MemOverlapStatus::Full,
- "torch.cat(): unsupported operation: the input tensors cannot refer to any "
- "of the output memory locations. Found overlap in input "
- "tensor ",
- idx);
- idx++;
- }
at::assert_no_internal_overlap(out);
+ Tensor notSkippedTensor;
// Indices of tensors to be skipped because they're empty
std::vector<int64_t> skipped_tensor_indices;
// Tensors to be read
- std::vector<const Tensor*> input_tensors;
+ std::vector<Tensor> input_tensors;
int tensor_idx = 0;
- for(const Tensor& t : materialized_inputs) {
- if(t.numel() == 0 || should_skip(t)) {
+ for (const Tensor& t : materialized_inputs) {
+ if (t.numel() == 0 || should_skip(t)) {
skipped_tensor_indices.push_back(tensor_idx);
tensor_idx++;
continue;
}
- input_tensors.push_back(&t);
+ input_tensors.push_back(t);
// TODO: Is this OK?
- notSkippedTensor = &t;
+ notSkippedTensor = t;
tensor_idx++;
}
-
// If all inputs are empty tensors, return an empty tensor
- if (notSkippedTensor == NULL) {
+ if (!notSkippedTensor.defined()) {
return;
}
-
- TORCH_CHECK(
- inputs.size() > 0,
- "torch.cat(): invalid number of inputs ",
- inputs.size());
- TORCH_CHECK(dimension >= 0, "torch.cat(): invalid dimension ", dimension);
-
for (const Tensor& t : inputs) {
- TORCH_CHECK(
- t.device() == notSkippedTensor->device(),
- "torch.cat(): all input tensors must be on the same device. Received ",
- t.device(),
- " and ",
- notSkippedTensor->device());
+ TORCH_CHECK(t.device() == notSkippedTensor.device(),
+ "torch.cat(): all input tensors must be on the same device. Received ",
+ t.device(), " and ", notSkippedTensor.device());
}
+ TORCH_CHECK(out.device() == notSkippedTensor.device(),
+ "torch.cat(): all input tensors and out must be on the same device, but inputs are on ",
+ notSkippedTensor.device(), " and out is on ", out.device());
- TORCH_CHECK(
- out.device() == notSkippedTensor->device(),
- "torch.cat(): all input tensors and out must be on the same device, but inputs are on ",
- notSkippedTensor->device(),
- " and out is on ",
- out.device());
-
- // TODO: memory_format is now an argument?
- // // TODO: Factor out `compute_output_memory_format`
- // c10::MemoryFormat memory_format = compute_output_memory_format(inputs);
-
- std::vector<int64_t> size(notSkippedTensor->sizes().vec());
+ if (out.suggest_memory_format() == MemoryFormat::ChannelsLast) {
+ out.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous);
+ }
+ std::vector<int64_t> size(notSkippedTensor.sizes().vec());
// Compute size of the result in the cat dimension
int64_t cat_dim_size = 0;
idx = 0;
- for(const Tensor& tensor : materialized_inputs) {
- if (should_skip(tensor)) {
- continue;
+ for (const Tensor& tensor : materialized_inputs) {
+ if (!should_skip(tensor)) {
+ // TODO: Factor out `check_shape_except_dim`
+ check_shape_except_dim(notSkippedTensor, tensor, dimension, idx);
+ cat_dim_size += at::native::size(tensor, dimension);
+ idx++;
}
- // TODO: Factor out `check_shape_except_dim`
- check_shape_except_dim(*notSkippedTensor, tensor, dimension, idx);
- cat_dim_size += at::native::size(tensor, dimension);
- idx++;
}
-
// Compute the size of the result
size[dimension] = cat_dim_size;
-
// skip resizing if size of result is same as expected
if (out.sizes() != size) {
out.resize_(size, memory_format);
}
-
if (out.numel() == 0) {
return;
}
- // Get stream
- MPSStream* stream = getCurrentMPSStream();
-
- struct CachedGraph : public MPSCachedGraph
- {
+ struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
- // TODO: Free this when no longer needed globally
- MPSGraphTensor** inputMPSGraphTensors_ = nil;
+ std::vector<MPSGraphTensor*> inputTensors_;
MPSGraphTensor* outputTensor_ = nil;
};
-
MPSGraphCache *cache_ = MPSGraphCache::getInstance();
- // Make string out of skipped tensor indices
- string skipped_indices_string = "";
- for(int idx : skipped_tensor_indices)
- skipped_indices_string += (std::to_string(idx)+",");
- string input_types = "";
- for(const Tensor& tensor : materialized_inputs)
- input_types += (getMPSTypeString(tensor.scalar_type())+",");
-
@autoreleasepool {
- string key = "cat_out_mps:" + getMPSTypeString(result_type(inputs))
- + ":" + to_string(inputs.size())
- + ":" + skipped_indices_string
- + ":" + input_types
- + ":" + to_string(dimension);
- CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
- if(!cachedGraph) {
- MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
+ string key = "cat_out_mps:" + to_string(dimension) + getTensorsStringKey(input_tensors, /*short_dtype*/true) + ":" +
+ (memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW");
+
+ CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
+ if (!cachedGraph) {
+ cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ MPSCachedGraph * () {
CachedGraph *newCachedGraph = nil;
@autoreleasepool {
- // Initialize graph
MPSGraph *mpsGraph = make_mps_graph();
newCachedGraph = new CachedGraph(mpsGraph);
- // Create placeholders
auto len_tensor_array = inputs.size() - skipped_tensor_indices.size();
- std::vector<MPSGraphTensor*> inputMPSGraphTensors(len_tensor_array);
- std::vector<MPSGraphTensor*> castInputMPSGraphTensors(len_tensor_array);
+ std::vector<MPSGraphTensor*> castInputTensors(len_tensor_array);
+ newCachedGraph->inputTensors_.reserve(len_tensor_array);
- int graph_tensor_idx = 0;
- for(const Tensor* tensor : input_tensors) {
- inputMPSGraphTensors[graph_tensor_idx] = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(tensor->scalar_type()) );
- if(getMPSDataType(result_type(inputs)) == MPSDataTypeBool) {
- castInputMPSGraphTensors[graph_tensor_idx] = [mpsGraph castTensor:inputMPSGraphTensors[graph_tensor_idx]
- toType:MPSDataTypeFloat32
- name:[NSString stringWithFormat:@"castInput%@", [NSNumber numberWithInt:graph_tensor_idx]]];
+ for (const auto idx : c10::irange(len_tensor_array)) {
+ const Tensor& tensor = input_tensors[idx];
+ auto scalar_type = getMPSScalarType(tensor.scalar_type());
+ if (tensor.scalar_type() == kBool) {
+ scalar_type = MPSDataTypeInt8;
}
- else {
- if(tensor->scalar_type() != result_type(inputs))
- castInputMPSGraphTensors[graph_tensor_idx] = [mpsGraph castTensor:inputMPSGraphTensors[graph_tensor_idx]
- toType:getMPSDataType(result_type(inputs))
- name:[NSString stringWithFormat:@"castInput%@", [NSNumber numberWithInt:graph_tensor_idx]]];
- else
- castInputMPSGraphTensors[graph_tensor_idx] = inputMPSGraphTensors[graph_tensor_idx];
+ newCachedGraph->inputTensors_[idx] = mpsGraphRankedPlaceHolder(mpsGraph, scalar_type, getMPSShape(tensor, memory_format));
+ if (tensor.scalar_type() != out_dtype) {
+ castInputTensors[idx] = [mpsGraph castTensor:newCachedGraph->inputTensors_[idx]
+ toType:getMPSDataType(out_dtype)
+ name:@"castInput"];
+ } else {
+ castInputTensors[idx] = newCachedGraph->inputTensors_[idx];
}
- graph_tensor_idx++;
}
- auto inputTensorsArray = [NSArray arrayWithObjects:castInputMPSGraphTensors.data()
+ auto inputTensorsArray = [NSArray arrayWithObjects:castInputTensors.data()
count:len_tensor_array];
- // Use concatTensors to concatenate
MPSGraphTensor* outputTensor = [mpsGraph concatTensors:inputTensorsArray
dimension:dimension // Maybe convert this from int64_t -> int32
name:nil];
-
- newCachedGraph->inputMPSGraphTensors_ = (MPSGraphTensor**)malloc(len_tensor_array * sizeof(MPSGraphTensor*));
-
- for(int i = 0; i < len_tensor_array; i++)
- newCachedGraph->inputMPSGraphTensors_[i] = inputMPSGraphTensors[i];
- if(getMPSDataType(result_type(inputs)) == MPSDataTypeBool)
+ if (getMPSDataType(out_dtype) == MPSDataTypeBool) {
outputTensor = [mpsGraph castTensor:outputTensor
toType:MPSDataTypeBool
name:@"outputTensor"];
- newCachedGraph->outputTensor_ = outputTensor;
+ }
+ newCachedGraph->outputTensor_ = memory_format == MemoryFormat::ChannelsLast ?
+ convertNHWCtoNCHW(mpsGraph, outputTensor) : outputTensor;
}
return newCachedGraph;
});
- cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}
std::vector<Placeholder> inputPlaceholders;
int i = 0;
int t_idx = 0;
- for(const Tensor& tensor : materialized_inputs) {
- if(std::find(skipped_tensor_indices.begin(), skipped_tensor_indices.end(), i) == skipped_tensor_indices.end()) {
- Placeholder currentInputPlaceholder = Placeholder(cachedGraph->inputMPSGraphTensors_[t_idx], tensor);
- inputPlaceholders.push_back(currentInputPlaceholder);
+ for (const Tensor& tensor : materialized_inputs) {
+ if (std::find(skipped_tensor_indices.begin(), skipped_tensor_indices.end(), i) == skipped_tensor_indices.end()) {
+ auto scalar_type = getMPSScalarType(tensor.scalar_type());
+ if (tensor.scalar_type() == kBool) {
+ scalar_type = MPSDataTypeInt8;
+ }
+ inputPlaceholders.emplace_back(cachedGraph->inputTensors_[t_idx], tensor,
+ getMPSShape(tensor, memory_format),
+ memory_format != MemoryFormat::ChannelsLast, scalar_type);
t_idx++;
}
i++;
}
- Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out);
+ auto outputDataType = getMPSScalarType(out.scalar_type());
+ if (!is_macos_13_or_newer() && out.scalar_type() == kBool) {
+ outputDataType = MPSDataTypeInt8;
+ }
+ Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out, nil, false, outputDataType);
NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease];
- for (int i = 0; i < inputPlaceholders.size(); i++) {
- feeds[(inputPlaceholders[i]).getMPSGraphTensor()] = (inputPlaceholders[i]).getMPSGraphTensorData();
+ for (auto& inputPlaceholder : inputPlaceholders) {
+ feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
}
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
- mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
+ runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
}
-
}
} // namespace native
diff --git a/test/test_mps.py b/test/test_mps.py
index bb10422..0525a0c 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -8873,7 +8873,7 @@
# does not support float64 Tensors.
# A few ops are currently broken on their reference inputs, but not their sample inputs. These should
# get patched up and this workaround removed.
- broken_on_ref_inputs = op.name in ['cat', 'clamp', 'where']
+ broken_on_ref_inputs = op.name in ['clamp', 'where']
inputs = op.reference_inputs(device, dtype) if not broken_on_ref_inputs else op.sample_inputs(device, dtype)
for sample_input in inputs:
self.compare_with_reference(op, op.ref, sample_input)