blob: 26f81ee0182d98c70ab6f62785d029e26ae7f204 [file] [log] [blame]
import argparse
import torch
import torch.nn as nn
from .factory import pytorch_lstm_creator, varlen_pytorch_lstm_creator
from .runner import get_nn_runners
def barf():
import pdb
pdb.set_trace()
def assertEqual(tensor, expected, threshold=0.001):
if isinstance(tensor, (list, tuple)):
for t, e in zip(tensor, expected):
assertEqual(t, e)
else:
if (tensor - expected).abs().max() > threshold:
barf()
def filter_requires_grad(tensors):
return [t for t in tensors if t.requires_grad]
def test_rnns(
experim_creator,
control_creator,
check_grad=True,
verbose=False,
seqLength=100,
numLayers=1,
inputSize=512,
hiddenSize=512,
miniBatch=64,
device="cuda",
seed=17,
):
creator_args = dict(
seqLength=seqLength,
numLayers=numLayers,
inputSize=inputSize,
hiddenSize=hiddenSize,
miniBatch=miniBatch,
device=device,
seed=seed,
)
print("Setting up...")
control = control_creator(**creator_args)
experim = experim_creator(**creator_args)
# Precondition
assertEqual(experim.inputs, control.inputs)
assertEqual(experim.params, control.params)
print("Checking outputs...")
control_outputs = control.forward(*control.inputs)
experim_outputs = experim.forward(*experim.inputs)
assertEqual(experim_outputs, control_outputs)
print("Checking grads...")
assert control.backward_setup is not None
assert experim.backward_setup is not None
assert control.backward is not None
assert experim.backward is not None
control_backward_inputs = control.backward_setup(control_outputs, seed)
experim_backward_inputs = experim.backward_setup(experim_outputs, seed)
control.backward(*control_backward_inputs)
experim.backward(*experim_backward_inputs)
control_grads = [p.grad for p in control.params]
experim_grads = [p.grad for p in experim.params]
assertEqual(experim_grads, control_grads)
if verbose:
print(experim.forward.graph_for(*experim.inputs))
print("")
def test_vl_py(**test_args):
# XXX: This compares vl_py with vl_lstm.
# It's done this way because those two don't give the same outputs so
# the result isn't an apples-to-apples comparison right now.
control_creator = varlen_pytorch_lstm_creator
name, experim_creator, context = get_nn_runners("vl_py")[0]
with context():
print(f"testing {name}...")
creator_keys = [
"seqLength",
"numLayers",
"inputSize",
"hiddenSize",
"miniBatch",
"device",
"seed",
]
creator_args = {key: test_args[key] for key in creator_keys}
print("Setting up...")
control = control_creator(**creator_args)
experim = experim_creator(**creator_args)
# Precondition
assertEqual(experim.inputs, control.inputs[:2])
assertEqual(experim.params, control.params)
print("Checking outputs...")
control_out, control_hiddens = control.forward(*control.inputs)
control_hx, control_cx = control_hiddens
experim_out, experim_hiddens = experim.forward(*experim.inputs)
experim_hx, experim_cx = experim_hiddens
experim_padded = nn.utils.rnn.pad_sequence(experim_out).squeeze(-2)
assertEqual(experim_padded, control_out)
assertEqual(torch.cat(experim_hx, dim=1), control_hx)
assertEqual(torch.cat(experim_cx, dim=1), control_cx)
print("Checking grads...")
assert control.backward_setup is not None
assert experim.backward_setup is not None
assert control.backward is not None
assert experim.backward is not None
control_backward_inputs = control.backward_setup(
(control_out, control_hiddens), test_args["seed"]
)
experim_backward_inputs = experim.backward_setup(
(experim_out, experim_hiddens), test_args["seed"]
)
control.backward(*control_backward_inputs)
experim.backward(*experim_backward_inputs)
control_grads = [p.grad for p in control.params]
experim_grads = [p.grad for p in experim.params]
assertEqual(experim_grads, control_grads)
if test_args["verbose"]:
print(experim.forward.graph_for(*experim.inputs))
print("")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Test lstm correctness")
parser.add_argument("--seqLength", default="100", type=int)
parser.add_argument("--numLayers", default="1", type=int)
parser.add_argument("--inputSize", default="512", type=int)
parser.add_argument("--hiddenSize", default="512", type=int)
parser.add_argument("--miniBatch", default="64", type=int)
parser.add_argument("--device", default="cuda", type=str)
parser.add_argument("--check-grad", "--check_grad", default="True", type=bool)
parser.add_argument("--variable-lstms", "--variable_lstms", action="store_true")
parser.add_argument("--seed", default="17", type=int)
parser.add_argument("--verbose", action="store_true")
parser.add_argument("--rnns", nargs="*", help="What to run. jit_premul, jit, etc")
args = parser.parse_args()
if args.rnns is None:
args.rnns = ["jit_premul", "jit"]
print(args)
if "cuda" in args.device:
assert torch.cuda.is_available()
rnn_runners = get_nn_runners(*args.rnns)
should_test_varlen_lstms = args.variable_lstms
test_args = vars(args)
del test_args["rnns"]
del test_args["variable_lstms"]
if should_test_varlen_lstms:
test_vl_py(**test_args)
for name, creator, context in rnn_runners:
with context():
print(f"testing {name}...")
test_rnns(creator, pytorch_lstm_creator, **test_args)