Add ATen Op _chunk_cat and _chunk_cat.out (#121081)
# Motivation
In backward of per-parameter sharding FSDP, each rank performs reduce scatter to sync gradients across ranks. A rank chunks each gradient tensor into `world_size` slices along the 0-th dimension and concatenate all slices along the 1-th dimension. Gradient tensors will be padded before concatenation when tensor.size(0) % world_size != 0.
### Example 1
Consider `world_size=3` and tensors A (2x4), B (3x3), C (1x2):
Input tensors:
```
AAAA BBB CC
AAAA BBB
BBB
```
Reduce-scatter-copy-in Output:
```
AAAABBBCC
AAAABBB00
0000BBB00
```
### Example 2
Consider `world_size=2` and tensors A (2x4), B (3x3), C(1x2), D(4x2):
Input tensors:
```
AAAA BBB CC DD
AAAA BBB 00 DD
BBB DD
000 DD
```
Reduce-scatter-copy-in first pad:
```
AAAA BBB CC DD
AAAA BBB 00 DD
BBB DD
000 DD
```
Then chunk and cat along dim as the output:
```
AAAABBBBBBCCDDDD
AAAABBB00000DDDD
```
The performance of reduce-scatter-copy-in is critical to per-parameter sharding FSDP. However, reduce-scatter-copy-in via composing existing ATen ops involves `cat` and irregular `pad`, leading redundant data copies and unsatisfactory performance.
# PR
We provide aten native support for reduce-scatter-copy-in, namely `_chunk_cat()`:
```
_chunk_cat(Tensor[] tensors, int dim, int num_chunks) -> Tensor
```
This PR includes the registration of `_chunk_cat` and `_chunk_cat.out`, OpInfo tests, and basic implementation composing existing ATen ops.
In the next PR, we will add the CUDA implementation. Comparing with baselines of composing existing ATen ops, `_chunk_cat()` CUDA implementation improves copy bandwidth from 498 GB/s to 966 GB/s on a production benchmark.
## Requirements on input
1. If input tensors have different ndims, dim should be non-negative and be less than the ndims of every input tensors. If all input tensors have the same ndims, we support both negative and non-negative dim.
2. For wrapped_dim, all tensors should have the same size for 0,...,wrapped_dim-1 dimensions. No requirements for (wrapped_dim, ...)-th dimension.
3. Expect positive num_chunks
4. Expect non-empty input tensor list and each input tensor should have at least 1 element
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121081
Approved by: https://github.com/albanD
diff --git a/test/test_mps.py b/test/test_mps.py
index 3c5e47c..63bc604 100644
--- a/test/test_mps.py
+++ b/test/test_mps.py
@@ -71,6 +71,7 @@
# Unimplemented ops
'__getitem__': [torch.float16],
'_segment_reduce': [torch.float16, torch.float32],
+ '_chunk_cat': [torch.float16, torch.float32],
'unfold_copy': [torch.float16, torch.float32], # unfold_backward is not implemented
'unfold': [torch.float16, torch.float32],
'sparse.mmreduce': [torch.float32], # csr not supported
@@ -342,6 +343,7 @@
AFTER_MACOS_14_0_SUPPORTED_COMPLEX_OPS = {
'__rdiv__',
+ '_chunk_cat',
'acos',
'acosh',
'all',