[MPS] Add complex `add`/`sub` (#108394)
Using `view_as_real` and running elementwise ops in resulted tensors
Add `add` and `sub` to list of complex ops that should work on MPS
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108394
Approved by: https://github.com/albanD
ghstack dependencies: #108393
diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm
index 7c0105c..33b8862 100644
--- a/aten/src/ATen/native/mps/operations/BinaryOps.mm
+++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm
@@ -35,6 +35,7 @@
#include <ATen/ops/remainder_native.h>
#include <ATen/ops/result_type.h>
#include <ATen/ops/sub_native.h>
+#include <ATen/ops/view_as_real.h>
#include <ATen/ops/xlogy_native.h>
#endif
@@ -411,10 +412,20 @@
}
TORCH_IMPL_FUNC(add_out_mps)(const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& output) {
+ if (isComplexType(self.scalar_type()) && isComplexType(other.scalar_type()) && !alpha.isComplex()) {
+ // Complex add with non-complex alpha is just add over views
+ return mps::add_sub_lerp_template(
+ at::view_as_real(self), at::view_as_real(other), alpha, at::view_as_real(output), "add");
+ }
mps::add_sub_lerp_template(self, other, alpha, output, "add");
}
TORCH_IMPL_FUNC(sub_out_mps)(const Tensor& self, const Tensor& other, const Scalar& alpha, const Tensor& output) {
+ if (isComplexType(self.scalar_type()) && isComplexType(other.scalar_type()) && !alpha.isComplex()) {
+ // Complex sub with non-complex alpha is just add over views
+ return mps::add_sub_lerp_template(
+ at::view_as_real(self), at::view_as_real(other), alpha, at::view_as_real(output), "sub");
+ }
mps::add_sub_lerp_template(self, other, alpha, output, "sub");
}
diff --git a/test/test_mps.py b/test/test_mps.py
index 9d1197d..7442f06 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -209,6 +209,8 @@
def mps_ops_modifier(ops):
# Supported complex OPS
SUPPORTED_COMPLEX_OPS = [
+ '__radd__',
+ 'add',
'atleast_1d',
'atleast_2d',
'atleast_3d',
@@ -243,6 +245,7 @@
'split',
'squeeze',
'squeezemultiple',
+ 'sub',
't',
'unflatten',
'unsafe_split',