blob: 55b7736755667ce335ccadff172b7864321298cc [file] [log] [blame]
import functools
import importlib
import itertools
import json
import logging
import math
import operator
import os
import torch
from .. import config
from ..utils import check_is_cuda, checkpoint_params, is_jit_model, torchscript
log = logging.getLogger(__name__)
def cached(fn):
cached_name = f"_{fn.__name__}"
@functools.wraps(fn)
def inner(self):
if hasattr(self, cached_name):
return getattr(self, cached_name)
result = fn(self)
setattr(self, cached_name, result)
return result
return inner
def load_module_fx(name):
pymod = importlib.import_module(f"subgraphs.{name}")
# TODO(jansel): upstream these fixes to to_folder()
pymod.module._operator_iadd = operator.iadd
pymod.module._operator_imul = operator.imul
pymod.module._operator_itruediv = operator.itruediv
pymod.module._operator_setitem = operator.setitem
pymod.module.math_sqrt = math.sqrt
pymod.module.device = torch.device
pymod.module.inf = float("inf")
return pymod.FxModule()
def load_module_jit(name):
filename = os.path.join(config.base_dir, "subgraphs", name, "model.ts")
if not os.path.exists(filename):
return None
model = torch.jit.load(filename)
assert is_jit_model(model)
return model
class SubGraph(object):
@classmethod
def load(cls, name):
model_dir = os.path.join(config.base_dir, "subgraphs", name)
example_inputs = torch.load(os.path.join(model_dir, "example_inputs.pt"))
example_outputs = torch.load(os.path.join(model_dir, "example_outputs.pt"))
metadata = json.loads(open(os.path.join(model_dir, "metadata.json")).read())
model_fx = load_module_fx(name)
model_jit = load_module_jit(name)
is_cuda = metadata["is_cuda"]
assert model_jit is not None
torch.set_rng_state(torch.load(os.path.join(model_dir, "rng_state.pt")))
if is_cuda:
model_jit = model_jit.cuda()
restore_jit = checkpoint_params(model_jit)
if model_fx is not None:
if is_cuda:
model_fx = model_fx.cuda()
restore_fx = checkpoint_params(model_fx)
else:
model_fx = model_jit
restore_fx = restore_jit
def restore():
restore_fx()
restore_jit()
subgraph = cls(model_fx, example_inputs, model_dir)
subgraph._scripted = model_jit
subgraph._example_outputs = example_outputs
subgraph._is_cuda = is_cuda
subgraph.restore = restore
return subgraph
def __init__(self, model, example_inputs, model_dir):
super(SubGraph, self).__init__()
self.model = model
self.example_inputs = example_inputs
self.model_dir = model_dir
def filename(self, name):
return os.path.join(self.model_dir, name)
@property
@cached
def scripted(self):
return torchscript(self.model, self.example_inputs)
@property
@cached
def example_outputs(self):
filename = self.filename("example_outputs.pt")
if os.path.exists(filename):
return torch.load(filename)
result = self.model(*self.example_inputs)
torch.save(result, filename)
return result
@property
def example_outputs_list(self):
if self.is_tensor_output:
return [self.example_outputs]
return self.example_outputs
@property
def input_names(self):
return [f"i{i}" for i in range(len(self.example_inputs))]
@property
def is_tensor_output(self):
return not isinstance(self.example_outputs, (list, tuple))
@property
def output_names(self):
return [f"o{x}" for x in range(len(self.example_outputs_list))]
@property
def device_index(self):
return 0
@property
@cached
def onnx_filename(self):
filename = self.filename("onnx")
if os.path.exists(filename):
return filename
try:
torch.onnx.export(
self.scripted,
self.example_inputs,
filename,
input_names=self.input_names,
output_names=self.output_names,
do_constant_folding=True,
opset_version=14,
)
except IndexError:
# work around bug in constant folding pass
torch.onnx.export(
self.scripted,
self.example_inputs,
filename,
input_names=self.input_names,
output_names=self.output_names,
do_constant_folding=False,
opset_version=14,
)
return filename
@property
def is_cpu(self):
return not self.is_cuda
@property
@cached
def is_cuda(self):
return check_is_cuda(self.model, self.example_inputs)
@property
def output_specs(self):
return [
(o.shape, o.dtype, o.layout, o.device, o.requires_grad)
for o in self.example_outputs_list
]
def empty_outputs_factory(self):
specs = self.output_specs
def create():
return [
torch.empty(
shape,
dtype=dtype,
layout=layout,
device=device,
requires_grad=requires_grad,
)
for shape, dtype, layout, device, requires_grad in specs
]
return create
def wrap_returns(self, fn):
"""Fix [Tensor()] vs Tensor() return type issues"""
expected = self.example_outputs
actual = fn(*self.example_inputs)
if isinstance(expected, (list, tuple)) and not isinstance(
actual, (list, tuple)
):
assert len(expected) == 1
if isinstance(expected, tuple):
return lambda *args: (fn(*args),)
else:
return lambda *args: [fn(*args)]
elif not isinstance(expected, (list, tuple)) and isinstance(
actual, (list, tuple)
):
assert len(actual) == 1
return lambda *args: fn(*args)[0]
elif isinstance(expected, (list, tuple)) and isinstance(actual, (list, tuple)):
assert len(actual) == len(expected)
return fn
else:
return fn
def has_dtype(self, dtype):
for x in itertools.chain(
self.example_inputs, self.scripted.parameters(), self.scripted.buffers()
):
if x.dtype == dtype:
return True
return False
def will_tensorrt_barf(self):
return False
# code = torch.jit.freeze(self.scripted).code
# TODO(jansel): submit a bug report for this one, issue is in opacus_cifar10
# if "group_norm" in code or "einsum" in code:
# return True
# return self.has_dtype(torch.int64)