| import math |
| import warnings |
| |
| from torch import Tensor |
| import torch |
| |
| |
| # These no_grad_* functions are necessary as wrappers around the parts of these |
| # functions that use `with torch.no_grad()`. The JIT doesn't support context |
| # managers, so these need to be implemented as builtins. Using these wrappers |
| # lets us keep those builtins small and re-usable. |
| def _no_grad_uniform_(tensor, a, b): |
| with torch.no_grad(): |
| return tensor.uniform_(a, b) |
| |
| |
| def _no_grad_normal_(tensor, mean, std): |
| with torch.no_grad(): |
| return tensor.normal_(mean, std) |
| |
| |
| def _no_grad_trunc_normal_(tensor, mean, std, a, b): |
| # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf |
| def norm_cdf(x): |
| # Computes standard normal cumulative distribution function |
| return (1. + math.erf(x / math.sqrt(2.))) / 2. |
| |
| if (mean < a - 2 * std) or (mean > b + 2 * std): |
| warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " |
| "The distribution of values may be incorrect.", |
| stacklevel=2) |
| |
| with torch.no_grad(): |
| # Values are generated by using a truncated uniform distribution and |
| # then using the inverse CDF for the normal distribution. |
| # Get upper and lower cdf values |
| l = norm_cdf((a - mean) / std) |
| u = norm_cdf((b - mean) / std) |
| |
| # Uniformly fill tensor with values from [l, u], then translate to |
| # [2l-1, 2u-1]. |
| tensor.uniform_(2 * l - 1, 2 * u - 1) |
| |
| # Use inverse cdf transform for normal distribution to get truncated |
| # standard normal |
| tensor.erfinv_() |
| |
| # Transform to proper mean, std |
| tensor.mul_(std * math.sqrt(2.)) |
| tensor.add_(mean) |
| |
| # Clamp to ensure it's in the proper range |
| tensor.clamp_(min=a, max=b) |
| return tensor |
| |
| |
| def _no_grad_fill_(tensor, val): |
| with torch.no_grad(): |
| return tensor.fill_(val) |
| |
| |
| def _no_grad_zero_(tensor): |
| with torch.no_grad(): |
| return tensor.zero_() |
| |
| |
| def calculate_gain(nonlinearity, param=None): |
| r"""Return the recommended gain value for the given nonlinearity function. |
| The values are as follows: |
| |
| ================= ==================================================== |
| nonlinearity gain |
| ================= ==================================================== |
| Linear / Identity :math:`1` |
| Conv{1,2,3}D :math:`1` |
| Sigmoid :math:`1` |
| Tanh :math:`\frac{5}{3}` |
| ReLU :math:`\sqrt{2}` |
| Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` |
| SELU :math:`\frac{3}{4}` |
| ================= ==================================================== |
| |
| .. warning:: |
| In order to implement `Self-Normalizing Neural Networks`_ , |
| you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``. |
| This gives the initial weights a variance of ``1 / N``, |
| which is necessary to induce a stable fixed point in the forward pass. |
| In contrast, the default gain for ``SELU`` sacrifices the normalisation |
| effect for more stable gradient flow in rectangular layers. |
| |
| Args: |
| nonlinearity: the non-linear function (`nn.functional` name) |
| param: optional parameter for the non-linear function |
| |
| Examples: |
| >>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 |
| |
| .. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html |
| """ |
| linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] |
| if nonlinearity in linear_fns or nonlinearity == 'sigmoid': |
| return 1 |
| elif nonlinearity == 'tanh': |
| return 5.0 / 3 |
| elif nonlinearity == 'relu': |
| return math.sqrt(2.0) |
| elif nonlinearity == 'leaky_relu': |
| if param is None: |
| negative_slope = 0.01 |
| elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): |
| # True/False are instances of int, hence check above |
| negative_slope = param |
| else: |
| raise ValueError("negative_slope {} not a valid number".format(param)) |
| return math.sqrt(2.0 / (1 + negative_slope ** 2)) |
| elif nonlinearity == 'selu': |
| return 3.0 / 4 # Value found empirically (https://github.com/pytorch/pytorch/pull/50664) |
| else: |
| raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) |
| |
| |
| def uniform_(tensor: Tensor, a: float = 0., b: float = 1.) -> Tensor: |
| r"""Fills the input Tensor with values drawn from the uniform |
| distribution :math:`\mathcal{U}(a, b)`. |
| |
| Args: |
| tensor: an n-dimensional `torch.Tensor` |
| a: the lower bound of the uniform distribution |
| b: the upper bound of the uniform distribution |
| |
| Examples: |
| >>> w = torch.empty(3, 5) |
| >>> nn.init.uniform_(w) |
| """ |
| if torch.overrides.has_torch_function_variadic(tensor): |
| return torch.overrides.handle_torch_function(uniform_, (tensor,), tensor=tensor, a=a, b=b) |
| return _no_grad_uniform_(tensor, a, b) |
| |
| |
| def normal_(tensor: Tensor, mean: float = 0., std: float = 1.) -> Tensor: |
| r"""Fills the input Tensor with values drawn from the normal |
| distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`. |
| |
| Args: |
| tensor: an n-dimensional `torch.Tensor` |
| mean: the mean of the normal distribution |
| std: the standard deviation of the normal distribution |
| |
| Examples: |
| >>> w = torch.empty(3, 5) |
| >>> nn.init.normal_(w) |
| """ |
| if torch.overrides.has_torch_function_variadic(tensor): |
| return torch.overrides.handle_torch_function(normal_, (tensor,), tensor=tensor, mean=mean, std=std) |
| return _no_grad_normal_(tensor, mean, std) |
| |
| def trunc_normal_(tensor: Tensor, mean: float = 0., std: float = 1., a: float = -2., b: float = 2.) -> Tensor: |
| r"""Fills the input Tensor with values drawn from a truncated |
| normal distribution. The values are effectively drawn from the |
| normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` |
| with values outside :math:`[a, b]` redrawn until they are within |
| the bounds. The method used for generating the random values works |
| best when :math:`a \leq \text{mean} \leq b`. |
| |
| Args: |
| tensor: an n-dimensional `torch.Tensor` |
| mean: the mean of the normal distribution |
| std: the standard deviation of the normal distribution |
| a: the minimum cutoff value |
| b: the maximum cutoff value |
| |
| Examples: |
| >>> w = torch.empty(3, 5) |
| >>> nn.init.trunc_normal_(w) |
| """ |
| return _no_grad_trunc_normal_(tensor, mean, std, a, b) |
| |
| |
| def constant_(tensor: Tensor, val: float) -> Tensor: |
| r"""Fills the input Tensor with the value :math:`\text{val}`. |
| |
| Args: |
| tensor: an n-dimensional `torch.Tensor` |
| val: the value to fill the tensor with |
| |
| Examples: |
| >>> w = torch.empty(3, 5) |
| >>> nn.init.constant_(w, 0.3) |
| """ |
| if torch.overrides.has_torch_function_variadic(tensor): |
| return torch.overrides.handle_torch_function(constant_, (tensor,), tensor=tensor, val=val) |
| return _no_grad_fill_(tensor, val) |
| |
| |
| def ones_(tensor: Tensor) -> Tensor: |
| r"""Fills the input Tensor with the scalar value `1`. |
| |
| Args: |
| tensor: an n-dimensional `torch.Tensor` |
| |
| Examples: |
| >>> w = torch.empty(3, 5) |
| >>> nn.init.ones_(w) |
| """ |
| return _no_grad_fill_(tensor, 1.) |
| |
| |
| def zeros_(tensor: Tensor) -> Tensor: |
| r"""Fills the input Tensor with the scalar value `0`. |
| |
| Args: |
| tensor: an n-dimensional `torch.Tensor` |
| |
| Examples: |
| >>> w = torch.empty(3, 5) |
| >>> nn.init.zeros_(w) |
| """ |
| return _no_grad_zero_(tensor) |
| |
| |
| def eye_(tensor): |
| r"""Fills the 2-dimensional input `Tensor` with the identity |
| matrix. Preserves the identity of the inputs in `Linear` layers, where as |
| many inputs are preserved as possible. |
| |
| Args: |
| tensor: a 2-dimensional `torch.Tensor` |
| |
| Examples: |
| >>> w = torch.empty(3, 5) |
| >>> nn.init.eye_(w) |
| """ |
| if tensor.ndimension() != 2: |
| raise ValueError("Only tensors with 2 dimensions are supported") |
| |
| with torch.no_grad(): |
| torch.eye(*tensor.shape, out=tensor, requires_grad=tensor.requires_grad) |
| return tensor |
| |
| |
| def dirac_(tensor, groups=1): |
| r"""Fills the {3, 4, 5}-dimensional input `Tensor` with the Dirac |
| delta function. Preserves the identity of the inputs in `Convolutional` |
| layers, where as many input channels are preserved as possible. In case |
| of groups>1, each group of channels preserves identity |
| |
| Args: |
| tensor: a {3, 4, 5}-dimensional `torch.Tensor` |
| groups (int, optional): number of groups in the conv layer (default: 1) |
| Examples: |
| >>> w = torch.empty(3, 16, 5, 5) |
| >>> nn.init.dirac_(w) |
| >>> w = torch.empty(3, 24, 5, 5) |
| >>> nn.init.dirac_(w, 3) |
| """ |
| dimensions = tensor.ndimension() |
| if dimensions not in [3, 4, 5]: |
| raise ValueError("Only tensors with 3, 4, or 5 dimensions are supported") |
| |
| sizes = tensor.size() |
| |
| if sizes[0] % groups != 0: |
| raise ValueError('dim 0 must be divisible by groups') |
| |
| out_chans_per_grp = sizes[0] // groups |
| min_dim = min(out_chans_per_grp, sizes[1]) |
| |
| with torch.no_grad(): |
| tensor.zero_() |
| |
| for g in range(groups): |
| for d in range(min_dim): |
| if dimensions == 3: # Temporal convolution |
| tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2] = 1 |
| elif dimensions == 4: # Spatial convolution |
| tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2, |
| tensor.size(3) // 2] = 1 |
| else: # Volumetric convolution |
| tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2, |
| tensor.size(3) // 2, tensor.size(4) // 2] = 1 |
| return tensor |
| |
| |
| def _calculate_fan_in_and_fan_out(tensor): |
| dimensions = tensor.dim() |
| if dimensions < 2: |
| raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions") |
| |
| num_input_fmaps = tensor.size(1) |
| num_output_fmaps = tensor.size(0) |
| receptive_field_size = 1 |
| if tensor.dim() > 2: |
| # math.prod is not always available, accumulate the product manually |
| # we could use functools.reduce but that is not supported by TorchScript |
| for s in tensor.shape[2:]: |
| receptive_field_size *= s |
| fan_in = num_input_fmaps * receptive_field_size |
| fan_out = num_output_fmaps * receptive_field_size |
| |
| return fan_in, fan_out |
| |
| |
| def xavier_uniform_(tensor: Tensor, gain: float = 1.) -> Tensor: |
| r"""Fills the input `Tensor` with values according to the method |
| described in `Understanding the difficulty of training deep feedforward |
| neural networks` - Glorot, X. & Bengio, Y. (2010), using a uniform |
| distribution. The resulting tensor will have values sampled from |
| :math:`\mathcal{U}(-a, a)` where |
| |
| .. math:: |
| a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}} |
| |
| Also known as Glorot initialization. |
| |
| Args: |
| tensor: an n-dimensional `torch.Tensor` |
| gain: an optional scaling factor |
| |
| Examples: |
| >>> w = torch.empty(3, 5) |
| >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu')) |
| """ |
| fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) |
| std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) |
| a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation |
| |
| return _no_grad_uniform_(tensor, -a, a) |
| |
| |
| def xavier_normal_(tensor: Tensor, gain: float = 1.) -> Tensor: |
| r"""Fills the input `Tensor` with values according to the method |
| described in `Understanding the difficulty of training deep feedforward |
| neural networks` - Glorot, X. & Bengio, Y. (2010), using a normal |
| distribution. The resulting tensor will have values sampled from |
| :math:`\mathcal{N}(0, \text{std}^2)` where |
| |
| .. math:: |
| \text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}} |
| |
| Also known as Glorot initialization. |
| |
| Args: |
| tensor: an n-dimensional `torch.Tensor` |
| gain: an optional scaling factor |
| |
| Examples: |
| >>> w = torch.empty(3, 5) |
| >>> nn.init.xavier_normal_(w) |
| """ |
| fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) |
| std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) |
| |
| return _no_grad_normal_(tensor, 0., std) |
| |
| |
| def _calculate_correct_fan(tensor, mode): |
| mode = mode.lower() |
| valid_modes = ['fan_in', 'fan_out'] |
| if mode not in valid_modes: |
| raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes)) |
| |
| fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) |
| return fan_in if mode == 'fan_in' else fan_out |
| |
| |
| def kaiming_uniform_( |
| tensor: Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu' |
| ): |
| r"""Fills the input `Tensor` with values according to the method |
| described in `Delving deep into rectifiers: Surpassing human-level |
| performance on ImageNet classification` - He, K. et al. (2015), using a |
| uniform distribution. The resulting tensor will have values sampled from |
| :math:`\mathcal{U}(-\text{bound}, \text{bound})` where |
| |
| .. math:: |
| \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} |
| |
| Also known as He initialization. |
| |
| Args: |
| tensor: an n-dimensional `torch.Tensor` |
| a: the negative slope of the rectifier used after this layer (only |
| used with ``'leaky_relu'``) |
| mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` |
| preserves the magnitude of the variance of the weights in the |
| forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the |
| backwards pass. |
| nonlinearity: the non-linear function (`nn.functional` name), |
| recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). |
| |
| Examples: |
| >>> w = torch.empty(3, 5) |
| >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu') |
| """ |
| if torch.overrides.has_torch_function_variadic(tensor): |
| return torch.overrides.handle_torch_function( |
| kaiming_uniform_, |
| (tensor,), |
| tensor=tensor, |
| a=a, |
| mode=mode, |
| nonlinearity=nonlinearity) |
| |
| if 0 in tensor.shape: |
| warnings.warn("Initializing zero-element tensors is a no-op") |
| return tensor |
| fan = _calculate_correct_fan(tensor, mode) |
| gain = calculate_gain(nonlinearity, a) |
| std = gain / math.sqrt(fan) |
| bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation |
| with torch.no_grad(): |
| return tensor.uniform_(-bound, bound) |
| |
| |
| def kaiming_normal_( |
| tensor: Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu' |
| ): |
| r"""Fills the input `Tensor` with values according to the method |
| described in `Delving deep into rectifiers: Surpassing human-level |
| performance on ImageNet classification` - He, K. et al. (2015), using a |
| normal distribution. The resulting tensor will have values sampled from |
| :math:`\mathcal{N}(0, \text{std}^2)` where |
| |
| .. math:: |
| \text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}} |
| |
| Also known as He initialization. |
| |
| Args: |
| tensor: an n-dimensional `torch.Tensor` |
| a: the negative slope of the rectifier used after this layer (only |
| used with ``'leaky_relu'``) |
| mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` |
| preserves the magnitude of the variance of the weights in the |
| forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the |
| backwards pass. |
| nonlinearity: the non-linear function (`nn.functional` name), |
| recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). |
| |
| Examples: |
| >>> w = torch.empty(3, 5) |
| >>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu') |
| """ |
| if 0 in tensor.shape: |
| warnings.warn("Initializing zero-element tensors is a no-op") |
| return tensor |
| fan = _calculate_correct_fan(tensor, mode) |
| gain = calculate_gain(nonlinearity, a) |
| std = gain / math.sqrt(fan) |
| with torch.no_grad(): |
| return tensor.normal_(0, std) |
| |
| |
| def orthogonal_(tensor, gain=1): |
| r"""Fills the input `Tensor` with a (semi) orthogonal matrix, as |
| described in `Exact solutions to the nonlinear dynamics of learning in deep |
| linear neural networks` - Saxe, A. et al. (2013). The input tensor must have |
| at least 2 dimensions, and for tensors with more than 2 dimensions the |
| trailing dimensions are flattened. |
| |
| Args: |
| tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2` |
| gain: optional scaling factor |
| |
| Examples: |
| >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) |
| >>> w = torch.empty(3, 5) |
| >>> nn.init.orthogonal_(w) |
| """ |
| if tensor.ndimension() < 2: |
| raise ValueError("Only tensors with 2 or more dimensions are supported") |
| |
| if tensor.numel() == 0: |
| # no-op |
| return tensor |
| rows = tensor.size(0) |
| cols = tensor.numel() // rows |
| flattened = tensor.new(rows, cols).normal_(0, 1) |
| |
| if rows < cols: |
| flattened.t_() |
| |
| # Compute the qr factorization |
| q, r = torch.linalg.qr(flattened) |
| # Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf |
| d = torch.diag(r, 0) |
| ph = d.sign() |
| q *= ph |
| |
| if rows < cols: |
| q.t_() |
| |
| with torch.no_grad(): |
| tensor.view_as(q).copy_(q) |
| tensor.mul_(gain) |
| return tensor |
| |
| |
| def sparse_(tensor, sparsity, std=0.01): |
| r"""Fills the 2D input `Tensor` as a sparse matrix, where the |
| non-zero elements will be drawn from the normal distribution |
| :math:`\mathcal{N}(0, 0.01)`, as described in `Deep learning via |
| Hessian-free optimization` - Martens, J. (2010). |
| |
| Args: |
| tensor: an n-dimensional `torch.Tensor` |
| sparsity: The fraction of elements in each column to be set to zero |
| std: the standard deviation of the normal distribution used to generate |
| the non-zero values |
| |
| Examples: |
| >>> w = torch.empty(3, 5) |
| >>> nn.init.sparse_(w, sparsity=0.1) |
| """ |
| if tensor.ndimension() != 2: |
| raise ValueError("Only tensors with 2 dimensions are supported") |
| |
| rows, cols = tensor.shape |
| num_zeros = int(math.ceil(sparsity * rows)) |
| |
| with torch.no_grad(): |
| tensor.normal_(0, std) |
| for col_idx in range(cols): |
| row_indices = torch.randperm(rows) |
| zero_indices = row_indices[:num_zeros] |
| tensor[zero_indices, col_idx] = 0 |
| return tensor |
| |
| |
| # for backward compatibility |
| def _make_deprecate(meth): |
| new_name = meth.__name__ |
| old_name = new_name[:-1] |
| |
| def deprecated_init(*args, **kwargs): |
| warnings.warn("nn.init.{} is now deprecated in favor of nn.init.{}." |
| .format(old_name, new_name), stacklevel=2) |
| return meth(*args, **kwargs) |
| |
| deprecated_init.__doc__ = r""" |
| {old_name}(...) |
| |
| .. warning:: |
| This method is now deprecated in favor of :func:`torch.nn.init.{new_name}`. |
| |
| See :func:`~torch.nn.init.{new_name}` for details.""".format( |
| old_name=old_name, new_name=new_name) |
| deprecated_init.__name__ = old_name |
| return deprecated_init |
| |
| |
| uniform = _make_deprecate(uniform_) |
| normal = _make_deprecate(normal_) |
| constant = _make_deprecate(constant_) |
| eye = _make_deprecate(eye_) |
| dirac = _make_deprecate(dirac_) |
| xavier_uniform = _make_deprecate(xavier_uniform_) |
| xavier_normal = _make_deprecate(xavier_normal_) |
| kaiming_uniform = _make_deprecate(kaiming_uniform_) |
| kaiming_normal = _make_deprecate(kaiming_normal_) |
| orthogonal = _make_deprecate(orthogonal_) |
| sparse = _make_deprecate(sparse_) |