| import torch |
| import torch.distributed as dist |
| from torch.autograd import Function |
| # The two imports below are not always available depending on the |
| # USE_DISTRIBUTED compile flag. Make sure they raise import error |
| # if we're trying to use them. |
| from torch.distributed import group, ReduceOp |
| |
| def broadcast(tensor, src, group=group.WORLD): |
| """ |
| Broadcasts the tensor to the whole group. |
| |
| ``tensor`` must have the same number of elements in all processes |
| participating in the collective. |
| |
| Arguments: |
| tensor (Tensor): Data to be sent if ``src`` is the rank of current |
| process. |
| src (int): Source rank. |
| group (ProcessGroup, optional): The process group to work on. |
| |
| Returns: |
| Tensor: Received tensor from the broadcast op. |
| |
| """ |
| return _Broadcast.apply(src, group, tensor) |
| |
| |
| def gather(tensor, dst=0, group=group.WORLD): |
| """ |
| Gathers a list of tensors in a single process. |
| |
| Arguments: |
| tensor (Tensor): Input tensor. |
| dst (int, optional): Destination rank (default is 0). |
| group (ProcessGroup, optional): The process group to work on. |
| |
| Returns: |
| tuple[Tensor]: List of appropriately-sized tensors with the gathered data. |
| """ |
| return _Gather.apply(dst, group, tensor) |
| |
| |
| def scatter(tensors, src=0, group=group.WORLD): |
| """ |
| Scatters a list of tensors to all processes in a group. |
| |
| Each process will receive exactly one tensor and store its data in the |
| ``tensor`` argument. |
| |
| Arguments: |
| tensors (list[Tensor]): List of tensors to scatter on the source rank. |
| Receivers must pass ``None`. |
| src (int, optional): Source rank (default is 0). |
| group (ProcessGroup, optional): The process group to work on. |
| |
| Returns: |
| Tensor: Output tensor from the scatter operation. |
| |
| """ |
| return _Scatter.apply(src, group, *tensors) |
| |
| |
| def reduce(tensor, dst, op=ReduceOp.SUM, group=group.WORLD): |
| """ |
| Reduces the tensor data across all machines. |
| |
| Only the process with rank ``dst`` is going to receive the final result. |
| |
| Arguments: |
| tensor (Tensor): Input of the collective. |
| dst (int): Destination rank. |
| op (optional): One of the values from |
| ``torch.distributed.ReduceOp`` |
| enum. Specifies an operation used for element-wise reductions. |
| group (ProcessGroup, optional): The process group to work on. |
| |
| Returns: |
| Tensor: Output of the collective. |
| |
| """ |
| return _Reduce.apply(dst, op, group, tensor) |
| |
| |
| def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=group.WORLD): |
| """ |
| Reduces, then scatters a list of tensors to all processes in a group. |
| |
| Arguments: |
| output (Tensor): Output tensor. |
| input_list (list[Tensor]): List of tensors to reduce and scatter. |
| op (optional): One of the values from |
| ``torch.distributed.ReduceOp`` |
| enum. Specifies an operation used for element-wise reductions. |
| group (ProcessGroup, optional): The process group to work on. |
| |
| Returns: |
| Tensor: Output of the collective. |
| |
| """ |
| return _Reduce_Scatter.apply(op, group, output, *input_list) |
| |
| |
| def all_gather(tensor, group=group.WORLD): |
| """ |
| Gathers tensors from the whole group in a list. |
| |
| Arguments: |
| tensor (Tensor): Tensor to be broadcast from current process. |
| group (ProcessGroup, optional): The process group to work on. |
| |
| Returns: |
| tuple([Tensor]): Output of the collective. |
| |
| """ |
| return _AllGather.apply(group, tensor) |
| |
| def _all_gather_base(output_tensor, input_tensor, group=group.WORLD): |
| """ |
| Single tensor all gather. Gathers a single tensor from all ranks, and puts them in a single output tensor. |
| |
| Args: |
| output_tensor (Tensor): Output tensor. It should contain |
| correctly-sized tensors to be used for output of the collective. |
| input_tensor (Tensor): Tensor to be broadcast from current process. |
| group (ProcessGroup, optional): The process group to work on. If None, |
| the default process group will be used. |
| |
| Examples: |
| >>> # All tensors below are of torch.int64 dtype. |
| >>> # We have 2 process groups, 2 ranks. |
| >>> # xdoctest: +SKIP("incorrect want text") |
| >>> output_tensor = torch.zeros(2, dtype=torch.int64) |
| >>> output_tensor |
| [tensor([0, 0])] # Rank 0 and 1 |
| >>> tensor = torch.arange(1, dtype=torch.int64) + 1 + rank |
| >>> tensor |
| tensor([1]) # Rank 0 |
| tensor([2]) # Rank 1 |
| >>> dist.all_gather_base(output_tensor, tensor) |
| >>> output_tensor |
| tensor([1,2]) # Rank 0 |
| tensor([1,2]) # Rank 1 |
| |
| .. warning:: |
| `_all_gather_base` is experimental and subject to change. |
| It is the caller's responsibility to ensure the output_tensor |
| is correctly sized. |
| |
| """ |
| return _AllGatherBase.apply(output_tensor, input_tensor, group) |
| |
| |
| def all_to_all(output_tensor_list, input_tensor_list, group=group.WORLD): |
| """ |
| Each process scatters list of input tensors to all processes in a group and |
| return gathered list of tensors in output list. |
| |
| Arguments: |
| output_tensor_list (list[Tensor]): list of tensors to gather one per rank. |
| input_tensor_list (list[Tensor]): List of tensors to scatter one per rank. |
| group (ProcessGroup, optional): The process group to work on. |
| |
| Returns: |
| tuple([Tensor]): Output of the collective. |
| |
| """ |
| return _AlltoAll.apply(group, output_tensor_list, *input_tensor_list) |
| |
| |
| def all_to_all_single( |
| output, |
| input, |
| output_split_sizes=None, |
| input_split_sizes=None, |
| group=group.WORLD, |
| ): |
| """ |
| Each process splits input tensor and then scatters the split list |
| to all processes in a group. Then concatenate the received tensors from all |
| the processes in the group and return single output tensor. |
| |
| Arguments: |
| output (Tensor): Gathered concatenated output tensor. |
| input (Tensor): Input tensor to scatter. |
| output_split_sizes: (list[Int], optional): Output split sizes for dim 0 |
| if specified None or empty, dim 0 of ``output`` tensor must divide |
| equally by ``world_size``. |
| input_split_sizes: (list[Int], optional): Input split sizes for dim 0 |
| if specified None or empty, dim 0 of ``input`` tensor must divide |
| equally by ``world_size``. |
| |
| Returns: |
| Tensor: Output of the collective. |
| |
| """ |
| return _AlltoAllSingle.apply( |
| group, output, output_split_sizes, input_split_sizes, input |
| ) |
| |
| |
| def all_reduce(tensor, op=ReduceOp.SUM, group=group.WORLD): |
| """ |
| Reduces the tensor data across all machines in such a way that all get |
| the final result. |
| |
| After the call the returned tensor is going to be bitwise |
| identical in all processes. |
| |
| Arguments: |
| tensor (Tensor): Input of the collective. |
| op (optional): One of the values from |
| ``torch.distributed.ReduceOp`` |
| enum. Specifies an operation used for element-wise reductions. |
| group (ProcessGroup, optional): The process group to work on. |
| |
| Returns: |
| Tensor: Output of the collective |
| |
| """ |
| return _AllReduce.apply(op, group, tensor) |
| |
| |
| class _Broadcast(Function): |
| @staticmethod |
| def forward(ctx, src, group, tensor): |
| ctx.src = src |
| ctx.group = group |
| ctx.rank = dist.get_rank() |
| # torch.distributed makes all the calls in place |
| # we allocate new tensors to avoid this |
| tensor = tensor.clone() |
| dist.broadcast(tensor, src, group=group) |
| return tensor |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| gx = _Reduce.apply(ctx.src, ReduceOp.SUM, ctx.group, grad_output) |
| if ctx.src != ctx.rank: |
| gx.zero_() |
| return (None, None, gx) |
| |
| |
| class _Gather(Function): |
| @staticmethod |
| def forward(ctx, dst, group, tensor): |
| ctx.dst = dst |
| ctx.group = group |
| # Need to create a list of tensors here to do the |
| # aggregation, get it from the group size |
| # tensor should be correctly sized for the method |
| # gathering |
| tensor_list = [ |
| torch.zeros_like(tensor) for i in range(dist.get_world_size(group=group)) |
| ] |
| |
| tensor = tensor.contiguous() |
| if dist.get_rank(group=group) == dst: |
| dist.gather(tensor, tensor_list, dst, group=group) |
| else: |
| dist.gather(tensor, None, dst, group=group) |
| return tuple(tensor_list) |
| |
| @staticmethod |
| def backward(ctx, *grad_outputs): |
| return (None, None) + (_Scatter.apply(ctx.dst, ctx.group, *grad_outputs),) |
| |
| |
| class _Scatter(Function): |
| @staticmethod |
| def forward(ctx, src, group, *tensors): |
| ctx.src = src |
| ctx.group = group |
| assert all(t.size() == tensors[0].size() for t in tensors) |
| output = torch.zeros_like(tensors[0]) |
| if dist.get_rank(group=group) == src: |
| dist.scatter(output, list(tensors), src, group=group) |
| else: |
| dist.scatter(output, None, src, group=group) |
| return output |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| return (None, None) + _Gather.apply(ctx.src, ctx.group, grad_output) |
| |
| |
| class _Reduce(Function): |
| @staticmethod |
| def forward(ctx, src, op, group, tensor): |
| ctx.src = src |
| ctx.group = group |
| tensor = tensor.clone() |
| dist.reduce(tensor, src, op=op, group=group) |
| return tensor |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| return (None, None, None) + (_Broadcast.apply(ctx.src, ctx.group, grad_output),) |
| |
| |
| class _Reduce_Scatter(Function): |
| @staticmethod |
| def forward(ctx, op, group, tensor, *input_tensor_list): |
| ctx.group = group |
| input_tensor_list = tuple(t.contiguous() for t in input_tensor_list) |
| dist.reduce_scatter(tensor, list(input_tensor_list), op=op, group=group) |
| return tensor |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| return (None, None, None) + _AllGather.apply(ctx.group, grad_output) |
| |
| |
| class _AllGather(Function): |
| @staticmethod |
| def forward(ctx, group, tensor): |
| # Need contiguous tensors for collectives. |
| tensor = tensor.contiguous() |
| |
| ctx.group = group |
| out_tensor_list = [ |
| torch.empty_like(tensor) for _ in range(dist.get_world_size(group=group)) |
| ] |
| |
| dist.all_gather(out_tensor_list, tensor, group=group) |
| return tuple(out_tensor_list) |
| |
| @staticmethod |
| def backward(ctx, *grad_outputs): |
| if dist.get_backend(group=ctx.group) is dist.Backend.NCCL: |
| rank = dist.get_rank() |
| gx = torch.empty_like(grad_outputs[rank]) |
| _Reduce_Scatter.apply(ReduceOp.SUM, ctx.group, gx, *grad_outputs) |
| else: |
| # As many backends doesn't support ReduceScatter, we use AlltoAll with .sum() |
| # to emulate the ReduceScatter behavior |
| tensor_list = [torch.empty_like(tensor) for tensor in grad_outputs] |
| gxs = _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs) |
| gx = torch.sum(torch.stack(gxs), dim=0) |
| return (None, gx) |
| |
| class _AllGatherBase(Function): |
| @staticmethod |
| def forward(ctx, output_tensor, input_tensor, group): |
| ctx.group = group |
| dist._all_gather_base(output_tensor, input_tensor.contiguous(), group=group) |
| return output_tensor |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| if dist.get_backend(group=ctx.group) is dist.Backend.NCCL: |
| world_size = dist.get_world_size(group=ctx.group) |
| out_size = list(grad_output.size()) |
| if out_size[0] % world_size != 0: |
| raise RuntimeError( |
| f'Tensor with dimensions: {out_size} does ' |
| f'not have first dimension divisible by world_size: {world_size}' |
| ) |
| out_size[0] = out_size[0] // dist.get_world_size(group=ctx.group) |
| gx = torch.empty(out_size, device=grad_output.device, dtype=grad_output.dtype) |
| dist._reduce_scatter_base(gx, grad_output, ReduceOp.SUM, ctx.group) |
| else: |
| raise RuntimeError("Backend not supported!") |
| return (None, gx, None) |
| |
| class _AlltoAll(Function): |
| @staticmethod |
| def forward(ctx, group, out_tensor_list, *tensors): |
| ctx.group = group |
| ctx.input_tensor_size_list = [ |
| tensors[i].size() for i in range(dist.get_world_size(group=group)) |
| ] |
| my_rank = dist.get_rank(group=group) |
| tensors = tuple(t.contiguous() for t in tensors) |
| # Implement it on means of scatter/gather, send/recv async operations have issues |
| if dist.get_backend(group=group) is dist.Backend.GLOO: |
| for i in range(dist.get_world_size(group=group)): |
| to_send = None |
| if i == my_rank: |
| to_send = list(tensors) |
| dist.scatter(out_tensor_list[i], to_send, i, group=group) |
| else: |
| dist.all_to_all( |
| out_tensor_list, |
| list(tensors), |
| group=group, |
| ) |
| return tuple(out_tensor_list) |
| |
| @staticmethod |
| def backward(ctx, *grad_outputs): |
| tensor_list = [ |
| torch.empty(size, device=grad_outputs[0].device, dtype=grad_outputs[0].dtype) |
| for size in ctx.input_tensor_size_list |
| ] |
| return (None, None) + _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs) |
| |
| |
| class _AlltoAllSingle(Function): |
| @staticmethod |
| def forward(ctx, group, output, output_split_sizes, input_split_sizes, input): |
| ctx.group = group |
| ctx.input_size = input.size() |
| ctx.output_split_sizes = input_split_sizes |
| ctx.input_split_sizes = output_split_sizes |
| dist.all_to_all_single( |
| output, |
| input, |
| output_split_sizes=output_split_sizes, |
| input_split_sizes=input_split_sizes, |
| group=group, |
| ) |
| return output |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| tensor = torch.empty(ctx.input_size, device=grad_output.device, dtype=grad_output.dtype) |
| return (None, None, None, None) + ( |
| _AlltoAllSingle.apply( |
| ctx.group, |
| tensor, |
| ctx.output_split_sizes, |
| ctx.input_split_sizes, |
| grad_output.contiguous(), |
| ), |
| ) |
| |
| |
| class _AllReduce(Function): |
| @staticmethod |
| def forward(ctx, op, group, tensor): |
| ctx.group = group |
| ctx.op = op |
| tensor = tensor.clone() |
| dist.all_reduce(tensor, op=op, group=group) |
| return tensor |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| return (None, None) + (_AllReduce.apply(ctx.op, ctx.group, grad_output),) |