| import torch |
| |
| from collections import namedtuple |
| from typing import List, Tuple |
| from torch import Tensor |
| |
| from .cells import lstm_cell, premul_lstm_cell, premul_lstm_cell_no_bias, flat_lstm_cell |
| |
| |
| # list[list[T]] -> list[T] |
| def flatten_list(lst): |
| result = [] |
| for inner in lst: |
| result.extend(inner) |
| return result |
| |
| |
| ''' |
| Define a creator as a function: |
| (options) -> (inputs, params, forward, backward_setup, backward) |
| inputs: the inputs to the returned 'forward'. One can call |
| forward(*inputs) directly. |
| params: List[Tensor] all requires_grad=True parameters. |
| forward: function / graph executor / module |
| One can call rnn(rnn_inputs) using the outputs of the creator. |
| backward_setup: backward_inputs = backward_setup(*outputs) |
| Then, we pass backward_inputs to backward. If None, then it is assumed to |
| be the identity function. |
| backward: Given `output = backward_setup(*forward(*inputs))`, performs |
| backpropagation. If None, then nothing happens. |
| |
| fastrnns.bench times the forward and backward invocations. |
| ''' |
| |
| |
| ModelDef = namedtuple('ModelDef', [ |
| 'inputs', 'params', 'forward', 'backward_setup', 'backward']) |
| |
| |
| def lstm_backward_setup(lstm_outputs, seed=None): |
| hx, _ = lstm_outputs |
| return simple_backward_setup(hx, seed) |
| |
| |
| def simple_backward_setup(output, seed=None): |
| assert isinstance(output, torch.Tensor) |
| if seed: |
| torch.manual_seed(seed) |
| grad_output = torch.randn_like(output) |
| return output, grad_output |
| |
| |
| def simple_backward(output, grad_output, **kwargs): |
| return output.backward(grad_output, **kwargs) |
| |
| |
| def pytorch_lstm_creator(**kwargs): |
| input, hidden, _, module = lstm_inputs(return_module=True, **kwargs) |
| return ModelDef( |
| inputs=[input, hidden], |
| params=flatten_list(module.all_weights), |
| forward=module, |
| backward_setup=lstm_backward_setup, |
| backward=simple_backward) |
| |
| |
| def lstm_creator(script=True, **kwargs): |
| input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs) |
| inputs = [input, hidden] + params[0] |
| return ModelDef( |
| inputs=inputs, |
| params=flatten_list(params), |
| forward=lstm_factory(lstm_cell, script), |
| backward_setup=lstm_backward_setup, |
| backward=simple_backward) |
| |
| |
| def lnlstm_creator(script=True, decompose_layernorm=False, **kwargs): |
| assert script is True |
| from .custom_lstms import script_lnlstm |
| input_size = kwargs['inputSize'] |
| hidden_size = kwargs['hiddenSize'] |
| seq_len = kwargs['seqLength'] |
| batch_size = kwargs['miniBatch'] |
| ge = script_lnlstm(input_size, hidden_size, 1, |
| decompose_layernorm=decompose_layernorm).cuda() |
| |
| input = torch.randn(seq_len, batch_size, input_size, device='cuda') |
| states = [(torch.randn(batch_size, hidden_size, device='cuda'), |
| torch.randn(batch_size, hidden_size, device='cuda'))] |
| |
| return ModelDef( |
| inputs=[input, states], |
| params=ge.parameters(), |
| forward=ge, |
| backward_setup=lstm_backward_setup, |
| backward=simple_backward) |
| |
| |
| def dropoutlstm_creator(script=True, **kwargs): |
| assert script is True |
| from .custom_lstms import script_lstm, LSTMState |
| input_size = kwargs['inputSize'] |
| hidden_size = kwargs['hiddenSize'] |
| seq_len = kwargs['seqLength'] |
| batch_size = kwargs['miniBatch'] |
| num_layers = kwargs['numLayers'] |
| ge = script_lstm(input_size, hidden_size, num_layers, dropout=True).cuda() |
| |
| input = torch.randn(seq_len, batch_size, input_size, device='cuda') |
| states = [LSTMState(torch.randn(batch_size, hidden_size, device='cuda'), |
| torch.randn(batch_size, hidden_size, device='cuda')) |
| for _ in range(num_layers)] |
| return ModelDef( |
| inputs=[input, states], |
| params=ge.parameters(), |
| forward=ge, |
| backward_setup=lstm_backward_setup, |
| backward=simple_backward) |
| |
| |
| def lstm_premul_creator(script=True, **kwargs): |
| input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs) |
| inputs = [input, hidden] + params[0] |
| return ModelDef( |
| inputs=inputs, |
| params=flatten_list(params), |
| forward=lstm_factory_premul(premul_lstm_cell, script), |
| backward_setup=lstm_backward_setup, |
| backward=simple_backward) |
| |
| |
| def lstm_premul_bias_creator(script=True, **kwargs): |
| input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs) |
| inputs = [input, hidden] + params[0] |
| return ModelDef( |
| inputs=inputs, |
| params=flatten_list(params), |
| forward=lstm_factory_premul_bias(premul_lstm_cell_no_bias, script), |
| backward_setup=lstm_backward_setup, |
| backward=simple_backward) |
| |
| |
| def lstm_simple_creator(script=True, **kwargs): |
| input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs) |
| inputs = [input] + [h[0] for h in hidden] + params[0] |
| return ModelDef( |
| inputs=inputs, |
| params=flatten_list(params), |
| forward=lstm_factory_simple(flat_lstm_cell, script), |
| backward_setup=lstm_backward_setup, |
| backward=simple_backward) |
| |
| |
| def lstm_multilayer_creator(script=True, **kwargs): |
| input, hidden, params, _ = lstm_inputs(return_module=False, **kwargs) |
| inputs = [input, hidden, flatten_list(params)] |
| return ModelDef( |
| inputs=inputs, |
| params=flatten_list(params), |
| forward=lstm_factory_multilayer(lstm_cell, script), |
| backward_setup=lstm_backward_setup, |
| backward=simple_backward) |
| |
| |
| def imagenet_cnn_creator(arch, jit=True): |
| def creator(device='cuda', **kwargs): |
| model = arch().to(device) |
| x = torch.randn(32, 3, 224, 224, device=device) |
| if jit: |
| model = torch.jit.trace(model, x) |
| return ModelDef( |
| inputs=(x,), |
| params=list(model.parameters()), |
| forward=model, |
| backward_setup=simple_backward_setup, |
| backward=simple_backward) |
| |
| return creator |
| |
| |
| def varlen_lstm_inputs(minlen=30, maxlen=100, |
| numLayers=1, inputSize=512, hiddenSize=512, |
| miniBatch=64, return_module=False, device='cuda', |
| seed=None, **kwargs): |
| if seed is not None: |
| torch.manual_seed(seed) |
| lengths = torch.randint( |
| low=minlen, high=maxlen, size=[miniBatch], |
| dtype=torch.long, device=device) |
| x = [torch.randn(length, inputSize, device=device) |
| for length in lengths] |
| hx = torch.randn(numLayers, miniBatch, hiddenSize, device=device) |
| cx = torch.randn(numLayers, miniBatch, hiddenSize, device=device) |
| lstm = torch.nn.LSTM(inputSize, hiddenSize, numLayers).to(device) |
| |
| if return_module: |
| return x, lengths, (hx, cx), lstm.all_weights, lstm |
| else: |
| # NB: lstm.all_weights format: |
| # wih, whh, bih, bhh = lstm.all_weights[layer] |
| return x, lengths, (hx, cx), lstm.all_weights, None |
| |
| |
| def varlen_lstm_backward_setup(forward_output, seed=None): |
| if seed: |
| torch.manual_seed(seed) |
| rnn_utils = torch.nn.utils.rnn |
| sequences = forward_output[0] |
| padded = rnn_utils.pad_sequence(sequences) |
| grad = torch.randn_like(padded) |
| return padded, grad |
| |
| |
| def varlen_pytorch_lstm_creator(**kwargs): |
| rnn_utils = torch.nn.utils.rnn |
| sequences, _, hidden, _, module = varlen_lstm_inputs( |
| return_module=True, **kwargs) |
| |
| def forward(sequences, hidden): |
| packed = rnn_utils.pack_sequence(sequences, enforce_sorted=False) |
| out, new_hidden = module(packed, hidden) |
| padded, lengths = rnn_utils.pad_packed_sequence(out) |
| # XXX: It's more efficient to store the output in its padded form, |
| # but that might not be conducive to loss computation. |
| # Un-padding the output also makes the backward pass 2x slower... |
| # return [padded[:lengths[i], i, :] for i in range(lengths.size(0))] |
| return padded, new_hidden |
| |
| return ModelDef( |
| inputs=[sequences, hidden], |
| params=flatten_list(module.all_weights), |
| forward=forward, |
| backward_setup=lstm_backward_setup, |
| backward=simple_backward) |
| |
| |
| def varlen_lstm_factory(cell, script): |
| def dynamic_rnn(sequences: List[Tensor], hiddens: Tuple[Tensor, Tensor], wih: Tensor, |
| whh: Tensor, bih: Tensor, bhh: Tensor |
| ) -> Tuple[List[Tensor], Tuple[List[Tensor], List[Tensor]]]: |
| hx, cx = hiddens |
| hxs = hx.unbind(1) |
| cxs = cx.unbind(1) |
| # List of: (output, hx, cx) |
| outputs = [] |
| hx_outs = [] |
| cx_outs = [] |
| |
| for batch in range(len(sequences)): |
| output = [] |
| hy, cy = hxs[batch], cxs[batch] |
| inputs = sequences[batch].unbind(0) |
| |
| for seq_idx in range(len(inputs)): |
| hy, cy = cell( |
| inputs[seq_idx].unsqueeze(0), (hy, cy), wih, whh, bih, bhh) |
| output += [hy] |
| outputs += [torch.stack(output)] |
| hx_outs += [hy.unsqueeze(0)] |
| cx_outs += [cy.unsqueeze(0)] |
| |
| return outputs, (hx_outs, cx_outs) |
| |
| if script: |
| cell = torch.jit.script(cell) |
| dynamic_rnn = torch.jit.script(dynamic_rnn) |
| |
| return dynamic_rnn |
| |
| |
| def varlen_lstm_creator(script=False, **kwargs): |
| sequences, _, hidden, params, _ = varlen_lstm_inputs( |
| return_module=False, **kwargs) |
| inputs = [sequences, hidden] + params[0] |
| return ModelDef( |
| inputs=inputs, |
| params=flatten_list(params), |
| forward=varlen_lstm_factory(lstm_cell, script), |
| backward_setup=varlen_lstm_backward_setup, |
| backward=simple_backward) |
| |
| |
| # cudnn_layernorm_lstm: since cudnn does not have Layernorm LSTM, we cannot benchmark |
| # the lowerbound directly. Instead, we only benchmark the forward pass by mimicing the |
| # computation of a cudnn lstm + seq_len * 3 layernorm computation. This should serve |
| # as a perf lowerbound for the Layernorm LSTM forward pass(given that Layernorm itself |
| # is invariant), the lowerbound of backward pass is hard to get since we lose the |
| # intermediate results, we can still optimize the layernorm implementation to make |
| # a faster forward lowerbound though. |
| def layernorm_pytorch_lstm_creator(**kwargs): |
| input, hidden, _, module = lstm_inputs(return_module=True, **kwargs) |
| batch_size = kwargs['miniBatch'] |
| hidden_size = kwargs['hiddenSize'] |
| ln_i = torch.nn.LayerNorm(4 * hidden_size).cuda() |
| ln_h = torch.nn.LayerNorm(4 * hidden_size).cuda() |
| ln_c = torch.nn.LayerNorm(hidden_size).cuda() |
| ln_input1 = torch.randn(batch_size, 4 * hidden_size, device='cuda') |
| |
| def forward(input, hidden): |
| out, new_hidden = module(input, hidden) |
| # plus (seq_len * three laynorm cell computation) to mimic the lower bound of |
| # Layernorm cudnn LSTM in the forward pass |
| seq_len = len(input.unbind(0)) |
| hy, cy = new_hidden |
| for i in range(seq_len): |
| ln_i_output = ln_i(ln_input1) |
| ln_h_output = ln_h(ln_input1) |
| cy = ln_c(cy) |
| |
| return out, (hy, cy) |
| |
| return ModelDef( |
| inputs=[input, hidden], |
| params=flatten_list(module.all_weights), |
| forward=forward, |
| backward_setup=lstm_backward_setup, |
| backward=None) |
| |
| |
| # input: lstm.all_weights format (wih, whh, bih, bhh = lstm.all_weights[layer]) |
| # output: packed_weights with format |
| # packed_weights[0] is wih with size (layer, 4*hiddenSize, inputSize) |
| # packed_weights[1] is whh with size (layer, 4*hiddenSize, hiddenSize) |
| # packed_weights[2] is bih with size (layer, 4*hiddenSize) |
| # packed_weights[3] is bhh with size (layer, 4*hiddenSize) |
| def stack_weights(weights): |
| def unzip_columns(mat): |
| assert isinstance(mat, list) |
| assert isinstance(mat[0], list) |
| layers = len(mat) |
| columns = len(mat[0]) |
| return [[mat[layer][col] for layer in range(layers)] |
| for col in range(columns)] |
| |
| # XXX: script fns have problems indexing multidim lists, so we try to |
| # avoid them by stacking tensors |
| all_weights = weights |
| packed_weights = [torch.stack(param) |
| for param in unzip_columns(all_weights)] |
| return packed_weights |
| |
| |
| # returns: x, (hx, cx), all_weights, lstm module with all_weights as params |
| def lstm_inputs(seqLength=100, numLayers=1, inputSize=512, hiddenSize=512, |
| miniBatch=64, dropout=0.0, return_module=False, device='cuda', seed=None): |
| if seed is not None: |
| torch.manual_seed(seed) |
| x = torch.randn(seqLength, miniBatch, inputSize, device=device) |
| hx = torch.randn(numLayers, miniBatch, hiddenSize, device=device) |
| cx = torch.randn(numLayers, miniBatch, hiddenSize, device=device) |
| lstm = torch.nn.LSTM(inputSize, hiddenSize, numLayers, dropout=dropout) |
| if 'cuda' in device: |
| lstm = lstm.cuda() |
| |
| if return_module: |
| return x, (hx, cx), lstm.all_weights, lstm |
| else: |
| # NB: lstm.all_weights format: |
| # wih, whh, bih, bhh = lstm.all_weights[layer] |
| return x, (hx, cx), lstm.all_weights, None |
| |
| |
| def lstm_factory(cell, script): |
| def dynamic_rnn(input: Tensor, hidden: Tuple[Tensor, Tensor], wih: Tensor, whh: Tensor, |
| bih: Tensor, bhh: Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: |
| hx, cx = hidden |
| outputs = [] |
| inputs = input.unbind(0) |
| hy, cy = hx[0], cx[0] |
| for seq_idx in range(len(inputs)): |
| hy, cy = cell(inputs[seq_idx], (hy, cy), wih, whh, bih, bhh) |
| outputs += [hy] |
| return torch.stack(outputs), (hy.unsqueeze(0), cy.unsqueeze(0)) |
| |
| if script: |
| cell = torch.jit.script(cell) |
| dynamic_rnn = torch.jit.script(dynamic_rnn) |
| |
| return dynamic_rnn |
| |
| |
| # premul: we're going to premultiply the inputs & weights |
| def lstm_factory_premul(premul_cell, script): |
| def dynamic_rnn(input: Tensor, hidden: Tuple[Tensor, Tensor], wih: Tensor, whh: Tensor, |
| bih: Tensor, bhh: Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: |
| hx, cx = hidden |
| outputs = [] |
| inputs = torch.matmul(input, wih.t()).unbind(0) |
| hy, cy = hx[0], cx[0] |
| for seq_idx in range(len(inputs)): |
| hy, cy = premul_cell(inputs[seq_idx], (hy, cy), whh, bih, bhh) |
| outputs += [hy] |
| return torch.stack(outputs), (hy.unsqueeze(0), cy.unsqueeze(0)) |
| |
| if script: |
| premul_cell = torch.jit.script(premul_cell) |
| dynamic_rnn = torch.jit.script(dynamic_rnn) |
| |
| return dynamic_rnn |
| |
| |
| # premul: we're going to premultiply the inputs & weights, and add bias |
| def lstm_factory_premul_bias(premul_cell, script): |
| def dynamic_rnn(input: Tensor, hidden: Tuple[Tensor, Tensor], wih: Tensor, whh: Tensor, |
| bih: Tensor, bhh: Tensor) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: |
| hx, cx = hidden |
| outputs = [] |
| inpSize = input.size() |
| # add bias for all timesteps instead of going step-by-step, results in a single reduction kernel in the backward |
| # FIXME matmul(x,y) + bias currently goes through jit AD, and backward formula in AD is not optimized for this |
| # case. Workaround with mm and views. |
| inpSize = input.size() |
| inputs = torch.mm(input.view(-1, inpSize[2]), wih.t()) + bih |
| inputs = inputs.view(inpSize[0], inpSize[1], -1).unbind(0) |
| hy, cy = hx[0], cx[0] |
| for seq_idx in range(len(inputs)): |
| hy, cy = premul_cell(inputs[seq_idx], (hy, cy), whh, bhh) |
| outputs += [hy] |
| return torch.stack(outputs), (hy.unsqueeze(0), cy.unsqueeze(0)) |
| |
| if script: |
| premul_cell = torch.jit.script(premul_cell) |
| dynamic_rnn = torch.jit.script(dynamic_rnn) |
| |
| return dynamic_rnn |
| |
| |
| # simple: flat inputs (no tuples), no list to accumulate outputs |
| # useful mostly for benchmarking older JIT versions |
| def lstm_factory_simple(cell, script): |
| def dynamic_rnn(input, hx, cx, wih, whh, bih, bhh): |
| hy = hx # for scoping |
| cy = cx # for scoping |
| inputs = input.unbind(0) |
| for seq_idx in range(len(inputs)): |
| hy, cy = cell(inputs[seq_idx], hy, cy, wih, whh, bih, bhh) |
| return hy, cy |
| |
| if script: |
| cell = torch.jit.script(cell) |
| dynamic_rnn = torch.jit.script(dynamic_rnn) |
| |
| return dynamic_rnn |
| |
| |
| def lstm_factory_multilayer(cell, script): |
| def dynamic_rnn(input: Tensor, hidden: Tuple[Tensor, Tensor], params: List[Tensor]) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: |
| params_stride = 4 # NB: this assumes that biases are there |
| hx, cx = hidden |
| hy, cy = hidden # for scoping... |
| inputs, outputs = input.unbind(0), [] |
| for layer in range(hx.size(0)): |
| hy = hx[layer] |
| cy = cx[layer] |
| base_idx = layer * params_stride |
| wih = params[base_idx] |
| whh = params[base_idx + 1] |
| bih = params[base_idx + 2] |
| bhh = params[base_idx + 3] |
| for seq_idx in range(len(inputs)): |
| hy, cy = cell(inputs[seq_idx], (hy, cy), wih, whh, bih, bhh) |
| outputs += [hy] |
| inputs, outputs = outputs, [] |
| return torch.stack(inputs), (hy.unsqueeze(0), cy.unsqueeze(0)) |
| |
| if script: |
| cell = torch.jit.script(cell) |
| dynamic_rnn = torch.jit.script(dynamic_rnn) |
| |
| return dynamic_rnn |