[MPS] Fix batch norm for NHWC (#94760)
Fixes `test_modules.py` batch norm NHWC testcases:
- `test_memory_format_nn_BatchNorm2d_eval_mode_mps_float32`
- `test_memory_format_nn_BatchNorm2d_eval_mode_mps_float32`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94760
Approved by: https://github.com/kulinseth
diff --git a/aten/src/ATen/native/mps/operations/Normalization.mm b/aten/src/ATen/native/mps/operations/Normalization.mm
index 1b4258e..34dd5f7 100644
--- a/aten/src/ATen/native/mps/operations/Normalization.mm
+++ b/aten/src/ATen/native/mps/operations/Normalization.mm
@@ -95,7 +95,7 @@
const bool has_weight = (weight_opt.has_value() && weight_opt->defined());
const bool has_bias = (bias_opt.has_value() && bias_opt->defined());
- const auto memory_format = self.suggest_memory_format();
+ auto memory_format = self.suggest_memory_format();
if (output.numel() == 0) {
return std::tuple<Tensor&, Tensor&, Tensor&>(output, save_mean, save_var);;
@@ -147,6 +147,12 @@
else
channelsDim = num_input_dims - 1;
+ bool executeGatherOp = true;
+ if (self.is_contiguous(memory_format)) {
+ memory_format = MemoryFormat::Contiguous;
+ executeGatherOp = false;
+ }
+
if(!cachedGraph) {
native_mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ native_mps::MPSCachedGraph * () {
@@ -318,7 +324,7 @@
cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
}
- auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, self, input_shape);
+ auto inputPlaceholder = native_mps::Placeholder(cachedGraph->inputTensor_, self, input_shape, executeGatherOp);
auto weightPlaceholder = native_mps::Placeholder();
if(has_weight)
weightPlaceholder = native_mps::Placeholder(cachedGraph->weightTensor_, weight_opt.value(), new_mean_shape);
@@ -340,7 +346,7 @@
runningVarInplaceUpdatePlaceholder = native_mps::Placeholder(cachedGraph->runningVarInplaceUpdate_, running_var_opt.value());
}
- auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output, input_shape);
+ auto outputPlaceholder = native_mps::Placeholder(cachedGraph->outputTensor_, output, input_shape, false);
auto saveMeanPlaceholder = native_mps::Placeholder(cachedGraph->saveMeanTensor_, save_mean);
auto saveVarPlaceholder = native_mps::Placeholder(cachedGraph->saveVarTensor_, save_var);
diff --git a/test/test_mps.py b/test/test_mps.py
index d23027e..9c03c76 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -5762,22 +5762,22 @@
self.assertEqual(scatter_result, scatter_result_cpu)
# for reduce in ["sum", "prod", "amax", "amin"]:
- for reduce in ["add", "multiply"]:
- helper((2, 3), 0, (5, 3), (5, 3), reduce_str=reduce)
- helper((2, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5), reduce_str=reduce)
- helper((8, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5), reduce_str=reduce)
- helper((8, 8, 4, 5), 0, (4, 7, 3, 2), (4, 7, 3, 2), reduce_str=reduce)
- helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (4, 7, 3, 2), reduce_str=reduce)
- helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (8, 8, 4, 5), reduce_str=reduce)
+ for reduce_type in ["add", "multiply"]:
+ helper((2, 3), 0, (5, 3), (5, 3), reduce_str=reduce_type)
+ helper((2, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5), reduce_str=reduce_type)
+ helper((8, 8, 4, 5), 0, (10, 8, 4, 5), (10, 8, 4, 5), reduce_str=reduce_type)
+ helper((8, 8, 4, 5), 0, (4, 7, 3, 2), (4, 7, 3, 2), reduce_str=reduce_type)
+ helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (4, 7, 3, 2), reduce_str=reduce_type)
+ helper((8, 8, 4, 5), 0, (4, 6, 3, 2), (8, 8, 4, 5), reduce_str=reduce_type)
- helper((2, 8, 4, 5), 1, (2, 20, 4, 5), (2, 20, 4, 5), reduce_str=reduce)
- helper((2, 8, 4, 5), 1, (2, 13, 3, 2), (2, 13, 3, 2), reduce_str=reduce)
- helper((8, 8, 4, 5), 1, (6, 5, 2, 3), (6, 5, 2, 3), reduce_str=reduce)
- helper((8, 8, 4, 5), 1, (3, 4, 2, 2), (6, 5, 2, 3), reduce_str=reduce)
+ helper((2, 8, 4, 5), 1, (2, 20, 4, 5), (2, 20, 4, 5), reduce_str=reduce_type)
+ helper((2, 8, 4, 5), 1, (2, 13, 3, 2), (2, 13, 3, 2), reduce_str=reduce_type)
+ helper((8, 8, 4, 5), 1, (6, 5, 2, 3), (6, 5, 2, 3), reduce_str=reduce_type)
+ helper((8, 8, 4, 5), 1, (3, 4, 2, 2), (6, 5, 2, 3), reduce_str=reduce_type)
- helper((4, 5, 9, 8), 2, (4, 5, 13, 8), (4, 5, 13, 8), reduce_str=reduce)
- helper((4, 5, 9, 8), 2, (3, 4, 10, 6), (3, 4, 10, 6), reduce_str=reduce)
- helper((4, 5, 9, 8), 2, (3, 3, 7, 5), (3, 4, 10, 6), reduce_str=reduce)
+ helper((4, 5, 9, 8), 2, (4, 5, 13, 8), (4, 5, 13, 8), reduce_str=reduce_type)
+ helper((4, 5, 9, 8), 2, (3, 4, 10, 6), (3, 4, 10, 6), reduce_str=reduce_type)
+ helper((4, 5, 9, 8), 2, (3, 3, 7, 5), (3, 4, 10, 6), reduce_str=reduce_type)
def test_is_nonzero(self):
self.assertFalse(torch.is_nonzero(torch.tensor([0.]).to('mps')))