[MPS] Register norm_dtype_out_mps and cdist (#91643)
Add support for `norm_dtype_out` and `cdist` ops
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91643
Approved by: https://github.com/razarmehr
diff --git a/test/test_mps.py b/test/test_mps.py
index 00dadbe..2084f95 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -289,6 +289,151 @@
helper(0, [1024])
helper(0.2, [2, 3])
+ def test_cdist_large(self, device="mps"):
+ for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
+ x = torch.randn(100, 10, device=device)
+ y = torch.randn(100, 10, device=device)
+ actual = torch.cdist(x, y, p=2, compute_mode=cm)
+ expected = self._brute_cdist(x, y, p=2)
+ self.assertEqual(expected, actual)
+
+ def test_cdist_large_batch(self, device="mps"):
+ for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
+ x = torch.randn(4, 3, 100, 10, device=device)
+ y = torch.randn(4, 3, 100, 10, device=device)
+ actual = torch.cdist(x, y, p=2, compute_mode=cm)
+ expected = self._brute_cdist(x, y, p=2)
+ self.assertEqual(expected, actual)
+
+ def test_cdist_non_contiguous(self, device="mps"):
+ for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
+ x = torch.randn(5, 7, device=device).mT
+ y = torch.randn(5, 3, device=device).mT
+ actual = torch.cdist(x, y, p=2, compute_mode=cm)
+ expected = self._brute_cdist(x, y, p=2)
+ self.assertFalse(x.is_contiguous())
+ self.assertFalse(y.is_contiguous())
+ self.assertEqual(expected, actual)
+
+ x = torch.randn(7, 5, device=device)
+ y = torch.randn(5, 3, device=device).t()
+ actual = torch.cdist(x, y, p=2, compute_mode=cm)
+ expected = self._brute_cdist(x, y, p=2)
+ self.assertTrue(x.is_contiguous())
+ self.assertFalse(y.is_contiguous())
+ self.assertEqual(expected, actual)
+
+ x = torch.randn(5, 7, device=device).t()
+ y = torch.randn(3, 5, device=device)
+ actual = torch.cdist(x, y, p=2, compute_mode=cm)
+ expected = self._brute_cdist(x, y, p=2)
+ self.assertFalse(x.is_contiguous())
+ self.assertTrue(y.is_contiguous())
+ self.assertEqual(expected, actual)
+
+ def test_cdist_non_contiguous_batch(self, device="mps"):
+ for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
+ x = torch.randn(4, 3, 2, 5, 7, device=device).mT
+ y = torch.randn(4, 3, 2, 5, 3, device=device).mT
+ actual = torch.cdist(x, y, p=2, compute_mode=cm)
+ expected = self._brute_cdist(x, y, p=2)
+ self.assertFalse(x.is_contiguous())
+ self.assertFalse(y.is_contiguous())
+ self.assertEqual(expected, actual)
+
+ x = torch.randn(7, 2, 7, 5, device=device)
+ y = torch.randn(7, 2, 5, 3, device=device).mT
+ actual = torch.cdist(x, y, p=2, compute_mode=cm)
+ expected = self._brute_cdist(x, y, p=2)
+ self.assertTrue(x.is_contiguous())
+ self.assertFalse(y.is_contiguous())
+ self.assertEqual(expected, actual)
+
+ x = torch.randn(4, 5, 7, device=device).mT
+ y = torch.randn(4, 3, 5, device=device)
+ actual = torch.cdist(x, y, p=2, compute_mode=cm)
+ expected = self._brute_cdist(x, y, p=2)
+ self.assertFalse(x.is_contiguous())
+ self.assertTrue(y.is_contiguous())
+ self.assertEqual(expected, actual)
+
+ def test_cdist_euclidean_large(self, device="mps"):
+ def _test_euclidean_large_cdist(sizex, sizey=None):
+ if sizey is None:
+ sizey = sizex
+ x = torch.randn(sizex, device=device, dtype=torch.float)
+ y = torch.randn(sizey, device=device, dtype=torch.float)
+ eps = 1e-6
+ # to avoid extremum
+ x = x - (((x - y) < eps).float() * 2 * eps)
+ x.requires_grad = True
+ y.requires_grad = True
+ dist = torch.cdist(x, y, p=2)
+ # Do a backward pass to check that it is valid for large
+ # matrices
+ loss = dist.sum()
+ loss.backward()
+
+ _test_euclidean_large_cdist((2000, 5))
+
+ def test_cdist_same_inputs(self, device="mps"):
+ # Test to detect issues in cdist gradient calculation
+ # When the distances are 0
+ sizex = (1, 27, 32)
+ for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
+ x = torch.randn(sizex, device=device, dtype=torch.float)
+ dist_grad = torch.randn((1, 27, 27), device=device, dtype=torch.float)
+ y = x.clone()
+ eps = 1e-6
+ x.requires_grad = True
+ d = torch.cdist(x, y)
+ d.backward(dist_grad)
+ # Check that the backward passs does not contain invalid
+ # values such as nan or inf
+ assert torch.isfinite(x.grad).all()
+
+
+ def _brute_cdist(self, x, y, p=2):
+ r1 = x.shape[-2]
+ r2 = y.shape[-2]
+ if r1 == 0 or r2 == 0:
+ return torch.empty(r1, r2, device=x.device)
+ return torch.norm(x[..., None, :] - y[..., None, :, :], p=p, dim=-1)
+
+ def test_cdist_norm(self, device="mps"):
+ for r1 in [3, 4]:
+ for m in [2, 3]:
+ for r2 in [4, 6]:
+ for p in [0, 1, 1.5, 2.5, float('inf')]:
+ x = torch.randn(r1, m, device=device)
+ y = torch.randn(r2, m, device=device)
+ if p == 2:
+ for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
+ actual = torch.cdist(x, y, p=2, compute_mode=cm)
+ expected = self._brute_cdist(x, y, p=2)
+ self.assertEqual(expected, actual, rtol=0, atol=0.02)
+ else:
+ actual = torch.cdist(x, y, p=p)
+ expected = self._brute_cdist(x, y, p=p)
+ self.assertEqual(expected, actual)
+
+ def test_cdist_norm_batch(self, device="mps"):
+ for r1 in [3, 4]:
+ for m in [2, 3]:
+ for r2 in [4, 6]:
+ for p in [0, 3, 1.5, 2.5, float('inf')]:
+ x = torch.randn(2, 3, 6, r1, m, device=device)
+ y = torch.randn(2, 3, 6, r2, m, device=device)
+ if p == 2:
+ for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
+ actual = torch.cdist(x, y, p=2, compute_mode=cm)
+ expected = self._brute_cdist(x, y, p=2)
+ self.assertEqual(expected, actual, rtol=0, atol=0.02)
+ else:
+ actual = torch.cdist(x, y, p=p)
+ expected = self._brute_cdist(x, y, p=p)
+ self.assertEqual(expected, actual)
+
def test_mm(self):
B = torch.ones(5, 6).to("mps")
C = torch.ones(6, 5).to("mps")
@@ -809,6 +954,55 @@
helper(shape, eps=3, momentum=0.67, wts=True, training=True, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
+ def test_norm(self):
+ a = torch.arange(9, dtype=torch.float, device="mps") - 4
+ b = a.reshape((3, 3))
+
+ a_cpu = torch.arange(9, dtype=torch.float, device="cpu") - 4
+ b_cpu = a_cpu.reshape((3, 3))
+
+ res = torch.norm(a)
+ res_cpu = torch.norm(a_cpu)
+ self.assertEqual(res, res_cpu)
+
+ res = torch.norm(b)
+ res_cpu = torch.norm(b_cpu)
+ self.assertEqual(res, res_cpu)
+
+ res = torch.norm(a, float('inf'))
+ res_cpu = torch.norm(a_cpu, float('inf'))
+ self.assertEqual(res, res_cpu)
+
+ res = torch.norm(b, float('inf'))
+ res_cpu = torch.norm(b_cpu, float('inf'))
+ self.assertEqual(res, res_cpu)
+
+ c = torch.tensor([[1, 2, 3], [-1, 1, 4]], dtype=torch.float, device="mps")
+ c_cpu = torch.tensor([[1, 2, 3], [-1, 1, 4]] , dtype=torch.float, device="cpu")
+
+ res = torch.norm(c, dim=0)
+ res_cpu = torch.norm(c_cpu, dim=0)
+ self.assertEqual(res, res_cpu)
+
+ res = torch.norm(c, dim=1)
+ res_cpu = torch.norm(c_cpu, dim=1)
+ self.assertEqual(res, res_cpu)
+
+ res = torch.norm(c, p=1, dim=1)
+ res_cpu = torch.norm(c_cpu, p=1, dim=1)
+ self.assertEqual(res, res_cpu)
+
+ d = torch.arange(8, dtype=torch.float, device="mps").reshape(2, 2, 2)
+ d_cpu = torch.arange(8, dtype=torch.float, device="cpu").reshape(2, 2, 2)
+
+ res = torch.norm(d, dim=(1, 2))
+ res_cpu = torch.norm(d_cpu, dim=(1, 2))
+ self.assertEqual(res, res_cpu)
+
+ res = torch.norm(d[0, :, :]), torch.norm(d[1, :, :])
+ res_cpu = torch.norm(d_cpu[0, :, :]), torch.norm(d_cpu[1, :, :])
+ self.assertEqual(res, res_cpu)
+
def test_layer_norm(self):
# TODO: Test non-contiguous
def helper(input_shape, normalized_shape, eps=1e-05, elementwise_affine=True, dtype=torch.float32):