Add ATen Op _chunk_cat and _chunk_cat.out (#121081)
# Motivation
In backward of per-parameter sharding FSDP, each rank performs reduce scatter to sync gradients across ranks. A rank chunks each gradient tensor into `world_size` slices along the 0-th dimension and concatenate all slices along the 1-th dimension. Gradient tensors will be padded before concatenation when tensor.size(0) % world_size != 0.
### Example 1
Consider `world_size=3` and tensors A (2x4), B (3x3), C (1x2):
Input tensors:
```
AAAA BBB CC
AAAA BBB
BBB
```
Reduce-scatter-copy-in Output:
```
AAAABBBCC
AAAABBB00
0000BBB00
```
### Example 2
Consider `world_size=2` and tensors A (2x4), B (3x3), C(1x2), D(4x2):
Input tensors:
```
AAAA BBB CC DD
AAAA BBB 00 DD
BBB DD
000 DD
```
Reduce-scatter-copy-in first pad:
```
AAAA BBB CC DD
AAAA BBB 00 DD
BBB DD
000 DD
```
Then chunk and cat along dim as the output:
```
AAAABBBBBBCCDDDD
AAAABBB00000DDDD
```
The performance of reduce-scatter-copy-in is critical to per-parameter sharding FSDP. However, reduce-scatter-copy-in via composing existing ATen ops involves `cat` and irregular `pad`, leading redundant data copies and unsatisfactory performance.
# PR
We provide aten native support for reduce-scatter-copy-in, namely `_chunk_cat()`:
```
_chunk_cat(Tensor[] tensors, int dim, int num_chunks) -> Tensor
```
This PR includes the registration of `_chunk_cat` and `_chunk_cat.out`, OpInfo tests, and basic implementation composing existing ATen ops.
In the next PR, we will add the CUDA implementation. Comparing with baselines of composing existing ATen ops, `_chunk_cat()` CUDA implementation improves copy bandwidth from 498 GB/s to 966 GB/s on a production benchmark.
## Requirements on input
1. If input tensors have different ndims, dim should be non-negative and be less than the ndims of every input tensors. If all input tensors have the same ndims, we support both negative and non-negative dim.
2. For wrapped_dim, all tensors should have the same size for 0,...,wrapped_dim-1 dimensions. No requirements for (wrapped_dim, ...)-th dimension.
3. Expect positive num_chunks
4. Expect non-empty input tensor list and each input tensor should have at least 1 element
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121081
Approved by: https://github.com/albanD
diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp
index e492bd6..1873201 100644
--- a/aten/src/ATen/native/TensorShape.cpp
+++ b/aten/src/ATen/native/TensorShape.cpp
@@ -36,9 +36,11 @@
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
+#include <ATen/ops/_chunk_cat_native.h>
#include <ATen/ops/_conj_copy_native.h>
#include <ATen/ops/_convert_indices_from_coo_to_csr.h>
#include <ATen/ops/_convert_indices_from_csr_to_coo.h>
+#include <ATen/ops/_foreach_copy.h>
#include <ATen/ops/_fw_primal_copy_native.h>
#include <ATen/ops/_indices_copy_native.h>
#include <ATen/ops/_make_dual.h>
@@ -2723,6 +2725,38 @@
}
}
+// Pads each tensor on `dim`-th dimension such that padded_dim % num_chunks == 0.
+static std::vector<Tensor> _pad_chunk(TensorList tensors, int64_t dim, int64_t num_chunks) {
+ auto num_tensors = tensors.size();
+ std::vector<Tensor> padded_tensors;
+ padded_tensors.reserve(num_tensors);
+ for (const auto & tensor : tensors) {
+ auto tensor_size = tensor.sizes();
+ std::vector<int64_t> padded_size(tensor_size.vec());
+ padded_size[dim] = (tensor_size[dim] + num_chunks - 1) / num_chunks * num_chunks;
+ Tensor padded_tensor = tensor;
+ if (padded_size != tensor_size) {
+ padded_tensor = tensor.new_zeros(padded_size);
+ padded_tensor.narrow(dim, 0, tensor_size[dim]).copy_(tensor);
+ }
+ std::vector<int64_t> view_sizes(tensor_size.begin(), tensor_size.begin()+dim);
+ view_sizes.insert(view_sizes.end(), {num_chunks, -1});
+ padded_tensors.push_back(padded_tensor.view(view_sizes));
+ }
+ return padded_tensors;
+}
+
+Tensor _chunk_cat(TensorList tensors, int64_t dim, int64_t num_chunks) {
+ auto wrapped_dim = at::native::preprocess_chunk_cat_inputs(tensors, dim, num_chunks);
+ return at::cat(_pad_chunk(tensors, wrapped_dim, num_chunks), wrapped_dim+1);
+}
+
+Tensor& _chunk_cat_out(TensorList tensors, int64_t dim, int64_t num_chunks, Tensor& out) {
+ auto wrapped_dim = at::native::preprocess_chunk_cat_inputs(tensors, dim, num_chunks);
+ at::cat_out(out, _pad_chunk(tensors, wrapped_dim, num_chunks), wrapped_dim+1);
+ return out;
+}
+
// TODO(msubkhankulov): refactor to use _stack
Tensor stack(TensorList tensors, int64_t dim) {
TORCH_CHECK(!tensors.empty(),
diff --git a/aten/src/ATen/native/TensorShape.h b/aten/src/ATen/native/TensorShape.h
index 1c84abb..638ddba 100644
--- a/aten/src/ATen/native/TensorShape.h
+++ b/aten/src/ATen/native/TensorShape.h
@@ -55,4 +55,51 @@
return num_splits;
}
+inline bool have_same_ndims(TensorList tensors) {
+ auto ndim = tensors[0].dim();
+ for (const auto tensor_idx : c10::irange(tensors.size())) {
+ if(tensors[tensor_idx].dim() != ndim) {
+ return false;
+ }
+ }
+ return true;
+}
+
+inline void leading_dimension_matches(TensorList tensors, int64_t dim) {
+ auto tensor_zero_size = tensors[0].sizes();
+ std::vector<c10::SymInt> leading_dim_sizes(tensor_zero_size.begin(), tensor_zero_size.begin() + dim);
+ for (const auto i : c10::irange(tensors.size())) {
+ at::Tensor tensor = tensors[i];
+ for(const auto j : c10::irange(dim)) {
+ TORCH_CHECK(
+ tensor.size(j) == leading_dim_sizes[j],
+ "_chunk_cat expects same sizes of 0,...,dim-1 dimensions for all tensors"
+ );
+ }
+ }
+}
+
+inline int64_t preprocess_chunk_cat_inputs(TensorList tensors, int64_t dim, int64_t num_chunks) {
+ TORCH_CHECK(num_chunks >= 1, "_chunk_cat expects positive num_chunks");
+ TORCH_CHECK(!tensors.empty(),
+ "_chunk_cat expects a non-empty input tensor list");
+ auto expected_dtype = tensors[0].dtype();
+ auto expected_device = tensors[0].device();
+ for(const auto i : c10::irange(tensors.size())) {
+ TORCH_CHECK(tensors[i].numel() > 0, "_chunk_cat expects non-empty tensor");
+ TORCH_CHECK(tensors[i].dtype() == expected_dtype, "_chunk_cat expects all input tensors with the same dtype");
+ TORCH_CHECK(tensors[i].device() == expected_device, "_chunk_cat expects all inputs tensors on the same device");
+ }
+ if (have_same_ndims(tensors)) {
+ dim = maybe_wrap_dim(dim, tensors[0].dim());
+ } else {
+ TORCH_CHECK(dim >= 0, "_chunk_cat expects non-negative dim when input tensors have different ndims")
+ for(const auto i : c10::irange(tensors.size())) {
+ TORCH_CHECK(dim < tensors[i].ndimension(), "_chunk_cat expects dim < ndim for all input tensors");
+ }
+ }
+ leading_dimension_matches(tensors, dim);
+ return dim;
+}
+
} // namespace at::native
diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml
index 99f6a3a..5d90f64 100644
--- a/aten/src/ATen/native/native_functions.yaml
+++ b/aten/src/ATen/native/native_functions.yaml
@@ -5617,6 +5617,14 @@
SparseCPU: _sspaddmm_out_cpu
SparseCUDA: _sspaddmm_out_cuda
+- func: _chunk_cat(Tensor[] tensors, int dim, int num_chunks) -> Tensor
+ dispatch:
+ CompositeExplicitAutograd: _chunk_cat
+
+- func: _chunk_cat.out(Tensor[] tensors, int dim, int num_chunks, *, Tensor(a!) out) -> Tensor(a!)
+ dispatch:
+ CompositeExplicitAutograd: _chunk_cat_out
+
- func: stack(Tensor[] tensors, int dim=0) -> Tensor
dispatch:
CompositeExplicitAutograd: stack
diff --git a/test/distributed/_tensor/test_dtensor_ops.py b/test/distributed/_tensor/test_dtensor_ops.py
index 66b78f0..afc6c17 100644
--- a/test/distributed/_tensor/test_dtensor_ops.py
+++ b/test/distributed/_tensor/test_dtensor_ops.py
@@ -94,6 +94,7 @@
# get full support with varying sharding specs
xfail("__getitem__"),
xfail("__rsub__"),
+ xfail("_chunk_cat"),
xfail("_native_batch_norm_legit"),
xfail("_upsample_bilinear2d_aa"),
xfail("addbmm"),
diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py
index 3e7120a..d3231df 100644
--- a/test/test_fx_experimental.py
+++ b/test/test_fx_experimental.py
@@ -1564,7 +1564,7 @@
@ops(op_db, allowed_dtypes=(torch.float,))
def test_normalize_operator_exhaustive(self, device, dtype, op):
# These ops currently don't trace in FX for various reasons (i.e. they take a list of tensors)
- fx_fail = {"cat", "stack", "hstack", "vstack", "dstack", "linalg.multi_dot", "_upsample_bilinear2d_aa"}
+ fx_fail = {"cat", "stack", "hstack", "vstack", "dstack", "linalg.multi_dot", "_upsample_bilinear2d_aa", "_chunk_cat"}
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
if isinstance(op.op, torch._ops.OpOverload):
self.skipTest("normalize operator doesn't work on torch.ops")
diff --git a/test/test_mps.py b/test/test_mps.py
index 3c5e47c..63bc604 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -71,6 +71,7 @@
# Unimplemented ops
'__getitem__': [torch.float16],
'_segment_reduce': [torch.float16, torch.float32],
+ '_chunk_cat': [torch.float16, torch.float32],
'unfold_copy': [torch.float16, torch.float32], # unfold_backward is not implemented
'unfold': [torch.float16, torch.float32],
'sparse.mmreduce': [torch.float32], # csr not supported
@@ -342,6 +343,7 @@
AFTER_MACOS_14_0_SUPPORTED_COMPLEX_OPS = {
'__rdiv__',
+ '_chunk_cat',
'acos',
'acosh',
'all',
diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py
index 9146f40..24092c4 100644
--- a/torch/_decomp/__init__.py
+++ b/torch/_decomp/__init__.py
@@ -457,6 +457,7 @@
aten.zero_,
aten.zeros,
aten.zeros_like,
+ aten._chunk_cat,
aten._weight_norm_interface,
]
)
diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py
index 5c66749..bacdab3 100644
--- a/torch/_decomp/decompositions.py
+++ b/torch/_decomp/decompositions.py
@@ -1212,6 +1212,97 @@
return r
+def _pad_chunk(
+ tensors: List[Tensor],
+ dim: int,
+ num_chunks: int,
+) -> List[Tensor]:
+ padded_tensors = []
+ for tensor in tensors:
+ tensor_size = tensor.size()
+ pad_along_dim = (tensor_size[dim] + num_chunks - 1) // num_chunks * num_chunks
+ if pad_along_dim != tensor_size[dim]:
+ # Use aten.constant_pad_nd instead of copy_ for functionalization
+ pad = [0] * 2 * (tensor.ndim - dim - 1) + [
+ 0,
+ pad_along_dim - tensor_size[dim],
+ ]
+ tensor = aten.constant_pad_nd(tensor, pad, 0)
+ view_size = tensor_size[:dim] + torch.Size([num_chunks, -1])
+ padded_tensors.append(tensor.view(view_size))
+ return padded_tensors
+
+
+def have_same_ndims(tensors: List[Tensor]):
+ ndim = tensors[0].ndim
+ for tensor in tensors:
+ if tensor.ndim != ndim:
+ return False
+ return True
+
+
+def leading_dimension_matches(tensors: List[Tensor], dim: int):
+ leading_dim_sizes = tensors[0].size()[:dim]
+ for tensor in tensors:
+ torch._check(
+ tensor.size()[:dim] == leading_dim_sizes,
+ lambda: "_chunk_cat expects same sizes of 0,...,dim-1 dimensions for all tensors",
+ )
+
+
+def _preprocess_chunk_cat_inputs(
+ tensors: List[Tensor],
+ dim: int,
+ num_chunks: int,
+):
+ torch._check(num_chunks >= 1, lambda: "_chunk_cat expects positive num_chunks")
+ torch._check(
+ len(tensors) > 0, lambda: "_chunk_cat expects a non-empty input tensor list"
+ )
+ expected_dtype = tensors[0].dtype
+ expected_device = tensors[0].device
+ for tensor in tensors:
+ torch._check(tensor.numel() > 0, lambda: "_chunk_cat expects non-empty tensor")
+ torch._check(
+ tensor.dtype == expected_dtype,
+ lambda: "_chunk_cat expects all input tensors with the same dtype",
+ )
+ torch._check(
+ tensor.device == expected_device,
+ lambda: "_chunk_cat expects all inputs tensors on the same device",
+ )
+ if have_same_ndims(tensors):
+ dim = utils.canonicalize_dim(tensors[0].dim(), dim)
+ else:
+ torch._check(
+ dim >= 0,
+ lambda: "_chunk_cat expects non-negative dim when input tensors have different ndims",
+ )
+ for tensor in tensors:
+ torch._check(
+ dim < tensor.ndim,
+ lambda: "_chunk_cat expects dim < ndim for all input tensors",
+ )
+ leading_dimension_matches(tensors, dim)
+ return dim
+
+
+@register_decomposition([aten._chunk_cat.default, aten._chunk_cat.out])
+def _chunk_cat(
+ tensors: List[Tensor],
+ dim: int,
+ num_chunks: int,
+ out: Optional[Tensor] = None,
+) -> Tensor:
+ dim = _preprocess_chunk_cat_inputs(tensors, dim, num_chunks)
+ padded_tensors = _pad_chunk(tensors, dim, num_chunks)
+ if out is None:
+ return torch.cat(padded_tensors, dim + 1)
+ else:
+ torch.cat(padded_tensors, dim + 1, out=out)
+ return out
+
+
@register_decomposition(aten.split_with_sizes)
def split_with_sizes(
self: Tensor, split_sizes: List[int], dim: int = 0
diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py
index c893554..67dbe2b 100644
--- a/torch/_dynamo/trace_rules.py
+++ b/torch/_dynamo/trace_rules.py
@@ -1229,6 +1229,7 @@
"torch._cast_Long",
"torch._cast_Short",
"torch._choose_qparams_per_tensor",
+ "torch._chunk_cat",
"torch._coalesce",
"torch._compute_linear_combination",
"torch._conj_copy",
diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py
index 4ff7266..59cfee4 100644
--- a/torch/testing/_internal/common_methods_invocations.py
+++ b/torch/testing/_internal/common_methods_invocations.py
@@ -2298,6 +2298,115 @@
for dim in range(-1, len(shape) - 1):
yield SampleInput(tensors, args=(dim,))
+
+def sample_inputs_chunk_cat(op_info, device, dtype, requires_grad, **kwargs):
+ # 1. If input tensors have different ndims, dim should be non-negative and be less than the ndims of every input tensors.
+ # If all input tensors have the same ndims, we support both negative and non-negative dim.
+ # 2. For wrapped_dim, all tensors should have the same size for 0,...,wrapped_dim-1 dimensions.
+ # No requirements for (wrapped_dim, ...)-th dimension.
+ # 3. Expect positive num_chunks
+ # 4. Expect non-empty input tensor list and each input tensor should have at least 1 element
+ make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+ same_ndim_cases = (
+ (
+ [
+ torch.Size([1, 2, 3]),
+ torch.Size([1, 2, 3]),
+ ], -1, 5
+ ),
+ (
+ [
+ torch.Size([1, 2, 3]),
+ torch.Size([1, 2, 3]),
+ ], 1, 5
+ ),
+ (
+ [
+ torch.Size([3, 3, 2, 1]),
+ torch.Size([1, 4, 2, 2]),
+ torch.Size([2, 1, 3, 3]),
+ ], 0, 2
+ ),
+ )
+ for sizes, dim, num_chunks in same_ndim_cases:
+ tensors = []
+ for size in sizes:
+ tensors.append(make_arg(size))
+ yield SampleInput(tensors, args=(dim, num_chunks))
+
+ different_ndim_case = [
+ torch.Size([2, 3, 3]),
+ torch.Size([2, 3, 1, 2]),
+ torch.Size([2, 3]),
+ torch.Size([2, 3, 2]),
+ ]
+ max_dim, num_chunks = 2, 3
+ for dim in range(max_dim):
+ tensors = []
+ for size in different_ndim_case:
+ tensors.append(make_arg(size))
+ yield SampleInput(tensors, args=(dim, num_chunks))
+
+
+def error_inputs_chunk_cat(op_info, device, **kwargs):
+ make_arg = partial(make_tensor, device=device, dtype=torch.float32)
+
+ # input tensors have different ndims but dim is negative
+ sizes, dim, num_chunks = [torch.Size([2, 3]), torch.Size([4,])], -1, 3
+ tensors = [make_arg(size) for size in sizes]
+ yield ErrorInput(
+ SampleInput(tensors, args=(dim, num_chunks)),
+ error_regex='_chunk_cat expects non-negative dim when input tensors have different ndims',
+ )
+
+ # input tensors have different ndims but dim >= ndim of some input tensors
+ sizes, dim, num_chunks = [torch.Size([2, 3]), torch.Size([4,])], 1, 3
+ tensors = [make_arg(size) for size in sizes]
+ yield ErrorInput(
+ SampleInput(tensors, args=(dim, num_chunks)),
+ error_regex='_chunk_cat expects dim < ndim for all input tensors',
+ )
+
+ # some tensors have different sizes for 0, ..., dim-1 dimensions.
+ sizes, dim, num_chunks = [torch.Size([2, 3, 4]), torch.Size([4, 3])], 1, 3
+ tensors = [make_arg(size) for size in sizes]
+ yield ErrorInput(
+ SampleInput(tensors, args=(dim, num_chunks)),
+ error_regex='_chunk_cat expects same sizes of 0,...,dim-1 dimensions for all tensors',
+ )
+
+ # negative num_chunks
+ sizes, dim, num_chunks = [torch.Size([2,]), torch.Size([3,])], 0, -1
+ tensors = [make_arg(size) for size in sizes]
+ yield ErrorInput(
+ SampleInput(tensors, args=(dim, num_chunks)),
+ error_regex='_chunk_cat expects positive num_chunks',
+ )
+
+ # zero as num_chunks
+ sizes, dim, num_chunks = [torch.Size([2,]), torch.Size([3,])], 0, 0
+ tensors = [make_arg(size) for size in sizes]
+ yield ErrorInput(
+ SampleInput(tensors, args=(dim, num_chunks)),
+ error_regex='_chunk_cat expects positive num_chunks',
+ )
+
+ # empty input tensor list
+ dim, num_chunks = 0, 1
+ yield ErrorInput(
+ SampleInput([], args=(dim, num_chunks)),
+ error_regex='_chunk_cat expects a non-empty input tensor list',
+ )
+
+ # empty input tensor with 0 elements
+ sizes, dim, num_chunks = [torch.Size([0,]), torch.Size([3,])], 0, 1
+ tensors = [make_arg(size) for size in sizes]
+ yield ErrorInput(
+ SampleInput(tensors, args=(dim, num_chunks)),
+ error_regex='_chunk_cat expects non-empty tensor',
+ )
+
+
def sample_inputs_cat_concat(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@@ -17482,6 +17591,13 @@
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
),
),
+ OpInfo('_chunk_cat',
+ dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
+ sample_inputs_func=sample_inputs_chunk_cat,
+ error_inputs_func=error_inputs_chunk_cat,
+ supports_autograd=False,
+ supports_out=True,
+ ),
OpInfo('hstack',
dtypes=all_types_and_complex_and(torch.complex32, torch.bool, torch.float16, torch.bfloat16),
sample_inputs_func=sample_inputs_hstack_dstack_vstack,