[MPS] Implement `mul` operation for complex types (#108395)
Using existing BinaryKernel template
Add `mul` as well as `kron` and `outer` to list of MPS ops that support complex types
This should add all the missing ops mentioned in https://github.com/pytorch/pytorch/issues/105665
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108395
Approved by: https://github.com/albanD
ghstack dependencies: #108393, #108394
diff --git a/aten/src/ATen/native/mps/operations/BinaryKernel.h b/aten/src/ATen/native/mps/operations/BinaryKernel.h
new file mode 100644
index 0000000..d22c22a
--- /dev/null
+++ b/aten/src/ATen/native/mps/operations/BinaryKernel.h
@@ -0,0 +1,5 @@
+#pragma once
+
+namespace at::native::mps {
+void complex_mul_out(const Tensor& input, const Tensor& other, const Tensor& output);
+}
diff --git a/aten/src/ATen/native/mps/operations/BinaryKernel.mm b/aten/src/ATen/native/mps/operations/BinaryKernel.mm
index ddd4b4d..ba88f3f 100644
--- a/aten/src/ATen/native/mps/operations/BinaryKernel.mm
+++ b/aten/src/ATen/native/mps/operations/BinaryKernel.mm
@@ -5,6 +5,7 @@
#include <ATen/native/BinaryOps.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/mps/OperationUtils.h>
+#include <ATen/native/mps/operations/BinaryKernel.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@@ -155,6 +156,31 @@
REGISTER_POLAR_OP(float);
REGISTER_POLAR_OP(half);
+template<typename T>
+kernel void complex_mul(constant void * input_ [[buffer(0)]],
+ constant void * other_ [[buffer(1)]],
+ device void * out_ [[buffer(2)]],
+ constant uint3 * offsets [[buffer(3)]],
+ uint tid [[thread_position_in_grid]]) {
+ device T* out = (device T*)((device uint8_t*)out_ + offsets[tid].x);
+ constant T* input = (constant T*)((constant uint8_t*)input_ + offsets[tid].y);
+ constant T* other = (constant T*)((constant uint8_t*)other_ + offsets[tid].z);
+ out[0] = input[0]*other[0] - input[1]*other[1];
+ out[1] = input[0]*other[1] + input[1]*other[0];
+}
+
+#define REGISTER_COMPLEX_MUL_OP(DTYPE) \
+template \
+[[host_name("complex_mul_" #DTYPE)]] \
+kernel void complex_mul<DTYPE>( \
+ constant void * input, \
+ constant void * other, \
+ device void * out, \
+ constant uint3 * offsets, \
+ uint tid)
+
+REGISTER_COMPLEX_MUL_OP(float);
+REGISTER_COMPLEX_MUL_OP(half);
)BINARY_METAL";
using namespace mps;
@@ -269,6 +295,26 @@
}
});
}
+
+void complex_mul_out(const Tensor& input, const Tensor& other, const Tensor& output) {
+ TORCH_INTERNAL_ASSERT(c10::isComplexType(input.scalar_type()) && c10::isComplexType(other.scalar_type()));
+ auto new_size = at::infer_size(input.sizes(), other.sizes());
+ if (!output.sizes().equals(new_size)) {
+ output.resize_(new_size);
+ }
+ uint32_t length = output.numel();
+ if (length == 0) {
+ return;
+ }
+ auto output_as_real = at::view_as_real(output).select(output.dim(), 0);
+ auto input_as_real = at::view_as_real(input).select(input.dim(), 0);
+ auto other_as_real = at::view_as_real(other).select(other.dim(), 0);
+ auto iter =
+ TensorIteratorConfig().add_output(output_as_real).add_input(input_as_real).add_input(other_as_real).build();
+
+ mps::binary_mps_impl(iter, "complex_mul");
+}
+
} // namespace mps
static void fmax_mps_kernel(TensorIteratorBase& iter) {
diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm
index 33b8862..17e2d23 100644
--- a/aten/src/ATen/native/mps/operations/BinaryOps.mm
+++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm
@@ -4,6 +4,7 @@
#include <ATen/ScalarOps.h>
#include <ATen/native/BinaryOps.h>
#include <ATen/native/mps/OperationUtils.h>
+#include <ATen/native/mps/operations/BinaryKernel.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@@ -386,13 +387,24 @@
// Arithmetic Binary Ops
CREATE_MPS_STRUCTURED_BINARY_OP_FUNC(minimum_out_mps, minimum, Tensor);
CREATE_MPS_STRUCTURED_BINARY_OP_FUNC(maximum_out_mps, maximum, Tensor);
-CREATE_MPS_STRUCTURED_BINARY_OP_FUNC(mul_out_mps, multiplication, Tensor);
CREATE_MPS_STRUCTURED_BINARY_OP_FUNC(pow_tensor_scalar_out_mps, power, Scalar);
CREATE_MPS_STRUCTURED_BINARY_OP_FUNC(pow_tensor_tensor_out_mps, power, Tensor);
CREATE_MPS_BINARY_COMPARISON_OP_FUNC(logical_and_out_mps, logicalAND, Tensor);
CREATE_MPS_BINARY_COMPARISON_OP_FUNC(logical_or_out_mps, logicalOR, Tensor);
CREATE_MPS_BINARY_COMPARISON_OP_FUNC(logical_xor_out_mps, logicalXOR, Tensor);
+TORCH_IMPL_FUNC(mul_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
+ if (c10::isComplexType(self.scalar_type()) || c10::isComplexType(other.scalar_type())) {
+ return mps::complex_mul_out(self, other, output);
+ }
+ mps::binaryOpTensor(
+ self, other, Scalar(1.0), output, "mul", ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
+ MPSGraph* mpsGraph = cachedGraph->graph();
+ return [mpsGraph multiplicationWithPrimaryTensor:primaryCastTensor
+ secondaryTensor:secondaryCastTensor
+ name:nil];
+ });
+}
TORCH_IMPL_FUNC(atan2_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
TORCH_CHECK(self.scalar_type() != ScalarType::Long, "MPS does not support atan2 op with int64 input");
mps::binaryOpTensor(
diff --git a/test/test_mps.py b/test/test_mps.py
index 7442f06..452bed4 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -210,6 +210,7 @@
# Supported complex OPS
SUPPORTED_COMPLEX_OPS = [
'__radd__',
+ '__rmul__',
'add',
'atleast_1d',
'atleast_2d',
@@ -227,11 +228,14 @@
'isinf',
'isreal',
'item',
+ 'kron',
'linspace',
'logspace',
+ 'mul',
'nn.functional.feature_alpha_dropoutwithout_train',
'nn.functional.unfold',
'ones',
+ 'outer',
'positive',
'randn',
'ravel',