[MPS] Add pow.Scalar (#95201) 1. Adds `pow.Scalar`. 2. Modifies testing `atol` and `rtol` to get pow output match tests pass. 3. Xfails numerically incorrect dtypes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/95201 Approved by: https://github.com/kulinseth
diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm index 6a34d60..4569add 100644 --- a/aten/src/ATen/native/mps/operations/BinaryOps.mm +++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm
@@ -347,6 +347,28 @@ mps::add_sub_template(self, other, alpha, output, "sub"); } +TORCH_IMPL_FUNC(pow_Scalar_out_mps) (const Scalar& base, const Tensor& exp, const Tensor& out) { + if (base.equal(1.0)) { + out.fill_(1); + } else { + // Copied and modified from aten/stc/ATen/ScalarOps.h + // as MPS doesn't support float64 tensor. + Tensor base_tensor; + if (base.isFloatingPoint()) { + base_tensor = at::scalar_tensor(base, at::device(exp.device()).dtype(at::kFloat)); + } else if (base.isBoolean()) { + base_tensor = at::scalar_tensor(base, at::device(exp.device()).dtype(at::kBool)); + } else if (base.isComplex()) { + base_tensor = at::scalar_tensor(base, at::device(exp.device()).dtype(at::kComplexDouble)); + } else { + AT_ASSERT(base.isIntegral(false)); + base_tensor = at::scalar_tensor(base, at::device(exp.device()).dtype(at::kLong)); + } + base_tensor.unsafeGetTensorImpl()->set_wrapped_number(true); + at::pow_out(const_cast<Tensor&>(out), base_tensor, exp); // redispatch! + } +} + Tensor& floor_divide_out_mps(const Tensor& self, const Tensor& other, Tensor& result) { mps::div_mode_template(self, other, "floor", result, "floor_divide_out"); return result;
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index e878e1d..4721285 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml
@@ -9563,6 +9563,7 @@ structured: True dispatch: CPU, CUDA: pow_Scalar_out + MPS: pow_Scalar_out_mps tags: pointwise - func: pow.Scalar(Scalar self, Tensor exponent) -> Tensor
diff --git a/test/test_mps.py b/test/test_mps.py index 95ba3a3..b404bf8 100644 --- a/test/test_mps.py +++ b/test/test_mps.py
@@ -69,6 +69,7 @@ '__radd__': [torch.uint8], '__rdiv__': [torch.uint8], '__rmul__': [torch.uint8], + '__rpow__': [torch.uint8], 'abs': [torch.uint8], 'acos': [torch.uint8], 'acosh': [torch.uint8], @@ -108,6 +109,7 @@ 'nn.functional.poisson_nll_loss': [torch.uint8], 'nn.functional.softsign': [torch.uint8], 'nn.functional.tanhshrink': [torch.uint8], + 'pow': [torch.int16, torch.int64, torch.uint8], 'rad2deg': [torch.uint8], 'reciprocal': [torch.uint8], 'remainder': [torch.uint8], @@ -130,6 +132,7 @@ # Those ops are not expected to work XFAILLIST = { + '__rpow__': [torch.int16, torch.int32, torch.int64], 'chalf': None, # Unsupported dtypes 'dot': [torch.int64], @@ -140,8 +143,6 @@ 'nn.functional.conv_transpose2d': [torch.int64], 'remainder': [torch.int64], 'sigmoid': [torch.int64], - # Accuracy problems - 'pow': [torch.float32], # failures due to lack of op implementation on MPS backend 'put': None, # Weird @@ -1792,6 +1793,7 @@ # Test pow def test_pow(self): def helper(shape): + # aten::pow.Tensor_Tensor cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) x = cpu_x.detach().clone().to('mps') cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) @@ -1801,6 +1803,7 @@ self.assertEqual(z, ref_z) + # aten::pow.Tensor_Scalar cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) x = cpu_x.detach().clone().to('mps') exp = random.random() @@ -1809,6 +1812,15 @@ self.assertEqual(z, ref_z) + # aten::pow.Scalar + x = random.random() + cpu_y = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=False) + y = cpu_y.detach().clone().to('mps') + z = torch.pow(x, y) + ref_z = torch.pow(x, cpu_y) + + self.assertEqual(z, ref_z) + helper((2, 8, 4, 5)) # Test addcmul @@ -9438,7 +9450,7 @@ '__rmatmul__': ['f32'], '__rmul__': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], '__ror__': ['b8', 'i16', 'i32', 'i64', 'u8'], - '__rpow__': ['f16'], + '__rpow__': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], '__rxor__': ['b8', 'i16', 'i32', 'i64', 'u8'], 'masked.argmax': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'masked.argmin': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], @@ -9640,7 +9652,7 @@ 'nn.functional.upsample_nearest': ['f32'], 'norm': ['f32', 'f16'], 'positive': ['f16', 'f32', 'i16', 'i32', 'i64', 'u8'], - 'pow': ['f16', 'f32'], + 'pow': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'put': None, 'rad2deg': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], 'real': ['b8', 'f16', 'f32', 'i16', 'i32', 'i64', 'u8'], @@ -9743,6 +9755,7 @@ '__rdiv__': ['f16', 'f32'], '__rmatmul__': ['f32'], '__rmul__': ['f16', 'f32'], + '__rpow__': ['f32'], 'masked.log_softmax': ['f32'], 'masked.logaddexp': ['f32'], 'masked.softmax': ['f32'], @@ -9885,6 +9898,7 @@ 'nn.functional.upsample_bilinear': ['f32'], 'norm': ['f32', 'f16'], 'positive': ['f16', 'f32'], + 'pow': ['f32'], 'rad2deg': ['f16', 'f32'], 'real': ['f16', 'f32'], 'reciprocal': ['f16', 'f32'], @@ -10115,15 +10129,18 @@ if op.name == "nn.functional.conv2d" and dtype == torch.float32: atol = 1e-4 rtol = 3e-5 - elif (op.name in self.FP16_LOW_PRECISION_LIST) and dtype == torch.float16: + elif op.name in self.FP16_LOW_PRECISION_LIST and dtype == torch.float16: atol = 1e-2 rtol = 1e-2 - elif (op.name == "masked.mean"): + elif op.name == "masked.mean": atol = 7e-4 rtol = 2e-3 - elif (op.name == "native_layer_norm"): + elif op.name == "native_layer_norm": atol = 1e-4 rtol = 1.3e-5 + elif op.name in ["pow", "__rpow__"]: + atol = 1e-6 + rtol = 4e-6 else: atol = None rtol = None