[Dynamo] Symbolic shape guards (#87570)
**Introduces symbolic shape guards into dynamo.**
In this PR, we take the existing fake tensor infra and plumbing in dynamo and we start passing a shape_env around. This shape_env does not get plumbed down to middle layers / backend yet - it only collects expressions from frontend invocations at the moment. We then translate these expressions into guards at the point where we take other guards installed throughout dynamo - and add them to check_fn.
Part 1 of https://docs.google.com/document/d/1QJ-M4zfMkD-fjHIqW089RptjLl9EgozZGCceUbvmgfY/edit#
cc @jansel @lezcano @fdrocha @mlazos @soumith @yanboliang @penguinwu @anijain2305
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87570
Approved by: https://github.com/ezyang
diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py
index a2a94fc..a32825d 100644
--- a/test/dynamo/test_dynamic_shapes.py
+++ b/test/dynamo/test_dynamic_shapes.py
@@ -3,14 +3,26 @@
from torch._dynamo.testing import make_test_cls_with_patches
try:
- from . import test_functions, test_misc, test_modules, test_repros, test_unspec
+ from . import (
+ test_export,
+ test_functions,
+ test_misc,
+ test_modules,
+ test_repros,
+ test_subgraphs,
+ test_unspec,
+ )
except ImportError:
+ import test_export
import test_functions
import test_misc
import test_modules
import test_repros
+ import test_subgraphs
import test_unspec
+import unittest
+
def make_dynamic_cls(cls):
return make_test_cls_with_patches(
@@ -23,6 +35,145 @@
DynamicShapesReproTests = make_dynamic_cls(test_repros.ReproTests)
DynamicShapesNNModuleTests = make_dynamic_cls(test_modules.NNModuleTests)
DynamicShapesUnspecTests = make_dynamic_cls(test_unspec.UnspecTests)
+DynamicShapesExportTests = make_dynamic_cls(test_export.ExportTests)
+DynamicShapesSubGraphTests = make_dynamic_cls(test_subgraphs.SubGraphTests)
+
+
+# DynamicShapesFunctionTests
+unittest.expectedFailure(
+ DynamicShapesFunctionTests.test_len_tensor_dynamic_shapes
+ # TypeError: 'torch._C.SymIntNode' object cannot be interpreted as an integer
+)
+
+unittest.expectedFailure(
+ DynamicShapesFunctionTests.test_tensor_len_dynamic_shapes
+ # TypeError: 'torch._C.SymIntNode' object cannot be interpreted as an integer
+)
+
+
+# DynamicShapesReproTests
+unittest.expectedFailure(
+ DynamicShapesReproTests.test_reformer_eval_dynamic_shapes
+ # TypeError: 'torch._C.SymIntNode' object cannot be interpreted as an integer
+)
+
+unittest.expectedFailure(
+ DynamicShapesReproTests.test_reformer_train_dynamic_shapes
+ # TypeError: 'torch._C.SymIntNode' object cannot be interpreted as an integer
+)
+
+unittest.expectedFailure(
+ DynamicShapesReproTests.test_issue175_dynamic_shapes
+ # TypeError: 'torch._C.SymIntNode' object cannot be interpreted as an integer
+)
+
+unittest.expectedFailure(
+ DynamicShapesReproTests.test_do_paste_mask_dynamic_shapes
+ # aten.min.dim - couldn't find symbolic meta function/decomposition
+)
+
+unittest.expectedFailure(
+ DynamicShapesReproTests.test_convert_boxes_to_pooler_format_dynamic_shapes
+ # Could not infer dtype of torch._C.SymIntNode
+)
+
+unittest.expectedFailure(
+ DynamicShapesReproTests.test_ellipsis_dynamic_shapes
+ # Cannot call sizes() on tensor with symbolic sizes/strides
+)
+
+unittest.expectedFailure(
+ DynamicShapesReproTests.test_hf_t5_forward_dynamic_shapes
+ # Cannot call sizes() on tensor with symbolic sizes/strides
+)
+
+unittest.expectedFailure(
+ DynamicShapesReproTests.test_reformer_sorting_dynamic_shapes
+ # Unable to cast Python instance to C++ type
+)
+
+unittest.expectedFailure(
+ DynamicShapesReproTests.test_boxes_len_dynamic_shapes
+ # Unable to cast Python instance to C++ type
+)
+
+unittest.expectedFailure(
+ DynamicShapesReproTests.test_guard_fail_tensor_bool_dynamic_shapes
+ # RuntimeError: aten.allclose.default - couldn't find symbolic meta function/decomposition
+)
+
+# DynamicShapesMiscTests
+unittest.expectedFailure(
+ DynamicShapesMiscTests.test_unsupported_fake_tensor_dynamic_shapes
+ # aten.quantize_per_tensor.default - couldn't find symbolic meta function/decomposition
+)
+unittest.expectedFailure(
+ DynamicShapesMiscTests.test_module_deepcopy_dynamic_shapes
+ # aten.squeeze_.dim - couldn't find symbolic meta function/decompositio
+)
+
+# DynamicShapesUnspecTests
+unittest.expectedFailure(
+ DynamicShapesUnspecTests.test_unspec_float_precision_dynamic_shapes
+ # float() argument must be a string or a real number, not 'torch._C.SymIntNode'
+)
+
+
+# DynamicShapesNNModuleTests
+unittest.expectedFailure(
+ DynamicShapesNNModuleTests.test_unsupportedmethod_dynamic_shapes
+ # aten.squeeze_.dim - couldn't find symbolic meta function/decomposition
+)
+
+unittest.expectedFailure(
+ DynamicShapesNNModuleTests.test_unsupportedmodule_dynamic_shapes
+ # aten.squeeze_.dim - couldn't find symbolic meta function/decomposition
+)
+
+unittest.expectedFailure(
+ DynamicShapesNNModuleTests.test_self_mutating1_dynamic_shapes
+ # aten.squeeze_.dim - couldn't find symbolic meta function/decomposition
+)
+
+unittest.expectedFailure(
+ DynamicShapesNNModuleTests.test_call_fn_with_non_const_inputs_safe_dynamic_shapes
+ # aten.squeeze_.dim - couldn't find symbolic meta function/decomposition
+)
+
+
+# DynamicShapesExportTests
+unittest.expectedFailure(
+ DynamicShapesExportTests.test_export_compare_optimize_with_make_fx_dynamic_shapes
+)
+unittest.expectedFailure(
+ DynamicShapesExportTests.test_export_with_constant_list_nonzero_dynamic_shapes
+)
+unittest.expectedFailure(
+ DynamicShapesExportTests.test_export_with_constant_list_nonzero_free_function_dynamic_shapes
+)
+unittest.expectedFailure(
+ DynamicShapesExportTests.test_export_with_constant_tuple_nonzero_dynamic_shapes
+)
+unittest.expectedFailure(
+ DynamicShapesExportTests.test_export_with_stack_trace_dynamic_shapes
+)
+unittest.expectedFailure(
+ DynamicShapesExportTests.test_zeroes_in_new_shape_scalar_out_dynamic_shapes
+)
+unittest.expectedFailure(
+ DynamicShapesExportTests.test_zeroes_in_new_shape_scalar_out_permute_dupe_and_bypass_dynamic_shapes
+)
+unittest.expectedFailure(
+ DynamicShapesExportTests.test_zeroes_in_new_shape_scalar_out_permute_dynamic_shapes
+)
+
+
+# DynamicShapesSubGraphTests
+unittest.expectedFailure(
+ DynamicShapesSubGraphTests.test_enumerate_not_break_graph_dynamic_shapes
+)
+unittest.expectedFailure(DynamicShapesSubGraphTests.test_restore_state_dynamic_shapes)
+
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py
index d18ef7e..d428a43 100644
--- a/test/dynamo/test_functions.py
+++ b/test/dynamo/test_functions.py
@@ -6,6 +6,7 @@
import itertools
import operator
from typing import Any
+from unittest.mock import patch
import torch
diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py
index bbb8ba5..66fc198 100644
--- a/test/dynamo/test_repros.py
+++ b/test/dynamo/test_repros.py
@@ -872,8 +872,9 @@
self.assertTrue(same(opt_fn(input1), correct1))
self.assertTrue(same(opt_fn(input2), correct2))
- self.assertEqual(cnt.frame_count, ifdyn(1, 2))
- self.assertEqual(cnt.op_count, ifdyn(19, 4))
+ # Dyn recompiles are due to changes in hidden_state (Should we be guarding on this?)
+ self.assertEqual(cnt.frame_count, ifdyn(4, 2))
+ self.assertEqual(cnt.op_count, ifdyn(76, 4))
def test_hf_t5_forward(self):
input = torch.randn([1, 2048, 512])
diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py
index 5701363..d406f2e 100644
--- a/test/functorch/test_aotdispatch.py
+++ b/test/functorch/test_aotdispatch.py
@@ -1174,6 +1174,7 @@
xfail('nn.functional.rrelu', ''), # aten.rrelu_with_noise.default - couldn't find symbolic meta function...
xfail('nn.functional.smooth_l1_loss', ''), # could not find kernel
xfail('nn.functional.unfold', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
+ xfail('unfold', ''), # aten.squeeze_copy.dim - couldn't find symbolic meta function/decomposition
xfail('nn.functional.upsample_bilinear', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('nn.functional.upsample_nearest', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('norm', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py
index 3c2e818..1e72d5a 100644
--- a/test/test_proxy_tensor.py
+++ b/test/test_proxy_tensor.py
@@ -1288,6 +1288,7 @@
xfail('nn.functional.unfold', ''), # aten.im2col.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.upsample_bilinear', ''), # aten.upsample_bilinear2d.vec - couldn't find symbolic meta function/de...
xfail('nn.functional.upsample_nearest', ''), # aten.upsample_nearest1d.vec - couldn't find symbolic meta function/deco...
+ xfail('nonzero', ''), # aten.nonzero.default - couldn't find symbolic meta function/decomposition
xfail('norm', 'nuc'), # aten._linalg_svd.default - couldn't find symbolic meta function/decomposition
xfail('normal', ''), # aten.normal.Tensor_Tensor - couldn't find symbolic meta function/decomposition
xfail('normal', 'number_mean'), # aten.normal.float_Tensor - couldn't find symbolic meta function/decomposition
@@ -1305,6 +1306,7 @@
xfail('qr', ''), # aten.linalg_qr.default - couldn't find symbolic meta function/decomposition
xfail('rad2deg', ''), # aten.rad2deg.default - couldn't find symbolic meta function/decomposition
xfail('renorm', ''), # aten.renorm.default - couldn't find symbolic meta function/decomposition
+ xfail('repeat_interleave', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('reshape_as', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('resize_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition
xfail('resize_as_', ''), # aten.clone.default - couldn't find symbolic meta function/decomposition
@@ -1354,6 +1356,8 @@
xfail('view_as', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('vsplit', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('unbind', ''), # aten.unbind.int - couldn't find symbolic meta function/decomposition
+ xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
+ xfail('unique', ''), # aten._unique2.default - couldn't find symbolic meta function/decomposition
}
symbolic_tensor_segfaults = {
skip('nn.functional.batch_norm') # Segfault??
@@ -1454,6 +1458,7 @@
xfail('true_divide', ''), # aten.div_.Tensor - couldn't find symbolic meta function/decomposition
xfail('trunc', ''), # aten.trunc_.default - couldn't find symbolic meta function/decomposition
xfail('uniform', ''), # aten.uniform_.default - couldn't find symbolic meta function/decomposition
+ xfail('unique', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
xfail('unsqueeze', ''), # aten.unsqueeze_.default - couldn't find symbolic meta function/decomposition
xfail('xlogy', ''), # aten.xlogy_.Tensor - couldn't find symbolic meta function/decomposition
}
diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py
index 46a23b3..206cffb 100644
--- a/torch/_dynamo/convert_frame.py
+++ b/torch/_dynamo/convert_frame.py
@@ -417,7 +417,7 @@
assert output.guards is not None
CleanupManager.instance[out_code] = output.cleanups
- check_fn = CheckFunctionManager(output.guards, locals, globals)
+ check_fn = CheckFunctionManager(output, output.guards, locals, globals)
guarded_code = GuardedCode(out_code, check_fn.check_fn)
guard_str = "GUARDS:\n"
diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py
index 1f43ac6..9edd6f6 100644
--- a/torch/_dynamo/guards.py
+++ b/torch/_dynamo/guards.py
@@ -12,7 +12,10 @@
import numpy as np
+import sympy
+
import torch
+from torch.fx.experimental.symbolic_shapes import FloorDiv
from . import config, convert_frame, mutation_guard
from .eval_frame import set_guard_error_hook, set_guard_fail_hook
@@ -176,6 +179,7 @@
# Code is python expression strings generated for each guard
self.code: List[str] = []
self.tensor_check_names = []
+ self.tensor_check_ids = {}
self.tensor_check_examples = []
self.guarded_code = guarded_code
@@ -414,9 +418,13 @@
self.ID_MATCH(guard)
else:
value = self.get(guard.name)
- self.tensor_check_names.append(self.arg_ref(guard))
+ tensor_name = self.arg_ref(guard)
+ self.tensor_check_names.append(tensor_name)
self.tensor_check_examples.append(value)
+ # STOP - DO NOT USE id_ref FOR TENSORS - TENSOR INVALIDATION RULES DIFFER
+ self.tensor_check_ids[tensor_name] = id(value)
+
# Note: Guard code produced for tensor_match is a little different.
# We accumulate tensor names, then do a single install of `___check_tensors`.
# See _guards.cpp and TensorGuard for more information.
@@ -469,6 +477,62 @@
check_fn: Callable
+from sympy.printing.str import StrPrinter
+
+
[email protected]
+class TensorReference(object):
+ """
+ TensorReference objects are entirely optional. They are created to give us hints
+ into where the symbolic shape came from.
+
+ ref_id: The id of the tensor
+ kind: A string tracking where in the tensor this value came from ("size","stride", etc)
+ idx: An index in the structure
+
+ NOTE - A symbolic shape coming from tensor at id 12345's shape dim 2, would be
+ TensorReference(ref_id=12345, kind="size", idx=2)
+ """
+
+ ref_id: Optional[int] = None
+ kind: Optional[str] = None
+ idx: Optional[int] = None
+ # Note - this is untyped because of TypeError: '_SpecialForm' object does not support item assignment
+ # But it is a Optional[Union["sympy.Expr", int]]
+ expr: Optional[object] = None # Populated after association
+
+ def __hash__(self):
+ return hash((self.ref_id, self.kind, self.idx))
+
+
+class DynamoGuardPrinter(StrPrinter):
+ @staticmethod
+ def tensor_ref_as_str(tensor_ref, id_to_name_map):
+ if tensor_ref.kind in ("size", "stride"):
+ return f"{id_to_name_map[tensor_ref.ref_id]}.{tensor_ref.kind}()[{tensor_ref.idx}]"
+ return f"{id_to_name_map[tensor_ref.ref_id]}.{tensor_ref.kind}()"
+
+ def __init__(self, expr_to_tensor_ref, id_to_name_map):
+ super().__init__()
+ self.expr_to_tensor_ref = expr_to_tensor_ref
+ self.id_to_name_map = id_to_name_map
+
+ def _print_Symbol(self, expr) -> str:
+ assert isinstance(expr, sympy.core.symbol.Symbol)
+ if expr == 0:
+ return "0"
+ if expr == 1:
+ return "1"
+ assert expr in self.expr_to_tensor_ref, f"Unknown expression {expr}"
+ refs = self.expr_to_tensor_ref[expr]
+ if len(refs) == 0:
+ return super()._print_Symbol(expr)
+ tensor_ref = next(
+ iter(refs)
+ ) # Any is fine here, because we install equality guards later
+ return DynamoGuardPrinter.tensor_ref_as_str(tensor_ref, self.id_to_name_map)
+
+
# NB: Naively, you'd expect this to only be a function that produces
# the callable that consistutes the guard. However, there is some
# delicate handling for invalidating this check function when the
@@ -482,6 +546,7 @@
class CheckFunctionManager:
def __init__(
self,
+ output_graph=None,
guards: Optional[Set[Guard]] = None,
f_locals: Optional[Dict] = None,
f_globals: Optional[Dict] = None,
@@ -489,6 +554,7 @@
self.valid = True
self._weakrefs = []
self._seen_ids = set()
+ self.output_graph = output_graph
# Note: right overrides left
def combine_scopes(left, right):
@@ -511,6 +577,82 @@
self.check_fn = self.compile_check_fn(local_builder, global_builder)
self._seen_ids.clear()
+ """
+ This is a complex bit of logic. The outline here is brief. For a line by line breakdown, see
+ the code comments below.
+
+ The role of this function is to take the current state of symbolic shape guards, tensor ids in the
+ CURRENT dynamo frame, and tensor names (dynamo's frame agnostic tensor reference mechanism, see TensorCheck and
+ guards.cpp for more info) - and produce executable python expressions for addition to our guarded code components
+ that make their way into check_fn.
+
+ We DO NOT create guards based on ids. The IDs act as a lookup for the following mapping:
+
+ dynamo: tensor_name <> tensor_id
+ shape_env: tensor_id <> shape_expr
+
+ This allows us to then create a tensor_name <> shape_expr association for the current frames guards.
+ """
+
+ def _parse_symbolic_shape_expressions(self, tensor_check_names, tensor_check_ids):
+ # Pre join output
+ finished_expressions = []
+
+ # A mapping of tensor_ids to tensor names
+ id_to_name_map = {}
+
+ # We should not have a shape env, or guards if we are not in config.dynamic shapes
+ # But check it anyway.
+ if not config.dynamic_shapes:
+ return None
+
+ expr_to_tensor_ref = {}
+ guard_printer = DynamoGuardPrinter(expr_to_tensor_ref, id_to_name_map)
+
+ # tensor_check_names is the primary tensor association mechanism in dynamo.
+ # All other guards installations are driven off of it, so these ones will too.
+ for name in tensor_check_names:
+ tensor_id = tensor_check_ids[name]
+ id_to_name_map[tensor_id] = name
+
+ if tensor_id in self.output_graph.tensor_id_to_sym_shape_ref:
+ # If we made it here, this tensor_id is relevant to dynamo guard installation
+ # AND was found in the shape_env
+ tensor_ref_set = self.output_graph.tensor_id_to_sym_shape_ref[tensor_id]
+ for tensor_ref in tensor_ref_set:
+ obj_expr = tensor_ref.expr
+ if obj_expr not in expr_to_tensor_ref:
+ expr_to_tensor_ref[obj_expr] = {}
+ expr_to_tensor_ref[obj_expr][tensor_ref] = ""
+ finished_expressions.append(f"isinstance({name}, torch.Tensor)")
+
+ guard_expression = self.output_graph.shape_env.get_guard_expr()
+ expr_as_str = guard_printer.doprint(guard_expression)
+ # We may get into a state where symbolic shape keys (all should be found in replacements)
+ # Have not been removed from the expression. This is a serious enough error state that we need to assert.
+ for key in self.output_graph.shape_env.var_to_val.keys():
+ assert str(key) not in expr_as_str, f"Unknown shape symbol {key}. "
+ finished_expressions.append(expr_as_str)
+
+ for expr in expr_to_tensor_ref.keys():
+ tensor_refs = expr_to_tensor_ref[expr].keys()
+ equality_candidates = [
+ DynamoGuardPrinter.tensor_ref_as_str(x, id_to_name_map)
+ for x in tensor_refs
+ ]
+
+ if len(equality_candidates) > 1:
+ equality_expr = " == ".join(equality_candidates)
+ # breakpoint()
+ finished_expressions.append(equality_expr)
+
+ # Redundant with code_parts, but allows us to wrap it with parens nicely.
+ if len(finished_expressions) == 0:
+ return None
+
+ expression = " and ".join(finished_expressions)
+ return f"({expression})"
+
def compile_check_fn(self, local_builder, global_builder):
assert not (set(local_builder.argnames) & set(global_builder.argnames))
# see parallel handling of ".0" / "___implicit0" in _eval_frame.c
@@ -530,9 +672,20 @@
tensor_check_names = (
local_builder.tensor_check_names + global_builder.tensor_check_names
)
+
+ tensor_check_ids = local_builder.tensor_check_ids.copy()
+ tensor_check_ids.update(global_builder.tensor_check_ids)
+
check_tensors_fn = None
check_tensors_verbose_fn = None
if tensor_check_names:
+ symbolic_shape_expression = self._parse_symbolic_shape_expressions(
+ tensor_check_names, tensor_check_ids
+ )
+ if symbolic_shape_expression:
+ code_parts.append(symbolic_shape_expression)
+ verbose_code_parts.append(symbolic_shape_expression)
+
tensor_check_examples = (
local_builder.tensor_check_examples
+ global_builder.tensor_check_examples
@@ -548,14 +701,23 @@
)
verbose_code_parts.append(f"___check_tensors_verbose({verbose_args})")
- code = " and ".join(unique(code_parts))
+ def direct_equality(a, b):
+ return a == b
+ def direct_negation(a, b):
+ return not direct_equality(a, b)
+
+ code = " and ".join(unique(code_parts))
closure_vars = collections.OrderedDict(
[
("___guarded_code", self),
("___check_tensors", check_tensors_fn),
("___check_tensors_verbose", check_tensors_verbose_fn),
("tensor_check_names", tensor_check_names),
+ ("Eq", direct_equality),
+ ("Ne", direct_negation),
+ ("Mod", sympy.Mod),
+ ("FloorDiv", FloorDiv),
]
)
closure_vars.update(CLOSURE_VARS)
@@ -567,6 +729,7 @@
print("GUARDS", code)
set_guard_fail_hook(guard_fail_hook)
out = dict()
+ # print("RUNNING PY CODE", py_code)
exec(py_code, global_builder.scope, out)
guard_fn = out["___make_guard_fn"](*closure_vars.values())
guard_fn.closure_vars = closure_vars
diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py
index f87b079..c23d4f6 100644
--- a/torch/_dynamo/output_graph.py
+++ b/torch/_dynamo/output_graph.py
@@ -10,6 +10,7 @@
import torch.nn
from torch import fx
+from torch.fx.experimental.symbolic_shapes import ShapeEnv
from . import config, logging as torchdynamo_logging, variables
from .bytecode_transformation import create_instruction, Instruction, unique_id
@@ -104,6 +105,8 @@
self.random_values_var = None
self.initial_random_state = ()
self.unspec_variable_map = {}
+ self.shape_env = ShapeEnv() if config.dynamic_shapes else None
+ self.tensor_id_to_sym_shape_ref = {}
@property
def output(self):
@@ -394,8 +397,10 @@
gm.recompile()
gm.compile_subgraph_reason = self.compile_subgraph_reason
name = unique_id("__compiled_fn")
+
compiled_fn = self.call_user_compiler(gm)
compiled_fn = disable(compiled_fn)
+
counters["stats"]["unique_graphs"] += 1
self.install_global(name, compiled_fn)
diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py
index 0b5cfae..4031a97 100644
--- a/torch/_dynamo/symbolic_convert.py
+++ b/torch/_dynamo/symbolic_convert.py
@@ -1340,7 +1340,8 @@
if fake_tensors_available:
with torch._subclasses.FakeTensorMode(
- throw_on_data_dependent_ops=True
+ throw_on_data_dependent_ops=True,
+ shape_env=output.shape_env,
) as fake_mode:
pass
self._fake_mode = fake_mode
diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py
index aa64de0..1bc646b 100644
--- a/torch/_dynamo/utils.py
+++ b/torch/_dynamo/utils.py
@@ -25,6 +25,7 @@
from typing import Any, Dict
import numpy as np
+import sympy
import torch
from torch import fx
@@ -666,6 +667,43 @@
UnsupportedFakeTensorException,
)
+ def make_fake_tensor(e, fake_mode, tx=None):
+ fake_tensor = fake_mode.from_tensor(
+ e, static_shapes=config.dynamic_shapes is False
+ )
+ if tx is not None:
+ from torch._dynamo.guards import TensorReference
+
+ def _record(tensor_ref):
+ if tensor_ref.ref_id not in tx.output.tensor_id_to_sym_shape_ref:
+ tx.output.tensor_id_to_sym_shape_ref[tensor_ref.ref_id] = set()
+ tx.output.tensor_id_to_sym_shape_ref[tensor_ref.ref_id].add(tensor_ref)
+
+ def _extract(symbol):
+ if isinstance(symbol, int):
+ return None
+ sym_expr = symbol.get_pyobj().expr
+ if not isinstance(sym_expr, sympy.Symbol):
+ return None
+ return sym_expr
+
+ def _record_ref(e, index, symbol, kind):
+ sym_expr = _extract(symbol)
+ if sym_expr:
+ tensor_ref = TensorReference(id(e), kind, index, sym_expr)
+ _record(tensor_ref)
+
+ for index, symbol in enumerate(fake_tensor.size()):
+ _record_ref(e, index, symbol, "size")
+
+ for index, symbol in enumerate(fake_tensor.stride()):
+ _record_ref(e, index, symbol, "stride")
+
+ offset = fake_tensor.storage_offset()
+ _record_ref(e, None, offset, "storage_offset")
+
+ return fake_tensor
+
def wrap_fake_exception(fn):
try:
return fn()
@@ -678,7 +716,13 @@
def wrap_to_fake_tensor(e, fake_mode):
if type(e) in (torch.Tensor, torch.nn.Parameter):
- return wrap_fake_exception(lambda: fake_mode.from_tensor(e))
+ return wrap_fake_exception(lambda: make_fake_tensor(e, fake_mode))
+ else:
+ return e
+
+ def wrap_to_fake_tensor_and_record(e, tx):
+ if type(e) in (torch.Tensor, torch.nn.Parameter):
+ return wrap_fake_exception(lambda: make_fake_tensor(e, tx.fake_mode, tx))
else:
return e
diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py
index 53fdb95..cc64e00 100644
--- a/torch/_dynamo/variables/builtin.py
+++ b/torch/_dynamo/variables/builtin.py
@@ -359,11 +359,23 @@
a, b = b, a
assert isinstance(a, variables.TensorVariable)
- # 1. result of an item call is a scalar convert to a tensor
- # 2. dynamic shape should be resolved to tensor
- if isinstance(a, (FakeItemVariable, DynamicShapeVariable)):
+ # result of an item call is a scalar convert to a tensor
+ if isinstance(a, FakeItemVariable):
a = variables.TorchVariable(torch.tensor).call_function(tx, [a], {})
+ # Dynamic input does not get resolved, rather, gets stored as call_function
+ if isinstance(a, DynamicShapeVariable):
+ return variables.TensorVariable.create(
+ tx=tx,
+ proxy=tx.output.create_proxy(
+ "call_function",
+ self.fn,
+ *proxy_args_kwargs([a, b], {}),
+ current_tx=tx,
+ ),
+ **VariableTracker.propagate(self, [a, b]),
+ )
+
# convert min/max to torch ops
if b.is_python_constant():
kwargs = {"min": b} if (self.fn is max) else {"max": b}
diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py
index a8db819..864d2c4 100644
--- a/torch/_dynamo/variables/tensor.py
+++ b/torch/_dynamo/variables/tensor.py
@@ -17,7 +17,7 @@
DataDependentOutputException,
DynamicOutputShapeException,
)
- from ..utils import deepcopy_to_fake_tensor, wrap_to_fake_tensor
+ from ..utils import deepcopy_to_fake_tensor, wrap_to_fake_tensor_and_record
import torch.utils._python_dispatch as py_dispatch
from torch.fx.immutable_collections import immutable_list
@@ -98,7 +98,7 @@
Run the computation represented by `node` using fake tensors and return the result.
"""
op = node.op
- fake_wrapper = functools.partial(wrap_to_fake_tensor, fake_mode=tx.fake_mode)
+ fake_wrapper = functools.partial(wrap_to_fake_tensor_and_record, tx=tx)
from ..utils import wrap_fake_exception
def visit(n: torch.fx.Node):
@@ -206,7 +206,7 @@
proxy.tracer.real_value_cache[proxy.node] = _clone_input(example_value)
if use_fake_tensors:
fake_wrapper = functools.partial(
- wrap_to_fake_tensor, fake_mode=tx.fake_mode
+ wrap_to_fake_tensor_and_record, tx=tx
)
example_value = fake_wrapper(example_value)
@@ -241,14 +241,14 @@
return TorchVariable(proxy.node.target)
elif istype(example_value, (int, bool, float)) and config.dynamic_shapes:
proxy.node.meta["example_value"] = example_value
- return DynamicShapeVariable(proxy, type(example_value), **options)
+ return DynamicShapeVariable(proxy, example_value, **options)
elif istype(example_value, torch.Size) and config.dynamic_shapes:
proxy.node.meta["example_value"] = example_value
sizes = []
for i, v in enumerate(example_value):
proxy_i = proxy[i]
proxy_i.node.meta["example_value"] = v
- sizes.append(DynamicShapeVariable(proxy_i, int))
+ sizes.append(DynamicShapeVariable(proxy_i, v))
return SizeVariable(sizes, proxy, **options)
elif istype(example_value, int) and proxy.node.target in (
torch.seed,
@@ -258,7 +258,7 @@
getattr(torch.distributed, "get_world_size", _missing),
):
proxy.node.meta["example_value"] = example_value
- return DynamicShapeVariable(proxy, type(example_value), **options)
+ return DynamicShapeVariable(proxy, example_value, **options)
elif istype(example_value, torch.Size) and all(
[isinstance(x, int) for x in example_value]
):
@@ -337,6 +337,9 @@
from . import UserDefinedObjectVariable
return UserDefinedObjectVariable(example_value)
+ elif isinstance(example_value, torch.SymIntNode):
+ proxy.node.meta["example_value"] = example_value
+ return cls(proxy, **options)
else:
raise AssertionError(
"torch.* op returned non-Tensor "
@@ -474,7 +477,6 @@
kwargs = dict(kwargs)
options = VariableTracker.propagate(self, args, kwargs.values())
-
if name == "stride" and self.stride is not None:
constant_result = ConstantVariable(self.stride, **options)
elif name == "size" and self.size is not None:
@@ -578,12 +580,12 @@
Represents a symbolic size, e.g., as returned by tensor.size(0)
"""
- def __init__(self, proxy, dyn_shape_cls, **kwargs):
+ def __init__(self, proxy, dyn_shape, **kwargs):
super(DynamicShapeVariable, self).__init__(proxy, **kwargs)
- self.dyn_shape_cls = dyn_shape_cls
+ self.dyn_shape = dyn_shape
def python_type(self):
- return self.dyn_shape_cls
+ return type(self.dyn_shape)
def unpack_var_sequence(self, tx):
super(DynamicShapeVariable, self).unpack_var_sequence(tx)
diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py
index 1ecfbe1..e0c88b2 100644
--- a/torch/_dynamo/variables/torch.py
+++ b/torch/_dynamo/variables/torch.py
@@ -344,6 +344,24 @@
example_value=example_value,
**options,
)
+ elif (
+ self.value == torch.numel
+ and len(args) == 1
+ and isinstance(args[0], TensorVariable)
+ and len(kwargs) == 0
+ ):
+ # TODO(voz): This is rewritten as a call_method because
+ # torch.numel(x) w/ sym shapes raises a RuntimeError and x.numel() does not
+ return TensorVariable.create(
+ tx=tx,
+ proxy=tx.output.create_proxy(
+ "call_method",
+ "numel",
+ *proxy_args_kwargs(args, kwargs),
+ current_tx=tx,
+ ),
+ **options,
+ )
else:
# Handle sth like torch.LongTensor(list(np.int64, np.int64, ...)),
# as FX symbolic trace doesn't support numpy int/float as base types.
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index 2f2f07f..652c24c 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -234,7 +234,7 @@
warnings.filterwarnings("ignore", "The .grad attribute of a Tensor")
grad_not_none = t.grad is not None
if grad_not_none:
- out.grad = self.from_real_tensor(fake_mode, t.grad)
+ out.grad = self.from_real_tensor(fake_mode, t.grad, shape_env=shape_env)
self.set_tensor_memo(t, out)
return out
diff --git a/torch/_subclasses/meta_utils.py b/torch/_subclasses/meta_utils.py
index 80723f1..3e1040d 100644
--- a/torch/_subclasses/meta_utils.py
+++ b/torch/_subclasses/meta_utils.py
@@ -146,7 +146,7 @@
def sym(x):
if make_symbolic:
- return shape_env.create_symbol(x)
+ return shape_env.create_symintnode(shape_env.create_symbol(x))
else:
return x