| import warnings |
| from typing import List, Optional |
| |
| import torch |
| from torch._utils import _get_device_index |
| from torch.autograd import Function |
| |
| from . import comm |
| |
| |
| class Broadcast(Function): |
| @staticmethod |
| def forward(ctx, target_gpus, *inputs): |
| assert all( |
| i.device.type != "cpu" for i in inputs |
| ), "Broadcast function not implemented for CPU tensors" |
| target_gpus = [_get_device_index(x, True) for x in target_gpus] |
| ctx.target_gpus = target_gpus |
| if len(inputs) == 0: |
| return tuple() |
| ctx.num_inputs = len(inputs) |
| ctx.input_device = inputs[0].get_device() |
| outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus) |
| non_differentiables = [] |
| for idx, input_requires_grad in enumerate(ctx.needs_input_grad[1:]): |
| if not input_requires_grad: |
| for output in outputs: |
| non_differentiables.append(output[idx]) |
| ctx.mark_non_differentiable(*non_differentiables) |
| return tuple([t for tensors in outputs for t in tensors]) |
| |
| @staticmethod |
| def backward(ctx, *grad_outputs): |
| return (None,) + ReduceAddCoalesced.apply( |
| ctx.input_device, ctx.num_inputs, *grad_outputs |
| ) |
| |
| |
| class ReduceAddCoalesced(Function): |
| @staticmethod |
| def forward(ctx, destination, num_inputs, *grads): |
| ctx.target_gpus = [ |
| grads[i].get_device() for i in range(0, len(grads), num_inputs) |
| ] |
| |
| grads_ = [grads[i : i + num_inputs] for i in range(0, len(grads), num_inputs)] |
| return comm.reduce_add_coalesced(grads_, destination) |
| |
| @staticmethod |
| def backward(ctx, *grad_outputs): |
| return ( |
| None, |
| None, |
| ) + Broadcast.apply(ctx.target_gpus, *grad_outputs) |
| |
| |
| class Gather(Function): |
| @staticmethod |
| def forward(ctx, target_device, dim, *inputs): |
| assert all( |
| i.device.type != "cpu" for i in inputs |
| ), "Gather function not implemented for CPU tensors" |
| if target_device == "cpu": |
| ctx.target_device = "cpu" |
| else: |
| target_device = _get_device_index(target_device, True) |
| ctx.target_device = target_device |
| ctx.dim = dim |
| ctx.input_gpus = tuple(i.get_device() for i in inputs) |
| if all(t.dim() == 0 for t in inputs) and dim == 0: |
| inputs = tuple(t.view(1) for t in inputs) |
| warnings.warn( |
| "Was asked to gather along dimension 0, but all " |
| "input tensors were scalars; will instead unsqueeze " |
| "and return a vector." |
| ) |
| ctx.unsqueezed_scalar = True |
| else: |
| ctx.unsqueezed_scalar = False |
| ctx.input_sizes = tuple(i.size(ctx.dim) for i in inputs) |
| return comm.gather(inputs, ctx.dim, ctx.target_device) |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| scattered_grads = Scatter.apply( |
| ctx.input_gpus, ctx.input_sizes, ctx.dim, grad_output |
| ) |
| if ctx.unsqueezed_scalar: |
| scattered_grads = tuple(g[0] for g in scattered_grads) |
| return (None, None) + scattered_grads |
| |
| |
| class Scatter(Function): |
| @staticmethod |
| def forward(ctx, target_gpus, chunk_sizes, dim, input): |
| target_gpus = [_get_device_index(x, True) for x in target_gpus] |
| ctx.dim = dim |
| ctx.input_device = input.get_device() if input.device.type != "cpu" else -1 |
| streams = None |
| if torch.cuda.is_available() and ctx.input_device == -1: |
| # Perform CPU to GPU copies in a background stream |
| streams = [ |
| _get_stream(torch.device("cuda", device)) for device in target_gpus |
| ] |
| outputs = comm.scatter(input, target_gpus, chunk_sizes, ctx.dim, streams) |
| # Synchronize with the copy stream |
| if streams is not None: |
| for i, output in enumerate(outputs): |
| with torch.cuda.device(target_gpus[i]): |
| main_stream = torch.cuda.current_stream() |
| main_stream.wait_stream(streams[i]) |
| output.record_stream(main_stream) |
| return outputs |
| |
| @staticmethod |
| def backward(ctx, *grad_output): |
| return None, None, None, Gather.apply(ctx.input_device, ctx.dim, *grad_output) |
| |
| |
| # background streams used for copying |
| _streams: Optional[List[Optional[torch.Stream]]] = None |
| |
| |
| def _get_stream(device: torch.device): |
| """Get a background stream for copying between CPU and target device.""" |
| global _streams |
| if device.type == "cpu": |
| return None |
| device_mod = getattr(torch, device.type, None) |
| if device_mod is None: |
| return None |
| if _streams is None: |
| _streams = [None] * device_mod.device_count() |
| if _streams[device.index] is None: |
| _streams[device.index] = device_mod.Stream(device.index) |
| return _streams[device.index] |