[MPS] Fix upsample for NHWC output (#94963)
Fixes https://github.com/huggingface/diffusers/issues/941
**Before**:
<img width="1144" alt="Screenshot 2023-02-15 at 8 11 53 PM" src="https://user-images.githubusercontent.com/104024078/219266709-6a77636a-2fc0-4802-b130-85069b95953f.png">
**After**:
<img width="1144" alt="Screenshot 2023-02-15 at 8 12 02 PM" src="https://user-images.githubusercontent.com/104024078/219266694-ea743c02-fb55-44f1-b7d6-5946106527c3.png">
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94963
Approved by: https://github.com/razarmehr
diff --git a/aten/src/ATen/native/mps/operations/UpSample.mm b/aten/src/ATen/native/mps/operations/UpSample.mm
index 17895e1..3b781de 100644
--- a/aten/src/ATen/native/mps/operations/UpSample.mm
+++ b/aten/src/ATen/native/mps/operations/UpSample.mm
@@ -26,6 +26,11 @@
} else {
native::upsample_2d_common_check(input.sizes(), output_size);
}
+ Tensor out;
+ if (!output.is_contiguous()) {
+ out = at::empty_like(output, MemoryFormat::Contiguous);
+ }
+
bool centerResults = false;
MPSGraphResizeMode resizeMode = MPSGraphResizeNearest;
MPSGraphResizeNearestRoundingMode nearestRoundingMode = MPSGraphResizeNearestRoundingModeFloor;
@@ -199,7 +204,7 @@
MPSGraphTensorData* sizeTensorData = [[[MPSGraphTensorData alloc] initWithMPSNDArray: sizeNDArray] autorelease];
Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input);
- Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, output);
+ Placeholder outputPlaceholder = Placeholder(cachedGraph->outputTensor, out.has_storage() ? out : output, nil, false);
NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
@@ -209,6 +214,10 @@
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
+
+ if (out.has_storage()) {
+ output.copy_(out);
+ }
}
}
diff --git a/test/test_mps.py b/test/test_mps.py
index 0f4b2ea..05b42c7 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -4545,9 +4545,9 @@
)
def test_upsample_nearest2d(self):
- def helper(N, C, H, W):
+ def helper(N, C, H, W, memory_format):
inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float,
- requires_grad=True).reshape(N, C, H, W)
+ requires_grad=True).reshape(N, C, H, W).to(memory_format=memory_format)
inputCPU.retain_grad()
inputMPS = inputCPU.detach().to('mps').requires_grad_()
@@ -4573,8 +4573,9 @@
self.assertEqual(inputCPU.grad, inputMPS.grad)
- helper(1, 1, 4, 4)
- helper(7, 5, 3, 2)
+ for memory_format in [torch.channels_last, torch.contiguous_format]:
+ helper(1, 1, 4, 4, memory_format=memory_format)
+ helper(7, 5, 3, 2, memory_format=memory_format)
def test_upsample_bilinear2d(self):
def helper(N, C, H, W):