[MPS] Native implementation for addr (#94538)

```
addr_out_mps to perform res = betainput + alpha(vec1Xvec2)
move addr f16 to low precision list
move addr none float to unsupported list
add test_addr tests
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94538
Approved by: https://github.com/razarmehr
diff --git a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm
index 0cb6be7..6e3f1bc 100644
--- a/aten/src/ATen/native/mps/operations/LinearAlgebra.mm
+++ b/aten/src/ATen/native/mps/operations/LinearAlgebra.mm
@@ -185,6 +185,152 @@
   return output;
 }
 
+
+Tensor addr_mps(const Tensor& self,
+            const Tensor& vec1, const Tensor& vec2,
+            const Scalar& beta, const Scalar& alpha) {
+  Tensor result = at::empty({0}, self.options());
+  addr_out_mps(self, vec1,vec2,beta,alpha,result);
+  return result;
+}
+
+
+Tensor& addr_out_mps(const Tensor& self,
+                 const Tensor& vec1, const Tensor& vec2,
+                 const Scalar& beta, const Scalar& alpha, Tensor &result) {
+  using namespace mps;
+
+  TORCH_CHECK(result.is_mps());
+  TORCH_CHECK(vec1.dim() == 1 && vec2.dim() == 1, "tensors must be 1-D");
+  TORCH_CHECK(vec1.scalar_type() == ScalarType::Double
+              || vec1.scalar_type() == ScalarType::Float
+              || vec1.scalar_type() == ScalarType::Half, "MPS device does not support addr for non-float input");
+
+  TensorArg args[]{{result, "out", 0}, {self, "self", 1}, {vec1, "vec1", 2}, {vec2, "vec2", 3}};
+  checkAllSameGPU(__func__, args);
+
+  IntArrayRef vec1_sizes = vec1.sizes();
+  IntArrayRef vec2_sizes = vec2.sizes();
+  IntArrayRef self_sizes;
+
+  c10::MaybeOwned<Tensor> self_;
+  if (&result != &self) {
+    self_ = expand_size(self, {vec1_sizes[0], vec2_sizes[0]}, "addr");
+    self_sizes = self_->sizes();
+  } else {
+    self_ = c10::MaybeOwned<Tensor>::borrowed(self);
+    self_sizes = self_->sizes();
+    TORCH_CHECK(result.dim() == 2, "tensors must be 2-D");
+    TORCH_CHECK(self_sizes[0] == vec1_sizes[0], "vec1_ dim 0 must match vec1 dim 0");
+    TORCH_CHECK(self_sizes[1] == vec2_sizes[0], "vec1_ dim 1 must match vec2 dim 0");
+  }
+
+  if (&result != &vec1) {
+    result.resize_(self_sizes);
+    if (beta.toComplexDouble() != 0.0) {
+      at::native::copy_(result, *self_);
+    }
+  }
+
+  IntArrayRef result_sizes = result.sizes();
+  if ((result_sizes[0] == 0) || (result_sizes[1] == 0)) {
+    return result;
+  }
+
+  MPSStream* stream = getCurrentMPSStream();
+  bool is_beta_non_zero = beta.toDouble() != 0.0;
+  MPSShape* inputShape = @[@(vec1.numel()), @(1)];
+  MPSShape* otherShape = @[@(1), @(vec2.numel())];
+
+  struct CachedGraph : public mps::MPSCachedGraph
+  {
+    CachedGraph(MPSGraph *graph) : MPSCachedGraph(graph) {}
+    MPSGraphTensor *vec1Tensor_ = nil;
+    MPSGraphTensor *vec2Tensor_ = nil;
+    MPSGraphTensor *selfTensor_ = nil;
+    MPSGraphTensor *resultTensor_ = nil;
+  };
+
+  mps::MPSGraphCache *cache_ = mps::MPSGraphCache::getInstance();
+
+  @autoreleasepool {
+    string key = "addr_out_mps_impl" + getTensorsStringKey({vec1, vec2, *self_})
+                                       + ":" + to_string(beta.toDouble())
+                                       + ":" + to_string(alpha.toDouble());
+    CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
+    if(!cachedGraph) {
+
+      mps::MPSCachedGraph *tmpCachedGraph = cache_->CreateCachedGraph(key, ^ mps::MPSCachedGraph * () {
+        CachedGraph *newCachedGraph = nil;
+
+        @autoreleasepool{
+          MPSGraph *mpsGraph = mps::make_mps_graph();
+          newCachedGraph = new CachedGraph(mpsGraph);
+
+          MPSGraphTensor *t1 = mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec1.scalar_type()), inputShape);
+          MPSGraphTensor *t2 =  mps::mpsGraphRankedPlaceHolder(mpsGraph, getMPSDataType(vec2.scalar_type()), otherShape);
+          MPSGraphTensor *selfTensor =  mps::mpsGraphRankedPlaceHolder(mpsGraph, *self_);
+
+          // Intermediate as placeholder
+          MPSGraphTensor* productTensor = [mpsGraph matrixMultiplicationWithPrimaryTensor:t1
+                                                                          secondaryTensor:t2
+                                                                                     name:@"MM/(vec1Xvec2)"];
+
+          // Intermediates for beta and alpha
+          MPSGraphTensor* betaTensor = [mpsGraph constantWithScalar:beta.toDouble()
+                                                           dataType:getMPSScalarType((*self_).scalar_type())];
+          MPSGraphTensor* alphaTensor = [mpsGraph constantWithScalar:alpha.toDouble()
+                                                           dataType:getMPSScalarType(vec1.scalar_type())];
+
+          // Intermediates for multiplying by beta and alpha
+          MPSGraphTensor* productTimesAlphaTensor = [mpsGraph multiplicationWithPrimaryTensor:productTensor
+                                                                              secondaryTensor:alphaTensor
+                                                                                         name:@"MM/alpha*(vec1Xvec2)"];
+          MPSGraphTensor* selfTimesBetaTensor = selfTensor;
+          if (is_beta_non_zero) {
+            selfTimesBetaTensor = [mpsGraph multiplicationWithPrimaryTensor:selfTensor
+                                                            secondaryTensor:betaTensor
+                                                                       name:@"MM/beta*input"];
+          }
+
+          MPSGraphTensor* resultTensor = productTimesAlphaTensor;
+          if (is_beta_non_zero) {
+            resultTensor = [mpsGraph additionWithPrimaryTensor:productTimesAlphaTensor
+                                               secondaryTensor:selfTimesBetaTensor
+                                                          name:@"MM/beta*input+alpha*(vec1@vec2)"];
+           }
+
+          newCachedGraph->vec1Tensor_ = t1;
+          newCachedGraph->vec2Tensor_ = t2;
+          newCachedGraph->selfTensor_ = selfTensor;
+          newCachedGraph->resultTensor_ = resultTensor;
+        }
+        return newCachedGraph;
+      });
+      cachedGraph = static_cast<CachedGraph *>(tmpCachedGraph);
+    }
+
+    Placeholder vec1Placeholder = Placeholder(cachedGraph->vec1Tensor_, vec1, inputShape);
+    Placeholder vec2Placeholder = Placeholder(cachedGraph->vec2Tensor_, vec2, otherShape);
+    Placeholder selfPlaceholder = Placeholder(cachedGraph->selfTensor_, *self_);
+    Placeholder resultPlaceholder = Placeholder(cachedGraph->resultTensor_, result);
+
+    NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* feeds = @{
+      vec1Placeholder.getMPSGraphTensor() : vec1Placeholder.getMPSGraphTensorData(),
+      vec2Placeholder.getMPSGraphTensor() : vec2Placeholder.getMPSGraphTensorData(),
+      selfPlaceholder.getMPSGraphTensor() : selfPlaceholder.getMPSGraphTensorData()
+    };
+
+    NSDictionary<MPSGraphTensor*, MPSGraphTensorData*>* results = @{
+      resultPlaceholder.getMPSGraphTensor() : resultPlaceholder.getMPSGraphTensorData()
+    };
+
+    mps::runMPSGraph(stream, cachedGraph->graph(), feeds, results);
+  }
+
+  return result;
+}
+
 Tensor& addmm_out_mps_impl(
     const Tensor& bias,
     const Tensor& self,  // input
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 2cb2b62..fc2c60c 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -596,6 +596,7 @@
   variants: function, method
   dispatch:
     CPU, CUDA: addr
+    MPS: addr_mps
     CompositeExplicitAutograd: math_addr
 
 - func: addr_(Tensor(a!) self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)
@@ -606,6 +607,7 @@
 - func: addr.out(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
   dispatch:
     CPU, CUDA: addr_out
+    MPS: addr_out_mps
     CompositeExplicitAutograd: math_addr_out
 
 - func: affine_grid_generator(Tensor theta, int[] size, bool align_corners) -> Tensor
diff --git a/test/test_mps.py b/test/test_mps.py
index 4841e6a..e3329a4 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -522,6 +522,13 @@
         self.assertEqual(output_cpu, output_mps)
         self.assertEqual(output_cpu.size(), output_mps.size())
 
+    def test_addr(self):
+        A = torch.ones(5, 10).to("mps")
+        B = torch.ones(5).to("mps")
+        C = torch.ones(10).to("mps")
+        D = torch.addr(A, B, C).to("cpu")
+        torch.testing.assert_close(D, torch.full((5, 10), 2.0))
+
     def test_trace(self):
         M_cpu = torch.randn(3, 3)
         M_mps = M_cpu.detach().clone().to("mps")
@@ -6422,6 +6429,30 @@
         m2 = maybe_transpose(t3, torch.randn(50, 25, device=device).to(dtype))
         self._test_addmm_addmv(torch.addmm, M, m1, m2, transpose_out=t4)
 
+    def _test_addr(self, f, t, m, v, alpha=None, beta=None):
+        dtype = t.dtype
+        numpy_dtype = dtype
+        alpha = 1.2 if alpha is None else alpha
+        beta = 0.8 if beta is None else beta
+        res1 = f(t, m, v, alpha=alpha, beta=beta)
+        res2 = alpha * np.outer(m.to(numpy_dtype).cpu().numpy(), v.to(numpy_dtype).cpu().numpy())
+        if beta != 0:
+            res2 += (torch.mul(t, beta)).to(numpy_dtype).cpu().numpy()
+        res2 = torch.from_numpy(res2).to(dtype)
+        self.assertEqual(res1, res2)
+
+    def test_addr(self, device="mps", dtype=torch.float32):
+        M = torch.randn(10, 25, device=device).to(dtype)
+        m1 = torch.randn(10, device=device).to(dtype)
+        m2 = torch.randn(25, device=device).to(dtype)
+        self._test_addr(torch.addr, M, m1, m2)
+
+        # Test beta=0, M=nan
+        M = torch.full((10, 25), math.nan, device=device).to(dtype)
+        m1 = torch.randn(10, device=device).to(dtype)
+        m2 = torch.randn(25, device=device).to(dtype)
+        self._test_addr(torch.addr, M, m1, m2, beta=0)
+
 class TestGatherScatter(TestCase):
     def test_slicing_with_step(self):
         # Slicing with step
@@ -8707,7 +8738,7 @@
         'addcmul': ['f32', 'i16', 'i32', 'i64', 'u8'],
         'addmm': ['f32'],
         'addmv': ['f32'],
-        'addr': ['b8', 'f32', 'i16', 'i32', 'i64', 'u8'],
+        'addr': ['f32'],
         'all': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
         'allclose': ['f16', 'f32'],
         'any': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],