Edward Z. Yang | 5b88a20 | 2022-07-20 18:12:25 -0400 | [diff] [blame] | 1 | # Owner(s): ["module: ProxyTensor"] |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 2 | |
Edward Z. Yang | 89e16c4 | 2023-02-15 17:57:21 -0500 | [diff] [blame] | 3 | from torch.testing._internal.common_utils import TestCase, run_tests, xfail_inherited_tests |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 4 | import torch |
| 5 | import unittest |
| 6 | import warnings |
Horace He | 6a3ecda | 2022-08-31 00:29:55 +0000 | [diff] [blame] | 7 | import operator |
Mostafa Elhoushi | 0894c49 | 2022-07-25 12:43:17 +0000 | [diff] [blame] | 8 | from collections.abc import Iterable |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 9 | from torch.testing._internal.common_device_type import instantiate_device_type_tests |
Fabio Rocha | b652577 | 2023-02-16 15:34:34 +0000 | [diff] [blame] | 10 | from torch.testing._internal.common_methods_invocations import op_db, wrapper_set_seed, skip, xfail, skipOps |
Edward Z. Yang | f7365ec | 2022-12-10 20:29:21 -0800 | [diff] [blame] | 11 | from torch._subclasses.fake_tensor import DynamicOutputShapeException, DataDependentOutputException |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 12 | |
David Berard | 00f6518 | 2022-06-29 10:28:42 -0700 | [diff] [blame] | 13 | from torch._decomp import decomposition_table |
Edward Z. Yang | f1f26fe | 2023-02-12 14:04:01 -0800 | [diff] [blame] | 14 | from torch.fx.experimental.symbolic_shapes import ( |
Edward Z. Yang | 3758559 | 2023-02-21 06:45:00 -0800 | [diff] [blame] | 15 | sym_float, eval_guards, bind_symbols, fx_placeholder_vals, fx_placeholder_targets, |
Edward Z. Yang | 8efe4fd | 2023-02-23 11:54:36 -0800 | [diff] [blame] | 16 | constrain_range, guard_int, GuardOnDataDependentSymNode |
Edward Z. Yang | f1f26fe | 2023-02-12 14:04:01 -0800 | [diff] [blame] | 17 | ) |
Richard Zou | 44b09bf | 2023-04-18 06:51:23 -0700 | [diff] [blame] | 18 | from torch.testing._internal.custom_op_db import custom_op_db |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 19 | from torch.testing._internal.common_device_type import ops |
Horace He | 6639677 | 2022-08-10 22:31:38 +0000 | [diff] [blame] | 20 | from torch._C import _disabled_torch_function_impl |
Edward Z. Yang | 54563e6 | 2022-12-15 16:37:24 +0800 | [diff] [blame] | 21 | from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter, get_isolated_graphmodule |
Horace He | b7046e9 | 2022-07-07 04:54:31 +0000 | [diff] [blame] | 22 | from torch.utils._pytree import tree_map |
Edward Z. Yang | bf387e8 | 2022-08-01 08:55:19 -0700 | [diff] [blame] | 23 | from torch import nn |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 24 | import re |
| 25 | |
Edward Z. Yang | d247244 | 2022-08-03 14:25:16 -0700 | [diff] [blame] | 26 | import functools |
Horace He | 2f4a517 | 2022-09-21 02:30:50 +0000 | [diff] [blame] | 27 | import itertools |
Edward Z. Yang | d247244 | 2022-08-03 14:25:16 -0700 | [diff] [blame] | 28 | |
Horace He | 91b4648 | 2022-07-26 20:21:16 +0000 | [diff] [blame] | 29 | aten = torch.ops.aten |
| 30 | |
Horace He | e3c89d0 | 2022-08-25 06:59:37 +0000 | [diff] [blame] | 31 | HAS_CUDA = torch.cuda.is_available() |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 32 | |
| 33 | |
Edward Z. Yang | f1f26fe | 2023-02-12 14:04:01 -0800 | [diff] [blame] | 34 | def strip_end(s, suffix): |
| 35 | if suffix and s.endswith(suffix): |
| 36 | return s[:-len(suffix)] |
| 37 | else: |
| 38 | return s |
| 39 | |
| 40 | |
| 41 | def show_guards(gm): |
| 42 | names = [strip_end(n, "_1") for n in fx_placeholder_targets(gm)] |
| 43 | return "\n".join( |
Edward Z. Yang | 8372c5d | 2023-03-28 19:56:26 -0700 | [diff] [blame] | 44 | gm.shape_env.produce_guards(fx_placeholder_vals(gm), names, _simplified=True, constraint_inputs=None) |
Edward Z. Yang | f1f26fe | 2023-02-12 14:04:01 -0800 | [diff] [blame] | 45 | ) |
| 46 | |
| 47 | |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 48 | def process_failures(): |
| 49 | """ |
| 50 | Takes file containing failures like |
| 51 | |
| 52 | FAILED test/test_proxy_tensor.py::TestProxyTensorOpInfoCPU::test_make_fx_symbolic_exhaustive___getitem___cpu_float32 - RuntimeError: aten.size.default - couldn't find symbolic meta function/decomposition # noqa: B950 |
| 53 | |
| 54 | and processes them into a list of opinfo xfails |
| 55 | """ |
| 56 | f = open('pytest_failures') |
| 57 | failures = f.readlines() |
| 58 | failures = [i.strip() for i in failures] |
| 59 | |
| 60 | def process_failure_string(s, matcher): |
| 61 | out = re.search(matcher, s) |
| 62 | return out.groups() |
| 63 | |
| 64 | SYMBOLIC_TRACE_MATCH = r'exhaustive_(.*)_cpu.*: (.*)' |
| 65 | failures = [process_failure_string(s, SYMBOLIC_TRACE_MATCH) for s in failures] |
| 66 | |
| 67 | def create_normalized_name(op): |
| 68 | if op.variant_test_name == '': |
| 69 | s = op.name |
| 70 | else: |
| 71 | s = f"{op.name}.{op.variant_test_name}" |
| 72 | return s.replace('.', '_') |
| 73 | |
| 74 | remap_opinfo = {create_normalized_name(op): (op.name, op.variant_test_name) for op in op_db} |
| 75 | |
| 76 | print("symbolic_tensor_failures = {") |
| 77 | for failure, reason in failures: |
| 78 | print(f" xfail{remap_opinfo[failure]}, # {reason}") |
| 79 | print("}") |
| 80 | |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 81 | |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 82 | USE_TORCHVISION = False |
| 83 | try: |
| 84 | import torchvision |
| 85 | USE_TORCHVISION = True |
| 86 | except ImportError: |
| 87 | warnings.warn("Couldn't import torchvision. Some of our tests use it, try " |
| 88 | "to install it with commands from pytorch.org, post-fixed with " |
| 89 | "`--no-deps` to avoid overwriting the pytorch installation", |
| 90 | UserWarning) |
| 91 | |
| 92 | |
Horace He | b7046e9 | 2022-07-07 04:54:31 +0000 | [diff] [blame] | 93 | def _create_new_input(x): |
| 94 | if not isinstance(x, torch.Tensor): |
| 95 | return x |
| 96 | if x.dtype != torch.float: |
| 97 | return x + 1 |
| 98 | if x.is_leaf: |
Edward Z. Yang | d247244 | 2022-08-03 14:25:16 -0700 | [diff] [blame] | 99 | return torch.rand_like(x, requires_grad=x.requires_grad) |
Horace He | b7046e9 | 2022-07-07 04:54:31 +0000 | [diff] [blame] | 100 | else: |
| 101 | return torch.rand_like(x) |
| 102 | |
Horace He | 6639677 | 2022-08-10 22:31:38 +0000 | [diff] [blame] | 103 | """ |
| 104 | Delays a cos being executed on the unwraptensor until its used. Simulates a CommTensor used |
| 105 | """ |
| 106 | class UnwrapTensor(torch.Tensor): |
| 107 | @staticmethod |
| 108 | def __new__(cls, tensor: torch.Tensor): |
| 109 | r = torch.Tensor._make_wrapper_subclass( |
| 110 | cls, |
| 111 | tensor.size(), |
| 112 | dtype=tensor.dtype, |
| 113 | device=tensor.device, |
| 114 | layout=tensor.layout, |
| 115 | requires_grad=tensor.requires_grad, |
| 116 | ) |
| 117 | r._tensor = tensor |
| 118 | return r |
| 119 | |
| 120 | def __repr__(self): |
| 121 | # TODO: consider all_gather the local tensors for better debugging |
| 122 | return f"UnwrapTensor({self._tensor})" |
| 123 | |
| 124 | __torch_function__ = _disabled_torch_function_impl |
| 125 | |
| 126 | @classmethod |
| 127 | def __torch_dispatch__(cls, func, types, args=(), kwargs=None): |
| 128 | def unwrap(e): |
| 129 | ret = e |
| 130 | if isinstance(e, UnwrapTensor): |
| 131 | ret = e._tensor.cos() |
| 132 | |
| 133 | return ret |
| 134 | |
| 135 | args = tree_map(unwrap, args) |
| 136 | kwargs = tree_map(unwrap, kwargs) |
| 137 | return func(*args, **kwargs) |
| 138 | |
Edward Z. Yang | d247244 | 2022-08-03 14:25:16 -0700 | [diff] [blame] | 139 | class TestGenericProxyTensor(TestCase): |
| 140 | # WARNING: if any of your inputs are index tensors, DO NOT use this |
| 141 | # function |
Horace He | b7046e9 | 2022-07-07 04:54:31 +0000 | [diff] [blame] | 142 | def _test(self, f, inps): |
Edward Z. Yang | b361f70 | 2022-08-03 10:50:30 -0700 | [diff] [blame] | 143 | fx_f = make_fx(f, tracing_mode=self.tracing_mode)(*inps) |
Horace He | b7046e9 | 2022-07-07 04:54:31 +0000 | [diff] [blame] | 144 | new_inps = tree_map(_create_new_input, inps) |
Edward Z. Yang | 817a827 | 2022-08-16 13:37:29 -0700 | [diff] [blame] | 145 | r1 = fx_f(*new_inps) |
| 146 | r2 = f(*new_inps) |
| 147 | self.assertEqual(r1, r2) |
Horace He | b7046e9 | 2022-07-07 04:54:31 +0000 | [diff] [blame] | 148 | |
Brian Hirsh | af440c4 | 2023-03-21 20:15:23 +0000 | [diff] [blame] | 149 | def test_pre_autograd_mode_stack(self): |
| 150 | def f(a): |
| 151 | b = torch.ones(4, 4) |
| 152 | return torch.matmul(a, b) |
| 153 | # We expect to see matmul in the trace - it should NOT be decomposed into mm. |
| 154 | # Also, torch.ones() doesn't show up in the trace. |
| 155 | # This is annoying but expected: ones() never dispatches to the Autograd dispatch key, |
| 156 | # so our mode never sees it - it goes directly to the BackendSelect key. |
Brian Hirsh | 62fad31 | 2023-04-25 19:04:34 +0000 | [diff] [blame] | 157 | inp = torch.ones(4, 4) |
| 158 | # Test that make_fx(pre_autograd=True) clears caches properly. |
| 159 | from torch._dispatch.python import enable_python_dispatcher |
| 160 | with enable_python_dispatcher(): |
| 161 | out1 = f(inp) |
| 162 | fx_g = make_fx(f, pre_autograd=True)(inp) |
Brian Hirsh | af440c4 | 2023-03-21 20:15:23 +0000 | [diff] [blame] | 163 | self.assertExpectedInline(fx_g.code.strip(), """\ |
| 164 | def forward(self, a_1): |
| 165 | ones = torch.ops.aten.ones.default([4, 4], device = device(type='cpu'), pin_memory = False) |
| 166 | matmul = torch.ops.aten.matmul.default(a_1, ones); a_1 = ones = None |
| 167 | return matmul""") |
| 168 | |
| 169 | |
Edward Z. Yang | d247244 | 2022-08-03 14:25:16 -0700 | [diff] [blame] | 170 | def test_make_fx_simple(self): |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 171 | def f(x): |
| 172 | return torch.sin(x) |
Horace He | b7046e9 | 2022-07-07 04:54:31 +0000 | [diff] [blame] | 173 | self._test(f, (torch.randn(3),)) |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 174 | |
Edward Z. Yang | d247244 | 2022-08-03 14:25:16 -0700 | [diff] [blame] | 175 | def test_scalar_device(self, device='cpu'): |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 176 | def f(a, b): |
| 177 | return a + b |
Horace He | b7046e9 | 2022-07-07 04:54:31 +0000 | [diff] [blame] | 178 | self._test(f, [torch.randn(3, device=device), torch.tensor(5)]) |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 179 | |
Ivan Yashchuk | 900e93d | 2022-08-02 11:02:10 +0000 | [diff] [blame] | 180 | def test_isolated_graphmodule(self): |
| 181 | def is_any_sum(gm): |
| 182 | return any(node.target == torch.ops.aten.sum.default for node in gm.graph.nodes) |
| 183 | |
| 184 | def is_any_digamma(gm): |
| 185 | return any(node.target == torch.ops.aten.digamma.default for node in gm.graph.nodes) |
| 186 | |
| 187 | def is_any_sigmoid(gm): |
| 188 | return any(node.target == torch.ops.aten.sigmoid.default for node in gm.graph.nodes) |
| 189 | |
| 190 | def inner(x): |
| 191 | return torch.sum(x) |
| 192 | |
| 193 | def f(x): |
| 194 | gm = get_isolated_graphmodule(inner, (x,), {}) |
| 195 | self.assertTrue(is_any_sum(gm)) |
| 196 | return x + torch.randn(x.shape) |
| 197 | |
| 198 | # get_isolated_graphmodule uses make_fx internally that shouldn't be traced |
| 199 | # by the outer make_fx call |
| 200 | traced = make_fx(f)(torch.randn(3)) |
| 201 | self.assertFalse(is_any_sum(traced)) |
| 202 | |
| 203 | # When factory functions are used, they should not be traced |
| 204 | # by the outer make_fx call |
| 205 | def inner_with_factory(): |
| 206 | val = torch.tensor(float(1)) |
| 207 | val.add_(2) |
| 208 | return torch.full((10, 10), val).sum() |
| 209 | |
| 210 | def f1(x): |
| 211 | gm = get_isolated_graphmodule(inner_with_factory, (), {}) |
| 212 | self.assertTrue(is_any_sum(gm)) |
| 213 | return torch.sigmoid(x) |
| 214 | |
| 215 | def f2(x): |
| 216 | gm = get_isolated_graphmodule(f1, (x,), {}) |
| 217 | self.assertFalse(is_any_sum(gm)) |
| 218 | self.assertTrue(is_any_sigmoid(gm)) |
| 219 | return torch.digamma(x) |
| 220 | |
| 221 | traced = make_fx(f2)(torch.randn(3)) |
| 222 | self.assertFalse(is_any_sum(traced)) |
| 223 | self.assertFalse(is_any_sigmoid(traced)) |
| 224 | self.assertTrue(is_any_digamma(traced)) |
| 225 | |
| 226 | # Verify nested make_fx calls don't make factory functions to be leaked |
jon-chuang | d5901fc | 2023-02-01 17:28:44 +0000 | [diff] [blame] | 227 | # into the outer graph. Verify that `make_fx`` itself does not leak its execution. |
Ivan Yashchuk | 900e93d | 2022-08-02 11:02:10 +0000 | [diff] [blame] | 228 | def f2(x): |
| 229 | gm = make_fx(f1)(x) |
| 230 | self.assertFalse(is_any_sum(gm)) |
| 231 | self.assertTrue(is_any_sigmoid(gm)) |
| 232 | return torch.digamma(x) |
| 233 | |
| 234 | traced = make_fx(f2)(torch.randn(3)) |
| 235 | self.assertFalse(is_any_sum(traced)) |
jon-chuang | d5901fc | 2023-02-01 17:28:44 +0000 | [diff] [blame] | 236 | self.assertFalse(is_any_sigmoid(traced)) |
| 237 | self.assertTrue(is_any_digamma(traced)) |
| 238 | |
| 239 | # Verify that the `forward`` function of a graph module produced as a |
| 240 | # side effect of an interior `make_fx` is still traced |
| 241 | def f3(x): |
| 242 | gm = make_fx(f1)(x) |
| 243 | self.assertFalse(is_any_sum(gm)) |
| 244 | self.assertTrue(is_any_sigmoid(gm)) |
| 245 | # `gm.forward`` is still traced |
| 246 | return torch.digamma(gm(x)) |
| 247 | |
| 248 | traced = make_fx(f3)(torch.randn(3)) |
| 249 | self.assertFalse(is_any_sum(traced)) |
Ivan Yashchuk | 900e93d | 2022-08-02 11:02:10 +0000 | [diff] [blame] | 250 | self.assertTrue(is_any_sigmoid(traced)) |
| 251 | self.assertTrue(is_any_digamma(traced)) |
| 252 | |
| 253 | # Verify interaction with non-ProxyTensor modes |
| 254 | from torch.testing._internal.logging_tensor import LoggingTensorMode |
| 255 | |
| 256 | def f1_logging(x): |
| 257 | with LoggingTensorMode(): |
| 258 | gm = get_isolated_graphmodule(inner_with_factory, (), {}) |
| 259 | self.assertTrue(is_any_sum(gm)) |
| 260 | return torch.sigmoid(x) |
| 261 | |
| 262 | def f2_logging(x): |
| 263 | with LoggingTensorMode(), LoggingTensorMode(): |
| 264 | gm = get_isolated_graphmodule(f1_logging, (x,), {}) |
| 265 | self.assertFalse(is_any_sum(gm)) |
| 266 | self.assertTrue(is_any_sigmoid(gm)) |
| 267 | return torch.digamma(x) |
| 268 | |
| 269 | traced = make_fx(f2_logging)(torch.randn(3)) |
| 270 | self.assertFalse(is_any_sum(traced)) |
| 271 | self.assertFalse(is_any_sigmoid(traced)) |
| 272 | self.assertTrue(is_any_digamma(traced)) |
| 273 | |
| 274 | # Verify interaction with another tensor subclass |
| 275 | # This case currently doesn't work and should raise an error |
| 276 | # See: https://github.com/pytorch/pytorch/pull/81764#issuecomment-1200472068 |
| 277 | from torch.testing._internal.logging_tensor import LoggingTensor |
| 278 | |
| 279 | def f1_logging_tensor(x): |
| 280 | gm = get_isolated_graphmodule(inner_with_factory, (), {}) |
| 281 | self.assertTrue(is_any_sum(gm)) |
| 282 | return torch.sigmoid(x) |
| 283 | |
| 284 | def f2_logging_tensor(x): |
| 285 | x = LoggingTensor(x) |
| 286 | gm = get_isolated_graphmodule(f1_logging_tensor, (x,), {}) |
| 287 | self.assertFalse(is_any_sum(gm)) |
| 288 | self.assertTrue(is_any_sigmoid(gm)) |
| 289 | return torch.digamma(x) |
| 290 | |
Edward Z. Yang | 817a827 | 2022-08-16 13:37:29 -0700 | [diff] [blame] | 291 | traced = make_fx(f2_logging_tensor)(torch.randn(3)) |
| 292 | self.assertFalse(is_any_sum(traced)) |
| 293 | self.assertFalse(is_any_sigmoid(traced)) # this fails, sigmoid is traced with LoggingTensor |
| 294 | self.assertTrue(is_any_digamma(traced)) |
Ivan Yashchuk | 900e93d | 2022-08-02 11:02:10 +0000 | [diff] [blame] | 295 | |
Brian Hirsh | 35c9ea8 | 2023-03-27 15:08:41 +0000 | [diff] [blame] | 296 | # See https://github.com/pytorch/pytorch/issues/97541 |
| 297 | def test_empty_like_doesnt_burn_in_defaults(self): |
| 298 | def f(x): |
| 299 | return torch.empty_like(x) |
| 300 | out = make_fx(f)(torch.randn(3)) |
| 301 | self.assertExpectedInline(out.code.strip(), """\ |
| 302 | def forward(self, x_1): |
| 303 | empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False); x_1 = None |
| 304 | return empty_like""") |
| 305 | |
Brian Hirsh | ba90c9f | 2022-08-11 14:20:53 -0700 | [diff] [blame] | 306 | def test_proxy_tensor_mode_with_decomp_table_preserves_proxy(self): |
| 307 | def f(x): |
| 308 | y = x.new_zeros(x.size()) |
| 309 | y.copy_(x) |
| 310 | return y |
| 311 | |
| 312 | def _new_zeros_decomp(inp, size, dtype=None, layout=None, device=None, pin_memory=None): |
| 313 | return torch.zeros(size, dtype=inp.dtype, device=inp.device) |
| 314 | |
| 315 | factory_func_decomp = {torch.ops.aten.new_zeros.default: _new_zeros_decomp} |
| 316 | |
| 317 | # When new_zeros() decomposes into torch.zero(), we expect ProxyTensorMode |
| 318 | # to still be (re-entrantly) enabled, so that the `torch.zero()` call |
| 319 | # returns a ProxyTensor. |
| 320 | out = make_fx(f, decomposition_table=factory_func_decomp)(torch.ones(2)) |
| 321 | self.assertExpectedInline(out.code, """\ |
| 322 | |
| 323 | |
| 324 | |
| 325 | def forward(self, x_1): |
| 326 | zeros = torch.ops.aten.zeros.default([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) |
Horace He | a27a4a0 | 2022-08-31 07:01:37 +0000 | [diff] [blame] | 327 | copy_ = torch.ops.aten.copy_.default(zeros, x_1); zeros = x_1 = None |
| 328 | return copy_ |
Brian Hirsh | ba90c9f | 2022-08-11 14:20:53 -0700 | [diff] [blame] | 329 | """) |
| 330 | |
Edward Z. Yang | 988bd01 | 2022-08-09 08:35:50 -0700 | [diff] [blame] | 331 | def test_make_fx_reentrant_dispatch(self): |
| 332 | def f(x): |
| 333 | return torch.ops.aten.norm.Scalar(x, 2.0) |
| 334 | |
| 335 | def norm_decomp(x, p=2.0): |
| 336 | if p != 2.0: |
| 337 | raise RuntimeError("can't handle with p != 2") |
| 338 | return torch.sqrt(torch.sum(torch.square(x))) |
| 339 | |
| 340 | decomp = {torch.ops.aten.norm.Scalar: norm_decomp} |
| 341 | |
| 342 | traced = make_fx(f, decomposition_table=decomp, tracing_mode=self.tracing_mode)(torch.rand(3)) |
| 343 | |
| 344 | for n in traced.graph.nodes: |
| 345 | self.assertTrue("square" not in str(n.target)) |
| 346 | self.assertTrue("norm" not in str(n.target)) |
| 347 | |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 348 | @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision") |
Edward Z. Yang | d247244 | 2022-08-03 14:25:16 -0700 | [diff] [blame] | 349 | def test_resnet18_backward_trace(self): |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 350 | mod = torchvision.models.resnet18() |
| 351 | |
Edward Z. Yang | d247244 | 2022-08-03 14:25:16 -0700 | [diff] [blame] | 352 | # An old version of this test called the module directly. This works |
| 353 | # for tracing_mode == "real", but for fake tensors, we also have to |
| 354 | # ensure that the parameters and buffers get wrapped in fake tensors |
Richard Zou | 5d01277 | 2023-01-17 21:49:58 -0500 | [diff] [blame] | 355 | # because free fake tensors are not supported. Fortunately functional_call |
Edward Z. Yang | d247244 | 2022-08-03 14:25:16 -0700 | [diff] [blame] | 356 | # does precisely this for us. |
| 357 | def f(x, params, buffers): |
| 358 | for p in params.values(): |
| 359 | p.grad = None |
Richard Zou | 5d01277 | 2023-01-17 21:49:58 -0500 | [diff] [blame] | 360 | loss = torch.func.functional_call(mod, {**params, **buffers}, (x,)).sum() |
Edward Z. Yang | d247244 | 2022-08-03 14:25:16 -0700 | [diff] [blame] | 361 | # I could have done this with the functional API, but there is |
| 362 | # plenty of exercising this; I want to show mutating API still |
| 363 | # works |
| 364 | loss.backward() |
| 365 | return [p.grad for p in params.values()] |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 366 | |
Edward Z. Yang | d247244 | 2022-08-03 14:25:16 -0700 | [diff] [blame] | 367 | inp = torch.randn(3, 3, 250, 250) |
| 368 | self._test(f, [inp, dict(mod.named_parameters()), dict(mod.named_buffers())]) |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 369 | |
Edward Z. Yang | 63f35f1 | 2022-08-10 15:33:10 -0700 | [diff] [blame] | 370 | def test_varargs(self): |
| 371 | def f(*args): |
| 372 | return sum(args) |
| 373 | |
| 374 | self._test(f, [torch.randn(2), torch.randn(2)]) |
| 375 | |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 376 | def test_proxy_tensor(self): |
| 377 | def f_grad(x): |
| 378 | val = x.cos().cos().sum() |
| 379 | return torch.autograd.grad(val, x) |
| 380 | |
| 381 | def f_backward(x): |
| 382 | val = x.cos().cos().sum() |
| 383 | val.backward() |
| 384 | return x.grad |
| 385 | |
| 386 | for f in [f_grad, f_backward]: |
Horace He | b7046e9 | 2022-07-07 04:54:31 +0000 | [diff] [blame] | 387 | self._test(f, [torch.randn(3, requires_grad=True)]) |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 388 | |
Edward Z. Yang | 54563e6 | 2022-12-15 16:37:24 +0800 | [diff] [blame] | 389 | def test_pickle_issue89626(self): |
| 390 | import pickle |
| 391 | x = torch.randn(2) |
| 392 | make_fx(lambda x: x * 2, tracing_mode=self.tracing_mode)(x) |
| 393 | pickle.dumps(x) |
| 394 | |
PyTorch MergeBot | 4e33c8c | 2022-06-27 12:06:49 +0000 | [diff] [blame] | 395 | def test_inplace_metadata(self): |
| 396 | def f(x): |
| 397 | x = x.clone() |
| 398 | x.unsqueeze_(-1) |
| 399 | assert x.shape[-1] == 1 |
| 400 | return x |
| 401 | |
Horace He | b7046e9 | 2022-07-07 04:54:31 +0000 | [diff] [blame] | 402 | self._test(f, [torch.randn(5)]) |
PyTorch MergeBot | 4e33c8c | 2022-06-27 12:06:49 +0000 | [diff] [blame] | 403 | |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 404 | def test_mode_tracing_factory_function(self): |
| 405 | def f(x): |
| 406 | return x + torch.randn(x.shape) |
| 407 | |
Horace He | f5d7e5a | 2022-06-16 22:04:10 +0000 | [diff] [blame] | 408 | # default behavior should trace factory functions |
Edward Z. Yang | b361f70 | 2022-08-03 10:50:30 -0700 | [diff] [blame] | 409 | traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3)) |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 410 | self.assertTrue( |
| 411 | any( |
Horace He | 91b4648 | 2022-07-26 20:21:16 +0000 | [diff] [blame] | 412 | node.target == aten.randn.default |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 413 | for node in traced.graph.nodes |
| 414 | ) |
| 415 | ) |
| 416 | |
Edward Z. Yang | 94b5c80 | 2022-11-18 13:14:40 -0800 | [diff] [blame] | 417 | def test_val_metadata_mutation(self): |
| 418 | def f(x): |
| 419 | y = x.clone() |
| 420 | y.unsqueeze_(0) |
| 421 | return y |
| 422 | |
| 423 | traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3, requires_grad=True)) |
| 424 | self.assertEqual([ |
| 425 | tuple(node.meta['val'].shape) |
| 426 | for node in traced.graph.nodes |
| 427 | if 'val' in node.meta |
| 428 | ], [(3,), (3,), (1, 3)]) |
| 429 | |
Horace He | 615dd25 | 2022-06-28 00:20:22 +0000 | [diff] [blame] | 430 | def test_make_fx_overloads(self): |
| 431 | def f(x): |
| 432 | return x.cos() + torch.randn(x.shape) |
| 433 | |
Edward Z. Yang | b361f70 | 2022-08-03 10:50:30 -0700 | [diff] [blame] | 434 | traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3)) |
Horace He | 615dd25 | 2022-06-28 00:20:22 +0000 | [diff] [blame] | 435 | |
Aaron Gokaslan | e2a3817 | 2023-04-25 15:02:13 +0000 | [diff] [blame] | 436 | self.assertTrue(all(isinstance(node.target, torch._ops.OpOverload) |
| 437 | for node in traced.graph.nodes if node.op == 'call_function')) |
Horace He | 615dd25 | 2022-06-28 00:20:22 +0000 | [diff] [blame] | 438 | |
Horace He | b7046e9 | 2022-07-07 04:54:31 +0000 | [diff] [blame] | 439 | def test_tensor_constants(self): |
| 440 | def f(): |
| 441 | val = torch.tensor(float('inf')) |
| 442 | return torch.full((100, 100), val) |
| 443 | |
| 444 | self._test(f, []) |
| 445 | |
Edward Z. Yang | d423722 | 2022-08-12 06:17:53 -0700 | [diff] [blame] | 446 | def test_allclose(self): |
| 447 | def f(a, b): |
| 448 | return torch.allclose(a, b) |
Edward Z. Yang | fca03ee | 2022-07-13 21:11:10 -0700 | [diff] [blame] | 449 | |
Edward Z. Yang | f7365ec | 2022-12-10 20:29:21 -0800 | [diff] [blame] | 450 | def test_f(): |
| 451 | make_fx(f, tracing_mode=self.tracing_mode)( |
Edward Z. Yang | d423722 | 2022-08-12 06:17:53 -0700 | [diff] [blame] | 452 | torch.zeros(3), torch.zeros(3) |
| 453 | ) |
Edward Z. Yang | f7365ec | 2022-12-10 20:29:21 -0800 | [diff] [blame] | 454 | |
Edward Z. Yang | ec2461b | 2023-01-30 05:57:30 -0800 | [diff] [blame] | 455 | if self.tracing_mode != "real": |
Edward Z. Yang | f7365ec | 2022-12-10 20:29:21 -0800 | [diff] [blame] | 456 | self.assertRaises(DataDependentOutputException, test_f) |
| 457 | else: |
| 458 | self.assertRaisesRegex(RuntimeError, "data-dependent", test_f) |
Edward Z. Yang | fca03ee | 2022-07-13 21:11:10 -0700 | [diff] [blame] | 459 | |
| 460 | def test_constant_proxy_tensor_mut(self): |
Edward Z. Yang | fca03ee | 2022-07-13 21:11:10 -0700 | [diff] [blame] | 461 | def f(): |
| 462 | val = torch.tensor(float(1)) |
| 463 | val.add_(2) |
| 464 | return torch.full((100, 100), val) |
| 465 | |
Edward Z. Yang | b361f70 | 2022-08-03 10:50:30 -0700 | [diff] [blame] | 466 | g = make_fx(f, tracing_mode=self.tracing_mode)() |
Edward Z. Yang | fca03ee | 2022-07-13 21:11:10 -0700 | [diff] [blame] | 467 | self.assertEqual(g(), f()) |
| 468 | # In case we mutated shared state in the g graph! |
| 469 | self.assertEqual(g(), f()) |
| 470 | |
Edward Z. Yang | 9821592 | 2022-08-01 07:02:58 -0700 | [diff] [blame] | 471 | def test_constant_unbind(self): |
| 472 | def f(): |
| 473 | val = torch.tensor([2]) |
| 474 | r, = torch.unbind(val, 0) |
| 475 | return r.item() |
| 476 | |
Edward Z. Yang | b361f70 | 2022-08-03 10:50:30 -0700 | [diff] [blame] | 477 | g = make_fx(f, tracing_mode=self.tracing_mode)() |
Edward Z. Yang | 9821592 | 2022-08-01 07:02:58 -0700 | [diff] [blame] | 478 | self.assertEqual(g(), f()) |
| 479 | |
Edward Z. Yang | 24acc31 | 2022-08-17 20:30:13 -0700 | [diff] [blame] | 480 | def test_constant_blowup(self): |
| 481 | def f(): |
| 482 | val = torch.tensor([2]) |
| 483 | blowup = val.repeat(1000) |
Edward Z. Yang | f7365ec | 2022-12-10 20:29:21 -0800 | [diff] [blame] | 484 | return bool(blowup.sum().item() == 2) |
Edward Z. Yang | 24acc31 | 2022-08-17 20:30:13 -0700 | [diff] [blame] | 485 | |
Edward Z. Yang | ec2461b | 2023-01-30 05:57:30 -0800 | [diff] [blame] | 486 | def test_f(): |
| 487 | make_fx(f, tracing_mode=self.tracing_mode)() |
| 488 | |
| 489 | if self.tracing_mode == "fake": |
| 490 | self.assertRaises(DataDependentOutputException, test_f) |
| 491 | else: |
| 492 | self.assertRaisesRegex(RuntimeError, "data-dependent", test_f) |
Edward Z. Yang | 24acc31 | 2022-08-17 20:30:13 -0700 | [diff] [blame] | 493 | |
Edward Z. Yang | 9152144 | 2022-08-17 20:30:41 -0700 | [diff] [blame] | 494 | def test_constant_random(self): |
| 495 | def f(): |
| 496 | val = torch.tensor([2.0]) |
| 497 | val.normal_() |
Edward Z. Yang | f7365ec | 2022-12-10 20:29:21 -0800 | [diff] [blame] | 498 | return bool(val.item() == 2.1) |
Edward Z. Yang | 9152144 | 2022-08-17 20:30:41 -0700 | [diff] [blame] | 499 | |
Edward Z. Yang | ec2461b | 2023-01-30 05:57:30 -0800 | [diff] [blame] | 500 | def test_f(): |
| 501 | make_fx(f, tracing_mode=self.tracing_mode)() |
| 502 | |
| 503 | if self.tracing_mode == "fake": |
| 504 | self.assertRaises(DataDependentOutputException, test_f) |
| 505 | else: |
| 506 | self.assertRaisesRegex(RuntimeError, "data-dependent", test_f) |
Edward Z. Yang | 9152144 | 2022-08-17 20:30:41 -0700 | [diff] [blame] | 507 | |
David Berard | 00f6518 | 2022-06-29 10:28:42 -0700 | [diff] [blame] | 508 | def test_decomposition_interpreter(self): |
| 509 | def fn(x): |
| 510 | return torch.nn.functional.silu(x) |
| 511 | |
| 512 | x = torch.rand((4, 4)) |
Edward Z. Yang | b361f70 | 2022-08-03 10:50:30 -0700 | [diff] [blame] | 513 | fx_module = make_fx(fn, tracing_mode=self.tracing_mode, decomposition_table=None)(x) |
David Berard | 00f6518 | 2022-06-29 10:28:42 -0700 | [diff] [blame] | 514 | |
| 515 | found_silu = False |
| 516 | for n in fx_module.graph.nodes: |
| 517 | if n.target == torch.ops.aten.silu or n.target == torch.ops.aten.silu.default: |
| 518 | found_silu = True |
| 519 | |
| 520 | self.assertTrue(found_silu) |
| 521 | |
| 522 | new_graph = torch.fx.Graph() |
| 523 | silu_decomp_table = {torch.ops.aten.silu.default: decomposition_table[torch.ops.aten.silu.default]} |
| 524 | DecompositionInterpreter( |
| 525 | fx_module, |
| 526 | new_graph=new_graph, |
| 527 | decomposition_table=silu_decomp_table, |
| 528 | ).run(x) |
| 529 | |
| 530 | decomposed_module = torch.fx.GraphModule(fx_module, new_graph) |
| 531 | |
| 532 | for n in decomposed_module.graph.nodes: |
| 533 | self.assertTrue(n.target != torch.ops.aten.silu) |
| 534 | self.assertTrue(n.target != torch.ops.aten.silu.default) |
| 535 | |
| 536 | self.assertEqual(fx_module(x), decomposed_module(x)) |
Horace He | 615dd25 | 2022-06-28 00:20:22 +0000 | [diff] [blame] | 537 | |
Edward Z. Yang | d247244 | 2022-08-03 14:25:16 -0700 | [diff] [blame] | 538 | def test_make_fx_model_fwd_bwd(self): |
Mostafa Elhoushi | 0894c49 | 2022-07-25 12:43:17 +0000 | [diff] [blame] | 539 | class Foo(torch.nn.Module): |
| 540 | def __init__(self): |
| 541 | super().__init__() |
| 542 | self.linear = torch.nn.Linear(5, 5) |
| 543 | |
| 544 | def forward(self, x): |
| 545 | return self.linear(x).relu() |
| 546 | |
| 547 | model = Foo() |
| 548 | |
| 549 | def f(x, params): |
Richard Zou | 5d01277 | 2023-01-17 21:49:58 -0500 | [diff] [blame] | 550 | out = torch.func.functional_call(model, params, x).sum() |
Mostafa Elhoushi | 0894c49 | 2022-07-25 12:43:17 +0000 | [diff] [blame] | 551 | out.backward() |
| 552 | return list(params.values()) |
| 553 | input = torch.randn(3, 5, requires_grad=True) |
| 554 | params = dict(model.named_parameters()) |
Edward Z. Yang | b361f70 | 2022-08-03 10:50:30 -0700 | [diff] [blame] | 555 | fx_f = make_fx(f, tracing_mode=self.tracing_mode)(input, params) |
Mostafa Elhoushi | 0894c49 | 2022-07-25 12:43:17 +0000 | [diff] [blame] | 556 | # fx may change the order of parameters in list, so using set() to compare |
| 557 | self.assertTrue( |
| 558 | torch.allclose(fx_f(input, params)[0], f(input, params)[0]) |
| 559 | or |
| 560 | torch.allclose(fx_f(input, params)[0], f(input, params)[1]) |
| 561 | ) |
| 562 | self.assertTrue( |
| 563 | torch.allclose(fx_f(input, params)[1], f(input, params)[0]) |
| 564 | or |
| 565 | torch.allclose(fx_f(input, params)[1], f(input, params)[1]) |
| 566 | ) |
| 567 | |
Horace He | 8b8942b | 2022-08-25 01:53:33 +0000 | [diff] [blame] | 568 | def test_make_fx_model_double_param(self): |
| 569 | class Emformer(torch.nn.Module): |
| 570 | def __init__( |
| 571 | self, |
| 572 | input_dim: int = 256, |
| 573 | ) -> None: |
| 574 | super().__init__() |
| 575 | |
| 576 | self.layer_norm = torch.nn.LayerNorm(input_dim) |
| 577 | |
| 578 | def forward(mod_self, x): # noqa: B902 |
| 579 | self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor)) |
| 580 | y = mod_self.layer_norm(x) |
| 581 | self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor)) |
| 582 | z = mod_self.layer_norm(y) |
| 583 | return z |
| 584 | |
| 585 | |
| 586 | gm = make_fx(Emformer())(torch.randn(16, 1, 256)) |
Aaron Gokaslan | 3d82d8d | 2023-02-10 23:40:26 +0000 | [diff] [blame] | 587 | ops = {n.target for n in gm.graph.nodes if n.op == 'call_function'} |
Horace He | 8b8942b | 2022-08-25 01:53:33 +0000 | [diff] [blame] | 588 | self.assertEqual(len(ops), 2) |
| 589 | |
| 590 | |
Edward Z. Yang | d247244 | 2022-08-03 14:25:16 -0700 | [diff] [blame] | 591 | def test_make_fx_model_fwd_bwd_wgtupdate(self): |
Mostafa Elhoushi | 0894c49 | 2022-07-25 12:43:17 +0000 | [diff] [blame] | 592 | class Foo(torch.nn.Module): |
| 593 | def __init__(self): |
| 594 | super().__init__() |
| 595 | self.linear = torch.nn.Linear(5, 5) |
| 596 | |
| 597 | def forward(self, x): |
| 598 | return self.linear(x).relu() |
| 599 | |
| 600 | model = Foo() |
| 601 | |
| 602 | def f(args, params, buffers): |
Edward Z. Yang | 817a827 | 2022-08-16 13:37:29 -0700 | [diff] [blame] | 603 | for p in params.values(): |
| 604 | p.grad = None |
Mostafa Elhoushi | 0894c49 | 2022-07-25 12:43:17 +0000 | [diff] [blame] | 605 | if not isinstance(args, Iterable): |
| 606 | args = [args] |
| 607 | params_and_buffers = {**params, **buffers} |
Richard Zou | 5d01277 | 2023-01-17 21:49:58 -0500 | [diff] [blame] | 608 | out = torch.func.functional_call(model, params_and_buffers, args) |
Mostafa Elhoushi | 0894c49 | 2022-07-25 12:43:17 +0000 | [diff] [blame] | 609 | out.sum().backward() |
| 610 | return [p - 1e-4 * p.grad for p in params.values()] |
| 611 | |
| 612 | input = torch.randn(3, 5, requires_grad=True) |
| 613 | params = dict(model.named_parameters()) |
| 614 | buffers = dict(model.named_buffers()) |
Edward Z. Yang | b361f70 | 2022-08-03 10:50:30 -0700 | [diff] [blame] | 615 | fx_f = make_fx(f, tracing_mode=self.tracing_mode)(input, params, buffers) |
Mostafa Elhoushi | 0894c49 | 2022-07-25 12:43:17 +0000 | [diff] [blame] | 616 | # fx may change the order of parameters in list, so using set() to compare |
| 617 | # also there is a numerical difference in results so changing atol from 1e-08 to 1e-03 |
| 618 | self.assertTrue( |
| 619 | torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[0], atol=1e-03) |
| 620 | or |
| 621 | torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[1], atol=1e-03) |
| 622 | ) |
| 623 | self.assertTrue( |
| 624 | torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[0], atol=1e-03) |
| 625 | or |
| 626 | torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[1], atol=1e-03) |
| 627 | ) |
| 628 | |
Horace He | 6639677 | 2022-08-10 22:31:38 +0000 | [diff] [blame] | 629 | def test_trace_subclasses(self): |
Horace He | a27a4a0 | 2022-08-31 07:01:37 +0000 | [diff] [blame] | 630 | def f1(x): |
Horace He | 6639677 | 2022-08-10 22:31:38 +0000 | [diff] [blame] | 631 | x = UnwrapTensor(x) |
| 632 | y = x * 2 |
| 633 | return y |
| 634 | |
Horace He | a27a4a0 | 2022-08-31 07:01:37 +0000 | [diff] [blame] | 635 | def f2(x): |
| 636 | wrapped = UnwrapTensor(x) |
| 637 | y = x * wrapped |
| 638 | return y |
| 639 | |
Horace He | 6639677 | 2022-08-10 22:31:38 +0000 | [diff] [blame] | 640 | inp = [torch.randn(5)] |
Horace He | a27a4a0 | 2022-08-31 07:01:37 +0000 | [diff] [blame] | 641 | self._test(f1, inp) |
| 642 | self._test(f2, inp) |
Horace He | 0e0af73 | 2022-08-20 00:47:11 +0000 | [diff] [blame] | 643 | |
| 644 | def test_partial_decomp(self): |
| 645 | def f(a, b, c): |
| 646 | x = torch.addmm(a, b, c) |
| 647 | y = torch.addmm(a, b, c, beta=2, alpha=1) |
| 648 | return x + y |
| 649 | inps = [torch.randn(5, 5), torch.randn(5, 5), torch.randn(5, 5)] |
| 650 | fx_g = make_fx(f)(*inps) |
| 651 | |
| 652 | def addmm(a, b, c, beta=1, alpha=1): |
| 653 | if beta == 1 and alpha == 1: |
| 654 | return NotImplemented |
| 655 | return beta * a + alpha * (b @ c) |
| 656 | |
Brian Hirsh | af440c4 | 2023-03-21 20:15:23 +0000 | [diff] [blame] | 657 | decomposed_fx = make_fx(f, decomposition_table={aten.addmm.default: addmm})(*inps) |
Horace He | 0e0af73 | 2022-08-20 00:47:11 +0000 | [diff] [blame] | 658 | |
| 659 | self.assertEqual(fx_g(*inps), decomposed_fx(*inps)) |
| 660 | self.assertEqual(len([n for n in fx_g.graph.nodes if n.target == aten.addmm.default]), 2) |
| 661 | self.assertEqual(len([n for n in decomposed_fx.graph.nodes if n.target == aten.addmm.default]), 1) |
| 662 | |
Horace He | a27a4a0 | 2022-08-31 07:01:37 +0000 | [diff] [blame] | 663 | def test_decomp_of_capture(self): |
| 664 | val = torch.randn(5) |
| 665 | |
| 666 | def f(x): |
| 667 | return x.t() + val.t() |
| 668 | |
| 669 | def nop(x): |
| 670 | return x.cos() |
| 671 | |
| 672 | traced = make_fx(f, decomposition_table={torch.ops.aten.t.default: nop})(torch.randn(5)) |
| 673 | self.assertEqual(len([n for n in traced.graph.nodes if n.target == torch.ops.aten.t.default]), 0) |
| 674 | |
| 675 | |
Horace He | e3c89d0 | 2022-08-25 06:59:37 +0000 | [diff] [blame] | 676 | @unittest.skipIf(not HAS_CUDA, 'CUDA-only test') |
| 677 | def test_amp_cache(self): |
| 678 | layer = torch.nn.Conv2d(3, 3, 3).cuda() |
Horace He | 0e0af73 | 2022-08-20 00:47:11 +0000 | [diff] [blame] | 679 | |
Horace He | e3c89d0 | 2022-08-25 06:59:37 +0000 | [diff] [blame] | 680 | def f(x, w): |
| 681 | return torch.nn.functional.conv2d(x, w, stride=layer.stride) |
Horace He | 6639677 | 2022-08-10 22:31:38 +0000 | [diff] [blame] | 682 | |
Horace He | e3c89d0 | 2022-08-25 06:59:37 +0000 | [diff] [blame] | 683 | inp = torch.randn(4, 3, 10, 10, device='cuda') |
| 684 | with torch.autocast('cuda'): |
| 685 | out_graph = make_fx(f)(inp, layer.weight).graph |
| 686 | out_graph2 = make_fx(f)(inp, layer.weight).graph |
| 687 | |
| 688 | self.assertEqual(len(out_graph.nodes), len(out_graph2.nodes)) |
| 689 | for a, b in zip(out_graph.nodes, out_graph2.nodes): |
| 690 | self.assertEqual(a.op, b.op) |
| 691 | |
Horace He | 4bdc0af | 2022-09-16 02:29:13 +0000 | [diff] [blame] | 692 | def test_strides(self): |
| 693 | def f(x): |
| 694 | self.assertTrue(x.is_contiguous()) |
| 695 | self.assertFalse(x.is_contiguous(memory_format=torch.channels_last)) |
| 696 | x = x.permute(0, 3, 1, 2) |
| 697 | self.assertFalse(x.is_contiguous()) |
| 698 | self.assertTrue(x.is_contiguous(memory_format=torch.channels_last)) |
| 699 | return x |
| 700 | make_fx(f)(torch.randn(2, 3, 4, 5)) |
| 701 | |
| 702 | def f(x): |
| 703 | self.assertTrue(x.is_contiguous()) |
| 704 | y = x[:, 1] |
| 705 | self.assertFalse(y.is_contiguous()) |
| 706 | y = x[:, ::2] |
| 707 | self.assertFalse(y.is_contiguous()) |
| 708 | return x.cos() |
| 709 | |
| 710 | make_fx(f)(torch.randn(2, 3, 4, 5)) |
| 711 | |
Horace He | 5e23074 | 2022-10-19 02:07:13 +0000 | [diff] [blame] | 712 | def test_pr_86917(self): |
| 713 | # Tests the issue brought up here https://github.com/pytorch/pytorch/pull/86917#issuecomment-1283155344 |
| 714 | def f(a, b): |
| 715 | return torch.ops.aten.nll_loss_forward(a, b, None, 1, 10) |
| 716 | |
| 717 | self._test(f, [torch.randn(1, 10), torch.zeros(1, dtype=torch.long)]) |
| 718 | |
Edward Z. Yang | d247244 | 2022-08-03 14:25:16 -0700 | [diff] [blame] | 719 | class TestGenericProxyTensorReal(TestGenericProxyTensor): |
| 720 | tracing_mode = "real" |
| 721 | |
| 722 | |
| 723 | class TestGenericProxyTensorFake(TestGenericProxyTensor): |
| 724 | tracing_mode = "fake" |
| 725 | |
| 726 | |
Edward Z. Yang | d247244 | 2022-08-03 14:25:16 -0700 | [diff] [blame] | 727 | @xfail_inherited_tests([ |
Edward Z. Yang | d247244 | 2022-08-03 14:25:16 -0700 | [diff] [blame] | 728 | "test_make_fx_overloads", |
Edward Z. Yang | d247244 | 2022-08-03 14:25:16 -0700 | [diff] [blame] | 729 | ]) |
| 730 | class TestGenericProxyTensorSymbolic(TestGenericProxyTensor): |
| 731 | tracing_mode = "symbolic" |
| 732 | |
| 733 | |
| 734 | del TestGenericProxyTensor |
| 735 | |
| 736 | |
Edward Z. Yang | b361f70 | 2022-08-03 10:50:30 -0700 | [diff] [blame] | 737 | class TestRealProxyTensor(TestCase): |
Horace He | c280857 | 2022-08-13 00:37:28 +0000 | [diff] [blame] | 738 | pass |
Edward Z. Yang | b361f70 | 2022-08-03 10:50:30 -0700 | [diff] [blame] | 739 | |
| 740 | class TestFakeProxyTensor(TestCase): |
| 741 | def test_issue82547(self): |
| 742 | x = nn.Parameter(torch.randn(3, 3)) |
| 743 | |
| 744 | def f(): |
| 745 | return torch.ops.aten.t.default(x) |
Tugsbayasgalan (Tugsuu) Manlaibaatar | 1aab755 | 2022-12-12 18:53:08 +0000 | [diff] [blame] | 746 | self.assertRaisesRegex(Exception, "Please convert all Tensors", lambda: make_fx(f, tracing_mode="fake")()) |
Edward Z. Yang | b361f70 | 2022-08-03 10:50:30 -0700 | [diff] [blame] | 747 | |
| 748 | class A(torch.Tensor): |
| 749 | pass |
| 750 | |
| 751 | x = A(torch.randn(3, 3)) |
| 752 | self.assertRaisesRegex(TypeError, "no implementation found", lambda: make_fx(f, tracing_mode="fake")()) |
| 753 | |
| 754 | def test_use_fake_and_tensor(self): |
| 755 | def f(x, y): |
| 756 | z = torch.tensor([2.0, 3.0]) |
| 757 | return x + y + z |
| 758 | |
| 759 | g = make_fx(f, tracing_mode="fake")(torch.randn(2), torch.randn(2)) |
| 760 | x, y = torch.randn(2), torch.randn(2) |
| 761 | self.assertEqual(g(x, y), f(x, y)) |
| 762 | |
Edward Z. Yang | 10c938a | 2023-04-21 16:02:40 -0400 | [diff] [blame] | 763 | def test_fused_adam(self): |
| 764 | # See https://github.com/pytorch/pytorch/issues/99356 |
Wanchao Liang | ff7d5b6 | 2023-04-24 17:25:53 +0000 | [diff] [blame] | 765 | params = [torch.randn(10, 10) for _ in range(10)] |
Edward Z. Yang | 10c938a | 2023-04-21 16:02:40 -0400 | [diff] [blame] | 766 | grads = [torch.randn(10, 10) for _ in range(10)] |
| 767 | exp_avgs = [torch.randn(10, 10) for _ in range(10)] |
| 768 | exp_avg_sqs = [torch.randn(10, 10) for _ in range(10)] |
| 769 | max_exp_avg_sqs = [torch.randn(10, 10) for _ in range(10)] |
| 770 | state_steps = [torch.tensor(0) for _ in range(10)] |
| 771 | |
| 772 | def fused_adam(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps): |
Wanchao Liang | ff7d5b6 | 2023-04-24 17:25:53 +0000 | [diff] [blame] | 773 | (new_params, _, _, _, _) = aten._fused_adam.default( |
Edward Z. Yang | 10c938a | 2023-04-21 16:02:40 -0400 | [diff] [blame] | 774 | params, |
| 775 | grads, |
| 776 | exp_avgs, |
| 777 | exp_avg_sqs, |
| 778 | max_exp_avg_sqs, |
| 779 | state_steps, |
| 780 | lr=0.1, |
| 781 | beta1=0.9, |
| 782 | beta2=0.999, |
| 783 | weight_decay=0.01, |
| 784 | eps=1e-8, |
| 785 | amsgrad=False, |
| 786 | maximize=False, |
| 787 | ) |
| 788 | |
Wanchao Liang | ff7d5b6 | 2023-04-24 17:25:53 +0000 | [diff] [blame] | 789 | for p, new_p in zip(params, new_params): |
| 790 | p.copy_(new_p) |
| 791 | |
| 792 | return params |
| 793 | |
Edward Z. Yang | 10c938a | 2023-04-21 16:02:40 -0400 | [diff] [blame] | 794 | gm = make_fx(fused_adam, tracing_mode='fake')( |
| 795 | params, |
| 796 | grads, |
| 797 | exp_avgs, |
| 798 | exp_avg_sqs, |
| 799 | max_exp_avg_sqs, |
| 800 | state_steps, |
| 801 | ) |
Wanchao Liang | ff7d5b6 | 2023-04-24 17:25:53 +0000 | [diff] [blame] | 802 | ensure_ops_have_val = [aten._fused_adam.default, operator.getitem] |
Edward Z. Yang | 10c938a | 2023-04-21 16:02:40 -0400 | [diff] [blame] | 803 | for n in gm.graph.nodes: |
Wanchao Liang | ff7d5b6 | 2023-04-24 17:25:53 +0000 | [diff] [blame] | 804 | if n.op == "call_function" and n.target in ensure_ops_have_val: |
Edward Z. Yang | 10c938a | 2023-04-21 16:02:40 -0400 | [diff] [blame] | 805 | self.assertIn('val', n.meta) |
| 806 | |
Edward Z. Yang | ccade94 | 2022-09-14 10:51:36 -0700 | [diff] [blame] | 807 | def test_alias(self): |
| 808 | def f(x): |
| 809 | return torch.ops.aten.alias(x) |
| 810 | |
| 811 | r = str(make_fx(f, tracing_mode="fake")(torch.randn(2)).code).strip() |
| 812 | # NB: this should not have a detach call |
| 813 | self.assertExpectedInline(r, """\ |
| 814 | def forward(self, x_1): |
| 815 | alias = torch.ops.aten.alias.default(x_1); x_1 = None |
| 816 | return alias""") |
| 817 | |
Horace He | 2c1bc21 | 2022-10-15 04:10:47 +0000 | [diff] [blame] | 818 | def test_meta(self): |
| 819 | def f(x): |
| 820 | a = x.cos() |
| 821 | b = torch.var_mean(a, dim=0) |
| 822 | c = b * 2 |
| 823 | return c |
| 824 | |
| 825 | out = make_fx(f, tracing_mode="fake")(torch.randn(5, 5)) |
| 826 | for n in out.graph.nodes: |
| 827 | if n.op == 'output': |
| 828 | continue |
| 829 | self.assertTrue('val' in n.meta) |
| 830 | |
Horace He | 6a3ecda | 2022-08-31 00:29:55 +0000 | [diff] [blame] | 831 | def _get_node(fx_g, cond): |
| 832 | for n in fx_g.graph.nodes: |
| 833 | if cond(n): |
| 834 | return n |
| 835 | raise AssertionError |
| 836 | |
Horace He | 377b5d6 | 2022-09-16 22:59:44 +0000 | [diff] [blame] | 837 | def _get_free_symbols(shape_env): |
| 838 | vars = tuple(shape_env.var_to_val.keys()) |
| 839 | return len([var for var in vars if var not in shape_env.replacements]) |
| 840 | |
| 841 | def _trace(f, *args): |
| 842 | inps = [torch.randn(arg) for arg in args] |
| 843 | return make_fx(f, tracing_mode="symbolic")(*inps) |
| 844 | |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 845 | # TODO: Need to test the guards themselves specifically as well |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 846 | class TestSymbolicTracing(TestCase): |
Edward Z. Yang | 4c8cfb5 | 2022-08-15 20:03:13 -0700 | [diff] [blame] | 847 | def _test_dynamic(self, fn, trace_inputs, test_inputs, assert_eq=True): |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 848 | """ |
| 849 | Tests fn traced with trace_inputs against test_inputs |
| 850 | Also returns shape env |
| 851 | """ |
| 852 | trace_inputs = [torch.randn(shape) for shape in trace_inputs] |
| 853 | traced_f = make_fx(fn, tracing_mode="symbolic")(*trace_inputs) |
| 854 | for input in test_inputs: |
| 855 | input = [torch.randn(shape) for shape in input] |
Edward Z. Yang | 4c8cfb5 | 2022-08-15 20:03:13 -0700 | [diff] [blame] | 856 | rx, ry = traced_f(*input), fn(*input) |
| 857 | if assert_eq: |
| 858 | self.assertEqual(rx, ry) |
Edward Z. Yang | 45109ec | 2022-12-10 05:19:57 -0800 | [diff] [blame] | 859 | return traced_f |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 860 | |
| 861 | |
Edward Z. Yang | e48c916 | 2022-12-16 09:02:37 +0800 | [diff] [blame] | 862 | def test_debug_interpreter(self): |
| 863 | import torch.library |
| 864 | from torch.library import Library |
| 865 | |
| 866 | foo = Library("foo", "DEF") |
| 867 | foo.define("foo(Tensor self) -> Tensor") |
| 868 | |
| 869 | # Operator where meta and cpu disagree on strides |
| 870 | @torch.library.impl(foo, "foo", "CPU") |
| 871 | def foo_cpu(x): |
| 872 | return x.clone().T |
| 873 | |
| 874 | @torch.library.impl(foo, "foo", "Meta") |
| 875 | def foo_meta(x): |
| 876 | return x.clone() |
| 877 | |
| 878 | def f(x): |
| 879 | return torch.ops.foo.foo.default(x) |
| 880 | |
| 881 | gm = make_fx(f, tracing_mode="symbolic")(torch.randn(2, 2)) |
| 882 | from torch._functorch.compilers import DebugInterpreter |
| 883 | |
| 884 | interp = DebugInterpreter(gm) |
| 885 | |
| 886 | # input mismatch is caught (indicates guard problem) |
| 887 | self.assertRaisesRegex( |
| 888 | AssertionError, r"3 != 1", |
| 889 | lambda: interp.run(torch.randn(3, 3).T), |
| 890 | ) |
| 891 | |
| 892 | # Catch the incorrect meta |
| 893 | self.assertRaisesRegex( |
| 894 | AssertionError, r"\(3, 1\) != \(1, 3\)", |
| 895 | lambda: interp.run(torch.randn(3, 3)) |
| 896 | ) |
| 897 | |
Edward Z. Yang | e33f1ee | 2022-12-10 20:23:17 -0800 | [diff] [blame] | 898 | def test_resize_from_zero(self): |
| 899 | def f(x, y): |
| 900 | x.resize_(y.size(0)) |
| 901 | |
| 902 | r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(0), torch.empty(2)).code).strip() |
| 903 | self.assertExpectedInline(r, """\ |
| 904 | def forward(self, x_1, y_1): |
| 905 | sym_size = torch.ops.aten.sym_size(y_1, 0); y_1 = None |
| 906 | resize_ = torch.ops.aten.resize_.default(x_1, [sym_size]); x_1 = sym_size = None |
| 907 | return None""") |
| 908 | |
| 909 | |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 910 | def test_unary(self): |
| 911 | def f(x): |
| 912 | assert x.shape[0] < 20 |
| 913 | return x.cos() |
| 914 | test_inputs = [] |
| 915 | test_inputs.append([(2, 5)]) |
| 916 | test_inputs.append([(6, 8)]) |
Edward Z. Yang | 45109ec | 2022-12-10 05:19:57 -0800 | [diff] [blame] | 917 | gm = self._test_dynamic(f, [(3, 4)], test_inputs) |
| 918 | self.assertTrue(eval_guards(gm, torch.randn(4, 5))) |
Edward Z. Yang | 67436f6 | 2022-12-16 09:02:35 +0800 | [diff] [blame] | 919 | self.assertEqual(repr(bind_symbols(gm, torch.randn(4, 5))), "{s0: 4, s1: 5}") |
Edward Z. Yang | 45109ec | 2022-12-10 05:19:57 -0800 | [diff] [blame] | 920 | self.assertFalse(eval_guards(gm, torch.randn(25, 5))) |
Michael Voznesensky | b1e60bf | 2023-04-03 20:11:34 +0000 | [diff] [blame] | 921 | self.assertExpectedInline(show_guards(gm), """L['x'].size()[0] < 20""") |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 922 | |
Edward Z. Yang | 1b5bfe9 | 2023-01-26 09:42:30 -0800 | [diff] [blame] | 923 | @unittest.skipIf(not HAS_CUDA, 'CUDA-only test') |
| 924 | def test_cpu_scalar_cuda(self): |
| 925 | # Extracted from wave2vec2 |
| 926 | def f(a, b): |
| 927 | return (a * b) @ b |
| 928 | |
| 929 | r = str( |
| 930 | make_fx(f, tracing_mode="symbolic")( |
| 931 | torch.tensor(1.0), torch.randn(2, 2, device='cuda') |
| 932 | ).code |
| 933 | ).strip() |
| 934 | self.assertExpectedInline(r, """\ |
| 935 | def forward(self, a_1, b_1): |
| 936 | mul = torch.ops.aten.mul.Tensor(a_1, b_1); a_1 = None |
| 937 | mm = torch.ops.aten.mm.default(mul, b_1); mul = b_1 = None |
| 938 | return mm""") |
| 939 | |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 940 | def test_binary_broadcast(self): |
| 941 | def f(a, b): |
| 942 | c = a * b |
| 943 | return c |
| 944 | |
| 945 | test_inputs = [] |
| 946 | test_inputs.append([(1, 5), (3, 1)]) |
| 947 | test_inputs.append([(1, 4), (4, 1)]) |
Edward Z. Yang | 45109ec | 2022-12-10 05:19:57 -0800 | [diff] [blame] | 948 | shape_env = self._test_dynamic(f, [(1, 2), (3, 1)], test_inputs).shape_env |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 949 | assert len(shape_env.guards) == 0 |
| 950 | |
Edward Z. Yang | 4c8cfb5 | 2022-08-15 20:03:13 -0700 | [diff] [blame] | 951 | def test_multiply_shape(self): |
| 952 | def f(a): |
| 953 | return torch.empty(a.shape[0] * 2) |
| 954 | |
| 955 | r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip() |
| 956 | self.assertExpectedInline(r, """\ |
| 957 | def forward(self, a_1): |
Horace He | 7ebdb4c | 2022-08-23 05:11:03 +0000 | [diff] [blame] | 958 | sym_size = torch.ops.aten.sym_size(a_1, 0); a_1 = None |
| 959 | mul = sym_size * 2; sym_size = None |
Edward Z. Yang | ad44670 | 2022-08-29 06:08:43 -0700 | [diff] [blame] | 960 | empty = torch.ops.aten.empty.memory_format([mul], device = device(type='cpu'), pin_memory = False); mul = None |
Edward Z. Yang | 94b5c80 | 2022-11-18 13:14:40 -0800 | [diff] [blame] | 961 | return empty""") |
Edward Z. Yang | 4c8cfb5 | 2022-08-15 20:03:13 -0700 | [diff] [blame] | 962 | |
Edward Z. Yang | f7365ec | 2022-12-10 20:29:21 -0800 | [diff] [blame] | 963 | def test_item(self): |
| 964 | def f(a): |
| 965 | r = a.item() |
| 966 | return r * a |
| 967 | |
| 968 | r = str(make_fx(f, tracing_mode="symbolic")(torch.randn(1)).code).strip() |
| 969 | self.assertExpectedInline(r, """\ |
| 970 | def forward(self, a_1): |
| 971 | _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(a_1) |
| 972 | mul = torch.ops.aten.mul.Tensor(a_1, _local_scalar_dense); a_1 = _local_scalar_dense = None |
| 973 | return mul""") |
| 974 | |
Edward Z. Yang | 2f32fd7 | 2023-02-15 06:37:23 -0800 | [diff] [blame] | 975 | def test_item_to_constructor(self): |
| 976 | def f(a): |
| 977 | r = a.item() |
Edward Z. Yang | 027ebca | 2023-03-01 10:51:12 -0800 | [diff] [blame] | 978 | constrain_range(r, min=2) |
Edward Z. Yang | 2f32fd7 | 2023-02-15 06:37:23 -0800 | [diff] [blame] | 979 | return torch.empty(r) |
| 980 | |
| 981 | r = str(make_fx(f, tracing_mode="symbolic")(torch.randint(5, (1,))).code).strip() |
| 982 | self.assertExpectedInline( |
| 983 | r, """\ |
| 984 | def forward(self, a_1): |
| 985 | _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(a_1); a_1 = None |
| 986 | empty = torch.ops.aten.empty.memory_format([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None |
| 987 | return empty""" # noqa: B950 |
| 988 | ) |
albanD | 12b2f70 | 2022-10-19 11:27:42 -0400 | [diff] [blame] | 989 | |
Edward Z. Yang | 4833e47 | 2023-02-23 11:51:25 -0800 | [diff] [blame] | 990 | def test_dynamic_pointwise_scalar(self): |
| 991 | def f(gravity, mask): |
| 992 | gravity[mask, 0] = gravity[mask, 0] * -1 |
| 993 | |
| 994 | r = str(make_fx(f, tracing_mode="symbolic")( |
| 995 | torch.randn((12, 4)), |
| 996 | torch.randint(0, 2, (12,), dtype=torch.bool) |
| 997 | ).code).strip() |
| 998 | self.assertExpectedInline(r, """\ |
| 999 | def forward(self, gravity_1, mask_1): |
| 1000 | select = torch.ops.aten.select.int(gravity_1, 1, 0) |
| 1001 | index = torch.ops.aten.index.Tensor(select, [mask_1]); select = None |
| 1002 | mul = torch.ops.aten.mul.Tensor(index, -1); index = None |
| 1003 | select_1 = torch.ops.aten.select.int(gravity_1, 1, 0); gravity_1 = None |
| 1004 | index_put_ = torch.ops.aten.index_put_.default(select_1, [mask_1], mul); select_1 = mask_1 = mul = None |
| 1005 | return None""") |
| 1006 | |
| 1007 | def test_reflect_r_over_x(self): |
| 1008 | def reflect_R_over_x(R): |
| 1009 | reflect = torch.eye(3, device=R.device) |
| 1010 | reflect[0, 0] = -1 |
| 1011 | return reflect @ R @ reflect |
| 1012 | |
| 1013 | def f(crop_camera, mask): |
| 1014 | crop_camera[mask] = reflect_R_over_x(crop_camera[mask]) |
| 1015 | |
| 1016 | r = str(make_fx(f, tracing_mode="symbolic")( |
| 1017 | torch.randn((12, 3, 3)), |
| 1018 | torch.randint(0, 2, (12,), dtype=torch.bool) |
| 1019 | ).code).strip() |
| 1020 | self.assertExpectedInline(r, """\ |
| 1021 | def forward(self, crop_camera_1, mask_1): |
| 1022 | index = torch.ops.aten.index.Tensor(crop_camera_1, [mask_1]) |
| 1023 | eye = torch.ops.aten.eye.default(3, device = device(type='cpu'), pin_memory = False) |
| 1024 | _tensor_constant0 = self._tensor_constant0 |
| 1025 | lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None |
| 1026 | select = torch.ops.aten.select.int(eye, 0, 0) |
| 1027 | select_1 = torch.ops.aten.select.int(select, 0, 0); select = None |
| 1028 | copy_ = torch.ops.aten.copy_.default(select_1, lift_fresh_copy); select_1 = lift_fresh_copy = None |
| 1029 | transpose = torch.ops.aten.transpose.int(index, -2, -1) |
| 1030 | t = torch.ops.aten.t.default(eye) |
| 1031 | clone = torch.ops.aten.clone.default(transpose, memory_format = torch.contiguous_format); transpose = None |
| 1032 | sym_size = torch.ops.aten.sym_size(index, 0); index = None |
| 1033 | sym_size_1 = torch.ops.aten.sym_size(crop_camera_1, 2) |
| 1034 | mul = sym_size * sym_size_1 |
| 1035 | sym_size_2 = torch.ops.aten.sym_size(crop_camera_1, 1) |
| 1036 | _unsafe_view = torch.ops.aten._unsafe_view.default(clone, [mul, sym_size_2]); clone = mul = sym_size_2 = None |
| 1037 | mm = torch.ops.aten.mm.default(_unsafe_view, t); _unsafe_view = t = None |
| 1038 | view = torch.ops.aten.view.default(mm, [sym_size, sym_size_1, 3]); mm = sym_size_1 = None |
| 1039 | transpose_1 = torch.ops.aten.transpose.int(view, -2, -1) |
| 1040 | clone_1 = torch.ops.aten.clone.default(transpose_1, memory_format = torch.contiguous_format); transpose_1 = None |
| 1041 | mul_1 = sym_size * 3 |
| 1042 | sym_size_3 = torch.ops.aten.sym_size(view, 1); view = None |
| 1043 | view_1 = torch.ops.aten.view.default(clone_1, [mul_1, sym_size_3]); clone_1 = mul_1 = sym_size_3 = None |
| 1044 | mm_1 = torch.ops.aten.mm.default(view_1, eye); view_1 = eye = None |
| 1045 | view_2 = torch.ops.aten.view.default(mm_1, [sym_size, 3, 3]); mm_1 = sym_size = None |
| 1046 | index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], view_2); crop_camera_1 = mask_1 = view_2 = None |
| 1047 | return None""") |
| 1048 | |
Edward Z. Yang | 98ff841 | 2023-03-06 18:23:35 -0800 | [diff] [blame] | 1049 | def test_unbacked_slice(self): |
| 1050 | def f(x, m): |
| 1051 | x = x[m] |
| 1052 | return x[slice(None, None, None), slice(None, None, None), slice(None, 2, None)] |
| 1053 | |
| 1054 | make_fx(f, tracing_mode="symbolic")( |
| 1055 | torch.randn((12, 3, 3)), |
| 1056 | torch.randint(0, 2, (12,), dtype=torch.bool) |
| 1057 | ) |
| 1058 | |
Edward Z. Yang | 4833e47 | 2023-02-23 11:51:25 -0800 | [diff] [blame] | 1059 | @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision") |
| 1060 | def test_unbacked_batch_resnet(self): |
| 1061 | mod = torchvision.models.resnet18() |
| 1062 | |
| 1063 | def f(x, mask, params, buffers): |
| 1064 | for p in itertools.chain([x, mask], params.values(), buffers.values()): |
| 1065 | for s in p.shape: |
| 1066 | guard_int(s) |
| 1067 | x = x[mask] |
| 1068 | constrain_range(x.shape[0], min=1) |
| 1069 | for p in params.values(): |
| 1070 | p.grad = None |
| 1071 | return torch.func.functional_call(mod, {**params, **buffers}, (x,)).sum() |
| 1072 | |
| 1073 | make_fx(f, tracing_mode="symbolic")( |
| 1074 | torch.randn(3, 3, 250, 250), |
| 1075 | torch.randint(0, 2, (3,), dtype=torch.bool), |
| 1076 | dict(mod.named_parameters()), |
| 1077 | dict(mod.named_buffers()), |
| 1078 | ) |
| 1079 | |
| 1080 | def test_boolean_index(self): |
| 1081 | def f(images, handedness, valid): |
| 1082 | images = images[valid] |
| 1083 | handedness = handedness[valid] |
Edward Z. Yang | 4833e47 | 2023-02-23 11:51:25 -0800 | [diff] [blame] | 1084 | right_hand_mask = handedness == 1 |
| 1085 | images[right_hand_mask] = images[right_hand_mask].flip(-1) |
| 1086 | |
| 1087 | r = str(make_fx(f, tracing_mode="symbolic")( |
| 1088 | torch.randint(0, 256, (512, 1, 96, 96)), |
| 1089 | torch.randint(0, 1, (512,)), |
| 1090 | torch.randint(0, 2, (512,), dtype=torch.bool) |
| 1091 | ).code).strip() |
| 1092 | self.assertExpectedInline(r, """\ |
| 1093 | def forward(self, images_1, handedness_1, valid_1): |
| 1094 | index = torch.ops.aten.index.Tensor(images_1, [valid_1]); images_1 = None |
| 1095 | index_1 = torch.ops.aten.index.Tensor(handedness_1, [valid_1]); handedness_1 = valid_1 = None |
| 1096 | eq = torch.ops.aten.eq.Scalar(index_1, 1); index_1 = None |
| 1097 | index_2 = torch.ops.aten.index.Tensor(index, [eq]) |
| 1098 | flip = torch.ops.aten.flip.default(index_2, [-1]); index_2 = None |
| 1099 | index_put_ = torch.ops.aten.index_put_.default(index, [eq], flip); index = eq = flip = None |
| 1100 | return None""") |
| 1101 | |
albanD | 12b2f70 | 2022-10-19 11:27:42 -0400 | [diff] [blame] | 1102 | def test_neg_shape(self): |
| 1103 | def f(a): |
| 1104 | return torch.empty(-a.shape[0] + 10) |
| 1105 | |
Horace He | 21bef8e | 2022-10-26 16:37:10 +0000 | [diff] [blame] | 1106 | r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(2)).code).strip() |
albanD | 12b2f70 | 2022-10-19 11:27:42 -0400 | [diff] [blame] | 1107 | self.assertExpectedInline(r, """\ |
| 1108 | def forward(self, a_1): |
| 1109 | sym_size = torch.ops.aten.sym_size(a_1, 0); a_1 = None |
| 1110 | neg = -sym_size; sym_size = None |
| 1111 | add = neg + 10; neg = None |
| 1112 | empty = torch.ops.aten.empty.memory_format([add], device = device(type='cpu'), pin_memory = False); add = None |
Edward Z. Yang | 94b5c80 | 2022-11-18 13:14:40 -0800 | [diff] [blame] | 1113 | return empty""") |
albanD | 12b2f70 | 2022-10-19 11:27:42 -0400 | [diff] [blame] | 1114 | |
Edward Z. Yang | 8efe4fd | 2023-02-23 11:54:36 -0800 | [diff] [blame] | 1115 | def test_invalidate_nonzero(self): |
| 1116 | ok = False |
| 1117 | |
| 1118 | def f(a): |
| 1119 | nonlocal ok |
| 1120 | b = a.clone() |
| 1121 | x = b.nonzero() |
| 1122 | x1 = b.nonzero() |
| 1123 | x2 = b.nonzero() |
| 1124 | assert x1.shape[0] == x2.shape[0] |
| 1125 | ok = True |
| 1126 | b.normal_() |
| 1127 | y = b.nonzero() |
| 1128 | try: |
| 1129 | bool(x1.shape[0] == y.shape[0]) |
| 1130 | self.fail("didn't raise exception") |
| 1131 | except GuardOnDataDependentSymNode: |
| 1132 | pass |
| 1133 | |
| 1134 | make_fx(f, tracing_mode="symbolic")(torch.randn(4)) |
| 1135 | |
albanD | c21dcff | 2022-10-16 22:16:14 -0400 | [diff] [blame] | 1136 | def test_sqrt_size(self): |
| 1137 | def f(a): |
| 1138 | return a / a.size(-1) ** 0.5 |
| 1139 | |
| 1140 | r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip() |
| 1141 | self.assertExpectedInline(r, """\ |
| 1142 | def forward(self, a_1): |
| 1143 | sym_size = torch.ops.aten.sym_size(a_1, 0) |
Edward Z. Yang | 1ff5222 | 2022-10-27 13:49:11 -0700 | [diff] [blame] | 1144 | pow_1 = sym_size ** 0.5; sym_size = None |
albanD | c21dcff | 2022-10-16 22:16:14 -0400 | [diff] [blame] | 1145 | div = torch.ops.aten.div.Tensor(a_1, pow_1); a_1 = pow_1 = None |
| 1146 | return div""") |
| 1147 | |
| 1148 | |
Edward Z. Yang | 2a332af | 2022-09-02 08:53:59 -0700 | [diff] [blame] | 1149 | def test_symint_to_tensor(self): |
| 1150 | def f(a): |
| 1151 | return a / a.shape[0] |
| 1152 | |
| 1153 | r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip() |
| 1154 | self.assertExpectedInline(r, """\ |
| 1155 | def forward(self, a_1): |
| 1156 | sym_size = torch.ops.aten.sym_size(a_1, 0) |
| 1157 | div = torch.ops.aten.div.Tensor(a_1, sym_size); a_1 = sym_size = None |
Edward Z. Yang | 2a332af | 2022-09-02 08:53:59 -0700 | [diff] [blame] | 1158 | return div""") |
| 1159 | |
| 1160 | r = str(make_fx(f, tracing_mode="symbolic", decomposition_table=decomposition_table)(torch.empty(4)).code).strip() |
| 1161 | self.assertExpectedInline(r, """\ |
| 1162 | def forward(self, a_1): |
| 1163 | sym_size = torch.ops.aten.sym_size(a_1, 0) |
Joel Schlosser | 8b55b86 | 2022-12-27 16:59:38 -0500 | [diff] [blame] | 1164 | sym_float = torch.sym_float(sym_size); sym_size = None |
Edward Z. Yang | 2a332af | 2022-09-02 08:53:59 -0700 | [diff] [blame] | 1165 | div = torch.ops.prims.div.default(a_1, sym_float); a_1 = sym_float = None |
Edward Z. Yang | 2a332af | 2022-09-02 08:53:59 -0700 | [diff] [blame] | 1166 | return div""") |
| 1167 | |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1168 | def test_cat(self): |
| 1169 | def f(a, b): |
| 1170 | val = torch.mul(a, b) |
| 1171 | out = torch.cat([val, val]) |
| 1172 | if out.shape[0] * out.shape[1] > 20: |
| 1173 | out = out.cos() |
| 1174 | return out |
| 1175 | |
| 1176 | test_inputs = [] |
| 1177 | test_inputs.append([(1, 5), (6, 1)]) |
| 1178 | test_inputs.append([(1, 4), (3, 1)]) |
Edward Z. Yang | 45109ec | 2022-12-10 05:19:57 -0800 | [diff] [blame] | 1179 | gm = self._test_dynamic(f, [(1, 6), (8, 1)], test_inputs) |
| 1180 | self.assertTrue(eval_guards(gm, torch.randn(1, 10), torch.randn(6, 1))) |
| 1181 | self.assertFalse(eval_guards(gm, torch.randn(1, 2), torch.randn(4, 1))) |
Michael Voznesensky | b1e60bf | 2023-04-03 20:11:34 +0000 | [diff] [blame] | 1182 | self.assertExpectedInline(show_guards(gm), """2*L['a'].size()[1]*L['b'].size()[0] > 20""") |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1183 | |
Horace He | 86de9e7 | 2022-08-13 19:03:13 +0000 | [diff] [blame] | 1184 | def test_new_empty(self): |
| 1185 | def f(a, b): |
Edward Z. Yang | 4c8cfb5 | 2022-08-15 20:03:13 -0700 | [diff] [blame] | 1186 | return a.new_empty(b.shape[0], b.shape[1] * 2) |
Horace He | 86de9e7 | 2022-08-13 19:03:13 +0000 | [diff] [blame] | 1187 | |
Edward Z. Yang | 45109ec | 2022-12-10 05:19:57 -0800 | [diff] [blame] | 1188 | self._test_dynamic(f, [(2, 4), (4, 5)], [[(2, 3), (5, 7)], [(3, 7), (9, 3)]], assert_eq=False).shape_env |
Horace He | 86de9e7 | 2022-08-13 19:03:13 +0000 | [diff] [blame] | 1189 | |
Edward Z. Yang | 954660a | 2022-10-03 09:29:49 -0700 | [diff] [blame] | 1190 | def test_size_with_tensor(self): |
| 1191 | def f(tensor): |
| 1192 | max_size = torch.tensor([800, 1216], dtype=torch.int64) |
| 1193 | batch_shape = [2] + list(tensor.shape[:-2]) + list(max_size) |
| 1194 | return tensor.new_empty(batch_shape) |
| 1195 | |
| 1196 | a = torch.randn(3, 800, 1199) |
| 1197 | self.assertRaisesRegex( |
| 1198 | RuntimeError, "data-dependent", lambda: make_fx(f, tracing_mode="symbolic")(a) |
| 1199 | ) |
| 1200 | |
Horace He | 86de9e7 | 2022-08-13 19:03:13 +0000 | [diff] [blame] | 1201 | def test_expand(self): |
| 1202 | def f(a): |
| 1203 | b = torch.mul(a, a) |
| 1204 | c = b.expand(a.shape) |
| 1205 | return c |
| 1206 | |
| 1207 | self._test_dynamic(f, [(3,)], [[(3,)], [(4,)], [(2,)]]) |
| 1208 | self._test_dynamic(f, [(5, 1)], [[(4, 1)], [(3, 1)], [(6, 1)]]) |
| 1209 | |
Horace He | bd757b3 | 2022-10-19 03:19:22 +0000 | [diff] [blame] | 1210 | def test_metadata(self): |
Horace He | 6a3ecda | 2022-08-31 00:29:55 +0000 | [diff] [blame] | 1211 | def f(a, b): |
| 1212 | d = a.new_empty(a.shape[0] + b.shape[0]) |
| 1213 | return d |
| 1214 | fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5), torch.randn(4)) |
Horace He | 6a3ecda | 2022-08-31 00:29:55 +0000 | [diff] [blame] | 1215 | meta_c = _get_node(fx_g, lambda x: x.target == aten.new_empty.default) |
| 1216 | meta_d = _get_node(fx_g, lambda x: x.target == operator.add) |
Edward Z. Yang | c450159 | 2023-01-19 21:16:12 +0000 | [diff] [blame] | 1217 | self.assertTrue(meta_c.meta['val'].shape[0].node.expr == meta_d.meta['val'].node.expr) |
Horace He | bd757b3 | 2022-10-19 03:19:22 +0000 | [diff] [blame] | 1218 | |
| 1219 | def test_metadata_fresh(self): |
| 1220 | def f(x): |
| 1221 | assert x.shape[0] == 3 |
| 1222 | return x.cos() |
| 1223 | |
| 1224 | fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(3)) |
| 1225 | meta_cos = _get_node(fx_g, lambda x: x.target == aten.cos.default) |
| 1226 | meta_inp = _get_node(fx_g, lambda x: x.op == 'placeholder') |
Edward Z. Yang | c450159 | 2023-01-19 21:16:12 +0000 | [diff] [blame] | 1227 | self.assertTrue(meta_cos.meta['val'].shape[0].node.expr == 3) |
Horace He | bd757b3 | 2022-10-19 03:19:22 +0000 | [diff] [blame] | 1228 | # Checks if the input expr has been updated even though the constraint |
| 1229 | # happened afterwards |
Edward Z. Yang | c450159 | 2023-01-19 21:16:12 +0000 | [diff] [blame] | 1230 | self.assertTrue(meta_inp.meta['val'].shape[0].node.expr == 3) |
Horace He | bd757b3 | 2022-10-19 03:19:22 +0000 | [diff] [blame] | 1231 | |
Sherlock Huang | caf3d53 | 2022-11-19 23:10:34 +0000 | [diff] [blame] | 1232 | def test_elementwise_meta_with_sym_numbers(self): |
| 1233 | def f(x, offset, as_sym_float=False): |
| 1234 | x0 = x.size()[0] |
| 1235 | if as_sym_float: |
| 1236 | x0 = sym_float(x0) |
| 1237 | return torch.add(x0, offset) |
Horace He | bd757b3 | 2022-10-19 03:19:22 +0000 | [diff] [blame] | 1238 | |
Sherlock Huang | caf3d53 | 2022-11-19 23:10:34 +0000 | [diff] [blame] | 1239 | fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2.0, False) |
| 1240 | meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor) |
| 1241 | self.assertEqual(meta_add.meta['val'].shape, ()) |
| 1242 | self.assertEqual(meta_add.meta['val'].dtype, torch.float32) |
Horace He | bd757b3 | 2022-10-19 03:19:22 +0000 | [diff] [blame] | 1243 | |
Sherlock Huang | caf3d53 | 2022-11-19 23:10:34 +0000 | [diff] [blame] | 1244 | fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2, False) |
| 1245 | meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor) |
| 1246 | self.assertEqual(meta_add.meta['val'].shape, ()) |
| 1247 | self.assertEqual(meta_add.meta['val'].dtype, torch.int64) |
| 1248 | |
| 1249 | fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2, True) |
| 1250 | meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor) |
| 1251 | self.assertEqual(meta_add.meta['val'].shape, ()) |
| 1252 | self.assertEqual(meta_add.meta['val'].dtype, torch.float32) |
Horace He | 86de9e7 | 2022-08-13 19:03:13 +0000 | [diff] [blame] | 1253 | |
Horace He | bc993e3 | 2022-10-03 07:11:53 +0000 | [diff] [blame] | 1254 | def test_return_symint(self): |
| 1255 | def f(x): |
| 1256 | return x.shape[0], x.cos(), x.shape[0] / 5 |
| 1257 | self._test_dynamic(f, [(5,)], [[(4,)], [(12,)]]) |
| 1258 | |
| 1259 | def f(x): |
| 1260 | return x.shape |
| 1261 | self._test_dynamic(f, [(5, 3)], [[(4, 6)]]) |
| 1262 | |
Edward Z. Yang | 2a47b10 | 2022-10-29 08:45:32 -0700 | [diff] [blame] | 1263 | def test_rmethod(self): |
| 1264 | def f(x): |
| 1265 | return x.size(0) + x |
| 1266 | self._test_dynamic(f, [(5,)], [[(4,)], [(12,)]]) |
| 1267 | |
Horace He | 569eebb | 2022-10-25 04:04:16 +0000 | [diff] [blame] | 1268 | def test_mega_guard(self): |
| 1269 | def f(a, b): |
| 1270 | assert a.shape[0] == b.shape[0] * 2 |
Horace He | 569eebb | 2022-10-25 04:04:16 +0000 | [diff] [blame] | 1271 | return a.cos() |
| 1272 | fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(16), torch.randn(8)) |
Edward Z. Yang | bcf15cd | 2022-12-29 13:32:31 +0800 | [diff] [blame] | 1273 | from torch._dynamo.source import LocalSource |
Edward Z. Yang | 45109ec | 2022-12-10 05:19:57 -0800 | [diff] [blame] | 1274 | self.assertExpectedInline( |
Michael Voznesensky | 4c28929 | 2023-04-22 07:33:12 +0000 | [diff] [blame] | 1275 | str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")], ignore_static=False)), # noqa: B950 |
Michael Voznesensky | b1e60bf | 2023-04-03 20:11:34 +0000 | [diff] [blame] | 1276 | """["L['a'].size()[0] == 2*L['b'].size()[0]", "L['a'].stride()[0] == 1", "L['a'].storage_offset() == 0", "L['b'].stride()[0] == 1", "L['b'].storage_offset() == 0", "2 <= L['b'].size()[0]"]""" # noqa: B950 |
Edward Z. Yang | 45109ec | 2022-12-10 05:19:57 -0800 | [diff] [blame] | 1277 | ) |
Michael Voznesensky | 4c28929 | 2023-04-22 07:33:12 +0000 | [diff] [blame] | 1278 | self.assertExpectedInline( |
| 1279 | str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")], ignore_static=True)), # noqa: B950 |
| 1280 | """["L['a'].size()[0] == 2*L['b'].size()[0]", "2 <= L['b'].size()[0]"]""" # noqa: B950 |
| 1281 | ) |
Horace He | 569eebb | 2022-10-25 04:04:16 +0000 | [diff] [blame] | 1282 | |
Horace He | 21bef8e | 2022-10-26 16:37:10 +0000 | [diff] [blame] | 1283 | def test_sym_storage_offset(self): |
| 1284 | def f(x, y): |
| 1285 | return x + y |
| 1286 | |
| 1287 | inp = (torch.randn(8)[3:], torch.randn(5)) |
| 1288 | fx_g = make_fx(f, tracing_mode="symbolic")(*inp) |
| 1289 | inp = (torch.randn(8)[3:], torch.randn(5)) |
| 1290 | self.assertEqual(fx_g(*inp), f(*inp)) |
Horace He | 569eebb | 2022-10-25 04:04:16 +0000 | [diff] [blame] | 1291 | |
Horace He | 377b5d6 | 2022-09-16 22:59:44 +0000 | [diff] [blame] | 1292 | def _assert_no_guards(self, fx_g, free_symbols): |
Edward Z. Yang | 9baf677 | 2022-09-21 07:00:52 -0700 | [diff] [blame] | 1293 | assert _get_free_symbols(fx_g.shape_env) == free_symbols, fx_g.shape_env.var_to_val |
| 1294 | assert len(fx_g.shape_env.get_nontrivial_guards()) == 0, fx_g.shape_env.format_guards() |
Horace He | 377b5d6 | 2022-09-16 22:59:44 +0000 | [diff] [blame] | 1295 | |
| 1296 | def test_guards_equal(self): |
| 1297 | def f(a, b): |
| 1298 | return a * b |
| 1299 | |
Edward Z. Yang | ada6e5b | 2022-09-28 17:28:26 -0400 | [diff] [blame] | 1300 | # NB: Numbers are carefully chosen to avoid duck shaping from applying |
| 1301 | |
| 1302 | fx_g = _trace(f, (5, 6), (5, 6)) |
Horace He | 377b5d6 | 2022-09-16 22:59:44 +0000 | [diff] [blame] | 1303 | self._assert_no_guards(fx_g, 2) |
| 1304 | |
Edward Z. Yang | ada6e5b | 2022-09-28 17:28:26 -0400 | [diff] [blame] | 1305 | fx_g = _trace(f, (5, 6, 7), (5, 6, 7)) |
Horace He | 377b5d6 | 2022-09-16 22:59:44 +0000 | [diff] [blame] | 1306 | self._assert_no_guards(fx_g, 3) |
| 1307 | |
Edward Z. Yang | ada6e5b | 2022-09-28 17:28:26 -0400 | [diff] [blame] | 1308 | fx_g = _trace(f, (5, 1), (1, 6)) |
| 1309 | self._assert_no_guards(fx_g, 2) |
Horace He | 377b5d6 | 2022-09-16 22:59:44 +0000 | [diff] [blame] | 1310 | |
| 1311 | def f(a, b, c, d): |
| 1312 | a = a + b |
| 1313 | cat = torch.cat([c, d]) |
| 1314 | return a + cat |
| 1315 | |
| 1316 | fx_g = _trace(f, 7, 7, 4, 3) |
| 1317 | self._assert_no_guards(fx_g, 2) |
| 1318 | |
Horace He | 12a19a4 | 2022-09-17 18:11:51 +0000 | [diff] [blame] | 1319 | def f(a, b, c, d, e): |
| 1320 | vals = [a, b, c, d, e] |
| 1321 | x = a |
| 1322 | for idx in range(len(vals) - 1): |
| 1323 | x = torch.cat([x, vals[idx]]) + vals[idx + 1] |
| 1324 | return x |
| 1325 | |
| 1326 | fx_g = _trace(f, 2, 4, 8, 16, 32) |
| 1327 | self._assert_no_guards(fx_g, 1) |
| 1328 | |
Horace He | 377b5d6 | 2022-09-16 22:59:44 +0000 | [diff] [blame] | 1329 | def f(a, b): |
| 1330 | a = a.view(b.shape[0]) |
| 1331 | return a + b.sum() |
| 1332 | |
| 1333 | fx_g = _trace(f, (4, 2), 8) |
| 1334 | self._assert_no_guards(fx_g, 2) |
| 1335 | |
Edward Z. Yang | ada6e5b | 2022-09-28 17:28:26 -0400 | [diff] [blame] | 1336 | fx_g = _trace(f, (4, 2), (8, 5)) |
Horace He | 377b5d6 | 2022-09-16 22:59:44 +0000 | [diff] [blame] | 1337 | self._assert_no_guards(fx_g, 3) |
| 1338 | |
| 1339 | fx_g = _trace(f, (2, 3, 4), 24) |
| 1340 | self._assert_no_guards(fx_g, 3) |
| 1341 | |
| 1342 | def test_nonidentity_transitive_guards(self): |
| 1343 | def f(a, b, c, d, e): |
| 1344 | vals = [a, b, c, d, e] |
| 1345 | cat_vals = [] |
| 1346 | for idx in range(len(vals) - 1): |
| 1347 | cat_vals.append(torch.cat([vals[idx], vals[idx]])) |
| 1348 | final_vals = [] |
| 1349 | for a, b in reversed(list(zip(cat_vals, vals[1:]))): |
| 1350 | final_vals.append(a + b) |
| 1351 | return final_vals |
| 1352 | |
| 1353 | fx_g = _trace(f, 2, 4, 8, 16, 32) |
Edward Z. Yang | f1f26fe | 2023-02-12 14:04:01 -0800 | [diff] [blame] | 1354 | self.assertExpectedInline(show_guards(fx_g), """""") |
Horace He | 377b5d6 | 2022-09-16 22:59:44 +0000 | [diff] [blame] | 1355 | |
| 1356 | |
Horace He | 86de9e7 | 2022-08-13 19:03:13 +0000 | [diff] [blame] | 1357 | |
Horace He | 12a19a4 | 2022-09-17 18:11:51 +0000 | [diff] [blame] | 1358 | |
Horace He | 5e23074 | 2022-10-19 02:07:13 +0000 | [diff] [blame] | 1359 | |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 1360 | make_fx_failures = { |
PyTorch MergeBot | 4e33c8c | 2022-06-27 12:06:49 +0000 | [diff] [blame] | 1361 | # unknown |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 1362 | xfail('allclose'), |
Peter Bell | 9bf52f4 | 2022-06-20 19:58:41 +0100 | [diff] [blame] | 1363 | xfail('equal'), |
Horace He | f5d7e5a | 2022-06-16 22:04:10 +0000 | [diff] [blame] | 1364 | # empty |
| 1365 | skip('new_empty'), |
| 1366 | skip('empty_like'), |
| 1367 | skip('empty'), |
Edward Z. Yang | ce950b4 | 2023-02-21 09:13:06 -0500 | [diff] [blame] | 1368 | skip('empty_permuted'), |
Horace He | f5d7e5a | 2022-06-16 22:04:10 +0000 | [diff] [blame] | 1369 | # flaky |
| 1370 | skip('linalg.lstsq', 'grad_oriented'), |
| 1371 | skip('nn.functional.max_unpool1d', '', device_type='cpu'), |
| 1372 | skip('nn.functional.max_unpool2d', '', device_type='cpu'), |
| 1373 | skip('nn.functional.max_unpool3d', '', device_type='cpu'), |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 1374 | skip('linalg.lstsq'), # flaky, probably just a precision issue |
PyTorch MergeBot | 4e33c8c | 2022-06-27 12:06:49 +0000 | [diff] [blame] | 1375 | |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 1376 | # data-dependent control flow |
| 1377 | xfail('cov'), |
| 1378 | xfail('istft'), |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 1379 | xfail('nn.functional.gaussian_nll_loss'), |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 1380 | xfail('tensor_split'), |
Horace He | f5d7e5a | 2022-06-16 22:04:10 +0000 | [diff] [blame] | 1381 | xfail('corrcoef'), |
Edward Z. Yang | d423722 | 2022-08-12 06:17:53 -0700 | [diff] [blame] | 1382 | xfail('quantile'), |
| 1383 | xfail('nanquantile'), |
kshitij12345 | 4f6027b | 2022-09-12 16:59:05 +0000 | [diff] [blame] | 1384 | xfail('narrow'), |
Horace He | f5d7e5a | 2022-06-16 22:04:10 +0000 | [diff] [blame] | 1385 | |
Elias Ellison | 638feec | 2023-04-19 01:01:15 +0000 | [diff] [blame] | 1386 | # many complex operators incorrect striding, metadata |
| 1387 | skip('fft.fft', ''), |
| 1388 | skip('fft.hfft2', ''), |
| 1389 | skip('fft.hfft', ''), |
| 1390 | skip('fft.hfftn', ''), |
| 1391 | skip('fft.ifft', ''), |
| 1392 | skip('fft.ihfft2', ''), |
| 1393 | skip('fft.ihfft', ''), |
| 1394 | skip('fft.ihfftn', ''), |
| 1395 | skip('fft.irfft2', ''), |
| 1396 | skip('fft.irfft', ''), |
| 1397 | skip('fft.irfftn', ''), |
| 1398 | skip('fft.rfft2', ''), |
| 1399 | skip('fft.rfft', ''), |
| 1400 | skip('fft.rfftn', ''), |
| 1401 | |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 1402 | # Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse |
| 1403 | xfail('sparse.sampled_addmm'), |
mingfeima | c620ece | 2023-02-10 11:12:35 +0800 | [diff] [blame] | 1404 | xfail('sparse.mm', 'reduce'), |
Natalia Gimelshein | 6162a04 | 2022-09-16 15:54:50 +0000 | [diff] [blame] | 1405 | |
Edward Z. Yang | 42fefd4 | 2022-08-02 13:22:19 -0700 | [diff] [blame] | 1406 | # proxy tensor doesn't support sparse correctly right now |
| 1407 | skip('to_sparse'), |
PyTorch MergeBot | 4e33c8c | 2022-06-27 12:06:49 +0000 | [diff] [blame] | 1408 | # segfaults |
| 1409 | skip('block_diag'), |
| 1410 | } |
| 1411 | |
| 1412 | fake_tensor_failures = { |
PyTorch MergeBot | 4e33c8c | 2022-06-27 12:06:49 +0000 | [diff] [blame] | 1413 | # FakeTensor fallback doesn't work |
albanD | 496c0a2 | 2023-02-06 18:32:23 +0000 | [diff] [blame] | 1414 | xfail('_segment_reduce', 'lengths'), |
PyTorch MergeBot | 4e33c8c | 2022-06-27 12:06:49 +0000 | [diff] [blame] | 1415 | xfail('multinomial'), |
PyTorch MergeBot | 4e33c8c | 2022-06-27 12:06:49 +0000 | [diff] [blame] | 1416 | xfail('cholesky'), |
| 1417 | xfail('cholesky_inverse'), |
Sherlock Huang | f1fb586 | 2022-11-17 18:50:33 +0000 | [diff] [blame] | 1418 | # cannot do these as they rely on tensor data |
| 1419 | xfail('repeat_interleave'), |
PyTorch MergeBot | 4e33c8c | 2022-06-27 12:06:49 +0000 | [diff] [blame] | 1420 | # ASAN failures due to divide by 0 |
| 1421 | skip('nn.functional.nll_loss'), |
Elias Ellison | 638feec | 2023-04-19 01:01:15 +0000 | [diff] [blame] | 1422 | |
Elias Ellison | 638feec | 2023-04-19 01:01:15 +0000 | [diff] [blame] | 1423 | xfail("stft"), |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 1424 | } |
| 1425 | |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1426 | symbolic_tensor_failures = { |
Elias Ellison | 1c0f7bd | 2022-07-27 22:19:14 +0000 | [diff] [blame] | 1427 | xfail('linalg.eig'), |
Richard Zou | a47bc96 | 2022-09-06 10:20:07 -0700 | [diff] [blame] | 1428 | xfail('linalg.eigvals'), |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1429 | xfail('cholesky_solve', ''), # Could not run 'aten::_cholesky_solve_helper' with arguments from the 'Meta' back... |
Horace He | b3b9786 | 2022-10-13 20:19:16 +0000 | [diff] [blame] | 1430 | xfail('combinations', ''), |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1431 | xfail('diff', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1432 | xfail('frexp', ''), # aten.frexp.Tensor - couldn't find symbolic meta function/decomposition |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1433 | xfail('geqrf', ''), # aten.geqrf.default - couldn't find symbolic meta function/decomposition |
| 1434 | xfail('gradient', ''), # aten.size.default - couldn't find symbolic meta function/decomposition |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1435 | xfail('histc', ''), # Could not run 'aten::histc' with arguments from the 'Meta' backend. This could be because... |
| 1436 | xfail('histogram', ''), # Could not run 'aten::histogram.bin_ct' with arguments from the 'Meta' backend. This c... |
| 1437 | xfail('histogramdd', ''), # aten._histogramdd_bin_edges.default - couldn't find symbolic meta function/decomposition |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1438 | xfail('index_reduce', ''), # Float |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1439 | xfail('isin', ''), # aten.isin.Tensor_Tensor - couldn't find symbolic meta function/decomposition |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1440 | xfail('kron', ''), # aten.size.default - couldn't find symbolic meta function/decomposition |
| 1441 | xfail('kthvalue', ''), # aten.kthvalue.default - couldn't find symbolic meta function/decomposition |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1442 | xfail('linalg.eigh', ''), # aten._linalg_eigh.default - couldn't find symbolic meta function/decomposition |
| 1443 | xfail('linalg.eigvalsh', ''), # aten._linalg_eigh.default - couldn't find symbolic meta function/decomposition |
| 1444 | xfail('linalg.householder_product', ''), # aten.linalg_householder_product.default - couldn't find symbolic meta funct... |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1445 | xfail('linalg.ldl_factor', ''), # aten.linalg_ldl_factor_ex.default - couldn't find symbolic meta function/decomposition |
| 1446 | xfail('linalg.ldl_factor_ex', ''), # aten.linalg_ldl_factor_ex.default - couldn't find symbolic meta function/decompos... |
| 1447 | xfail('linalg.ldl_solve', ''), # aten.linalg_ldl_solve.default - couldn't find symbolic meta function/decomposition |
| 1448 | xfail('linalg.lu', ''), # aten.linalg_lu.default - couldn't find symbolic meta function/decomposition |
| 1449 | xfail('linalg.lu_factor', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition |
| 1450 | xfail('linalg.lu_factor_ex', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition |
| 1451 | xfail('linalg.lu_solve', ''), # aten.linalg_lu_solve.default - couldn't find symbolic meta function/decomposition |
Edward Z. Yang | 617e90d | 2022-07-27 11:07:50 -0400 | [diff] [blame] | 1452 | xfail('linalg.matrix_power'), # RuntimeError: Trying to call aten.size on a tensor with symbolic shape |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1453 | xfail('linalg.matrix_rank', 'hermitian'), # aten.size.default - couldn't find symbolic meta function/decomposition |
| 1454 | xfail('linalg.multi_dot', ''), # aten.size.default - couldn't find symbolic meta function/decomposition |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1455 | xfail('linalg.pinv', ''), # aten.linalg_pinv.atol_rtol_tensor - couldn't find symbolic meta function/decomposition |
| 1456 | xfail('linalg.pinv', 'singular'), # aten.linalg_cholesky_ex.default - couldn't find symbolic meta function/decomposition |
| 1457 | xfail('linalg.pinv', 'hermitian'), # aten.linalg_pinv.atol_rtol_tensor - couldn't find symbolic meta function/decompo... |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1458 | xfail('linalg.slogdet', ''), # aten._linalg_slogdet.default - couldn't find symbolic meta function/decomposition |
| 1459 | xfail('linalg.solve', ''), # aten._linalg_solve_ex.default - couldn't find symbolic meta function/decomposition |
| 1460 | xfail('linalg.solve_ex', ''), # aten._linalg_solve_ex.default - couldn't find symbolic meta function/decomposition |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1461 | xfail('linalg.tensorinv', ''), # aten.size.default - couldn't find symbolic meta function/decomposition |
| 1462 | xfail('linalg.tensorsolve', ''), # aten.size.default - couldn't find symbolic meta function/decomposition |
| 1463 | xfail('linalg.vander', ''), # aten.size.default - couldn't find symbolic meta function/decomposition |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1464 | xfail('logaddexp2', ''), # aten.logaddexp2.default - couldn't find symbolic meta function/decomposition |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1465 | xfail('logdet', ''), # aten.size.default - couldn't find symbolic meta function/decomposition |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1466 | xfail('lu', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition |
| 1467 | xfail('lu_solve', ''), # aten.linalg_lu_solve.default - couldn't find symbolic meta function/decomposition |
| 1468 | xfail('lu_unpack', ''), # aten.lu_unpack.default - couldn't find symbolic meta function/decomposition |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1469 | xfail('masked_select', ''), # aten.masked_select.default - couldn't find symbolic meta function/decomposition |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1470 | xfail('matrix_exp', ''), # aten.linalg_matrix_exp.default - couldn't find symbolic meta function/decomposition |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1471 | xfail('median', ''), # Could not run 'aten::median' with arguments from the 'Meta' backend. This could be becau... |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1472 | xfail('mode', ''), # aten.mode.default - couldn't find symbolic meta function/decomposition |
kshitij12345 | a3d37f1 | 2022-08-08 14:42:51 +0000 | [diff] [blame] | 1473 | xfail('nanquantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend. |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1474 | xfail('narrow', ''), # aten.size.default - couldn't find symbolic meta function/decomposition |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1475 | xfail('nn.functional.adaptive_max_pool1d', ''), # aten.size.default - couldn't find symbolic meta function/decomposition |
| 1476 | xfail('nn.functional.adaptive_max_pool2d', ''), # aten.adaptive_max_pool2d.default - couldn't find symbolic meta funct... |
| 1477 | xfail('nn.functional.adaptive_max_pool3d', ''), # argument 'output_size' (position 2) must be tupl... |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1478 | xfail('nn.functional.avg_pool3d', ''), # aten.avg_pool3d.default - couldn't find symbolic meta function/decomposition |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1479 | xfail('nn.functional.bilinear', ''), # aten.size.default - couldn't find symbolic meta function/decomposition |
| 1480 | xfail('nn.functional.binary_cross_entropy', ''), # aten.new_empty.default - couldn't find symbolic meta function/decom... |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1481 | xfail('nn.functional.cosine_similarity', ''), # aten.size.default - couldn't find symbolic meta function/decomposition |
| 1482 | xfail('nn.functional.cross_entropy', ''), # aten.size.default - couldn't find symbolic meta function/decomposition |
kshitij12345 | 56a41b5 | 2022-09-22 00:21:11 +0000 | [diff] [blame] | 1483 | xfail('nn.functional.ctc_loss'), # aten._ctc_loss.Tensor - couldn't find symbolic meta function/decomposition |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1484 | xfail('nn.functional.embedding_bag', ''), # aten._embedding_bag_forward_only.default - couldn't find symbolic meta fun... |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1485 | xfail('nn.functional.fractional_max_pool2d', ''), # argument 'size' must be tuple of ints, but found element of t... |
| 1486 | xfail('nn.functional.fractional_max_pool3d', ''), # argument 'size' must be tuple of ints, but found element of t... |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1487 | xfail('nn.functional.grid_sample', ''), # aten.grid_sampler_2d.default - couldn't find symbolic meta function/decompos... |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1488 | xfail('nn.functional.interpolate', 'linear'), # aten.upsample_linear1d.vec - couldn't find symbolic meta function/dec... |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1489 | xfail('nn.functional.interpolate', 'trilinear'), # aten.upsample_trilinear3d.vec - couldn't find symbolic meta functi... |
Richard Zou | c771d73 | 2022-09-06 07:10:33 -0700 | [diff] [blame] | 1490 | xfail('nn.functional.max_pool1d', ''), # Trying to call aten.size on a tensor with symbolic shapes. |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1491 | xfail('nn.functional.max_pool3d', ''), # aten.max_pool3d_with_indices.default - couldn't find symbolic meta function/d... |
| 1492 | xfail('nn.functional.max_unpool1d', 'grad'), # aten.max_unpool2d.default - couldn't find symbolic meta function/decom... |
| 1493 | xfail('nn.functional.max_unpool2d', 'grad'), # aten.max_unpool2d.default - couldn't find symbolic meta function/decom... |
| 1494 | xfail('nn.functional.max_unpool3d', 'grad'), # aten.max_unpool3d.default - couldn't find symbolic meta function/decom... |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1495 | xfail('nn.functional.multi_margin_loss', ''), # Could not run 'aten::multi_margin_loss' with arguments from the... |
| 1496 | xfail('nn.functional.multilabel_margin_loss', ''), # Could not run 'aten::multilabel_margin_loss_forward' with ... |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1497 | xfail('nn.functional.pad', 'reflect'), # aten.reflection_pad1d.default - couldn't find symbolic meta function/decompo... |
| 1498 | xfail('nn.functional.pad', 'replicate'), # aten.replication_pad1d.default - couldn't find symbolic meta function/deco... |
| 1499 | xfail('nn.functional.pdist', ''), # Could not run 'aten::_pdist_forward' with arguments from the 'Meta' backend... |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1500 | xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco... |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1501 | xfail('nn.functional.smooth_l1_loss', ''), # aten.size.default - couldn't find symbolic meta function/decomposition |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1502 | xfail('normal', 'number_mean'), # aten.normal.float_Tensor - couldn't find symbolic meta function/decomposition |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1503 | xfail('ormqr', ''), # aten.ormqr.default - couldn't find symbolic meta function/decomposition |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1504 | xfail('pinverse', ''), # aten.linalg_pinv.atol_rtol_tensor - couldn't find symbolic meta function/decomposition |
| 1505 | xfail('polygamma', 'polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition |
| 1506 | xfail('polygamma', 'polygamma_n_1'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition |
| 1507 | xfail('polygamma', 'polygamma_n_2'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition |
| 1508 | xfail('polygamma', 'polygamma_n_3'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition |
| 1509 | xfail('polygamma', 'polygamma_n_4'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition |
kshitij12345 | a3d37f1 | 2022-08-08 14:42:51 +0000 | [diff] [blame] | 1510 | xfail('quantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend. |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1511 | xfail('renorm', ''), # aten.renorm.default - couldn't find symbolic meta function/decomposition |
Michael Voznesensky | bc19494 | 2022-10-25 21:15:40 +0000 | [diff] [blame] | 1512 | xfail('repeat_interleave', ''), # Cannot call sizes() on tensor with symbolic sizes/strides |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1513 | xfail('resize_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition |
| 1514 | xfail('resize_as_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition |
| 1515 | xfail('roll', ''), # Tensors of type TensorImpl do not have numel |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1516 | xfail('searchsorted', ''), # Could not run 'aten::searchsorted.Tensor' with arguments from the 'Meta' backend. ... |
albanD | 496c0a2 | 2023-02-06 18:32:23 +0000 | [diff] [blame] | 1517 | xfail('_segment_reduce', 'offsets'), # aten.segment_reduce.default - couldn't find symbolic meta function/decomposition |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1518 | xfail('special.airy_ai', ''), # aten.special_airy_ai.default - couldn't find symbolic meta function/decomposition |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1519 | xfail('special.bessel_y0', ''), # aten.special_bessel_y0.default - couldn't find symbolic meta function/decomposition |
| 1520 | xfail('special.bessel_y1', ''), # aten.special_bessel_y1.default - couldn't find symbolic meta function/decomposition |
| 1521 | xfail('special.chebyshev_polynomial_t', ''), # aten.special_chebyshev_polynomial_t.default - couldn't find symbolic me... |
| 1522 | xfail('special.chebyshev_polynomial_u', ''), # aten.special_chebyshev_polynomial_u.default - couldn't find symbolic me... |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1523 | xfail('special.hermite_polynomial_h', ''), # aten.special_hermite_polynomial_h.default - couldn't find symbolic meta f... |
| 1524 | xfail('special.hermite_polynomial_he', ''), # aten.special_hermite_polynomial_he.default - couldn't find symbolic meta... |
| 1525 | xfail('special.laguerre_polynomial_l', ''), # aten.special_laguerre_polynomial_l.default - couldn't find symbolic meta... |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1526 | xfail('special.modified_bessel_i0', ''), # aten.special_modified_bessel_i0.default - couldn't find symbolic meta funct... |
| 1527 | xfail('special.modified_bessel_i1', ''), # aten.special_modified_bessel_i1.default - couldn't find symbolic meta funct... |
| 1528 | xfail('special.modified_bessel_k0', ''), # aten.special_modified_bessel_k0.default - couldn't find symbolic meta funct... |
| 1529 | xfail('special.modified_bessel_k1', ''), # aten.special_modified_bessel_k1.default - couldn't find symbolic meta funct... |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1530 | xfail('special.polygamma', 'special_polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic meta function/... |
| 1531 | xfail('special.scaled_modified_bessel_k0', ''), # aten.special_scaled_modified_bessel_k0.default - couldn't find symbo... |
| 1532 | xfail('special.scaled_modified_bessel_k1', ''), # aten.special_scaled_modified_bessel_k1.default - couldn't find symbo... |
Nikolay Korovaiko | bfebf25 | 2022-08-05 03:36:09 +0000 | [diff] [blame] | 1533 | xfail('stft', ''), # argument 'size' must be tuple of ints, but found element of type torch._C.SymIntNode at... |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1534 | xfail('take_along_dim', ''), # dtype of indices should be Long but got Float |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1535 | xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/decomposition |
Michael Voznesensky | bc19494 | 2022-10-25 21:15:40 +0000 | [diff] [blame] | 1536 | xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition |
| 1537 | xfail('unique', ''), # aten._unique2.default - couldn't find symbolic meta function/decomposition |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1538 | } |
Horace He | 7ebdb4c | 2022-08-23 05:11:03 +0000 | [diff] [blame] | 1539 | symbolic_tensor_segfaults = { |
Horace He | b3b9786 | 2022-10-13 20:19:16 +0000 | [diff] [blame] | 1540 | skip('nn.functional.batch_norm') # Segfault?? |
Horace He | 7ebdb4c | 2022-08-23 05:11:03 +0000 | [diff] [blame] | 1541 | } |
| 1542 | |
| 1543 | symbolic_tensor_failures.update(symbolic_tensor_segfaults) |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 1544 | |
Edward Z. Yang | 5582001 | 2022-11-19 12:51:53 -0500 | [diff] [blame] | 1545 | outplace_symbolic_tensor_failures = { |
Peter Bell | 8770a7e | 2023-01-18 11:29:42 +0000 | [diff] [blame] | 1546 | xfail('i0', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition |
Edward Z. Yang | 5582001 | 2022-11-19 12:51:53 -0500 | [diff] [blame] | 1547 | xfail('masked_scatter', ''), # aten.masked_scatter.default - couldn't find symbolic meta function/decomposition |
| 1548 | xfail('nn.functional.rrelu', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition |
| 1549 | } |
| 1550 | |
albanD | 9199f91 | 2022-10-19 18:33:17 -0400 | [diff] [blame] | 1551 | inplace_symbolic_tensor_failures = { |
lezcano | 154e58c | 2022-11-18 11:25:36 +0000 | [diff] [blame] | 1552 | # bugs |
| 1553 | xfail('float_power', ''), # base given to float_power_ has dtype Float but the operation's result requires dtype Double |
| 1554 | # decomp not implemented |
lezcano | 154e58c | 2022-11-18 11:25:36 +0000 | [diff] [blame] | 1555 | xfail('unique', ''), |
| 1556 | # in-place has a different signature than out-of-place |
| 1557 | xfail('uniform', ''), |
albanD | 9199f91 | 2022-10-19 18:33:17 -0400 | [diff] [blame] | 1558 | } |
| 1559 | |
| 1560 | # Copies inputs to inplace operations to avoid inplace modifications |
| 1561 | # to leaves requiring gradient |
| 1562 | def _get_safe_inplace(inplace_variant): |
| 1563 | @functools.wraps(inplace_variant) |
| 1564 | def _fn(t, *args, **kwargs): |
| 1565 | return inplace_variant(t.clone(), *args, **kwargs) |
| 1566 | |
| 1567 | return _fn |
| 1568 | |
| 1569 | def _test_make_fx_helper(self, device, dtype, op, tracing_mode, inplace=False): |
Sherlock Huang | 5faa279 | 2022-11-15 01:06:23 +0000 | [diff] [blame] | 1570 | def f(args, kwargs, extra_args, extra_kwargs): |
albanD | 254b681 | 2022-10-19 18:33:17 -0400 | [diff] [blame] | 1571 | if extra_args: |
| 1572 | for i, t in extra_args: |
| 1573 | args[i] = t.size() |
Sherlock Huang | 5faa279 | 2022-11-15 01:06:23 +0000 | [diff] [blame] | 1574 | if extra_kwargs: |
| 1575 | for k, t in extra_kwargs.items(): |
| 1576 | kwargs[k] = t.size() |
albanD | 9199f91 | 2022-10-19 18:33:17 -0400 | [diff] [blame] | 1577 | |
| 1578 | fn = _get_safe_inplace(op.get_inplace()) if inplace else op.op |
| 1579 | return fn(*args, **kwargs) |
PyTorch MergeBot | 4e33c8c | 2022-06-27 12:06:49 +0000 | [diff] [blame] | 1580 | sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) |
| 1581 | new_f = None |
Horace He | 2f4a517 | 2022-09-21 02:30:50 +0000 | [diff] [blame] | 1582 | |
| 1583 | # Limit ourselves to first 100 inputs so symbolic tracing tests don't take too long |
| 1584 | for sample_input in itertools.islice(sample_inputs_itr, 100): |
albanD | 9199f91 | 2022-10-19 18:33:17 -0400 | [diff] [blame] | 1585 | if inplace and sample_input.broadcasts_input: |
| 1586 | continue |
PyTorch MergeBot | 4e33c8c | 2022-06-27 12:06:49 +0000 | [diff] [blame] | 1587 | args = [sample_input.input] + list(sample_input.args) |
| 1588 | kwargs = sample_input.kwargs |
| 1589 | |
albanD | 254b681 | 2022-10-19 18:33:17 -0400 | [diff] [blame] | 1590 | # If any argument is a torch.Size(), maybe get dynamic shapes for it by: |
| 1591 | # - Create a temporary Tensor whose size is the torch.Size() we want. Note that |
| 1592 | # we use an expanded Tensor as we cannot pass "meta" Tensors to make_fx. |
| 1593 | # - Pass it to make_fx such that it is is converted to a proxy Tensor |
| 1594 | # - Unpack the size in the wrapper to get a torch.Size with dynamic shapes (in |
| 1595 | # symbolic mode, a no-op otherwise) |
| 1596 | extra_args = [] |
Sherlock Huang | 5faa279 | 2022-11-15 01:06:23 +0000 | [diff] [blame] | 1597 | extra_kwargs = {} |
albanD | 254b681 | 2022-10-19 18:33:17 -0400 | [diff] [blame] | 1598 | for i, arg in enumerate(args): |
| 1599 | if isinstance(arg, torch.Size): |
Sherlock Huang | 5faa279 | 2022-11-15 01:06:23 +0000 | [diff] [blame] | 1600 | extra_args.append((i, torch.empty(arg, device="cpu"))) |
| 1601 | for key, value in kwargs.items(): |
| 1602 | if isinstance(value, torch.Size): |
| 1603 | extra_kwargs[key] = torch.empty(value, device="cpu") |
albanD | 254b681 | 2022-10-19 18:33:17 -0400 | [diff] [blame] | 1604 | |
PyTorch MergeBot | 4e33c8c | 2022-06-27 12:06:49 +0000 | [diff] [blame] | 1605 | try: |
Sherlock Huang | 5faa279 | 2022-11-15 01:06:23 +0000 | [diff] [blame] | 1606 | new_f = make_fx(f, tracing_mode=tracing_mode)(args, kwargs, extra_args, extra_kwargs) |
PyTorch MergeBot | 4e33c8c | 2022-06-27 12:06:49 +0000 | [diff] [blame] | 1607 | except DynamicOutputShapeException as e: |
| 1608 | self.skipTest("Dynamic output shape operation in trace") |
PyTorch MergeBot | 4e33c8c | 2022-06-27 12:06:49 +0000 | [diff] [blame] | 1609 | for arg in args: |
| 1610 | if isinstance(arg, torch.Tensor) and arg.dtype == torch.float: |
| 1611 | arg.uniform_(0, 1) |
| 1612 | try: |
Sherlock Huang | 5faa279 | 2022-11-15 01:06:23 +0000 | [diff] [blame] | 1613 | old_out = f(args, kwargs, extra_args, extra_kwargs) |
PyTorch MergeBot | 4e33c8c | 2022-06-27 12:06:49 +0000 | [diff] [blame] | 1614 | except Exception: |
| 1615 | continue |
Sherlock Huang | 5faa279 | 2022-11-15 01:06:23 +0000 | [diff] [blame] | 1616 | new_out = wrapper_set_seed(new_f, args, kwargs, extra_args, extra_kwargs) |
PyTorch MergeBot | 4e33c8c | 2022-06-27 12:06:49 +0000 | [diff] [blame] | 1617 | self.assertEqual(new_out, old_out) |
| 1618 | |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 1619 | class TestProxyTensorOpInfo(TestCase): |
Richard Zou | 44b09bf | 2023-04-18 06:51:23 -0700 | [diff] [blame] | 1620 | @ops(op_db + custom_op_db, allowed_dtypes=(torch.float,)) |
PyTorch MergeBot | 4e33c8c | 2022-06-27 12:06:49 +0000 | [diff] [blame] | 1621 | @skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures) |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 1622 | def test_make_fx_exhaustive(self, device, dtype, op): |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1623 | _test_make_fx_helper(self, device, dtype, op, "real") |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 1624 | |
Richard Zou | 44b09bf | 2023-04-18 06:51:23 -0700 | [diff] [blame] | 1625 | @ops(op_db + custom_op_db, allowed_dtypes=(torch.float,)) |
PyTorch MergeBot | 4e33c8c | 2022-06-27 12:06:49 +0000 | [diff] [blame] | 1626 | @skipOps('TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive', make_fx_failures.union(fake_tensor_failures)) |
| 1627 | def test_make_fx_fake_exhaustive(self, device, dtype, op): |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1628 | _test_make_fx_helper(self, device, dtype, op, "fake") |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 1629 | |
Richard Zou | 44b09bf | 2023-04-18 06:51:23 -0700 | [diff] [blame] | 1630 | @ops(op_db + custom_op_db, allowed_dtypes=(torch.float,)) |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1631 | @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive', |
Edward Z. Yang | 5582001 | 2022-11-19 12:51:53 -0500 | [diff] [blame] | 1632 | make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | outplace_symbolic_tensor_failures) |
Horace He | 1a18ff3 | 2022-07-23 19:03:38 +0000 | [diff] [blame] | 1633 | def test_make_fx_symbolic_exhaustive(self, device, dtype, op): |
| 1634 | _test_make_fx_helper(self, device, dtype, op, "symbolic") |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 1635 | |
Richard Zou | 44b09bf | 2023-04-18 06:51:23 -0700 | [diff] [blame] | 1636 | @ops(op_db + custom_op_db, allowed_dtypes=(torch.float,)) |
albanD | 9199f91 | 2022-10-19 18:33:17 -0400 | [diff] [blame] | 1637 | @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_inplace', |
| 1638 | make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | inplace_symbolic_tensor_failures) |
| 1639 | def test_make_fx_symbolic_exhaustive_inplace(self, device, dtype, op): |
| 1640 | if not op.get_inplace(): |
| 1641 | self.skipTest("No inplace variable for this op") |
| 1642 | _test_make_fx_helper(self, device, dtype, op, "symbolic", inplace=True) |
| 1643 | |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 1644 | |
| 1645 | only_for = ("cpu") |
Horace He | 4d88aff | 2022-06-07 00:28:53 +0000 | [diff] [blame] | 1646 | instantiate_device_type_tests(TestProxyTensorOpInfo, globals(), only_for=only_for) |
| 1647 | |
| 1648 | |
| 1649 | if __name__ == '__main__': |
| 1650 | run_tests() |