blob: 67e45bf18e03b42505361f1e6c1ecf9e8e83b9af [file] [log] [blame]
Edward Z. Yang5b88a202022-07-20 18:12:25 -04001# Owner(s): ["module: ProxyTensor"]
Horace He4d88aff2022-06-07 00:28:53 +00002
Edward Z. Yang89e16c42023-02-15 17:57:21 -05003from torch.testing._internal.common_utils import TestCase, run_tests, xfail_inherited_tests
Horace He4d88aff2022-06-07 00:28:53 +00004import torch
5import unittest
6import warnings
Horace He6a3ecda2022-08-31 00:29:55 +00007import operator
Mostafa Elhoushi0894c492022-07-25 12:43:17 +00008from collections.abc import Iterable
Horace He4d88aff2022-06-07 00:28:53 +00009from torch.testing._internal.common_device_type import instantiate_device_type_tests
Fabio Rochab6525772023-02-16 15:34:34 +000010from torch.testing._internal.common_methods_invocations import op_db, wrapper_set_seed, skip, xfail, skipOps
Edward Z. Yangf7365ec2022-12-10 20:29:21 -080011from torch._subclasses.fake_tensor import DynamicOutputShapeException, DataDependentOutputException
Horace He4d88aff2022-06-07 00:28:53 +000012
David Berard00f65182022-06-29 10:28:42 -070013from torch._decomp import decomposition_table
Edward Z. Yangf1f26fe2023-02-12 14:04:01 -080014from torch.fx.experimental.symbolic_shapes import (
Edward Z. Yang37585592023-02-21 06:45:00 -080015 sym_float, eval_guards, bind_symbols, fx_placeholder_vals, fx_placeholder_targets,
Edward Z. Yang8efe4fd2023-02-23 11:54:36 -080016 constrain_range, guard_int, GuardOnDataDependentSymNode
Edward Z. Yangf1f26fe2023-02-12 14:04:01 -080017)
Richard Zou44b09bf2023-04-18 06:51:23 -070018from torch.testing._internal.custom_op_db import custom_op_db
Horace He4d88aff2022-06-07 00:28:53 +000019from torch.testing._internal.common_device_type import ops
Horace He66396772022-08-10 22:31:38 +000020from torch._C import _disabled_torch_function_impl
Edward Z. Yang54563e62022-12-15 16:37:24 +080021from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter, get_isolated_graphmodule
Horace Heb7046e92022-07-07 04:54:31 +000022from torch.utils._pytree import tree_map
Edward Z. Yangbf387e82022-08-01 08:55:19 -070023from torch import nn
Horace He1a18ff32022-07-23 19:03:38 +000024import re
25
Edward Z. Yangd2472442022-08-03 14:25:16 -070026import functools
Horace He2f4a5172022-09-21 02:30:50 +000027import itertools
Edward Z. Yangd2472442022-08-03 14:25:16 -070028
Horace He91b46482022-07-26 20:21:16 +000029aten = torch.ops.aten
30
Horace Hee3c89d02022-08-25 06:59:37 +000031HAS_CUDA = torch.cuda.is_available()
Horace He1a18ff32022-07-23 19:03:38 +000032
33
Edward Z. Yangf1f26fe2023-02-12 14:04:01 -080034def strip_end(s, suffix):
35 if suffix and s.endswith(suffix):
36 return s[:-len(suffix)]
37 else:
38 return s
39
40
41def show_guards(gm):
42 names = [strip_end(n, "_1") for n in fx_placeholder_targets(gm)]
43 return "\n".join(
Edward Z. Yang8372c5d2023-03-28 19:56:26 -070044 gm.shape_env.produce_guards(fx_placeholder_vals(gm), names, _simplified=True, constraint_inputs=None)
Edward Z. Yangf1f26fe2023-02-12 14:04:01 -080045 )
46
47
Horace He1a18ff32022-07-23 19:03:38 +000048def process_failures():
49 """
50 Takes file containing failures like
51
52 FAILED test/test_proxy_tensor.py::TestProxyTensorOpInfoCPU::test_make_fx_symbolic_exhaustive___getitem___cpu_float32 - RuntimeError: aten.size.default - couldn't find symbolic meta function/decomposition # noqa: B950
53
54 and processes them into a list of opinfo xfails
55 """
56 f = open('pytest_failures')
57 failures = f.readlines()
58 failures = [i.strip() for i in failures]
59
60 def process_failure_string(s, matcher):
61 out = re.search(matcher, s)
62 return out.groups()
63
64 SYMBOLIC_TRACE_MATCH = r'exhaustive_(.*)_cpu.*: (.*)'
65 failures = [process_failure_string(s, SYMBOLIC_TRACE_MATCH) for s in failures]
66
67 def create_normalized_name(op):
68 if op.variant_test_name == '':
69 s = op.name
70 else:
71 s = f"{op.name}.{op.variant_test_name}"
72 return s.replace('.', '_')
73
74 remap_opinfo = {create_normalized_name(op): (op.name, op.variant_test_name) for op in op_db}
75
76 print("symbolic_tensor_failures = {")
77 for failure, reason in failures:
78 print(f" xfail{remap_opinfo[failure]}, # {reason}")
79 print("}")
80
Horace He4d88aff2022-06-07 00:28:53 +000081
Horace He4d88aff2022-06-07 00:28:53 +000082USE_TORCHVISION = False
83try:
84 import torchvision
85 USE_TORCHVISION = True
86except ImportError:
87 warnings.warn("Couldn't import torchvision. Some of our tests use it, try "
88 "to install it with commands from pytorch.org, post-fixed with "
89 "`--no-deps` to avoid overwriting the pytorch installation",
90 UserWarning)
91
92
Horace Heb7046e92022-07-07 04:54:31 +000093def _create_new_input(x):
94 if not isinstance(x, torch.Tensor):
95 return x
96 if x.dtype != torch.float:
97 return x + 1
98 if x.is_leaf:
Edward Z. Yangd2472442022-08-03 14:25:16 -070099 return torch.rand_like(x, requires_grad=x.requires_grad)
Horace Heb7046e92022-07-07 04:54:31 +0000100 else:
101 return torch.rand_like(x)
102
Horace He66396772022-08-10 22:31:38 +0000103"""
104Delays a cos being executed on the unwraptensor until its used. Simulates a CommTensor used
105"""
106class UnwrapTensor(torch.Tensor):
107 @staticmethod
108 def __new__(cls, tensor: torch.Tensor):
109 r = torch.Tensor._make_wrapper_subclass(
110 cls,
111 tensor.size(),
112 dtype=tensor.dtype,
113 device=tensor.device,
114 layout=tensor.layout,
115 requires_grad=tensor.requires_grad,
116 )
117 r._tensor = tensor
118 return r
119
120 def __repr__(self):
121 # TODO: consider all_gather the local tensors for better debugging
122 return f"UnwrapTensor({self._tensor})"
123
124 __torch_function__ = _disabled_torch_function_impl
125
126 @classmethod
127 def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
128 def unwrap(e):
129 ret = e
130 if isinstance(e, UnwrapTensor):
131 ret = e._tensor.cos()
132
133 return ret
134
135 args = tree_map(unwrap, args)
136 kwargs = tree_map(unwrap, kwargs)
137 return func(*args, **kwargs)
138
Edward Z. Yangd2472442022-08-03 14:25:16 -0700139class TestGenericProxyTensor(TestCase):
140 # WARNING: if any of your inputs are index tensors, DO NOT use this
141 # function
Horace Heb7046e92022-07-07 04:54:31 +0000142 def _test(self, f, inps):
Edward Z. Yangb361f702022-08-03 10:50:30 -0700143 fx_f = make_fx(f, tracing_mode=self.tracing_mode)(*inps)
Horace Heb7046e92022-07-07 04:54:31 +0000144 new_inps = tree_map(_create_new_input, inps)
Edward Z. Yang817a8272022-08-16 13:37:29 -0700145 r1 = fx_f(*new_inps)
146 r2 = f(*new_inps)
147 self.assertEqual(r1, r2)
Horace Heb7046e92022-07-07 04:54:31 +0000148
Brian Hirshaf440c42023-03-21 20:15:23 +0000149 def test_pre_autograd_mode_stack(self):
150 def f(a):
151 b = torch.ones(4, 4)
152 return torch.matmul(a, b)
153 # We expect to see matmul in the trace - it should NOT be decomposed into mm.
154 # Also, torch.ones() doesn't show up in the trace.
155 # This is annoying but expected: ones() never dispatches to the Autograd dispatch key,
156 # so our mode never sees it - it goes directly to the BackendSelect key.
Brian Hirsh62fad312023-04-25 19:04:34 +0000157 inp = torch.ones(4, 4)
158 # Test that make_fx(pre_autograd=True) clears caches properly.
159 from torch._dispatch.python import enable_python_dispatcher
160 with enable_python_dispatcher():
161 out1 = f(inp)
162 fx_g = make_fx(f, pre_autograd=True)(inp)
Brian Hirshaf440c42023-03-21 20:15:23 +0000163 self.assertExpectedInline(fx_g.code.strip(), """\
164def forward(self, a_1):
165 ones = torch.ops.aten.ones.default([4, 4], device = device(type='cpu'), pin_memory = False)
166 matmul = torch.ops.aten.matmul.default(a_1, ones); a_1 = ones = None
167 return matmul""")
168
169
Edward Z. Yangd2472442022-08-03 14:25:16 -0700170 def test_make_fx_simple(self):
Horace He4d88aff2022-06-07 00:28:53 +0000171 def f(x):
172 return torch.sin(x)
Horace Heb7046e92022-07-07 04:54:31 +0000173 self._test(f, (torch.randn(3),))
Horace He4d88aff2022-06-07 00:28:53 +0000174
Edward Z. Yangd2472442022-08-03 14:25:16 -0700175 def test_scalar_device(self, device='cpu'):
Horace He4d88aff2022-06-07 00:28:53 +0000176 def f(a, b):
177 return a + b
Horace Heb7046e92022-07-07 04:54:31 +0000178 self._test(f, [torch.randn(3, device=device), torch.tensor(5)])
Horace He4d88aff2022-06-07 00:28:53 +0000179
Ivan Yashchuk900e93d2022-08-02 11:02:10 +0000180 def test_isolated_graphmodule(self):
181 def is_any_sum(gm):
182 return any(node.target == torch.ops.aten.sum.default for node in gm.graph.nodes)
183
184 def is_any_digamma(gm):
185 return any(node.target == torch.ops.aten.digamma.default for node in gm.graph.nodes)
186
187 def is_any_sigmoid(gm):
188 return any(node.target == torch.ops.aten.sigmoid.default for node in gm.graph.nodes)
189
190 def inner(x):
191 return torch.sum(x)
192
193 def f(x):
194 gm = get_isolated_graphmodule(inner, (x,), {})
195 self.assertTrue(is_any_sum(gm))
196 return x + torch.randn(x.shape)
197
198 # get_isolated_graphmodule uses make_fx internally that shouldn't be traced
199 # by the outer make_fx call
200 traced = make_fx(f)(torch.randn(3))
201 self.assertFalse(is_any_sum(traced))
202
203 # When factory functions are used, they should not be traced
204 # by the outer make_fx call
205 def inner_with_factory():
206 val = torch.tensor(float(1))
207 val.add_(2)
208 return torch.full((10, 10), val).sum()
209
210 def f1(x):
211 gm = get_isolated_graphmodule(inner_with_factory, (), {})
212 self.assertTrue(is_any_sum(gm))
213 return torch.sigmoid(x)
214
215 def f2(x):
216 gm = get_isolated_graphmodule(f1, (x,), {})
217 self.assertFalse(is_any_sum(gm))
218 self.assertTrue(is_any_sigmoid(gm))
219 return torch.digamma(x)
220
221 traced = make_fx(f2)(torch.randn(3))
222 self.assertFalse(is_any_sum(traced))
223 self.assertFalse(is_any_sigmoid(traced))
224 self.assertTrue(is_any_digamma(traced))
225
226 # Verify nested make_fx calls don't make factory functions to be leaked
jon-chuangd5901fc2023-02-01 17:28:44 +0000227 # into the outer graph. Verify that `make_fx`` itself does not leak its execution.
Ivan Yashchuk900e93d2022-08-02 11:02:10 +0000228 def f2(x):
229 gm = make_fx(f1)(x)
230 self.assertFalse(is_any_sum(gm))
231 self.assertTrue(is_any_sigmoid(gm))
232 return torch.digamma(x)
233
234 traced = make_fx(f2)(torch.randn(3))
235 self.assertFalse(is_any_sum(traced))
jon-chuangd5901fc2023-02-01 17:28:44 +0000236 self.assertFalse(is_any_sigmoid(traced))
237 self.assertTrue(is_any_digamma(traced))
238
239 # Verify that the `forward`` function of a graph module produced as a
240 # side effect of an interior `make_fx` is still traced
241 def f3(x):
242 gm = make_fx(f1)(x)
243 self.assertFalse(is_any_sum(gm))
244 self.assertTrue(is_any_sigmoid(gm))
245 # `gm.forward`` is still traced
246 return torch.digamma(gm(x))
247
248 traced = make_fx(f3)(torch.randn(3))
249 self.assertFalse(is_any_sum(traced))
Ivan Yashchuk900e93d2022-08-02 11:02:10 +0000250 self.assertTrue(is_any_sigmoid(traced))
251 self.assertTrue(is_any_digamma(traced))
252
253 # Verify interaction with non-ProxyTensor modes
254 from torch.testing._internal.logging_tensor import LoggingTensorMode
255
256 def f1_logging(x):
257 with LoggingTensorMode():
258 gm = get_isolated_graphmodule(inner_with_factory, (), {})
259 self.assertTrue(is_any_sum(gm))
260 return torch.sigmoid(x)
261
262 def f2_logging(x):
263 with LoggingTensorMode(), LoggingTensorMode():
264 gm = get_isolated_graphmodule(f1_logging, (x,), {})
265 self.assertFalse(is_any_sum(gm))
266 self.assertTrue(is_any_sigmoid(gm))
267 return torch.digamma(x)
268
269 traced = make_fx(f2_logging)(torch.randn(3))
270 self.assertFalse(is_any_sum(traced))
271 self.assertFalse(is_any_sigmoid(traced))
272 self.assertTrue(is_any_digamma(traced))
273
274 # Verify interaction with another tensor subclass
275 # This case currently doesn't work and should raise an error
276 # See: https://github.com/pytorch/pytorch/pull/81764#issuecomment-1200472068
277 from torch.testing._internal.logging_tensor import LoggingTensor
278
279 def f1_logging_tensor(x):
280 gm = get_isolated_graphmodule(inner_with_factory, (), {})
281 self.assertTrue(is_any_sum(gm))
282 return torch.sigmoid(x)
283
284 def f2_logging_tensor(x):
285 x = LoggingTensor(x)
286 gm = get_isolated_graphmodule(f1_logging_tensor, (x,), {})
287 self.assertFalse(is_any_sum(gm))
288 self.assertTrue(is_any_sigmoid(gm))
289 return torch.digamma(x)
290
Edward Z. Yang817a8272022-08-16 13:37:29 -0700291 traced = make_fx(f2_logging_tensor)(torch.randn(3))
292 self.assertFalse(is_any_sum(traced))
293 self.assertFalse(is_any_sigmoid(traced)) # this fails, sigmoid is traced with LoggingTensor
294 self.assertTrue(is_any_digamma(traced))
Ivan Yashchuk900e93d2022-08-02 11:02:10 +0000295
Brian Hirsh35c9ea82023-03-27 15:08:41 +0000296 # See https://github.com/pytorch/pytorch/issues/97541
297 def test_empty_like_doesnt_burn_in_defaults(self):
298 def f(x):
299 return torch.empty_like(x)
300 out = make_fx(f)(torch.randn(3))
301 self.assertExpectedInline(out.code.strip(), """\
302def forward(self, x_1):
303 empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False); x_1 = None
304 return empty_like""")
305
Brian Hirshba90c9f2022-08-11 14:20:53 -0700306 def test_proxy_tensor_mode_with_decomp_table_preserves_proxy(self):
307 def f(x):
308 y = x.new_zeros(x.size())
309 y.copy_(x)
310 return y
311
312 def _new_zeros_decomp(inp, size, dtype=None, layout=None, device=None, pin_memory=None):
313 return torch.zeros(size, dtype=inp.dtype, device=inp.device)
314
315 factory_func_decomp = {torch.ops.aten.new_zeros.default: _new_zeros_decomp}
316
317 # When new_zeros() decomposes into torch.zero(), we expect ProxyTensorMode
318 # to still be (re-entrantly) enabled, so that the `torch.zero()` call
319 # returns a ProxyTensor.
320 out = make_fx(f, decomposition_table=factory_func_decomp)(torch.ones(2))
321 self.assertExpectedInline(out.code, """\
322
323
324
325def forward(self, x_1):
326 zeros = torch.ops.aten.zeros.default([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False)
Horace Hea27a4a02022-08-31 07:01:37 +0000327 copy_ = torch.ops.aten.copy_.default(zeros, x_1); zeros = x_1 = None
328 return copy_
Brian Hirshba90c9f2022-08-11 14:20:53 -0700329 """)
330
Edward Z. Yang988bd012022-08-09 08:35:50 -0700331 def test_make_fx_reentrant_dispatch(self):
332 def f(x):
333 return torch.ops.aten.norm.Scalar(x, 2.0)
334
335 def norm_decomp(x, p=2.0):
336 if p != 2.0:
337 raise RuntimeError("can't handle with p != 2")
338 return torch.sqrt(torch.sum(torch.square(x)))
339
340 decomp = {torch.ops.aten.norm.Scalar: norm_decomp}
341
342 traced = make_fx(f, decomposition_table=decomp, tracing_mode=self.tracing_mode)(torch.rand(3))
343
344 for n in traced.graph.nodes:
345 self.assertTrue("square" not in str(n.target))
346 self.assertTrue("norm" not in str(n.target))
347
Horace He4d88aff2022-06-07 00:28:53 +0000348 @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
Edward Z. Yangd2472442022-08-03 14:25:16 -0700349 def test_resnet18_backward_trace(self):
Horace He4d88aff2022-06-07 00:28:53 +0000350 mod = torchvision.models.resnet18()
351
Edward Z. Yangd2472442022-08-03 14:25:16 -0700352 # An old version of this test called the module directly. This works
353 # for tracing_mode == "real", but for fake tensors, we also have to
354 # ensure that the parameters and buffers get wrapped in fake tensors
Richard Zou5d012772023-01-17 21:49:58 -0500355 # because free fake tensors are not supported. Fortunately functional_call
Edward Z. Yangd2472442022-08-03 14:25:16 -0700356 # does precisely this for us.
357 def f(x, params, buffers):
358 for p in params.values():
359 p.grad = None
Richard Zou5d012772023-01-17 21:49:58 -0500360 loss = torch.func.functional_call(mod, {**params, **buffers}, (x,)).sum()
Edward Z. Yangd2472442022-08-03 14:25:16 -0700361 # I could have done this with the functional API, but there is
362 # plenty of exercising this; I want to show mutating API still
363 # works
364 loss.backward()
365 return [p.grad for p in params.values()]
Horace He4d88aff2022-06-07 00:28:53 +0000366
Edward Z. Yangd2472442022-08-03 14:25:16 -0700367 inp = torch.randn(3, 3, 250, 250)
368 self._test(f, [inp, dict(mod.named_parameters()), dict(mod.named_buffers())])
Horace He4d88aff2022-06-07 00:28:53 +0000369
Edward Z. Yang63f35f12022-08-10 15:33:10 -0700370 def test_varargs(self):
371 def f(*args):
372 return sum(args)
373
374 self._test(f, [torch.randn(2), torch.randn(2)])
375
Horace He4d88aff2022-06-07 00:28:53 +0000376 def test_proxy_tensor(self):
377 def f_grad(x):
378 val = x.cos().cos().sum()
379 return torch.autograd.grad(val, x)
380
381 def f_backward(x):
382 val = x.cos().cos().sum()
383 val.backward()
384 return x.grad
385
386 for f in [f_grad, f_backward]:
Horace Heb7046e92022-07-07 04:54:31 +0000387 self._test(f, [torch.randn(3, requires_grad=True)])
Horace He4d88aff2022-06-07 00:28:53 +0000388
Edward Z. Yang54563e62022-12-15 16:37:24 +0800389 def test_pickle_issue89626(self):
390 import pickle
391 x = torch.randn(2)
392 make_fx(lambda x: x * 2, tracing_mode=self.tracing_mode)(x)
393 pickle.dumps(x)
394
PyTorch MergeBot4e33c8c2022-06-27 12:06:49 +0000395 def test_inplace_metadata(self):
396 def f(x):
397 x = x.clone()
398 x.unsqueeze_(-1)
399 assert x.shape[-1] == 1
400 return x
401
Horace Heb7046e92022-07-07 04:54:31 +0000402 self._test(f, [torch.randn(5)])
PyTorch MergeBot4e33c8c2022-06-27 12:06:49 +0000403
Horace He4d88aff2022-06-07 00:28:53 +0000404 def test_mode_tracing_factory_function(self):
405 def f(x):
406 return x + torch.randn(x.shape)
407
Horace Hef5d7e5a2022-06-16 22:04:10 +0000408 # default behavior should trace factory functions
Edward Z. Yangb361f702022-08-03 10:50:30 -0700409 traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3))
Horace He4d88aff2022-06-07 00:28:53 +0000410 self.assertTrue(
411 any(
Horace He91b46482022-07-26 20:21:16 +0000412 node.target == aten.randn.default
Horace He4d88aff2022-06-07 00:28:53 +0000413 for node in traced.graph.nodes
414 )
415 )
416
Edward Z. Yang94b5c802022-11-18 13:14:40 -0800417 def test_val_metadata_mutation(self):
418 def f(x):
419 y = x.clone()
420 y.unsqueeze_(0)
421 return y
422
423 traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3, requires_grad=True))
424 self.assertEqual([
425 tuple(node.meta['val'].shape)
426 for node in traced.graph.nodes
427 if 'val' in node.meta
428 ], [(3,), (3,), (1, 3)])
429
Horace He615dd252022-06-28 00:20:22 +0000430 def test_make_fx_overloads(self):
431 def f(x):
432 return x.cos() + torch.randn(x.shape)
433
Edward Z. Yangb361f702022-08-03 10:50:30 -0700434 traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3))
Horace He615dd252022-06-28 00:20:22 +0000435
Aaron Gokaslane2a38172023-04-25 15:02:13 +0000436 self.assertTrue(all(isinstance(node.target, torch._ops.OpOverload)
437 for node in traced.graph.nodes if node.op == 'call_function'))
Horace He615dd252022-06-28 00:20:22 +0000438
Horace Heb7046e92022-07-07 04:54:31 +0000439 def test_tensor_constants(self):
440 def f():
441 val = torch.tensor(float('inf'))
442 return torch.full((100, 100), val)
443
444 self._test(f, [])
445
Edward Z. Yangd4237222022-08-12 06:17:53 -0700446 def test_allclose(self):
447 def f(a, b):
448 return torch.allclose(a, b)
Edward Z. Yangfca03ee2022-07-13 21:11:10 -0700449
Edward Z. Yangf7365ec2022-12-10 20:29:21 -0800450 def test_f():
451 make_fx(f, tracing_mode=self.tracing_mode)(
Edward Z. Yangd4237222022-08-12 06:17:53 -0700452 torch.zeros(3), torch.zeros(3)
453 )
Edward Z. Yangf7365ec2022-12-10 20:29:21 -0800454
Edward Z. Yangec2461b2023-01-30 05:57:30 -0800455 if self.tracing_mode != "real":
Edward Z. Yangf7365ec2022-12-10 20:29:21 -0800456 self.assertRaises(DataDependentOutputException, test_f)
457 else:
458 self.assertRaisesRegex(RuntimeError, "data-dependent", test_f)
Edward Z. Yangfca03ee2022-07-13 21:11:10 -0700459
460 def test_constant_proxy_tensor_mut(self):
Edward Z. Yangfca03ee2022-07-13 21:11:10 -0700461 def f():
462 val = torch.tensor(float(1))
463 val.add_(2)
464 return torch.full((100, 100), val)
465
Edward Z. Yangb361f702022-08-03 10:50:30 -0700466 g = make_fx(f, tracing_mode=self.tracing_mode)()
Edward Z. Yangfca03ee2022-07-13 21:11:10 -0700467 self.assertEqual(g(), f())
468 # In case we mutated shared state in the g graph!
469 self.assertEqual(g(), f())
470
Edward Z. Yang98215922022-08-01 07:02:58 -0700471 def test_constant_unbind(self):
472 def f():
473 val = torch.tensor([2])
474 r, = torch.unbind(val, 0)
475 return r.item()
476
Edward Z. Yangb361f702022-08-03 10:50:30 -0700477 g = make_fx(f, tracing_mode=self.tracing_mode)()
Edward Z. Yang98215922022-08-01 07:02:58 -0700478 self.assertEqual(g(), f())
479
Edward Z. Yang24acc312022-08-17 20:30:13 -0700480 def test_constant_blowup(self):
481 def f():
482 val = torch.tensor([2])
483 blowup = val.repeat(1000)
Edward Z. Yangf7365ec2022-12-10 20:29:21 -0800484 return bool(blowup.sum().item() == 2)
Edward Z. Yang24acc312022-08-17 20:30:13 -0700485
Edward Z. Yangec2461b2023-01-30 05:57:30 -0800486 def test_f():
487 make_fx(f, tracing_mode=self.tracing_mode)()
488
489 if self.tracing_mode == "fake":
490 self.assertRaises(DataDependentOutputException, test_f)
491 else:
492 self.assertRaisesRegex(RuntimeError, "data-dependent", test_f)
Edward Z. Yang24acc312022-08-17 20:30:13 -0700493
Edward Z. Yang91521442022-08-17 20:30:41 -0700494 def test_constant_random(self):
495 def f():
496 val = torch.tensor([2.0])
497 val.normal_()
Edward Z. Yangf7365ec2022-12-10 20:29:21 -0800498 return bool(val.item() == 2.1)
Edward Z. Yang91521442022-08-17 20:30:41 -0700499
Edward Z. Yangec2461b2023-01-30 05:57:30 -0800500 def test_f():
501 make_fx(f, tracing_mode=self.tracing_mode)()
502
503 if self.tracing_mode == "fake":
504 self.assertRaises(DataDependentOutputException, test_f)
505 else:
506 self.assertRaisesRegex(RuntimeError, "data-dependent", test_f)
Edward Z. Yang91521442022-08-17 20:30:41 -0700507
David Berard00f65182022-06-29 10:28:42 -0700508 def test_decomposition_interpreter(self):
509 def fn(x):
510 return torch.nn.functional.silu(x)
511
512 x = torch.rand((4, 4))
Edward Z. Yangb361f702022-08-03 10:50:30 -0700513 fx_module = make_fx(fn, tracing_mode=self.tracing_mode, decomposition_table=None)(x)
David Berard00f65182022-06-29 10:28:42 -0700514
515 found_silu = False
516 for n in fx_module.graph.nodes:
517 if n.target == torch.ops.aten.silu or n.target == torch.ops.aten.silu.default:
518 found_silu = True
519
520 self.assertTrue(found_silu)
521
522 new_graph = torch.fx.Graph()
523 silu_decomp_table = {torch.ops.aten.silu.default: decomposition_table[torch.ops.aten.silu.default]}
524 DecompositionInterpreter(
525 fx_module,
526 new_graph=new_graph,
527 decomposition_table=silu_decomp_table,
528 ).run(x)
529
530 decomposed_module = torch.fx.GraphModule(fx_module, new_graph)
531
532 for n in decomposed_module.graph.nodes:
533 self.assertTrue(n.target != torch.ops.aten.silu)
534 self.assertTrue(n.target != torch.ops.aten.silu.default)
535
536 self.assertEqual(fx_module(x), decomposed_module(x))
Horace He615dd252022-06-28 00:20:22 +0000537
Edward Z. Yangd2472442022-08-03 14:25:16 -0700538 def test_make_fx_model_fwd_bwd(self):
Mostafa Elhoushi0894c492022-07-25 12:43:17 +0000539 class Foo(torch.nn.Module):
540 def __init__(self):
541 super().__init__()
542 self.linear = torch.nn.Linear(5, 5)
543
544 def forward(self, x):
545 return self.linear(x).relu()
546
547 model = Foo()
548
549 def f(x, params):
Richard Zou5d012772023-01-17 21:49:58 -0500550 out = torch.func.functional_call(model, params, x).sum()
Mostafa Elhoushi0894c492022-07-25 12:43:17 +0000551 out.backward()
552 return list(params.values())
553 input = torch.randn(3, 5, requires_grad=True)
554 params = dict(model.named_parameters())
Edward Z. Yangb361f702022-08-03 10:50:30 -0700555 fx_f = make_fx(f, tracing_mode=self.tracing_mode)(input, params)
Mostafa Elhoushi0894c492022-07-25 12:43:17 +0000556 # fx may change the order of parameters in list, so using set() to compare
557 self.assertTrue(
558 torch.allclose(fx_f(input, params)[0], f(input, params)[0])
559 or
560 torch.allclose(fx_f(input, params)[0], f(input, params)[1])
561 )
562 self.assertTrue(
563 torch.allclose(fx_f(input, params)[1], f(input, params)[0])
564 or
565 torch.allclose(fx_f(input, params)[1], f(input, params)[1])
566 )
567
Horace He8b8942b2022-08-25 01:53:33 +0000568 def test_make_fx_model_double_param(self):
569 class Emformer(torch.nn.Module):
570 def __init__(
571 self,
572 input_dim: int = 256,
573 ) -> None:
574 super().__init__()
575
576 self.layer_norm = torch.nn.LayerNorm(input_dim)
577
578 def forward(mod_self, x): # noqa: B902
579 self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor))
580 y = mod_self.layer_norm(x)
581 self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor))
582 z = mod_self.layer_norm(y)
583 return z
584
585
586 gm = make_fx(Emformer())(torch.randn(16, 1, 256))
Aaron Gokaslan3d82d8d2023-02-10 23:40:26 +0000587 ops = {n.target for n in gm.graph.nodes if n.op == 'call_function'}
Horace He8b8942b2022-08-25 01:53:33 +0000588 self.assertEqual(len(ops), 2)
589
590
Edward Z. Yangd2472442022-08-03 14:25:16 -0700591 def test_make_fx_model_fwd_bwd_wgtupdate(self):
Mostafa Elhoushi0894c492022-07-25 12:43:17 +0000592 class Foo(torch.nn.Module):
593 def __init__(self):
594 super().__init__()
595 self.linear = torch.nn.Linear(5, 5)
596
597 def forward(self, x):
598 return self.linear(x).relu()
599
600 model = Foo()
601
602 def f(args, params, buffers):
Edward Z. Yang817a8272022-08-16 13:37:29 -0700603 for p in params.values():
604 p.grad = None
Mostafa Elhoushi0894c492022-07-25 12:43:17 +0000605 if not isinstance(args, Iterable):
606 args = [args]
607 params_and_buffers = {**params, **buffers}
Richard Zou5d012772023-01-17 21:49:58 -0500608 out = torch.func.functional_call(model, params_and_buffers, args)
Mostafa Elhoushi0894c492022-07-25 12:43:17 +0000609 out.sum().backward()
610 return [p - 1e-4 * p.grad for p in params.values()]
611
612 input = torch.randn(3, 5, requires_grad=True)
613 params = dict(model.named_parameters())
614 buffers = dict(model.named_buffers())
Edward Z. Yangb361f702022-08-03 10:50:30 -0700615 fx_f = make_fx(f, tracing_mode=self.tracing_mode)(input, params, buffers)
Mostafa Elhoushi0894c492022-07-25 12:43:17 +0000616 # fx may change the order of parameters in list, so using set() to compare
617 # also there is a numerical difference in results so changing atol from 1e-08 to 1e-03
618 self.assertTrue(
619 torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[0], atol=1e-03)
620 or
621 torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[1], atol=1e-03)
622 )
623 self.assertTrue(
624 torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[0], atol=1e-03)
625 or
626 torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[1], atol=1e-03)
627 )
628
Horace He66396772022-08-10 22:31:38 +0000629 def test_trace_subclasses(self):
Horace Hea27a4a02022-08-31 07:01:37 +0000630 def f1(x):
Horace He66396772022-08-10 22:31:38 +0000631 x = UnwrapTensor(x)
632 y = x * 2
633 return y
634
Horace Hea27a4a02022-08-31 07:01:37 +0000635 def f2(x):
636 wrapped = UnwrapTensor(x)
637 y = x * wrapped
638 return y
639
Horace He66396772022-08-10 22:31:38 +0000640 inp = [torch.randn(5)]
Horace Hea27a4a02022-08-31 07:01:37 +0000641 self._test(f1, inp)
642 self._test(f2, inp)
Horace He0e0af732022-08-20 00:47:11 +0000643
644 def test_partial_decomp(self):
645 def f(a, b, c):
646 x = torch.addmm(a, b, c)
647 y = torch.addmm(a, b, c, beta=2, alpha=1)
648 return x + y
649 inps = [torch.randn(5, 5), torch.randn(5, 5), torch.randn(5, 5)]
650 fx_g = make_fx(f)(*inps)
651
652 def addmm(a, b, c, beta=1, alpha=1):
653 if beta == 1 and alpha == 1:
654 return NotImplemented
655 return beta * a + alpha * (b @ c)
656
Brian Hirshaf440c42023-03-21 20:15:23 +0000657 decomposed_fx = make_fx(f, decomposition_table={aten.addmm.default: addmm})(*inps)
Horace He0e0af732022-08-20 00:47:11 +0000658
659 self.assertEqual(fx_g(*inps), decomposed_fx(*inps))
660 self.assertEqual(len([n for n in fx_g.graph.nodes if n.target == aten.addmm.default]), 2)
661 self.assertEqual(len([n for n in decomposed_fx.graph.nodes if n.target == aten.addmm.default]), 1)
662
Horace Hea27a4a02022-08-31 07:01:37 +0000663 def test_decomp_of_capture(self):
664 val = torch.randn(5)
665
666 def f(x):
667 return x.t() + val.t()
668
669 def nop(x):
670 return x.cos()
671
672 traced = make_fx(f, decomposition_table={torch.ops.aten.t.default: nop})(torch.randn(5))
673 self.assertEqual(len([n for n in traced.graph.nodes if n.target == torch.ops.aten.t.default]), 0)
674
675
Horace Hee3c89d02022-08-25 06:59:37 +0000676 @unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
677 def test_amp_cache(self):
678 layer = torch.nn.Conv2d(3, 3, 3).cuda()
Horace He0e0af732022-08-20 00:47:11 +0000679
Horace Hee3c89d02022-08-25 06:59:37 +0000680 def f(x, w):
681 return torch.nn.functional.conv2d(x, w, stride=layer.stride)
Horace He66396772022-08-10 22:31:38 +0000682
Horace Hee3c89d02022-08-25 06:59:37 +0000683 inp = torch.randn(4, 3, 10, 10, device='cuda')
684 with torch.autocast('cuda'):
685 out_graph = make_fx(f)(inp, layer.weight).graph
686 out_graph2 = make_fx(f)(inp, layer.weight).graph
687
688 self.assertEqual(len(out_graph.nodes), len(out_graph2.nodes))
689 for a, b in zip(out_graph.nodes, out_graph2.nodes):
690 self.assertEqual(a.op, b.op)
691
Horace He4bdc0af2022-09-16 02:29:13 +0000692 def test_strides(self):
693 def f(x):
694 self.assertTrue(x.is_contiguous())
695 self.assertFalse(x.is_contiguous(memory_format=torch.channels_last))
696 x = x.permute(0, 3, 1, 2)
697 self.assertFalse(x.is_contiguous())
698 self.assertTrue(x.is_contiguous(memory_format=torch.channels_last))
699 return x
700 make_fx(f)(torch.randn(2, 3, 4, 5))
701
702 def f(x):
703 self.assertTrue(x.is_contiguous())
704 y = x[:, 1]
705 self.assertFalse(y.is_contiguous())
706 y = x[:, ::2]
707 self.assertFalse(y.is_contiguous())
708 return x.cos()
709
710 make_fx(f)(torch.randn(2, 3, 4, 5))
711
Horace He5e230742022-10-19 02:07:13 +0000712 def test_pr_86917(self):
713 # Tests the issue brought up here https://github.com/pytorch/pytorch/pull/86917#issuecomment-1283155344
714 def f(a, b):
715 return torch.ops.aten.nll_loss_forward(a, b, None, 1, 10)
716
717 self._test(f, [torch.randn(1, 10), torch.zeros(1, dtype=torch.long)])
718
Edward Z. Yangd2472442022-08-03 14:25:16 -0700719class TestGenericProxyTensorReal(TestGenericProxyTensor):
720 tracing_mode = "real"
721
722
723class TestGenericProxyTensorFake(TestGenericProxyTensor):
724 tracing_mode = "fake"
725
726
Edward Z. Yangd2472442022-08-03 14:25:16 -0700727@xfail_inherited_tests([
Edward Z. Yangd2472442022-08-03 14:25:16 -0700728 "test_make_fx_overloads",
Edward Z. Yangd2472442022-08-03 14:25:16 -0700729])
730class TestGenericProxyTensorSymbolic(TestGenericProxyTensor):
731 tracing_mode = "symbolic"
732
733
734del TestGenericProxyTensor
735
736
Edward Z. Yangb361f702022-08-03 10:50:30 -0700737class TestRealProxyTensor(TestCase):
Horace Hec2808572022-08-13 00:37:28 +0000738 pass
Edward Z. Yangb361f702022-08-03 10:50:30 -0700739
740class TestFakeProxyTensor(TestCase):
741 def test_issue82547(self):
742 x = nn.Parameter(torch.randn(3, 3))
743
744 def f():
745 return torch.ops.aten.t.default(x)
Tugsbayasgalan (Tugsuu) Manlaibaatar1aab7552022-12-12 18:53:08 +0000746 self.assertRaisesRegex(Exception, "Please convert all Tensors", lambda: make_fx(f, tracing_mode="fake")())
Edward Z. Yangb361f702022-08-03 10:50:30 -0700747
748 class A(torch.Tensor):
749 pass
750
751 x = A(torch.randn(3, 3))
752 self.assertRaisesRegex(TypeError, "no implementation found", lambda: make_fx(f, tracing_mode="fake")())
753
754 def test_use_fake_and_tensor(self):
755 def f(x, y):
756 z = torch.tensor([2.0, 3.0])
757 return x + y + z
758
759 g = make_fx(f, tracing_mode="fake")(torch.randn(2), torch.randn(2))
760 x, y = torch.randn(2), torch.randn(2)
761 self.assertEqual(g(x, y), f(x, y))
762
Edward Z. Yang10c938a2023-04-21 16:02:40 -0400763 def test_fused_adam(self):
764 # See https://github.com/pytorch/pytorch/issues/99356
Wanchao Liangff7d5b62023-04-24 17:25:53 +0000765 params = [torch.randn(10, 10) for _ in range(10)]
Edward Z. Yang10c938a2023-04-21 16:02:40 -0400766 grads = [torch.randn(10, 10) for _ in range(10)]
767 exp_avgs = [torch.randn(10, 10) for _ in range(10)]
768 exp_avg_sqs = [torch.randn(10, 10) for _ in range(10)]
769 max_exp_avg_sqs = [torch.randn(10, 10) for _ in range(10)]
770 state_steps = [torch.tensor(0) for _ in range(10)]
771
772 def fused_adam(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps):
Wanchao Liangff7d5b62023-04-24 17:25:53 +0000773 (new_params, _, _, _, _) = aten._fused_adam.default(
Edward Z. Yang10c938a2023-04-21 16:02:40 -0400774 params,
775 grads,
776 exp_avgs,
777 exp_avg_sqs,
778 max_exp_avg_sqs,
779 state_steps,
780 lr=0.1,
781 beta1=0.9,
782 beta2=0.999,
783 weight_decay=0.01,
784 eps=1e-8,
785 amsgrad=False,
786 maximize=False,
787 )
788
Wanchao Liangff7d5b62023-04-24 17:25:53 +0000789 for p, new_p in zip(params, new_params):
790 p.copy_(new_p)
791
792 return params
793
Edward Z. Yang10c938a2023-04-21 16:02:40 -0400794 gm = make_fx(fused_adam, tracing_mode='fake')(
795 params,
796 grads,
797 exp_avgs,
798 exp_avg_sqs,
799 max_exp_avg_sqs,
800 state_steps,
801 )
Wanchao Liangff7d5b62023-04-24 17:25:53 +0000802 ensure_ops_have_val = [aten._fused_adam.default, operator.getitem]
Edward Z. Yang10c938a2023-04-21 16:02:40 -0400803 for n in gm.graph.nodes:
Wanchao Liangff7d5b62023-04-24 17:25:53 +0000804 if n.op == "call_function" and n.target in ensure_ops_have_val:
Edward Z. Yang10c938a2023-04-21 16:02:40 -0400805 self.assertIn('val', n.meta)
806
Edward Z. Yangccade942022-09-14 10:51:36 -0700807 def test_alias(self):
808 def f(x):
809 return torch.ops.aten.alias(x)
810
811 r = str(make_fx(f, tracing_mode="fake")(torch.randn(2)).code).strip()
812 # NB: this should not have a detach call
813 self.assertExpectedInline(r, """\
814def forward(self, x_1):
815 alias = torch.ops.aten.alias.default(x_1); x_1 = None
816 return alias""")
817
Horace He2c1bc212022-10-15 04:10:47 +0000818 def test_meta(self):
819 def f(x):
820 a = x.cos()
821 b = torch.var_mean(a, dim=0)
822 c = b * 2
823 return c
824
825 out = make_fx(f, tracing_mode="fake")(torch.randn(5, 5))
826 for n in out.graph.nodes:
827 if n.op == 'output':
828 continue
829 self.assertTrue('val' in n.meta)
830
Horace He6a3ecda2022-08-31 00:29:55 +0000831def _get_node(fx_g, cond):
832 for n in fx_g.graph.nodes:
833 if cond(n):
834 return n
835 raise AssertionError
836
Horace He377b5d62022-09-16 22:59:44 +0000837def _get_free_symbols(shape_env):
838 vars = tuple(shape_env.var_to_val.keys())
839 return len([var for var in vars if var not in shape_env.replacements])
840
841def _trace(f, *args):
842 inps = [torch.randn(arg) for arg in args]
843 return make_fx(f, tracing_mode="symbolic")(*inps)
844
Horace He1a18ff32022-07-23 19:03:38 +0000845# TODO: Need to test the guards themselves specifically as well
Horace He1a18ff32022-07-23 19:03:38 +0000846class TestSymbolicTracing(TestCase):
Edward Z. Yang4c8cfb52022-08-15 20:03:13 -0700847 def _test_dynamic(self, fn, trace_inputs, test_inputs, assert_eq=True):
Horace He1a18ff32022-07-23 19:03:38 +0000848 """
849 Tests fn traced with trace_inputs against test_inputs
850 Also returns shape env
851 """
852 trace_inputs = [torch.randn(shape) for shape in trace_inputs]
853 traced_f = make_fx(fn, tracing_mode="symbolic")(*trace_inputs)
854 for input in test_inputs:
855 input = [torch.randn(shape) for shape in input]
Edward Z. Yang4c8cfb52022-08-15 20:03:13 -0700856 rx, ry = traced_f(*input), fn(*input)
857 if assert_eq:
858 self.assertEqual(rx, ry)
Edward Z. Yang45109ec2022-12-10 05:19:57 -0800859 return traced_f
Horace He1a18ff32022-07-23 19:03:38 +0000860
861
Edward Z. Yange48c9162022-12-16 09:02:37 +0800862 def test_debug_interpreter(self):
863 import torch.library
864 from torch.library import Library
865
866 foo = Library("foo", "DEF")
867 foo.define("foo(Tensor self) -> Tensor")
868
869 # Operator where meta and cpu disagree on strides
870 @torch.library.impl(foo, "foo", "CPU")
871 def foo_cpu(x):
872 return x.clone().T
873
874 @torch.library.impl(foo, "foo", "Meta")
875 def foo_meta(x):
876 return x.clone()
877
878 def f(x):
879 return torch.ops.foo.foo.default(x)
880
881 gm = make_fx(f, tracing_mode="symbolic")(torch.randn(2, 2))
882 from torch._functorch.compilers import DebugInterpreter
883
884 interp = DebugInterpreter(gm)
885
886 # input mismatch is caught (indicates guard problem)
887 self.assertRaisesRegex(
888 AssertionError, r"3 != 1",
889 lambda: interp.run(torch.randn(3, 3).T),
890 )
891
892 # Catch the incorrect meta
893 self.assertRaisesRegex(
894 AssertionError, r"\(3, 1\) != \(1, 3\)",
895 lambda: interp.run(torch.randn(3, 3))
896 )
897
Edward Z. Yange33f1ee2022-12-10 20:23:17 -0800898 def test_resize_from_zero(self):
899 def f(x, y):
900 x.resize_(y.size(0))
901
902 r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(0), torch.empty(2)).code).strip()
903 self.assertExpectedInline(r, """\
904def forward(self, x_1, y_1):
905 sym_size = torch.ops.aten.sym_size(y_1, 0); y_1 = None
906 resize_ = torch.ops.aten.resize_.default(x_1, [sym_size]); x_1 = sym_size = None
907 return None""")
908
909
Horace He1a18ff32022-07-23 19:03:38 +0000910 def test_unary(self):
911 def f(x):
912 assert x.shape[0] < 20
913 return x.cos()
914 test_inputs = []
915 test_inputs.append([(2, 5)])
916 test_inputs.append([(6, 8)])
Edward Z. Yang45109ec2022-12-10 05:19:57 -0800917 gm = self._test_dynamic(f, [(3, 4)], test_inputs)
918 self.assertTrue(eval_guards(gm, torch.randn(4, 5)))
Edward Z. Yang67436f62022-12-16 09:02:35 +0800919 self.assertEqual(repr(bind_symbols(gm, torch.randn(4, 5))), "{s0: 4, s1: 5}")
Edward Z. Yang45109ec2022-12-10 05:19:57 -0800920 self.assertFalse(eval_guards(gm, torch.randn(25, 5)))
Michael Voznesenskyb1e60bf2023-04-03 20:11:34 +0000921 self.assertExpectedInline(show_guards(gm), """L['x'].size()[0] < 20""")
Horace He1a18ff32022-07-23 19:03:38 +0000922
Edward Z. Yang1b5bfe92023-01-26 09:42:30 -0800923 @unittest.skipIf(not HAS_CUDA, 'CUDA-only test')
924 def test_cpu_scalar_cuda(self):
925 # Extracted from wave2vec2
926 def f(a, b):
927 return (a * b) @ b
928
929 r = str(
930 make_fx(f, tracing_mode="symbolic")(
931 torch.tensor(1.0), torch.randn(2, 2, device='cuda')
932 ).code
933 ).strip()
934 self.assertExpectedInline(r, """\
935def forward(self, a_1, b_1):
936 mul = torch.ops.aten.mul.Tensor(a_1, b_1); a_1 = None
937 mm = torch.ops.aten.mm.default(mul, b_1); mul = b_1 = None
938 return mm""")
939
Horace He1a18ff32022-07-23 19:03:38 +0000940 def test_binary_broadcast(self):
941 def f(a, b):
942 c = a * b
943 return c
944
945 test_inputs = []
946 test_inputs.append([(1, 5), (3, 1)])
947 test_inputs.append([(1, 4), (4, 1)])
Edward Z. Yang45109ec2022-12-10 05:19:57 -0800948 shape_env = self._test_dynamic(f, [(1, 2), (3, 1)], test_inputs).shape_env
Horace He1a18ff32022-07-23 19:03:38 +0000949 assert len(shape_env.guards) == 0
950
Edward Z. Yang4c8cfb52022-08-15 20:03:13 -0700951 def test_multiply_shape(self):
952 def f(a):
953 return torch.empty(a.shape[0] * 2)
954
955 r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip()
956 self.assertExpectedInline(r, """\
957def forward(self, a_1):
Horace He7ebdb4c2022-08-23 05:11:03 +0000958 sym_size = torch.ops.aten.sym_size(a_1, 0); a_1 = None
959 mul = sym_size * 2; sym_size = None
Edward Z. Yangad446702022-08-29 06:08:43 -0700960 empty = torch.ops.aten.empty.memory_format([mul], device = device(type='cpu'), pin_memory = False); mul = None
Edward Z. Yang94b5c802022-11-18 13:14:40 -0800961 return empty""")
Edward Z. Yang4c8cfb52022-08-15 20:03:13 -0700962
Edward Z. Yangf7365ec2022-12-10 20:29:21 -0800963 def test_item(self):
964 def f(a):
965 r = a.item()
966 return r * a
967
968 r = str(make_fx(f, tracing_mode="symbolic")(torch.randn(1)).code).strip()
969 self.assertExpectedInline(r, """\
970def forward(self, a_1):
971 _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(a_1)
972 mul = torch.ops.aten.mul.Tensor(a_1, _local_scalar_dense); a_1 = _local_scalar_dense = None
973 return mul""")
974
Edward Z. Yang2f32fd72023-02-15 06:37:23 -0800975 def test_item_to_constructor(self):
976 def f(a):
977 r = a.item()
Edward Z. Yang027ebca2023-03-01 10:51:12 -0800978 constrain_range(r, min=2)
Edward Z. Yang2f32fd72023-02-15 06:37:23 -0800979 return torch.empty(r)
980
981 r = str(make_fx(f, tracing_mode="symbolic")(torch.randint(5, (1,))).code).strip()
982 self.assertExpectedInline(
983 r, """\
984def forward(self, a_1):
985 _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(a_1); a_1 = None
986 empty = torch.ops.aten.empty.memory_format([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None
987 return empty""" # noqa: B950
988 )
albanD12b2f702022-10-19 11:27:42 -0400989
Edward Z. Yang4833e472023-02-23 11:51:25 -0800990 def test_dynamic_pointwise_scalar(self):
991 def f(gravity, mask):
992 gravity[mask, 0] = gravity[mask, 0] * -1
993
994 r = str(make_fx(f, tracing_mode="symbolic")(
995 torch.randn((12, 4)),
996 torch.randint(0, 2, (12,), dtype=torch.bool)
997 ).code).strip()
998 self.assertExpectedInline(r, """\
999def forward(self, gravity_1, mask_1):
1000 select = torch.ops.aten.select.int(gravity_1, 1, 0)
1001 index = torch.ops.aten.index.Tensor(select, [mask_1]); select = None
1002 mul = torch.ops.aten.mul.Tensor(index, -1); index = None
1003 select_1 = torch.ops.aten.select.int(gravity_1, 1, 0); gravity_1 = None
1004 index_put_ = torch.ops.aten.index_put_.default(select_1, [mask_1], mul); select_1 = mask_1 = mul = None
1005 return None""")
1006
1007 def test_reflect_r_over_x(self):
1008 def reflect_R_over_x(R):
1009 reflect = torch.eye(3, device=R.device)
1010 reflect[0, 0] = -1
1011 return reflect @ R @ reflect
1012
1013 def f(crop_camera, mask):
1014 crop_camera[mask] = reflect_R_over_x(crop_camera[mask])
1015
1016 r = str(make_fx(f, tracing_mode="symbolic")(
1017 torch.randn((12, 3, 3)),
1018 torch.randint(0, 2, (12,), dtype=torch.bool)
1019 ).code).strip()
1020 self.assertExpectedInline(r, """\
1021def forward(self, crop_camera_1, mask_1):
1022 index = torch.ops.aten.index.Tensor(crop_camera_1, [mask_1])
1023 eye = torch.ops.aten.eye.default(3, device = device(type='cpu'), pin_memory = False)
1024 _tensor_constant0 = self._tensor_constant0
1025 lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None
1026 select = torch.ops.aten.select.int(eye, 0, 0)
1027 select_1 = torch.ops.aten.select.int(select, 0, 0); select = None
1028 copy_ = torch.ops.aten.copy_.default(select_1, lift_fresh_copy); select_1 = lift_fresh_copy = None
1029 transpose = torch.ops.aten.transpose.int(index, -2, -1)
1030 t = torch.ops.aten.t.default(eye)
1031 clone = torch.ops.aten.clone.default(transpose, memory_format = torch.contiguous_format); transpose = None
1032 sym_size = torch.ops.aten.sym_size(index, 0); index = None
1033 sym_size_1 = torch.ops.aten.sym_size(crop_camera_1, 2)
1034 mul = sym_size * sym_size_1
1035 sym_size_2 = torch.ops.aten.sym_size(crop_camera_1, 1)
1036 _unsafe_view = torch.ops.aten._unsafe_view.default(clone, [mul, sym_size_2]); clone = mul = sym_size_2 = None
1037 mm = torch.ops.aten.mm.default(_unsafe_view, t); _unsafe_view = t = None
1038 view = torch.ops.aten.view.default(mm, [sym_size, sym_size_1, 3]); mm = sym_size_1 = None
1039 transpose_1 = torch.ops.aten.transpose.int(view, -2, -1)
1040 clone_1 = torch.ops.aten.clone.default(transpose_1, memory_format = torch.contiguous_format); transpose_1 = None
1041 mul_1 = sym_size * 3
1042 sym_size_3 = torch.ops.aten.sym_size(view, 1); view = None
1043 view_1 = torch.ops.aten.view.default(clone_1, [mul_1, sym_size_3]); clone_1 = mul_1 = sym_size_3 = None
1044 mm_1 = torch.ops.aten.mm.default(view_1, eye); view_1 = eye = None
1045 view_2 = torch.ops.aten.view.default(mm_1, [sym_size, 3, 3]); mm_1 = sym_size = None
1046 index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], view_2); crop_camera_1 = mask_1 = view_2 = None
1047 return None""")
1048
Edward Z. Yang98ff8412023-03-06 18:23:35 -08001049 def test_unbacked_slice(self):
1050 def f(x, m):
1051 x = x[m]
1052 return x[slice(None, None, None), slice(None, None, None), slice(None, 2, None)]
1053
1054 make_fx(f, tracing_mode="symbolic")(
1055 torch.randn((12, 3, 3)),
1056 torch.randint(0, 2, (12,), dtype=torch.bool)
1057 )
1058
Edward Z. Yang4833e472023-02-23 11:51:25 -08001059 @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision")
1060 def test_unbacked_batch_resnet(self):
1061 mod = torchvision.models.resnet18()
1062
1063 def f(x, mask, params, buffers):
1064 for p in itertools.chain([x, mask], params.values(), buffers.values()):
1065 for s in p.shape:
1066 guard_int(s)
1067 x = x[mask]
1068 constrain_range(x.shape[0], min=1)
1069 for p in params.values():
1070 p.grad = None
1071 return torch.func.functional_call(mod, {**params, **buffers}, (x,)).sum()
1072
1073 make_fx(f, tracing_mode="symbolic")(
1074 torch.randn(3, 3, 250, 250),
1075 torch.randint(0, 2, (3,), dtype=torch.bool),
1076 dict(mod.named_parameters()),
1077 dict(mod.named_buffers()),
1078 )
1079
1080 def test_boolean_index(self):
1081 def f(images, handedness, valid):
1082 images = images[valid]
1083 handedness = handedness[valid]
Edward Z. Yang4833e472023-02-23 11:51:25 -08001084 right_hand_mask = handedness == 1
1085 images[right_hand_mask] = images[right_hand_mask].flip(-1)
1086
1087 r = str(make_fx(f, tracing_mode="symbolic")(
1088 torch.randint(0, 256, (512, 1, 96, 96)),
1089 torch.randint(0, 1, (512,)),
1090 torch.randint(0, 2, (512,), dtype=torch.bool)
1091 ).code).strip()
1092 self.assertExpectedInline(r, """\
1093def forward(self, images_1, handedness_1, valid_1):
1094 index = torch.ops.aten.index.Tensor(images_1, [valid_1]); images_1 = None
1095 index_1 = torch.ops.aten.index.Tensor(handedness_1, [valid_1]); handedness_1 = valid_1 = None
1096 eq = torch.ops.aten.eq.Scalar(index_1, 1); index_1 = None
1097 index_2 = torch.ops.aten.index.Tensor(index, [eq])
1098 flip = torch.ops.aten.flip.default(index_2, [-1]); index_2 = None
1099 index_put_ = torch.ops.aten.index_put_.default(index, [eq], flip); index = eq = flip = None
1100 return None""")
1101
albanD12b2f702022-10-19 11:27:42 -04001102 def test_neg_shape(self):
1103 def f(a):
1104 return torch.empty(-a.shape[0] + 10)
1105
Horace He21bef8e2022-10-26 16:37:10 +00001106 r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(2)).code).strip()
albanD12b2f702022-10-19 11:27:42 -04001107 self.assertExpectedInline(r, """\
1108def forward(self, a_1):
1109 sym_size = torch.ops.aten.sym_size(a_1, 0); a_1 = None
1110 neg = -sym_size; sym_size = None
1111 add = neg + 10; neg = None
1112 empty = torch.ops.aten.empty.memory_format([add], device = device(type='cpu'), pin_memory = False); add = None
Edward Z. Yang94b5c802022-11-18 13:14:40 -08001113 return empty""")
albanD12b2f702022-10-19 11:27:42 -04001114
Edward Z. Yang8efe4fd2023-02-23 11:54:36 -08001115 def test_invalidate_nonzero(self):
1116 ok = False
1117
1118 def f(a):
1119 nonlocal ok
1120 b = a.clone()
1121 x = b.nonzero()
1122 x1 = b.nonzero()
1123 x2 = b.nonzero()
1124 assert x1.shape[0] == x2.shape[0]
1125 ok = True
1126 b.normal_()
1127 y = b.nonzero()
1128 try:
1129 bool(x1.shape[0] == y.shape[0])
1130 self.fail("didn't raise exception")
1131 except GuardOnDataDependentSymNode:
1132 pass
1133
1134 make_fx(f, tracing_mode="symbolic")(torch.randn(4))
1135
albanDc21dcff2022-10-16 22:16:14 -04001136 def test_sqrt_size(self):
1137 def f(a):
1138 return a / a.size(-1) ** 0.5
1139
1140 r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip()
1141 self.assertExpectedInline(r, """\
1142def forward(self, a_1):
1143 sym_size = torch.ops.aten.sym_size(a_1, 0)
Edward Z. Yang1ff52222022-10-27 13:49:11 -07001144 pow_1 = sym_size ** 0.5; sym_size = None
albanDc21dcff2022-10-16 22:16:14 -04001145 div = torch.ops.aten.div.Tensor(a_1, pow_1); a_1 = pow_1 = None
1146 return div""")
1147
1148
Edward Z. Yang2a332af2022-09-02 08:53:59 -07001149 def test_symint_to_tensor(self):
1150 def f(a):
1151 return a / a.shape[0]
1152
1153 r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip()
1154 self.assertExpectedInline(r, """\
1155def forward(self, a_1):
1156 sym_size = torch.ops.aten.sym_size(a_1, 0)
1157 div = torch.ops.aten.div.Tensor(a_1, sym_size); a_1 = sym_size = None
Edward Z. Yang2a332af2022-09-02 08:53:59 -07001158 return div""")
1159
1160 r = str(make_fx(f, tracing_mode="symbolic", decomposition_table=decomposition_table)(torch.empty(4)).code).strip()
1161 self.assertExpectedInline(r, """\
1162def forward(self, a_1):
1163 sym_size = torch.ops.aten.sym_size(a_1, 0)
Joel Schlosser8b55b862022-12-27 16:59:38 -05001164 sym_float = torch.sym_float(sym_size); sym_size = None
Edward Z. Yang2a332af2022-09-02 08:53:59 -07001165 div = torch.ops.prims.div.default(a_1, sym_float); a_1 = sym_float = None
Edward Z. Yang2a332af2022-09-02 08:53:59 -07001166 return div""")
1167
Horace He1a18ff32022-07-23 19:03:38 +00001168 def test_cat(self):
1169 def f(a, b):
1170 val = torch.mul(a, b)
1171 out = torch.cat([val, val])
1172 if out.shape[0] * out.shape[1] > 20:
1173 out = out.cos()
1174 return out
1175
1176 test_inputs = []
1177 test_inputs.append([(1, 5), (6, 1)])
1178 test_inputs.append([(1, 4), (3, 1)])
Edward Z. Yang45109ec2022-12-10 05:19:57 -08001179 gm = self._test_dynamic(f, [(1, 6), (8, 1)], test_inputs)
1180 self.assertTrue(eval_guards(gm, torch.randn(1, 10), torch.randn(6, 1)))
1181 self.assertFalse(eval_guards(gm, torch.randn(1, 2), torch.randn(4, 1)))
Michael Voznesenskyb1e60bf2023-04-03 20:11:34 +00001182 self.assertExpectedInline(show_guards(gm), """2*L['a'].size()[1]*L['b'].size()[0] > 20""")
Horace He1a18ff32022-07-23 19:03:38 +00001183
Horace He86de9e72022-08-13 19:03:13 +00001184 def test_new_empty(self):
1185 def f(a, b):
Edward Z. Yang4c8cfb52022-08-15 20:03:13 -07001186 return a.new_empty(b.shape[0], b.shape[1] * 2)
Horace He86de9e72022-08-13 19:03:13 +00001187
Edward Z. Yang45109ec2022-12-10 05:19:57 -08001188 self._test_dynamic(f, [(2, 4), (4, 5)], [[(2, 3), (5, 7)], [(3, 7), (9, 3)]], assert_eq=False).shape_env
Horace He86de9e72022-08-13 19:03:13 +00001189
Edward Z. Yang954660a2022-10-03 09:29:49 -07001190 def test_size_with_tensor(self):
1191 def f(tensor):
1192 max_size = torch.tensor([800, 1216], dtype=torch.int64)
1193 batch_shape = [2] + list(tensor.shape[:-2]) + list(max_size)
1194 return tensor.new_empty(batch_shape)
1195
1196 a = torch.randn(3, 800, 1199)
1197 self.assertRaisesRegex(
1198 RuntimeError, "data-dependent", lambda: make_fx(f, tracing_mode="symbolic")(a)
1199 )
1200
Horace He86de9e72022-08-13 19:03:13 +00001201 def test_expand(self):
1202 def f(a):
1203 b = torch.mul(a, a)
1204 c = b.expand(a.shape)
1205 return c
1206
1207 self._test_dynamic(f, [(3,)], [[(3,)], [(4,)], [(2,)]])
1208 self._test_dynamic(f, [(5, 1)], [[(4, 1)], [(3, 1)], [(6, 1)]])
1209
Horace Hebd757b32022-10-19 03:19:22 +00001210 def test_metadata(self):
Horace He6a3ecda2022-08-31 00:29:55 +00001211 def f(a, b):
1212 d = a.new_empty(a.shape[0] + b.shape[0])
1213 return d
1214 fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5), torch.randn(4))
Horace He6a3ecda2022-08-31 00:29:55 +00001215 meta_c = _get_node(fx_g, lambda x: x.target == aten.new_empty.default)
1216 meta_d = _get_node(fx_g, lambda x: x.target == operator.add)
Edward Z. Yangc4501592023-01-19 21:16:12 +00001217 self.assertTrue(meta_c.meta['val'].shape[0].node.expr == meta_d.meta['val'].node.expr)
Horace Hebd757b32022-10-19 03:19:22 +00001218
1219 def test_metadata_fresh(self):
1220 def f(x):
1221 assert x.shape[0] == 3
1222 return x.cos()
1223
1224 fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(3))
1225 meta_cos = _get_node(fx_g, lambda x: x.target == aten.cos.default)
1226 meta_inp = _get_node(fx_g, lambda x: x.op == 'placeholder')
Edward Z. Yangc4501592023-01-19 21:16:12 +00001227 self.assertTrue(meta_cos.meta['val'].shape[0].node.expr == 3)
Horace Hebd757b32022-10-19 03:19:22 +00001228 # Checks if the input expr has been updated even though the constraint
1229 # happened afterwards
Edward Z. Yangc4501592023-01-19 21:16:12 +00001230 self.assertTrue(meta_inp.meta['val'].shape[0].node.expr == 3)
Horace Hebd757b32022-10-19 03:19:22 +00001231
Sherlock Huangcaf3d532022-11-19 23:10:34 +00001232 def test_elementwise_meta_with_sym_numbers(self):
1233 def f(x, offset, as_sym_float=False):
1234 x0 = x.size()[0]
1235 if as_sym_float:
1236 x0 = sym_float(x0)
1237 return torch.add(x0, offset)
Horace Hebd757b32022-10-19 03:19:22 +00001238
Sherlock Huangcaf3d532022-11-19 23:10:34 +00001239 fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2.0, False)
1240 meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor)
1241 self.assertEqual(meta_add.meta['val'].shape, ())
1242 self.assertEqual(meta_add.meta['val'].dtype, torch.float32)
Horace Hebd757b32022-10-19 03:19:22 +00001243
Sherlock Huangcaf3d532022-11-19 23:10:34 +00001244 fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2, False)
1245 meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor)
1246 self.assertEqual(meta_add.meta['val'].shape, ())
1247 self.assertEqual(meta_add.meta['val'].dtype, torch.int64)
1248
1249 fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2, True)
1250 meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor)
1251 self.assertEqual(meta_add.meta['val'].shape, ())
1252 self.assertEqual(meta_add.meta['val'].dtype, torch.float32)
Horace He86de9e72022-08-13 19:03:13 +00001253
Horace Hebc993e32022-10-03 07:11:53 +00001254 def test_return_symint(self):
1255 def f(x):
1256 return x.shape[0], x.cos(), x.shape[0] / 5
1257 self._test_dynamic(f, [(5,)], [[(4,)], [(12,)]])
1258
1259 def f(x):
1260 return x.shape
1261 self._test_dynamic(f, [(5, 3)], [[(4, 6)]])
1262
Edward Z. Yang2a47b102022-10-29 08:45:32 -07001263 def test_rmethod(self):
1264 def f(x):
1265 return x.size(0) + x
1266 self._test_dynamic(f, [(5,)], [[(4,)], [(12,)]])
1267
Horace He569eebb2022-10-25 04:04:16 +00001268 def test_mega_guard(self):
1269 def f(a, b):
1270 assert a.shape[0] == b.shape[0] * 2
Horace He569eebb2022-10-25 04:04:16 +00001271 return a.cos()
1272 fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(16), torch.randn(8))
Edward Z. Yangbcf15cd2022-12-29 13:32:31 +08001273 from torch._dynamo.source import LocalSource
Edward Z. Yang45109ec2022-12-10 05:19:57 -08001274 self.assertExpectedInline(
Michael Voznesensky4c289292023-04-22 07:33:12 +00001275 str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")], ignore_static=False)), # noqa: B950
Michael Voznesenskyb1e60bf2023-04-03 20:11:34 +00001276 """["L['a'].size()[0] == 2*L['b'].size()[0]", "L['a'].stride()[0] == 1", "L['a'].storage_offset() == 0", "L['b'].stride()[0] == 1", "L['b'].storage_offset() == 0", "2 <= L['b'].size()[0]"]""" # noqa: B950
Edward Z. Yang45109ec2022-12-10 05:19:57 -08001277 )
Michael Voznesensky4c289292023-04-22 07:33:12 +00001278 self.assertExpectedInline(
1279 str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")], ignore_static=True)), # noqa: B950
1280 """["L['a'].size()[0] == 2*L['b'].size()[0]", "2 <= L['b'].size()[0]"]""" # noqa: B950
1281 )
Horace He569eebb2022-10-25 04:04:16 +00001282
Horace He21bef8e2022-10-26 16:37:10 +00001283 def test_sym_storage_offset(self):
1284 def f(x, y):
1285 return x + y
1286
1287 inp = (torch.randn(8)[3:], torch.randn(5))
1288 fx_g = make_fx(f, tracing_mode="symbolic")(*inp)
1289 inp = (torch.randn(8)[3:], torch.randn(5))
1290 self.assertEqual(fx_g(*inp), f(*inp))
Horace He569eebb2022-10-25 04:04:16 +00001291
Horace He377b5d62022-09-16 22:59:44 +00001292 def _assert_no_guards(self, fx_g, free_symbols):
Edward Z. Yang9baf6772022-09-21 07:00:52 -07001293 assert _get_free_symbols(fx_g.shape_env) == free_symbols, fx_g.shape_env.var_to_val
1294 assert len(fx_g.shape_env.get_nontrivial_guards()) == 0, fx_g.shape_env.format_guards()
Horace He377b5d62022-09-16 22:59:44 +00001295
1296 def test_guards_equal(self):
1297 def f(a, b):
1298 return a * b
1299
Edward Z. Yangada6e5b2022-09-28 17:28:26 -04001300 # NB: Numbers are carefully chosen to avoid duck shaping from applying
1301
1302 fx_g = _trace(f, (5, 6), (5, 6))
Horace He377b5d62022-09-16 22:59:44 +00001303 self._assert_no_guards(fx_g, 2)
1304
Edward Z. Yangada6e5b2022-09-28 17:28:26 -04001305 fx_g = _trace(f, (5, 6, 7), (5, 6, 7))
Horace He377b5d62022-09-16 22:59:44 +00001306 self._assert_no_guards(fx_g, 3)
1307
Edward Z. Yangada6e5b2022-09-28 17:28:26 -04001308 fx_g = _trace(f, (5, 1), (1, 6))
1309 self._assert_no_guards(fx_g, 2)
Horace He377b5d62022-09-16 22:59:44 +00001310
1311 def f(a, b, c, d):
1312 a = a + b
1313 cat = torch.cat([c, d])
1314 return a + cat
1315
1316 fx_g = _trace(f, 7, 7, 4, 3)
1317 self._assert_no_guards(fx_g, 2)
1318
Horace He12a19a42022-09-17 18:11:51 +00001319 def f(a, b, c, d, e):
1320 vals = [a, b, c, d, e]
1321 x = a
1322 for idx in range(len(vals) - 1):
1323 x = torch.cat([x, vals[idx]]) + vals[idx + 1]
1324 return x
1325
1326 fx_g = _trace(f, 2, 4, 8, 16, 32)
1327 self._assert_no_guards(fx_g, 1)
1328
Horace He377b5d62022-09-16 22:59:44 +00001329 def f(a, b):
1330 a = a.view(b.shape[0])
1331 return a + b.sum()
1332
1333 fx_g = _trace(f, (4, 2), 8)
1334 self._assert_no_guards(fx_g, 2)
1335
Edward Z. Yangada6e5b2022-09-28 17:28:26 -04001336 fx_g = _trace(f, (4, 2), (8, 5))
Horace He377b5d62022-09-16 22:59:44 +00001337 self._assert_no_guards(fx_g, 3)
1338
1339 fx_g = _trace(f, (2, 3, 4), 24)
1340 self._assert_no_guards(fx_g, 3)
1341
1342 def test_nonidentity_transitive_guards(self):
1343 def f(a, b, c, d, e):
1344 vals = [a, b, c, d, e]
1345 cat_vals = []
1346 for idx in range(len(vals) - 1):
1347 cat_vals.append(torch.cat([vals[idx], vals[idx]]))
1348 final_vals = []
1349 for a, b in reversed(list(zip(cat_vals, vals[1:]))):
1350 final_vals.append(a + b)
1351 return final_vals
1352
1353 fx_g = _trace(f, 2, 4, 8, 16, 32)
Edward Z. Yangf1f26fe2023-02-12 14:04:01 -08001354 self.assertExpectedInline(show_guards(fx_g), """""")
Horace He377b5d62022-09-16 22:59:44 +00001355
1356
Horace He86de9e72022-08-13 19:03:13 +00001357
Horace He12a19a42022-09-17 18:11:51 +00001358
Horace He5e230742022-10-19 02:07:13 +00001359
Horace He4d88aff2022-06-07 00:28:53 +00001360make_fx_failures = {
PyTorch MergeBot4e33c8c2022-06-27 12:06:49 +00001361 # unknown
Horace He4d88aff2022-06-07 00:28:53 +00001362 xfail('allclose'),
Peter Bell9bf52f42022-06-20 19:58:41 +01001363 xfail('equal'),
Horace Hef5d7e5a2022-06-16 22:04:10 +00001364 # empty
1365 skip('new_empty'),
1366 skip('empty_like'),
1367 skip('empty'),
Edward Z. Yangce950b42023-02-21 09:13:06 -05001368 skip('empty_permuted'),
Horace Hef5d7e5a2022-06-16 22:04:10 +00001369 # flaky
1370 skip('linalg.lstsq', 'grad_oriented'),
1371 skip('nn.functional.max_unpool1d', '', device_type='cpu'),
1372 skip('nn.functional.max_unpool2d', '', device_type='cpu'),
1373 skip('nn.functional.max_unpool3d', '', device_type='cpu'),
Horace He4d88aff2022-06-07 00:28:53 +00001374 skip('linalg.lstsq'), # flaky, probably just a precision issue
PyTorch MergeBot4e33c8c2022-06-27 12:06:49 +00001375
Horace He4d88aff2022-06-07 00:28:53 +00001376 # data-dependent control flow
1377 xfail('cov'),
1378 xfail('istft'),
Horace He4d88aff2022-06-07 00:28:53 +00001379 xfail('nn.functional.gaussian_nll_loss'),
Horace He4d88aff2022-06-07 00:28:53 +00001380 xfail('tensor_split'),
Horace Hef5d7e5a2022-06-16 22:04:10 +00001381 xfail('corrcoef'),
Edward Z. Yangd4237222022-08-12 06:17:53 -07001382 xfail('quantile'),
1383 xfail('nanquantile'),
kshitij123454f6027b2022-09-12 16:59:05 +00001384 xfail('narrow'),
Horace Hef5d7e5a2022-06-16 22:04:10 +00001385
Elias Ellison638feec2023-04-19 01:01:15 +00001386 # many complex operators incorrect striding, metadata
1387 skip('fft.fft', ''),
1388 skip('fft.hfft2', ''),
1389 skip('fft.hfft', ''),
1390 skip('fft.hfftn', ''),
1391 skip('fft.ifft', ''),
1392 skip('fft.ihfft2', ''),
1393 skip('fft.ihfft', ''),
1394 skip('fft.ihfftn', ''),
1395 skip('fft.irfft2', ''),
1396 skip('fft.irfft', ''),
1397 skip('fft.irfftn', ''),
1398 skip('fft.rfft2', ''),
1399 skip('fft.rfft', ''),
1400 skip('fft.rfftn', ''),
1401
Horace He4d88aff2022-06-07 00:28:53 +00001402 # Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse
1403 xfail('sparse.sampled_addmm'),
mingfeimac620ece2023-02-10 11:12:35 +08001404 xfail('sparse.mm', 'reduce'),
Natalia Gimelshein6162a042022-09-16 15:54:50 +00001405
Edward Z. Yang42fefd42022-08-02 13:22:19 -07001406 # proxy tensor doesn't support sparse correctly right now
1407 skip('to_sparse'),
PyTorch MergeBot4e33c8c2022-06-27 12:06:49 +00001408 # segfaults
1409 skip('block_diag'),
1410}
1411
1412fake_tensor_failures = {
PyTorch MergeBot4e33c8c2022-06-27 12:06:49 +00001413 # FakeTensor fallback doesn't work
albanD496c0a22023-02-06 18:32:23 +00001414 xfail('_segment_reduce', 'lengths'),
PyTorch MergeBot4e33c8c2022-06-27 12:06:49 +00001415 xfail('multinomial'),
PyTorch MergeBot4e33c8c2022-06-27 12:06:49 +00001416 xfail('cholesky'),
1417 xfail('cholesky_inverse'),
Sherlock Huangf1fb5862022-11-17 18:50:33 +00001418 # cannot do these as they rely on tensor data
1419 xfail('repeat_interleave'),
PyTorch MergeBot4e33c8c2022-06-27 12:06:49 +00001420 # ASAN failures due to divide by 0
1421 skip('nn.functional.nll_loss'),
Elias Ellison638feec2023-04-19 01:01:15 +00001422
Elias Ellison638feec2023-04-19 01:01:15 +00001423 xfail("stft"),
Horace He4d88aff2022-06-07 00:28:53 +00001424}
1425
Horace He1a18ff32022-07-23 19:03:38 +00001426symbolic_tensor_failures = {
Elias Ellison1c0f7bd2022-07-27 22:19:14 +00001427 xfail('linalg.eig'),
Richard Zoua47bc962022-09-06 10:20:07 -07001428 xfail('linalg.eigvals'),
Horace He1a18ff32022-07-23 19:03:38 +00001429 xfail('cholesky_solve', ''), # Could not run 'aten::_cholesky_solve_helper' with arguments from the 'Meta' back...
Horace Heb3b97862022-10-13 20:19:16 +00001430 xfail('combinations', ''),
Horace He1a18ff32022-07-23 19:03:38 +00001431 xfail('diff', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition
Horace He1a18ff32022-07-23 19:03:38 +00001432 xfail('frexp', ''), # aten.frexp.Tensor - couldn't find symbolic meta function/decomposition
Horace He1a18ff32022-07-23 19:03:38 +00001433 xfail('geqrf', ''), # aten.geqrf.default - couldn't find symbolic meta function/decomposition
1434 xfail('gradient', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
Horace He1a18ff32022-07-23 19:03:38 +00001435 xfail('histc', ''), # Could not run 'aten::histc' with arguments from the 'Meta' backend. This could be because...
1436 xfail('histogram', ''), # Could not run 'aten::histogram.bin_ct' with arguments from the 'Meta' backend. This c...
1437 xfail('histogramdd', ''), # aten._histogramdd_bin_edges.default - couldn't find symbolic meta function/decomposition
Horace He1a18ff32022-07-23 19:03:38 +00001438 xfail('index_reduce', ''), # Float
Horace He1a18ff32022-07-23 19:03:38 +00001439 xfail('isin', ''), # aten.isin.Tensor_Tensor - couldn't find symbolic meta function/decomposition
Horace He1a18ff32022-07-23 19:03:38 +00001440 xfail('kron', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
1441 xfail('kthvalue', ''), # aten.kthvalue.default - couldn't find symbolic meta function/decomposition
Horace He1a18ff32022-07-23 19:03:38 +00001442 xfail('linalg.eigh', ''), # aten._linalg_eigh.default - couldn't find symbolic meta function/decomposition
1443 xfail('linalg.eigvalsh', ''), # aten._linalg_eigh.default - couldn't find symbolic meta function/decomposition
1444 xfail('linalg.householder_product', ''), # aten.linalg_householder_product.default - couldn't find symbolic meta funct...
Horace He1a18ff32022-07-23 19:03:38 +00001445 xfail('linalg.ldl_factor', ''), # aten.linalg_ldl_factor_ex.default - couldn't find symbolic meta function/decomposition
1446 xfail('linalg.ldl_factor_ex', ''), # aten.linalg_ldl_factor_ex.default - couldn't find symbolic meta function/decompos...
1447 xfail('linalg.ldl_solve', ''), # aten.linalg_ldl_solve.default - couldn't find symbolic meta function/decomposition
1448 xfail('linalg.lu', ''), # aten.linalg_lu.default - couldn't find symbolic meta function/decomposition
1449 xfail('linalg.lu_factor', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition
1450 xfail('linalg.lu_factor_ex', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition
1451 xfail('linalg.lu_solve', ''), # aten.linalg_lu_solve.default - couldn't find symbolic meta function/decomposition
Edward Z. Yang617e90d2022-07-27 11:07:50 -04001452 xfail('linalg.matrix_power'), # RuntimeError: Trying to call aten.size on a tensor with symbolic shape
Horace He1a18ff32022-07-23 19:03:38 +00001453 xfail('linalg.matrix_rank', 'hermitian'), # aten.size.default - couldn't find symbolic meta function/decomposition
1454 xfail('linalg.multi_dot', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
Horace He1a18ff32022-07-23 19:03:38 +00001455 xfail('linalg.pinv', ''), # aten.linalg_pinv.atol_rtol_tensor - couldn't find symbolic meta function/decomposition
1456 xfail('linalg.pinv', 'singular'), # aten.linalg_cholesky_ex.default - couldn't find symbolic meta function/decomposition
1457 xfail('linalg.pinv', 'hermitian'), # aten.linalg_pinv.atol_rtol_tensor - couldn't find symbolic meta function/decompo...
Horace He1a18ff32022-07-23 19:03:38 +00001458 xfail('linalg.slogdet', ''), # aten._linalg_slogdet.default - couldn't find symbolic meta function/decomposition
1459 xfail('linalg.solve', ''), # aten._linalg_solve_ex.default - couldn't find symbolic meta function/decomposition
1460 xfail('linalg.solve_ex', ''), # aten._linalg_solve_ex.default - couldn't find symbolic meta function/decomposition
Horace He1a18ff32022-07-23 19:03:38 +00001461 xfail('linalg.tensorinv', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
1462 xfail('linalg.tensorsolve', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
1463 xfail('linalg.vander', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
Horace He1a18ff32022-07-23 19:03:38 +00001464 xfail('logaddexp2', ''), # aten.logaddexp2.default - couldn't find symbolic meta function/decomposition
Horace He1a18ff32022-07-23 19:03:38 +00001465 xfail('logdet', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
Horace He1a18ff32022-07-23 19:03:38 +00001466 xfail('lu', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition
1467 xfail('lu_solve', ''), # aten.linalg_lu_solve.default - couldn't find symbolic meta function/decomposition
1468 xfail('lu_unpack', ''), # aten.lu_unpack.default - couldn't find symbolic meta function/decomposition
Horace He1a18ff32022-07-23 19:03:38 +00001469 xfail('masked_select', ''), # aten.masked_select.default - couldn't find symbolic meta function/decomposition
Horace He1a18ff32022-07-23 19:03:38 +00001470 xfail('matrix_exp', ''), # aten.linalg_matrix_exp.default - couldn't find symbolic meta function/decomposition
Horace He1a18ff32022-07-23 19:03:38 +00001471 xfail('median', ''), # Could not run 'aten::median' with arguments from the 'Meta' backend. This could be becau...
Horace He1a18ff32022-07-23 19:03:38 +00001472 xfail('mode', ''), # aten.mode.default - couldn't find symbolic meta function/decomposition
kshitij12345a3d37f12022-08-08 14:42:51 +00001473 xfail('nanquantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend.
Horace He1a18ff32022-07-23 19:03:38 +00001474 xfail('narrow', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
Horace He1a18ff32022-07-23 19:03:38 +00001475 xfail('nn.functional.adaptive_max_pool1d', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
1476 xfail('nn.functional.adaptive_max_pool2d', ''), # aten.adaptive_max_pool2d.default - couldn't find symbolic meta funct...
1477 xfail('nn.functional.adaptive_max_pool3d', ''), # argument 'output_size' (position 2) must be tupl...
Horace He1a18ff32022-07-23 19:03:38 +00001478 xfail('nn.functional.avg_pool3d', ''), # aten.avg_pool3d.default - couldn't find symbolic meta function/decomposition
Horace He1a18ff32022-07-23 19:03:38 +00001479 xfail('nn.functional.bilinear', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
1480 xfail('nn.functional.binary_cross_entropy', ''), # aten.new_empty.default - couldn't find symbolic meta function/decom...
Horace He1a18ff32022-07-23 19:03:38 +00001481 xfail('nn.functional.cosine_similarity', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
1482 xfail('nn.functional.cross_entropy', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
kshitij1234556a41b52022-09-22 00:21:11 +00001483 xfail('nn.functional.ctc_loss'), # aten._ctc_loss.Tensor - couldn't find symbolic meta function/decomposition
Horace He1a18ff32022-07-23 19:03:38 +00001484 xfail('nn.functional.embedding_bag', ''), # aten._embedding_bag_forward_only.default - couldn't find symbolic meta fun...
Horace He1a18ff32022-07-23 19:03:38 +00001485 xfail('nn.functional.fractional_max_pool2d', ''), # argument 'size' must be tuple of ints, but found element of t...
1486 xfail('nn.functional.fractional_max_pool3d', ''), # argument 'size' must be tuple of ints, but found element of t...
Horace He1a18ff32022-07-23 19:03:38 +00001487 xfail('nn.functional.grid_sample', ''), # aten.grid_sampler_2d.default - couldn't find symbolic meta function/decompos...
Horace He1a18ff32022-07-23 19:03:38 +00001488 xfail('nn.functional.interpolate', 'linear'), # aten.upsample_linear1d.vec - couldn't find symbolic meta function/dec...
Horace He1a18ff32022-07-23 19:03:38 +00001489 xfail('nn.functional.interpolate', 'trilinear'), # aten.upsample_trilinear3d.vec - couldn't find symbolic meta functi...
Richard Zouc771d732022-09-06 07:10:33 -07001490 xfail('nn.functional.max_pool1d', ''), # Trying to call aten.size on a tensor with symbolic shapes.
Horace He1a18ff32022-07-23 19:03:38 +00001491 xfail('nn.functional.max_pool3d', ''), # aten.max_pool3d_with_indices.default - couldn't find symbolic meta function/d...
1492 xfail('nn.functional.max_unpool1d', 'grad'), # aten.max_unpool2d.default - couldn't find symbolic meta function/decom...
1493 xfail('nn.functional.max_unpool2d', 'grad'), # aten.max_unpool2d.default - couldn't find symbolic meta function/decom...
1494 xfail('nn.functional.max_unpool3d', 'grad'), # aten.max_unpool3d.default - couldn't find symbolic meta function/decom...
Horace He1a18ff32022-07-23 19:03:38 +00001495 xfail('nn.functional.multi_margin_loss', ''), # Could not run 'aten::multi_margin_loss' with arguments from the...
1496 xfail('nn.functional.multilabel_margin_loss', ''), # Could not run 'aten::multilabel_margin_loss_forward' with ...
Horace He1a18ff32022-07-23 19:03:38 +00001497 xfail('nn.functional.pad', 'reflect'), # aten.reflection_pad1d.default - couldn't find symbolic meta function/decompo...
1498 xfail('nn.functional.pad', 'replicate'), # aten.replication_pad1d.default - couldn't find symbolic meta function/deco...
1499 xfail('nn.functional.pdist', ''), # Could not run 'aten::_pdist_forward' with arguments from the 'Meta' backend...
Horace He1a18ff32022-07-23 19:03:38 +00001500 xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco...
Horace He1a18ff32022-07-23 19:03:38 +00001501 xfail('nn.functional.smooth_l1_loss', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
Horace He1a18ff32022-07-23 19:03:38 +00001502 xfail('normal', 'number_mean'), # aten.normal.float_Tensor - couldn't find symbolic meta function/decomposition
Horace He1a18ff32022-07-23 19:03:38 +00001503 xfail('ormqr', ''), # aten.ormqr.default - couldn't find symbolic meta function/decomposition
Horace He1a18ff32022-07-23 19:03:38 +00001504 xfail('pinverse', ''), # aten.linalg_pinv.atol_rtol_tensor - couldn't find symbolic meta function/decomposition
1505 xfail('polygamma', 'polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition
1506 xfail('polygamma', 'polygamma_n_1'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition
1507 xfail('polygamma', 'polygamma_n_2'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition
1508 xfail('polygamma', 'polygamma_n_3'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition
1509 xfail('polygamma', 'polygamma_n_4'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition
kshitij12345a3d37f12022-08-08 14:42:51 +00001510 xfail('quantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend.
Horace He1a18ff32022-07-23 19:03:38 +00001511 xfail('renorm', ''), # aten.renorm.default - couldn't find symbolic meta function/decomposition
Michael Voznesenskybc194942022-10-25 21:15:40 +00001512 xfail('repeat_interleave', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
Horace He1a18ff32022-07-23 19:03:38 +00001513 xfail('resize_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition
1514 xfail('resize_as_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition
1515 xfail('roll', ''), # Tensors of type TensorImpl do not have numel
Horace He1a18ff32022-07-23 19:03:38 +00001516 xfail('searchsorted', ''), # Could not run 'aten::searchsorted.Tensor' with arguments from the 'Meta' backend. ...
albanD496c0a22023-02-06 18:32:23 +00001517 xfail('_segment_reduce', 'offsets'), # aten.segment_reduce.default - couldn't find symbolic meta function/decomposition
Horace He1a18ff32022-07-23 19:03:38 +00001518 xfail('special.airy_ai', ''), # aten.special_airy_ai.default - couldn't find symbolic meta function/decomposition
Horace He1a18ff32022-07-23 19:03:38 +00001519 xfail('special.bessel_y0', ''), # aten.special_bessel_y0.default - couldn't find symbolic meta function/decomposition
1520 xfail('special.bessel_y1', ''), # aten.special_bessel_y1.default - couldn't find symbolic meta function/decomposition
1521 xfail('special.chebyshev_polynomial_t', ''), # aten.special_chebyshev_polynomial_t.default - couldn't find symbolic me...
1522 xfail('special.chebyshev_polynomial_u', ''), # aten.special_chebyshev_polynomial_u.default - couldn't find symbolic me...
Horace He1a18ff32022-07-23 19:03:38 +00001523 xfail('special.hermite_polynomial_h', ''), # aten.special_hermite_polynomial_h.default - couldn't find symbolic meta f...
1524 xfail('special.hermite_polynomial_he', ''), # aten.special_hermite_polynomial_he.default - couldn't find symbolic meta...
1525 xfail('special.laguerre_polynomial_l', ''), # aten.special_laguerre_polynomial_l.default - couldn't find symbolic meta...
Horace He1a18ff32022-07-23 19:03:38 +00001526 xfail('special.modified_bessel_i0', ''), # aten.special_modified_bessel_i0.default - couldn't find symbolic meta funct...
1527 xfail('special.modified_bessel_i1', ''), # aten.special_modified_bessel_i1.default - couldn't find symbolic meta funct...
1528 xfail('special.modified_bessel_k0', ''), # aten.special_modified_bessel_k0.default - couldn't find symbolic meta funct...
1529 xfail('special.modified_bessel_k1', ''), # aten.special_modified_bessel_k1.default - couldn't find symbolic meta funct...
Horace He1a18ff32022-07-23 19:03:38 +00001530 xfail('special.polygamma', 'special_polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic meta function/...
1531 xfail('special.scaled_modified_bessel_k0', ''), # aten.special_scaled_modified_bessel_k0.default - couldn't find symbo...
1532 xfail('special.scaled_modified_bessel_k1', ''), # aten.special_scaled_modified_bessel_k1.default - couldn't find symbo...
Nikolay Korovaikobfebf252022-08-05 03:36:09 +00001533 xfail('stft', ''), # argument 'size' must be tuple of ints, but found element of type torch._C.SymIntNode at...
Horace He1a18ff32022-07-23 19:03:38 +00001534 xfail('take_along_dim', ''), # dtype of indices should be Long but got Float
Horace He1a18ff32022-07-23 19:03:38 +00001535 xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/decomposition
Michael Voznesenskybc194942022-10-25 21:15:40 +00001536 xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
1537 xfail('unique', ''), # aten._unique2.default - couldn't find symbolic meta function/decomposition
Horace He1a18ff32022-07-23 19:03:38 +00001538}
Horace He7ebdb4c2022-08-23 05:11:03 +00001539symbolic_tensor_segfaults = {
Horace Heb3b97862022-10-13 20:19:16 +00001540 skip('nn.functional.batch_norm') # Segfault??
Horace He7ebdb4c2022-08-23 05:11:03 +00001541}
1542
1543symbolic_tensor_failures.update(symbolic_tensor_segfaults)
Horace He4d88aff2022-06-07 00:28:53 +00001544
Edward Z. Yang55820012022-11-19 12:51:53 -05001545outplace_symbolic_tensor_failures = {
Peter Bell8770a7e2023-01-18 11:29:42 +00001546 xfail('i0', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition
Edward Z. Yang55820012022-11-19 12:51:53 -05001547 xfail('masked_scatter', ''), # aten.masked_scatter.default - couldn't find symbolic meta function/decomposition
1548 xfail('nn.functional.rrelu', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition
1549}
1550
albanD9199f912022-10-19 18:33:17 -04001551inplace_symbolic_tensor_failures = {
lezcano154e58c2022-11-18 11:25:36 +00001552 # bugs
1553 xfail('float_power', ''), # base given to float_power_ has dtype Float but the operation's result requires dtype Double
1554 # decomp not implemented
lezcano154e58c2022-11-18 11:25:36 +00001555 xfail('unique', ''),
1556 # in-place has a different signature than out-of-place
1557 xfail('uniform', ''),
albanD9199f912022-10-19 18:33:17 -04001558}
1559
1560# Copies inputs to inplace operations to avoid inplace modifications
1561# to leaves requiring gradient
1562def _get_safe_inplace(inplace_variant):
1563 @functools.wraps(inplace_variant)
1564 def _fn(t, *args, **kwargs):
1565 return inplace_variant(t.clone(), *args, **kwargs)
1566
1567 return _fn
1568
1569def _test_make_fx_helper(self, device, dtype, op, tracing_mode, inplace=False):
Sherlock Huang5faa2792022-11-15 01:06:23 +00001570 def f(args, kwargs, extra_args, extra_kwargs):
albanD254b6812022-10-19 18:33:17 -04001571 if extra_args:
1572 for i, t in extra_args:
1573 args[i] = t.size()
Sherlock Huang5faa2792022-11-15 01:06:23 +00001574 if extra_kwargs:
1575 for k, t in extra_kwargs.items():
1576 kwargs[k] = t.size()
albanD9199f912022-10-19 18:33:17 -04001577
1578 fn = _get_safe_inplace(op.get_inplace()) if inplace else op.op
1579 return fn(*args, **kwargs)
PyTorch MergeBot4e33c8c2022-06-27 12:06:49 +00001580 sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
1581 new_f = None
Horace He2f4a5172022-09-21 02:30:50 +00001582
1583 # Limit ourselves to first 100 inputs so symbolic tracing tests don't take too long
1584 for sample_input in itertools.islice(sample_inputs_itr, 100):
albanD9199f912022-10-19 18:33:17 -04001585 if inplace and sample_input.broadcasts_input:
1586 continue
PyTorch MergeBot4e33c8c2022-06-27 12:06:49 +00001587 args = [sample_input.input] + list(sample_input.args)
1588 kwargs = sample_input.kwargs
1589
albanD254b6812022-10-19 18:33:17 -04001590 # If any argument is a torch.Size(), maybe get dynamic shapes for it by:
1591 # - Create a temporary Tensor whose size is the torch.Size() we want. Note that
1592 # we use an expanded Tensor as we cannot pass "meta" Tensors to make_fx.
1593 # - Pass it to make_fx such that it is is converted to a proxy Tensor
1594 # - Unpack the size in the wrapper to get a torch.Size with dynamic shapes (in
1595 # symbolic mode, a no-op otherwise)
1596 extra_args = []
Sherlock Huang5faa2792022-11-15 01:06:23 +00001597 extra_kwargs = {}
albanD254b6812022-10-19 18:33:17 -04001598 for i, arg in enumerate(args):
1599 if isinstance(arg, torch.Size):
Sherlock Huang5faa2792022-11-15 01:06:23 +00001600 extra_args.append((i, torch.empty(arg, device="cpu")))
1601 for key, value in kwargs.items():
1602 if isinstance(value, torch.Size):
1603 extra_kwargs[key] = torch.empty(value, device="cpu")
albanD254b6812022-10-19 18:33:17 -04001604
PyTorch MergeBot4e33c8c2022-06-27 12:06:49 +00001605 try:
Sherlock Huang5faa2792022-11-15 01:06:23 +00001606 new_f = make_fx(f, tracing_mode=tracing_mode)(args, kwargs, extra_args, extra_kwargs)
PyTorch MergeBot4e33c8c2022-06-27 12:06:49 +00001607 except DynamicOutputShapeException as e:
1608 self.skipTest("Dynamic output shape operation in trace")
PyTorch MergeBot4e33c8c2022-06-27 12:06:49 +00001609 for arg in args:
1610 if isinstance(arg, torch.Tensor) and arg.dtype == torch.float:
1611 arg.uniform_(0, 1)
1612 try:
Sherlock Huang5faa2792022-11-15 01:06:23 +00001613 old_out = f(args, kwargs, extra_args, extra_kwargs)
PyTorch MergeBot4e33c8c2022-06-27 12:06:49 +00001614 except Exception:
1615 continue
Sherlock Huang5faa2792022-11-15 01:06:23 +00001616 new_out = wrapper_set_seed(new_f, args, kwargs, extra_args, extra_kwargs)
PyTorch MergeBot4e33c8c2022-06-27 12:06:49 +00001617 self.assertEqual(new_out, old_out)
1618
Horace He4d88aff2022-06-07 00:28:53 +00001619class TestProxyTensorOpInfo(TestCase):
Richard Zou44b09bf2023-04-18 06:51:23 -07001620 @ops(op_db + custom_op_db, allowed_dtypes=(torch.float,))
PyTorch MergeBot4e33c8c2022-06-27 12:06:49 +00001621 @skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures)
Horace He4d88aff2022-06-07 00:28:53 +00001622 def test_make_fx_exhaustive(self, device, dtype, op):
Horace He1a18ff32022-07-23 19:03:38 +00001623 _test_make_fx_helper(self, device, dtype, op, "real")
Horace He4d88aff2022-06-07 00:28:53 +00001624
Richard Zou44b09bf2023-04-18 06:51:23 -07001625 @ops(op_db + custom_op_db, allowed_dtypes=(torch.float,))
PyTorch MergeBot4e33c8c2022-06-27 12:06:49 +00001626 @skipOps('TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive', make_fx_failures.union(fake_tensor_failures))
1627 def test_make_fx_fake_exhaustive(self, device, dtype, op):
Horace He1a18ff32022-07-23 19:03:38 +00001628 _test_make_fx_helper(self, device, dtype, op, "fake")
Horace He4d88aff2022-06-07 00:28:53 +00001629
Richard Zou44b09bf2023-04-18 06:51:23 -07001630 @ops(op_db + custom_op_db, allowed_dtypes=(torch.float,))
Horace He1a18ff32022-07-23 19:03:38 +00001631 @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive',
Edward Z. Yang55820012022-11-19 12:51:53 -05001632 make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | outplace_symbolic_tensor_failures)
Horace He1a18ff32022-07-23 19:03:38 +00001633 def test_make_fx_symbolic_exhaustive(self, device, dtype, op):
1634 _test_make_fx_helper(self, device, dtype, op, "symbolic")
Horace He4d88aff2022-06-07 00:28:53 +00001635
Richard Zou44b09bf2023-04-18 06:51:23 -07001636 @ops(op_db + custom_op_db, allowed_dtypes=(torch.float,))
albanD9199f912022-10-19 18:33:17 -04001637 @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_inplace',
1638 make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | inplace_symbolic_tensor_failures)
1639 def test_make_fx_symbolic_exhaustive_inplace(self, device, dtype, op):
1640 if not op.get_inplace():
1641 self.skipTest("No inplace variable for this op")
1642 _test_make_fx_helper(self, device, dtype, op, "symbolic", inplace=True)
1643
Horace He4d88aff2022-06-07 00:28:53 +00001644
1645only_for = ("cpu")
Horace He4d88aff2022-06-07 00:28:53 +00001646instantiate_device_type_tests(TestProxyTensorOpInfo, globals(), only_for=only_for)
1647
1648
1649if __name__ == '__main__':
1650 run_tests()