[MPS] Register norm_dtype_out_mps and cdist (#91643)

Add support for `norm_dtype_out` and `cdist` ops
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91643
Approved by: https://github.com/razarmehr
diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h
index b0a3389..3066671 100644
--- a/aten/src/ATen/native/mps/OperationUtils.h
+++ b/aten/src/ATen/native/mps/OperationUtils.h
@@ -43,6 +43,8 @@
 MPSDataType getMPSScalarType(ScalarType scalar_type);
 MPSScalar   getMPSScalar(const Scalar& scalar, ScalarType type);
 std::string getMPSTypeString(ScalarType scalar_type);
+NSArray<NSNumber*>* getTensorAxes(const Tensor& t);
+NSArray<NSNumber*>* getTensorAxes(const Tensor& t, at::OptionalIntArrayRef dim);
 std::string getMPSShapeString(MPSShape* shape);
 std::string getTensorsStringKey(const TensorList& tensors, bool use_scalar_value = false);
 std::string getArrayRefString(const IntArrayRef s);
@@ -127,6 +129,13 @@
   MPSGraphTensor *outputTensor_ = nil;
 };
 
+struct MPSBinaryCachedGraph : public MPSCachedGraph
+{
+  MPSBinaryCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
+  MPSGraphTensor *inputTensor_ = nil;
+  MPSGraphTensor *otherTensor_ = nil;
+  MPSGraphTensor *outputTensor_ = nil;
+};
 
 // TODO: Improve the overall design of MPSGraphCache.
 // https://github.com/pytorch/pytorch/issues/77176
diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm
index 9da3b2e..8de0016 100644
--- a/aten/src/ATen/native/mps/OperationUtils.mm
+++ b/aten/src/ATen/native/mps/OperationUtils.mm
@@ -87,6 +87,30 @@
   }
 }
 
+NSArray<NSNumber*>* getTensorAxes(const Tensor& t) {
+  int64_t ndim = t.dim();
+  auto axes = [NSMutableArray<NSNumber*> arrayWithCapacity:ndim];
+  for (const auto i: c10::irange(ndim)) {
+    axes[i] = [NSNumber numberWithInteger:i];
+  }
+  return axes;
+}
+
+NSArray<NSNumber*>* getTensorAxes(const Tensor& t, at::OptionalIntArrayRef dim) {
+  if (dim.has_value() && dim.value().size() != 0) {
+    IntArrayRef dimValues = dim.value();
+    int ndim = dimValues.size();
+    auto axes = [NSMutableArray<NSNumber*> arrayWithCapacity:ndim];
+    for (const auto i: c10::irange(ndim)) {
+      axes[i] = [NSNumber numberWithInteger:dimValues[i]];
+    }
+
+    return axes;
+  }
+
+  return getTensorAxes(t);
+}
+
 std::string getMPSShapeString(MPSShape* shape) {
     std::string str;
     for(NSNumber *elem in shape) {
diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm
index 43b3336..9ef87b4 100644
--- a/aten/src/ATen/native/mps/operations/ReduceOps.mm
+++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm
@@ -15,6 +15,9 @@
 namespace at {
 namespace native {
 
+typedef MPSGraphTensor* (^NormOpBlock)(mps::MPSBinaryCachedGraph*, MPSGraphTensor*, MPSGraphTensor*);
+#define NormOpFn(graph, primary, secondary) MPSGraphTensor* (mps::MPSBinaryCachedGraph* graph, MPSGraphTensor* primary, MPSGraphTensor* secondary)
+
 enum StdVarType {
   STANDARD_VARIANCE,
   STANDARD_DEVIATION
@@ -34,15 +37,6 @@
 
 using namespace mps;
 
-NSArray<NSNumber*>* getTensorAxes(const Tensor& t) {
-  int64_t ndim = t.dim();
-  auto axes = [NSMutableArray<NSNumber*> arrayWithCapacity:ndim];
-  for (const auto i: c10::irange(ndim)) {
-    axes[i] = [NSNumber numberWithInteger:i];
-  }
-  return axes;
-}
-
 void set_apparent_shapes(NSMutableArray<NSNumber*> * &apparent_out_shape,
                          NSMutableArray<NSNumber*> * &apparent_in_shape,
                          int64_t num_reduce_dims,
@@ -410,19 +404,28 @@
   reduction_out_mps(input_t, opt_dim, keepdim, dtype, output_t, MPSReductionType::MEAN, "mean_out_mps");
 }
 
-TORCH_IMPL_FUNC(norm_out_mps)
-    (const Tensor& input_tensor,
-     const OptionalScalarRef opt_p,
-     IntArrayRef dim,
-     bool keepdim,
-     const Tensor& output_t) {
+void impl_func_norm_mps(
+  const Tensor& input_tensor,
+  const Tensor& other_tensor,
+  const OptionalScalarRef& opt_p,
+  IntArrayRef dim,
+  bool keepdim,
+  c10::optional<ScalarType> opt_dtype,
+  const Tensor& output_t,
+  bool cdist = false,
+  c10::optional<IntArrayRef> input_broadcasted_shape = c10::nullopt,
+  NormOpBlock normOpBlock = nullptr
+  ) {
+
   if (input_tensor.numel() == 0) {
     return;
   }
 
   auto input_t = (input_tensor.sizes().size() == 0) ? input_tensor.view({1}) : input_tensor;
+  auto in_dtype = opt_dtype.value_or(input_tensor.scalar_type());
+  auto mps_input_dtype = getMPSDataType(in_dtype);
 
-  IntArrayRef input_shape = input_t.sizes();
+  IntArrayRef input_shape = cdist ? input_broadcasted_shape.value() : input_t.sizes();
 
   for (const auto dim_val: dim) {
     auto wrap_dim = maybe_wrap_dim(dim_val, input_shape.size());
@@ -456,6 +459,13 @@
                       num_output_dims,
                       input_shape,
                       axes);
+
+  NSArray<NSNumber*>* wrappedAxes = mps::getTensorAxes(input_t, dim);
+  if (cdist) {
+    apparent_input_shape  = [mps::getMPSShape(input_tensor.sizes()) mutableCopy];
+    apparent_output_shape = [mps::getMPSShape(output_t.sizes()) mutableCopy];
+  }
+
   if (output_t.numel() == 0) {
     return;
   }
@@ -465,62 +475,76 @@
   @autoreleasepool {
     NSString* ns_key = [[axes valueForKey:@"description"] componentsJoinedByString:@","];
       string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0";
-      string key =  string("norm_out_mps:") + [ns_key UTF8String] + ":" + getMPSTypeString(input_t.scalar_type()) + ":p" + to_string(p) + ":" + keepdim_info;
+      string tensor_key = cdist ? getTensorsStringKey({input_tensor, other_tensor}) : mps::getTensorsStringKey({input_t});
+      string key =  string("norm_out_mps:") + [ns_key UTF8String] + ":" + tensor_key + ":p" + to_string(p) + ":" + keepdim_info;
 
-    auto cachedGraph = cache_->LookUpAs<MPSUnaryCachedGraph>(key);
+    auto cachedGraph = cache_->LookUpAs<MPSBinaryCachedGraph>(key);
 
-    if (!cachedGraph) {
-      cachedGraph = cache_->CreateCachedGraphAs<MPSUnaryCachedGraph>(key, ^ MPSCachedGraph * () {
+    if(!cachedGraph) {
+      cachedGraph = cache_->CreateCachedGraphAs<MPSBinaryCachedGraph>(key, ^ MPSCachedGraph * () {
 
-        MPSUnaryCachedGraph *newCachedGraph = nil;
+        MPSBinaryCachedGraph *newCachedGraph = nil;
 
         @autoreleasepool {
           MPSGraph* mpsGraph = make_mps_graph();
-          newCachedGraph = new MPSUnaryCachedGraph(mpsGraph);
+          newCachedGraph = new MPSBinaryCachedGraph(mpsGraph);
+          newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, input_tensor);
 
-          MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(input_t.scalar_type()));
+          if (cdist) {
+            newCachedGraph->otherTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, other_tensor);
+          }
 
-          MPSGraphTensor *outputTensor = nil;
+          MPSGraphTensor* inputTensor = cdist ? normOpBlock(newCachedGraph, newCachedGraph->inputTensor_, newCachedGraph->otherTensor_) :
+                                                newCachedGraph->inputTensor_;
+          if (opt_dtype.has_value()) {
+            inputTensor = [mpsGraph castTensor:inputTensor
+                                        toType:mps_input_dtype
+                                          name:@"castInputTensor"];
+          }
+
+          MPSGraphTensor *outputTensor;
 
           if (pIsZero) {
               MPSGraphTensor *absoluteTensor = [mpsGraph absoluteWithTensor:inputTensor
                                                                        name:nil];
               MPSGraphTensor *powerValTensor = [mpsGraph constantWithScalar:p
-                                                                   dataType:getMPSDataType(input_t.scalar_type())];
+                                                                   dataType:mps_input_dtype];
               MPSGraphTensor *powerTensor = [mpsGraph powerWithPrimaryTensor:absoluteTensor
                                                              secondaryTensor:powerValTensor
                                                                         name:nil];
               outputTensor = [mpsGraph reductionSumWithTensor:powerTensor
-                                                         axes:axes
+                                                         axes:wrappedAxes
                                                          name:nil];
-          } else if (pIsPosInf) {
+          }
+          else if (pIsPosInf) {
               MPSGraphTensor *absoluteTensor = [mpsGraph absoluteWithTensor:inputTensor
                                                                        name:nil];
               outputTensor = [mpsGraph reductionMaximumWithTensor:absoluteTensor
-                                                             axes:axes
+                                                             axes:wrappedAxes
                                                              name:nil];
-          } else if (pIsNegInf) {
+          }
+          else if (pIsNegInf) {
               MPSGraphTensor *absoluteTensor = [mpsGraph absoluteWithTensor:inputTensor
                                                                        name:nil];
               outputTensor = [mpsGraph reductionMinimumWithTensor:absoluteTensor
-                                                             axes:axes
+                                                             axes:wrappedAxes
                                                              name:nil];
           } else {
               MPSGraphTensor *absoluteTensor = [mpsGraph absoluteWithTensor:inputTensor
                                                                        name:nil];
 
               MPSGraphTensor *powerValTensor = [mpsGraph constantWithScalar:p
-                                                                   dataType:getMPSDataType(input_t.scalar_type())];
+                                                                   dataType:mps_input_dtype];
 
               MPSGraphTensor *reciprocalPowerValTensor = [mpsGraph constantWithScalar:reciprocal_p
-                                                                             dataType:getMPSDataType(input_t.scalar_type())];
+                                                                             dataType:mps_input_dtype];
 
               MPSGraphTensor *powerTensor = [mpsGraph powerWithPrimaryTensor:absoluteTensor
                                                              secondaryTensor:powerValTensor
                                                                         name:nil];
 
               MPSGraphTensor *reductionSumTensor = [mpsGraph reductionSumWithTensor:powerTensor
-                                                                         axes:axes
+                                                                         axes:wrappedAxes
                                                                          name:nil];
 
               outputTensor = [mpsGraph powerWithPrimaryTensor:reductionSumTensor
@@ -528,37 +552,133 @@
                                                          name:nil];
           }
 
-          newCachedGraph->inputTensor_ = inputTensor;
+          if (cdist) {
+            outputTensor= [mpsGraph reshapeTensor:outputTensor withShape:mps::getMPSShape(output_t) name: nil];
+          }
+
           newCachedGraph->outputTensor_ = outputTensor;
         }
         return newCachedGraph;
       });
     }
 
-    auto inputPlaceholder = Placeholder();
-
-    if (apparent_input_shape) {
-      inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t, apparent_input_shape);
-    } else {
-      inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t);
-    }
-
+    auto otherPlaceholder = Placeholder();
+    auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t);
     auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output_t, apparent_output_shape);
 
+    NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =[NSMutableDictionary dictionary];
+    feeds[inputPlaceholder.getMPSGraphTensor()]   = inputPlaceholder.getMPSGraphTensorData();
 
-    NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
-      inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
-    };
+    if (cdist) {
+      otherPlaceholder = Placeholder(cachedGraph->otherTensor_, other_tensor);
+      feeds[otherPlaceholder.getMPSGraphTensor()] = otherPlaceholder.getMPSGraphTensorData();
+    }
 
     NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
       outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
     };
 
     runMPSGraph(stream, cachedGraph->graph(), feeds, results);
-
   }
 }
 
+TORCH_IMPL_FUNC(norm_out_mps)
+(const Tensor& self,
+ const OptionalScalarRef opt_p,
+ IntArrayRef dim,
+ bool keepdim,
+ const Tensor& result) {
+  impl_func_norm_mps(self, self, opt_p, dim, keepdim, c10::nullopt, result, /*cdist=*/false);
+}
+
+TORCH_IMPL_FUNC(norm_dtype_out_mps)
+(const Tensor& self,
+ const OptionalScalarRef opt_p,
+ IntArrayRef dim,
+ bool keepdim,
+ ScalarType dtype,
+ const Tensor& result) {
+  impl_func_norm_mps(self, self, opt_p, dim, keepdim, dtype, result, /*cdist=*/false);
+}
+
+
+Tensor _cdist_forward_mps(const Tensor& x1, const Tensor& x2, const double p, c10::optional<int64_t> compute_mode) {
+  using namespace mps;
+  TORCH_CHECK(x1.dim() >= 2, "cdist only supports at least 2D tensors, X1 got: ", x1.dim(), "D");
+  TORCH_CHECK(x2.dim() >= 2, "cdist only supports at least 2D tensors, X2 got: ", x2.dim(), "D");
+  TORCH_CHECK(x1.size(-1) == x2.size(-1), "X1 and X2 must have the same number of columns. X1: ", x1.size(-1), " X2: ", x2.size(-1));
+  TORCH_CHECK(at::isFloatingType(x1.scalar_type()), "cdist only supports floating-point dtypes, X1 got: ", x1.scalar_type());
+  auto device1 = x1.device().type();
+  TORCH_CHECK(at::isFloatingType(x2.scalar_type()), "cdist only supports floating-point dtypes, X2 got: ", x2.scalar_type());
+  auto device2 = x2.device().type();
+  TORCH_CHECK(p >= 0, "cdist only supports non-negative p values");
+  TORCH_CHECK(device1 == device2, "X1 and X2 must have the same device type. X1: ", device1, " X2: ", device2);
+  TORCH_CHECK(x1.is_mps() && (x1.get_device() == x2.get_device()), "device of X1 (", x1.get_device(), ") must match device of X2 (", x2.get_device(), ")");
+
+  int64_t c1 = x1.size(-1);
+  int64_t c2 = x2.size(-1);
+
+  auto dim1 = x1.dim();
+  auto dim2 = x2.dim();
+  int64_t mode = compute_mode.value_or(0);
+  TORCH_CHECK(mode >= 0 && mode <= 2, "possible modes: 0, 1, 2, but was: ", mode);
+
+  int64_t r1 = x1.size(-2);
+  int64_t r2 = x2.size(-2);
+
+  //For batch calculation we expand all dimensions(except the last two) to one, with size that equals to product of them.
+  //The last two dimensions will stay the same
+  IntArrayRef batch_tensor1(x1.sizes().data(), dim1 - 2);
+  IntArrayRef batch_tensor2(x2.sizes().data(), dim2 - 2);
+  std::vector<int64_t> expand_batch_portion = infer_size(batch_tensor1, batch_tensor2);
+  std::vector<int64_t> tensor1_expand_size(expand_batch_portion);
+  tensor1_expand_size.insert(tensor1_expand_size.end(), {r1, c1});
+  std::vector<int64_t> tensor2_expand_size(expand_batch_portion);
+  tensor2_expand_size.insert(tensor2_expand_size.end(), {r2, c2});
+
+  const int64_t expand_batch_product = c10::multiply_integers(expand_batch_portion);
+  std::vector<int64_t> tensor1_view{expand_batch_product, r1, c1};
+  std::vector<int64_t> tensor2_view{expand_batch_product, r2, c2};
+
+  std::vector<int64_t> output_shape(expand_batch_portion);
+  output_shape.insert(output_shape.end(), {r1, r2});
+  Tensor result = at::empty(output_shape, x1.options());
+
+  NormOpBlock norm_op_block = ^NormOpFn(cachedGraph, x1Tensor, x2Tensor) {
+    MPSGraph* mpsGraph = cachedGraph->graph();
+
+    MPSGraphTensor* inputBroadcast = [mpsGraph broadcastTensor:x1Tensor toShape:getMPSShape(tensor1_expand_size) name:nil];
+    MPSGraphTensor* inputBroadcastReshape = [mpsGraph reshapeTensor:inputBroadcast withShape:getMPSShape(tensor1_view) name:nil];
+
+    MPSGraphTensor* otherBroadcast = [mpsGraph broadcastTensor:x2Tensor toShape:getMPSShape(tensor2_expand_size) name:nil];
+    MPSGraphTensor* otherBroadcastReshape = [mpsGraph reshapeTensor:otherBroadcast withShape:getMPSShape(tensor2_view) name:nil];
+
+    NSMutableArray<MPSGraphTensor*> *inputArray = [NSMutableArray arrayWithCapacity:tensor1_view[1]];
+    NSMutableArray<MPSGraphTensor*> *otherArray = [NSMutableArray arrayWithCapacity:tensor2_view[1]];
+
+    for (const auto i : c10::irange(tensor2_view[1])) {
+      inputArray[i] = inputBroadcastReshape;
+    }
+
+    for (const auto i : c10::irange(tensor1_view[1])) {
+      otherArray[i] = otherBroadcastReshape;
+    }
+
+    MPSGraphTensor *inputTensorReshaped = [mpsGraph concatTensors:inputArray dimension:1 interleave:YES name:nil];
+    MPSGraphTensor *otherTensorReshaped = [mpsGraph concatTensors:otherArray dimension:1 interleave:NO name:nil];
+
+
+    MPSGraphTensor *inputTensorPNorm = [mpsGraph subtractionWithPrimaryTensor: inputTensorReshaped
+                                                              secondaryTensor: otherTensorReshaped
+                                                                         name: nil];
+    return inputTensorPNorm;
+  };
+
+  c10::optional<IntArrayRef> inputBroadcastSize = c10::make_optional(makeArrayRef(tensor1_view.data(), tensor1_view.size()));
+  impl_func_norm_mps(x1, x2, OptionalScalarRef(p), makeArrayRef<int64_t>(2), false, c10::nullopt, result, /*cdist=*/true, inputBroadcastSize, norm_op_block);
+  return result;
+}
+
 Tensor std_var_common_impl_mps(const Tensor & input_t,
                                at::OptionalIntArrayRef dim,
                                c10::optional<int64_t> correction,
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 9b42c97..27677db 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -4090,6 +4090,7 @@
 - func: _cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor
   dispatch:
     CPU, CUDA: _cdist_forward
+    MPS: _cdist_forward_mps
   autogen: _cdist_forward.out
 
 - func: _cdist_backward(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist) -> Tensor
@@ -6174,6 +6175,7 @@
   device_check: NoCheck   # TensorIterator
   dispatch:
     CPU, CUDA: norm_dtype_out
+    MPS: norm_dtype_out_mps
 
 - func: norm.out(Tensor self, Scalar? p, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
   structured: True
diff --git a/test/test_mps.py b/test/test_mps.py
index 00dadbe..2084f95 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -289,6 +289,151 @@
         helper(0, [1024])
         helper(0.2, [2, 3])
 
+    def test_cdist_large(self, device="mps"):
+        for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
+            x = torch.randn(100, 10, device=device)
+            y = torch.randn(100, 10, device=device)
+            actual = torch.cdist(x, y, p=2, compute_mode=cm)
+            expected = self._brute_cdist(x, y, p=2)
+            self.assertEqual(expected, actual)
+
+    def test_cdist_large_batch(self, device="mps"):
+        for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
+            x = torch.randn(4, 3, 100, 10, device=device)
+            y = torch.randn(4, 3, 100, 10, device=device)
+            actual = torch.cdist(x, y, p=2, compute_mode=cm)
+            expected = self._brute_cdist(x, y, p=2)
+            self.assertEqual(expected, actual)
+
+    def test_cdist_non_contiguous(self, device="mps"):
+        for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
+            x = torch.randn(5, 7, device=device).mT
+            y = torch.randn(5, 3, device=device).mT
+            actual = torch.cdist(x, y, p=2, compute_mode=cm)
+            expected = self._brute_cdist(x, y, p=2)
+            self.assertFalse(x.is_contiguous())
+            self.assertFalse(y.is_contiguous())
+            self.assertEqual(expected, actual)
+
+            x = torch.randn(7, 5, device=device)
+            y = torch.randn(5, 3, device=device).t()
+            actual = torch.cdist(x, y, p=2, compute_mode=cm)
+            expected = self._brute_cdist(x, y, p=2)
+            self.assertTrue(x.is_contiguous())
+            self.assertFalse(y.is_contiguous())
+            self.assertEqual(expected, actual)
+
+            x = torch.randn(5, 7, device=device).t()
+            y = torch.randn(3, 5, device=device)
+            actual = torch.cdist(x, y, p=2, compute_mode=cm)
+            expected = self._brute_cdist(x, y, p=2)
+            self.assertFalse(x.is_contiguous())
+            self.assertTrue(y.is_contiguous())
+            self.assertEqual(expected, actual)
+
+    def test_cdist_non_contiguous_batch(self, device="mps"):
+        for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
+            x = torch.randn(4, 3, 2, 5, 7, device=device).mT
+            y = torch.randn(4, 3, 2, 5, 3, device=device).mT
+            actual = torch.cdist(x, y, p=2, compute_mode=cm)
+            expected = self._brute_cdist(x, y, p=2)
+            self.assertFalse(x.is_contiguous())
+            self.assertFalse(y.is_contiguous())
+            self.assertEqual(expected, actual)
+
+            x = torch.randn(7, 2, 7, 5, device=device)
+            y = torch.randn(7, 2, 5, 3, device=device).mT
+            actual = torch.cdist(x, y, p=2, compute_mode=cm)
+            expected = self._brute_cdist(x, y, p=2)
+            self.assertTrue(x.is_contiguous())
+            self.assertFalse(y.is_contiguous())
+            self.assertEqual(expected, actual)
+
+            x = torch.randn(4, 5, 7, device=device).mT
+            y = torch.randn(4, 3, 5, device=device)
+            actual = torch.cdist(x, y, p=2, compute_mode=cm)
+            expected = self._brute_cdist(x, y, p=2)
+            self.assertFalse(x.is_contiguous())
+            self.assertTrue(y.is_contiguous())
+            self.assertEqual(expected, actual)
+
+    def test_cdist_euclidean_large(self, device="mps"):
+        def _test_euclidean_large_cdist(sizex, sizey=None):
+            if sizey is None:
+                sizey = sizex
+            x = torch.randn(sizex, device=device, dtype=torch.float)
+            y = torch.randn(sizey, device=device, dtype=torch.float)
+            eps = 1e-6
+            # to avoid extremum
+            x = x - (((x - y) < eps).float() * 2 * eps)
+            x.requires_grad = True
+            y.requires_grad = True
+            dist = torch.cdist(x, y, p=2)
+            # Do a backward pass to check that it is valid for large
+            # matrices
+            loss = dist.sum()
+            loss.backward()
+
+        _test_euclidean_large_cdist((2000, 5))
+
+    def test_cdist_same_inputs(self, device="mps"):
+        # Test to detect issues in cdist gradient calculation
+        # When the distances are 0
+        sizex = (1, 27, 32)
+        for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
+            x = torch.randn(sizex, device=device, dtype=torch.float)
+            dist_grad = torch.randn((1, 27, 27), device=device, dtype=torch.float)
+            y = x.clone()
+            eps = 1e-6
+            x.requires_grad = True
+            d = torch.cdist(x, y)
+            d.backward(dist_grad)
+            # Check that the backward passs does not contain invalid
+            # values such as nan or inf
+            assert torch.isfinite(x.grad).all()
+
+
+    def _brute_cdist(self, x, y, p=2):
+        r1 = x.shape[-2]
+        r2 = y.shape[-2]
+        if r1 == 0 or r2 == 0:
+            return torch.empty(r1, r2, device=x.device)
+        return torch.norm(x[..., None, :] - y[..., None, :, :], p=p, dim=-1)
+
+    def test_cdist_norm(self, device="mps"):
+        for r1 in [3, 4]:
+            for m in [2, 3]:
+                for r2 in [4, 6]:
+                    for p in [0, 1, 1.5, 2.5, float('inf')]:
+                        x = torch.randn(r1, m, device=device)
+                        y = torch.randn(r2, m, device=device)
+                        if p == 2:
+                            for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
+                                actual = torch.cdist(x, y, p=2, compute_mode=cm)
+                                expected = self._brute_cdist(x, y, p=2)
+                                self.assertEqual(expected, actual, rtol=0, atol=0.02)
+                        else:
+                            actual = torch.cdist(x, y, p=p)
+                            expected = self._brute_cdist(x, y, p=p)
+                            self.assertEqual(expected, actual)
+
+    def test_cdist_norm_batch(self, device="mps"):
+        for r1 in [3, 4]:
+            for m in [2, 3]:
+                for r2 in [4, 6]:
+                    for p in [0, 3, 1.5, 2.5, float('inf')]:
+                        x = torch.randn(2, 3, 6, r1, m, device=device)
+                        y = torch.randn(2, 3, 6, r2, m, device=device)
+                        if p == 2:
+                            for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
+                                actual = torch.cdist(x, y, p=2, compute_mode=cm)
+                                expected = self._brute_cdist(x, y, p=2)
+                                self.assertEqual(expected, actual, rtol=0, atol=0.02)
+                        else:
+                            actual = torch.cdist(x, y, p=p)
+                            expected = self._brute_cdist(x, y, p=p)
+                            self.assertEqual(expected, actual)
+
     def test_mm(self):
         B = torch.ones(5, 6).to("mps")
         C = torch.ones(6, 5).to("mps")
@@ -809,6 +954,55 @@
                         helper(shape, eps=3, momentum=0.67, wts=True, training=True, channels_last=channels_last,
                                track_running_stats=track_running_stats, test_module=test_module)
 
+    def test_norm(self):
+        a = torch.arange(9, dtype=torch.float, device="mps") - 4
+        b = a.reshape((3, 3))
+
+        a_cpu = torch.arange(9, dtype=torch.float, device="cpu") - 4
+        b_cpu = a_cpu.reshape((3, 3))
+
+        res = torch.norm(a)
+        res_cpu = torch.norm(a_cpu)
+        self.assertEqual(res, res_cpu)
+
+        res = torch.norm(b)
+        res_cpu = torch.norm(b_cpu)
+        self.assertEqual(res, res_cpu)
+
+        res = torch.norm(a, float('inf'))
+        res_cpu = torch.norm(a_cpu, float('inf'))
+        self.assertEqual(res, res_cpu)
+
+        res = torch.norm(b, float('inf'))
+        res_cpu = torch.norm(b_cpu, float('inf'))
+        self.assertEqual(res, res_cpu)
+
+        c = torch.tensor([[1, 2, 3], [-1, 1, 4]], dtype=torch.float, device="mps")
+        c_cpu = torch.tensor([[1, 2, 3], [-1, 1, 4]] , dtype=torch.float, device="cpu")
+
+        res = torch.norm(c, dim=0)
+        res_cpu = torch.norm(c_cpu, dim=0)
+        self.assertEqual(res, res_cpu)
+
+        res = torch.norm(c, dim=1)
+        res_cpu = torch.norm(c_cpu, dim=1)
+        self.assertEqual(res, res_cpu)
+
+        res = torch.norm(c, p=1, dim=1)
+        res_cpu = torch.norm(c_cpu, p=1, dim=1)
+        self.assertEqual(res, res_cpu)
+
+        d = torch.arange(8, dtype=torch.float, device="mps").reshape(2, 2, 2)
+        d_cpu = torch.arange(8, dtype=torch.float, device="cpu").reshape(2, 2, 2)
+
+        res = torch.norm(d, dim=(1, 2))
+        res_cpu = torch.norm(d_cpu, dim=(1, 2))
+        self.assertEqual(res, res_cpu)
+
+        res = torch.norm(d[0, :, :]), torch.norm(d[1, :, :])
+        res_cpu = torch.norm(d_cpu[0, :, :]), torch.norm(d_cpu[1, :, :])
+        self.assertEqual(res, res_cpu)
+
     def test_layer_norm(self):
         # TODO: Test non-contiguous
         def helper(input_shape, normalized_shape, eps=1e-05, elementwise_affine=True, dtype=torch.float32):