blob: d6082ce48acf833aa64469249a9226e52fb4e720 [file] [log] [blame]
import contextlib
import dis
import functools
import logging
import os.path
import types
import unittest
from unittest.mock import patch
import torch
from torch import fx
from . import config, eval_frame, optimize_assert, reset
from .bytecode_transformation import (
create_instruction,
debug_checks,
is_generator,
transform_code_object,
)
from .guards import CheckFunctionManager, GuardedCode
from .utils import same
unsupported = eval_frame.unsupported
three = 3
log = logging.getLogger(__name__)
def clone_me(x):
if x is None:
return None
return x.detach().clone().requires_grad_(x.requires_grad)
def collect_results(model, prediction, loss, example_inputs):
results = []
results.append(prediction)
results.append(loss)
if isinstance(loss, torch.Tensor) and loss.item() > 1:
log.warning(
f"High loss value alert - {loss:.2f}. Can result in unstable gradients."
)
grads = dict()
params = dict()
for name, param in model.named_parameters():
param_copy = param
grad = param.grad
# Treat None and zero grad as same
if param.grad is None:
grad = torch.zeros_like(param)
grads[name + ".grad"] = grad
params[name] = param_copy
results.append(grads)
results.append(params)
for example in example_inputs:
if isinstance(example, (tuple, list)):
for inp in example:
if isinstance(inp, torch.Tensor):
results.append(inp.grad)
else:
if isinstance(example, torch.Tensor):
results.append(example.grad)
return results
def requires_bwd_pass(out):
if isinstance(out, torch.Tensor):
return out.requires_grad
elif isinstance(out, (list, tuple)):
return any([requires_bwd_pass(x) for x in out])
elif out is None:
return False
raise NotImplementedError("Don't know how to reduce", type(out))
def reduce_to_scalar_loss(out):
"""Reduce the output of a model to get scalar loss"""
if isinstance(out, torch.Tensor):
# Mean does not work on integer tensors
return out.sum() / out.numel()
elif isinstance(out, (list, tuple)):
return sum([reduce_to_scalar_loss(x) for x in out]) / len(out)
elif type(out).__name__ in (
"MaskedLMOutput",
"Seq2SeqLMOutput",
"CausalLMOutputWithCrossAttentions",
):
return reduce_to_scalar_loss(out.logits)
elif type(out).__name__ == "SquashedNormal":
return out.mean.sum()
elif isinstance(out, dict):
return sum([reduce_to_scalar_loss(value) for value in out.values()]) / len(
out.keys()
)
raise NotImplementedError("Don't know how to reduce", type(out))
def debug_dir():
path = os.path.join(os.path.dirname(__file__), "../debug")
if not os.path.exists(path):
os.mkdir(path)
return path
def debug_dump(name, code: types.CodeType, extra=""):
with open(os.path.join(debug_dir(), name), "w") as fd:
fd.write(
f"{dis.Bytecode(code).info()}\n\n{dis.Bytecode(code).dis()}\n\n{extra}\n"
)
def debug_insert_nops(frame, cache_size):
"""used to debug jump updates"""
def insert_nops(instructions, code_options):
instructions.insert(0, create_instruction("NOP"))
instructions.insert(0, create_instruction("NOP"))
if is_generator(frame.f_code):
return None
debug_checks(frame.f_code)
code = transform_code_object(frame.f_code, insert_nops)
return GuardedCode(code, CheckFunctionManager().check_fn)
class CompileCounter:
def __init__(self):
self.frame_count = 0
self.op_count = 0
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
self.frame_count += 1
for node in gm.graph.nodes:
if "call" in node.op:
self.op_count += 1
return gm.forward
def clear(self):
self.frame_count = 0
self.op_count = 0
class CompileCounterWithBackend:
def __init__(self, backend):
self.frame_count = 0
self.op_count = 0
self.backend = backend
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
from torch._dynamo.eval_frame import lookup_backend
self.frame_count += 1
for node in gm.graph.nodes:
if "call" in node.op:
self.op_count += 1
return lookup_backend(self.backend)(gm, example_inputs)
def standard_test(self, fn, nargs, expected_ops=None, expected_ops_dynamic=None):
if config.dynamic_shapes and expected_ops_dynamic is not None:
expected_ops = expected_ops_dynamic
actual = CompileCounter()
if expected_ops is None:
expected = CompileCounter()
try:
gm = torch.fx.symbolic_trace(fn)
expected(gm)
print("\nfx.symbolic_trace graph:")
gm.graph.print_tabular()
expected_ops = expected.op_count
except Exception:
pass # Silently ignore FX errors (not our issue)
args1 = [torch.randn(10, 10) for _ in range(nargs)]
args2 = [torch.randn(10, 10) for _ in range(nargs)]
correct1 = fn(*args1)
correct2 = fn(*args2)
reset()
opt_fn = optimize_assert(actual)(fn)
val1a = opt_fn(*args1)
val2a = opt_fn(*args2)
val1b = opt_fn(*args1)
val2b = opt_fn(*args2)
reset()
self.assertTrue(same(val1a, correct1))
self.assertTrue(same(val1b, correct1))
self.assertTrue(same(val2a, correct2))
self.assertTrue(same(val2b, correct2))
self.assertEqual(actual.frame_count, 1)
if expected_ops is not None:
self.assertEqual(actual.op_count, expected_ops)
def dummy_fx_compile(gm: fx.GraphModule, example_inputs):
return gm.forward
def format_speedup(speedup, pvalue, is_correct=True, pvalue_threshold=0.1):
if not is_correct:
return "ERROR"
if pvalue > pvalue_threshold:
return f"{speedup:.3f}x SAME"
return f"{speedup:.3f}x p={pvalue:.2f}"
def requires_static_shapes(fn):
@functools.wraps(fn)
def _fn(*args, **kwargs):
if config.dynamic_shapes:
raise unittest.SkipTest("requires static shapes")
return fn(*args, **kwargs)
return _fn
def rand_strided(size, stride, dtype=torch.float32, device="cpu"):
needed_size = sum((shape - 1) * stride for shape, stride in zip(size, stride)) + 1
if dtype.is_floating_point:
buffer = torch.randn(needed_size, dtype=dtype, device=device)
else:
buffer = torch.ones(size=[needed_size], dtype=dtype, device=device)
return torch.as_strided(buffer, size, stride)
def _make_fn_with_patches(fn, *patches):
@functools.wraps(fn)
def _fn(*args, **kwargs):
with contextlib.ExitStack() as stack:
for attr, val in patches:
stack.enter_context(patch.object(config, attr, val))
return fn(*args, **kwargs)
return _fn
def make_test_cls_with_patches(cls, cls_prefix, fn_suffix, *patches):
class DummyTestClass(cls):
pass
DummyTestClass.__name__ = f"{cls_prefix}{cls.__name__}"
for name in dir(cls):
if name.startswith("test_"):
fn = getattr(cls, name)
if not callable(fn):
continue
new_name = f"{name}{fn_suffix}"
fn = _make_fn_with_patches(fn, *patches)
fn.__name__ = new_name
setattr(DummyTestClass, name, None)
setattr(DummyTestClass, new_name, fn)
return DummyTestClass