Reland "Add torch.empty_permuted (#95069)" (#95208)

This reverts commit 92e03cd583c027a4100a13682cf65771b80569da.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95208
Approved by: https://github.com/albanD
diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp
index 4c0ba04..cf98348 100644
--- a/aten/src/ATen/native/TensorFactories.cpp
+++ b/aten/src/ATen/native/TensorFactories.cpp
@@ -46,6 +46,7 @@
 #include <ATen/ops/empty_like.h>
 #include <ATen/ops/empty_like_native.h>
 #include <ATen/ops/empty_native.h>
+#include <ATen/ops/empty_permuted_native.h>
 #include <ATen/ops/empty_strided.h>
 #include <ATen/ops/empty_strided_native.h>
 #include <ATen/ops/eye.h>
@@ -278,6 +279,45 @@
   return result;
 }
 
+Tensor empty_permuted_symint(SymIntArrayRef size, IntArrayRef physical_layout, c10::optional<ScalarType> dtype_opt,
+  c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt
+) {
+  // size is logical; aka, the output size you'll get from the operation overall
+  //
+  // physical_layout follows NCHW/NHWC convention:
+  // contiguous is [0,1,2,3], channels last is [0,2,3,1]
+  //
+  // this means if i is physical index, physical_layout[i] is logical index;
+  // e.g., to find what is innermost physical dim (3), query NHWC[3] == 1
+  // (aka it is channels)
+  int64_t dim = static_cast<int64_t>(size.size());
+  SymDimVector phys_size(dim);
+  TORCH_CHECK(static_cast<int64_t>(physical_layout.size()) == dim,
+    "Number of dimensions in size does not match the "
+    "length of the physical_layout; i.e. len(size) = ", dim,
+    " is not equal to len(physical_layout) = ", physical_layout.size());
+  std::vector<bool> seen_dims(dim);
+  for (const auto i : c10::irange(dim)) {
+    TORCH_CHECK(physical_layout[i] >= 0 && physical_layout[i] < dim,
+      "Dimension out of range (expected to be between 0 and ", dim - 1, ", but got ",
+      physical_layout[i], " at index ", i, ").  NB: negative dims "
+      "not currently supported; file an issue if you want it.");
+    TORCH_CHECK(!seen_dims[physical_layout[i]], "Duplicate dim not allowed");
+    phys_size[i] = size[physical_layout[i]];
+    seen_dims[physical_layout[i]] = true;
+  }
+  // do a contiguous allocation
+  Tensor phys_tensor = at::empty_symint(phys_size, dtype_opt, layout_opt, device_opt, pin_memory_opt, c10::nullopt);
+  SymIntArrayRef phys_strides = phys_tensor.sym_strides();
+  // permute the strides (inverse permutation!  This is why this is
+  // empty_permute*d*, not empty_permute; it's not an empty + permute)
+  SymDimVector strides(dim);
+  for (const auto i : c10::irange(dim)) {
+    strides[physical_layout[i]] = phys_strides[i];
+  }
+  return phys_tensor.as_strided_symint(size, strides);
+}
+
 Tensor empty_strided_cpu(IntArrayRef size, IntArrayRef stride, c10::optional<ScalarType> dtype_opt,
                          c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
   return at::detail::empty_strided_cpu(size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt);
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 013f62d..2f7a1a8 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -2241,6 +2241,11 @@
     SparseCsrCPU, SparseCsrCUDA: empty_sparse_compressed
     QuantizedCPU, QuantizedCUDA, QuantizedMeta: empty_unknown_quantized
 
+- func: empty_permuted(SymInt[] size, int[] physical_layout, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: empty_permuted_symint
+  autogen: empty_permuted.out
+
 # We do not make new_empty a composite that calls into new_empty_strided, as the strided version
 # is significantly more difficult to implement by different backends
 - func: new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect
index a3bb816..daf0178 100644
--- a/test/expect/HasDecompTest.test_has_decomposition.expect
+++ b/test/expect/HasDecompTest.test_has_decomposition.expect
@@ -719,6 +719,8 @@
 aten::empty.memory_format
 aten::empty.names
 aten::empty.names_out
+aten::empty_permuted
+aten::empty_permuted.out
 aten::empty_quantized
 aten::empty_quantized.out
 aten::equal
diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py
index cb5c78d..8d9dff2 100644
--- a/test/inductor/test_torchinductor_opinfo.py
+++ b/test/inductor/test_torchinductor_opinfo.py
@@ -429,6 +429,7 @@
 inductor_override_kwargs = {
     # the return value of empty is undefined
     "empty": {"assert_equal": False},
+    "empty_permuted": {"assert_equal": False},
     "empty_like": {"assert_equal": False},
     "new_empty": {"assert_equal": False},
     "new_empty_strided": {"assert_equal": False},
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index aa896c7..013eaa9 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -1154,6 +1154,7 @@
     skip('new_empty'),
     skip('empty_like'),
     skip('empty'),
+    skip('empty_permuted'),
     # flaky
     skip('linalg.lstsq', 'grad_oriented'),
     skip('nn.functional.max_unpool1d', '', device_type='cpu'),
diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py
index fa67156..dbd100d 100644
--- a/torch/_inductor/decomposition.py
+++ b/torch/_inductor/decomposition.py
@@ -61,6 +61,18 @@
     return aten.div.Tensor_mode(a, b, rounding_mode="floor")
 
 
+# Not really sure how to put this into the main library.  PrimTorch wants
+# empty_permuted to go to the prim, and typically users don't really want
+# to decompose to empty_strided (but inductor is OK with it, because we are
+# cool with strides and everything goes to empty_strided)
+@register_decomposition([aten.empty_permuted.default])
+def empty_permuted(size, physical_layout, **kwargs):
+    perm = [0] * len(size)
+    for p, l in enumerate(physical_layout):
+        perm[l] = p
+    return torch.empty([size[l] for l in physical_layout], **kwargs).permute(perm)
+
+
 def get_alignment_size(x):
     if x.dtype == torch.float16 or x.dtype == torch.half or x.dtype == torch.bfloat16:
         return 8
diff --git a/torch/_prims/__init__.py b/torch/_prims/__init__.py
index 8434933..652f283 100644
--- a/torch/_prims/__init__.py
+++ b/torch/_prims/__init__.py
@@ -193,6 +193,7 @@
     # Tensor Creation Prims
     #
     "empty_strided",
+    "empty_permuted",
     "scalar_tensor",
     "iota",
     #
@@ -2466,6 +2467,61 @@
 )
 
 
+def _empty_permuted_meta(
+    shape: ShapeType,
+    physical_layout: DimsSequenceType,
+    *,
+    dtype: torch.dtype,
+    device: torch.device,
+    requires_grad: bool,
+) -> TensorLikeType:
+    p_strides = utils.make_contiguous_strides_for([shape[l] for l in physical_layout])
+    dim = len(shape)
+    utils.check(
+        len(physical_layout) == dim,
+        lambda: (
+            "Number of dimensions in the tensor input does not match the "
+            f"length of the physical layout; i.e. len(size) = {dim} "
+            f"is not equal to len(physical_layout) = {len(physical_layout)}"
+        ),
+    )
+    strides = [0] * len(shape)
+    seen_dims = set()
+    for p, l in enumerate(physical_layout):
+        utils.check(
+            0 <= l < dim,
+            lambda: (
+                f"Dimension out of range (expected to be between 0 and {dim - 1}, but got "
+                f"{l} at index {p}).  NB: negative dims "
+                "not currently supported; file an issue if you want it."
+            ),
+        )
+        utils.check(l not in seen_dims, lambda: "Duplicate dim not allowed")
+        strides[l] = p_strides[p]
+        seen_dims.add(l)
+    return TensorMeta(
+        shape=shape,
+        strides=strides,
+        dtype=dtype,
+        device=device,
+    )
+
+
+_empty_permuted_doc = """
+    Creates a tensor with uninitialized values according to some physical layout,
+    that is guaranteed to be non-overlapping and dense.
+"""
+
+# TODO: add layout, pin_memory
+empty_permuted = _make_prim(
+    schema="empty_permuted(SymInt[] shape, int[] physical_layout, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",  # noqa: B950
+    return_type=RETURN_TYPE.NEW,
+    meta=_empty_permuted_meta,
+    impl_aten=torch.empty_permuted,
+    doc=_empty_permuted_doc,
+)
+
+
 def _full_meta(
     shape: ShapeType,
     fill_value: NumberType,
diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py
index 73977d9..7608bda 100644
--- a/torch/_refs/__init__.py
+++ b/torch/_refs/__init__.py
@@ -4042,6 +4042,27 @@
     )
 
 
+@out_wrapper()
+def empty_permuted(
+    shape,
+    physical_layout,
+    dtype: Optional[torch.dtype] = None,
+    layout: torch.layout = torch.strided,
+    device: Optional[torch.device] = None,
+    requires_grad: bool = False,
+    pin_memory: bool = False,
+) -> TensorLikeType:
+    return prims.empty_permuted(
+        shape,
+        physical_layout,
+        dtype=dtype,
+        layout=layout,
+        device=device,
+        pin_memory=pin_memory,
+        requires_grad=requires_grad,
+    )
+
+
 @register_decomposition(aten.new_empty)
 def new_empty(
     a: TensorLikeType,
diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py
index 77404e2..e44456e2 100644
--- a/torch/_torch_docs.py
+++ b/torch/_torch_docs.py
@@ -12354,6 +12354,51 @@
 )
 
 add_docstr(
+    torch.empty_permuted,
+    r"""
+empty_permuted(size, physical_layout, *, dtype=None, layout=None, device=None, requires_grad=False, pin_memory=False) -> Tensor
+
+Creates an uninitialized, non-overlapping and dense tensor with the
+specified :attr:`size`, with :attr:`physical_layout` specifying how the
+dimensions are physically laid out in memory (each logical dimension is listed
+from outermost to innermost).  :attr:`physical_layout` is a generalization
+of NCHW/NHWC notation: if each dimension is assigned a number according to
+what order they occur in size (N=0, C=1, H=2, W=3), then NCHW is ``(0, 1, 2, 3)``
+while NHWC is ``(0, 2, 3, 1)``.  Equivalently, the strides of the output
+tensor ``t`` are such that ``t.stride(physical_layout[i]) == contiguous_strides[i]``
+(notably, this function is *not* equivalent to ``torch.empty(size).permute(physical_layout)``).
+
+Unlike :func:`torch.empty_strided`, this is guaranteed to produce a dense
+tensor with no overlaps.  If possible, prefer using this function over
+:func:`torch.empty_strided` or manual use of :func:`torch.as_strided`.
+
+Args:
+    size (tuple of int): the shape of the output tensor
+    physical_layout (tuple of int): the ordering of dimensions physically in memory
+
+Keyword args:
+    {dtype}
+    {layout}
+    {device}
+    {requires_grad}
+    {pin_memory}
+
+Examples:
+
+    >>> torch.empty((2, 3, 5, 7)).stride()
+    (105, 35, 7, 1)
+    >>> torch.empty_permuted((2, 3, 5, 7), (0, 1, 2, 3)).stride()
+    (105, 35, 7, 1)
+    >>> torch.empty((2, 3, 5, 7), memory_format=torch.channels_last).stride()
+    (105, 1, 21, 3)
+    >>> torch.empty_permuted((2, 3, 5, 7), (0, 2, 3, 1)).stride()
+    (105, 1, 21, 3)
+""".format(
+        **factory_common_args
+    ),
+)
+
+add_docstr(
     torch.full,
     r"""
 full(size, fill_value, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor
diff --git a/torch/overrides.py b/torch/overrides.py
index f84d89e..6637045 100644
--- a/torch/overrides.py
+++ b/torch/overrides.py
@@ -144,6 +144,7 @@
         torch.cudnn_grid_sampler,
         torch.cudnn_is_acceptable,
         torch.empty,
+        torch.empty_permuted,
         torch.empty_strided,
         torch.empty_quantized,
         torch.eye,
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 54678e8..cd561e4 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -1567,6 +1567,33 @@
     for case in cases:
         yield SampleInput(case, device=device, dtype=dtype, requires_grad=requires_grad)
 
+def sample_inputs_empty_permuted(op, device, dtype, requires_grad, **kwargs):
+    # shape
+    cases = (
+        (), (0,), (1,), (1, 3, 5), (5, 3, 1), (1, 0, 5, 1),
+    )
+
+    for case in cases:
+        for layout in itertools.permutations(range(len(case))):
+            yield SampleInput(case, layout, device=device, dtype=dtype, requires_grad=requires_grad)
+
+def error_inputs_empty_permuted(op_info, device, **kwargs):
+    yield ErrorInput(
+        SampleInput((2,), args=((0, 1),)),
+        error_type=RuntimeError,
+        error_regex="Number of dimensions in size does not match the length of the physical_layout"
+    )
+    yield ErrorInput(
+        SampleInput((2,), args=((3,),)),
+        error_type=RuntimeError,
+        error_regex="Dimension out of range"
+    )
+    yield ErrorInput(
+        SampleInput((2, 3), args=((0, 0),)),
+        error_type=RuntimeError,
+        error_regex="Duplicate dim not allowed"
+    )
+
 def sample_inputs_scalar_tensor(op, device, dtype, requires_grad, **kwargs):
     # Not including a scalar tensor in vals because meta tests start failing due to
     # lack of meta support for _local_scalar_dense
@@ -15772,6 +15799,48 @@
                # UserWarning not triggered : Resized a non-empty tensor but did not warn about it.
                DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out_warning'),
            )),
+    OpInfo('empty_permuted',
+           dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
+           sample_inputs_func=sample_inputs_empty_permuted,
+           error_inputs_func=error_inputs_empty_permuted,
+           supports_out=False,
+           supports_autograd=False,
+           skips=(
+               DecorateInfo(unittest.expectedFailure, "TestNormalizeOperators", "test_normalize_operator_exhaustive"),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_variant_consistency_eager'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_noncontiguous_samples'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_conj_view'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestNNCOpInfo', 'test_nnc_correctness'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCudaFuserOpInfo'),
+               # Empty tensor data is garbage so it's hard to make comparisons with it.
+               DecorateInfo(unittest.skip("Skipped!"), 'TestCommon', 'test_non_standard_bool_values'),
+               DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"), 'TestCompositeCompliance',
+                            'test_operator'),
+               # requires_grad doesn't exist in the jit schema
+               DecorateInfo(unittest.expectedFailure, 'TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
+               DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"),
+                            'TestCommon',
+                            'test_out'),
+               DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"),
+                            'TestCommon',
+                            'test_out_warning'),
+               DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"),
+                            'TestLazyOpInfo'),
+               DecorateInfo(unittest.skip("Expected: empty_permuted is not comparable"),
+                            'TestCommon', 'test_complex_half_reference_testing'),
+               DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'),
+           )),
     OpInfo('scalar_tensor',
            dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16, torch.chalf),
            sample_inputs_func=sample_inputs_scalar_tensor,
diff --git a/torch/utils/_device.py b/torch/utils/_device.py
index 54fb15d..12e9da7 100644
--- a/torch/utils/_device.py
+++ b/torch/utils/_device.py
@@ -8,6 +8,7 @@
     return {
         # standard ones
         torch.empty,
+        torch.empty_permuted,
         torch.empty_strided,
         torch.empty_quantized,
         torch.ones,