[MPS] Fix channels last copies in ELU,ReLU and Hardswish (#94664)

Fixes test_modules.py tests:
```
test_memory_format_nn_Hardswish_mps_float32
test_non_contiguous_tensors_nn_Hardswish_mps_float32
test_memory_format_nn_ReLU_mps_float32
```
Fixes elu when ran with `ChannelsLast` memory format.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94664
Approved by: https://github.com/kulinseth
diff --git a/aten/src/ATen/native/mps/operations/Activation.mm b/aten/src/ATen/native/mps/operations/Activation.mm
index a5dae09..9e643eb 100644
--- a/aten/src/ATen/native/mps/operations/Activation.mm
+++ b/aten/src/ATen/native/mps/operations/Activation.mm
@@ -18,14 +18,15 @@
 Tensor relu_mps(const Tensor& self) {
   using namespace mps;
   using CachedGraph = MPSUnaryCachedGraph;
-  Tensor output = at::empty_like(self);
-  resize_tensor(&output);
-  TORCH_CHECK(output.is_mps());
-
   MPSGraphCache* cache_ = MPSGraphCache::getInstance();
 
   MPSStream* stream = getCurrentMPSStream();
 
+  bool executeGatherOp = !(self.is_contiguous(MemoryFormat::Contiguous) ||
+                           self.is_contiguous(MemoryFormat::ChannelsLast) ||
+                           self.is_contiguous(MemoryFormat::ChannelsLast3d));
+  Tensor output = at::empty_like(self, executeGatherOp ? MemoryFormat::Contiguous : MemoryFormat::Preserve);
+
   @autoreleasepool {
     string key = "relu" + getTensorsStringKey({self});
     CachedGraph* cachedGraph = cache_->LookUpAs<CachedGraph>(key);
@@ -51,8 +52,8 @@
       cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
     }
 
-    Placeholder selfPlaceholder   = Placeholder(cachedGraph->inputTensor_, self);
-    Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
+    Placeholder selfPlaceholder   = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp);
+    Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output, nil, false);
 
     // Create dictionary of inputs and outputs
     NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
@@ -75,7 +76,13 @@
   using CachedGraph = MPSUnaryCachedGraph;
   // Inplace relu
   Tensor &output = self;
-  TORCH_CHECK(output.is_mps());
+  bool executeGatherOp = !(self.is_contiguous(MemoryFormat::Contiguous) ||
+                           self.is_contiguous(MemoryFormat::ChannelsLast) ||
+                           self.is_contiguous(MemoryFormat::ChannelsLast3d));
+  Tensor out;
+  if (executeGatherOp) {
+    out = at::empty_like(self, MemoryFormat::Contiguous);
+  }
 
   MPSGraphCache* cache_ = MPSGraphCache::getInstance();
 
@@ -106,8 +113,8 @@
       cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
     }
 
-    Placeholder selfPlaceholder   = Placeholder(cachedGraph->inputTensor_, self);
-    Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output);
+    Placeholder selfPlaceholder   = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp);
+    Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, executeGatherOp ? out : output, nil, false);
 
     // Create dictionary of inputs and outputs
     NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
@@ -119,7 +126,9 @@
     };
 
     runMPSGraph(stream, cachedGraph->graph(), feeds, results);
-
+    if (executeGatherOp) {
+      output.copy_(out);
+    }
   }
 
   return output;
@@ -1052,11 +1061,17 @@
   string func_name) {
 
   using namespace mps;
-  TORCH_CHECK(self.is_mps());
+  auto resultMemFormat = result.suggest_memory_format();
+  bool executeGatherOp = !(self.is_contiguous(resultMemFormat) && result.is_contiguous(resultMemFormat));
+  Tensor out;
+  if (executeGatherOp && resultMemFormat == MemoryFormat::ChannelsLast) {
+    out = at::empty_like(result, MemoryFormat::Contiguous);
+  }
 
   // Empty output
-  if(result.numel() == 0)
+  if(result.numel() == 0) {
     return;
+  }
 
   struct CachedGraph : public MPSCachedGraph
   {
@@ -1137,8 +1152,8 @@
       cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
     }
 
-    Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
-    Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, result);
+    Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp);
+    Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor_, out.has_storage() ? out : result, nil, false);
 
     // Create dictionary of inputs and outputs
     NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
@@ -1150,8 +1165,10 @@
     };
 
     runMPSGraph(stream, cachedGraph->graph(), feeds, results);
+    if (out.has_storage()) {
+      result.copy_(out);
+    }
   }
-
 }
 
 // scale * (max(0, x) + min(0, alpha * (exp(input_scale * x) - 1) ))
@@ -1174,13 +1191,18 @@
   const Tensor& self_or_result,
   const Tensor& grad_input
 ) {
-
   using namespace mps;
-  TORCH_CHECK(grad_output.is_mps());
+  auto gradMemFormat = grad_input.suggest_memory_format();
+  bool executeGatherOp = !(grad_output.is_contiguous(gradMemFormat) && self_or_result.is_contiguous(gradMemFormat) && grad_input.is_contiguous(gradMemFormat));
+  Tensor out;
+  if (executeGatherOp && gradMemFormat == MemoryFormat::ChannelsLast) {
+    out = at::empty_like(grad_input, MemoryFormat::Contiguous);
+  }
 
   // Empty output
-  if(grad_input.numel() == 0)
+  if(grad_input.numel() == 0) {
     return;
+  }
 
   struct CachedGraph : public MPSCachedGraph
   {
@@ -1281,14 +1303,14 @@
       cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
     }
 
-    Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output);
+    Placeholder gradOutputPlaceholder = Placeholder(cachedGraph->gradOutputTensor_, grad_output, nil, executeGatherOp);
     Placeholder selfPlaceholder = Placeholder();
     Placeholder resultPlaceholder = Placeholder();
     if(is_result)
-      resultPlaceholder = Placeholder(cachedGraph->resultTensor_, self_or_result);
+      resultPlaceholder = Placeholder(cachedGraph->resultTensor_, self_or_result, nil, executeGatherOp);
     else
-      selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self_or_result);
-    Placeholder gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, grad_input);
+      selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self_or_result, nil, executeGatherOp);
+    Placeholder gradInputPlaceholder = Placeholder(cachedGraph->gradInputTensor_, out.has_storage() ? out : grad_input, nil, false);
 
     // Create dictionary of inputs and outputs
     NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = nil;
@@ -1309,8 +1331,10 @@
     };
 
     runMPSGraph(stream, cachedGraph->graph(), feeds, results);
+    if (out.has_storage()) {
+      grad_input.copy_(out);
+    }
   }
-
 }
 
 TORCH_IMPL_FUNC(glu_out_mps) (
@@ -1390,7 +1414,6 @@
     runMPSGraph(stream, cachedGraph->graph(), feeds, results);
 
   }
-
 }
 
 Tensor& glu_backward_mps_out (
@@ -2210,12 +2233,17 @@
   using namespace mps;
   using CachedGraph = MPSUnaryCachedGraph;
 
-  TORCH_CHECK(self.is_mps());
-
   if (output.numel() == 0) {
     return output;
   }
 
+  auto resultMemFormat = output.suggest_memory_format();
+  bool executeGatherOp = !(self.is_contiguous(resultMemFormat) && output.is_contiguous(resultMemFormat));
+  Tensor out;
+  if (executeGatherOp && !output.is_contiguous(MemoryFormat::Contiguous)) {
+    out = at::empty_like(output, MemoryFormat::Contiguous);
+  }
+
   MPSGraphCache* cache_ = MPSGraphCache::getInstance();
 
   MPSStream* stream = at::mps::getCurrentMPSStream();
@@ -2296,9 +2324,9 @@
           });
       cachedGraph = static_cast<CachedGraph*>(tmpCachedGraph);
     }
-    Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self);
+    Placeholder selfPlaceholder = Placeholder(cachedGraph->inputTensor_, self, nil, executeGatherOp);
     Placeholder outputPlaceholder =
-        Placeholder(cachedGraph->outputTensor_, output);
+        Placeholder(cachedGraph->outputTensor_, out.has_storage() ? out : output, nil, false);
 
     // Create dictionary of inputs and outputs
     NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
@@ -2312,6 +2340,9 @@
     };
 
     runMPSGraph(stream, cachedGraph->graph(), feeds, results);
+    if (out.has_storage()) {
+      output.copy_(out);
+    }
   }
   return output;
 }
diff --git a/aten/src/ATen/native/mps/operations/Copy.mm b/aten/src/ATen/native/mps/operations/Copy.mm
index 1e47b57..eade505 100644
--- a/aten/src/ATen/native/mps/operations/Copy.mm
+++ b/aten/src/ATen/native/mps/operations/Copy.mm
@@ -103,18 +103,20 @@
 
 static at::Tensor& copy_from_mps_(at::Tensor& dst_, const at::Tensor& src_, bool non_blocking)
 {
+  auto sameMemFormat = src_.is_contiguous(dst_.suggest_memory_format()) && dst_.is_contiguous(dst_.suggest_memory_format());
+
   id<MTLDevice> device = MPSDevice::getInstance()->device();
   MPSStream* stream = getCurrentMPSStream();
   Tensor dst;
   Tensor src;
-  if (!dst_.is_contiguous()) {
+  if (!dst_.is_contiguous(MemoryFormat::Contiguous) && !sameMemFormat) {
     dst = at::empty_like(dst_, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
   } else {
     dst = dst_;
   }
 
   auto storage_byte_offset = src_.storage_offset() * src_.itemsize();
-  if (!src_.is_contiguous()) {
+  if (!src_.is_contiguous(MemoryFormat::Contiguous) && !sameMemFormat) {
     Tensor emptyShell = Tensor();
     src = gatherViewTensor(src_, emptyShell);
     if (src.has_storage()) {
@@ -250,8 +252,9 @@
   // gather into dst. This reduces the overhead of doing an additional blit for most cases
   bool returnGatherOutput = (dst_.is_contiguous() && !dst_byte_offset && src_.dtype() == dst_.dtype());
   Tensor src;
+  auto sameMemFormat = src_.is_contiguous(dst_.suggest_memory_format()) && dst_.is_contiguous(dst_.suggest_memory_format());
 
-  if (src_.is_view() || !src_.is_contiguous()) {
+  if (!src_.is_contiguous(MemoryFormat::Contiguous) && !sameMemFormat) {
     Tensor emptyShell = Tensor();
     src = gatherViewTensor(src_, returnGatherOutput ? dst_ : emptyShell);
 
@@ -273,7 +276,7 @@
   // Scatter to `dst` if the memory is not contiguous
   // If the memory is not contiguous, it means that the tensor has strides and we would not be
   // able to do the copy using a single blit
-  if (!dst_.is_contiguous()) {
+  if (!dst_.is_contiguous(MemoryFormat::Contiguous) && !sameMemFormat) {
     return scatterViewTensor(src, dst_);
   }
   src._set_conj(src_.is_conj());
diff --git a/test/test_mps.py b/test/test_mps.py
index bd3f5c1..8b282a9 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -4739,10 +4739,11 @@
 
     # Test selu, elu, celu
     def test_elu(self):
-        def helper(shape, alpha=1.0):
-            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
-            x = cpu_x.detach().clone().to('mps').requires_grad_()
+        def helper(shape, alpha=1.0, memory_format=torch.contiguous_format):
+            cpu_x = torch.randn(shape, device='cpu', dtype=torch.float)
+            cpu_x = cpu_x.to(memory_format=memory_format).requires_grad_()
 
+            x = cpu_x.detach().clone().to('mps').requires_grad_(True)
             for activation_func in [torch.nn.ELU(alpha=alpha), torch.nn.CELU(alpha=alpha), torch.nn.SELU()]:
                 elu_result = activation_func(x)
                 elu_result_cpu = activation_func(cpu_x)
@@ -4757,9 +4758,10 @@
                 self.assertEqual(x.grad, cpu_x.grad)
 
         # Test empty shape too
-        for shape in [[], (2, 3), (2, 8, 4, 5)]:
-            for alpha in [0.000001, 1.0, 2.3, 0.34, 23]:
-                helper(shape, alpha)
+        for memory_fromat in [torch.channels_last, torch.contiguous_format]:
+            for shape in [(2, 8, 4, 5)]:
+                for alpha in [0.000001, 1.0, 2.3, 0.34, 23]:
+                    helper(shape, alpha, memory_fromat)
 
     # Test glu
     def test_glu(self):