[MPS] Register norm_dtype_out_mps and cdist (#91643)

Add support for `norm_dtype_out` and `cdist` ops
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91643
Approved by: https://github.com/razarmehr
diff --git a/test/test_mps.py b/test/test_mps.py
index 00dadbe..2084f95 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -289,6 +289,151 @@
         helper(0, [1024])
         helper(0.2, [2, 3])
 
+    def test_cdist_large(self, device="mps"):
+        for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
+            x = torch.randn(100, 10, device=device)
+            y = torch.randn(100, 10, device=device)
+            actual = torch.cdist(x, y, p=2, compute_mode=cm)
+            expected = self._brute_cdist(x, y, p=2)
+            self.assertEqual(expected, actual)
+
+    def test_cdist_large_batch(self, device="mps"):
+        for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
+            x = torch.randn(4, 3, 100, 10, device=device)
+            y = torch.randn(4, 3, 100, 10, device=device)
+            actual = torch.cdist(x, y, p=2, compute_mode=cm)
+            expected = self._brute_cdist(x, y, p=2)
+            self.assertEqual(expected, actual)
+
+    def test_cdist_non_contiguous(self, device="mps"):
+        for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
+            x = torch.randn(5, 7, device=device).mT
+            y = torch.randn(5, 3, device=device).mT
+            actual = torch.cdist(x, y, p=2, compute_mode=cm)
+            expected = self._brute_cdist(x, y, p=2)
+            self.assertFalse(x.is_contiguous())
+            self.assertFalse(y.is_contiguous())
+            self.assertEqual(expected, actual)
+
+            x = torch.randn(7, 5, device=device)
+            y = torch.randn(5, 3, device=device).t()
+            actual = torch.cdist(x, y, p=2, compute_mode=cm)
+            expected = self._brute_cdist(x, y, p=2)
+            self.assertTrue(x.is_contiguous())
+            self.assertFalse(y.is_contiguous())
+            self.assertEqual(expected, actual)
+
+            x = torch.randn(5, 7, device=device).t()
+            y = torch.randn(3, 5, device=device)
+            actual = torch.cdist(x, y, p=2, compute_mode=cm)
+            expected = self._brute_cdist(x, y, p=2)
+            self.assertFalse(x.is_contiguous())
+            self.assertTrue(y.is_contiguous())
+            self.assertEqual(expected, actual)
+
+    def test_cdist_non_contiguous_batch(self, device="mps"):
+        for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
+            x = torch.randn(4, 3, 2, 5, 7, device=device).mT
+            y = torch.randn(4, 3, 2, 5, 3, device=device).mT
+            actual = torch.cdist(x, y, p=2, compute_mode=cm)
+            expected = self._brute_cdist(x, y, p=2)
+            self.assertFalse(x.is_contiguous())
+            self.assertFalse(y.is_contiguous())
+            self.assertEqual(expected, actual)
+
+            x = torch.randn(7, 2, 7, 5, device=device)
+            y = torch.randn(7, 2, 5, 3, device=device).mT
+            actual = torch.cdist(x, y, p=2, compute_mode=cm)
+            expected = self._brute_cdist(x, y, p=2)
+            self.assertTrue(x.is_contiguous())
+            self.assertFalse(y.is_contiguous())
+            self.assertEqual(expected, actual)
+
+            x = torch.randn(4, 5, 7, device=device).mT
+            y = torch.randn(4, 3, 5, device=device)
+            actual = torch.cdist(x, y, p=2, compute_mode=cm)
+            expected = self._brute_cdist(x, y, p=2)
+            self.assertFalse(x.is_contiguous())
+            self.assertTrue(y.is_contiguous())
+            self.assertEqual(expected, actual)
+
+    def test_cdist_euclidean_large(self, device="mps"):
+        def _test_euclidean_large_cdist(sizex, sizey=None):
+            if sizey is None:
+                sizey = sizex
+            x = torch.randn(sizex, device=device, dtype=torch.float)
+            y = torch.randn(sizey, device=device, dtype=torch.float)
+            eps = 1e-6
+            # to avoid extremum
+            x = x - (((x - y) < eps).float() * 2 * eps)
+            x.requires_grad = True
+            y.requires_grad = True
+            dist = torch.cdist(x, y, p=2)
+            # Do a backward pass to check that it is valid for large
+            # matrices
+            loss = dist.sum()
+            loss.backward()
+
+        _test_euclidean_large_cdist((2000, 5))
+
+    def test_cdist_same_inputs(self, device="mps"):
+        # Test to detect issues in cdist gradient calculation
+        # When the distances are 0
+        sizex = (1, 27, 32)
+        for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
+            x = torch.randn(sizex, device=device, dtype=torch.float)
+            dist_grad = torch.randn((1, 27, 27), device=device, dtype=torch.float)
+            y = x.clone()
+            eps = 1e-6
+            x.requires_grad = True
+            d = torch.cdist(x, y)
+            d.backward(dist_grad)
+            # Check that the backward passs does not contain invalid
+            # values such as nan or inf
+            assert torch.isfinite(x.grad).all()
+
+
+    def _brute_cdist(self, x, y, p=2):
+        r1 = x.shape[-2]
+        r2 = y.shape[-2]
+        if r1 == 0 or r2 == 0:
+            return torch.empty(r1, r2, device=x.device)
+        return torch.norm(x[..., None, :] - y[..., None, :, :], p=p, dim=-1)
+
+    def test_cdist_norm(self, device="mps"):
+        for r1 in [3, 4]:
+            for m in [2, 3]:
+                for r2 in [4, 6]:
+                    for p in [0, 1, 1.5, 2.5, float('inf')]:
+                        x = torch.randn(r1, m, device=device)
+                        y = torch.randn(r2, m, device=device)
+                        if p == 2:
+                            for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
+                                actual = torch.cdist(x, y, p=2, compute_mode=cm)
+                                expected = self._brute_cdist(x, y, p=2)
+                                self.assertEqual(expected, actual, rtol=0, atol=0.02)
+                        else:
+                            actual = torch.cdist(x, y, p=p)
+                            expected = self._brute_cdist(x, y, p=p)
+                            self.assertEqual(expected, actual)
+
+    def test_cdist_norm_batch(self, device="mps"):
+        for r1 in [3, 4]:
+            for m in [2, 3]:
+                for r2 in [4, 6]:
+                    for p in [0, 3, 1.5, 2.5, float('inf')]:
+                        x = torch.randn(2, 3, 6, r1, m, device=device)
+                        y = torch.randn(2, 3, 6, r2, m, device=device)
+                        if p == 2:
+                            for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
+                                actual = torch.cdist(x, y, p=2, compute_mode=cm)
+                                expected = self._brute_cdist(x, y, p=2)
+                                self.assertEqual(expected, actual, rtol=0, atol=0.02)
+                        else:
+                            actual = torch.cdist(x, y, p=p)
+                            expected = self._brute_cdist(x, y, p=p)
+                            self.assertEqual(expected, actual)
+
     def test_mm(self):
         B = torch.ones(5, 6).to("mps")
         C = torch.ones(6, 5).to("mps")
@@ -809,6 +954,55 @@
                         helper(shape, eps=3, momentum=0.67, wts=True, training=True, channels_last=channels_last,
                                track_running_stats=track_running_stats, test_module=test_module)
 
+    def test_norm(self):
+        a = torch.arange(9, dtype=torch.float, device="mps") - 4
+        b = a.reshape((3, 3))
+
+        a_cpu = torch.arange(9, dtype=torch.float, device="cpu") - 4
+        b_cpu = a_cpu.reshape((3, 3))
+
+        res = torch.norm(a)
+        res_cpu = torch.norm(a_cpu)
+        self.assertEqual(res, res_cpu)
+
+        res = torch.norm(b)
+        res_cpu = torch.norm(b_cpu)
+        self.assertEqual(res, res_cpu)
+
+        res = torch.norm(a, float('inf'))
+        res_cpu = torch.norm(a_cpu, float('inf'))
+        self.assertEqual(res, res_cpu)
+
+        res = torch.norm(b, float('inf'))
+        res_cpu = torch.norm(b_cpu, float('inf'))
+        self.assertEqual(res, res_cpu)
+
+        c = torch.tensor([[1, 2, 3], [-1, 1, 4]], dtype=torch.float, device="mps")
+        c_cpu = torch.tensor([[1, 2, 3], [-1, 1, 4]] , dtype=torch.float, device="cpu")
+
+        res = torch.norm(c, dim=0)
+        res_cpu = torch.norm(c_cpu, dim=0)
+        self.assertEqual(res, res_cpu)
+
+        res = torch.norm(c, dim=1)
+        res_cpu = torch.norm(c_cpu, dim=1)
+        self.assertEqual(res, res_cpu)
+
+        res = torch.norm(c, p=1, dim=1)
+        res_cpu = torch.norm(c_cpu, p=1, dim=1)
+        self.assertEqual(res, res_cpu)
+
+        d = torch.arange(8, dtype=torch.float, device="mps").reshape(2, 2, 2)
+        d_cpu = torch.arange(8, dtype=torch.float, device="cpu").reshape(2, 2, 2)
+
+        res = torch.norm(d, dim=(1, 2))
+        res_cpu = torch.norm(d_cpu, dim=(1, 2))
+        self.assertEqual(res, res_cpu)
+
+        res = torch.norm(d[0, :, :]), torch.norm(d[1, :, :])
+        res_cpu = torch.norm(d_cpu[0, :, :]), torch.norm(d_cpu[1, :, :])
+        self.assertEqual(res, res_cpu)
+
     def test_layer_norm(self):
         # TODO: Test non-contiguous
         def helper(input_shape, normalized_shape, eps=1e-05, elementwise_affine=True, dtype=torch.float32):