[MPS] Fix the ChannelsLast memory format in cat_out_mps() (#91786)

- Fixed the memory leak with the `malloc()`
- Introduced shortened data type strings (optional) to avoid getting extra long cached graph string keys with ops such as cat_out()
- Fixed data type issues in Monterey
- Removed the unused `use_scalar_value` argument from `getTensorsStringKey()`
- Clean up and refactoring

Fixes #89353

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91786
Approved by: https://github.com/kulinseth
diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h
index 5cba91d..e6a6ef4 100644
--- a/aten/src/ATen/native/mps/OperationUtils.h
+++ b/aten/src/ATen/native/mps/OperationUtils.h
@@ -42,12 +42,12 @@
 MPSDataType getMPSDataType(ScalarType scalar_type);
 MPSDataType getMPSScalarType(ScalarType scalar_type);
 MPSScalar   getMPSScalar(const Scalar& scalar, ScalarType type);
-std::string getMPSTypeString(ScalarType scalar_type);
+std::string getMPSTypeString(ScalarType scalar_type, bool short_name = false);
 std::string scalarToMetalTypeString(const c10::ScalarType& scalar_type);
 NSArray<NSNumber*>* getTensorAxes(const Tensor& t);
 NSArray<NSNumber*>* getTensorAxes(const Tensor& t, at::OptionalIntArrayRef dim);
 std::string getMPSShapeString(MPSShape* shape);
-std::string getTensorsStringKey(const TensorList& tensors, bool use_scalar_value = false);
+std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype = false);
 std::string getArrayRefString(const IntArrayRef s);
 // use has_storage() on the returned tensor to determine if src actually is a view
 Tensor gatherViewTensor(const at::Tensor& src, at::Tensor& dst);
diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm
index ebcdacd..1fe3f29 100644
--- a/aten/src/ATen/native/mps/OperationUtils.mm
+++ b/aten/src/ATen/native/mps/OperationUtils.mm
@@ -63,25 +63,26 @@
   }
 }
 
-std::string getMPSTypeString(ScalarType scalar_type) {
+// use short_name to avoid getting extra long cached graph keys with ops such as cat_out(), etc.
+std::string getMPSTypeString(ScalarType scalar_type, bool short_name) {
   switch (scalar_type) {
     case ScalarType::Double:
     case ScalarType::Float:
-      return "Float32";
+      return short_name ? "f32" : "Float32";
     case ScalarType::Half:
-      return "Float16";
+      return short_name ? "f16" : "Float16";
     case ScalarType::Int:
-      return "Int32";
+      return short_name ? "i32" : "Int32";
     case ScalarType::Long:
-      return "Int64";
+      return short_name ? "i64" : "Int64";
     case ScalarType::Short:
-      return "Int16";
+      return short_name ? "i16" : "Int16";
     case ScalarType::Char:
-      return "Int8";
+      return short_name ? "i8" : "Int8";
     case ScalarType::Byte:
-      return "UInt8";
+      return short_name ? "u8" : "UInt8";
     case ScalarType::Bool:
-      return "Bool";
+      return short_name ? "b8" : "Bool";
     default:
       return "Undefined";
   }
@@ -149,16 +150,16 @@
   return ss.str();
 }
 
-std::string getTensorsStringKey(const TensorList& tensors, bool use_scalar_value) {
+std::string getTensorsStringKey(const TensorList& tensors, bool short_dtype) {
     std::string str;
     // The key format per tensor would look like ":Float32[1,1,1,10]:"
     for (const Tensor& tensor: tensors) {
       str += ":";
       if (tensor.defined()) {
-        str += getMPSTypeString(tensor.scalar_type()) + "[";
+        str += getMPSTypeString(tensor.scalar_type(), short_dtype) + "[";
         // if tensor is a scalar
         if (tensor.dim() == 0) {
-          str += (use_scalar_value ? std::to_string(tensor.item().to<double>()) : "Scalar");
+          str += "Scalar";
         } else {
           const NSString* ns_shape_key = [[getMPSShape(tensor) valueForKey:@"description"] componentsJoinedByString:@","];
           str += std::string(ns_shape_key.UTF8String);
diff --git a/aten/src/ATen/native/mps/operations/PointwiseOps.mm b/aten/src/ATen/native/mps/operations/PointwiseOps.mm
index 8da6b94..c10b656 100644
--- a/aten/src/ATen/native/mps/operations/PointwiseOps.mm
+++ b/aten/src/ATen/native/mps/operations/PointwiseOps.mm
@@ -34,7 +34,7 @@
   MPSGraphCache* cache_ = MPSGraphCache::getInstance();
 
   @autoreleasepool {
-    string key = op_name + getTensorsStringKey({self, tensor1, tensor2}, false);
+    string key = op_name + getTensorsStringKey({self, tensor1, tensor2});
 
     CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
 
diff --git a/aten/src/ATen/native/mps/operations/Shape.mm b/aten/src/ATen/native/mps/operations/Shape.mm
index 3190991..37a65c7 100644
--- a/aten/src/ATen/native/mps/operations/Shape.mm
+++ b/aten/src/ATen/native/mps/operations/Shape.mm
@@ -182,32 +182,6 @@
   }
 }
 
-inline c10::MemoryFormat compute_output_memory_format(const TensorList &inputs) {
-  c10::optional<c10::MemoryFormat> format = c10::nullopt;
-  for (auto &t : inputs) {
-    auto f = t.suggest_memory_format();
-    if (!format.has_value()) {
-      format = f;
-      continue;
-    }
-    if (format.value() == f) {
-      continue;
-    }
-    bool contiguous = (format.value() == c10::MemoryFormat::Contiguous || f == c10::MemoryFormat::Contiguous || format.value() != f);
-    if (contiguous) {
-      return c10::MemoryFormat::Contiguous;
-    }
-  }
-  return format.value();
-}
-
-//Tensor cat_mps(TensorList inputs, int64_t dimension) {
-  //ScalarType high_type = result_type(inputs);
-  //Tensor out = at::empty({0}, inputs.front().options().dtype(high_type));
-  //at::native::cat_out_mps(inputs, dimension, out);
-  //return out;
-//}
-
 TORCH_IMPL_FUNC(cat_out_mps)
       (const ITensorListRef& inputs,
        int64_t dimension,
@@ -217,21 +191,30 @@
        bool all_same_sizes_and_stride,
        MemoryFormat memory_format,
        const Tensor& out) {
+
   using namespace mps;
   if (out.numel() == 0) {
     return;
   }
-
   auto materialized_inputs = inputs.materialize();
+  auto out_dtype = at::native::result_type(inputs);
 
   int idx = 0;
-  for(const Tensor& t : materialized_inputs) {
-    TORCH_CHECK(t.dim() > 0,
-             "zero-dimensional tensor (at position ", idx, ") cannot be concatenated");
+  for (const Tensor& t : materialized_inputs) {
+    TORCH_CHECK(t.dim() > 0, "zero-dimensional tensor (at position ", idx, ") cannot be concatenated");
+    auto lap = at::get_overlap_status(out, t);
+    TORCH_CHECK(lap != at::MemOverlapStatus::Partial && lap != at::MemOverlapStatus::Full,
+        "torch.cat(): unsupported operation: the input tensors cannot refer to any "
+        "of the output memory locations. Found overlap in input tensor ", idx);
     idx++;
   }
+  // Check for type promotion
+  TORCH_CHECK(canCast(out_dtype, out.scalar_type()),
+              "torch.cat(): input types can't be cast to the desired output type ", out.scalar_type());
+  TORCH_CHECK(inputs.size() > 0,"torch.cat(): invalid number of inputs ", inputs.size());
 
   dimension = legacy_cat_wrap_dim(dimension, materialized_inputs);
+  TORCH_CHECK(dimension >= 0, "torch.cat(): invalid dimension ", dimension);
 
   // previously, size [0] tensors were the only possible empty tensors; thus, it
   // wasn't possible to cat empty tensors unless all the other tensors were
@@ -242,214 +225,154 @@
   auto should_skip = [](const Tensor& t) {
     return t.dim() == 1 && at::native::size(t, 0) == 0;
   };
-
-  const Tensor* notSkippedTensor = NULL; // non-owning reference
-
-  // Check for type promotion
-  TORCH_CHECK(
-      canCast(result_type(inputs), out.scalar_type()),
-      "torch.cat(): input types ",
-      " can't be cast to the desired output type ",
-      out.scalar_type());
-
-  // Inputs cannot alias the output tensor
-  idx = 0;
-  for(const Tensor& t : materialized_inputs) {
-    auto lap = at::get_overlap_status(out, t);
-    TORCH_CHECK(
-        lap != at::MemOverlapStatus::Partial &&
-            lap != at::MemOverlapStatus::Full,
-        "torch.cat(): unsupported operation: the input tensors cannot refer to any "
-        "of the output memory locations. Found overlap in input "
-        "tensor ",
-        idx);
-    idx++;
-  }
   at::assert_no_internal_overlap(out);
 
+  Tensor notSkippedTensor;
   // Indices of tensors to be skipped because they're empty
   std::vector<int64_t> skipped_tensor_indices;
   // Tensors to be read
-  std::vector<const Tensor*> input_tensors;
+  std::vector<Tensor> input_tensors;
   int tensor_idx = 0;
-  for(const Tensor& t : materialized_inputs) {
-    if(t.numel() == 0 || should_skip(t)) {
+  for (const Tensor& t : materialized_inputs) {
+    if (t.numel() == 0 || should_skip(t)) {
       skipped_tensor_indices.push_back(tensor_idx);
       tensor_idx++;
       continue;
     }
-    input_tensors.push_back(&t);
+    input_tensors.push_back(t);
     // TODO: Is this OK?
-    notSkippedTensor = &t;
+    notSkippedTensor = t;
     tensor_idx++;
   }
-
   // If all inputs are empty tensors, return an empty tensor
-  if (notSkippedTensor == NULL) {
+  if (!notSkippedTensor.defined()) {
     return;
   }
-
-  TORCH_CHECK(
-      inputs.size() > 0,
-      "torch.cat(): invalid number of inputs ",
-      inputs.size());
-  TORCH_CHECK(dimension >= 0, "torch.cat(): invalid dimension ", dimension);
-
   for (const Tensor& t : inputs) {
-    TORCH_CHECK(
-        t.device() == notSkippedTensor->device(),
-        "torch.cat(): all input tensors must be on the same device. Received ",
-        t.device(),
-        " and ",
-        notSkippedTensor->device());
+    TORCH_CHECK(t.device() == notSkippedTensor.device(),
+                "torch.cat(): all input tensors must be on the same device. Received ",
+                t.device(), " and ", notSkippedTensor.device());
   }
+  TORCH_CHECK(out.device() == notSkippedTensor.device(),
+              "torch.cat(): all input tensors and out must be on the same device, but inputs are on ",
+              notSkippedTensor.device(), " and out is on ", out.device());
 
-  TORCH_CHECK(
-      out.device() == notSkippedTensor->device(),
-      "torch.cat(): all input tensors and out must be on the same device, but inputs are on ",
-      notSkippedTensor->device(),
-      " and out is on ",
-      out.device());
-
-  // TODO: memory_format is now an argument?
-  // // TODO: Factor out `compute_output_memory_format`
-  // c10::MemoryFormat memory_format = compute_output_memory_format(inputs);
-
-  std::vector<int64_t> size(notSkippedTensor->sizes().vec());
+  if (out.suggest_memory_format() == MemoryFormat::ChannelsLast) {
+    out.unsafeGetTensorImpl()->empty_tensor_restride(MemoryFormat::Contiguous);
+  }
+  std::vector<int64_t> size(notSkippedTensor.sizes().vec());
 
   // Compute size of the result in the cat dimension
   int64_t cat_dim_size = 0;
   idx = 0;
-  for(const Tensor& tensor : materialized_inputs) {
-    if (should_skip(tensor)) {
-      continue;
+  for (const Tensor& tensor : materialized_inputs) {
+    if (!should_skip(tensor)) {
+      // TODO: Factor out `check_shape_except_dim`
+      check_shape_except_dim(notSkippedTensor, tensor, dimension, idx);
+      cat_dim_size += at::native::size(tensor, dimension);
+      idx++;
     }
-    // TODO: Factor out `check_shape_except_dim`
-    check_shape_except_dim(*notSkippedTensor, tensor, dimension, idx);
-    cat_dim_size += at::native::size(tensor, dimension);
-    idx++;
   }
-
   // Compute the size of the result
   size[dimension] = cat_dim_size;
-
   // skip resizing if size of result is same as expected
   if (out.sizes() != size) {
     out.resize_(size, memory_format);
   }
-
   if (out.numel() == 0) {
     return;
   }
 
-  // Get stream
-  MPSStream* stream = getCurrentMPSStream();
-
-  struct CachedGraph : public MPSCachedGraph
-  {
+  struct CachedGraph : public MPSCachedGraph {
     CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
-    // TODO: Free this when no longer needed globally
-    MPSGraphTensor** inputMPSGraphTensors_ = nil;
+    std::vector<MPSGraphTensor*> inputTensors_;
     MPSGraphTensor* outputTensor_ = nil;
   };
-
   MPSGraphCache *cache_ = MPSGraphCache::getInstance();
 
-  // Make string out of skipped tensor indices
-  string skipped_indices_string = "";
-  for(int idx : skipped_tensor_indices)
-    skipped_indices_string += (std::to_string(idx)+",");
-  string input_types = "";
-  for(const Tensor& tensor : materialized_inputs)
-    input_types += (getMPSTypeString(tensor.scalar_type())+",");
-
   @autoreleasepool {
-    string key = "cat_out_mps:" + getMPSTypeString(result_type(inputs))
-                                + ":" + to_string(inputs.size())
-                                + ":" + skipped_indices_string
-                                + ":" + input_types
-                                + ":" + to_string(dimension);
-    CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
-    if(!cachedGraph) {
-      MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ MPSCachedGraph * () {
+    string key = "cat_out_mps:" + to_string(dimension) + getTensorsStringKey(input_tensors, /*short_dtype*/true) + ":" +
+                 (memory_format == MemoryFormat::ChannelsLast ? "NHWC" : "NCHW");
+
+    CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
+    if (!cachedGraph) {
+      cachedGraph = cache_->CreateCachedGraphAs<CachedGraph>(key, ^ MPSCachedGraph * () {
         CachedGraph *newCachedGraph = nil;
 
         @autoreleasepool {
-          // Initialize graph
           MPSGraph *mpsGraph = make_mps_graph();
           newCachedGraph = new CachedGraph(mpsGraph);
 
-          // Create placeholders
           auto len_tensor_array = inputs.size() - skipped_tensor_indices.size();
-          std::vector<MPSGraphTensor*> inputMPSGraphTensors(len_tensor_array);
-          std::vector<MPSGraphTensor*> castInputMPSGraphTensors(len_tensor_array);
+          std::vector<MPSGraphTensor*> castInputTensors(len_tensor_array);
+          newCachedGraph->inputTensors_.reserve(len_tensor_array);
 
-          int graph_tensor_idx = 0;
-          for(const Tensor* tensor : input_tensors) {
-            inputMPSGraphTensors[graph_tensor_idx] = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(tensor->scalar_type()) );
-            if(getMPSDataType(result_type(inputs)) == MPSDataTypeBool) {
-              castInputMPSGraphTensors[graph_tensor_idx] = [mpsGraph castTensor:inputMPSGraphTensors[graph_tensor_idx]
-                                                                           toType:MPSDataTypeFloat32
-                                                                             name:[NSString stringWithFormat:@"castInput%@", [NSNumber numberWithInt:graph_tensor_idx]]];
+          for (const auto idx : c10::irange(len_tensor_array)) {
+            const Tensor& tensor = input_tensors[idx];
+            auto scalar_type = getMPSScalarType(tensor.scalar_type());
+            if (tensor.scalar_type() == kBool) {
+              scalar_type = MPSDataTypeInt8;
             }
-            else {
-              if(tensor->scalar_type() != result_type(inputs))
-                castInputMPSGraphTensors[graph_tensor_idx] = [mpsGraph castTensor:inputMPSGraphTensors[graph_tensor_idx]
-                                                                           toType:getMPSDataType(result_type(inputs))
-                                                                             name:[NSString stringWithFormat:@"castInput%@", [NSNumber numberWithInt:graph_tensor_idx]]];
-              else
-                castInputMPSGraphTensors[graph_tensor_idx] = inputMPSGraphTensors[graph_tensor_idx];
+            newCachedGraph->inputTensors_[idx] = mpsGraphRankedPlaceHolder(mpsGraph, scalar_type, getMPSShape(tensor, memory_format));
+            if (tensor.scalar_type() != out_dtype) {
+              castInputTensors[idx] = [mpsGraph castTensor:newCachedGraph->inputTensors_[idx]
+                                                    toType:getMPSDataType(out_dtype)
+                                                      name:@"castInput"];
+            } else {
+              castInputTensors[idx] = newCachedGraph->inputTensors_[idx];
             }
-            graph_tensor_idx++;
           }
 
-          auto inputTensorsArray = [NSArray arrayWithObjects:castInputMPSGraphTensors.data()
+          auto inputTensorsArray = [NSArray arrayWithObjects:castInputTensors.data()
                                                        count:len_tensor_array];
-          // Use concatTensors to concatenate
           MPSGraphTensor* outputTensor = [mpsGraph concatTensors:inputTensorsArray
                                                        dimension:dimension // Maybe convert this from int64_t -> int32
                                                             name:nil];
-
-          newCachedGraph->inputMPSGraphTensors_ = (MPSGraphTensor**)malloc(len_tensor_array * sizeof(MPSGraphTensor*));
-
-          for(int i = 0; i < len_tensor_array; i++)
-            newCachedGraph->inputMPSGraphTensors_[i] = inputMPSGraphTensors[i];
-          if(getMPSDataType(result_type(inputs)) == MPSDataTypeBool)
+          if (getMPSDataType(out_dtype) == MPSDataTypeBool) {
             outputTensor = [mpsGraph castTensor:outputTensor
                                          toType:MPSDataTypeBool
                                            name:@"outputTensor"];
-          newCachedGraph->outputTensor_ = outputTensor;
+          }
+          newCachedGraph->outputTensor_ = memory_format == MemoryFormat::ChannelsLast ?
+                                         convertNHWCtoNCHW(mpsGraph, outputTensor) : outputTensor;
         }
         return newCachedGraph;
       });
-      cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
     }
 
     std::vector<Placeholder> inputPlaceholders;
     int i = 0;
     int t_idx = 0;
-    for(const Tensor& tensor : materialized_inputs) {
-      if(std::find(skipped_tensor_indices.begin(), skipped_tensor_indices.end(), i) == skipped_tensor_indices.end()) {
-        Placeholder currentInputPlaceholder = Placeholder(cachedGraph->inputMPSGraphTensors_[t_idx], tensor);
-        inputPlaceholders.push_back(currentInputPlaceholder);
+    for (const Tensor& tensor : materialized_inputs) {
+      if (std::find(skipped_tensor_indices.begin(), skipped_tensor_indices.end(), i) == skipped_tensor_indices.end()) {
+        auto scalar_type = getMPSScalarType(tensor.scalar_type());
+        if (tensor.scalar_type() == kBool) {
+          scalar_type = MPSDataTypeInt8;
+        }
+        inputPlaceholders.emplace_back(cachedGraph->inputTensors_[t_idx], tensor,
+                                       getMPSShape(tensor, memory_format),
+                                       memory_format != MemoryFormat::ChannelsLast, scalar_type);
         t_idx++;
       }
       i++;
     }
 
-    Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out);
+    auto outputDataType = getMPSScalarType(out.scalar_type());
+    if (!is_macos_13_or_newer() && out.scalar_type() == kBool) {
+      outputDataType = MPSDataTypeInt8;
+    }
+    Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out, nil, false, outputDataType);
 
     NSMutableDictionary *feeds = [[NSMutableDictionary new] autorelease];
-    for (int i = 0; i < inputPlaceholders.size(); i++) {
-      feeds[(inputPlaceholders[i]).getMPSGraphTensor()] = (inputPlaceholders[i]).getMPSGraphTensorData();
+    for (auto& inputPlaceholder : inputPlaceholders) {
+      feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
     }
     NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
       outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
     };
 
-    mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
+    runMPSGraph(getCurrentMPSStream(), cachedGraph->graph(), feeds, results);
   }
-
 }
 
 } // namespace native
diff --git a/test/test_mps.py b/test/test_mps.py
index bb10422..0525a0c 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -8873,7 +8873,7 @@
         # does not support float64 Tensors.
         # A few ops are currently broken on their reference inputs, but not their sample inputs. These should
         # get patched up and this workaround removed.
-        broken_on_ref_inputs = op.name in ['cat', 'clamp', 'where']
+        broken_on_ref_inputs = op.name in ['clamp', 'where']
         inputs = op.reference_inputs(device, dtype) if not broken_on_ref_inputs else op.sample_inputs(device, dtype)
         for sample_input in inputs:
             self.compare_with_reference(op, op.ref, sample_input)