Add ATen Op _chunk_cat and _chunk_cat.out (#121081)

# Motivation

In backward of per-parameter sharding FSDP, each rank performs reduce scatter to sync gradients across ranks. A rank chunks each gradient tensor into `world_size` slices along the 0-th dimension and concatenate all slices along the 1-th dimension. Gradient tensors will be padded before concatenation when tensor.size(0) % world_size != 0.

### Example 1
Consider `world_size=3` and tensors A (2x4), B (3x3), C (1x2):

Input tensors:
```
AAAA   BBB   CC
AAAA   BBB
       BBB
```

Reduce-scatter-copy-in Output:
```
AAAABBBCC
AAAABBB00
0000BBB00
```

### Example 2
Consider `world_size=2` and tensors A (2x4), B (3x3), C(1x2), D(4x2):

Input tensors:
```
AAAA   BBB   CC   DD
AAAA   BBB   00   DD
       BBB        DD
       000        DD
```

Reduce-scatter-copy-in first pad:
```
AAAA   BBB   CC   DD
AAAA   BBB   00   DD
       BBB        DD
       000        DD
```

Then chunk and cat along dim as the output:
```
AAAABBBBBBCCDDDD
AAAABBB00000DDDD
```

The performance of reduce-scatter-copy-in is critical to per-parameter sharding FSDP. However, reduce-scatter-copy-in via composing existing ATen ops involves `cat` and irregular `pad`, leading redundant data copies and unsatisfactory performance.

# PR
We provide aten native support for reduce-scatter-copy-in, namely `_chunk_cat()`:

```
_chunk_cat(Tensor[] tensors, int dim, int num_chunks) -> Tensor
```

This PR includes the registration of `_chunk_cat` and `_chunk_cat.out`, OpInfo tests, and basic implementation composing existing ATen ops.
In the next PR, we will add the CUDA implementation. Comparing with baselines of composing existing ATen ops, `_chunk_cat()` CUDA implementation improves copy bandwidth from 498 GB/s to 966 GB/s on a production benchmark.

## Requirements on input

1. If input tensors have different ndims, dim should be non-negative and be less than the ndims of every input tensors. If all input tensors have the same ndims, we support both negative and non-negative dim.
2. For wrapped_dim, all tensors should have the same size for 0,...,wrapped_dim-1 dimensions. No requirements for (wrapped_dim, ...)-th dimension.
3. Expect positive num_chunks
4. Expect non-empty input tensor list and each input tensor should have at least 1 element

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121081
Approved by: https://github.com/albanD
diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp
index e492bd6..1873201 100644
--- a/aten/src/ATen/native/TensorShape.cpp
+++ b/aten/src/ATen/native/TensorShape.cpp
@@ -36,9 +36,11 @@
 #include <ATen/Functions.h>
 #include <ATen/NativeFunctions.h>
 #else
+#include <ATen/ops/_chunk_cat_native.h>
 #include <ATen/ops/_conj_copy_native.h>
 #include <ATen/ops/_convert_indices_from_coo_to_csr.h>
 #include <ATen/ops/_convert_indices_from_csr_to_coo.h>
+#include <ATen/ops/_foreach_copy.h>
 #include <ATen/ops/_fw_primal_copy_native.h>
 #include <ATen/ops/_indices_copy_native.h>
 #include <ATen/ops/_make_dual.h>
@@ -2723,6 +2725,38 @@
   }
 }
 
+// Pads each tensor on `dim`-th dimension such that padded_dim % num_chunks == 0.
+static std::vector<Tensor> _pad_chunk(TensorList tensors, int64_t dim, int64_t num_chunks) {
+  auto num_tensors = tensors.size();
+  std::vector<Tensor> padded_tensors;
+  padded_tensors.reserve(num_tensors);
+  for (const auto & tensor : tensors) {
+    auto tensor_size = tensor.sizes();
+    std::vector<int64_t> padded_size(tensor_size.vec());
+    padded_size[dim] = (tensor_size[dim] + num_chunks - 1) / num_chunks * num_chunks;
+    Tensor padded_tensor = tensor;
+    if (padded_size != tensor_size) {
+      padded_tensor = tensor.new_zeros(padded_size);
+      padded_tensor.narrow(dim, 0, tensor_size[dim]).copy_(tensor);
+    }
+    std::vector<int64_t> view_sizes(tensor_size.begin(), tensor_size.begin()+dim);
+    view_sizes.insert(view_sizes.end(), {num_chunks, -1});
+    padded_tensors.push_back(padded_tensor.view(view_sizes));
+  }
+  return padded_tensors;
+}
+
+Tensor _chunk_cat(TensorList tensors, int64_t dim, int64_t num_chunks) {
+  auto wrapped_dim = at::native::preprocess_chunk_cat_inputs(tensors, dim, num_chunks);
+  return at::cat(_pad_chunk(tensors, wrapped_dim, num_chunks), wrapped_dim+1);
+}
+
+Tensor& _chunk_cat_out(TensorList tensors, int64_t dim, int64_t num_chunks, Tensor& out) {
+  auto wrapped_dim = at::native::preprocess_chunk_cat_inputs(tensors, dim, num_chunks);
+  at::cat_out(out, _pad_chunk(tensors, wrapped_dim, num_chunks), wrapped_dim+1);
+  return out;
+}
+
 // TODO(msubkhankulov): refactor to use _stack
 Tensor stack(TensorList tensors, int64_t dim) {
   TORCH_CHECK(!tensors.empty(),
diff --git a/aten/src/ATen/native/TensorShape.h b/aten/src/ATen/native/TensorShape.h
index 1c84abb..638ddba 100644
--- a/aten/src/ATen/native/TensorShape.h
+++ b/aten/src/ATen/native/TensorShape.h
@@ -55,4 +55,51 @@
   return num_splits;
 }
 
+inline bool have_same_ndims(TensorList tensors) {
+  auto ndim = tensors[0].dim();
+  for (const auto tensor_idx : c10::irange(tensors.size())) {
+    if(tensors[tensor_idx].dim() != ndim) {
+      return false;
+    }
+  }
+  return true;
+}
+
+inline void leading_dimension_matches(TensorList tensors, int64_t dim) {
+  auto tensor_zero_size = tensors[0].sizes();
+  std::vector<c10::SymInt> leading_dim_sizes(tensor_zero_size.begin(), tensor_zero_size.begin() + dim);
+  for (const auto i : c10::irange(tensors.size())) {
+    at::Tensor tensor = tensors[i];
+    for(const auto j : c10::irange(dim)) {
+      TORCH_CHECK(
+        tensor.size(j) == leading_dim_sizes[j],
+        "_chunk_cat expects same sizes of 0,...,dim-1 dimensions for all tensors"
+      );
+    }
+  }
+}
+
+inline int64_t preprocess_chunk_cat_inputs(TensorList tensors, int64_t dim, int64_t num_chunks) {
+  TORCH_CHECK(num_chunks >= 1, "_chunk_cat expects positive num_chunks");
+  TORCH_CHECK(!tensors.empty(),
+           "_chunk_cat expects a non-empty input tensor list");
+  auto expected_dtype = tensors[0].dtype();
+  auto expected_device = tensors[0].device();
+  for(const auto i : c10::irange(tensors.size())) {
+    TORCH_CHECK(tensors[i].numel() > 0, "_chunk_cat expects non-empty tensor");
+    TORCH_CHECK(tensors[i].dtype() == expected_dtype, "_chunk_cat expects all input tensors with the same dtype");
+    TORCH_CHECK(tensors[i].device() == expected_device, "_chunk_cat expects all inputs tensors on the same device");
+  }
+  if (have_same_ndims(tensors)) {
+    dim = maybe_wrap_dim(dim, tensors[0].dim());
+  } else {
+    TORCH_CHECK(dim >= 0, "_chunk_cat expects non-negative dim when input tensors have different ndims")
+    for(const auto i : c10::irange(tensors.size())) {
+      TORCH_CHECK(dim < tensors[i].ndimension(), "_chunk_cat expects dim < ndim for all input tensors");
+    }
+  }
+  leading_dimension_matches(tensors, dim);
+  return dim;
+}
+
 } // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 99f6a3a..5d90f64 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -5617,6 +5617,14 @@
     SparseCPU: _sspaddmm_out_cpu
     SparseCUDA: _sspaddmm_out_cuda
 
+- func: _chunk_cat(Tensor[] tensors, int dim, int num_chunks) -> Tensor
+  dispatch:
+    CompositeExplicitAutograd: _chunk_cat
+
+- func: _chunk_cat.out(Tensor[] tensors, int dim, int num_chunks, *, Tensor(a!) out) -> Tensor(a!)
+  dispatch:
+    CompositeExplicitAutograd: _chunk_cat_out
+
 - func: stack(Tensor[] tensors, int dim=0) -> Tensor
   dispatch:
     CompositeExplicitAutograd: stack
diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py
index 66b78f0..afc6c17 100644
--- a/test/distributed/_tensor/test_dtensor_ops.py
+++ b/test/distributed/_tensor/test_dtensor_ops.py
@@ -94,6 +94,7 @@
     # get full support with varying sharding specs
     xfail("__getitem__"),
     xfail("__rsub__"),
+    xfail("_chunk_cat"),
     xfail("_native_batch_norm_legit"),
     xfail("_upsample_bilinear2d_aa"),
     xfail("addbmm"),
diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py
index 3e7120a..d3231df 100644
--- a/test/test_fx_experimental.py
+++ b/test/test_fx_experimental.py
@@ -1564,7 +1564,7 @@
     @ops(op_db, allowed_dtypes=(torch.float,))
     def test_normalize_operator_exhaustive(self, device, dtype, op):
         # These ops currently don't trace in FX for various reasons (i.e. they take a list of tensors)
-        fx_fail = {"cat", "stack", "hstack", "vstack", "dstack", "linalg.multi_dot", "_upsample_bilinear2d_aa"}
+        fx_fail = {"cat", "stack", "hstack", "vstack", "dstack", "linalg.multi_dot", "_upsample_bilinear2d_aa", "_chunk_cat"}
         sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
         if isinstance(op.op, torch._ops.OpOverload):
             self.skipTest("normalize operator doesn't work on torch.ops")
diff --git a/test/test_mps.py b/test/test_mps.py
index 3c5e47c..63bc604 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -71,6 +71,7 @@
         # Unimplemented ops
         '__getitem__': [torch.float16],
         '_segment_reduce': [torch.float16, torch.float32],
+        '_chunk_cat': [torch.float16, torch.float32],
         'unfold_copy': [torch.float16, torch.float32],  # unfold_backward is not implemented
         'unfold': [torch.float16, torch.float32],
         'sparse.mmreduce': [torch.float32],  # csr not supported
@@ -342,6 +343,7 @@
 
     AFTER_MACOS_14_0_SUPPORTED_COMPLEX_OPS = {
         '__rdiv__',
+        '_chunk_cat',
         'acos',
         'acosh',
         'all',
diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py
index 9146f40..24092c4 100644
--- a/torch/_decomp/__init__.py
+++ b/torch/_decomp/__init__.py
@@ -457,6 +457,7 @@
             aten.zero_,
             aten.zeros,
             aten.zeros_like,
+            aten._chunk_cat,
             aten._weight_norm_interface,
         ]
     )
diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py
index 5c66749..bacdab3 100644
--- a/torch/_decomp/decompositions.py
+++ b/torch/_decomp/decompositions.py
@@ -1212,6 +1212,97 @@
     return r
 
 
+def _pad_chunk(
+    tensors: List[Tensor],
+    dim: int,
+    num_chunks: int,
+) -> List[Tensor]:
+    padded_tensors = []
+    for tensor in tensors:
+        tensor_size = tensor.size()
+        pad_along_dim = (tensor_size[dim] + num_chunks - 1) // num_chunks * num_chunks
+        if pad_along_dim != tensor_size[dim]:
+            # Use aten.constant_pad_nd instead of copy_ for functionalization
+            pad = [0] * 2 * (tensor.ndim - dim - 1) + [
+                0,
+                pad_along_dim - tensor_size[dim],
+            ]
+            tensor = aten.constant_pad_nd(tensor, pad, 0)
+        view_size = tensor_size[:dim] + torch.Size([num_chunks, -1])
+        padded_tensors.append(tensor.view(view_size))
+    return padded_tensors
+
+
+def have_same_ndims(tensors: List[Tensor]):
+    ndim = tensors[0].ndim
+    for tensor in tensors:
+        if tensor.ndim != ndim:
+            return False
+    return True
+
+
+def leading_dimension_matches(tensors: List[Tensor], dim: int):
+    leading_dim_sizes = tensors[0].size()[:dim]
+    for tensor in tensors:
+        torch._check(
+            tensor.size()[:dim] == leading_dim_sizes,
+            lambda: "_chunk_cat expects same sizes of 0,...,dim-1 dimensions for all tensors",
+        )
+
+
+def _preprocess_chunk_cat_inputs(
+    tensors: List[Tensor],
+    dim: int,
+    num_chunks: int,
+):
+    torch._check(num_chunks >= 1, lambda: "_chunk_cat expects positive num_chunks")
+    torch._check(
+        len(tensors) > 0, lambda: "_chunk_cat expects a non-empty input tensor list"
+    )
+    expected_dtype = tensors[0].dtype
+    expected_device = tensors[0].device
+    for tensor in tensors:
+        torch._check(tensor.numel() > 0, lambda: "_chunk_cat expects non-empty tensor")
+        torch._check(
+            tensor.dtype == expected_dtype,
+            lambda: "_chunk_cat expects all input tensors with the same dtype",
+        )
+        torch._check(
+            tensor.device == expected_device,
+            lambda: "_chunk_cat expects all inputs tensors on the same device",
+        )
+    if have_same_ndims(tensors):
+        dim = utils.canonicalize_dim(tensors[0].dim(), dim)
+    else:
+        torch._check(
+            dim >= 0,
+            lambda: "_chunk_cat expects non-negative dim when input tensors have different ndims",
+        )
+        for tensor in tensors:
+            torch._check(
+                dim < tensor.ndim,
+                lambda: "_chunk_cat expects dim < ndim for all input tensors",
+            )
+    leading_dimension_matches(tensors, dim)
+    return dim
+
+
+@register_decomposition([aten._chunk_cat.default, aten._chunk_cat.out])
+def _chunk_cat(
+    tensors: List[Tensor],
+    dim: int,
+    num_chunks: int,
+    out: Optional[Tensor] = None,
+) -> Tensor:
+    dim = _preprocess_chunk_cat_inputs(tensors, dim, num_chunks)
+    padded_tensors = _pad_chunk(tensors, dim, num_chunks)
+    if out is None:
+        return torch.cat(padded_tensors, dim + 1)
+    else:
+        torch.cat(padded_tensors, dim + 1, out=out)
+        return out
+
+
 @register_decomposition(aten.split_with_sizes)
 def split_with_sizes(
     self: Tensor, split_sizes: List[int], dim: int = 0
diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py
index c893554..67dbe2b 100644
--- a/torch/_dynamo/trace_rules.py
+++ b/torch/_dynamo/trace_rules.py
@@ -1229,6 +1229,7 @@
         "torch._cast_Long",
         "torch._cast_Short",
         "torch._choose_qparams_per_tensor",
+        "torch._chunk_cat",
         "torch._coalesce",
         "torch._compute_linear_combination",
         "torch._conj_copy",
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 4ff7266..59cfee4 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -2298,6 +2298,115 @@
         for dim in range(-1, len(shape) - 1):
             yield SampleInput(tensors, args=(dim,))
 
+
+def sample_inputs_chunk_cat(op_info, device, dtype, requires_grad, **kwargs):
+    # 1. If input tensors have different ndims, dim should be non-negative and be less than the ndims of every input tensors.
+    #    If all input tensors have the same ndims, we support both negative and non-negative dim.
+    # 2. For wrapped_dim, all tensors should have the same size for 0,...,wrapped_dim-1 dimensions.
+    #        No requirements for (wrapped_dim, ...)-th dimension.
+    # 3. Expect positive num_chunks
+    # 4. Expect non-empty input tensor list and each input tensor should have at least 1 element
+    make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+    same_ndim_cases = (
+        (
+            [
+                torch.Size([1, 2, 3]),
+                torch.Size([1, 2, 3]),
+            ], -1, 5
+        ),
+        (
+            [
+                torch.Size([1, 2, 3]),
+                torch.Size([1, 2, 3]),
+            ], 1, 5
+        ),
+        (
+            [
+                torch.Size([3, 3, 2, 1]),
+                torch.Size([1, 4, 2, 2]),
+                torch.Size([2, 1, 3, 3]),
+            ], 0, 2
+        ),
+    )
+    for sizes, dim, num_chunks in same_ndim_cases:
+        tensors = []
+        for size in sizes:
+            tensors.append(make_arg(size))
+        yield SampleInput(tensors, args=(dim, num_chunks))
+
+    different_ndim_case = [
+        torch.Size([2, 3, 3]),
+        torch.Size([2, 3, 1, 2]),
+        torch.Size([2, 3]),
+        torch.Size([2, 3, 2]),
+    ]
+    max_dim, num_chunks = 2, 3
+    for dim in range(max_dim):
+        tensors = []
+        for size in different_ndim_case:
+            tensors.append(make_arg(size))
+        yield SampleInput(tensors, args=(dim, num_chunks))
+
+
+def error_inputs_chunk_cat(op_info, device, **kwargs):
+    make_arg = partial(make_tensor, device=device, dtype=torch.float32)
+
+    # input tensors have different ndims but dim is negative
+    sizes, dim, num_chunks = [torch.Size([2, 3]), torch.Size([4,])], -1, 3
+    tensors = [make_arg(size) for size in sizes]
+    yield ErrorInput(
+        SampleInput(tensors, args=(dim, num_chunks)),
+        error_regex='_chunk_cat expects non-negative dim when input tensors have different ndims',
+    )
+
+    # input tensors have different ndims but dim >= ndim of some input tensors
+    sizes, dim, num_chunks = [torch.Size([2, 3]), torch.Size([4,])], 1, 3
+    tensors = [make_arg(size) for size in sizes]
+    yield ErrorInput(
+        SampleInput(tensors, args=(dim, num_chunks)),
+        error_regex='_chunk_cat expects dim < ndim for all input tensors',
+    )
+
+    # some tensors have different sizes for 0, ..., dim-1 dimensions.
+    sizes, dim, num_chunks = [torch.Size([2, 3, 4]), torch.Size([4, 3])], 1, 3
+    tensors = [make_arg(size) for size in sizes]
+    yield ErrorInput(
+        SampleInput(tensors, args=(dim, num_chunks)),
+        error_regex='_chunk_cat expects same sizes of 0,...,dim-1 dimensions for all tensors',
+    )
+
+    # negative num_chunks
+    sizes, dim, num_chunks = [torch.Size([2,]), torch.Size([3,])], 0, -1
+    tensors = [make_arg(size) for size in sizes]
+    yield ErrorInput(
+        SampleInput(tensors, args=(dim, num_chunks)),
+        error_regex='_chunk_cat expects positive num_chunks',
+    )
+
+    # zero as num_chunks
+    sizes, dim, num_chunks = [torch.Size([2,]), torch.Size([3,])], 0, 0
+    tensors = [make_arg(size) for size in sizes]
+    yield ErrorInput(
+        SampleInput(tensors, args=(dim, num_chunks)),
+        error_regex='_chunk_cat expects positive num_chunks',
+    )
+
+    # empty input tensor list
+    dim, num_chunks = 0, 1
+    yield ErrorInput(
+        SampleInput([], args=(dim, num_chunks)),
+        error_regex='_chunk_cat expects a non-empty input tensor list',
+    )
+
+    # empty input tensor with 0 elements
+    sizes, dim, num_chunks = [torch.Size([0,]), torch.Size([3,])], 0, 1
+    tensors = [make_arg(size) for size in sizes]
+    yield ErrorInput(
+        SampleInput(tensors, args=(dim, num_chunks)),
+        error_regex='_chunk_cat expects non-empty tensor',
+    )
+
+
 def sample_inputs_cat_concat(op_info, device, dtype, requires_grad, **kwargs):
     make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
 
@@ -17482,6 +17591,13 @@
                DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
            ),
            ),
+    OpInfo('_chunk_cat',
+           dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
+           sample_inputs_func=sample_inputs_chunk_cat,
+           error_inputs_func=error_inputs_chunk_cat,
+           supports_autograd=False,
+           supports_out=True,
+           ),
     OpInfo('hstack',
            dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
            sample_inputs_func=sample_inputs_hstack_dstack_vstack,