[MPS] Add complex support for `fill` (#111885)

Fixes #110537
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111885
Approved by: https://github.com/malfet
diff --git a/aten/src/ATen/native/mps/operations/ConstantOps.mm b/aten/src/ATen/native/mps/operations/ConstantOps.mm
index 471188b..4e7ff92 100644
--- a/aten/src/ATen/native/mps/operations/ConstantOps.mm
+++ b/aten/src/ATen/native/mps/operations/ConstantOps.mm
@@ -1,8 +1,15 @@
 //  Copyright © 2022 Apple Inc.
 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
 #include <ATen/native/mps/OperationUtils.h>
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include <ATen/Functions.h>
+#include <ATen/NativeFunctions.h>
+#else
 #include <ATen/ops/fill_native.h>
+#include <ATen/ops/view_as_real.h>
 #include <ATen/ops/zero_native.h>
+#endif
 
 namespace at::native {
 
@@ -78,16 +85,25 @@
   return false;
 }
 
-Tensor& zero_mps_(Tensor& self) {
-  // check if it's possible to use fillBuffer() to fill the Tensor's storage
-  if (fill_mps_tensor_(self, 0) == true)
-    return self;
-  return fill_scalar_mps_impl(self, 0.0f);
-}
-
 Tensor& fill_scalar_mps(Tensor& self, const Scalar& value) {
+  // check if it's possible to use fillBuffer() to fill the Tensor's storage
   if (value.toDouble() == 0.0 && fill_mps_tensor_(self, 0) == true)
     return self;
+
+  if (isComplexType(self.scalar_type())) {
+    auto self_as_real = at::view_as_real(self);
+    auto self_as_real_real = self_as_real.select(self.dim(), 0);
+    auto self_as_real_imag = self_as_real.select(self.dim(), 1);
+    if (value.isComplex()) {
+      auto value_cdouble = value.to<c10::complex<double>>();
+      fill_scalar_mps_impl(self_as_real_real, value_cdouble.real());
+      fill_scalar_mps_impl(self_as_real_imag, value_cdouble.imag());
+      return self;
+    }
+    fill_scalar_mps_impl(self_as_real_real, value);
+    fill_scalar_mps_impl(self_as_real_imag, 0.0f);
+    return self;
+  }
   return fill_scalar_mps_impl(self, value);
 }
 
@@ -99,7 +115,11 @@
   Scalar scalar_value = value.item();
   if (scalar_value.toDouble() == 0.0 && fill_mps_tensor_(self, 0) == true)
     return self;
-  return fill_scalar_mps_impl(self, scalar_value);
+  return fill_scalar_mps(self, scalar_value);
+}
+
+Tensor& zero_mps_(Tensor& self) {
+  return fill_scalar_mps(self, 0.0f);
 }
 
 } // namespace at::native
diff --git a/test/test_mps.py b/test/test_mps.py
index 5d3a19e..5d09584 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -234,6 +234,7 @@
         'empty_strided',
         'eye',
         'flatten',
+        'fill',
         'full',
         'imag',
         'isfinite',
@@ -11171,7 +11172,7 @@
     def test_tensor_creation(self, device, dtype):
         def ones(device):
             return torch.ones((2, 2), dtype=dtype, device=device)
-        if dtype not in MPS_DTYPES:
+        if dtype not in MPS_DTYPES + [torch.complex64]:
             with self.assertRaises(TypeError):
                 ones(device)
         else: