Reorganize test_proxy_tensor.py per tracing mode (#82739)
I intend to run all of the proxy tensor tests on each of our tracing
modes, but to do this I have to make the tests parametrized on
tracing mode first. This does that refactor, without adding any
new tests.
Signed-off-by: Edward Z. Yang <[email protected]>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82739
Approved by: https://github.com/eellison
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 263a3b1..19fdc8c 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -120,8 +120,10 @@
return torch.rand_like(x)
class TestProxyTensor(TestCase):
+ tracing_mode = "real"
+
def _test(self, f, inps):
- fx_f = make_fx(f)(*inps)
+ fx_f = make_fx(f, tracing_mode=self.tracing_mode)(*inps)
new_inps = tree_map(_create_new_input, inps)
self.assertEqual(fx_f(*new_inps), f(*new_inps))
@@ -279,7 +281,7 @@
return x + torch.randn(x.shape)
# default behavior should trace factory functions
- traced = make_fx(f)(torch.randn(3))
+ traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3))
self.assertTrue(
any(
node.target == aten.randn.default
@@ -287,23 +289,11 @@
)
)
- def test_mode_tracing_factory_function_no_factory_function(self):
- def f(x):
- return x + torch.randn(x.shape)
- # setting the flag to false should not trace factory functions
- traced = make_fx(f, trace_factory_functions=False)(torch.randn(3))
- self.assertFalse(
- any(
- node.target == aten.randn.default
- for node in traced.graph.nodes
- )
- )
-
def test_make_fx_overloads(self):
def f(x):
return x.cos() + torch.randn(x.shape)
- traced = make_fx(f)(torch.randn(3))
+ traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3))
self.assertTrue(all([isinstance(node.target, torch._ops.OpOverload)
for node in traced.graph.nodes if node.op == 'call_function']))
@@ -316,29 +306,20 @@
self._test(f, [])
def test_constant_proxy_tensor(self):
- from torch.fx.experimental.proxy_tensor import make_fx
-
def f():
val = torch.tensor(float('inf'))
return torch.full((100, 100), val)
- g = make_fx(f)()
+ g = make_fx(f, tracing_mode=self.tracing_mode)()
self.assertEqual(g(), f())
def test_constant_proxy_tensor_mut(self):
- from torch.fx.experimental.proxy_tensor import make_fx
-
def f():
val = torch.tensor(float(1))
val.add_(2)
return torch.full((100, 100), val)
- g = make_fx(f)()
- self.assertEqual(g(), f())
- # In case we mutated shared state in the g graph!
- self.assertEqual(g(), f())
-
- g = make_fx(f, tracing_mode="fake")()
+ g = make_fx(f, tracing_mode=self.tracing_mode)()
self.assertEqual(g(), f())
# In case we mutated shared state in the g graph!
self.assertEqual(g(), f())
@@ -349,37 +330,15 @@
r, = torch.unbind(val, 0)
return r.item()
- g = make_fx(f)()
+ g = make_fx(f, tracing_mode=self.tracing_mode)()
self.assertEqual(g(), f())
- def test_issue82547(self):
- x = nn.Parameter(torch.randn(3, 3))
-
- def f():
- return torch.ops.aten.t.default(x)
- self.assertRaisesRegex(Exception, "non-Fake Tensor", lambda: make_fx(f, tracing_mode="fake")())
-
- class A(torch.Tensor):
- pass
-
- x = A(torch.randn(3, 3))
- self.assertRaisesRegex(TypeError, "no implementation found", lambda: make_fx(f, tracing_mode="fake")())
-
- def test_use_fake_and_tensor(self):
- def f(x, y):
- z = torch.tensor([2.0, 3.0])
- return x + y + z
-
- g = make_fx(f, tracing_mode="fake")(torch.randn(2), torch.randn(2))
- x, y = torch.randn(2), torch.randn(2)
- self.assertEqual(g(x, y), f(x, y))
-
def test_decomposition_interpreter(self):
def fn(x):
return torch.nn.functional.silu(x)
x = torch.rand((4, 4))
- fx_module = make_fx(fn, decomposition_table=None)(x)
+ fx_module = make_fx(fn, tracing_mode=self.tracing_mode, decomposition_table=None)(x)
found_silu = False
for n in fx_module.graph.nodes:
@@ -421,7 +380,7 @@
return list(params.values())
input = torch.randn(3, 5, requires_grad=True)
params = dict(model.named_parameters())
- fx_f = make_fx(f)(input, params)
+ fx_f = make_fx(f, tracing_mode=self.tracing_mode)(input, params)
# fx may change the order of parameters in list, so using set() to compare
self.assertTrue(
torch.allclose(fx_f(input, params)[0], f(input, params)[0])
@@ -456,7 +415,7 @@
input = torch.randn(3, 5, requires_grad=True)
params = dict(model.named_parameters())
buffers = dict(model.named_buffers())
- fx_f = make_fx(f)(input, params, buffers)
+ fx_f = make_fx(f, tracing_mode=self.tracing_mode)(input, params, buffers)
# fx may change the order of parameters in list, so using set() to compare
# also there is a numerical difference in results so changing atol from 1e-08 to 1e-03
self.assertTrue(
@@ -470,6 +429,43 @@
torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[1], atol=1e-03)
)
+
+class TestRealProxyTensor(TestCase):
+ def test_mode_tracing_factory_function_no_factory_function(self):
+ def f(x):
+ return x + torch.randn(x.shape)
+ # setting the flag to false should not trace factory functions
+ traced = make_fx(f, trace_factory_functions=False)(torch.randn(3))
+ self.assertFalse(
+ any(
+ node.target == aten.randn.default
+ for node in traced.graph.nodes
+ )
+ )
+
+class TestFakeProxyTensor(TestCase):
+ def test_issue82547(self):
+ x = nn.Parameter(torch.randn(3, 3))
+
+ def f():
+ return torch.ops.aten.t.default(x)
+ self.assertRaisesRegex(Exception, "non-Fake Tensor", lambda: make_fx(f, tracing_mode="fake")())
+
+ class A(torch.Tensor):
+ pass
+
+ x = A(torch.randn(3, 3))
+ self.assertRaisesRegex(TypeError, "no implementation found", lambda: make_fx(f, tracing_mode="fake")())
+
+ def test_use_fake_and_tensor(self):
+ def f(x, y):
+ z = torch.tensor([2.0, 3.0])
+ return x + y + z
+
+ g = make_fx(f, tracing_mode="fake")(torch.randn(2), torch.randn(2))
+ x, y = torch.randn(2), torch.randn(2)
+ self.assertEqual(g(x, y), f(x, y))
+
# TODO: Need to test the guards themselves specifically as well
@skipIfNoSympy
class TestSymbolicTracing(TestCase):