[MPS] Fix upsample for NHWC output  (#94963)

Fixes https://github.com/huggingface/diffusers/issues/941

**Before**:
<img width="1144" alt="Screenshot 2023-02-15 at 8 11 53 PM" src="https://user-images.githubusercontent.com/104024078/219266709-6a77636a-2fc0-4802-b130-85069b95953f.png">

**After**:
<img width="1144" alt="Screenshot 2023-02-15 at 8 12 02 PM" src="https://user-images.githubusercontent.com/104024078/219266694-ea743c02-fb55-44f1-b7d6-5946106527c3.png">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94963
Approved by: https://github.com/razarmehr
diff --git a/aten/src/ATen/native/mps/operations/UpSample.mm b/aten/src/ATen/native/mps/operations/UpSample.mm
index 17895e1..3b781de 100644
--- a/aten/src/ATen/native/mps/operations/UpSample.mm
+++ b/aten/src/ATen/native/mps/operations/UpSample.mm
@@ -26,6 +26,11 @@
   } else {
     native::upsample_2d_common_check(input.sizes(), output_size);
   }
+  Tensor out;
+  if (!output.is_contiguous()) {
+    out = at::empty_like(output, MemoryFormat::Contiguous);
+  }
+
   bool centerResults = false;
   MPSGraphResizeMode resizeMode = MPSGraphResizeNearest;
   MPSGraphResizeNearestRoundingMode nearestRoundingMode = MPSGraphResizeNearestRoundingModeFloor;
@@ -199,7 +204,7 @@
     MPSGraphTensorData* sizeTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray: sizeNDArray] autorelease];
 
     Placeholder inputPlaceholder  = Placeholder(cachedGraph->inputTensor, input);
-    Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output);
+    Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, out.has_storage() ? out : output, nil, false);
 
     NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
         inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
@@ -209,6 +214,10 @@
         outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
     };
     runMPSGraph(stream, cachedGraph->graph(), feeds, results);
+
+    if (out.has_storage()) {
+      output.copy_(out);
+    }
   }
 }
 
diff --git a/test/test_mps.py b/test/test_mps.py
index 0f4b2ea..05b42c7 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -4545,9 +4545,9 @@
             )
 
     def test_upsample_nearest2d(self):
-        def helper(N, C, H, W):
+        def helper(N, C, H, W, memory_format):
             inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float,
-                                    requires_grad=True).reshape(N, C, H, W)
+                                    requires_grad=True).reshape(N, C, H, W).to(memory_format=memory_format)
             inputCPU.retain_grad()
             inputMPS = inputCPU.detach().to('mps').requires_grad_()
 
@@ -4573,8 +4573,9 @@
 
                     self.assertEqual(inputCPU.grad, inputMPS.grad)
 
-        helper(1, 1, 4, 4)
-        helper(7, 5, 3, 2)
+        for memory_format in [torch.channels_last, torch.contiguous_format]:
+            helper(1, 1, 4, 4, memory_format=memory_format)
+            helper(7, 5, 3, 2, memory_format=memory_format)
 
     def test_upsample_bilinear2d(self):
         def helper(N, C, H, W):