blob: abf00a0d791ca3e14c90cda8c615fd8b39919b7b [file] [log] [blame]
# Owner(s): ["module: dynamo"]
import torch
import torch._dynamo
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo.testing import CompileCounter
class TestInputAttrTracking(torch._dynamo.test_case.TestCase):
def test_tensor_property_on_tensor(self):
def fn(x):
return x * x.y
x_ = torch.randn([2, 2])
y_ = torch.randn([2, 2])
x_.y = y_
eager_result = fn(x_)
graph = None
def grab_graph_backend(gm, inps):
nonlocal graph
graph = gm
return gm
fn = torch._dynamo.optimize(grab_graph_backend, nopython=True)(fn)
compile_result = fn(x_)
self.assertEqual(eager_result, compile_result)
placeholder_cnt = 0
for node in graph.graph.nodes:
if node.op == "placeholder":
placeholder_cnt += 1
# We want to be very sure that this lifts y to inputs!
self.assertEqual(placeholder_cnt, 2)
def test_tensor_property_assigned_on_tensor(self):
def fn(x, y):
x.y = y
return x * x.y
x_ = torch.randn([2, 2])
y_ = torch.randn([2, 2])
eager_result = fn(x_, y_)
graph = None
def grab_graph_backend(gm, inps):
nonlocal graph
graph = gm
return gm
fn = torch._dynamo.optimize(grab_graph_backend, nopython=True)(fn)
compile_result = fn(x_, y_)
self.assertEqual(eager_result, compile_result)
placeholder_cnt = 0
for node in graph.graph.nodes:
if node.op == "placeholder":
placeholder_cnt += 1
# y is already an input
self.assertEqual(placeholder_cnt, 2)
def test_const_property_on_tensor(self):
def fn(x):
return x * x.y
x_ = torch.randn([2, 2])
y_ = 4
x_.y = y_
eager_result = fn(x_)
graph = None
def grab_graph_backend(gm, inps):
nonlocal graph
graph = gm
return gm
fn = torch._dynamo.optimize(grab_graph_backend, nopython=True)(fn)
compile_result = fn(x_)
self.assertEqual(eager_result, compile_result)
placeholder_cnt = 0
for node in graph.graph.nodes:
if node.op == "placeholder":
placeholder_cnt += 1
# We want to be very sure that this does not lifts y to inputs, as its a const
self.assertEqual(placeholder_cnt, 1)
def test_const_property_assigned_on_tensor(self):
def fn(x, y):
x.y = y
return x * x.y
x_ = torch.randn([2, 2])
y_ = 4
eager_result = fn(x_, y_)
fn = torch._dynamo.optimize("eager", nopython=True)(fn)
compile_result = fn(x_, y_)
self.assertEqual(eager_result, compile_result)
def test_guards_correctly_property_assigned_on_tensor_type_change(self):
def fn(x, y):
x.y = y
return x * x.y
x_ = torch.randn([2, 2])
fn = torch._dynamo.optimize("eager", nopython=True)(fn)
compile_result_const = fn(x_, 4)
self.assertEqual(compile_result_const, x_ * 4)
y = torch.randn([2, 2])
compile_result_tensor = fn(x_, y)
self.assertEqual(compile_result_tensor, x_ * y)
def test_guards_correctly_property_assigned_on_tensor_type_change_inductor(self):
def fn(x, y):
x.y = y
return x * x.y
x_ = torch.randn([2, 2])
fn = torch._dynamo.optimize("inductor", nopython=True)(fn)
compile_result_const = fn(x_, 4)
self.assertEqual(compile_result_const, x_ * 4)
y = torch.randn([2, 2])
compile_result_tensor = fn(x_, y)
self.assertEqual(compile_result_tensor, x_ * y)
def test_complex_attr_access_without_graph_breaks(self):
def fn(x, y, z):
for t in x:
t.y = y
t.z = y * z
new_y = 1
new_z = 1
for t in x:
new_y = t.y * new_y
new_z = t.z * new_z
return new_y, new_z
x_0 = torch.randn([2, 2])
x_1 = torch.randn([2, 2])
x_2 = torch.randn([2, 2])
x = [x_0, x_1, x_2]
y = torch.randn([2, 2])
z = 5
eager_result = fn(x, y, z)
counter = CompileCounter()
fn = torch._dynamo.optimize(counter, nopython=True)(fn)
compile_result = fn(x, y, z)
self.assertEqual(compile_result, eager_result)
self.assertEqual(counter.frame_count, 1)
self.assertEqual(counter.op_count, 9)
# Graph for reference
# ------------- ------ ----------------------- ------------------------------------ --------
# placeholder l_y_ L_y_ () {}
# call_function mul <built-in function mul> (l_y_, 5) {}
# call_function mul_1 <built-in function mul> (l_y_, 5) {}
# call_function mul_2 <built-in function mul> (l_y_, 5) {}
# call_function mul_3 <built-in function mul> (l_y_, 1) {}
# call_function mul_4 <built-in function mul> (mul, 1) {}
# call_function mul_5 <built-in function mul> (l_y_, mul_3) {}
# call_function mul_6 <built-in function mul> (mul_1, mul_4) {}
# call_function mul_7 <built-in function mul> (l_y_, mul_5) {}
# call_function mul_8 <built-in function mul> (mul_2, mul_6) {}
# output output output ((mul_7, mul_8, mul, mul_1, mul_2),) {}
def test_complex_attr_access_with_graph_breaks(self):
def fn(x, y, z):
for t in x:
t.y = y
t.z = y * z
print("Break!")
new_y = 1
new_z = 1
for t in x:
new_y = t.y * new_y
new_z = t.z * new_z
return new_y, new_z
x_0 = torch.randn([2, 2])
x_1 = torch.randn([2, 2])
x_2 = torch.randn([2, 2])
x = [x_0, x_1, x_2]
y = torch.randn([2, 2])
z = 5
eager_result = fn(x, y, z)
counter = CompileCounter()
fn = torch._dynamo.optimize(counter, nopython=False)(fn)
compile_result = fn(x, y, z)
self.assertEqual(compile_result, eager_result)
self.assertEqual(counter.frame_count, 2)
self.assertEqual(counter.op_count, 9)
# Graph for reference
# ------------- ------ ----------------------- ---------------------- --------
# placeholder l_y_ L_y_ () {}
# call_function mul <built-in function mul> (l_y_, 5) {}
# call_function mul_1 <built-in function mul> (l_y_, 5) {}
# call_function mul_2 <built-in function mul> (l_y_, 5) {}
# output output output ((mul, mul_1, mul_2),) {}
# [GRAPH BREAK!]
# ------------- ------- ----------------------- ----------------- --------
# placeholder l_x_0_y L_x_0_y () {}
# placeholder l_x_0_z L_x_0_z () {}
# placeholder l_x_1_y L_x_1_y () {}
# placeholder l_x_1_z L_x_1_z () {}
# placeholder l_x_2_y L_x_2_y () {}
# placeholder l_x_2_z L_x_2_z () {}
# call_function mul <built-in function mul> (l_x_0_y, 1) {}
# call_function mul_1 <built-in function mul> (l_x_0_z, 1) {}
# call_function mul_2 <built-in function mul> (l_x_1_y, mul) {}
# call_function mul_3 <built-in function mul> (l_x_1_z, mul_1) {}
# call_function mul_4 <built-in function mul> (l_x_2_y, mul_2) {}
# call_function mul_5 <built-in function mul> (l_x_2_z, mul_3) {}
# output output output ((mul_4, mul_5),) {}
def test_complex_attr_access_with_inline_reconstruct(self):
def inline_test_fn(x, y, z):
print("f")
return x.a + y.a + z.a
def fn(x, y, z):
x.a = 1
y.a = 2
z.a = 3
mult = inline_test_fn(x, y, z)
y = y * mult
x = x * mult
return x, y
x = torch.randn([2, 2])
y = torch.randn([2, 2])
z = torch.randn([2, 2])
eager_result = fn(x, y, z)
counter = CompileCounter()
fn = torch._dynamo.optimize(counter, nopython=False)(fn)
compile_result = fn(x, y, z)
self.assertEqual(compile_result, eager_result)
self.assertEqual(counter.frame_count, 1)
self.assertEqual(counter.op_count, 2)
# Graph for reference
# __compiled_fn_2 <eval_with_key>.0 opcode name target args kwargs
# ------------- ------ ----------------------- --------------- --------
# placeholder l_x_ L_x_ () {}
# placeholder l_y_ L_y_ () {}
# call_function mul <built-in function mul> (l_y_, 6) {}
# call_function mul_1 <built-in function mul> (l_x_, 6) {}
# output output output ((mul_1, mul),) {}
def test_set_data_on_input_tensor(self):
def fn(x, y):
x.data = y.data
if x.size() == y.size():
return x * y
else:
return y * y
x = torch.randn([5, 5])
y = torch.randn([2, 2])
eager_result = fn(x, y)
counter = CompileCounter()
fn = torch._dynamo.optimize(counter, nopython=True)(fn)
compile_result = fn(x, y)
self.assertEqual(compile_result, eager_result)
self.assertEqual(counter.frame_count, 1)
self.assertEqual(counter.op_count, 2)
# Graph for reference
# __compiled_fn_0 <eval_with_key>.0 opcode name target args kwargs
# ------------- ------ ----------------------- ------------ --------
# placeholder l_x_ L_x_ () {}
# placeholder l_y_ L_y_ () {}
# call_method detach detach (l_y_,) {}
# call_function mul <built-in function mul> (l_x_, l_y_) {}
# output output output ((mul,),) {}
# Note - this does not actually get captured in the graph yet.
# The plan of record is to introduce a set_data op, entirely subsume the operation into a call_function
# in the the fx graph, and let aot_autograd handle it.
def test_set_data_on_scoped_tensor(self):
def fn(x):
z = torch.zeros([4, 4])
z.data = x.data
if x.size() == z.size():
return z * x
else:
return x
x = torch.randn([5, 5])
eager_result = fn(x)
counter = CompileCounter()
fn = torch._dynamo.optimize(counter, nopython=False)(fn)
compile_result = fn(x)
self.assertEqual(compile_result, eager_result)
self.assertEqual(counter.frame_count, 2)
self.assertEqual(counter.op_count, 3)
def test_set_data_on_user_defined_class_input_tensor(self):
class MyUserDefinedClass:
def __init__(self, x, y):
self.x = x
self.y = y
def do_some_setattr_stuff(self):
self.z = x * y
self.a = x + x
return self.z * self.a
x = torch.randn([5, 5])
y = torch.randn([5, 5])
mudc_1 = MyUserDefinedClass(x, y)
eager_result = mudc_1.do_some_setattr_stuff()
counter = CompileCounter()
mudc_2 = MyUserDefinedClass(x, y)
do_some_setattr_stuff = torch._dynamo.optimize(counter, nopython=True)(
mudc_2.do_some_setattr_stuff
)
compile_result = do_some_setattr_stuff()
self.assertEqual(compile_result, eager_result)
self.assertEqual(counter.frame_count, 1)
self.assertEqual(counter.op_count, 3)
# Graph for reference
# __compiled_fn_0 <eval_with_key>.0 opcode name target args kwargs
# ------------- ------ ----------------------- -------------------- --------
# placeholder l_x_ L_x_ () {}
# placeholder l_y_ L_y_ () {}
# call_function mul <built-in function mul> (l_x_, l_y_) {}
# call_function add <built-in function add> (l_x_, l_x_) {}
# call_function mul_1 <built-in function mul> (mul, add) {}
# output output output ((mul_1, mul, add),) {}