| # Owner(s): ["module: dynamo"] |
| import torch |
| |
| import torch._dynamo.test_case |
| import torch._dynamo.testing |
| from torch._dynamo import eval_frame |
| |
| c = 10 |
| |
| |
| def fn1(a, b): |
| return a + b - c |
| |
| |
| def fn2(a, b): |
| x = 0 |
| y = 1 |
| |
| def modify(): |
| nonlocal x |
| x += a + b + c |
| |
| for _ in range(2): |
| modify() |
| |
| return x + y |
| |
| |
| def fn3(): |
| yield 1 |
| yield 2 |
| |
| |
| with_debug_nops = eval_frame._optimize_catch_errors( |
| torch._dynamo.testing.debug_insert_nops |
| ) |
| |
| |
| class NopTests(torch._dynamo.test_case.TestCase): |
| @with_debug_nops |
| def test1(self): |
| self.assertEqual(fn1(1, 2), -7) |
| self.assertEqual(fn1(1, 2), -7) |
| |
| @with_debug_nops |
| def test2(self): |
| self.assertEqual(fn2(1, 2), 27) |
| self.assertEqual(fn2(1, 2), 27) |
| |
| @with_debug_nops |
| def test3(self): |
| t = fn3() |
| self.assertEqual(next(t), 1) |
| self.assertEqual(next(t), 2) |
| self.assertRaises(StopIteration, lambda: next(t)) |
| |
| def test_extended_args(self): |
| too_many_adds = "+".join(["a", "b"] * 256) |
| source = ( |
| f"lambda a, b: ({too_many_adds}+a if a.sum() > 0 else {too_many_adds} - b)" |
| ) |
| fn = eval(source) |
| a = torch.ones(1) |
| b = torch.ones(1) |
| fn = with_debug_nops(fn) |
| self.assertEqual(fn(a, b).sum(), 513) |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| |
| run_tests() |