| import argparse |
| import torch |
| import torch.nn as nn |
| |
| from .factory import pytorch_lstm_creator, varlen_pytorch_lstm_creator |
| from .runner import get_nn_runners |
| |
| |
| def barf(): |
| import pdb |
| pdb.set_trace() |
| |
| |
| def assertEqual(tensor, expected, threshold=0.001): |
| if isinstance(tensor, list) or isinstance(tensor, tuple): |
| for t, e in zip(tensor, expected): |
| assertEqual(t, e) |
| else: |
| if (tensor - expected).abs().max() > threshold: |
| barf() |
| |
| |
| def filter_requires_grad(tensors): |
| return [t for t in tensors if t.requires_grad] |
| |
| |
| def test_rnns(experim_creator, control_creator, check_grad=True, verbose=False, |
| seqLength=100, numLayers=1, inputSize=512, hiddenSize=512, |
| miniBatch=64, device='cuda', seed=17): |
| creator_args = dict(seqLength=seqLength, numLayers=numLayers, |
| inputSize=inputSize, hiddenSize=hiddenSize, |
| miniBatch=miniBatch, device=device, seed=seed) |
| |
| print("Setting up...") |
| control = control_creator(**creator_args) |
| experim = experim_creator(**creator_args) |
| |
| # Precondition |
| assertEqual(experim.inputs, control.inputs) |
| assertEqual(experim.params, control.params) |
| |
| print("Checking outputs...") |
| control_outputs = control.forward(*control.inputs) |
| experim_outputs = experim.forward(*experim.inputs) |
| assertEqual(experim_outputs, control_outputs) |
| |
| print("Checking grads...") |
| assert control.backward_setup is not None |
| assert experim.backward_setup is not None |
| assert control.backward is not None |
| assert experim.backward is not None |
| control_backward_inputs = control.backward_setup(control_outputs, seed) |
| experim_backward_inputs = experim.backward_setup(experim_outputs, seed) |
| |
| control.backward(*control_backward_inputs) |
| experim.backward(*experim_backward_inputs) |
| |
| control_grads = [p.grad for p in control.params] |
| experim_grads = [p.grad for p in experim.params] |
| assertEqual(experim_grads, control_grads) |
| |
| if verbose: |
| print(experim.forward.graph_for(*experim.inputs)) |
| print('') |
| |
| |
| def test_vl_py(**test_args): |
| # XXX: This compares vl_py with vl_lstm. |
| # It's done this way because those two don't give the same outputs so |
| # the result isn't an apples-to-apples comparison right now. |
| control_creator = varlen_pytorch_lstm_creator |
| name, experim_creator, context = get_nn_runners('vl_py')[0] |
| with context(): |
| print('testing {}...'.format(name)) |
| creator_keys = [ |
| 'seqLength', 'numLayers', 'inputSize', |
| 'hiddenSize', 'miniBatch', 'device', 'seed' |
| ] |
| creator_args = {key: test_args[key] for key in creator_keys} |
| |
| print("Setting up...") |
| control = control_creator(**creator_args) |
| experim = experim_creator(**creator_args) |
| |
| # Precondition |
| assertEqual(experim.inputs, control.inputs[:2]) |
| assertEqual(experim.params, control.params) |
| |
| print("Checking outputs...") |
| control_out, control_hiddens = control.forward(*control.inputs) |
| control_hx, control_cx = control_hiddens |
| experim_out, experim_hiddens = experim.forward(*experim.inputs) |
| experim_hx, experim_cx = experim_hiddens |
| |
| experim_padded = nn.utils.rnn.pad_sequence(experim_out).squeeze(-2) |
| assertEqual(experim_padded, control_out) |
| assertEqual(torch.cat(experim_hx, dim=1), control_hx) |
| assertEqual(torch.cat(experim_cx, dim=1), control_cx) |
| |
| print("Checking grads...") |
| assert control.backward_setup is not None |
| assert experim.backward_setup is not None |
| assert control.backward is not None |
| assert experim.backward is not None |
| control_backward_inputs = control.backward_setup( |
| (control_out, control_hiddens), test_args['seed']) |
| experim_backward_inputs = experim.backward_setup( |
| (experim_out, experim_hiddens), test_args['seed']) |
| |
| control.backward(*control_backward_inputs) |
| experim.backward(*experim_backward_inputs) |
| |
| control_grads = [p.grad for p in control.params] |
| experim_grads = [p.grad for p in experim.params] |
| assertEqual(experim_grads, control_grads) |
| |
| if test_args['verbose']: |
| print(experim.forward.graph_for(*experim.inputs)) |
| print('') |
| |
| |
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser(description='Test lstm correctness') |
| |
| parser.add_argument('--seqLength', default='100', type=int) |
| parser.add_argument('--numLayers', default='1', type=int) |
| parser.add_argument('--inputSize', default='512', type=int) |
| parser.add_argument('--hiddenSize', default='512', type=int) |
| parser.add_argument('--miniBatch', default='64', type=int) |
| parser.add_argument('--device', default='cuda', type=str) |
| parser.add_argument('--check_grad', default='True', type=bool) |
| parser.add_argument('--variable_lstms', action='store_true') |
| parser.add_argument('--seed', default='17', type=int) |
| parser.add_argument('--verbose', action='store_true') |
| parser.add_argument('--rnns', nargs='*', |
| help='What to run. jit_premul, jit, etc') |
| args = parser.parse_args() |
| if args.rnns is None: |
| args.rnns = ['jit_premul', 'jit'] |
| print(args) |
| |
| if 'cuda' in args.device: |
| assert torch.cuda.is_available() |
| |
| rnn_runners = get_nn_runners(*args.rnns) |
| |
| should_test_varlen_lstms = args.variable_lstms |
| test_args = vars(args) |
| del test_args['rnns'] |
| del test_args['variable_lstms'] |
| |
| if should_test_varlen_lstms: |
| test_vl_py(**test_args) |
| |
| for name, creator, context in rnn_runners: |
| with context(): |
| print('testing {}...'.format(name)) |
| test_rnns(creator, pytorch_lstm_creator, **test_args) |