blob: 6b8b7dad9cee0938cb78921266fa577c50132f00 [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import dataclasses
import unittest.mock
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.testing import same
try:
from transformers import modeling_outputs
from transformers.configuration_utils import PretrainedConfig
from transformers.file_utils import ModelOutput
from transformers.modeling_outputs import BaseModelOutput
except ImportError:
modeling_outputs = None
def maybe_skip(fn):
if modeling_outputs is None:
return unittest.skip("requires HuggingFace")(fn)
return fn
class TestHFPretrained(torch._dynamo.test_case.TestCase):
@maybe_skip
def test_pretrained(self):
def fn(a, tmp):
if hasattr(tmp, "somekey"):
a = a + 1
if tmp.return_dict:
return a + torch.ones(2) * tmp.max_length
return a
x = torch.randn(2)
tmp = PretrainedConfig(return_dict=True, max_length=20)
ref = fn(x, tmp)
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
res = opt_fn(x, tmp)
self.assertTrue(same(ref, res))
class TestModelOutput(torch._dynamo.test_case.TestCase):
@maybe_skip
def test_mo_create(self):
def fn(a, b):
tmp = BaseModelOutput(a + 1, attentions=b + 3)
return tmp
torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=2)
@maybe_skip
def test_mo_assign(self):
def fn(a, b):
tmp = BaseModelOutput(last_hidden_state=b + 3)
tmp.hidden_states = a + 7
tmp["attentions"] = a + b + 6
return tmp
args = [torch.randn(10), torch.randn(10)]
obj1 = fn(*args)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize_assert(cnts)(fn)
obj2 = opt_fn(*args)
self.assertTrue(same(obj1.last_hidden_state, obj2.last_hidden_state))
self.assertTrue(same(obj1.hidden_states, obj2.hidden_states))
self.assertTrue(same(obj1.attentions, obj2.attentions))
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 4)
def _common(self, fn, op_count):
args = [
BaseModelOutput(
last_hidden_state=torch.randn(10), attentions=torch.randn(10)
)
]
obj1 = fn(*args)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize_assert(cnts)(fn)
obj2 = opt_fn(*args)
self.assertTrue(same(obj1, obj2))
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, op_count)
@maybe_skip
def test_mo_getattr(self):
def fn(obj: BaseModelOutput):
x = obj.last_hidden_state * 10
if obj.hidden_states is not None:
x += obj.hidden_states
if obj.attentions is not None:
x += obj.attentions
return x
self._common(fn, 2)
@maybe_skip
def test_mo_getitem(self):
def fn(obj: BaseModelOutput):
x = obj["last_hidden_state"] * 10
if "hidden_stats" in obj:
x += obj["hidden_states"]
if "attentions" in obj:
x += obj["attentions"]
return x
self._common(fn, 2)
@maybe_skip
def test_mo_tuple(self):
def fn(obj: BaseModelOutput):
a, b = obj.to_tuple()
return a + b * 10
self._common(fn, 2)
@maybe_skip
def test_mo_index(self):
def fn(obj: BaseModelOutput):
return obj[0] * 10 + obj[1]
self._common(fn, 2)
@maybe_skip
def test_mo_init(self):
@dataclasses.dataclass
class MyDataClass(ModelOutput):
a: torch.Tensor
b: torch.Tensor = None
c: torch.Tensor = None
d: torch.Tensor = None
e: torch.Tensor = None
def fn(obj):
class_fields = dataclasses.fields(obj)
assert len(class_fields)
assert all(field.default is None for field in class_fields[1:])
other_fields_are_none = all(
getattr(obj, field.name) is None for field in class_fields[1:]
)
assert not other_fields_are_none
total = getattr(obj, class_fields[0].name)
for field in class_fields[1:]:
v = getattr(obj, field.name)
if v is not None:
total += v
return total
tensors = [torch.randn(10), torch.randn(10), torch.randn(10)]
obj1 = MyDataClass(*tensors)
correct1 = fn(obj1)
obj2 = MyDataClass(*tensors)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
self.assertTrue(same(opt_fn(obj2), correct1))
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 2)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()