[MPS] Native implementation for addr (#94538)
```
addr_out_mps to perform res = betainput + alpha(vec1Xvec2)
move addr f16 to low precision list
move addr none float to unsupported list
add test_addr tests
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94538
Approved by: https://github.com/razarmehr
diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm
index 0cb6be7..6e3f1bc 100644
--- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm
+++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm
@@ -185,6 +185,152 @@
return output;
}
+
+Tensor addr_mps(const Tensor& self,
+ const Tensor& vec1, const Tensor& vec2,
+ const Scalar& beta, const Scalar& alpha) {
+ Tensor result = at::empty({0}, self.options());
+ addr_out_mps(self, vec1,vec2,beta,alpha,result);
+ return result;
+}
+
+
+Tensor& addr_out_mps(const Tensor& self,
+ const Tensor& vec1, const Tensor& vec2,
+ const Scalar& beta, const Scalar& alpha, Tensor &result) {
+ using namespace mps;
+
+ TORCH_CHECK(result.is_mps());
+ TORCH_CHECK(vec1.dim() == 1 && vec2.dim() == 1, "tensors must be 1-D");
+ TORCH_CHECK(vec1.scalar_type() == ScalarType::Double
+ || vec1.scalar_type() == ScalarType::Float
+ || vec1.scalar_type() == ScalarType::Half, "MPS device does not support addr for non-float input");
+
+ TensorArg args[]{{result, "out", 0}, {self, "self", 1}, {vec1, "vec1", 2}, {vec2, "vec2", 3}};
+ checkAllSameGPU(__func__, args);
+
+ IntArrayRef vec1_sizes = vec1.sizes();
+ IntArrayRef vec2_sizes = vec2.sizes();
+ IntArrayRef self_sizes;
+
+ c10::MaybeOwned<Tensor> self_;
+ if (&result != &self) {
+ self_ = expand_size(self, {vec1_sizes[0], vec2_sizes[0]}, "addr");
+ self_sizes = self_->sizes();
+ } else {
+ self_ = c10::MaybeOwned<Tensor>::borrowed(self);
+ self_sizes = self_->sizes();
+ TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
+ TORCH_CHECK(self_sizes[0] == vec1_sizes[0], "vec1_ dim 0 must match vec1 dim 0");
+ TORCH_CHECK(self_sizes[1] == vec2_sizes[0], "vec1_ dim 1 must match vec2 dim 0");
+ }
+
+ if (&result != &vec1) {
+ result.resize_(self_sizes);
+ if (beta.toComplexDouble() != 0.0) {
+ at::native::copy_(result, *self_);
+ }
+ }
+
+ IntArrayRef result_sizes = result.sizes();
+ if ((result_sizes[0] == 0) || (result_sizes[1] == 0)) {
+ return result;
+ }
+
+ MPSStream* stream = getCurrentMPSStream();
+ bool is_beta_non_zero = beta.toDouble() != 0.0;
+ MPSShape* inputShape = @[@(vec1.numel()), @(1)];
+ MPSShape* otherShape = @[@(1), @(vec2.numel())];
+
+ struct CachedGraph : public mps::MPSCachedGraph
+ {
+ CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
+ MPSGraphTensor *vec1Tensor_ = nil;
+ MPSGraphTensor *vec2Tensor_ = nil;
+ MPSGraphTensor *selfTensor_ = nil;
+ MPSGraphTensor *resultTensor_ = nil;
+ };
+
+ mps::MPSGraphCache *cache_ = mps::MPSGraphCache::getInstance();
+
+ @autoreleasepool {
+ string key = "addr_out_mps_impl" + getTensorsStringKey({vec1, vec2, *self_})
+ + ":" + to_string(beta.toDouble())
+ + ":" + to_string(alpha.toDouble());
+ CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
+ if(!cachedGraph) {
+
+ mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ mps::MPSCachedGraph * () {
+ CachedGraph *newCachedGraph = nil;
+
+ @autoreleasepool{
+ MPSGraph *mpsGraph = mps::make_mps_graph();
+ newCachedGraph = new CachedGraph(mpsGraph);
+
+ MPSGraphTensor *t1 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec1.scalar_type()), inputShape);
+ MPSGraphTensor *t2 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec2.scalar_type()), otherShape);
+ MPSGraphTensor *selfTensor = mps::mpsGraphRankedPlaceHolder(mpsGraph, *self_);
+
+ // Intermediate as placeholder
+ MPSGraphTensor* productTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:t1
+ secondaryTensor:t2
+ name:@"MM/(vec1Xvec2)"];
+
+ // Intermediates for beta and alpha
+ MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta.toDouble()
+ dataType:getMPSScalarType((*self_).scalar_type())];
+ MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar:alpha.toDouble()
+ dataType:getMPSScalarType(vec1.scalar_type())];
+
+ // Intermediates for multiplying by beta and alpha
+ MPSGraphTensor* productTimesAlphaTensor = [mpsGraph multiplicationWithPrimaryTensor:productTensor
+ secondaryTensor:alphaTensor
+ name:@"MM/alpha*(vec1Xvec2)"];
+ MPSGraphTensor* selfTimesBetaTensor = selfTensor;
+ if (is_beta_non_zero) {
+ selfTimesBetaTensor = [mpsGraph multiplicationWithPrimaryTensor:selfTensor
+ secondaryTensor:betaTensor
+ name:@"MM/beta*input"];
+ }
+
+ MPSGraphTensor* resultTensor = productTimesAlphaTensor;
+ if (is_beta_non_zero) {
+ resultTensor = [mpsGraph additionWithPrimaryTensor:productTimesAlphaTensor
+ secondaryTensor:selfTimesBetaTensor
+ name:@"MM/beta*input+alpha*(vec1@vec2)"];
+ }
+
+ newCachedGraph->vec1Tensor_ = t1;
+ newCachedGraph->vec2Tensor_ = t2;
+ newCachedGraph->selfTensor_ = selfTensor;
+ newCachedGraph->resultTensor_ = resultTensor;
+ }
+ return newCachedGraph;
+ });
+ cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
+ }
+
+ Placeholder vec1Placeholder = Placeholder(cachedGraph->vec1Tensor_, vec1, inputShape);
+ Placeholder vec2Placeholder = Placeholder(cachedGraph->vec2Tensor_, vec2, otherShape);
+ Placeholder selfPlaceholder = Placeholder(cachedGraph->selfTensor_, *self_);
+ Placeholder resultPlaceholder = Placeholder(cachedGraph->resultTensor_, result);
+
+ NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
+ vec1Placeholder.getMPSGraphTensor() : vec1Placeholder.getMPSGraphTensorData(),
+ vec2Placeholder.getMPSGraphTensor() : vec2Placeholder.getMPSGraphTensorData(),
+ selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()
+ };
+
+ NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
+ resultPlaceholder.getMPSGraphTensor() : resultPlaceholder.getMPSGraphTensorData()
+ };
+
+ mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
+ }
+
+ return result;
+}
+
Tensor& addmm_out_mps_impl(
const Tensor& bias,
const Tensor& self, // input
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 2cb2b62..fc2c60c 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -596,6 +596,7 @@
variants: function, method
dispatch:
CPU, CUDA: addr
+ MPS: addr_mps
CompositeExplicitAutograd: math_addr
- func: addr_(Tensor(a!) self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
@@ -606,6 +607,7 @@
- func: addr.out(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: addr_out
+ MPS: addr_out_mps
CompositeExplicitAutograd: math_addr_out
- func: affine_grid_generator(Tensor theta, int[] size, bool align_corners) -> Tensor
diff --git a/test/test_mps.py b/test/test_mps.py
index 4841e6a..e3329a4 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -522,6 +522,13 @@
self.assertEqual(output_cpu, output_mps)
self.assertEqual(output_cpu.size(), output_mps.size())
+ def test_addr(self):
+ A = torch.ones(5, 10).to("mps")
+ B = torch.ones(5).to("mps")
+ C = torch.ones(10).to("mps")
+ D = torch.addr(A, B, C).to("cpu")
+ torch.testing.assert_close(D, torch.full((5, 10), 2.0))
+
def test_trace(self):
M_cpu = torch.randn(3, 3)
M_mps = M_cpu.detach().clone().to("mps")
@@ -6422,6 +6429,30 @@
m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
self._test_addmm_addmv(torch.addmm, M, m1, m2, transpose_out=t4)
+ def _test_addr(self, f, t, m, v, alpha=None, beta=None):
+ dtype = t.dtype
+ numpy_dtype = dtype
+ alpha = 1.2 if alpha is None else alpha
+ beta = 0.8 if beta is None else beta
+ res1 = f(t, m, v, alpha=alpha, beta=beta)
+ res2 = alpha * np.outer(m.to(numpy_dtype).cpu().numpy(), v.to(numpy_dtype).cpu().numpy())
+ if beta != 0:
+ res2 += (torch.mul(t, beta)).to(numpy_dtype).cpu().numpy()
+ res2 = torch.from_numpy(res2).to(dtype)
+ self.assertEqual(res1, res2)
+
+ def test_addr(self, device="mps", dtype=torch.float32):
+ M = torch.randn(10, 25, device=device).to(dtype)
+ m1 = torch.randn(10, device=device).to(dtype)
+ m2 = torch.randn(25, device=device).to(dtype)
+ self._test_addr(torch.addr, M, m1, m2)
+
+ # Test beta=0, M=nan
+ M = torch.full((10, 25), math.nan, device=device).to(dtype)
+ m1 = torch.randn(10, device=device).to(dtype)
+ m2 = torch.randn(25, device=device).to(dtype)
+ self._test_addr(torch.addr, M, m1, m2, beta=0)
+
class TestGatherScatter(TestCase):
def test_slicing_with_step(self):
# Slicing with step
@@ -8707,7 +8738,7 @@
'addcmul': ['f32', 'i16', 'i32', 'i64', 'u8'],
'addmm': ['f32'],
'addmv': ['f32'],
- 'addr': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'],
+ 'addr': ['f32'],
'all': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
'allclose': ['f16', 'f32'],
'any': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],