[MPS] Add complex support for `fill` (#111885)
Fixes #110537
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111885
Approved by: https://github.com/malfet
diff --git a/aten/src/ATen/native/mps/operations/ConstantOps.mm b/aten/src/ATen/native/mps/operations/ConstantOps.mm
index 471188b..4e7ff92 100644
--- a/aten/src/ATen/native/mps/operations/ConstantOps.mm
+++ b/aten/src/ATen/native/mps/operations/ConstantOps.mm
@@ -1,8 +1,15 @@
// Copyright © 2022 Apple Inc.
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/native/mps/OperationUtils.h>
+
+#ifndef AT_PER_OPERATOR_HEADERS
+#include <ATen/Functions.h>
+#include <ATen/NativeFunctions.h>
+#else
#include <ATen/ops/fill_native.h>
+#include <ATen/ops/view_as_real.h>
#include <ATen/ops/zero_native.h>
+#endif
namespace at::native {
@@ -78,16 +85,25 @@
return false;
}
-Tensor& zero_mps_(Tensor& self) {
- // check if it's possible to use fillBuffer() to fill the Tensor's storage
- if (fill_mps_tensor_(self, 0) == true)
- return self;
- return fill_scalar_mps_impl(self, 0.0f);
-}
-
Tensor& fill_scalar_mps(Tensor& self, const Scalar& value) {
+ // check if it's possible to use fillBuffer() to fill the Tensor's storage
if (value.toDouble() == 0.0 && fill_mps_tensor_(self, 0) == true)
return self;
+
+ if (isComplexType(self.scalar_type())) {
+ auto self_as_real = at::view_as_real(self);
+ auto self_as_real_real = self_as_real.select(self.dim(), 0);
+ auto self_as_real_imag = self_as_real.select(self.dim(), 1);
+ if (value.isComplex()) {
+ auto value_cdouble = value.to<c10::complex<double>>();
+ fill_scalar_mps_impl(self_as_real_real, value_cdouble.real());
+ fill_scalar_mps_impl(self_as_real_imag, value_cdouble.imag());
+ return self;
+ }
+ fill_scalar_mps_impl(self_as_real_real, value);
+ fill_scalar_mps_impl(self_as_real_imag, 0.0f);
+ return self;
+ }
return fill_scalar_mps_impl(self, value);
}
@@ -99,7 +115,11 @@
Scalar scalar_value = value.item();
if (scalar_value.toDouble() == 0.0 && fill_mps_tensor_(self, 0) == true)
return self;
- return fill_scalar_mps_impl(self, scalar_value);
+ return fill_scalar_mps(self, scalar_value);
+}
+
+Tensor& zero_mps_(Tensor& self) {
+ return fill_scalar_mps(self, 0.0f);
}
} // namespace at::native
diff --git a/test/test_mps.py b/test/test_mps.py
index 5d3a19e..5d09584 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -234,6 +234,7 @@
'empty_strided',
'eye',
'flatten',
+ 'fill',
'full',
'imag',
'isfinite',
@@ -11171,7 +11172,7 @@
def test_tensor_creation(self, device, dtype):
def ones(device):
return torch.ones((2, 2), dtype=dtype, device=device)
- if dtype not in MPS_DTYPES:
+ if dtype not in MPS_DTYPES + [torch.complex64]:
with self.assertRaises(TypeError):
ones(device)
else: