| # Owner(s): ["module: ProxyTensor"] |
| |
| from torch.testing._internal.common_utils import TestCase, run_tests, xfail_inherited_tests |
| import torch |
| import unittest |
| import warnings |
| import operator |
| from collections.abc import Iterable |
| from torch.testing._internal.common_device_type import instantiate_device_type_tests |
| from torch.testing._internal.common_methods_invocations import op_db, wrapper_set_seed, skip, xfail, skipOps |
| from torch._subclasses.fake_tensor import DynamicOutputShapeException, DataDependentOutputException |
| |
| from torch._decomp import decomposition_table |
| from torch.fx.experimental.symbolic_shapes import ( |
| sym_float, eval_guards, bind_symbols, fx_placeholder_vals, fx_placeholder_targets, |
| constrain_range, guard_int, GuardOnDataDependentSymNode |
| ) |
| from torch.testing._internal.custom_op_db import custom_op_db |
| from torch.testing._internal.common_device_type import ops |
| from torch._C import _disabled_torch_function_impl |
| from torch.fx.experimental.proxy_tensor import make_fx, DecompositionInterpreter, get_isolated_graphmodule |
| from torch.utils._pytree import tree_map |
| from torch import nn |
| import re |
| |
| import functools |
| import itertools |
| |
| aten = torch.ops.aten |
| |
| HAS_CUDA = torch.cuda.is_available() |
| |
| |
| def strip_end(s, suffix): |
| if suffix and s.endswith(suffix): |
| return s[:-len(suffix)] |
| else: |
| return s |
| |
| |
| def show_guards(gm): |
| names = [strip_end(n, "_1") for n in fx_placeholder_targets(gm)] |
| return "\n".join( |
| gm.shape_env.produce_guards(fx_placeholder_vals(gm), names, _simplified=True, constraint_inputs=None) |
| ) |
| |
| |
| def process_failures(): |
| """ |
| Takes file containing failures like |
| |
| FAILED test/test_proxy_tensor.py::TestProxyTensorOpInfoCPU::test_make_fx_symbolic_exhaustive___getitem___cpu_float32 - RuntimeError: aten.size.default - couldn't find symbolic meta function/decomposition # noqa: B950 |
| |
| and processes them into a list of opinfo xfails |
| """ |
| f = open('pytest_failures') |
| failures = f.readlines() |
| failures = [i.strip() for i in failures] |
| |
| def process_failure_string(s, matcher): |
| out = re.search(matcher, s) |
| return out.groups() |
| |
| SYMBOLIC_TRACE_MATCH = r'exhaustive_(.*)_cpu.*: (.*)' |
| failures = [process_failure_string(s, SYMBOLIC_TRACE_MATCH) for s in failures] |
| |
| def create_normalized_name(op): |
| if op.variant_test_name == '': |
| s = op.name |
| else: |
| s = f"{op.name}.{op.variant_test_name}" |
| return s.replace('.', '_') |
| |
| remap_opinfo = {create_normalized_name(op): (op.name, op.variant_test_name) for op in op_db} |
| |
| print("symbolic_tensor_failures = {") |
| for failure, reason in failures: |
| print(f" xfail{remap_opinfo[failure]}, # {reason}") |
| print("}") |
| |
| |
| USE_TORCHVISION = False |
| try: |
| import torchvision |
| USE_TORCHVISION = True |
| except ImportError: |
| warnings.warn("Couldn't import torchvision. Some of our tests use it, try " |
| "to install it with commands from pytorch.org, post-fixed with " |
| "`--no-deps` to avoid overwriting the pytorch installation", |
| UserWarning) |
| |
| |
| def _create_new_input(x): |
| if not isinstance(x, torch.Tensor): |
| return x |
| if x.dtype != torch.float: |
| return x + 1 |
| if x.is_leaf: |
| return torch.rand_like(x, requires_grad=x.requires_grad) |
| else: |
| return torch.rand_like(x) |
| |
| """ |
| Delays a cos being executed on the unwraptensor until its used. Simulates a CommTensor used |
| """ |
| class UnwrapTensor(torch.Tensor): |
| @staticmethod |
| def __new__(cls, tensor: torch.Tensor): |
| r = torch.Tensor._make_wrapper_subclass( |
| cls, |
| tensor.size(), |
| dtype=tensor.dtype, |
| device=tensor.device, |
| layout=tensor.layout, |
| requires_grad=tensor.requires_grad, |
| ) |
| r._tensor = tensor |
| return r |
| |
| def __repr__(self): |
| # TODO: consider all_gather the local tensors for better debugging |
| return f"UnwrapTensor({self._tensor})" |
| |
| __torch_function__ = _disabled_torch_function_impl |
| |
| @classmethod |
| def __torch_dispatch__(cls, func, types, args=(), kwargs=None): |
| def unwrap(e): |
| ret = e |
| if isinstance(e, UnwrapTensor): |
| ret = e._tensor.cos() |
| |
| return ret |
| |
| args = tree_map(unwrap, args) |
| kwargs = tree_map(unwrap, kwargs) |
| return func(*args, **kwargs) |
| |
| class TestGenericProxyTensor(TestCase): |
| # WARNING: if any of your inputs are index tensors, DO NOT use this |
| # function |
| def _test(self, f, inps): |
| fx_f = make_fx(f, tracing_mode=self.tracing_mode)(*inps) |
| new_inps = tree_map(_create_new_input, inps) |
| r1 = fx_f(*new_inps) |
| r2 = f(*new_inps) |
| self.assertEqual(r1, r2) |
| |
| def test_pre_autograd_mode_stack(self): |
| def f(a): |
| b = torch.ones(4, 4) |
| return torch.matmul(a, b) |
| # We expect to see matmul in the trace - it should NOT be decomposed into mm. |
| # Also, torch.ones() doesn't show up in the trace. |
| # This is annoying but expected: ones() never dispatches to the Autograd dispatch key, |
| # so our mode never sees it - it goes directly to the BackendSelect key. |
| inp = torch.ones(4, 4) |
| # Test that make_fx(pre_autograd=True) clears caches properly. |
| from torch._dispatch.python import enable_python_dispatcher |
| with enable_python_dispatcher(): |
| out1 = f(inp) |
| fx_g = make_fx(f, pre_autograd=True)(inp) |
| self.assertExpectedInline(fx_g.code.strip(), """\ |
| def forward(self, a_1): |
| ones = torch.ops.aten.ones.default([4, 4], device = device(type='cpu'), pin_memory = False) |
| matmul = torch.ops.aten.matmul.default(a_1, ones); a_1 = ones = None |
| return matmul""") |
| |
| |
| def test_make_fx_simple(self): |
| def f(x): |
| return torch.sin(x) |
| self._test(f, (torch.randn(3),)) |
| |
| def test_scalar_device(self, device='cpu'): |
| def f(a, b): |
| return a + b |
| self._test(f, [torch.randn(3, device=device), torch.tensor(5)]) |
| |
| def test_isolated_graphmodule(self): |
| def is_any_sum(gm): |
| return any(node.target == torch.ops.aten.sum.default for node in gm.graph.nodes) |
| |
| def is_any_digamma(gm): |
| return any(node.target == torch.ops.aten.digamma.default for node in gm.graph.nodes) |
| |
| def is_any_sigmoid(gm): |
| return any(node.target == torch.ops.aten.sigmoid.default for node in gm.graph.nodes) |
| |
| def inner(x): |
| return torch.sum(x) |
| |
| def f(x): |
| gm = get_isolated_graphmodule(inner, (x,), {}) |
| self.assertTrue(is_any_sum(gm)) |
| return x + torch.randn(x.shape) |
| |
| # get_isolated_graphmodule uses make_fx internally that shouldn't be traced |
| # by the outer make_fx call |
| traced = make_fx(f)(torch.randn(3)) |
| self.assertFalse(is_any_sum(traced)) |
| |
| # When factory functions are used, they should not be traced |
| # by the outer make_fx call |
| def inner_with_factory(): |
| val = torch.tensor(float(1)) |
| val.add_(2) |
| return torch.full((10, 10), val).sum() |
| |
| def f1(x): |
| gm = get_isolated_graphmodule(inner_with_factory, (), {}) |
| self.assertTrue(is_any_sum(gm)) |
| return torch.sigmoid(x) |
| |
| def f2(x): |
| gm = get_isolated_graphmodule(f1, (x,), {}) |
| self.assertFalse(is_any_sum(gm)) |
| self.assertTrue(is_any_sigmoid(gm)) |
| return torch.digamma(x) |
| |
| traced = make_fx(f2)(torch.randn(3)) |
| self.assertFalse(is_any_sum(traced)) |
| self.assertFalse(is_any_sigmoid(traced)) |
| self.assertTrue(is_any_digamma(traced)) |
| |
| # Verify nested make_fx calls don't make factory functions to be leaked |
| # into the outer graph. Verify that `make_fx`` itself does not leak its execution. |
| def f2(x): |
| gm = make_fx(f1)(x) |
| self.assertFalse(is_any_sum(gm)) |
| self.assertTrue(is_any_sigmoid(gm)) |
| return torch.digamma(x) |
| |
| traced = make_fx(f2)(torch.randn(3)) |
| self.assertFalse(is_any_sum(traced)) |
| self.assertFalse(is_any_sigmoid(traced)) |
| self.assertTrue(is_any_digamma(traced)) |
| |
| # Verify that the `forward`` function of a graph module produced as a |
| # side effect of an interior `make_fx` is still traced |
| def f3(x): |
| gm = make_fx(f1)(x) |
| self.assertFalse(is_any_sum(gm)) |
| self.assertTrue(is_any_sigmoid(gm)) |
| # `gm.forward`` is still traced |
| return torch.digamma(gm(x)) |
| |
| traced = make_fx(f3)(torch.randn(3)) |
| self.assertFalse(is_any_sum(traced)) |
| self.assertTrue(is_any_sigmoid(traced)) |
| self.assertTrue(is_any_digamma(traced)) |
| |
| # Verify interaction with non-ProxyTensor modes |
| from torch.testing._internal.logging_tensor import LoggingTensorMode |
| |
| def f1_logging(x): |
| with LoggingTensorMode(): |
| gm = get_isolated_graphmodule(inner_with_factory, (), {}) |
| self.assertTrue(is_any_sum(gm)) |
| return torch.sigmoid(x) |
| |
| def f2_logging(x): |
| with LoggingTensorMode(), LoggingTensorMode(): |
| gm = get_isolated_graphmodule(f1_logging, (x,), {}) |
| self.assertFalse(is_any_sum(gm)) |
| self.assertTrue(is_any_sigmoid(gm)) |
| return torch.digamma(x) |
| |
| traced = make_fx(f2_logging)(torch.randn(3)) |
| self.assertFalse(is_any_sum(traced)) |
| self.assertFalse(is_any_sigmoid(traced)) |
| self.assertTrue(is_any_digamma(traced)) |
| |
| # Verify interaction with another tensor subclass |
| # This case currently doesn't work and should raise an error |
| # See: https://github.com/pytorch/pytorch/pull/81764#issuecomment-1200472068 |
| from torch.testing._internal.logging_tensor import LoggingTensor |
| |
| def f1_logging_tensor(x): |
| gm = get_isolated_graphmodule(inner_with_factory, (), {}) |
| self.assertTrue(is_any_sum(gm)) |
| return torch.sigmoid(x) |
| |
| def f2_logging_tensor(x): |
| x = LoggingTensor(x) |
| gm = get_isolated_graphmodule(f1_logging_tensor, (x,), {}) |
| self.assertFalse(is_any_sum(gm)) |
| self.assertTrue(is_any_sigmoid(gm)) |
| return torch.digamma(x) |
| |
| traced = make_fx(f2_logging_tensor)(torch.randn(3)) |
| self.assertFalse(is_any_sum(traced)) |
| self.assertFalse(is_any_sigmoid(traced)) # this fails, sigmoid is traced with LoggingTensor |
| self.assertTrue(is_any_digamma(traced)) |
| |
| # See https://github.com/pytorch/pytorch/issues/97541 |
| def test_empty_like_doesnt_burn_in_defaults(self): |
| def f(x): |
| return torch.empty_like(x) |
| out = make_fx(f)(torch.randn(3)) |
| self.assertExpectedInline(out.code.strip(), """\ |
| def forward(self, x_1): |
| empty_like = torch.ops.aten.empty_like.default(x_1, pin_memory = False); x_1 = None |
| return empty_like""") |
| |
| def test_proxy_tensor_mode_with_decomp_table_preserves_proxy(self): |
| def f(x): |
| y = x.new_zeros(x.size()) |
| y.copy_(x) |
| return y |
| |
| def _new_zeros_decomp(inp, size, dtype=None, layout=None, device=None, pin_memory=None): |
| return torch.zeros(size, dtype=inp.dtype, device=inp.device) |
| |
| factory_func_decomp = {torch.ops.aten.new_zeros.default: _new_zeros_decomp} |
| |
| # When new_zeros() decomposes into torch.zero(), we expect ProxyTensorMode |
| # to still be (re-entrantly) enabled, so that the `torch.zero()` call |
| # returns a ProxyTensor. |
| out = make_fx(f, decomposition_table=factory_func_decomp)(torch.ones(2)) |
| self.assertExpectedInline(out.code, """\ |
| |
| |
| |
| def forward(self, x_1): |
| zeros = torch.ops.aten.zeros.default([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) |
| copy_ = torch.ops.aten.copy_.default(zeros, x_1); zeros = x_1 = None |
| return copy_ |
| """) |
| |
| def test_make_fx_reentrant_dispatch(self): |
| def f(x): |
| return torch.ops.aten.norm.Scalar(x, 2.0) |
| |
| def norm_decomp(x, p=2.0): |
| if p != 2.0: |
| raise RuntimeError("can't handle with p != 2") |
| return torch.sqrt(torch.sum(torch.square(x))) |
| |
| decomp = {torch.ops.aten.norm.Scalar: norm_decomp} |
| |
| traced = make_fx(f, decomposition_table=decomp, tracing_mode=self.tracing_mode)(torch.rand(3)) |
| |
| for n in traced.graph.nodes: |
| self.assertTrue("square" not in str(n.target)) |
| self.assertTrue("norm" not in str(n.target)) |
| |
| @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision") |
| def test_resnet18_backward_trace(self): |
| mod = torchvision.models.resnet18() |
| |
| # An old version of this test called the module directly. This works |
| # for tracing_mode == "real", but for fake tensors, we also have to |
| # ensure that the parameters and buffers get wrapped in fake tensors |
| # because free fake tensors are not supported. Fortunately functional_call |
| # does precisely this for us. |
| def f(x, params, buffers): |
| for p in params.values(): |
| p.grad = None |
| loss = torch.func.functional_call(mod, {**params, **buffers}, (x,)).sum() |
| # I could have done this with the functional API, but there is |
| # plenty of exercising this; I want to show mutating API still |
| # works |
| loss.backward() |
| return [p.grad for p in params.values()] |
| |
| inp = torch.randn(3, 3, 250, 250) |
| self._test(f, [inp, dict(mod.named_parameters()), dict(mod.named_buffers())]) |
| |
| def test_varargs(self): |
| def f(*args): |
| return sum(args) |
| |
| self._test(f, [torch.randn(2), torch.randn(2)]) |
| |
| def test_proxy_tensor(self): |
| def f_grad(x): |
| val = x.cos().cos().sum() |
| return torch.autograd.grad(val, x) |
| |
| def f_backward(x): |
| val = x.cos().cos().sum() |
| val.backward() |
| return x.grad |
| |
| for f in [f_grad, f_backward]: |
| self._test(f, [torch.randn(3, requires_grad=True)]) |
| |
| def test_pickle_issue89626(self): |
| import pickle |
| x = torch.randn(2) |
| make_fx(lambda x: x * 2, tracing_mode=self.tracing_mode)(x) |
| pickle.dumps(x) |
| |
| def test_inplace_metadata(self): |
| def f(x): |
| x = x.clone() |
| x.unsqueeze_(-1) |
| assert x.shape[-1] == 1 |
| return x |
| |
| self._test(f, [torch.randn(5)]) |
| |
| def test_mode_tracing_factory_function(self): |
| def f(x): |
| return x + torch.randn(x.shape) |
| |
| # default behavior should trace factory functions |
| traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3)) |
| self.assertTrue( |
| any( |
| node.target == aten.randn.default |
| for node in traced.graph.nodes |
| ) |
| ) |
| |
| def test_val_metadata_mutation(self): |
| def f(x): |
| y = x.clone() |
| y.unsqueeze_(0) |
| return y |
| |
| traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3, requires_grad=True)) |
| self.assertEqual([ |
| tuple(node.meta['val'].shape) |
| for node in traced.graph.nodes |
| if 'val' in node.meta |
| ], [(3,), (3,), (1, 3)]) |
| |
| def test_make_fx_overloads(self): |
| def f(x): |
| return x.cos() + torch.randn(x.shape) |
| |
| traced = make_fx(f, tracing_mode=self.tracing_mode)(torch.randn(3)) |
| |
| self.assertTrue(all(isinstance(node.target, torch._ops.OpOverload) |
| for node in traced.graph.nodes if node.op == 'call_function')) |
| |
| def test_tensor_constants(self): |
| def f(): |
| val = torch.tensor(float('inf')) |
| return torch.full((100, 100), val) |
| |
| self._test(f, []) |
| |
| def test_allclose(self): |
| def f(a, b): |
| return torch.allclose(a, b) |
| |
| def test_f(): |
| make_fx(f, tracing_mode=self.tracing_mode)( |
| torch.zeros(3), torch.zeros(3) |
| ) |
| |
| if self.tracing_mode != "real": |
| self.assertRaises(DataDependentOutputException, test_f) |
| else: |
| self.assertRaisesRegex(RuntimeError, "data-dependent", test_f) |
| |
| def test_constant_proxy_tensor_mut(self): |
| def f(): |
| val = torch.tensor(float(1)) |
| val.add_(2) |
| return torch.full((100, 100), val) |
| |
| g = make_fx(f, tracing_mode=self.tracing_mode)() |
| self.assertEqual(g(), f()) |
| # In case we mutated shared state in the g graph! |
| self.assertEqual(g(), f()) |
| |
| def test_constant_unbind(self): |
| def f(): |
| val = torch.tensor([2]) |
| r, = torch.unbind(val, 0) |
| return r.item() |
| |
| g = make_fx(f, tracing_mode=self.tracing_mode)() |
| self.assertEqual(g(), f()) |
| |
| def test_constant_blowup(self): |
| def f(): |
| val = torch.tensor([2]) |
| blowup = val.repeat(1000) |
| return bool(blowup.sum().item() == 2) |
| |
| def test_f(): |
| make_fx(f, tracing_mode=self.tracing_mode)() |
| |
| if self.tracing_mode == "fake": |
| self.assertRaises(DataDependentOutputException, test_f) |
| else: |
| self.assertRaisesRegex(RuntimeError, "data-dependent", test_f) |
| |
| def test_constant_random(self): |
| def f(): |
| val = torch.tensor([2.0]) |
| val.normal_() |
| return bool(val.item() == 2.1) |
| |
| def test_f(): |
| make_fx(f, tracing_mode=self.tracing_mode)() |
| |
| if self.tracing_mode == "fake": |
| self.assertRaises(DataDependentOutputException, test_f) |
| else: |
| self.assertRaisesRegex(RuntimeError, "data-dependent", test_f) |
| |
| def test_decomposition_interpreter(self): |
| def fn(x): |
| return torch.nn.functional.silu(x) |
| |
| x = torch.rand((4, 4)) |
| fx_module = make_fx(fn, tracing_mode=self.tracing_mode, decomposition_table=None)(x) |
| |
| found_silu = False |
| for n in fx_module.graph.nodes: |
| if n.target == torch.ops.aten.silu or n.target == torch.ops.aten.silu.default: |
| found_silu = True |
| |
| self.assertTrue(found_silu) |
| |
| new_graph = torch.fx.Graph() |
| silu_decomp_table = {torch.ops.aten.silu.default: decomposition_table[torch.ops.aten.silu.default]} |
| DecompositionInterpreter( |
| fx_module, |
| new_graph=new_graph, |
| decomposition_table=silu_decomp_table, |
| ).run(x) |
| |
| decomposed_module = torch.fx.GraphModule(fx_module, new_graph) |
| |
| for n in decomposed_module.graph.nodes: |
| self.assertTrue(n.target != torch.ops.aten.silu) |
| self.assertTrue(n.target != torch.ops.aten.silu.default) |
| |
| self.assertEqual(fx_module(x), decomposed_module(x)) |
| |
| def test_make_fx_model_fwd_bwd(self): |
| class Foo(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(5, 5) |
| |
| def forward(self, x): |
| return self.linear(x).relu() |
| |
| model = Foo() |
| |
| def f(x, params): |
| out = torch.func.functional_call(model, params, x).sum() |
| out.backward() |
| return list(params.values()) |
| input = torch.randn(3, 5, requires_grad=True) |
| params = dict(model.named_parameters()) |
| fx_f = make_fx(f, tracing_mode=self.tracing_mode)(input, params) |
| # fx may change the order of parameters in list, so using set() to compare |
| self.assertTrue( |
| torch.allclose(fx_f(input, params)[0], f(input, params)[0]) |
| or |
| torch.allclose(fx_f(input, params)[0], f(input, params)[1]) |
| ) |
| self.assertTrue( |
| torch.allclose(fx_f(input, params)[1], f(input, params)[0]) |
| or |
| torch.allclose(fx_f(input, params)[1], f(input, params)[1]) |
| ) |
| |
| def test_make_fx_model_double_param(self): |
| class Emformer(torch.nn.Module): |
| def __init__( |
| self, |
| input_dim: int = 256, |
| ) -> None: |
| super().__init__() |
| |
| self.layer_norm = torch.nn.LayerNorm(input_dim) |
| |
| def forward(mod_self, x): # noqa: B902 |
| self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor)) |
| y = mod_self.layer_norm(x) |
| self.assertTrue(isinstance(mod_self.layer_norm.weight, torch.Tensor)) |
| z = mod_self.layer_norm(y) |
| return z |
| |
| |
| gm = make_fx(Emformer())(torch.randn(16, 1, 256)) |
| ops = {n.target for n in gm.graph.nodes if n.op == 'call_function'} |
| self.assertEqual(len(ops), 2) |
| |
| |
| def test_make_fx_model_fwd_bwd_wgtupdate(self): |
| class Foo(torch.nn.Module): |
| def __init__(self): |
| super().__init__() |
| self.linear = torch.nn.Linear(5, 5) |
| |
| def forward(self, x): |
| return self.linear(x).relu() |
| |
| model = Foo() |
| |
| def f(args, params, buffers): |
| for p in params.values(): |
| p.grad = None |
| if not isinstance(args, Iterable): |
| args = [args] |
| params_and_buffers = {**params, **buffers} |
| out = torch.func.functional_call(model, params_and_buffers, args) |
| out.sum().backward() |
| return [p - 1e-4 * p.grad for p in params.values()] |
| |
| input = torch.randn(3, 5, requires_grad=True) |
| params = dict(model.named_parameters()) |
| buffers = dict(model.named_buffers()) |
| fx_f = make_fx(f, tracing_mode=self.tracing_mode)(input, params, buffers) |
| # fx may change the order of parameters in list, so using set() to compare |
| # also there is a numerical difference in results so changing atol from 1e-08 to 1e-03 |
| self.assertTrue( |
| torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[0], atol=1e-03) |
| or |
| torch.allclose(fx_f(input, params, buffers)[0], f(input, params, buffers)[1], atol=1e-03) |
| ) |
| self.assertTrue( |
| torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[0], atol=1e-03) |
| or |
| torch.allclose(fx_f(input, params, buffers)[1], f(input, params, buffers)[1], atol=1e-03) |
| ) |
| |
| def test_trace_subclasses(self): |
| def f1(x): |
| x = UnwrapTensor(x) |
| y = x * 2 |
| return y |
| |
| def f2(x): |
| wrapped = UnwrapTensor(x) |
| y = x * wrapped |
| return y |
| |
| inp = [torch.randn(5)] |
| self._test(f1, inp) |
| self._test(f2, inp) |
| |
| def test_partial_decomp(self): |
| def f(a, b, c): |
| x = torch.addmm(a, b, c) |
| y = torch.addmm(a, b, c, beta=2, alpha=1) |
| return x + y |
| inps = [torch.randn(5, 5), torch.randn(5, 5), torch.randn(5, 5)] |
| fx_g = make_fx(f)(*inps) |
| |
| def addmm(a, b, c, beta=1, alpha=1): |
| if beta == 1 and alpha == 1: |
| return NotImplemented |
| return beta * a + alpha * (b @ c) |
| |
| decomposed_fx = make_fx(f, decomposition_table={aten.addmm.default: addmm})(*inps) |
| |
| self.assertEqual(fx_g(*inps), decomposed_fx(*inps)) |
| self.assertEqual(len([n for n in fx_g.graph.nodes if n.target == aten.addmm.default]), 2) |
| self.assertEqual(len([n for n in decomposed_fx.graph.nodes if n.target == aten.addmm.default]), 1) |
| |
| def test_decomp_of_capture(self): |
| val = torch.randn(5) |
| |
| def f(x): |
| return x.t() + val.t() |
| |
| def nop(x): |
| return x.cos() |
| |
| traced = make_fx(f, decomposition_table={torch.ops.aten.t.default: nop})(torch.randn(5)) |
| self.assertEqual(len([n for n in traced.graph.nodes if n.target == torch.ops.aten.t.default]), 0) |
| |
| |
| @unittest.skipIf(not HAS_CUDA, 'CUDA-only test') |
| def test_amp_cache(self): |
| layer = torch.nn.Conv2d(3, 3, 3).cuda() |
| |
| def f(x, w): |
| return torch.nn.functional.conv2d(x, w, stride=layer.stride) |
| |
| inp = torch.randn(4, 3, 10, 10, device='cuda') |
| with torch.autocast('cuda'): |
| out_graph = make_fx(f)(inp, layer.weight).graph |
| out_graph2 = make_fx(f)(inp, layer.weight).graph |
| |
| self.assertEqual(len(out_graph.nodes), len(out_graph2.nodes)) |
| for a, b in zip(out_graph.nodes, out_graph2.nodes): |
| self.assertEqual(a.op, b.op) |
| |
| def test_strides(self): |
| def f(x): |
| self.assertTrue(x.is_contiguous()) |
| self.assertFalse(x.is_contiguous(memory_format=torch.channels_last)) |
| x = x.permute(0, 3, 1, 2) |
| self.assertFalse(x.is_contiguous()) |
| self.assertTrue(x.is_contiguous(memory_format=torch.channels_last)) |
| return x |
| make_fx(f)(torch.randn(2, 3, 4, 5)) |
| |
| def f(x): |
| self.assertTrue(x.is_contiguous()) |
| y = x[:, 1] |
| self.assertFalse(y.is_contiguous()) |
| y = x[:, ::2] |
| self.assertFalse(y.is_contiguous()) |
| return x.cos() |
| |
| make_fx(f)(torch.randn(2, 3, 4, 5)) |
| |
| def test_pr_86917(self): |
| # Tests the issue brought up here https://github.com/pytorch/pytorch/pull/86917#issuecomment-1283155344 |
| def f(a, b): |
| return torch.ops.aten.nll_loss_forward(a, b, None, 1, 10) |
| |
| self._test(f, [torch.randn(1, 10), torch.zeros(1, dtype=torch.long)]) |
| |
| class TestGenericProxyTensorReal(TestGenericProxyTensor): |
| tracing_mode = "real" |
| |
| |
| class TestGenericProxyTensorFake(TestGenericProxyTensor): |
| tracing_mode = "fake" |
| |
| |
| @xfail_inherited_tests([ |
| "test_make_fx_overloads", |
| ]) |
| class TestGenericProxyTensorSymbolic(TestGenericProxyTensor): |
| tracing_mode = "symbolic" |
| |
| |
| del TestGenericProxyTensor |
| |
| |
| class TestRealProxyTensor(TestCase): |
| pass |
| |
| class TestFakeProxyTensor(TestCase): |
| def test_issue82547(self): |
| x = nn.Parameter(torch.randn(3, 3)) |
| |
| def f(): |
| return torch.ops.aten.t.default(x) |
| self.assertRaisesRegex(Exception, "Please convert all Tensors", lambda: make_fx(f, tracing_mode="fake")()) |
| |
| class A(torch.Tensor): |
| pass |
| |
| x = A(torch.randn(3, 3)) |
| self.assertRaisesRegex(TypeError, "no implementation found", lambda: make_fx(f, tracing_mode="fake")()) |
| |
| def test_use_fake_and_tensor(self): |
| def f(x, y): |
| z = torch.tensor([2.0, 3.0]) |
| return x + y + z |
| |
| g = make_fx(f, tracing_mode="fake")(torch.randn(2), torch.randn(2)) |
| x, y = torch.randn(2), torch.randn(2) |
| self.assertEqual(g(x, y), f(x, y)) |
| |
| def test_fused_adam(self): |
| # See https://github.com/pytorch/pytorch/issues/99356 |
| params = [torch.randn(10, 10) for _ in range(10)] |
| grads = [torch.randn(10, 10) for _ in range(10)] |
| exp_avgs = [torch.randn(10, 10) for _ in range(10)] |
| exp_avg_sqs = [torch.randn(10, 10) for _ in range(10)] |
| max_exp_avg_sqs = [torch.randn(10, 10) for _ in range(10)] |
| state_steps = [torch.tensor(0) for _ in range(10)] |
| |
| def fused_adam(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps): |
| (new_params, _, _, _, _) = aten._fused_adam.default( |
| params, |
| grads, |
| exp_avgs, |
| exp_avg_sqs, |
| max_exp_avg_sqs, |
| state_steps, |
| lr=0.1, |
| beta1=0.9, |
| beta2=0.999, |
| weight_decay=0.01, |
| eps=1e-8, |
| amsgrad=False, |
| maximize=False, |
| ) |
| |
| for p, new_p in zip(params, new_params): |
| p.copy_(new_p) |
| |
| return params |
| |
| gm = make_fx(fused_adam, tracing_mode='fake')( |
| params, |
| grads, |
| exp_avgs, |
| exp_avg_sqs, |
| max_exp_avg_sqs, |
| state_steps, |
| ) |
| ensure_ops_have_val = [aten._fused_adam.default, operator.getitem] |
| for n in gm.graph.nodes: |
| if n.op == "call_function" and n.target in ensure_ops_have_val: |
| self.assertIn('val', n.meta) |
| |
| def test_alias(self): |
| def f(x): |
| return torch.ops.aten.alias(x) |
| |
| r = str(make_fx(f, tracing_mode="fake")(torch.randn(2)).code).strip() |
| # NB: this should not have a detach call |
| self.assertExpectedInline(r, """\ |
| def forward(self, x_1): |
| alias = torch.ops.aten.alias.default(x_1); x_1 = None |
| return alias""") |
| |
| def test_meta(self): |
| def f(x): |
| a = x.cos() |
| b = torch.var_mean(a, dim=0) |
| c = b * 2 |
| return c |
| |
| out = make_fx(f, tracing_mode="fake")(torch.randn(5, 5)) |
| for n in out.graph.nodes: |
| if n.op == 'output': |
| continue |
| self.assertTrue('val' in n.meta) |
| |
| def _get_node(fx_g, cond): |
| for n in fx_g.graph.nodes: |
| if cond(n): |
| return n |
| raise AssertionError |
| |
| def _get_free_symbols(shape_env): |
| vars = tuple(shape_env.var_to_val.keys()) |
| return len([var for var in vars if var not in shape_env.replacements]) |
| |
| def _trace(f, *args): |
| inps = [torch.randn(arg) for arg in args] |
| return make_fx(f, tracing_mode="symbolic")(*inps) |
| |
| # TODO: Need to test the guards themselves specifically as well |
| class TestSymbolicTracing(TestCase): |
| def _test_dynamic(self, fn, trace_inputs, test_inputs, assert_eq=True): |
| """ |
| Tests fn traced with trace_inputs against test_inputs |
| Also returns shape env |
| """ |
| trace_inputs = [torch.randn(shape) for shape in trace_inputs] |
| traced_f = make_fx(fn, tracing_mode="symbolic")(*trace_inputs) |
| for input in test_inputs: |
| input = [torch.randn(shape) for shape in input] |
| rx, ry = traced_f(*input), fn(*input) |
| if assert_eq: |
| self.assertEqual(rx, ry) |
| return traced_f |
| |
| |
| def test_debug_interpreter(self): |
| import torch.library |
| from torch.library import Library |
| |
| foo = Library("foo", "DEF") |
| foo.define("foo(Tensor self) -> Tensor") |
| |
| # Operator where meta and cpu disagree on strides |
| @torch.library.impl(foo, "foo", "CPU") |
| def foo_cpu(x): |
| return x.clone().T |
| |
| @torch.library.impl(foo, "foo", "Meta") |
| def foo_meta(x): |
| return x.clone() |
| |
| def f(x): |
| return torch.ops.foo.foo.default(x) |
| |
| gm = make_fx(f, tracing_mode="symbolic")(torch.randn(2, 2)) |
| from torch._functorch.compilers import DebugInterpreter |
| |
| interp = DebugInterpreter(gm) |
| |
| # input mismatch is caught (indicates guard problem) |
| self.assertRaisesRegex( |
| AssertionError, r"3 != 1", |
| lambda: interp.run(torch.randn(3, 3).T), |
| ) |
| |
| # Catch the incorrect meta |
| self.assertRaisesRegex( |
| AssertionError, r"\(3, 1\) != \(1, 3\)", |
| lambda: interp.run(torch.randn(3, 3)) |
| ) |
| |
| def test_resize_from_zero(self): |
| def f(x, y): |
| x.resize_(y.size(0)) |
| |
| r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(0), torch.empty(2)).code).strip() |
| self.assertExpectedInline(r, """\ |
| def forward(self, x_1, y_1): |
| sym_size = torch.ops.aten.sym_size(y_1, 0); y_1 = None |
| resize_ = torch.ops.aten.resize_.default(x_1, [sym_size]); x_1 = sym_size = None |
| return None""") |
| |
| |
| def test_unary(self): |
| def f(x): |
| assert x.shape[0] < 20 |
| return x.cos() |
| test_inputs = [] |
| test_inputs.append([(2, 5)]) |
| test_inputs.append([(6, 8)]) |
| gm = self._test_dynamic(f, [(3, 4)], test_inputs) |
| self.assertTrue(eval_guards(gm, torch.randn(4, 5))) |
| self.assertEqual(repr(bind_symbols(gm, torch.randn(4, 5))), "{s0: 4, s1: 5}") |
| self.assertFalse(eval_guards(gm, torch.randn(25, 5))) |
| self.assertExpectedInline(show_guards(gm), """L['x'].size()[0] < 20""") |
| |
| @unittest.skipIf(not HAS_CUDA, 'CUDA-only test') |
| def test_cpu_scalar_cuda(self): |
| # Extracted from wave2vec2 |
| def f(a, b): |
| return (a * b) @ b |
| |
| r = str( |
| make_fx(f, tracing_mode="symbolic")( |
| torch.tensor(1.0), torch.randn(2, 2, device='cuda') |
| ).code |
| ).strip() |
| self.assertExpectedInline(r, """\ |
| def forward(self, a_1, b_1): |
| mul = torch.ops.aten.mul.Tensor(a_1, b_1); a_1 = None |
| mm = torch.ops.aten.mm.default(mul, b_1); mul = b_1 = None |
| return mm""") |
| |
| def test_binary_broadcast(self): |
| def f(a, b): |
| c = a * b |
| return c |
| |
| test_inputs = [] |
| test_inputs.append([(1, 5), (3, 1)]) |
| test_inputs.append([(1, 4), (4, 1)]) |
| shape_env = self._test_dynamic(f, [(1, 2), (3, 1)], test_inputs).shape_env |
| assert len(shape_env.guards) == 0 |
| |
| def test_multiply_shape(self): |
| def f(a): |
| return torch.empty(a.shape[0] * 2) |
| |
| r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip() |
| self.assertExpectedInline(r, """\ |
| def forward(self, a_1): |
| sym_size = torch.ops.aten.sym_size(a_1, 0); a_1 = None |
| mul = sym_size * 2; sym_size = None |
| empty = torch.ops.aten.empty.memory_format([mul], device = device(type='cpu'), pin_memory = False); mul = None |
| return empty""") |
| |
| def test_item(self): |
| def f(a): |
| r = a.item() |
| return r * a |
| |
| r = str(make_fx(f, tracing_mode="symbolic")(torch.randn(1)).code).strip() |
| self.assertExpectedInline(r, """\ |
| def forward(self, a_1): |
| _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(a_1) |
| mul = torch.ops.aten.mul.Tensor(a_1, _local_scalar_dense); a_1 = _local_scalar_dense = None |
| return mul""") |
| |
| def test_item_to_constructor(self): |
| def f(a): |
| r = a.item() |
| constrain_range(r, min=2) |
| return torch.empty(r) |
| |
| r = str(make_fx(f, tracing_mode="symbolic")(torch.randint(5, (1,))).code).strip() |
| self.assertExpectedInline( |
| r, """\ |
| def forward(self, a_1): |
| _local_scalar_dense = torch.ops.aten._local_scalar_dense.default(a_1); a_1 = None |
| empty = torch.ops.aten.empty.memory_format([_local_scalar_dense], device = device(type='cpu'), pin_memory = False); _local_scalar_dense = None |
| return empty""" # noqa: B950 |
| ) |
| |
| def test_dynamic_pointwise_scalar(self): |
| def f(gravity, mask): |
| gravity[mask, 0] = gravity[mask, 0] * -1 |
| |
| r = str(make_fx(f, tracing_mode="symbolic")( |
| torch.randn((12, 4)), |
| torch.randint(0, 2, (12,), dtype=torch.bool) |
| ).code).strip() |
| self.assertExpectedInline(r, """\ |
| def forward(self, gravity_1, mask_1): |
| select = torch.ops.aten.select.int(gravity_1, 1, 0) |
| index = torch.ops.aten.index.Tensor(select, [mask_1]); select = None |
| mul = torch.ops.aten.mul.Tensor(index, -1); index = None |
| select_1 = torch.ops.aten.select.int(gravity_1, 1, 0); gravity_1 = None |
| index_put_ = torch.ops.aten.index_put_.default(select_1, [mask_1], mul); select_1 = mask_1 = mul = None |
| return None""") |
| |
| def test_reflect_r_over_x(self): |
| def reflect_R_over_x(R): |
| reflect = torch.eye(3, device=R.device) |
| reflect[0, 0] = -1 |
| return reflect @ R @ reflect |
| |
| def f(crop_camera, mask): |
| crop_camera[mask] = reflect_R_over_x(crop_camera[mask]) |
| |
| r = str(make_fx(f, tracing_mode="symbolic")( |
| torch.randn((12, 3, 3)), |
| torch.randint(0, 2, (12,), dtype=torch.bool) |
| ).code).strip() |
| self.assertExpectedInline(r, """\ |
| def forward(self, crop_camera_1, mask_1): |
| index = torch.ops.aten.index.Tensor(crop_camera_1, [mask_1]) |
| eye = torch.ops.aten.eye.default(3, device = device(type='cpu'), pin_memory = False) |
| _tensor_constant0 = self._tensor_constant0 |
| lift_fresh_copy = torch.ops.aten.lift_fresh_copy.default(_tensor_constant0); _tensor_constant0 = None |
| select = torch.ops.aten.select.int(eye, 0, 0) |
| select_1 = torch.ops.aten.select.int(select, 0, 0); select = None |
| copy_ = torch.ops.aten.copy_.default(select_1, lift_fresh_copy); select_1 = lift_fresh_copy = None |
| transpose = torch.ops.aten.transpose.int(index, -2, -1) |
| t = torch.ops.aten.t.default(eye) |
| clone = torch.ops.aten.clone.default(transpose, memory_format = torch.contiguous_format); transpose = None |
| sym_size = torch.ops.aten.sym_size(index, 0); index = None |
| sym_size_1 = torch.ops.aten.sym_size(crop_camera_1, 2) |
| mul = sym_size * sym_size_1 |
| sym_size_2 = torch.ops.aten.sym_size(crop_camera_1, 1) |
| _unsafe_view = torch.ops.aten._unsafe_view.default(clone, [mul, sym_size_2]); clone = mul = sym_size_2 = None |
| mm = torch.ops.aten.mm.default(_unsafe_view, t); _unsafe_view = t = None |
| view = torch.ops.aten.view.default(mm, [sym_size, sym_size_1, 3]); mm = sym_size_1 = None |
| transpose_1 = torch.ops.aten.transpose.int(view, -2, -1) |
| clone_1 = torch.ops.aten.clone.default(transpose_1, memory_format = torch.contiguous_format); transpose_1 = None |
| mul_1 = sym_size * 3 |
| sym_size_3 = torch.ops.aten.sym_size(view, 1); view = None |
| view_1 = torch.ops.aten.view.default(clone_1, [mul_1, sym_size_3]); clone_1 = mul_1 = sym_size_3 = None |
| mm_1 = torch.ops.aten.mm.default(view_1, eye); view_1 = eye = None |
| view_2 = torch.ops.aten.view.default(mm_1, [sym_size, 3, 3]); mm_1 = sym_size = None |
| index_put_ = torch.ops.aten.index_put_.default(crop_camera_1, [mask_1], view_2); crop_camera_1 = mask_1 = view_2 = None |
| return None""") |
| |
| def test_unbacked_slice(self): |
| def f(x, m): |
| x = x[m] |
| return x[slice(None, None, None), slice(None, None, None), slice(None, 2, None)] |
| |
| make_fx(f, tracing_mode="symbolic")( |
| torch.randn((12, 3, 3)), |
| torch.randint(0, 2, (12,), dtype=torch.bool) |
| ) |
| |
| @unittest.skipIf(not USE_TORCHVISION, "test requires torchvision") |
| def test_unbacked_batch_resnet(self): |
| mod = torchvision.models.resnet18() |
| |
| def f(x, mask, params, buffers): |
| for p in itertools.chain([x, mask], params.values(), buffers.values()): |
| for s in p.shape: |
| guard_int(s) |
| x = x[mask] |
| constrain_range(x.shape[0], min=1) |
| for p in params.values(): |
| p.grad = None |
| return torch.func.functional_call(mod, {**params, **buffers}, (x,)).sum() |
| |
| make_fx(f, tracing_mode="symbolic")( |
| torch.randn(3, 3, 250, 250), |
| torch.randint(0, 2, (3,), dtype=torch.bool), |
| dict(mod.named_parameters()), |
| dict(mod.named_buffers()), |
| ) |
| |
| def test_boolean_index(self): |
| def f(images, handedness, valid): |
| images = images[valid] |
| handedness = handedness[valid] |
| right_hand_mask = handedness == 1 |
| images[right_hand_mask] = images[right_hand_mask].flip(-1) |
| |
| r = str(make_fx(f, tracing_mode="symbolic")( |
| torch.randint(0, 256, (512, 1, 96, 96)), |
| torch.randint(0, 1, (512,)), |
| torch.randint(0, 2, (512,), dtype=torch.bool) |
| ).code).strip() |
| self.assertExpectedInline(r, """\ |
| def forward(self, images_1, handedness_1, valid_1): |
| index = torch.ops.aten.index.Tensor(images_1, [valid_1]); images_1 = None |
| index_1 = torch.ops.aten.index.Tensor(handedness_1, [valid_1]); handedness_1 = valid_1 = None |
| eq = torch.ops.aten.eq.Scalar(index_1, 1); index_1 = None |
| index_2 = torch.ops.aten.index.Tensor(index, [eq]) |
| flip = torch.ops.aten.flip.default(index_2, [-1]); index_2 = None |
| index_put_ = torch.ops.aten.index_put_.default(index, [eq], flip); index = eq = flip = None |
| return None""") |
| |
| def test_neg_shape(self): |
| def f(a): |
| return torch.empty(-a.shape[0] + 10) |
| |
| r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(2)).code).strip() |
| self.assertExpectedInline(r, """\ |
| def forward(self, a_1): |
| sym_size = torch.ops.aten.sym_size(a_1, 0); a_1 = None |
| neg = -sym_size; sym_size = None |
| add = neg + 10; neg = None |
| empty = torch.ops.aten.empty.memory_format([add], device = device(type='cpu'), pin_memory = False); add = None |
| return empty""") |
| |
| def test_invalidate_nonzero(self): |
| ok = False |
| |
| def f(a): |
| nonlocal ok |
| b = a.clone() |
| x = b.nonzero() |
| x1 = b.nonzero() |
| x2 = b.nonzero() |
| assert x1.shape[0] == x2.shape[0] |
| ok = True |
| b.normal_() |
| y = b.nonzero() |
| try: |
| bool(x1.shape[0] == y.shape[0]) |
| self.fail("didn't raise exception") |
| except GuardOnDataDependentSymNode: |
| pass |
| |
| make_fx(f, tracing_mode="symbolic")(torch.randn(4)) |
| |
| def test_sqrt_size(self): |
| def f(a): |
| return a / a.size(-1) ** 0.5 |
| |
| r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip() |
| self.assertExpectedInline(r, """\ |
| def forward(self, a_1): |
| sym_size = torch.ops.aten.sym_size(a_1, 0) |
| pow_1 = sym_size ** 0.5; sym_size = None |
| div = torch.ops.aten.div.Tensor(a_1, pow_1); a_1 = pow_1 = None |
| return div""") |
| |
| |
| def test_symint_to_tensor(self): |
| def f(a): |
| return a / a.shape[0] |
| |
| r = str(make_fx(f, tracing_mode="symbolic")(torch.empty(4)).code).strip() |
| self.assertExpectedInline(r, """\ |
| def forward(self, a_1): |
| sym_size = torch.ops.aten.sym_size(a_1, 0) |
| div = torch.ops.aten.div.Tensor(a_1, sym_size); a_1 = sym_size = None |
| return div""") |
| |
| r = str(make_fx(f, tracing_mode="symbolic", decomposition_table=decomposition_table)(torch.empty(4)).code).strip() |
| self.assertExpectedInline(r, """\ |
| def forward(self, a_1): |
| sym_size = torch.ops.aten.sym_size(a_1, 0) |
| sym_float = torch.sym_float(sym_size); sym_size = None |
| div = torch.ops.prims.div.default(a_1, sym_float); a_1 = sym_float = None |
| return div""") |
| |
| def test_cat(self): |
| def f(a, b): |
| val = torch.mul(a, b) |
| out = torch.cat([val, val]) |
| if out.shape[0] * out.shape[1] > 20: |
| out = out.cos() |
| return out |
| |
| test_inputs = [] |
| test_inputs.append([(1, 5), (6, 1)]) |
| test_inputs.append([(1, 4), (3, 1)]) |
| gm = self._test_dynamic(f, [(1, 6), (8, 1)], test_inputs) |
| self.assertTrue(eval_guards(gm, torch.randn(1, 10), torch.randn(6, 1))) |
| self.assertFalse(eval_guards(gm, torch.randn(1, 2), torch.randn(4, 1))) |
| self.assertExpectedInline(show_guards(gm), """2*L['a'].size()[1]*L['b'].size()[0] > 20""") |
| |
| def test_new_empty(self): |
| def f(a, b): |
| return a.new_empty(b.shape[0], b.shape[1] * 2) |
| |
| self._test_dynamic(f, [(2, 4), (4, 5)], [[(2, 3), (5, 7)], [(3, 7), (9, 3)]], assert_eq=False).shape_env |
| |
| def test_size_with_tensor(self): |
| def f(tensor): |
| max_size = torch.tensor([800, 1216], dtype=torch.int64) |
| batch_shape = [2] + list(tensor.shape[:-2]) + list(max_size) |
| return tensor.new_empty(batch_shape) |
| |
| a = torch.randn(3, 800, 1199) |
| self.assertRaisesRegex( |
| RuntimeError, "data-dependent", lambda: make_fx(f, tracing_mode="symbolic")(a) |
| ) |
| |
| def test_expand(self): |
| def f(a): |
| b = torch.mul(a, a) |
| c = b.expand(a.shape) |
| return c |
| |
| self._test_dynamic(f, [(3,)], [[(3,)], [(4,)], [(2,)]]) |
| self._test_dynamic(f, [(5, 1)], [[(4, 1)], [(3, 1)], [(6, 1)]]) |
| |
| def test_metadata(self): |
| def f(a, b): |
| d = a.new_empty(a.shape[0] + b.shape[0]) |
| return d |
| fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5), torch.randn(4)) |
| meta_c = _get_node(fx_g, lambda x: x.target == aten.new_empty.default) |
| meta_d = _get_node(fx_g, lambda x: x.target == operator.add) |
| self.assertTrue(meta_c.meta['val'].shape[0].node.expr == meta_d.meta['val'].node.expr) |
| |
| def test_metadata_fresh(self): |
| def f(x): |
| assert x.shape[0] == 3 |
| return x.cos() |
| |
| fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(3)) |
| meta_cos = _get_node(fx_g, lambda x: x.target == aten.cos.default) |
| meta_inp = _get_node(fx_g, lambda x: x.op == 'placeholder') |
| self.assertTrue(meta_cos.meta['val'].shape[0].node.expr == 3) |
| # Checks if the input expr has been updated even though the constraint |
| # happened afterwards |
| self.assertTrue(meta_inp.meta['val'].shape[0].node.expr == 3) |
| |
| def test_elementwise_meta_with_sym_numbers(self): |
| def f(x, offset, as_sym_float=False): |
| x0 = x.size()[0] |
| if as_sym_float: |
| x0 = sym_float(x0) |
| return torch.add(x0, offset) |
| |
| fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2.0, False) |
| meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor) |
| self.assertEqual(meta_add.meta['val'].shape, ()) |
| self.assertEqual(meta_add.meta['val'].dtype, torch.float32) |
| |
| fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2, False) |
| meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor) |
| self.assertEqual(meta_add.meta['val'].shape, ()) |
| self.assertEqual(meta_add.meta['val'].dtype, torch.int64) |
| |
| fx_g = make_fx(f, tracing_mode="symbolic")(torch.rand(2, 3), 2, True) |
| meta_add = _get_node(fx_g, lambda x: x.target == aten.add.Tensor) |
| self.assertEqual(meta_add.meta['val'].shape, ()) |
| self.assertEqual(meta_add.meta['val'].dtype, torch.float32) |
| |
| def test_return_symint(self): |
| def f(x): |
| return x.shape[0], x.cos(), x.shape[0] / 5 |
| self._test_dynamic(f, [(5,)], [[(4,)], [(12,)]]) |
| |
| def f(x): |
| return x.shape |
| self._test_dynamic(f, [(5, 3)], [[(4, 6)]]) |
| |
| def test_rmethod(self): |
| def f(x): |
| return x.size(0) + x |
| self._test_dynamic(f, [(5,)], [[(4,)], [(12,)]]) |
| |
| def test_mega_guard(self): |
| def f(a, b): |
| assert a.shape[0] == b.shape[0] * 2 |
| return a.cos() |
| fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(16), torch.randn(8)) |
| from torch._dynamo.source import LocalSource |
| self.assertExpectedInline( |
| str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")], ignore_static=False)), # noqa: B950 |
| """["L['a'].size()[0] == 2*L['b'].size()[0]", "L['a'].stride()[0] == 1", "L['a'].storage_offset() == 0", "L['b'].stride()[0] == 1", "L['b'].storage_offset() == 0", "2 <= L['b'].size()[0]"]""" # noqa: B950 |
| ) |
| self.assertExpectedInline( |
| str(fx_g.shape_env.produce_guards(fx_placeholder_vals(fx_g), [LocalSource("a"), LocalSource("b")], ignore_static=True)), # noqa: B950 |
| """["L['a'].size()[0] == 2*L['b'].size()[0]", "2 <= L['b'].size()[0]"]""" # noqa: B950 |
| ) |
| |
| def test_sym_storage_offset(self): |
| def f(x, y): |
| return x + y |
| |
| inp = (torch.randn(8)[3:], torch.randn(5)) |
| fx_g = make_fx(f, tracing_mode="symbolic")(*inp) |
| inp = (torch.randn(8)[3:], torch.randn(5)) |
| self.assertEqual(fx_g(*inp), f(*inp)) |
| |
| def _assert_no_guards(self, fx_g, free_symbols): |
| assert _get_free_symbols(fx_g.shape_env) == free_symbols, fx_g.shape_env.var_to_val |
| assert len(fx_g.shape_env.get_nontrivial_guards()) == 0, fx_g.shape_env.format_guards() |
| |
| def test_guards_equal(self): |
| def f(a, b): |
| return a * b |
| |
| # NB: Numbers are carefully chosen to avoid duck shaping from applying |
| |
| fx_g = _trace(f, (5, 6), (5, 6)) |
| self._assert_no_guards(fx_g, 2) |
| |
| fx_g = _trace(f, (5, 6, 7), (5, 6, 7)) |
| self._assert_no_guards(fx_g, 3) |
| |
| fx_g = _trace(f, (5, 1), (1, 6)) |
| self._assert_no_guards(fx_g, 2) |
| |
| def f(a, b, c, d): |
| a = a + b |
| cat = torch.cat([c, d]) |
| return a + cat |
| |
| fx_g = _trace(f, 7, 7, 4, 3) |
| self._assert_no_guards(fx_g, 2) |
| |
| def f(a, b, c, d, e): |
| vals = [a, b, c, d, e] |
| x = a |
| for idx in range(len(vals) - 1): |
| x = torch.cat([x, vals[idx]]) + vals[idx + 1] |
| return x |
| |
| fx_g = _trace(f, 2, 4, 8, 16, 32) |
| self._assert_no_guards(fx_g, 1) |
| |
| def f(a, b): |
| a = a.view(b.shape[0]) |
| return a + b.sum() |
| |
| fx_g = _trace(f, (4, 2), 8) |
| self._assert_no_guards(fx_g, 2) |
| |
| fx_g = _trace(f, (4, 2), (8, 5)) |
| self._assert_no_guards(fx_g, 3) |
| |
| fx_g = _trace(f, (2, 3, 4), 24) |
| self._assert_no_guards(fx_g, 3) |
| |
| def test_nonidentity_transitive_guards(self): |
| def f(a, b, c, d, e): |
| vals = [a, b, c, d, e] |
| cat_vals = [] |
| for idx in range(len(vals) - 1): |
| cat_vals.append(torch.cat([vals[idx], vals[idx]])) |
| final_vals = [] |
| for a, b in reversed(list(zip(cat_vals, vals[1:]))): |
| final_vals.append(a + b) |
| return final_vals |
| |
| fx_g = _trace(f, 2, 4, 8, 16, 32) |
| self.assertExpectedInline(show_guards(fx_g), """""") |
| |
| |
| |
| |
| |
| make_fx_failures = { |
| # unknown |
| xfail('allclose'), |
| xfail('equal'), |
| # empty |
| skip('new_empty'), |
| skip('empty_like'), |
| skip('empty'), |
| skip('empty_permuted'), |
| # flaky |
| skip('linalg.lstsq', 'grad_oriented'), |
| skip('nn.functional.max_unpool1d', '', device_type='cpu'), |
| skip('nn.functional.max_unpool2d', '', device_type='cpu'), |
| skip('nn.functional.max_unpool3d', '', device_type='cpu'), |
| skip('linalg.lstsq'), # flaky, probably just a precision issue |
| |
| # data-dependent control flow |
| skip('item'), |
| xfail('cov'), |
| xfail('istft'), |
| xfail('nn.functional.gaussian_nll_loss'), |
| xfail('tensor_split'), |
| xfail('corrcoef'), |
| xfail('quantile'), |
| xfail('nanquantile'), |
| xfail('narrow'), |
| |
| # many complex operators incorrect striding, metadata |
| skip('fft.fft', ''), |
| skip('fft.hfft2', ''), |
| skip('fft.hfft', ''), |
| skip('fft.hfftn', ''), |
| skip('fft.ifft', ''), |
| skip('fft.ihfft2', ''), |
| skip('fft.ihfft', ''), |
| skip('fft.ihfftn', ''), |
| skip('fft.irfft2', ''), |
| skip('fft.irfft', ''), |
| skip('fft.irfftn', ''), |
| skip('fft.rfft2', ''), |
| skip('fft.rfft', ''), |
| skip('fft.rfftn', ''), |
| |
| # Seems like it's creating a sparse tensor that isn't captured by tensor.is_sparse |
| xfail('sparse.sampled_addmm'), |
| xfail('sparse.mm', 'reduce'), |
| |
| # proxy tensor doesn't support sparse correctly right now |
| skip('to_sparse'), |
| # segfaults |
| skip('block_diag'), |
| } |
| |
| fake_tensor_failures = { |
| # FakeTensor fallback doesn't work |
| xfail('_segment_reduce', 'lengths'), |
| xfail('multinomial'), |
| xfail('cholesky'), |
| xfail('cholesky_inverse'), |
| # cannot do these as they rely on tensor data |
| xfail('repeat_interleave'), |
| # ASAN failures due to divide by 0 |
| skip('nn.functional.nll_loss'), |
| |
| xfail("stft"), |
| } |
| |
| symbolic_tensor_failures = { |
| xfail('linalg.eig'), |
| xfail('linalg.eigvals'), |
| xfail('cholesky_solve', ''), # Could not run 'aten::_cholesky_solve_helper' with arguments from the 'Meta' back... |
| xfail('combinations', ''), |
| xfail('diff', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition |
| xfail('frexp', ''), # aten.frexp.Tensor - couldn't find symbolic meta function/decomposition |
| xfail('geqrf', ''), # aten.geqrf.default - couldn't find symbolic meta function/decomposition |
| xfail('gradient', ''), # aten.size.default - couldn't find symbolic meta function/decomposition |
| xfail('histc', ''), # Could not run 'aten::histc' with arguments from the 'Meta' backend. This could be because... |
| xfail('histogram', ''), # Could not run 'aten::histogram.bin_ct' with arguments from the 'Meta' backend. This c... |
| xfail('histogramdd', ''), # aten._histogramdd_bin_edges.default - couldn't find symbolic meta function/decomposition |
| xfail('index_reduce', ''), # Float |
| xfail('isin', ''), # aten.isin.Tensor_Tensor - couldn't find symbolic meta function/decomposition |
| xfail('kron', ''), # aten.size.default - couldn't find symbolic meta function/decomposition |
| xfail('kthvalue', ''), # aten.kthvalue.default - couldn't find symbolic meta function/decomposition |
| xfail('linalg.eigh', ''), # aten._linalg_eigh.default - couldn't find symbolic meta function/decomposition |
| xfail('linalg.eigvalsh', ''), # aten._linalg_eigh.default - couldn't find symbolic meta function/decomposition |
| xfail('linalg.householder_product', ''), # aten.linalg_householder_product.default - couldn't find symbolic meta funct... |
| xfail('linalg.ldl_factor', ''), # aten.linalg_ldl_factor_ex.default - couldn't find symbolic meta function/decomposition |
| xfail('linalg.ldl_factor_ex', ''), # aten.linalg_ldl_factor_ex.default - couldn't find symbolic meta function/decompos... |
| xfail('linalg.ldl_solve', ''), # aten.linalg_ldl_solve.default - couldn't find symbolic meta function/decomposition |
| xfail('linalg.lu', ''), # aten.linalg_lu.default - couldn't find symbolic meta function/decomposition |
| xfail('linalg.lu_factor', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition |
| xfail('linalg.lu_factor_ex', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition |
| xfail('linalg.lu_solve', ''), # aten.linalg_lu_solve.default - couldn't find symbolic meta function/decomposition |
| xfail('linalg.matrix_power'), # RuntimeError: Trying to call aten.size on a tensor with symbolic shape |
| xfail('linalg.matrix_rank', 'hermitian'), # aten.size.default - couldn't find symbolic meta function/decomposition |
| xfail('linalg.multi_dot', ''), # aten.size.default - couldn't find symbolic meta function/decomposition |
| xfail('linalg.pinv', ''), # aten.linalg_pinv.atol_rtol_tensor - couldn't find symbolic meta function/decomposition |
| xfail('linalg.pinv', 'singular'), # aten.linalg_cholesky_ex.default - couldn't find symbolic meta function/decomposition |
| xfail('linalg.pinv', 'hermitian'), # aten.linalg_pinv.atol_rtol_tensor - couldn't find symbolic meta function/decompo... |
| xfail('linalg.slogdet', ''), # aten._linalg_slogdet.default - couldn't find symbolic meta function/decomposition |
| xfail('linalg.solve', ''), # aten._linalg_solve_ex.default - couldn't find symbolic meta function/decomposition |
| xfail('linalg.solve_ex', ''), # aten._linalg_solve_ex.default - couldn't find symbolic meta function/decomposition |
| xfail('linalg.tensorinv', ''), # aten.size.default - couldn't find symbolic meta function/decomposition |
| xfail('linalg.tensorsolve', ''), # aten.size.default - couldn't find symbolic meta function/decomposition |
| xfail('linalg.vander', ''), # aten.size.default - couldn't find symbolic meta function/decomposition |
| xfail('logaddexp2', ''), # aten.logaddexp2.default - couldn't find symbolic meta function/decomposition |
| xfail('logdet', ''), # aten.size.default - couldn't find symbolic meta function/decomposition |
| xfail('lu', ''), # aten.linalg_lu_factor_ex.default - couldn't find symbolic meta function/decomposition |
| xfail('lu_solve', ''), # aten.linalg_lu_solve.default - couldn't find symbolic meta function/decomposition |
| xfail('lu_unpack', ''), # aten.lu_unpack.default - couldn't find symbolic meta function/decomposition |
| xfail('masked_select', ''), # aten.masked_select.default - couldn't find symbolic meta function/decomposition |
| xfail('matrix_exp', ''), # aten.linalg_matrix_exp.default - couldn't find symbolic meta function/decomposition |
| xfail('median', ''), # Could not run 'aten::median' with arguments from the 'Meta' backend. This could be becau... |
| xfail('mode', ''), # aten.mode.default - couldn't find symbolic meta function/decomposition |
| xfail('nanquantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend. |
| xfail('narrow', ''), # aten.size.default - couldn't find symbolic meta function/decomposition |
| xfail('nn.functional.adaptive_max_pool1d', ''), # aten.size.default - couldn't find symbolic meta function/decomposition |
| xfail('nn.functional.adaptive_max_pool2d', ''), # aten.adaptive_max_pool2d.default - couldn't find symbolic meta funct... |
| xfail('nn.functional.adaptive_max_pool3d', ''), # argument 'output_size' (position 2) must be tupl... |
| xfail('nn.functional.avg_pool3d', ''), # aten.avg_pool3d.default - couldn't find symbolic meta function/decomposition |
| xfail('nn.functional.bilinear', ''), # aten.size.default - couldn't find symbolic meta function/decomposition |
| xfail('nn.functional.binary_cross_entropy', ''), # aten.new_empty.default - couldn't find symbolic meta function/decom... |
| xfail('nn.functional.cosine_similarity', ''), # aten.size.default - couldn't find symbolic meta function/decomposition |
| xfail('nn.functional.cross_entropy', ''), # aten.size.default - couldn't find symbolic meta function/decomposition |
| xfail('nn.functional.ctc_loss'), # aten._ctc_loss.Tensor - couldn't find symbolic meta function/decomposition |
| xfail('nn.functional.embedding_bag', ''), # aten._embedding_bag_forward_only.default - couldn't find symbolic meta fun... |
| xfail('nn.functional.fractional_max_pool2d', ''), # argument 'size' must be tuple of ints, but found element of t... |
| xfail('nn.functional.fractional_max_pool3d', ''), # argument 'size' must be tuple of ints, but found element of t... |
| xfail('nn.functional.grid_sample', ''), # aten.grid_sampler_2d.default - couldn't find symbolic meta function/decompos... |
| xfail('nn.functional.interpolate', 'linear'), # aten.upsample_linear1d.vec - couldn't find symbolic meta function/dec... |
| xfail('nn.functional.interpolate', 'trilinear'), # aten.upsample_trilinear3d.vec - couldn't find symbolic meta functi... |
| xfail('nn.functional.max_pool1d', ''), # Trying to call aten.size on a tensor with symbolic shapes. |
| xfail('nn.functional.max_pool3d', ''), # aten.max_pool3d_with_indices.default - couldn't find symbolic meta function/d... |
| xfail('nn.functional.max_unpool1d', 'grad'), # aten.max_unpool2d.default - couldn't find symbolic meta function/decom... |
| xfail('nn.functional.max_unpool2d', 'grad'), # aten.max_unpool2d.default - couldn't find symbolic meta function/decom... |
| xfail('nn.functional.max_unpool3d', 'grad'), # aten.max_unpool3d.default - couldn't find symbolic meta function/decom... |
| xfail('nn.functional.multi_margin_loss', ''), # Could not run 'aten::multi_margin_loss' with arguments from the... |
| xfail('nn.functional.multilabel_margin_loss', ''), # Could not run 'aten::multilabel_margin_loss_forward' with ... |
| xfail('nn.functional.pad', 'reflect'), # aten.reflection_pad1d.default - couldn't find symbolic meta function/decompo... |
| xfail('nn.functional.pad', 'replicate'), # aten.replication_pad1d.default - couldn't find symbolic meta function/deco... |
| xfail('nn.functional.pdist', ''), # Could not run 'aten::_pdist_forward' with arguments from the 'Meta' backend... |
| xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco... |
| xfail('nn.functional.smooth_l1_loss', ''), # aten.size.default - couldn't find symbolic meta function/decomposition |
| xfail('normal', 'number_mean'), # aten.normal.float_Tensor - couldn't find symbolic meta function/decomposition |
| xfail('ormqr', ''), # aten.ormqr.default - couldn't find symbolic meta function/decomposition |
| xfail('pinverse', ''), # aten.linalg_pinv.atol_rtol_tensor - couldn't find symbolic meta function/decomposition |
| xfail('polygamma', 'polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition |
| xfail('polygamma', 'polygamma_n_1'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition |
| xfail('polygamma', 'polygamma_n_2'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition |
| xfail('polygamma', 'polygamma_n_3'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition |
| xfail('polygamma', 'polygamma_n_4'), # aten.polygamma.default - couldn't find symbolic meta function/decomposition |
| xfail('quantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend. |
| xfail('renorm', ''), # aten.renorm.default - couldn't find symbolic meta function/decomposition |
| xfail('repeat_interleave', ''), # Cannot call sizes() on tensor with symbolic sizes/strides |
| xfail('resize_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition |
| xfail('resize_as_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition |
| xfail('roll', ''), # Tensors of type TensorImpl do not have numel |
| xfail('searchsorted', ''), # Could not run 'aten::searchsorted.Tensor' with arguments from the 'Meta' backend. ... |
| xfail('_segment_reduce', 'offsets'), # aten.segment_reduce.default - couldn't find symbolic meta function/decomposition |
| xfail('special.airy_ai', ''), # aten.special_airy_ai.default - couldn't find symbolic meta function/decomposition |
| xfail('special.bessel_y0', ''), # aten.special_bessel_y0.default - couldn't find symbolic meta function/decomposition |
| xfail('special.bessel_y1', ''), # aten.special_bessel_y1.default - couldn't find symbolic meta function/decomposition |
| xfail('special.chebyshev_polynomial_t', ''), # aten.special_chebyshev_polynomial_t.default - couldn't find symbolic me... |
| xfail('special.chebyshev_polynomial_u', ''), # aten.special_chebyshev_polynomial_u.default - couldn't find symbolic me... |
| xfail('special.hermite_polynomial_h', ''), # aten.special_hermite_polynomial_h.default - couldn't find symbolic meta f... |
| xfail('special.hermite_polynomial_he', ''), # aten.special_hermite_polynomial_he.default - couldn't find symbolic meta... |
| xfail('special.laguerre_polynomial_l', ''), # aten.special_laguerre_polynomial_l.default - couldn't find symbolic meta... |
| xfail('special.modified_bessel_i0', ''), # aten.special_modified_bessel_i0.default - couldn't find symbolic meta funct... |
| xfail('special.modified_bessel_i1', ''), # aten.special_modified_bessel_i1.default - couldn't find symbolic meta funct... |
| xfail('special.modified_bessel_k0', ''), # aten.special_modified_bessel_k0.default - couldn't find symbolic meta funct... |
| xfail('special.modified_bessel_k1', ''), # aten.special_modified_bessel_k1.default - couldn't find symbolic meta funct... |
| xfail('special.polygamma', 'special_polygamma_n_0'), # aten.polygamma.default - couldn't find symbolic meta function/... |
| xfail('special.scaled_modified_bessel_k0', ''), # aten.special_scaled_modified_bessel_k0.default - couldn't find symbo... |
| xfail('special.scaled_modified_bessel_k1', ''), # aten.special_scaled_modified_bessel_k1.default - couldn't find symbo... |
| xfail('stft', ''), # argument 'size' must be tuple of ints, but found element of type torch._C.SymIntNode at... |
| xfail('take_along_dim', ''), # dtype of indices should be Long but got Float |
| xfail('triangular_solve', ''), # aten.triangular_solve.default - couldn't find symbolic meta function/decomposition |
| xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition |
| xfail('unique', ''), # aten._unique2.default - couldn't find symbolic meta function/decomposition |
| } |
| symbolic_tensor_segfaults = { |
| skip('nn.functional.batch_norm') # Segfault?? |
| } |
| |
| symbolic_tensor_failures.update(symbolic_tensor_segfaults) |
| |
| outplace_symbolic_tensor_failures = { |
| xfail('i0', ''), # aten.i0.default - couldn't find symbolic meta function/decomposition |
| xfail('masked_scatter', ''), # aten.masked_scatter.default - couldn't find symbolic meta function/decomposition |
| xfail('nn.functional.rrelu', ''), # aten.empty_like.default - couldn't find symbolic meta function/decomposition |
| } |
| |
| inplace_symbolic_tensor_failures = { |
| # bugs |
| xfail('float_power', ''), # base given to float_power_ has dtype Float but the operation's result requires dtype Double |
| # decomp not implemented |
| xfail('unique', ''), |
| # in-place has a different signature than out-of-place |
| xfail('uniform', ''), |
| } |
| |
| # Copies inputs to inplace operations to avoid inplace modifications |
| # to leaves requiring gradient |
| def _get_safe_inplace(inplace_variant): |
| @functools.wraps(inplace_variant) |
| def _fn(t, *args, **kwargs): |
| return inplace_variant(t.clone(), *args, **kwargs) |
| |
| return _fn |
| |
| def _test_make_fx_helper(self, device, dtype, op, tracing_mode, inplace=False): |
| def f(args, kwargs, extra_args, extra_kwargs): |
| if extra_args: |
| for i, t in extra_args: |
| args[i] = t.size() |
| if extra_kwargs: |
| for k, t in extra_kwargs.items(): |
| kwargs[k] = t.size() |
| |
| fn = _get_safe_inplace(op.get_inplace()) if inplace else op.op |
| return fn(*args, **kwargs) |
| sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False) |
| new_f = None |
| |
| # Limit ourselves to first 100 inputs so symbolic tracing tests don't take too long |
| for sample_input in itertools.islice(sample_inputs_itr, 100): |
| if inplace and sample_input.broadcasts_input: |
| continue |
| args = [sample_input.input] + list(sample_input.args) |
| kwargs = sample_input.kwargs |
| |
| # If any argument is a torch.Size(), maybe get dynamic shapes for it by: |
| # - Create a temporary Tensor whose size is the torch.Size() we want. Note that |
| # we use an expanded Tensor as we cannot pass "meta" Tensors to make_fx. |
| # - Pass it to make_fx such that it is is converted to a proxy Tensor |
| # - Unpack the size in the wrapper to get a torch.Size with dynamic shapes (in |
| # symbolic mode, a no-op otherwise) |
| extra_args = [] |
| extra_kwargs = {} |
| for i, arg in enumerate(args): |
| if isinstance(arg, torch.Size): |
| extra_args.append((i, torch.empty(arg, device="cpu"))) |
| for key, value in kwargs.items(): |
| if isinstance(value, torch.Size): |
| extra_kwargs[key] = torch.empty(value, device="cpu") |
| |
| try: |
| new_f = make_fx(f, tracing_mode=tracing_mode)(args, kwargs, extra_args, extra_kwargs) |
| except DynamicOutputShapeException as e: |
| self.skipTest("Dynamic output shape operation in trace") |
| for arg in args: |
| if isinstance(arg, torch.Tensor) and arg.dtype == torch.float: |
| arg.uniform_(0, 1) |
| try: |
| old_out = f(args, kwargs, extra_args, extra_kwargs) |
| except Exception: |
| continue |
| new_out = wrapper_set_seed(new_f, args, kwargs, extra_args, extra_kwargs) |
| self.assertEqual(new_out, old_out) |
| |
| class TestProxyTensorOpInfo(TestCase): |
| @ops(op_db + custom_op_db, allowed_dtypes=(torch.float,)) |
| @skipOps('TestProxyTensorOpInfo', 'test_make_fx_exhaustive', make_fx_failures) |
| def test_make_fx_exhaustive(self, device, dtype, op): |
| _test_make_fx_helper(self, device, dtype, op, "real") |
| |
| @ops(op_db + custom_op_db, allowed_dtypes=(torch.float,)) |
| @skipOps('TestProxyTensorOpInfo', 'test_make_fx_fake_exhaustive', make_fx_failures.union(fake_tensor_failures)) |
| def test_make_fx_fake_exhaustive(self, device, dtype, op): |
| _test_make_fx_helper(self, device, dtype, op, "fake") |
| |
| @ops(op_db + custom_op_db, allowed_dtypes=(torch.float,)) |
| @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive', |
| make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | outplace_symbolic_tensor_failures) |
| def test_make_fx_symbolic_exhaustive(self, device, dtype, op): |
| _test_make_fx_helper(self, device, dtype, op, "symbolic") |
| |
| @ops(op_db + custom_op_db, allowed_dtypes=(torch.float,)) |
| @skipOps('TestProxyTensorOpInfo', 'test_make_fx_symbolic_exhaustive_inplace', |
| make_fx_failures | fake_tensor_failures | symbolic_tensor_failures | inplace_symbolic_tensor_failures) |
| def test_make_fx_symbolic_exhaustive_inplace(self, device, dtype, op): |
| if not op.get_inplace(): |
| self.skipTest("No inplace variable for this op") |
| _test_make_fx_helper(self, device, dtype, op, "symbolic", inplace=True) |
| |
| |
| only_for = ("cpu") |
| instantiate_device_type_tests(TestProxyTensorOpInfo, globals(), only_for=only_for) |
| |
| |
| if __name__ == '__main__': |
| run_tests() |