blob: f350a49adf2d429d278c9c98eabb08f1d4a44e12 [file] [log] [blame]
# Owner(s): ["oncall: jit"]
import os
import sys
import torch
from torch import nn
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
class Sequence(nn.Module):
def __init__(self):
super(Sequence, self).__init__()
self.lstm1 = nn.LSTMCell(1, 51)
self.lstm2 = nn.LSTMCell(51, 51)
self.linear = nn.Linear(51, 1)
def forward(self, input):
outputs = []
h_t = torch.zeros(input.size(0), 51)
c_t = torch.zeros(input.size(0), 51)
h_t2 = torch.zeros(input.size(0), 51)
c_t2 = torch.zeros(input.size(0), 51)
for input_t in input.split(1, dim=1):
h_t, c_t = self.lstm1(input_t, (h_t, c_t))
h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
output = self.linear(h_t2)
outputs += [output]
outputs = torch.cat(outputs, dim=1)
return outputs
class TestScriptProfile(JitTestCase):
def test_basic(self):
seq = torch.jit.script(Sequence())
p = torch.jit._ScriptProfile()
p.enable()
seq(torch.rand((10, 100)))
p.disable()
self.assertNotEqual(p.dump_string(), "")
def test_script(self):
seq = Sequence()
@torch.jit.script
def fn():
p = torch.jit._ScriptProfile()
p.enable()
_ = seq(torch.rand((10, 100)))
p.disable()
return p
self.assertNotEqual(fn().dump_string(), "")
def test_multi(self):
seq = torch.jit.script(Sequence())
profiles = [torch.jit._ScriptProfile() for _ in range(5)]
for p in profiles:
p.enable()
last = None
while len(profiles) > 0:
seq(torch.rand((10, 10)))
p = profiles.pop()
p.disable()
stats = p.dump_string()
self.assertNotEqual(stats, "")
if last:
self.assertNotEqual(stats, last)
last = stats
def test_section(self):
seq = Sequence()
@torch.jit.script
def fn():
p = torch.jit._ScriptProfile()
p.enable()
_ = seq(torch.rand((10, 100)))
p.disable()
stats0 = p.dump_string()
_ = seq(torch.rand((10, 10)))
stats1 = p.dump_string()
p.enable()
_ = seq(torch.rand((10, 10)))
p.disable()
stats2 = p.dump_string()
p.enable()
return stats0, stats1, stats2
s0, s1, s2 = fn()
self.assertEqual(s0, s1)
self.assertNotEqual(s1, s2)
def test_empty(self):
p = torch.jit._ScriptProfile()
p.enable()
p.disable()
self.assertEqual(p.dump_string(), "")