[MPS] Register norm_dtype_out_mps and cdist (#91643)
Add support for `norm_dtype_out` and `cdist` ops
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91643
Approved by: https://github.com/razarmehr
diff --git a/aten/src/ATen/native/mps/OperationUtils.h b/aten/src/ATen/native/mps/OperationUtils.h
index b0a3389..3066671 100644
--- a/aten/src/ATen/native/mps/OperationUtils.h
+++ b/aten/src/ATen/native/mps/OperationUtils.h
@@ -43,6 +43,8 @@
MPSDataType getMPSScalarType(ScalarType scalar_type);
MPSScalar getMPSScalar(const Scalar& scalar, ScalarType type);
std::string getMPSTypeString(ScalarType scalar_type);
+NSArray<NSNumber*>* getTensorAxes(const Tensor& t);
+NSArray<NSNumber*>* getTensorAxes(const Tensor& t, at::OptionalIntArrayRef dim);
std::string getMPSShapeString(MPSShape* shape);
std::string getTensorsStringKey(const TensorList& tensors, bool use_scalar_value = false);
std::string getArrayRefString(const IntArrayRef s);
@@ -127,6 +129,13 @@
MPSGraphTensor *outputTensor_ = nil;
};
+struct MPSBinaryCachedGraph : public MPSCachedGraph
+{
+ MPSBinaryCachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
+ MPSGraphTensor *inputTensor_ = nil;
+ MPSGraphTensor *otherTensor_ = nil;
+ MPSGraphTensor *outputTensor_ = nil;
+};
// TODO: Improve the overall design of MPSGraphCache.
// https://github.com/pytorch/pytorch/issues/77176
diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm
index 9da3b2e..8de0016 100644
--- a/aten/src/ATen/native/mps/OperationUtils.mm
+++ b/aten/src/ATen/native/mps/OperationUtils.mm
@@ -87,6 +87,30 @@
}
}
+NSArray<NSNumber*>* getTensorAxes(const Tensor& t) {
+ int64_t ndim = t.dim();
+ auto axes = [NSMutableArray<NSNumber*> arrayWithCapacity:ndim];
+ for (const auto i: c10::irange(ndim)) {
+ axes[i] = [NSNumber numberWithInteger:i];
+ }
+ return axes;
+}
+
+NSArray<NSNumber*>* getTensorAxes(const Tensor& t, at::OptionalIntArrayRef dim) {
+ if (dim.has_value() && dim.value().size() != 0) {
+ IntArrayRef dimValues = dim.value();
+ int ndim = dimValues.size();
+ auto axes = [NSMutableArray<NSNumber*> arrayWithCapacity:ndim];
+ for (const auto i: c10::irange(ndim)) {
+ axes[i] = [NSNumber numberWithInteger:dimValues[i]];
+ }
+
+ return axes;
+ }
+
+ return getTensorAxes(t);
+}
+
std::string getMPSShapeString(MPSShape* shape) {
std::string str;
for(NSNumber *elem in shape) {
diff --git a/aten/src/ATen/native/mps/operations/ReduceOps.mm b/aten/src/ATen/native/mps/operations/ReduceOps.mm
index 43b3336..9ef87b4 100644
--- a/aten/src/ATen/native/mps/operations/ReduceOps.mm
+++ b/aten/src/ATen/native/mps/operations/ReduceOps.mm
@@ -15,6 +15,9 @@
namespace at {
namespace native {
+typedef MPSGraphTensor* (^NormOpBlock)(mps::MPSBinaryCachedGraph*, MPSGraphTensor*, MPSGraphTensor*);
+#define NormOpFn(graph, primary, secondary) MPSGraphTensor* (mps::MPSBinaryCachedGraph* graph, MPSGraphTensor* primary, MPSGraphTensor* secondary)
+
enum StdVarType {
STANDARD_VARIANCE,
STANDARD_DEVIATION
@@ -34,15 +37,6 @@
using namespace mps;
-NSArray<NSNumber*>* getTensorAxes(const Tensor& t) {
- int64_t ndim = t.dim();
- auto axes = [NSMutableArray<NSNumber*> arrayWithCapacity:ndim];
- for (const auto i: c10::irange(ndim)) {
- axes[i] = [NSNumber numberWithInteger:i];
- }
- return axes;
-}
-
void set_apparent_shapes(NSMutableArray<NSNumber*> * &apparent_out_shape,
NSMutableArray<NSNumber*> * &apparent_in_shape,
int64_t num_reduce_dims,
@@ -410,19 +404,28 @@
reduction_out_mps(input_t, opt_dim, keepdim, dtype, output_t, MPSReductionType::MEAN, "mean_out_mps");
}
-TORCH_IMPL_FUNC(norm_out_mps)
- (const Tensor& input_tensor,
- const OptionalScalarRef opt_p,
- IntArrayRef dim,
- bool keepdim,
- const Tensor& output_t) {
+void impl_func_norm_mps(
+ const Tensor& input_tensor,
+ const Tensor& other_tensor,
+ const OptionalScalarRef& opt_p,
+ IntArrayRef dim,
+ bool keepdim,
+ c10::optional<ScalarType> opt_dtype,
+ const Tensor& output_t,
+ bool cdist = false,
+ c10::optional<IntArrayRef> input_broadcasted_shape = c10::nullopt,
+ NormOpBlock normOpBlock = nullptr
+ ) {
+
if (input_tensor.numel() == 0) {
return;
}
auto input_t = (input_tensor.sizes().size() == 0) ? input_tensor.view({1}) : input_tensor;
+ auto in_dtype = opt_dtype.value_or(input_tensor.scalar_type());
+ auto mps_input_dtype = getMPSDataType(in_dtype);
- IntArrayRef input_shape = input_t.sizes();
+ IntArrayRef input_shape = cdist ? input_broadcasted_shape.value() : input_t.sizes();
for (const auto dim_val: dim) {
auto wrap_dim = maybe_wrap_dim(dim_val, input_shape.size());
@@ -456,6 +459,13 @@
num_output_dims,
input_shape,
axes);
+
+ NSArray<NSNumber*>* wrappedAxes = mps::getTensorAxes(input_t, dim);
+ if (cdist) {
+ apparent_input_shape = [mps::getMPSShape(input_tensor.sizes()) mutableCopy];
+ apparent_output_shape = [mps::getMPSShape(output_t.sizes()) mutableCopy];
+ }
+
if (output_t.numel() == 0) {
return;
}
@@ -465,62 +475,76 @@
@autoreleasepool {
NSString* ns_key = [[axes valueForKey:@"description"] componentsJoinedByString:@","];
string keepdim_info = (keepdim) ? "keepdim=1" : "keepdim=0";
- string key = string("norm_out_mps:") + [ns_key UTF8String] + ":" + getMPSTypeString(input_t.scalar_type()) + ":p" + to_string(p) + ":" + keepdim_info;
+ string tensor_key = cdist ? getTensorsStringKey({input_tensor, other_tensor}) : mps::getTensorsStringKey({input_t});
+ string key = string("norm_out_mps:") + [ns_key UTF8String] + ":" + tensor_key + ":p" + to_string(p) + ":" + keepdim_info;
- auto cachedGraph = cache_->LookUpAs<MPSUnaryCachedGraph>(key);
+ auto cachedGraph = cache_->LookUpAs<MPSBinaryCachedGraph>(key);
- if (!cachedGraph) {
- cachedGraph = cache_->CreateCachedGraphAs<MPSUnaryCachedGraph>(key, ^ MPSCachedGraph * () {
+ if(!cachedGraph) {
+ cachedGraph = cache_->CreateCachedGraphAs<MPSBinaryCachedGraph>(key, ^ MPSCachedGraph * () {
- MPSUnaryCachedGraph *newCachedGraph = nil;
+ MPSBinaryCachedGraph *newCachedGraph = nil;
@autoreleasepool {
MPSGraph* mpsGraph = make_mps_graph();
- newCachedGraph = new MPSUnaryCachedGraph(mpsGraph);
+ newCachedGraph = new MPSBinaryCachedGraph(mpsGraph);
+ newCachedGraph->inputTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, input_tensor);
- MPSGraphTensor* inputTensor = mpsGraphUnrankedPlaceHolder(mpsGraph, getMPSDataType(input_t.scalar_type()));
+ if (cdist) {
+ newCachedGraph->otherTensor_ = mpsGraphRankedPlaceHolder(mpsGraph, other_tensor);
+ }
- MPSGraphTensor *outputTensor = nil;
+ MPSGraphTensor* inputTensor = cdist ? normOpBlock(newCachedGraph, newCachedGraph->inputTensor_, newCachedGraph->otherTensor_) :
+ newCachedGraph->inputTensor_;
+ if (opt_dtype.has_value()) {
+ inputTensor = [mpsGraph castTensor:inputTensor
+ toType:mps_input_dtype
+ name:@"castInputTensor"];
+ }
+
+ MPSGraphTensor *outputTensor;
if (pIsZero) {
MPSGraphTensor *absoluteTensor = [mpsGraph absoluteWithTensor:inputTensor
name:nil];
MPSGraphTensor *powerValTensor = [mpsGraph constantWithScalar:p
- dataType:getMPSDataType(input_t.scalar_type())];
+ dataType:mps_input_dtype];
MPSGraphTensor *powerTensor = [mpsGraph powerWithPrimaryTensor:absoluteTensor
secondaryTensor:powerValTensor
name:nil];
outputTensor = [mpsGraph reductionSumWithTensor:powerTensor
- axes:axes
+ axes:wrappedAxes
name:nil];
- } else if (pIsPosInf) {
+ }
+ else if (pIsPosInf) {
MPSGraphTensor *absoluteTensor = [mpsGraph absoluteWithTensor:inputTensor
name:nil];
outputTensor = [mpsGraph reductionMaximumWithTensor:absoluteTensor
- axes:axes
+ axes:wrappedAxes
name:nil];
- } else if (pIsNegInf) {
+ }
+ else if (pIsNegInf) {
MPSGraphTensor *absoluteTensor = [mpsGraph absoluteWithTensor:inputTensor
name:nil];
outputTensor = [mpsGraph reductionMinimumWithTensor:absoluteTensor
- axes:axes
+ axes:wrappedAxes
name:nil];
} else {
MPSGraphTensor *absoluteTensor = [mpsGraph absoluteWithTensor:inputTensor
name:nil];
MPSGraphTensor *powerValTensor = [mpsGraph constantWithScalar:p
- dataType:getMPSDataType(input_t.scalar_type())];
+ dataType:mps_input_dtype];
MPSGraphTensor *reciprocalPowerValTensor = [mpsGraph constantWithScalar:reciprocal_p
- dataType:getMPSDataType(input_t.scalar_type())];
+ dataType:mps_input_dtype];
MPSGraphTensor *powerTensor = [mpsGraph powerWithPrimaryTensor:absoluteTensor
secondaryTensor:powerValTensor
name:nil];
MPSGraphTensor *reductionSumTensor = [mpsGraph reductionSumWithTensor:powerTensor
- axes:axes
+ axes:wrappedAxes
name:nil];
outputTensor = [mpsGraph powerWithPrimaryTensor:reductionSumTensor
@@ -528,37 +552,133 @@
name:nil];
}
- newCachedGraph->inputTensor_ = inputTensor;
+ if (cdist) {
+ outputTensor= [mpsGraph reshapeTensor:outputTensor withShape:mps::getMPSShape(output_t) name: nil];
+ }
+
newCachedGraph->outputTensor_ = outputTensor;
}
return newCachedGraph;
});
}
- auto inputPlaceholder = Placeholder();
-
- if (apparent_input_shape) {
- inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t, apparent_input_shape);
- } else {
- inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t);
- }
-
+ auto otherPlaceholder = Placeholder();
+ auto inputPlaceholder = Placeholder(cachedGraph->inputTensor_, input_t);
auto outputPlaceholder = Placeholder(cachedGraph->outputTensor_, output_t, apparent_output_shape);
+ NSMutableDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds =[NSMutableDictionary dictionary];
+ feeds[inputPlaceholder.getMPSGraphTensor()] = inputPlaceholder.getMPSGraphTensorData();
- NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *feeds = @{
- inputPlaceholder.getMPSGraphTensor() : inputPlaceholder.getMPSGraphTensorData(),
- };
+ if (cdist) {
+ otherPlaceholder = Placeholder(cachedGraph->otherTensor_, other_tensor);
+ feeds[otherPlaceholder.getMPSGraphTensor()] = otherPlaceholder.getMPSGraphTensorData();
+ }
NSDictionary<MPSGraphTensor *, MPSGraphTensorData *> *results = @{
outputPlaceholder.getMPSGraphTensor() : outputPlaceholder.getMPSGraphTensorData()
};
runMPSGraph(stream, cachedGraph->graph(), feeds, results);
-
}
}
+TORCH_IMPL_FUNC(norm_out_mps)
+(const Tensor& self,
+ const OptionalScalarRef opt_p,
+ IntArrayRef dim,
+ bool keepdim,
+ const Tensor& result) {
+ impl_func_norm_mps(self, self, opt_p, dim, keepdim, c10::nullopt, result, /*cdist=*/false);
+}
+
+TORCH_IMPL_FUNC(norm_dtype_out_mps)
+(const Tensor& self,
+ const OptionalScalarRef opt_p,
+ IntArrayRef dim,
+ bool keepdim,
+ ScalarType dtype,
+ const Tensor& result) {
+ impl_func_norm_mps(self, self, opt_p, dim, keepdim, dtype, result, /*cdist=*/false);
+}
+
+
+Tensor _cdist_forward_mps(const Tensor& x1, const Tensor& x2, const double p, c10::optional<int64_t> compute_mode) {
+ using namespace mps;
+ TORCH_CHECK(x1.dim() >= 2, "cdist only supports at least 2D tensors, X1 got: ", x1.dim(), "D");
+ TORCH_CHECK(x2.dim() >= 2, "cdist only supports at least 2D tensors, X2 got: ", x2.dim(), "D");
+ TORCH_CHECK(x1.size(-1) == x2.size(-1), "X1 and X2 must have the same number of columns. X1: ", x1.size(-1), " X2: ", x2.size(-1));
+ TORCH_CHECK(at::isFloatingType(x1.scalar_type()), "cdist only supports floating-point dtypes, X1 got: ", x1.scalar_type());
+ auto device1 = x1.device().type();
+ TORCH_CHECK(at::isFloatingType(x2.scalar_type()), "cdist only supports floating-point dtypes, X2 got: ", x2.scalar_type());
+ auto device2 = x2.device().type();
+ TORCH_CHECK(p >= 0, "cdist only supports non-negative p values");
+ TORCH_CHECK(device1 == device2, "X1 and X2 must have the same device type. X1: ", device1, " X2: ", device2);
+ TORCH_CHECK(x1.is_mps() && (x1.get_device() == x2.get_device()), "device of X1 (", x1.get_device(), ") must match device of X2 (", x2.get_device(), ")");
+
+ int64_t c1 = x1.size(-1);
+ int64_t c2 = x2.size(-1);
+
+ auto dim1 = x1.dim();
+ auto dim2 = x2.dim();
+ int64_t mode = compute_mode.value_or(0);
+ TORCH_CHECK(mode >= 0 && mode <= 2, "possible modes: 0, 1, 2, but was: ", mode);
+
+ int64_t r1 = x1.size(-2);
+ int64_t r2 = x2.size(-2);
+
+ //For batch calculation we expand all dimensions(except the last two) to one, with size that equals to product of them.
+ //The last two dimensions will stay the same
+ IntArrayRef batch_tensor1(x1.sizes().data(), dim1 - 2);
+ IntArrayRef batch_tensor2(x2.sizes().data(), dim2 - 2);
+ std::vector<int64_t> expand_batch_portion = infer_size(batch_tensor1, batch_tensor2);
+ std::vector<int64_t> tensor1_expand_size(expand_batch_portion);
+ tensor1_expand_size.insert(tensor1_expand_size.end(), {r1, c1});
+ std::vector<int64_t> tensor2_expand_size(expand_batch_portion);
+ tensor2_expand_size.insert(tensor2_expand_size.end(), {r2, c2});
+
+ const int64_t expand_batch_product = c10::multiply_integers(expand_batch_portion);
+ std::vector<int64_t> tensor1_view{expand_batch_product, r1, c1};
+ std::vector<int64_t> tensor2_view{expand_batch_product, r2, c2};
+
+ std::vector<int64_t> output_shape(expand_batch_portion);
+ output_shape.insert(output_shape.end(), {r1, r2});
+ Tensor result = at::empty(output_shape, x1.options());
+
+ NormOpBlock norm_op_block = ^NormOpFn(cachedGraph, x1Tensor, x2Tensor) {
+ MPSGraph* mpsGraph = cachedGraph->graph();
+
+ MPSGraphTensor* inputBroadcast = [mpsGraph broadcastTensor:x1Tensor toShape:getMPSShape(tensor1_expand_size) name:nil];
+ MPSGraphTensor* inputBroadcastReshape = [mpsGraph reshapeTensor:inputBroadcast withShape:getMPSShape(tensor1_view) name:nil];
+
+ MPSGraphTensor* otherBroadcast = [mpsGraph broadcastTensor:x2Tensor toShape:getMPSShape(tensor2_expand_size) name:nil];
+ MPSGraphTensor* otherBroadcastReshape = [mpsGraph reshapeTensor:otherBroadcast withShape:getMPSShape(tensor2_view) name:nil];
+
+ NSMutableArray<MPSGraphTensor*> *inputArray = [NSMutableArray arrayWithCapacity:tensor1_view[1]];
+ NSMutableArray<MPSGraphTensor*> *otherArray = [NSMutableArray arrayWithCapacity:tensor2_view[1]];
+
+ for (const auto i : c10::irange(tensor2_view[1])) {
+ inputArray[i] = inputBroadcastReshape;
+ }
+
+ for (const auto i : c10::irange(tensor1_view[1])) {
+ otherArray[i] = otherBroadcastReshape;
+ }
+
+ MPSGraphTensor *inputTensorReshaped = [mpsGraph concatTensors:inputArray dimension:1 interleave:YES name:nil];
+ MPSGraphTensor *otherTensorReshaped = [mpsGraph concatTensors:otherArray dimension:1 interleave:NO name:nil];
+
+
+ MPSGraphTensor *inputTensorPNorm = [mpsGraph subtractionWithPrimaryTensor: inputTensorReshaped
+ secondaryTensor: otherTensorReshaped
+ name: nil];
+ return inputTensorPNorm;
+ };
+
+ c10::optional<IntArrayRef> inputBroadcastSize = c10::make_optional(makeArrayRef(tensor1_view.data(), tensor1_view.size()));
+ impl_func_norm_mps(x1, x2, OptionalScalarRef(p), makeArrayRef<int64_t>(2), false, c10::nullopt, result, /*cdist=*/true, inputBroadcastSize, norm_op_block);
+ return result;
+}
+
Tensor std_var_common_impl_mps(const Tensor & input_t,
at::OptionalIntArrayRef dim,
c10::optional<int64_t> correction,
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 9b42c97..27677db 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -4090,6 +4090,7 @@
- func: _cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor
dispatch:
CPU, CUDA: _cdist_forward
+ MPS: _cdist_forward_mps
autogen: _cdist_forward.out
- func: _cdist_backward(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist) -> Tensor
@@ -6174,6 +6175,7 @@
device_check: NoCheck # TensorIterator
dispatch:
CPU, CUDA: norm_dtype_out
+ MPS: norm_dtype_out_mps
- func: norm.out(Tensor self, Scalar? p, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)
structured: True
diff --git a/test/test_mps.py b/test/test_mps.py
index 00dadbe..2084f95 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -289,6 +289,151 @@
helper(0, [1024])
helper(0.2, [2, 3])
+ def test_cdist_large(self, device="mps"):
+ for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
+ x = torch.randn(100, 10, device=device)
+ y = torch.randn(100, 10, device=device)
+ actual = torch.cdist(x, y, p=2, compute_mode=cm)
+ expected = self._brute_cdist(x, y, p=2)
+ self.assertEqual(expected, actual)
+
+ def test_cdist_large_batch(self, device="mps"):
+ for cm in ['use_mm_for_euclid_dist_if_necessary', 'use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
+ x = torch.randn(4, 3, 100, 10, device=device)
+ y = torch.randn(4, 3, 100, 10, device=device)
+ actual = torch.cdist(x, y, p=2, compute_mode=cm)
+ expected = self._brute_cdist(x, y, p=2)
+ self.assertEqual(expected, actual)
+
+ def test_cdist_non_contiguous(self, device="mps"):
+ for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
+ x = torch.randn(5, 7, device=device).mT
+ y = torch.randn(5, 3, device=device).mT
+ actual = torch.cdist(x, y, p=2, compute_mode=cm)
+ expected = self._brute_cdist(x, y, p=2)
+ self.assertFalse(x.is_contiguous())
+ self.assertFalse(y.is_contiguous())
+ self.assertEqual(expected, actual)
+
+ x = torch.randn(7, 5, device=device)
+ y = torch.randn(5, 3, device=device).t()
+ actual = torch.cdist(x, y, p=2, compute_mode=cm)
+ expected = self._brute_cdist(x, y, p=2)
+ self.assertTrue(x.is_contiguous())
+ self.assertFalse(y.is_contiguous())
+ self.assertEqual(expected, actual)
+
+ x = torch.randn(5, 7, device=device).t()
+ y = torch.randn(3, 5, device=device)
+ actual = torch.cdist(x, y, p=2, compute_mode=cm)
+ expected = self._brute_cdist(x, y, p=2)
+ self.assertFalse(x.is_contiguous())
+ self.assertTrue(y.is_contiguous())
+ self.assertEqual(expected, actual)
+
+ def test_cdist_non_contiguous_batch(self, device="mps"):
+ for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
+ x = torch.randn(4, 3, 2, 5, 7, device=device).mT
+ y = torch.randn(4, 3, 2, 5, 3, device=device).mT
+ actual = torch.cdist(x, y, p=2, compute_mode=cm)
+ expected = self._brute_cdist(x, y, p=2)
+ self.assertFalse(x.is_contiguous())
+ self.assertFalse(y.is_contiguous())
+ self.assertEqual(expected, actual)
+
+ x = torch.randn(7, 2, 7, 5, device=device)
+ y = torch.randn(7, 2, 5, 3, device=device).mT
+ actual = torch.cdist(x, y, p=2, compute_mode=cm)
+ expected = self._brute_cdist(x, y, p=2)
+ self.assertTrue(x.is_contiguous())
+ self.assertFalse(y.is_contiguous())
+ self.assertEqual(expected, actual)
+
+ x = torch.randn(4, 5, 7, device=device).mT
+ y = torch.randn(4, 3, 5, device=device)
+ actual = torch.cdist(x, y, p=2, compute_mode=cm)
+ expected = self._brute_cdist(x, y, p=2)
+ self.assertFalse(x.is_contiguous())
+ self.assertTrue(y.is_contiguous())
+ self.assertEqual(expected, actual)
+
+ def test_cdist_euclidean_large(self, device="mps"):
+ def _test_euclidean_large_cdist(sizex, sizey=None):
+ if sizey is None:
+ sizey = sizex
+ x = torch.randn(sizex, device=device, dtype=torch.float)
+ y = torch.randn(sizey, device=device, dtype=torch.float)
+ eps = 1e-6
+ # to avoid extremum
+ x = x - (((x - y) < eps).float() * 2 * eps)
+ x.requires_grad = True
+ y.requires_grad = True
+ dist = torch.cdist(x, y, p=2)
+ # Do a backward pass to check that it is valid for large
+ # matrices
+ loss = dist.sum()
+ loss.backward()
+
+ _test_euclidean_large_cdist((2000, 5))
+
+ def test_cdist_same_inputs(self, device="mps"):
+ # Test to detect issues in cdist gradient calculation
+ # When the distances are 0
+ sizex = (1, 27, 32)
+ for p in [0, 1, 2, 3, 1.5, 2.5, float('inf')]:
+ x = torch.randn(sizex, device=device, dtype=torch.float)
+ dist_grad = torch.randn((1, 27, 27), device=device, dtype=torch.float)
+ y = x.clone()
+ eps = 1e-6
+ x.requires_grad = True
+ d = torch.cdist(x, y)
+ d.backward(dist_grad)
+ # Check that the backward passs does not contain invalid
+ # values such as nan or inf
+ assert torch.isfinite(x.grad).all()
+
+
+ def _brute_cdist(self, x, y, p=2):
+ r1 = x.shape[-2]
+ r2 = y.shape[-2]
+ if r1 == 0 or r2 == 0:
+ return torch.empty(r1, r2, device=x.device)
+ return torch.norm(x[..., None, :] - y[..., None, :, :], p=p, dim=-1)
+
+ def test_cdist_norm(self, device="mps"):
+ for r1 in [3, 4]:
+ for m in [2, 3]:
+ for r2 in [4, 6]:
+ for p in [0, 1, 1.5, 2.5, float('inf')]:
+ x = torch.randn(r1, m, device=device)
+ y = torch.randn(r2, m, device=device)
+ if p == 2:
+ for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
+ actual = torch.cdist(x, y, p=2, compute_mode=cm)
+ expected = self._brute_cdist(x, y, p=2)
+ self.assertEqual(expected, actual, rtol=0, atol=0.02)
+ else:
+ actual = torch.cdist(x, y, p=p)
+ expected = self._brute_cdist(x, y, p=p)
+ self.assertEqual(expected, actual)
+
+ def test_cdist_norm_batch(self, device="mps"):
+ for r1 in [3, 4]:
+ for m in [2, 3]:
+ for r2 in [4, 6]:
+ for p in [0, 3, 1.5, 2.5, float('inf')]:
+ x = torch.randn(2, 3, 6, r1, m, device=device)
+ y = torch.randn(2, 3, 6, r2, m, device=device)
+ if p == 2:
+ for cm in ['use_mm_for_euclid_dist', 'donot_use_mm_for_euclid_dist']:
+ actual = torch.cdist(x, y, p=2, compute_mode=cm)
+ expected = self._brute_cdist(x, y, p=2)
+ self.assertEqual(expected, actual, rtol=0, atol=0.02)
+ else:
+ actual = torch.cdist(x, y, p=p)
+ expected = self._brute_cdist(x, y, p=p)
+ self.assertEqual(expected, actual)
+
def test_mm(self):
B = torch.ones(5, 6).to("mps")
C = torch.ones(6, 5).to("mps")
@@ -809,6 +954,55 @@
helper(shape, eps=3, momentum=0.67, wts=True, training=True, channels_last=channels_last,
track_running_stats=track_running_stats, test_module=test_module)
+ def test_norm(self):
+ a = torch.arange(9, dtype=torch.float, device="mps") - 4
+ b = a.reshape((3, 3))
+
+ a_cpu = torch.arange(9, dtype=torch.float, device="cpu") - 4
+ b_cpu = a_cpu.reshape((3, 3))
+
+ res = torch.norm(a)
+ res_cpu = torch.norm(a_cpu)
+ self.assertEqual(res, res_cpu)
+
+ res = torch.norm(b)
+ res_cpu = torch.norm(b_cpu)
+ self.assertEqual(res, res_cpu)
+
+ res = torch.norm(a, float('inf'))
+ res_cpu = torch.norm(a_cpu, float('inf'))
+ self.assertEqual(res, res_cpu)
+
+ res = torch.norm(b, float('inf'))
+ res_cpu = torch.norm(b_cpu, float('inf'))
+ self.assertEqual(res, res_cpu)
+
+ c = torch.tensor([[1, 2, 3], [-1, 1, 4]], dtype=torch.float, device="mps")
+ c_cpu = torch.tensor([[1, 2, 3], [-1, 1, 4]] , dtype=torch.float, device="cpu")
+
+ res = torch.norm(c, dim=0)
+ res_cpu = torch.norm(c_cpu, dim=0)
+ self.assertEqual(res, res_cpu)
+
+ res = torch.norm(c, dim=1)
+ res_cpu = torch.norm(c_cpu, dim=1)
+ self.assertEqual(res, res_cpu)
+
+ res = torch.norm(c, p=1, dim=1)
+ res_cpu = torch.norm(c_cpu, p=1, dim=1)
+ self.assertEqual(res, res_cpu)
+
+ d = torch.arange(8, dtype=torch.float, device="mps").reshape(2, 2, 2)
+ d_cpu = torch.arange(8, dtype=torch.float, device="cpu").reshape(2, 2, 2)
+
+ res = torch.norm(d, dim=(1, 2))
+ res_cpu = torch.norm(d_cpu, dim=(1, 2))
+ self.assertEqual(res, res_cpu)
+
+ res = torch.norm(d[0, :, :]), torch.norm(d[1, :, :])
+ res_cpu = torch.norm(d_cpu[0, :, :]), torch.norm(d_cpu[1, :, :])
+ self.assertEqual(res, res_cpu)
+
def test_layer_norm(self):
# TODO: Test non-contiguous
def helper(input_shape, normalized_shape, eps=1e-05, elementwise_affine=True, dtype=torch.float32):