| # Owner(s): ["module: dynamo"] |
| |
| import torch |
| import torch._dynamo.test_case |
| |
| |
| def target_with_varkwargs(arg1, /, positional_only_arg, *, keyword_only_arg, **kwargs): |
| local = 1 |
| return { |
| "local": local, |
| "arg1": arg1, |
| "positional_only_arg": positional_only_arg, |
| "keyword_only_arg": keyword_only_arg, |
| "kwargs": kwargs, |
| } |
| |
| |
| def varkwargs_code1(arg1, /, positional_only_arg, *, keyword_only_arg, **kwargs): |
| # remove a local variable: local = 1 |
| return { |
| "local": 1, |
| "arg1": arg1, |
| "positional_only_arg": positional_only_arg, |
| "keyword_only_arg": keyword_only_arg, |
| "kwargs": kwargs, |
| } |
| |
| |
| def varkwargs_code2(arg1, /, positional_only_arg, *, keyword_only_arg, **kwargs): |
| # introduce a local variable |
| local1 = 0 |
| local2 = 1 |
| return { |
| "local": local1 + local2, |
| "arg1": arg1, |
| "positional_only_arg": positional_only_arg, |
| "keyword_only_arg": keyword_only_arg, |
| "kwargs": kwargs, |
| } |
| |
| |
| def target_with_varargs(arg1, /, positional_only_arg, *varargs, **kwargs): |
| local = 1 |
| return { |
| "local": local, |
| "arg1": arg1, |
| "positional_only_arg": positional_only_arg, |
| "varargs": varargs, |
| "kwargs": kwargs, |
| } |
| |
| |
| def varargs_code1(arg1, /, positional_only_arg, *varargs, **kwargs): |
| # remove a local variable: local = 1 |
| return { |
| "local": 1, |
| "arg1": arg1, |
| "positional_only_arg": positional_only_arg, |
| "varargs": varargs, |
| "kwargs": kwargs, |
| } |
| |
| |
| def varargs_code2(arg1, /, positional_only_arg, *varargs, **kwargs): |
| # introduce a local variable |
| local1 = 0 |
| local2 = 1 |
| return { |
| "local": local1 + local2, |
| "arg1": arg1, |
| "positional_only_arg": positional_only_arg, |
| "varargs": varargs, |
| "kwargs": kwargs, |
| } |
| |
| |
| class FrameInitTests(torch._dynamo.test_case.TestCase): |
| def test_frame_init(self): |
| code_map1 = { |
| target_with_varargs.__code__: varargs_code1.__code__, |
| target_with_varkwargs.__code__: varkwargs_code1.__code__, |
| } |
| code_map2 = { |
| target_with_varargs.__code__: varargs_code2.__code__, |
| target_with_varkwargs.__code__: varkwargs_code2.__code__, |
| } |
| |
| def callback1(frame, cache_entry, frame_state): |
| if frame.f_code in code_map1: |
| transformed_code = code_map1[frame.f_code] |
| return torch._dynamo.types.GuardedCode( |
| transformed_code, lambda f_locals: True |
| ) |
| return None |
| |
| def callback2(frame, cache_entry, frame_state): |
| if frame.f_code in code_map2: |
| transformed_code = code_map2[frame.f_code] |
| return torch._dynamo.types.GuardedCode( |
| transformed_code, lambda f_locals: True |
| ) |
| return None |
| |
| for callback in [callback1, callback2]: |
| torch._dynamo.reset() |
| expected_varargs_output = target_with_varargs( |
| 1, 2, 3, 4, name1=1, name2=2, name3=3 |
| ) |
| expected_kwargs_output = target_with_varkwargs( |
| 1, 2, keyword_only_arg=1, name2=2, name3=3 |
| ) |
| original = torch._dynamo.eval_frame.set_eval_frame(callback1) |
| real_varargs_output = target_with_varargs( |
| 1, 2, 3, 4, name1=1, name2=2, name3=3 |
| ) |
| real_kwargs_output = target_with_varkwargs( |
| 1, 2, keyword_only_arg=1, name2=2, name3=3 |
| ) |
| self.assertEqual(real_varargs_output, expected_varargs_output) |
| self.assertEqual(real_kwargs_output, expected_kwargs_output) |
| torch._dynamo.eval_frame.set_eval_frame(original) |
| |
| |
| if __name__ == "__main__": |
| from torch._dynamo.test_case import run_tests |
| |
| run_tests() |