[MPS] Add pow.Scalar (#95201)

1. Adds `pow.Scalar`.
2. Modifies testing `atol` and `rtol` to get pow output match tests pass.
3. Xfails numerically incorrect dtypes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95201
Approved by: https://github.com/kulinseth
diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm
index 6a34d60..4569add 100644
--- a/aten/src/ATen/native/mps/operations/BinaryOps.mm
+++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm
@@ -347,6 +347,28 @@
   mps::add_sub_template(self, other, alpha, output, "sub");
 }
 
+TORCH_IMPL_FUNC(pow_Scalar_out_mps) (const Scalar& base, const Tensor& exp, const Tensor& out) {
+  if (base.equal(1.0)) {
+    out.fill_(1);
+  } else {
+    // Copied and modified from aten/stc/ATen/ScalarOps.h
+    // as MPS doesn't support float64 tensor.
+    Tensor base_tensor;
+    if (base.isFloatingPoint()) {
+      base_tensor = at::scalar_tensor(base, at::device(exp.device()).dtype(at::kFloat));
+    } else if (base.isBoolean()) {
+      base_tensor = at::scalar_tensor(base, at::device(exp.device()).dtype(at::kBool));
+    } else if (base.isComplex()) {
+      base_tensor = at::scalar_tensor(base, at::device(exp.device()).dtype(at::kComplexDouble));
+    } else {
+      AT_ASSERT(base.isIntegral(false));
+      base_tensor = at::scalar_tensor(base, at::device(exp.device()).dtype(at::kLong));
+    }
+    base_tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
+    at::pow_out(const_cast<Tensor&>(out), base_tensor, exp); // redispatch!
+  }
+}
+
 Tensor& floor_divide_out_mps(const Tensor& self, const Tensor& other, Tensor& result) {
   mps::div_mode_template(self, other, "floor", result, "floor_divide_out");
   return result;
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index e878e1d..4721285 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -9563,6 +9563,7 @@
   structured: True
   dispatch:
     CPU, CUDA: pow_Scalar_out
+    MPS: pow_Scalar_out_mps
   tags: pointwise
 
 - func: pow.Scalar(Scalar self, Tensor exponent) -> Tensor
diff --git a/test/test_mps.py b/test/test_mps.py
index 95ba3a3..b404bf8 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -69,6 +69,7 @@
         '__radd__': [torch.uint8],
         '__rdiv__': [torch.uint8],
         '__rmul__': [torch.uint8],
+        '__rpow__': [torch.uint8],
         'abs': [torch.uint8],
         'acos': [torch.uint8],
         'acosh': [torch.uint8],
@@ -108,6 +109,7 @@
         'nn.functional.poisson_nll_loss': [torch.uint8],
         'nn.functional.softsign': [torch.uint8],
         'nn.functional.tanhshrink': [torch.uint8],
+        'pow': [torch.int16, torch.int64, torch.uint8],
         'rad2deg': [torch.uint8],
         'reciprocal': [torch.uint8],
         'remainder': [torch.uint8],
@@ -130,6 +132,7 @@
 
     # Those ops are not expected to work
     XFAILLIST = {
+        '__rpow__': [torch.int16, torch.int32, torch.int64],
         'chalf': None,
         # Unsupported dtypes
         'dot': [torch.int64],
@@ -140,8 +143,6 @@
         'nn.functional.conv_transpose2d': [torch.int64],
         'remainder': [torch.int64],
         'sigmoid': [torch.int64],
-        # Accuracy problems
-        'pow': [torch.float32],
         # failures due to lack of op implementation on MPS backend
         'put': None,
         # Weird
@@ -1792,6 +1793,7 @@
     # Test pow
     def test_pow(self):
         def helper(shape):
+            # aten::pow.Tensor_Tensor
             cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
             x = cpu_x.detach().clone().to('mps')
             cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
@@ -1801,6 +1803,7 @@
 
             self.assertEqual(z, ref_z)
 
+            # aten::pow.Tensor_Scalar
             cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
             x = cpu_x.detach().clone().to('mps')
             exp = random.random()
@@ -1809,6 +1812,15 @@
 
             self.assertEqual(z, ref_z)
 
+            # aten::pow.Scalar
+            x = random.random()
+            cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False)
+            y = cpu_y.detach().clone().to('mps')
+            z = torch.pow(x, y)
+            ref_z = torch.pow(x, cpu_y)
+
+            self.assertEqual(z, ref_z)
+
         helper((2, 8, 4, 5))
 
     # Test addcmul
@@ -9438,7 +9450,7 @@
         '__rmatmul__': ['f32'],
         '__rmul__': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
         '__ror__': ['b8', 'i16', 'i32', 'i64', 'u8'],
-        '__rpow__': ['f16'],
+        '__rpow__': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
         '__rxor__': ['b8', 'i16', 'i32', 'i64', 'u8'],
         'masked.argmax': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
         'masked.argmin': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
@@ -9640,7 +9652,7 @@
         'nn.functional.upsample_nearest': ['f32'],
         'norm': ['f32', 'f16'],
         'positive': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
-        'pow': ['f16', 'f32'],
+        'pow': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
         'put': None,
         'rad2deg': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
         'real': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'],
@@ -9743,6 +9755,7 @@
         '__rdiv__': ['f16', 'f32'],
         '__rmatmul__': ['f32'],
         '__rmul__': ['f16', 'f32'],
+        '__rpow__': ['f32'],
         'masked.log_softmax': ['f32'],
         'masked.logaddexp': ['f32'],
         'masked.softmax': ['f32'],
@@ -9885,6 +9898,7 @@
         'nn.functional.upsample_bilinear': ['f32'],
         'norm': ['f32', 'f16'],
         'positive': ['f16', 'f32'],
+        'pow': ['f32'],
         'rad2deg': ['f16', 'f32'],
         'real': ['f16', 'f32'],
         'reciprocal': ['f16', 'f32'],
@@ -10115,15 +10129,18 @@
                 if op.name == "nn.functional.conv2d" and dtype == torch.float32:
                     atol = 1e-4
                     rtol = 3e-5
-                elif (op.name in self.FP16_LOW_PRECISION_LIST) and dtype == torch.float16:
+                elif op.name in self.FP16_LOW_PRECISION_LIST and dtype == torch.float16:
                     atol = 1e-2
                     rtol = 1e-2
-                elif (op.name == "masked.mean"):
+                elif op.name == "masked.mean":
                     atol = 7e-4
                     rtol = 2e-3
-                elif (op.name == "native_layer_norm"):
+                elif op.name == "native_layer_norm":
                     atol = 1e-4
                     rtol = 1.3e-5
+                elif op.name in ["pow", "__rpow__"]:
+                    atol = 1e-6
+                    rtol = 4e-6
                 else:
                     atol = None
                     rtol = None