| # Owner(s): ["module: dynamo"] |
| import torch |
| import torch._dynamo |
| import torch._dynamo.test_case |
| import torch._dynamo.testing |
| from torch._dynamo.testing import CompileCounter |
| |
| |
| class TestInputAttrTracking(torch._dynamo.test_case.TestCase): |
| def test_tensor_property_on_tensor(self): |
| def fn(x): |
| return x * x.y |
| |
| x_ = torch.randn([2, 2]) |
| y_ = torch.randn([2, 2]) |
| x_.y = y_ |
| |
| eager_result = fn(x_) |
| |
| graph = None |
| |
| def grab_graph_backend(gm, inps): |
| nonlocal graph |
| graph = gm |
| return gm |
| |
| fn = torch._dynamo.optimize(grab_graph_backend, nopython=True)(fn) |
| compile_result = fn(x_) |
| self.assertEqual(eager_result, compile_result) |
| |
| placeholder_cnt = 0 |
| for node in graph.graph.nodes: |
| if node.op == "placeholder": |
| placeholder_cnt += 1 |
| |
| # We want to be very sure that this lifts y to inputs! |
| self.assertEqual(placeholder_cnt, 2) |
| |
| def test_tensor_property_assigned_on_tensor(self): |
| def fn(x, y): |
| x.y = y |
| return x * x.y |
| |
| x_ = torch.randn([2, 2]) |
| y_ = torch.randn([2, 2]) |
| |
| eager_result = fn(x_, y_) |
| |
| graph = None |
| |
| def grab_graph_backend(gm, inps): |
| nonlocal graph |
| graph = gm |
| return gm |
| |
| fn = torch._dynamo.optimize(grab_graph_backend, nopython=True)(fn) |
| compile_result = fn(x_, y_) |
| self.assertEqual(eager_result, compile_result) |
| |
| placeholder_cnt = 0 |
| for node in graph.graph.nodes: |
| if node.op == "placeholder": |
| placeholder_cnt += 1 |
| |
| # y is already an input |
| self.assertEqual(placeholder_cnt, 2) |
| |
| def test_const_property_on_tensor(self): |
| def fn(x): |
| return x * x.y |
| |
| x_ = torch.randn([2, 2]) |
| y_ = 4 |
| x_.y = y_ |
| |
| eager_result = fn(x_) |
| |
| graph = None |
| |
| def grab_graph_backend(gm, inps): |
| nonlocal graph |
| graph = gm |
| return gm |
| |
| fn = torch._dynamo.optimize(grab_graph_backend, nopython=True)(fn) |
| compile_result = fn(x_) |
| self.assertEqual(eager_result, compile_result) |
| |
| placeholder_cnt = 0 |
| for node in graph.graph.nodes: |
| if node.op == "placeholder": |
| placeholder_cnt += 1 |
| |
| # We want to be very sure that this does not lifts y to inputs, as its a const |
| self.assertEqual(placeholder_cnt, 1) |
| |
| def test_const_property_assigned_on_tensor(self): |
| def fn(x, y): |
| x.y = y |
| return x * x.y |
| |
| x_ = torch.randn([2, 2]) |
| y_ = 4 |
| |
| eager_result = fn(x_, y_) |
| |
| fn = torch._dynamo.optimize("eager", nopython=True)(fn) |
| compile_result = fn(x_, y_) |
| self.assertEqual(eager_result, compile_result) |
| |
| def test_guards_correctly_property_assigned_on_tensor_type_change(self): |
| def fn(x, y): |
| x.y = y |
| return x * x.y |
| |
| x_ = torch.randn([2, 2]) |
| |
| fn = torch._dynamo.optimize("eager", nopython=True)(fn) |
| compile_result_const = fn(x_, 4) |
| self.assertEqual(compile_result_const, x_ * 4) |
| |
| y = torch.randn([2, 2]) |
| compile_result_tensor = fn(x_, y) |
| self.assertEqual(compile_result_tensor, x_ * y) |
| |
| def test_guards_correctly_property_assigned_on_tensor_type_change_inductor(self): |
| def fn(x, y): |
| x.y = y |
| return x * x.y |
| |
| x_ = torch.randn([2, 2]) |
| |
| fn = torch._dynamo.optimize("inductor", nopython=True)(fn) |
| compile_result_const = fn(x_, 4) |
| self.assertEqual(compile_result_const, x_ * 4) |
| |
| y = torch.randn([2, 2]) |
| compile_result_tensor = fn(x_, y) |
| self.assertEqual(compile_result_tensor, x_ * y) |
| |
| def test_complex_attr_access_without_graph_breaks(self): |
| def fn(x, y, z): |
| for t in x: |
| t.y = y |
| t.z = y * z |
| |
| new_y = 1 |
| new_z = 1 |
| for t in x: |
| new_y = t.y * new_y |
| new_z = t.z * new_z |
| |
| return new_y, new_z |
| |
| x_0 = torch.randn([2, 2]) |
| x_1 = torch.randn([2, 2]) |
| x_2 = torch.randn([2, 2]) |
| x = [x_0, x_1, x_2] |
| |
| y = torch.randn([2, 2]) |
| z = 5 |
| |
| eager_result = fn(x, y, z) |
| |
| counter = CompileCounter() |
| fn = torch._dynamo.optimize(counter, nopython=True)(fn) |
| |
| compile_result = fn(x, y, z) |
| self.assertEqual(compile_result, eager_result) |
| self.assertEqual(counter.frame_count, 1) |
| self.assertEqual(counter.op_count, 9) |
| # Graph for reference |
| # ------------- ------ ----------------------- ------------------------------------ -------- |
| # placeholder l_y_ L_y_ () {} |
| # call_function mul <built-in function mul> (l_y_, 5) {} |
| # call_function mul_1 <built-in function mul> (l_y_, 5) {} |
| # call_function mul_2 <built-in function mul> (l_y_, 5) {} |
| # call_function mul_3 <built-in function mul> (l_y_, 1) {} |
| # call_function mul_4 <built-in function mul> (mul, 1) {} |
| # call_function mul_5 <built-in function mul> (l_y_, mul_3) {} |
| # call_function mul_6 <built-in function mul> (mul_1, mul_4) {} |
| # call_function mul_7 <built-in function mul> (l_y_, mul_5) {} |
| # call_function mul_8 <built-in function mul> (mul_2, mul_6) {} |
| # output output output ((mul_7, mul_8, mul, mul_1, mul_2),) {} |
| |
| def test_complex_attr_access_with_graph_breaks(self): |
| def fn(x, y, z): |
| for t in x: |
| t.y = y |
| t.z = y * z |
| |
| print("Break!") |
| |
| new_y = 1 |
| new_z = 1 |
| for t in x: |
| new_y = t.y * new_y |
| new_z = t.z * new_z |
| |
| return new_y, new_z |
| |
| x_0 = torch.randn([2, 2]) |
| x_1 = torch.randn([2, 2]) |
| x_2 = torch.randn([2, 2]) |
| x = [x_0, x_1, x_2] |
| |
| y = torch.randn([2, 2]) |
| z = 5 |
| |
| eager_result = fn(x, y, z) |
| |
| counter = CompileCounter() |
| fn = torch._dynamo.optimize(counter, nopython=False)(fn) |
| |
| compile_result = fn(x, y, z) |
| self.assertEqual(compile_result, eager_result) |
| self.assertEqual(counter.frame_count, 2) |
| self.assertEqual(counter.op_count, 9) |
| # Graph for reference |
| # ------------- ------ ----------------------- ---------------------- -------- |
| # placeholder l_y_ L_y_ () {} |
| # call_function mul <built-in function mul> (l_y_, 5) {} |
| # call_function mul_1 <built-in function mul> (l_y_, 5) {} |
| # call_function mul_2 <built-in function mul> (l_y_, 5) {} |
| # output output output ((mul, mul_1, mul_2),) {} |
| # [GRAPH BREAK!] |
| # ------------- ------- ----------------------- ----------------- -------- |
| # placeholder l_x_0_y L_x_0_y () {} |
| # placeholder l_x_0_z L_x_0_z () {} |
| # placeholder l_x_1_y L_x_1_y () {} |
| # placeholder l_x_1_z L_x_1_z () {} |
| # placeholder l_x_2_y L_x_2_y () {} |
| # placeholder l_x_2_z L_x_2_z () {} |
| # call_function mul <built-in function mul> (l_x_0_y, 1) {} |
| # call_function mul_1 <built-in function mul> (l_x_0_z, 1) {} |
| # call_function mul_2 <built-in function mul> (l_x_1_y, mul) {} |
| # call_function mul_3 <built-in function mul> (l_x_1_z, mul_1) {} |
| # call_function mul_4 <built-in function mul> (l_x_2_y, mul_2) {} |
| # call_function mul_5 <built-in function mul> (l_x_2_z, mul_3) {} |
| # output output output ((mul_4, mul_5),) {} |
| |
| def test_complex_attr_access_with_inline_reconstruct(self): |
| def inline_test_fn(x, y, z): |
| print("f") |
| return x.a + y.a + z.a |
| |
| def fn(x, y, z): |
| x.a = 1 |
| y.a = 2 |
| z.a = 3 |
| |
| mult = inline_test_fn(x, y, z) |
| y = y * mult |
| x = x * mult |
| return x, y |
| |
| x = torch.randn([2, 2]) |
| y = torch.randn([2, 2]) |
| z = torch.randn([2, 2]) |
| |
| eager_result = fn(x, y, z) |
| |
| counter = CompileCounter() |
| |
| fn = torch._dynamo.optimize(counter, nopython=False)(fn) |
| |
| compile_result = fn(x, y, z) |
| self.assertEqual(compile_result, eager_result) |
| self.assertEqual(counter.frame_count, 1) |
| self.assertEqual(counter.op_count, 2) |
| # Graph for reference |
| # __compiled_fn_2 <eval_with_key>.0 opcode name target args kwargs |
| # ------------- ------ ----------------------- --------------- -------- |
| # placeholder l_x_ L_x_ () {} |
| # placeholder l_y_ L_y_ () {} |
| # call_function mul <built-in function mul> (l_y_, 6) {} |
| # call_function mul_1 <built-in function mul> (l_x_, 6) {} |
| # output output output ((mul_1, mul),) {} |
| |
| def test_set_data_on_input_tensor(self): |
| def fn(x, y): |
| x.data = y.data |
| if x.size() == y.size(): |
| return x * y |
| else: |
| return y * y |
| |
| x = torch.randn([5, 5]) |
| y = torch.randn([2, 2]) |
| |
| eager_result = fn(x, y) |
| |
| counter = CompileCounter() |
| |
| fn = torch._dynamo.optimize(counter, nopython=True)(fn) |
| |
| compile_result = fn(x, y) |
| self.assertEqual(compile_result, eager_result) |
| self.assertEqual(counter.frame_count, 1) |
| self.assertEqual(counter.op_count, 2) |
| # Graph for reference |
| # __compiled_fn_0 <eval_with_key>.0 opcode name target args kwargs |
| # ------------- ------ ----------------------- ------------ -------- |
| # placeholder l_x_ L_x_ () {} |
| # placeholder l_y_ L_y_ () {} |
| # call_method detach detach (l_y_,) {} |
| # call_function mul <built-in function mul> (l_x_, l_y_) {} |
| # output output output ((mul,),) {} |
| |
| # Note - this does not actually get captured in the graph yet. |
| # The plan of record is to introduce a set_data op, entirely subsume the operation into a call_function |
| # in the the fx graph, and let aot_autograd handle it. |
| def test_set_data_on_scoped_tensor(self): |
| def fn(x): |
| z = torch.zeros([4, 4]) |
| z.data = x.data |
| if x.size() == z.size(): |
| return z * x |
| else: |
| return x |
| |
| x = torch.randn([5, 5]) |
| |
| eager_result = fn(x) |
| |
| counter = CompileCounter() |
| |
| fn = torch._dynamo.optimize(counter, nopython=False)(fn) |
| |
| compile_result = fn(x) |
| self.assertEqual(compile_result, eager_result) |
| self.assertEqual(counter.frame_count, 2) |
| self.assertEqual(counter.op_count, 3) |
| |
| def test_set_data_on_user_defined_class_input_tensor(self): |
| class MyUserDefinedClass: |
| def __init__(self, x, y): |
| self.x = x |
| self.y = y |
| |
| def do_some_setattr_stuff(self): |
| self.z = x * y |
| self.a = x + x |
| return self.z * self.a |
| |
| x = torch.randn([5, 5]) |
| y = torch.randn([5, 5]) |
| mudc_1 = MyUserDefinedClass(x, y) |
| |
| eager_result = mudc_1.do_some_setattr_stuff() |
| |
| counter = CompileCounter() |
| |
| mudc_2 = MyUserDefinedClass(x, y) |
| do_some_setattr_stuff = torch._dynamo.optimize(counter, nopython=True)( |
| mudc_2.do_some_setattr_stuff |
| ) |
| |
| compile_result = do_some_setattr_stuff() |
| self.assertEqual(compile_result, eager_result) |
| self.assertEqual(counter.frame_count, 1) |
| self.assertEqual(counter.op_count, 3) |
| # Graph for reference |
| # __compiled_fn_0 <eval_with_key>.0 opcode name target args kwargs |
| # ------------- ------ ----------------------- -------------------- -------- |
| # placeholder l_x_ L_x_ () {} |
| # placeholder l_y_ L_y_ () {} |
| # call_function mul <built-in function mul> (l_x_, l_y_) {} |
| # call_function add <built-in function add> (l_x_, l_x_) {} |
| # call_function mul_1 <built-in function mul> (mul, add) {} |
| # output output output ((mul_1, mul, add),) {} |