| import torch |
| import torch.nn as nn |
| from torch.nn import Parameter |
| import torch.jit as jit |
| import warnings |
| from collections import namedtuple |
| from typing import List, Tuple |
| from torch import Tensor |
| import numbers |
| |
| ''' |
| Some helper classes for writing custom TorchScript LSTMs. |
| |
| Goals: |
| - Classes are easy to read, use, and extend |
| - Performance of custom LSTMs approach fused-kernel-levels of speed. |
| |
| A few notes about features we could add to clean up the below code: |
| - Support enumerate with nn.ModuleList: |
| https://github.com/pytorch/pytorch/issues/14471 |
| - Support enumerate/zip with lists: |
| https://github.com/pytorch/pytorch/issues/15952 |
| - Support overriding of class methods: |
| https://github.com/pytorch/pytorch/issues/10733 |
| - Support passing around user-defined namedtuple types for readability |
| - Support slicing w/ range. It enables reversing lists easily. |
| https://github.com/pytorch/pytorch/issues/10774 |
| - Multiline type annotations. List[List[Tuple[Tensor,Tensor]]] is verbose |
| https://github.com/pytorch/pytorch/pull/14922 |
| ''' |
| |
| |
| def script_lstm(input_size, hidden_size, num_layers, bias=True, |
| batch_first=False, dropout=False, bidirectional=False): |
| '''Returns a ScriptModule that mimics a PyTorch native LSTM.''' |
| |
| # The following are not implemented. |
| assert bias |
| assert not batch_first |
| |
| if bidirectional: |
| stack_type = StackedLSTM2 |
| layer_type = BidirLSTMLayer |
| dirs = 2 |
| elif dropout: |
| stack_type = StackedLSTMWithDropout |
| layer_type = LSTMLayer |
| dirs = 1 |
| else: |
| stack_type = StackedLSTM |
| layer_type = LSTMLayer |
| dirs = 1 |
| |
| return stack_type(num_layers, layer_type, |
| first_layer_args=[LSTMCell, input_size, hidden_size], |
| other_layer_args=[LSTMCell, hidden_size * dirs, |
| hidden_size]) |
| |
| |
| def script_lnlstm(input_size, hidden_size, num_layers, bias=True, |
| batch_first=False, dropout=False, bidirectional=False, |
| decompose_layernorm=False): |
| '''Returns a ScriptModule that mimics a PyTorch native LSTM.''' |
| |
| # The following are not implemented. |
| assert bias |
| assert not batch_first |
| assert not dropout |
| |
| if bidirectional: |
| stack_type = StackedLSTM2 |
| layer_type = BidirLSTMLayer |
| dirs = 2 |
| else: |
| stack_type = StackedLSTM |
| layer_type = LSTMLayer |
| dirs = 1 |
| |
| return stack_type(num_layers, layer_type, |
| first_layer_args=[LayerNormLSTMCell, input_size, hidden_size, |
| decompose_layernorm], |
| other_layer_args=[LayerNormLSTMCell, hidden_size * dirs, |
| hidden_size, decompose_layernorm]) |
| |
| |
| LSTMState = namedtuple('LSTMState', ['hx', 'cx']) |
| |
| |
| def reverse(lst: List[Tensor]) -> List[Tensor]: |
| return lst[::-1] |
| |
| |
| class LSTMCell(jit.ScriptModule): |
| def __init__(self, input_size, hidden_size): |
| super(LSTMCell, self).__init__() |
| self.input_size = input_size |
| self.hidden_size = hidden_size |
| self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size)) |
| self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size)) |
| self.bias_ih = Parameter(torch.randn(4 * hidden_size)) |
| self.bias_hh = Parameter(torch.randn(4 * hidden_size)) |
| |
| @jit.script_method |
| def forward(self, input: Tensor, state: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: |
| hx, cx = state |
| gates = (torch.mm(input, self.weight_ih.t()) + self.bias_ih + |
| torch.mm(hx, self.weight_hh.t()) + self.bias_hh) |
| ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) |
| |
| ingate = torch.sigmoid(ingate) |
| forgetgate = torch.sigmoid(forgetgate) |
| cellgate = torch.tanh(cellgate) |
| outgate = torch.sigmoid(outgate) |
| |
| cy = (forgetgate * cx) + (ingate * cellgate) |
| hy = outgate * torch.tanh(cy) |
| |
| return hy, (hy, cy) |
| |
| |
| class LayerNorm(jit.ScriptModule): |
| def __init__(self, normalized_shape): |
| super(LayerNorm, self).__init__() |
| if isinstance(normalized_shape, numbers.Integral): |
| normalized_shape = (normalized_shape,) |
| normalized_shape = torch.Size(normalized_shape) |
| |
| # XXX: This is true for our LSTM / NLP use case and helps simplify code |
| assert len(normalized_shape) == 1 |
| |
| self.weight = Parameter(torch.ones(normalized_shape)) |
| self.bias = Parameter(torch.zeros(normalized_shape)) |
| self.normalized_shape = normalized_shape |
| |
| @jit.script_method |
| def compute_layernorm_stats(self, input): |
| mu = input.mean(-1, keepdim=True) |
| sigma = input.std(-1, keepdim=True, unbiased=False) |
| return mu, sigma |
| |
| @jit.script_method |
| def forward(self, input): |
| mu, sigma = self.compute_layernorm_stats(input) |
| return (input - mu) / sigma * self.weight + self.bias |
| |
| |
| class LayerNormLSTMCell(jit.ScriptModule): |
| def __init__(self, input_size, hidden_size, decompose_layernorm=False): |
| super(LayerNormLSTMCell, self).__init__() |
| self.input_size = input_size |
| self.hidden_size = hidden_size |
| self.weight_ih = Parameter(torch.randn(4 * hidden_size, input_size)) |
| self.weight_hh = Parameter(torch.randn(4 * hidden_size, hidden_size)) |
| # The layernorms provide learnable biases |
| |
| if decompose_layernorm: |
| ln = LayerNorm |
| else: |
| ln = nn.LayerNorm |
| |
| self.layernorm_i = ln(4 * hidden_size) |
| self.layernorm_h = ln(4 * hidden_size) |
| self.layernorm_c = ln(hidden_size) |
| |
| @jit.script_method |
| def forward(self, input: Tensor, state: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: |
| hx, cx = state |
| igates = self.layernorm_i(torch.mm(input, self.weight_ih.t())) |
| hgates = self.layernorm_h(torch.mm(hx, self.weight_hh.t())) |
| gates = igates + hgates |
| ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1) |
| |
| ingate = torch.sigmoid(ingate) |
| forgetgate = torch.sigmoid(forgetgate) |
| cellgate = torch.tanh(cellgate) |
| outgate = torch.sigmoid(outgate) |
| |
| cy = self.layernorm_c((forgetgate * cx) + (ingate * cellgate)) |
| hy = outgate * torch.tanh(cy) |
| |
| return hy, (hy, cy) |
| |
| |
| class LSTMLayer(jit.ScriptModule): |
| def __init__(self, cell, *cell_args): |
| super(LSTMLayer, self).__init__() |
| self.cell = cell(*cell_args) |
| |
| @jit.script_method |
| def forward(self, input: Tensor, state: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: |
| inputs = input.unbind(0) |
| outputs = torch.jit.annotate(List[Tensor], []) |
| for i in range(len(inputs)): |
| out, state = self.cell(inputs[i], state) |
| outputs += [out] |
| return torch.stack(outputs), state |
| |
| |
| class ReverseLSTMLayer(jit.ScriptModule): |
| def __init__(self, cell, *cell_args): |
| super(ReverseLSTMLayer, self).__init__() |
| self.cell = cell(*cell_args) |
| |
| @jit.script_method |
| def forward(self, input: Tensor, state: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: |
| inputs = reverse(input.unbind(0)) |
| outputs = jit.annotate(List[Tensor], []) |
| for i in range(len(inputs)): |
| out, state = self.cell(inputs[i], state) |
| outputs += [out] |
| return torch.stack(reverse(outputs)), state |
| |
| |
| class BidirLSTMLayer(jit.ScriptModule): |
| __constants__ = ['directions'] |
| |
| def __init__(self, cell, *cell_args): |
| super(BidirLSTMLayer, self).__init__() |
| self.directions = nn.ModuleList([ |
| LSTMLayer(cell, *cell_args), |
| ReverseLSTMLayer(cell, *cell_args), |
| ]) |
| |
| @jit.script_method |
| def forward(self, input: Tensor, states: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: |
| # List[LSTMState]: [forward LSTMState, backward LSTMState] |
| outputs = jit.annotate(List[Tensor], []) |
| output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) |
| # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 |
| i = 0 |
| for direction in self.directions: |
| state = states[i] |
| out, out_state = direction(input, state) |
| outputs += [out] |
| output_states += [out_state] |
| i += 1 |
| return torch.cat(outputs, -1), output_states |
| |
| |
| def init_stacked_lstm(num_layers, layer, first_layer_args, other_layer_args): |
| layers = [layer(*first_layer_args)] + [layer(*other_layer_args) |
| for _ in range(num_layers - 1)] |
| return nn.ModuleList(layers) |
| |
| |
| class StackedLSTM(jit.ScriptModule): |
| __constants__ = ['layers'] # Necessary for iterating through self.layers |
| |
| def __init__(self, num_layers, layer, first_layer_args, other_layer_args): |
| super(StackedLSTM, self).__init__() |
| self.layers = init_stacked_lstm(num_layers, layer, first_layer_args, |
| other_layer_args) |
| |
| @jit.script_method |
| def forward(self, input: Tensor, states: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: |
| # List[LSTMState]: One state per layer |
| output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) |
| output = input |
| # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 |
| i = 0 |
| for rnn_layer in self.layers: |
| state = states[i] |
| output, out_state = rnn_layer(output, state) |
| output_states += [out_state] |
| i += 1 |
| return output, output_states |
| |
| |
| # Differs from StackedLSTM in that its forward method takes |
| # List[List[Tuple[Tensor,Tensor]]]. It would be nice to subclass StackedLSTM |
| # except we don't support overriding script methods. |
| # https://github.com/pytorch/pytorch/issues/10733 |
| class StackedLSTM2(jit.ScriptModule): |
| __constants__ = ['layers'] # Necessary for iterating through self.layers |
| |
| def __init__(self, num_layers, layer, first_layer_args, other_layer_args): |
| super(StackedLSTM2, self).__init__() |
| self.layers = init_stacked_lstm(num_layers, layer, first_layer_args, |
| other_layer_args) |
| |
| @jit.script_method |
| def forward(self, input: Tensor, states: List[List[Tuple[Tensor, Tensor]]]) -> Tuple[Tensor, List[List[Tuple[Tensor, Tensor]]]]: |
| # List[List[LSTMState]]: The outer list is for layers, |
| # inner list is for directions. |
| output_states = jit.annotate(List[List[Tuple[Tensor, Tensor]]], []) |
| output = input |
| # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 |
| i = 0 |
| for rnn_layer in self.layers: |
| state = states[i] |
| output, out_state = rnn_layer(output, state) |
| output_states += [out_state] |
| i += 1 |
| return output, output_states |
| |
| |
| class StackedLSTMWithDropout(jit.ScriptModule): |
| # Necessary for iterating through self.layers and dropout support |
| __constants__ = ['layers', 'num_layers'] |
| |
| def __init__(self, num_layers, layer, first_layer_args, other_layer_args): |
| super(StackedLSTMWithDropout, self).__init__() |
| self.layers = init_stacked_lstm(num_layers, layer, first_layer_args, |
| other_layer_args) |
| # Introduces a Dropout layer on the outputs of each LSTM layer except |
| # the last layer, with dropout probability = 0.4. |
| self.num_layers = num_layers |
| |
| if (num_layers == 1): |
| warnings.warn("dropout lstm adds dropout layers after all but last " |
| "recurrent layer, it expects num_layers greater than " |
| "1, but got num_layers = 1") |
| |
| self.dropout_layer = nn.Dropout(0.4) |
| |
| @jit.script_method |
| def forward(self, input: Tensor, states: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tuple[Tensor, Tensor]]]: |
| # List[LSTMState]: One state per layer |
| output_states = jit.annotate(List[Tuple[Tensor, Tensor]], []) |
| output = input |
| # XXX: enumerate https://github.com/pytorch/pytorch/issues/14471 |
| i = 0 |
| for rnn_layer in self.layers: |
| state = states[i] |
| output, out_state = rnn_layer(output, state) |
| # Apply the dropout layer except the last layer |
| if i < self.num_layers - 1: |
| output = self.dropout_layer(output) |
| output_states += [out_state] |
| i += 1 |
| return output, output_states |
| |
| |
| def flatten_states(states): |
| states = list(zip(*states)) |
| assert len(states) == 2 |
| return [torch.stack(state) for state in states] |
| |
| |
| def double_flatten_states(states): |
| # XXX: Can probably write this in a nicer way |
| states = flatten_states([flatten_states(inner) for inner in states]) |
| return [hidden.view([-1] + list(hidden.shape[2:])) for hidden in states] |
| |
| |
| def test_script_rnn_layer(seq_len, batch, input_size, hidden_size): |
| inp = torch.randn(seq_len, batch, input_size) |
| state = LSTMState(torch.randn(batch, hidden_size), |
| torch.randn(batch, hidden_size)) |
| rnn = LSTMLayer(LSTMCell, input_size, hidden_size) |
| out, out_state = rnn(inp, state) |
| |
| # Control: pytorch native LSTM |
| lstm = nn.LSTM(input_size, hidden_size, 1) |
| lstm_state = LSTMState(state.hx.unsqueeze(0), state.cx.unsqueeze(0)) |
| for lstm_param, custom_param in zip(lstm.all_weights[0], rnn.parameters()): |
| assert lstm_param.shape == custom_param.shape |
| with torch.no_grad(): |
| lstm_param.copy_(custom_param) |
| lstm_out, lstm_out_state = lstm(inp, lstm_state) |
| |
| assert (out - lstm_out).abs().max() < 1e-5 |
| assert (out_state[0] - lstm_out_state[0]).abs().max() < 1e-5 |
| assert (out_state[1] - lstm_out_state[1]).abs().max() < 1e-5 |
| |
| |
| def test_script_stacked_rnn(seq_len, batch, input_size, hidden_size, |
| num_layers): |
| inp = torch.randn(seq_len, batch, input_size) |
| states = [LSTMState(torch.randn(batch, hidden_size), |
| torch.randn(batch, hidden_size)) |
| for _ in range(num_layers)] |
| rnn = script_lstm(input_size, hidden_size, num_layers) |
| out, out_state = rnn(inp, states) |
| custom_state = flatten_states(out_state) |
| |
| # Control: pytorch native LSTM |
| lstm = nn.LSTM(input_size, hidden_size, num_layers) |
| lstm_state = flatten_states(states) |
| for layer in range(num_layers): |
| custom_params = list(rnn.parameters())[4 * layer: 4 * (layer + 1)] |
| for lstm_param, custom_param in zip(lstm.all_weights[layer], |
| custom_params): |
| assert lstm_param.shape == custom_param.shape |
| with torch.no_grad(): |
| lstm_param.copy_(custom_param) |
| lstm_out, lstm_out_state = lstm(inp, lstm_state) |
| |
| assert (out - lstm_out).abs().max() < 1e-5 |
| assert (custom_state[0] - lstm_out_state[0]).abs().max() < 1e-5 |
| assert (custom_state[1] - lstm_out_state[1]).abs().max() < 1e-5 |
| |
| |
| def test_script_stacked_bidir_rnn(seq_len, batch, input_size, hidden_size, |
| num_layers): |
| inp = torch.randn(seq_len, batch, input_size) |
| states = [[LSTMState(torch.randn(batch, hidden_size), |
| torch.randn(batch, hidden_size)) |
| for _ in range(2)] |
| for _ in range(num_layers)] |
| rnn = script_lstm(input_size, hidden_size, num_layers, bidirectional=True) |
| out, out_state = rnn(inp, states) |
| custom_state = double_flatten_states(out_state) |
| |
| # Control: pytorch native LSTM |
| lstm = nn.LSTM(input_size, hidden_size, num_layers, bidirectional=True) |
| lstm_state = double_flatten_states(states) |
| for layer in range(num_layers): |
| for direct in range(2): |
| index = 2 * layer + direct |
| custom_params = list(rnn.parameters())[4 * index: 4 * index + 4] |
| for lstm_param, custom_param in zip(lstm.all_weights[index], |
| custom_params): |
| assert lstm_param.shape == custom_param.shape |
| with torch.no_grad(): |
| lstm_param.copy_(custom_param) |
| lstm_out, lstm_out_state = lstm(inp, lstm_state) |
| |
| assert (out - lstm_out).abs().max() < 1e-5 |
| assert (custom_state[0] - lstm_out_state[0]).abs().max() < 1e-5 |
| assert (custom_state[1] - lstm_out_state[1]).abs().max() < 1e-5 |
| |
| |
| def test_script_stacked_lstm_dropout(seq_len, batch, input_size, hidden_size, |
| num_layers): |
| inp = torch.randn(seq_len, batch, input_size) |
| states = [LSTMState(torch.randn(batch, hidden_size), |
| torch.randn(batch, hidden_size)) |
| for _ in range(num_layers)] |
| rnn = script_lstm(input_size, hidden_size, num_layers, dropout=True) |
| |
| # just a smoke test |
| out, out_state = rnn(inp, states) |
| |
| |
| def test_script_stacked_lnlstm(seq_len, batch, input_size, hidden_size, |
| num_layers): |
| inp = torch.randn(seq_len, batch, input_size) |
| states = [LSTMState(torch.randn(batch, hidden_size), |
| torch.randn(batch, hidden_size)) |
| for _ in range(num_layers)] |
| rnn = script_lnlstm(input_size, hidden_size, num_layers) |
| |
| # just a smoke test |
| out, out_state = rnn(inp, states) |
| |
| |
| test_script_rnn_layer(5, 2, 3, 7) |
| test_script_stacked_rnn(5, 2, 3, 7, 4) |
| test_script_stacked_bidir_rnn(5, 2, 3, 7, 4) |
| test_script_stacked_lstm_dropout(5, 2, 3, 7, 4) |
| test_script_stacked_lnlstm(5, 2, 3, 7, 4) |