[MPS] Enable `index_select` for complex types (#122590)
Surprisingly, as of MacOS-14.14 MPS `gatherWithUpdatesTensor:indicesTensor:axis:batchDimensions:name:` still does not support complex types, so emulate them by using `at::view_as_real` trick
Fixes https://github.com/pytorch/pytorch/issues/122427
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122590
Approved by: https://github.com/Skylion007
diff --git a/aten/src/ATen/native/mps/operations/Indexing.mm b/aten/src/ATen/native/mps/operations/Indexing.mm
index 3170169..2b0ab90 100644
--- a/aten/src/ATen/native/mps/operations/Indexing.mm
+++ b/aten/src/ATen/native/mps/operations/Indexing.mm
@@ -39,6 +39,7 @@
#include <ATen/ops/masked_select_native.h>
#include <ATen/ops/nonzero.h>
#include <ATen/ops/nonzero_native.h>
+#include <ATen/ops/view_as_real.h>
#endif
namespace at::native {
@@ -602,7 +603,6 @@
" and ",
output.size(dim),
".");
- TORCH_CHECK(!self.is_complex(), "index_select(): Yet not supported for complex");
for (const auto i : irange(self.dim())) {
if (i == dim)
@@ -628,6 +628,14 @@
return output;
}
+ // As of MacOS 14.4 gatherWithUpdatesTensor: still does not support complex
+ // So back to old view_as_real trick
+ if (self.is_complex()) {
+ auto out_view = at::view_as_real(output);
+ index_select_out_mps(at::view_as_real(self), dim, index, out_view);
+ return output;
+ }
+
// Derive from MPSCachedGraph
struct CachedGraph : public MPSCachedGraph {
CachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
diff --git a/test/test_mps.py b/test/test_mps.py
index 7e62b55..6f847e8 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -264,6 +264,7 @@
'H',
'hsplit',
'imag',
+ 'index_select',
'isfinite',
'isinf',
'isreal',
@@ -297,6 +298,7 @@
'randn',
'ravel',
'real',
+ 'repeat_interleave',
'reshape_as',
'reshape',
'resolve_conj',