blob: 01a03d43ab7d26c21b12691a37b0df9e03826994 [file] [log] [blame]
#!/usr/bin/env python3
# Owner(s): ["oncall: mobile"]
import os
import io
import functools
import tempfile
import urllib
import unittest
import torch
import torch.backends.xnnpack
import torch.utils.model_dump
import torch.utils.mobile_optimizer
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS, skipIfNoXNNPACK
from torch.testing._internal.common_quantized import supported_qengines
class SimpleModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.layer1 = torch.nn.Linear(16, 64)
self.relu1 = torch.nn.ReLU()
self.layer2 = torch.nn.Linear(64, 8)
self.relu2 = torch.nn.ReLU()
def forward(self, features):
act = features
act = self.layer1(act)
act = self.relu1(act)
act = self.layer2(act)
act = self.relu2(act)
return act
class QuantModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.quant = torch.ao.quantization.QuantStub()
self.dequant = torch.ao.quantization.DeQuantStub()
self.core = SimpleModel()
def forward(self, x):
x = self.quant(x)
x = self.core(x)
x = self.dequant(x)
return x
class ModelWithLists(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.rt = [torch.zeros(1)]
self.ot = [torch.zeros(1), None]
def forward(self, arg):
arg = arg + self.rt[0]
o = self.ot[0]
if o is not None:
arg = arg + o
return arg
def webdriver_test(testfunc):
@functools.wraps(testfunc)
def wrapper(self, *args, **kwds):
self.needs_resources()
if os.environ.get("RUN_WEBDRIVER") != "1":
self.skipTest("Webdriver not requested")
from selenium import webdriver
for driver in [
"Firefox",
"Chrome",
]:
with self.subTest(driver=driver):
wd = getattr(webdriver, driver)()
testfunc(self, wd, *args, **kwds)
wd.close()
return wrapper
class TestModelDump(TestCase):
def needs_resources(self):
pass
def test_inline_skeleton(self):
self.needs_resources()
skel = torch.utils.model_dump.get_inline_skeleton()
assert "unpkg.org" not in skel
assert "src=" not in skel
def do_dump_model(self, model, extra_files=None):
# Just check that we're able to run successfully.
buf = io.BytesIO()
torch.jit.save(model, buf, _extra_files=extra_files)
info = torch.utils.model_dump.get_model_info(buf)
assert info is not None
def open_html_model(self, wd, model, extra_files=None):
buf = io.BytesIO()
torch.jit.save(model, buf, _extra_files=extra_files)
page = torch.utils.model_dump.get_info_and_burn_skeleton(buf)
wd.get("data:text/html;charset=utf-8," + urllib.parse.quote(page))
def open_section_and_get_body(self, wd, name):
container = wd.find_element_by_xpath(f"//div[@data-hider-title='{name}']")
caret = container.find_element_by_class_name("caret")
if container.get_attribute("data-shown") != "true":
caret.click()
content = container.find_element_by_tag_name("div")
return content
def test_scripted_model(self):
model = torch.jit.script(SimpleModel())
self.do_dump_model(model)
def test_traced_model(self):
model = torch.jit.trace(SimpleModel(), torch.zeros(2, 16))
self.do_dump_model(model)
def test_main(self):
self.needs_resources()
if IS_WINDOWS:
# I was getting tempfile errors in CI. Just skip it.
self.skipTest("Disabled on Windows.")
with tempfile.NamedTemporaryFile() as tf:
torch.jit.save(torch.jit.script(SimpleModel()), tf)
# Actually write contents to disk so we can read it below
tf.flush()
stdout = io.StringIO()
torch.utils.model_dump.main(
[
None,
"--style=json",
tf.name,
],
stdout=stdout)
self.assertRegex(stdout.getvalue(), r'\A{.*SimpleModel')
stdout = io.StringIO()
torch.utils.model_dump.main(
[
None,
"--style=html",
tf.name,
],
stdout=stdout)
self.assertRegex(
stdout.getvalue().replace("\n", " "),
r'\A<!DOCTYPE.*SimpleModel.*componentDidMount')
def get_quant_model(self):
fmodel = QuantModel().eval()
fmodel = torch.ao.quantization.fuse_modules(fmodel, [
["core.layer1", "core.relu1"],
["core.layer2", "core.relu2"],
])
fmodel.qconfig = torch.ao.quantization.get_default_qconfig("qnnpack")
prepped = torch.ao.quantization.prepare(fmodel)
prepped(torch.randn(2, 16))
qmodel = torch.ao.quantization.convert(prepped)
return qmodel
@unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available")
def test_quantized_model(self):
qmodel = self.get_quant_model()
self.do_dump_model(torch.jit.script(qmodel))
@skipIfNoXNNPACK
@unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available")
def test_optimized_quantized_model(self):
qmodel = self.get_quant_model()
smodel = torch.jit.trace(qmodel, torch.zeros(2, 16))
omodel = torch.utils.mobile_optimizer.optimize_for_mobile(smodel)
self.do_dump_model(omodel)
def test_model_with_lists(self):
model = torch.jit.script(ModelWithLists())
self.do_dump_model(model)
def test_invalid_json(self):
model = torch.jit.script(SimpleModel())
self.do_dump_model(model, extra_files={"foo.json": "{"})
@webdriver_test
def test_memory_computation(self, wd):
def check_memory(model, expected):
self.open_html_model(wd, model)
memory_table = self.open_section_and_get_body(wd, "Tensor Memory")
device = memory_table.find_element_by_xpath("//table/tbody/tr[1]/td[1]").text
self.assertEqual("cpu", device)
memory_usage_str = memory_table.find_element_by_xpath("//table/tbody/tr[1]/td[2]").text
self.assertEqual(expected, int(memory_usage_str))
simple_model_memory = (
# First layer, including bias.
64 * (16 + 1) +
# Second layer, including bias.
8 * (64 + 1)
# 32-bit float
) * 4
check_memory(torch.jit.script(SimpleModel()), simple_model_memory)
# The same SimpleModel instance appears twice in this model.
# The tensors will be shared, so ensure no double-counting.
a_simple_model = SimpleModel()
check_memory(
torch.jit.script(
torch.nn.Sequential(a_simple_model, a_simple_model)),
simple_model_memory)
# The freezing process will move the weight and bias
# from data to constants. Ensure they are still counted.
check_memory(
torch.jit.freeze(torch.jit.script(SimpleModel()).eval()),
simple_model_memory)
# Make sure we can handle a model with both constants and data tensors.
class ComposedModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.w1 = torch.zeros(1, 2)
self.w2 = torch.ones(2, 2)
def forward(self, arg):
return arg * self.w2 + self.w1
check_memory(
torch.jit.freeze(
torch.jit.script(ComposedModule()).eval(),
preserved_attrs=["w1"]),
4 * (2 + 4))
if __name__ == '__main__':
run_tests()