| import torch |
| import torch.distributed as dist |
| |
| from torch.autograd.function import Function |
| |
| class SyncBatchNorm(Function): |
| |
| @staticmethod |
| def forward(self, input, weight, bias, running_mean, running_var, eps, momentum, process_group, world_size): |
| if not ( |
| input.is_contiguous(memory_format=torch.channels_last) or |
| input.is_contiguous(memory_format=torch.channels_last_3d) |
| ): |
| input = input.contiguous() |
| if weight is not None: |
| weight = weight.contiguous() |
| |
| size = int(input.numel() // input.size(1)) |
| if size == 1 and world_size < 2: |
| raise ValueError('Expected more than 1 value per channel when training, got input size {}'.format(size)) |
| |
| num_channels = input.shape[1] |
| if input.numel() > 0: |
| # calculate mean/invstd for input. |
| mean, invstd = torch.batch_norm_stats(input, eps) |
| |
| count = torch.full( |
| (1,), |
| input.numel() // input.size(1), |
| dtype=mean.dtype, |
| device=mean.device |
| ) |
| |
| # C, C, 1 -> (2C + 1) |
| combined = torch.cat([mean, invstd, count], dim=0) |
| else: |
| # for empty input, set stats and the count to zero. The stats with |
| # zero count will be filtered out later when computing global mean |
| # & invstd, but they still needs to participate the all_gather |
| # collective communication to unblock other peer processes. |
| combined = torch.zeros( |
| 2 * num_channels + 1, |
| dtype=input.dtype, |
| device=input.device |
| ) |
| |
| # Use allgather instead of allreduce because count could be different across |
| # ranks, simple all reduce op can not give correct results. |
| # batch_norm_gather_stats_with_counts calculates global mean & invstd based on |
| # all gathered mean, invstd and count. |
| # for nccl backend, use the optimized version of all gather. |
| if process_group._get_backend_name() == 'nccl': |
| # world_size * (2C + 1) |
| combined_size = combined.numel() |
| combined_flat = torch.empty(1, |
| combined_size * world_size, |
| dtype=combined.dtype, |
| device=combined.device) |
| dist.all_gather_into_tensor(combined_flat, combined, process_group, async_op=False) |
| combined = torch.reshape(combined_flat, (world_size, combined_size)) |
| # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 |
| mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1) |
| else: |
| # world_size * (2C + 1) |
| combined_list = [ |
| torch.empty_like(combined) for _ in range(world_size) |
| ] |
| dist.all_gather(combined_list, combined, process_group, async_op=False) |
| combined = torch.stack(combined_list, dim=0) |
| # world_size * (2C + 1) -> world_size * C, world_size * C, world_size * 1 |
| mean_all, invstd_all, count_all = torch.split(combined, num_channels, dim=1) |
| |
| if not torch.cuda.is_current_stream_capturing(): |
| # The lines below force a synchronization between CUDA and CPU, because |
| # the shape of the result count_all depends on the values in mask tensor. |
| # Such synchronizations break CUDA Graph capturing. |
| # See https://github.com/pytorch/pytorch/issues/78549 |
| # FIXME: https://github.com/pytorch/pytorch/issues/78656 describes |
| # a better longer-term solution. |
| |
| # remove stats from empty inputs |
| mask = count_all.squeeze(-1) >= 1 |
| count_all = count_all[mask] |
| mean_all = mean_all[mask] |
| invstd_all = invstd_all[mask] |
| |
| # calculate global mean & invstd |
| mean, invstd = torch.batch_norm_gather_stats_with_counts( |
| input, |
| mean_all, |
| invstd_all, |
| running_mean, |
| running_var, |
| momentum, |
| eps, |
| count_all.view(-1) |
| ) |
| |
| self.save_for_backward(input, weight, mean, invstd, count_all.to(torch.int32)) |
| self.process_group = process_group |
| |
| # apply element-wise normalization |
| if input.numel() > 0: |
| return torch.batch_norm_elemt(input, weight, bias, mean, invstd, eps) |
| else: |
| return torch.empty_like(input) |
| |
| @staticmethod |
| def backward(self, grad_output): |
| if not ( |
| grad_output.is_contiguous(memory_format=torch.channels_last) or |
| grad_output.is_contiguous(memory_format=torch.channels_last_3d) |
| ): |
| grad_output = grad_output.contiguous() |
| saved_input, weight, mean, invstd, count_tensor = self.saved_tensors |
| grad_input = grad_weight = grad_bias = None |
| process_group = self.process_group |
| |
| if saved_input.numel() > 0: |
| # calculate local stats as well as grad_weight / grad_bias |
| sum_dy, sum_dy_xmu, grad_weight, grad_bias = torch.batch_norm_backward_reduce( |
| grad_output, |
| saved_input, |
| mean, |
| invstd, |
| weight, |
| self.needs_input_grad[0], |
| self.needs_input_grad[1], |
| self.needs_input_grad[2] |
| ) |
| |
| if self.needs_input_grad[0]: |
| # synchronizing stats used to calculate input gradient. |
| num_channels = sum_dy.shape[0] |
| combined = torch.cat([sum_dy, sum_dy_xmu], dim=0) |
| torch.distributed.all_reduce( |
| combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False) |
| sum_dy, sum_dy_xmu = torch.split(combined, num_channels) |
| |
| # backward pass for gradient calculation |
| grad_input = torch.batch_norm_backward_elemt( |
| grad_output, |
| saved_input, |
| mean, |
| invstd, |
| weight, |
| sum_dy, |
| sum_dy_xmu, |
| count_tensor |
| ) |
| # synchronizing of grad_weight / grad_bias is not needed as distributed |
| # training would handle all reduce. |
| if weight is None or not self.needs_input_grad[1]: |
| grad_weight = None |
| |
| if weight is None or not self.needs_input_grad[2]: |
| grad_bias = None |
| else: |
| # This process got an empty input tensor in the forward pass. |
| # Although this process can directly set grad_input as an empty |
| # tensor of zeros, it still needs to participate in the collective |
| # communication to unblock its peers, as other peer processes might |
| # have recieved non-empty inputs. |
| num_channels = saved_input.shape[1] |
| if self.needs_input_grad[0]: |
| # launch all_reduce to unblock other peer processes |
| combined = torch.zeros( |
| 2 * num_channels, |
| dtype=saved_input.dtype, |
| device=saved_input.device |
| ) |
| torch.distributed.all_reduce( |
| combined, torch.distributed.ReduceOp.SUM, process_group, async_op=False) |
| |
| # Leave grad_input, grad_weight and grad_bias as None, which will be |
| # interpreted by the autograd engine as Tensors full of zeros. |
| |
| return grad_input, grad_weight, grad_bias, None, None, None, None, None, None |
| |
| class CrossMapLRN2d(Function): |
| |
| @staticmethod |
| def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1): |
| ctx.size = size |
| ctx.alpha = alpha |
| ctx.beta = beta |
| ctx.k = k |
| ctx.scale = None |
| |
| assert input.dim() == 4 |
| |
| ctx.scale = ctx.scale or input.new() |
| output = input.new() |
| |
| batch_size = input.size(0) |
| channels = input.size(1) |
| input_height = input.size(2) |
| input_width = input.size(3) |
| |
| output.resize_as_(input) |
| ctx.scale.resize_as_(input) |
| |
| # use output storage as temporary buffer |
| input_square = output |
| torch.pow(input, 2, out=input_square) |
| |
| pre_pad = int((ctx.size - 1) / 2 + 1) |
| pre_pad_crop = channels if pre_pad > channels else pre_pad |
| |
| scale_first = ctx.scale.select(1, 0) |
| scale_first.zero_() |
| # compute first feature map normalization |
| for c in range(pre_pad_crop): |
| scale_first.add_(input_square.select(1, c)) |
| |
| # reuse computations for next feature maps normalization |
| # by adding the next feature map and removing the previous |
| for c in range(1, channels): |
| scale_previous = ctx.scale.select(1, c - 1) |
| scale_current = ctx.scale.select(1, c) |
| scale_current.copy_(scale_previous) |
| if c < channels - pre_pad + 1: |
| square_next = input_square.select(1, c + pre_pad - 1) |
| scale_current.add_(square_next, alpha=1) |
| |
| if c > pre_pad: |
| square_previous = input_square.select(1, c - pre_pad) |
| scale_current.add_(square_previous, alpha=-1) |
| |
| ctx.scale.mul_(ctx.alpha / ctx.size).add_(ctx.k) |
| |
| torch.pow(ctx.scale, -ctx.beta, out=output) |
| output.mul_(input) |
| |
| ctx.save_for_backward(input, output) |
| return output |
| |
| @staticmethod |
| def backward(ctx, grad_output): |
| input, output = ctx.saved_tensors |
| grad_input = grad_output.new() |
| |
| batch_size = input.size(0) |
| channels = input.size(1) |
| input_height = input.size(2) |
| input_width = input.size(3) |
| |
| paddded_ratio = input.new(channels + ctx.size - 1, input_height, |
| input_width) |
| accum_ratio = input.new(input_height, input_width) |
| |
| cache_ratio_value = 2 * ctx.alpha * ctx.beta / ctx.size |
| inversePrePad = int(ctx.size - (ctx.size - 1) / 2) |
| |
| grad_input.resize_as_(input) |
| torch.pow(ctx.scale, -ctx.beta, out=grad_input).mul_(grad_output) |
| |
| paddded_ratio.zero_() |
| padded_ratio_center = paddded_ratio.narrow(0, inversePrePad, |
| channels) |
| for n in range(batch_size): |
| torch.mul(grad_output[n], output[n], out=padded_ratio_center) |
| padded_ratio_center.div_(ctx.scale[n]) |
| torch.sum( |
| paddded_ratio.narrow(0, 0, ctx.size - 1), 0, keepdim=False, out=accum_ratio) |
| for c in range(channels): |
| accum_ratio.add_(paddded_ratio[c + ctx.size - 1]) |
| grad_input[n][c].addcmul_(input[n][c], accum_ratio, value=-cache_ratio_value) |
| accum_ratio.add_(paddded_ratio[c], alpha=-1) |
| |
| return grad_input, None, None, None, None |
| |
| class BackwardHookFunction(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, *args): |
| ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad]) |
| return args |
| |
| @staticmethod |
| def backward(ctx, *args): |
| return args |