[MPS] Fix batch norm for NHWC (#94760)

Fixes `test_modules.py` batch norm NHWC testcases:
- `test_memory_format_nn_BatchNorm2d_eval_mode_mps_float32`
- `test_memory_format_nn_BatchNorm2d_eval_mode_mps_float32`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94760
Approved by: https://github.com/kulinseth
diff --git a/aten/src/ATen/native/mps/operations/Normalization.mm b/aten/src/ATen/native/mps/operations/Normalization.mm
index 1b4258e..34dd5f7 100644
--- a/aten/src/ATen/native/mps/operations/Normalization.mm
+++ b/aten/src/ATen/native/mps/operations/Normalization.mm
@@ -95,7 +95,7 @@
   const bool has_weight = (weight_opt.has_value() && weight_opt->defined());
   const bool has_bias = (bias_opt.has_value() && bias_opt->defined());
 
-  const auto memory_format = self.suggest_memory_format();
+  auto memory_format = self.suggest_memory_format();
 
   if (output.numel() == 0) {
     return std::tuple<Tensor&, Tensor&, Tensor&>(output, save_mean, save_var);;
@@ -147,6 +147,12 @@
     else
       channelsDim = num_input_dims - 1;
 
+    bool executeGatherOp = true;
+    if (self.is_contiguous(memory_format)) {
+      memory_format = MemoryFormat::Contiguous;
+      executeGatherOp = false;
+    }
+
     if(!cachedGraph) {
       native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () {
 
@@ -318,7 +324,7 @@
       cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
     }
 
-    auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, self, input_shape);
+    auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, self, input_shape, executeGatherOp);
     auto weightPlaceholder = native_mps::Placeholder();
     if(has_weight)
       weightPlaceholder = native_mps::Placeholder(cachedGraph->weightTensor_, weight_opt.value(), new_mean_shape);
@@ -340,7 +346,7 @@
       runningVarInplaceUpdatePlaceholder = native_mps::Placeholder(cachedGraph->runningVarInplaceUpdate_, running_var_opt.value());
     }
 
-    auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output, input_shape);
+    auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output, input_shape, false);
     auto saveMeanPlaceholder = native_mps::Placeholder(cachedGraph->saveMeanTensor_, save_mean);
     auto saveVarPlaceholder = native_mps::Placeholder(cachedGraph->saveVarTensor_, save_var);
 
diff --git a/test/test_mps.py b/test/test_mps.py
index d23027e..9c03c76 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -5762,22 +5762,22 @@
             self.assertEqual(scatter_result, scatter_result_cpu)
 
         # for reduce in ["sum", "prod", "amax", "amin"]:
-        for reduce in ["add", "multiply"]:
-            helper((2, 3), 0, (5, 3), (5, 3), reduce_str=reduce)
-            helper((2, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5), reduce_str=reduce)
-            helper((8, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5), reduce_str=reduce)
-            helper((8, 8, 4, 5), 0, (4, 7, 3, 2), (4, 7, 3, 2), reduce_str=reduce)
-            helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (4, 7, 3, 2), reduce_str=reduce)
-            helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (8, 8, 4, 5), reduce_str=reduce)
+        for reduce_type in ["add", "multiply"]:
+            helper((2, 3), 0, (5, 3), (5, 3), reduce_str=reduce_type)
+            helper((2, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5), reduce_str=reduce_type)
+            helper((8, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5), reduce_str=reduce_type)
+            helper((8, 8, 4, 5), 0, (4, 7, 3, 2), (4, 7, 3, 2), reduce_str=reduce_type)
+            helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (4, 7, 3, 2), reduce_str=reduce_type)
+            helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (8, 8, 4, 5), reduce_str=reduce_type)
 
-            helper((2, 8, 4, 5), 1, (2, 20, 4, 5), (2, 20, 4, 5), reduce_str=reduce)
-            helper((2, 8, 4, 5), 1, (2, 13, 3, 2), (2, 13, 3, 2), reduce_str=reduce)
-            helper((8, 8, 4, 5), 1, (6, 5, 2, 3), (6, 5, 2, 3), reduce_str=reduce)
-            helper((8, 8, 4, 5), 1, (3, 4, 2, 2), (6, 5, 2, 3), reduce_str=reduce)
+            helper((2, 8, 4, 5), 1, (2, 20, 4, 5), (2, 20, 4, 5), reduce_str=reduce_type)
+            helper((2, 8, 4, 5), 1, (2, 13, 3, 2), (2, 13, 3, 2), reduce_str=reduce_type)
+            helper((8, 8, 4, 5), 1, (6, 5, 2, 3), (6, 5, 2, 3), reduce_str=reduce_type)
+            helper((8, 8, 4, 5), 1, (3, 4, 2, 2), (6, 5, 2, 3), reduce_str=reduce_type)
 
-            helper((4, 5, 9, 8), 2, (4, 5, 13, 8), (4, 5, 13, 8), reduce_str=reduce)
-            helper((4, 5, 9, 8), 2, (3, 4, 10, 6), (3, 4, 10, 6), reduce_str=reduce)
-            helper((4, 5, 9, 8), 2, (3, 3, 7, 5), (3, 4, 10, 6), reduce_str=reduce)
+            helper((4, 5, 9, 8), 2, (4, 5, 13, 8), (4, 5, 13, 8), reduce_str=reduce_type)
+            helper((4, 5, 9, 8), 2, (3, 4, 10, 6), (3, 4, 10, 6), reduce_str=reduce_type)
+            helper((4, 5, 9, 8), 2, (3, 3, 7, 5), (3, 4, 10, 6), reduce_str=reduce_type)
 
     def test_is_nonzero(self):
         self.assertFalse(torch.is_nonzero(torch.tensor([0.]).to('mps')))