[MPS] Implement `mul` operation for complex types (#108395)

Using existing BinaryKernel template

Add `mul` as well as `kron` and `outer` to list of MPS ops that support complex types

This should add all the missing ops mentioned in https://github.com/pytorch/pytorch/issues/105665
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108395
Approved by: https://github.com/albanD
ghstack dependencies: #108393, #108394
diff --git a/aten/src/ATen/native/mps/operations/BinaryKernel.h b/aten/src/ATen/native/mps/operations/BinaryKernel.h
new file mode 100644
index 0000000..d22c22a
--- /dev/null
+++ b/aten/src/ATen/native/mps/operations/BinaryKernel.h
@@ -0,0 +1,5 @@
+#pragma once
+
+namespace at::native::mps {
+void complex_mul_out(const Tensor& input, const Tensor& other, const Tensor& output);
+}
diff --git a/aten/src/ATen/native/mps/operations/BinaryKernel.mm b/aten/src/ATen/native/mps/operations/BinaryKernel.mm
index ddd4b4d..ba88f3f 100644
--- a/aten/src/ATen/native/mps/operations/BinaryKernel.mm
+++ b/aten/src/ATen/native/mps/operations/BinaryKernel.mm
@@ -5,6 +5,7 @@
 #include <ATen/native/BinaryOps.h>
 #include <ATen/native/TensorIterator.h>
 #include <ATen/native/mps/OperationUtils.h>
+#include <ATen/native/mps/operations/BinaryKernel.h>
 
 #ifndef AT_PER_OPERATOR_HEADERS
 #include <ATen/Functions.h>
@@ -155,6 +156,31 @@
 REGISTER_POLAR_OP(float);
 REGISTER_POLAR_OP(half);
 
+template<typename T>
+kernel void complex_mul(constant void  * input_       [[buffer(0)]],
+                        constant void  * other_       [[buffer(1)]],
+                        device   void  * out_         [[buffer(2)]],
+                        constant uint3 * offsets      [[buffer(3)]],
+                        uint tid [[thread_position_in_grid]]) {
+  device   T* out   = (device   T*)((device uint8_t*)out_ + offsets[tid].x);
+  constant T* input = (constant T*)((constant uint8_t*)input_ + offsets[tid].y);
+  constant T* other = (constant T*)((constant uint8_t*)other_ + offsets[tid].z);
+  out[0] = input[0]*other[0] - input[1]*other[1];
+  out[1] = input[0]*other[1] + input[1]*other[0];
+}
+
+#define REGISTER_COMPLEX_MUL_OP(DTYPE)       \
+template                                     \
+[[host_name("complex_mul_" #DTYPE)]]         \
+kernel void complex_mul<DTYPE>(              \
+  constant void    * input,                  \
+  constant void    * other,                  \
+  device   void    * out,                    \
+  constant uint3   * offsets,                \
+  uint tid)
+
+REGISTER_COMPLEX_MUL_OP(float);
+REGISTER_COMPLEX_MUL_OP(half);
 )BINARY_METAL";
 
 using namespace mps;
@@ -269,6 +295,26 @@
     }
   });
 }
+
+void complex_mul_out(const Tensor& input, const Tensor& other, const Tensor& output) {
+  TORCH_INTERNAL_ASSERT(c10::isComplexType(input.scalar_type()) && c10::isComplexType(other.scalar_type()));
+  auto new_size = at::infer_size(input.sizes(), other.sizes());
+  if (!output.sizes().equals(new_size)) {
+    output.resize_(new_size);
+  }
+  uint32_t length = output.numel();
+  if (length == 0) {
+    return;
+  }
+  auto output_as_real = at::view_as_real(output).select(output.dim(), 0);
+  auto input_as_real = at::view_as_real(input).select(input.dim(), 0);
+  auto other_as_real = at::view_as_real(other).select(other.dim(), 0);
+  auto iter =
+      TensorIteratorConfig().add_output(output_as_real).add_input(input_as_real).add_input(other_as_real).build();
+
+  mps::binary_mps_impl(iter, "complex_mul");
+}
+
 } // namespace mps
 
 static void fmax_mps_kernel(TensorIteratorBase& iter) {
diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm
index 33b8862..17e2d23 100644
--- a/aten/src/ATen/native/mps/operations/BinaryOps.mm
+++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm
@@ -4,6 +4,7 @@
 #include <ATen/ScalarOps.h>
 #include <ATen/native/BinaryOps.h>
 #include <ATen/native/mps/OperationUtils.h>
+#include <ATen/native/mps/operations/BinaryKernel.h>
 
 #ifndef AT_PER_OPERATOR_HEADERS
 #include <ATen/Functions.h>
@@ -386,13 +387,24 @@
 // Arithmetic Binary Ops
 CREATE_MPS_STRUCTURED_BINARY_OP_FUNC(minimum_out_mps, minimum, Tensor);
 CREATE_MPS_STRUCTURED_BINARY_OP_FUNC(maximum_out_mps, maximum, Tensor);
-CREATE_MPS_STRUCTURED_BINARY_OP_FUNC(mul_out_mps, multiplication, Tensor);
 CREATE_MPS_STRUCTURED_BINARY_OP_FUNC(pow_tensor_scalar_out_mps, power, Scalar);
 CREATE_MPS_STRUCTURED_BINARY_OP_FUNC(pow_tensor_tensor_out_mps, power, Tensor);
 CREATE_MPS_BINARY_COMPARISON_OP_FUNC(logical_and_out_mps, logicalAND, Tensor);
 CREATE_MPS_BINARY_COMPARISON_OP_FUNC(logical_or_out_mps, logicalOR, Tensor);
 CREATE_MPS_BINARY_COMPARISON_OP_FUNC(logical_xor_out_mps, logicalXOR, Tensor);
 
+TORCH_IMPL_FUNC(mul_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
+  if (c10::isComplexType(self.scalar_type()) || c10::isComplexType(other.scalar_type())) {
+    return mps::complex_mul_out(self, other, output);
+  }
+  mps::binaryOpTensor(
+      self, other, Scalar(1.0), output, "mul", ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
+        MPSGraph* mpsGraph = cachedGraph->graph();
+        return [mpsGraph multiplicationWithPrimaryTensor:primaryCastTensor
+                                         secondaryTensor:secondaryCastTensor
+                                                    name:nil];
+      });
+}
 TORCH_IMPL_FUNC(atan2_out_mps)(const Tensor& self, const Tensor& other, const Tensor& output) {
   TORCH_CHECK(self.scalar_type() != ScalarType::Long, "MPS does not support atan2 op with int64 input");
   mps::binaryOpTensor(
diff --git a/test/test_mps.py b/test/test_mps.py
index 7442f06..452bed4 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -210,6 +210,7 @@
     # Supported complex OPS
     SUPPORTED_COMPLEX_OPS = [
         '__radd__',
+        '__rmul__',
         'add',
         'atleast_1d',
         'atleast_2d',
@@ -227,11 +228,14 @@
         'isinf',
         'isreal',
         'item',
+        'kron',
         'linspace',
         'logspace',
+        'mul',
         'nn.functional.feature_alpha_dropoutwithout_train',
         'nn.functional.unfold',
         'ones',
+        'outer',
         'positive',
         'randn',
         'ravel',