| # Owner(s): ["module: dynamo"] |
| import unittest.mock |
| |
| import torch |
| |
| import torch._dynamo.test_case |
| import torch._dynamo.testing |
| from torch._dynamo.testing import same |
| |
| try: |
| from diffusers.models import unet_2d |
| except ImportError: |
| unet_2d = None |
| |
| |
| def maybe_skip(fn): |
| if unet_2d is None: |
| return unittest.skip("requires diffusers")(fn) |
| return fn |
| |
| |
| class TestBaseOutput(torch._dynamo.test_case.TestCase): |
| @maybe_skip |
| def test_create(self): |
| def fn(a): |
| tmp = unet_2d.UNet2DOutput(a + 1) |
| return tmp |
| |
| torch._dynamo.testing.standard_test(self, fn=fn, nargs=1, expected_ops=1) |
| |
| @maybe_skip |
| def test_assign(self): |
| def fn(a): |
| tmp = unet_2d.UNet2DOutput(a + 1) |
| tmp.sample = a + 2 |
| return tmp |
| |
| args = [torch.randn(10)] |
| obj1 = fn(*args) |
| |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize_assert(cnts)(fn) |
| obj2 = opt_fn(*args) |
| self.assertTrue(same(obj1.sample, obj2.sample)) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, 2) |
| |
| def _common(self, fn, op_count): |
| args = [ |
| unet_2d.UNet2DOutput( |
| sample=torch.randn(10), |
| ) |
| ] |
| obj1 = fn(*args) |
| cnts = torch._dynamo.testing.CompileCounter() |
| opt_fn = torch._dynamo.optimize_assert(cnts)(fn) |
| obj2 = opt_fn(*args) |
| self.assertTrue(same(obj1, obj2)) |
| self.assertEqual(cnts.frame_count, 1) |
| self.assertEqual(cnts.op_count, op_count) |
| |
| @maybe_skip |
| def test_getattr(self): |
| def fn(obj: unet_2d.UNet2DOutput): |
| x = obj.sample * 10 |
| return x |
| |
| self._common(fn, 1) |
| |
| @maybe_skip |
| def test_getitem(self): |
| def fn(obj: unet_2d.UNet2DOutput): |
| x = obj["sample"] * 10 |
| return x |
| |
| self._common(fn, 1) |
| |
| @maybe_skip |
| def test_tuple(self): |
| def fn(obj: unet_2d.UNet2DOutput): |
| a = obj.to_tuple() |
| return a[0] * 10 |
| |
| self._common(fn, 1) |
| |
| @maybe_skip |
| def test_index(self): |
| def fn(obj: unet_2d.UNet2DOutput): |
| return obj[0] * 10 |
| |
| self._common(fn, 1) |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| |
| run_tests() |