[MPS] Fix upsample for NHWC output (#94963)
Fixes https://github.com/huggingface/diffusers/issues/941
**Before**:
<img width="1144" alt="Screenshot 2023-02-15 at 8 11 53 PM" src="https://user-images.githubusercontent.com/104024078/219266709-6a77636a-2fc0-4802-b130-85069b95953f.png">
**After**:
<img width="1144" alt="Screenshot 2023-02-15 at 8 12 02 PM" src="https://user-images.githubusercontent.com/104024078/219266694-ea743c02-fb55-44f1-b7d6-5946106527c3.png">
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94963
Approved by: https://github.com/razarmehr
diff --git a/test/test_mps.py b/test/test_mps.py
index 0f4b2ea..05b42c7 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -4545,9 +4545,9 @@
)
def test_upsample_nearest2d(self):
- def helper(N, C, H, W):
+ def helper(N, C, H, W, memory_format):
inputCPU = torch.arange(N * C * H * W, device='cpu', dtype=torch.float,
- requires_grad=True).reshape(N, C, H, W)
+ requires_grad=True).reshape(N, C, H, W).to(memory_format=memory_format)
inputCPU.retain_grad()
inputMPS = inputCPU.detach().to('mps').requires_grad_()
@@ -4573,8 +4573,9 @@
self.assertEqual(inputCPU.grad, inputMPS.grad)
- helper(1, 1, 4, 4)
- helper(7, 5, 3, 2)
+ for memory_format in [torch.channels_last, torch.contiguous_format]:
+ helper(1, 1, 4, 4, memory_format=memory_format)
+ helper(7, 5, 3, 2, memory_format=memory_format)
def test_upsample_bilinear2d(self):
def helper(N, C, H, W):