[MPS] Add Conv3D support for MPS (#114183)
Fixes #77818
I saw that PR #99246 was approved, but no one fixed the rebase conflicts, so I am bringing this up again to be merged.
I am leveraging @mattiaspaul work. Quoting the description here:
> * this pull request enables 3D convolutions (forward/backward) for MPS (Apple Silicon) within the same Convolution.mm file as conv2d.
> * does not support channel_last (since pytorch doesn't implement channel_last for 3D tensors)
> * does not support conv3d_transpose and treats depth-separable convolutions not as normal case (there are no MPS kernels available for either of those so far)
> * requires MacOS >=13.2 (Ventura)
Please, let me know if there are any other changes needed and I'll be happy to implement them.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114183
Approved by: https://github.com/malfet
diff --git a/test/test_mps.py b/test/test_mps.py
index 6347d2e..078ef90 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -2641,8 +2641,38 @@
cpu_norm = torch.nn.utils.weight_norm(cpu_conv, dim=dim)
norm = torch.nn.utils.weight_norm(conv, dim=dim)
- cpu_out = cpu_conv(cpu_x)
- out = conv(x)
+ cpu_out = cpu_norm(cpu_x)
+ out = norm(x)
+
+ self.assertEqual(cpu_out, out)
+
+ cpu_grad = torch.randn(cpu_out.shape)
+ grad = cpu_grad.to('mps')
+ cpu_out.backward(gradient=cpu_grad)
+ out.backward(gradient=grad)
+
+ self.assertEqual(cpu_conv.weight_g.grad, conv.weight_g.grad)
+ self.assertEqual(cpu_conv.weight_v.grad, conv.weight_v.grad)
+
+ self.assertEqual(x.grad, cpu_x.grad)
+
+ # conv layer
+ if layer == 'conv3d':
+ cpu_x = torch.randn((3, 5, 5, 4), device='cpu', dtype=dtype, requires_grad=True)
+ x = cpu_x.detach().clone().to('mps').requires_grad_()
+
+ cpu_conv = torch.nn.Conv3d(3, 3, 3, device='cpu')
+ conv = torch.nn.Conv3d(3, 3, 3, device='mps')
+
+ with torch.no_grad():
+ conv.weight.copy_(cpu_conv.weight)
+ conv.bias.copy_(cpu_conv.bias)
+
+ cpu_norm = torch.nn.utils.weight_norm(cpu_conv, dim=dim)
+ norm = torch.nn.utils.weight_norm(conv, dim=dim)
+
+ cpu_out = cpu_norm(cpu_x)
+ out = norm(x)
self.assertEqual(cpu_out, out)
@@ -2666,6 +2696,15 @@
helper(3, layer='conv')
helper(-1, layer='conv')
+ if product_version >= 13.2:
+ # Conv3d is only available from MacOS 13 onwards
+ helper(0, layer='conv3d')
+ helper(1, layer='conv3d')
+ helper(2, layer='conv3d')
+ helper(3, layer='conv3d')
+ helper(4, layer='conv3d')
+ helper(-1, layer='conv3d')
+
# Test conv2d
def test_conv2d_unit(self):
def helper(input_shape, wt_shape,
@@ -8229,6 +8268,17 @@
# This used to crash with MPSNDArrayConvolutionA14.mm:4352: failed assertion
y2.sum().backward()
+ @unittest.skipIf(product_version < 13.2, "Skipped on macOS 12")
+ def test_conv3d_backward_collision(self):
+ # Conv3D is only available from MacOS 13.2 onwards
+ x = torch.rand(1, 1, 10, 10, 20, device="mps", requires_grad=True)
+ m1 = nn.Conv3d(1, 1, 3, stride=2, padding=1).to("mps")
+ m2 = nn.Conv3d(1, 1, 4, stride=2, padding=1).to("mps")
+ y1, y2 = m1(x), m2(x)
+ self.assertEqual(y1.shape, y2.shape)
+ y1.sum().backward()
+ # This used to crash with MPSNDArrayConvolutionA14.mm:4352: failed assertion
+ y2.sum().backward()
def test_gemm_permute_transpose(self):
batch_size = 32
@@ -9556,6 +9606,18 @@
x_gpu = conv_gpu(y_gpu)
self.assertEqual(x_cpu, x_gpu.cpu(), rtol=1e-03, atol=1e-05)
+ @unittest.skipIf(product_version < 13.2, "Skipped on macOS 12")
+ def test_conv3d_single_stride(self):
+ # Conv3d is only available from MacOS 13.2 onwards
+ y_cpu = torch.randn(2, 2, 3, 6)
+ y_gpu = y_cpu.to(device='mps')
+ for stride in range(1, 4):
+ conv_cpu = torch.nn.Conv3d(in_channels=2, out_channels=2, kernel_size=2, stride=stride)
+ conv_gpu = copy.deepcopy(conv_cpu).to(device='mps')
+ x_cpu = conv_cpu(y_cpu)
+ x_gpu = conv_gpu(y_gpu)
+ self.assertEqual(x_cpu, x_gpu.cpu(), rtol=1e-03, atol=1e-05)
+
def test_grid_sample(self):
def test(N, C, H, W, mode, padding_mode, align_corners, input_requires_grad):
def test_shape(N, C, IH, IW, H, W, mode, padding_mode, align_corners):