| import numbers |
| import warnings |
| from collections import namedtuple |
| from typing import List, Tuple |
| |
| import torch |
| import torch.jit as jit |
| import torch.nn as nn |
| from torch import Tensor |
| from torch.nn import Parameter |
| |
| """ |
| Some helper classes for writing custom TorchScript LSTMs. |
| |
| Goals: |
| - Classes are easy to read, use, and extend |
| - Performance of custom LSTMs approach fused-kernel-levels of speed. |
| |
| A few notes about features we could add to clean up the below code: |
| - Support enumerate with nn.ModuleList: |
| https://github.com/pytorch/pytorch/issues/14471 |
| - Support enumerate/zip with lists: |
| https://github.com/pytorch/pytorch/issues/15952 |
| - Support overriding of class methods: |
| https://github.com/pytorch/pytorch/issues/10733 |
| - Support passing around user-defined namedtuple types for readability |
| - Support slicing w/ range. It enables reversing lists easily. |
| https://github.com/pytorch/pytorch/issues/10774 |
| - Multiline type annotations. List[List[Tuple[Tensor,Tensor]]] is verbose |
| https://github.com/pytorch/pytorch/pull/14922 |
| """ |
| |
| |
| def script_lstm( |
| input_size, |
| hidden_size, |
| num_layers, |
| bias=True, |
| batch_first=False, |
| dropout=False, |
| bidirectional=False, |
| ): |
| """Returns a ScriptModule that mimics a PyTorch native LSTM.""" |
| |
| # The following are not implemented. |
| assert bias |
| assert not batch_first |
| |
| if bidirectional: |
| stack_type = StackedLSTM2 |
| layer_type = BidirLSTMLayer |
| dirs = 2 |
| elif dropout: |
| stack_type = StackedLSTMWithDropout |
| layer_type = LSTMLayer |
| dirs = 1 |
| else: |
| stack_type = StackedLSTM |
| layer_type = LSTMLayer |
| dirs = 1 |
| |
| return stack_type( |
| num_layers, |
| layer_type, |
| first_layer_args=[LSTMCell, input_size, hidden_size], |
| other_layer_args=[LSTMCell, hidden_size * dirs, hidden_size], |
| ) |
| |
| |
| def script_lnlstm( |
| input_size, |
| hidden_size, |
| num_layers, |
| bias=True, |
| batch_first=False, |
| dropout=False, |
| bidirectional=False, |
| decompose_layernorm=False, |
| ): |
| """Returns a ScriptModule that mimics a PyTorch native LSTM.""" |
| |
| # The following are not implemented. |
| assert bias |
| assert not batch_first |
| assert not dropout |
| |
| if bidirectional: |
| stack_type = StackedLSTM2 |
| layer_type = BidirLSTMLayer |
| dirs = 2 |
| else: |
| stack_type = StackedLSTM |
| layer_type = LSTMLayer |
| dirs = 1 |
| |
| return stack_type( |
| num_layers, |
| layer_type, |
| first_layer_args=[ |
| LayerNormLSTMCell, |
| input_size, |
| hidden_size, |
| decompose_layernorm, |
| ], |
| other_layer_args=[ |
| LayerNormLSTMCell, |
| hidden_size * dirs, |
| hidden_size, |
| decompose_layernorm, |
| ], |
| ) |
| |
| |
| LSTMState = namedtuple("LSTMState", ["hx", "cx"]) |
| |
| |
| def reverse(lst: List[Tensor]) -> List[Tensor]: |
| return lst[::-1] |
| |
| |
| class LSTMCell(jit.ScriptModule): |
| def __init__(self, input_size, hidden_size): |
| super().__init__() |
| self.input_size = input_size |
| self.hidden_size = hidden_size |
| self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size)) |
| self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size)) |
| self.bias_ih = Parameter(torch.randn(4 * hidden_size)) |
| self.bias_hh = Parameter(torch.randn(4 * hidden_size)) |
| |
| @jit.script_method |
| def forward( |
| self, input: Tensor, state: Tuple[Tensor, Tensor] |
| ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: |
| hx, cx = state |
| gates = ( |
| torch.mm(input, self.weight_ih.t()) |
| + self.bias_ih |
| + torch.mm(hx, self.weight_hh.t()) |
| + self.bias_hh |
| ) |
| ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) |
| |
| ingate = torch.sigmoid(ingate) |
| forgetgate = torch.sigmoid(forgetgate) |
| cellgate = torch.tanh(cellgate) |
| outgate = torch.sigmoid(outgate) |
| |
| cy = (forgetgate * cx) + (ingate * cellgate) |
| hy = outgate * torch.tanh(cy) |
| |
| return hy, (hy, cy) |
| |
| |
| class LayerNorm(jit.ScriptModule): |
| def __init__(self, normalized_shape): |
| super().__init__() |
| if isinstance(normalized_shape, numbers.Integral): |
| normalized_shape = (normalized_shape,) |
| normalized_shape = torch.Size(normalized_shape) |
| |
| # XXX: This is true for our LSTM / NLP use case and helps simplify code |
| assert len(normalized_shape) == 1 |
| |
| self.weight = Parameter(torch.ones(normalized_shape)) |
| self.bias = Parameter(torch.zeros(normalized_shape)) |
| self.normalized_shape = normalized_shape |
| |
| @jit.script_method |
| def compute_layernorm_stats(self, input): |
| mu = input.mean(-1, keepdim=True) |
| sigma = input.std(-1, keepdim=True, unbiased=False) |
| return mu, sigma |
| |
| @jit.script_method |
| def forward(self, input): |
| mu, sigma = self.compute_layernorm_stats(input) |
| return (input - mu) / sigma * self.weight + self.bias |
| |
| |
| class LayerNormLSTMCell(jit.ScriptModule): |
| def __init__(self, input_size, hidden_size, decompose_layernorm=False): |
| super().__init__() |
| self.input_size = input_size |
| self.hidden_size = hidden_size |
| self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size)) |
| self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size)) |
| # The layernorms provide learnable biases |
| |
| if decompose_layernorm: |
| ln = LayerNorm |
| else: |
| ln = nn.LayerNorm |
| |
| self.layernorm_i = ln(4 * hidden_size) |
| self.layernorm_h = ln(4 * hidden_size) |
| self.layernorm_c = ln(hidden_size) |
| |
| @jit.script_method |
| def forward( |
| self, input: Tensor, state: Tuple[Tensor, Tensor] |
| ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: |
| hx, cx = state |
| igates = self.layernorm_i(torch.mm(input, self.weight_ih.t())) |
| hgates = self.layernorm_h(torch.mm(hx, self.weight_hh.t())) |
| gates = igates + hgates |
| ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) |
| |
| ingate = torch.sigmoid(ingate) |
| forgetgate = torch.sigmoid(forgetgate) |
| cellgate = torch.tanh(cellgate) |
| outgate = torch.sigmoid(outgate) |
| |
| cy = self.layernorm_c((forgetgate * cx) + (ingate * cellgate)) |
| hy = outgate * torch.tanh(cy) |
| |
| return hy, (hy, cy) |
| |
| |
| class LSTMLayer(jit.ScriptModule): |
| def __init__(self, cell, *cell_args): |
| super().__init__() |
| self.cell = cell(*cell_args) |
| |
| @jit.script_method |
| def forward( |
| self, input: Tensor, state: Tuple[Tensor, Tensor] |
| ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: |
| inputs = input.unbind(0) |
| outputs = torch.jit.annotate(List[Tensor], []) |
| for i in range(len(inputs)): |
| out, state = self.cell(inputs[i], state) |
| outputs += [out] |
| return torch.stack(outputs), state |
| |
| |
| class ReverseLSTMLayer(jit.ScriptModule): |
| def __init__(self, cell, *cell_args): |
| super().__init__() |
| self.cell = cell(*cell_args) |
| |
| @jit.script_method |
| def forward( |
| self, input: Tensor, state: Tuple[Tensor, Tensor] |
| ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: |
| inputs = reverse(input.unbind(0)) |
| outputs = jit.annotate(List[Tensor], []) |
| for i in range(len(inputs)): |
| out, state = self.cell(inputs[i], state) |
| outputs += [out] |
| return torch.stack(reverse(outputs)), state |
| |
| |
| class BidirLSTMLayer(jit.ScriptModule): |
| __constants__ = ["directions"] |
| |
| def __init__(self, cell, *cell_args): |
| super().__init__() |
| self.directions = nn.ModuleList( |
| [ |
| LSTMLayer(cell, *cell_args), |
| ReverseLSTMLayer(cell, *cell_args), |
| ] |
| ) |
| |
| @jit.script_method |
| def forward( |
| self, input: Tensor, states: List[Tuple[Tensor, Tensor]] |
| ) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: |
| # List[LSTMState]: [forward LSTMState, backward LSTMState] |
| outputs = jit.annotate(List[Tensor], []) |
| output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) |
| # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 |
| i = 0 |
| for direction in self.directions: |
| state = states[i] |
| out, out_state = direction(input, state) |
| outputs += [out] |
| output_states += [out_state] |
| i += 1 |
| return torch.cat(outputs, -1), output_states |
| |
| |
| def init_stacked_lstm(num_layers, layer, first_layer_args, other_layer_args): |
| layers = [layer(*first_layer_args)] + [ |
| layer(*other_layer_args) for _ in range(num_layers - 1) |
| ] |
| return nn.ModuleList(layers) |
| |
| |
| class StackedLSTM(jit.ScriptModule): |
| __constants__ = ["layers"] # Necessary for iterating through self.layers |
| |
| def __init__(self, num_layers, layer, first_layer_args, other_layer_args): |
| super().__init__() |
| self.layers = init_stacked_lstm( |
| num_layers, layer, first_layer_args, other_layer_args |
| ) |
| |
| @jit.script_method |
| def forward( |
| self, input: Tensor, states: List[Tuple[Tensor, Tensor]] |
| ) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: |
| # List[LSTMState]: One state per layer |
| output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) |
| output = input |
| # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 |
| i = 0 |
| for rnn_layer in self.layers: |
| state = states[i] |
| output, out_state = rnn_layer(output, state) |
| output_states += [out_state] |
| i += 1 |
| return output, output_states |
| |
| |
| # Differs from StackedLSTM in that its forward method takes |
| # List[List[Tuple[Tensor,Tensor]]]. It would be nice to subclass StackedLSTM |
| # except we don't support overriding script methods. |
| # https://github.com/pytorch/pytorch/issues/10733 |
| class StackedLSTM2(jit.ScriptModule): |
| __constants__ = ["layers"] # Necessary for iterating through self.layers |
| |
| def __init__(self, num_layers, layer, first_layer_args, other_layer_args): |
| super().__init__() |
| self.layers = init_stacked_lstm( |
| num_layers, layer, first_layer_args, other_layer_args |
| ) |
| |
| @jit.script_method |
| def forward( |
| self, input: Tensor, states: List[List[Tuple[Tensor, Tensor]]] |
| ) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]]: |
| # List[List[LSTMState]]: The outer list is for layers, |
| # inner list is for directions. |
| output_states = jit.annotate(List[List[Tuple[Tensor, Tensor]]], []) |
| output = input |
| # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 |
| i = 0 |
| for rnn_layer in self.layers: |
| state = states[i] |
| output, out_state = rnn_layer(output, state) |
| output_states += [out_state] |
| i += 1 |
| return output, output_states |
| |
| |
| class StackedLSTMWithDropout(jit.ScriptModule): |
| # Necessary for iterating through self.layers and dropout support |
| __constants__ = ["layers", "num_layers"] |
| |
| def __init__(self, num_layers, layer, first_layer_args, other_layer_args): |
| super().__init__() |
| self.layers = init_stacked_lstm( |
| num_layers, layer, first_layer_args, other_layer_args |
| ) |
| # Introduces a Dropout layer on the outputs of each LSTM layer except |
| # the last layer, with dropout probability = 0.4. |
| self.num_layers = num_layers |
| |
| if num_layers == 1: |
| warnings.warn( |
| "dropout lstm adds dropout layers after all but last " |
| "recurrent layer, it expects num_layers greater than " |
| "1, but got num_layers = 1" |
| ) |
| |
| self.dropout_layer = nn.Dropout(0.4) |
| |
| @jit.script_method |
| def forward( |
| self, input: Tensor, states: List[Tuple[Tensor, Tensor]] |
| ) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: |
| # List[LSTMState]: One state per layer |
| output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) |
| output = input |
| # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 |
| i = 0 |
| for rnn_layer in self.layers: |
| state = states[i] |
| output, out_state = rnn_layer(output, state) |
| # Apply the dropout layer except the last layer |
| if i < self.num_layers - 1: |
| output = self.dropout_layer(output) |
| output_states += [out_state] |
| i += 1 |
| return output, output_states |
| |
| |
| def flatten_states(states): |
| states = list(zip(*states)) |
| assert len(states) == 2 |
| return [torch.stack(state) for state in states] |
| |
| |
| def double_flatten_states(states): |
| # XXX: Can probably write this in a nicer way |
| states = flatten_states([flatten_states(inner) for inner in states]) |
| return [hidden.view([-1] + list(hidden.shape[2:])) for hidden in states] |
| |
| |
| def test_script_rnn_layer(seq_len, batch, input_size, hidden_size): |
| inp = torch.randn(seq_len, batch, input_size) |
| state = LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size)) |
| rnn = LSTMLayer(LSTMCell, input_size, hidden_size) |
| out, out_state = rnn(inp, state) |
| |
| # Control: pytorch native LSTM |
| lstm = nn.LSTM(input_size, hidden_size, 1) |
| lstm_state = LSTMState(state.hx.unsqueeze(0), state.cx.unsqueeze(0)) |
| for lstm_param, custom_param in zip(lstm.all_weights[0], rnn.parameters()): |
| assert lstm_param.shape == custom_param.shape |
| with torch.no_grad(): |
| lstm_param.copy_(custom_param) |
| lstm_out, lstm_out_state = lstm(inp, lstm_state) |
| |
| assert (out - lstm_out).abs().max() < 1e-5 |
| assert (out_state[0] - lstm_out_state[0]).abs().max() < 1e-5 |
| assert (out_state[1] - lstm_out_state[1]).abs().max() < 1e-5 |
| |
| |
| def test_script_stacked_rnn(seq_len, batch, input_size, hidden_size, num_layers): |
| inp = torch.randn(seq_len, batch, input_size) |
| states = [ |
| LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size)) |
| for _ in range(num_layers) |
| ] |
| rnn = script_lstm(input_size, hidden_size, num_layers) |
| out, out_state = rnn(inp, states) |
| custom_state = flatten_states(out_state) |
| |
| # Control: pytorch native LSTM |
| lstm = nn.LSTM(input_size, hidden_size, num_layers) |
| lstm_state = flatten_states(states) |
| for layer in range(num_layers): |
| custom_params = list(rnn.parameters())[4 * layer : 4 * (layer + 1)] |
| for lstm_param, custom_param in zip(lstm.all_weights[layer], custom_params): |
| assert lstm_param.shape == custom_param.shape |
| with torch.no_grad(): |
| lstm_param.copy_(custom_param) |
| lstm_out, lstm_out_state = lstm(inp, lstm_state) |
| |
| assert (out - lstm_out).abs().max() < 1e-5 |
| assert (custom_state[0] - lstm_out_state[0]).abs().max() < 1e-5 |
| assert (custom_state[1] - lstm_out_state[1]).abs().max() < 1e-5 |
| |
| |
| def test_script_stacked_bidir_rnn(seq_len, batch, input_size, hidden_size, num_layers): |
| inp = torch.randn(seq_len, batch, input_size) |
| states = [ |
| [ |
| LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size)) |
| for _ in range(2) |
| ] |
| for _ in range(num_layers) |
| ] |
| rnn = script_lstm(input_size, hidden_size, num_layers, bidirectional=True) |
| out, out_state = rnn(inp, states) |
| custom_state = double_flatten_states(out_state) |
| |
| # Control: pytorch native LSTM |
| lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=True) |
| lstm_state = double_flatten_states(states) |
| for layer in range(num_layers): |
| for direct in range(2): |
| index = 2 * layer + direct |
| custom_params = list(rnn.parameters())[4 * index : 4 * index + 4] |
| for lstm_param, custom_param in zip(lstm.all_weights[index], custom_params): |
| assert lstm_param.shape == custom_param.shape |
| with torch.no_grad(): |
| lstm_param.copy_(custom_param) |
| lstm_out, lstm_out_state = lstm(inp, lstm_state) |
| |
| assert (out - lstm_out).abs().max() < 1e-5 |
| assert (custom_state[0] - lstm_out_state[0]).abs().max() < 1e-5 |
| assert (custom_state[1] - lstm_out_state[1]).abs().max() < 1e-5 |
| |
| |
| def test_script_stacked_lstm_dropout( |
| seq_len, batch, input_size, hidden_size, num_layers |
| ): |
| inp = torch.randn(seq_len, batch, input_size) |
| states = [ |
| LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size)) |
| for _ in range(num_layers) |
| ] |
| rnn = script_lstm(input_size, hidden_size, num_layers, dropout=True) |
| |
| # just a smoke test |
| out, out_state = rnn(inp, states) |
| |
| |
| def test_script_stacked_lnlstm(seq_len, batch, input_size, hidden_size, num_layers): |
| inp = torch.randn(seq_len, batch, input_size) |
| states = [ |
| LSTMState(torch.randn(batch, hidden_size), torch.randn(batch, hidden_size)) |
| for _ in range(num_layers) |
| ] |
| rnn = script_lnlstm(input_size, hidden_size, num_layers) |
| |
| # just a smoke test |
| out, out_state = rnn(inp, states) |
| |
| |
| test_script_rnn_layer(5, 2, 3, 7) |
| test_script_stacked_rnn(5, 2, 3, 7, 4) |
| test_script_stacked_bidir_rnn(5, 2, 3, 7, 4) |
| test_script_stacked_lstm_dropout(5, 2, 3, 7, 4) |
| test_script_stacked_lnlstm(5, 2, 3, 7, 4) |