| import builtins |
| import collections |
| import logging |
| import math |
| import os |
| import re |
| import types |
| import weakref |
| from inspect import currentframe, getframeinfo |
| from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union |
| from weakref import ReferenceType |
| |
| import torch |
| |
| from torch._guards import ( |
| DuplicateInputs, |
| Guard, |
| GuardBuilderBase, |
| GuardEnvExpr, |
| GuardSource, |
| Source, |
| ) |
| from torch.fx.experimental.symbolic_shapes import SYMPY_INTERP |
| |
| from . import config, convert_frame, mutation_guard |
| from .eval_frame import set_guard_error_hook, set_guard_fail_hook |
| from .exc import unimplemented |
| from .types import GuardedCode, GuardFail, GuardFn # noqa: F401 |
| from .utils import ( |
| dict_const_keys, |
| dict_const_keys_repr, |
| dict_param_key_ids, |
| guard_failures, |
| HAS_NUMPY, |
| istype, |
| np, |
| orig_code_map, |
| rename_implicit, |
| tensor_always_has_static_shape, |
| tensor_static_reason_to_message, |
| tuple_iterator_getitem, |
| tuple_iterator_len, |
| ) |
| |
| log = logging.getLogger(__name__) |
| TensorGuards = torch._C._dynamo.guards.TensorGuards |
| check_obj_id = torch._C._dynamo.guards.check_obj_id |
| check_type_id = torch._C._dynamo.guards.check_type_id |
| |
| |
| CLOSURE_VARS = collections.OrderedDict( |
| [ |
| ("___check_type_id", check_type_id), |
| ("___check_obj_id", check_obj_id), |
| ("___is_grad_enabled", torch.is_grad_enabled), |
| ("___odict_getitem", collections.OrderedDict.__getitem__), |
| ("___dict_param_key_ids", dict_param_key_ids), |
| ("___dict_const_keys", dict_const_keys), |
| ("___tuple_iterator_len", tuple_iterator_len), |
| ("___tuple_iterator_getitem", tuple_iterator_getitem), |
| ("__math_isnan", math.isnan), |
| ("inf", float("inf")), |
| ] |
| ) |
| |
| |
| def strip_function_call(name): |
| """ |
| "___odict_getitem(a, 1)" => "a" |
| """ |
| m = re.search(r"([a-z0-9_]+)\(([^(),]+)[^()]*\)", name) |
| if m and m.group(1) != "slice": |
| return strip_function_call(m.group(2)) |
| return strip_getattr_getitem(name) |
| |
| |
| def strip_getattr_getitem(name): |
| """ |
| "a[1]" => "a" |
| "a.foo" => "a" |
| """ |
| return re.split(r"[.\[]", name)[0] |
| |
| |
| class GuardBuilder(GuardBuilderBase): |
| def __init__( |
| self, |
| id_ref: Callable[[Type[object]], str], |
| source_ref: Callable[[Source], str], |
| scope: Optional[Dict[str, object]], |
| check_fn_manager: "CheckFunctionManager", |
| renames=True, |
| ): |
| self.id_ref = id_ref |
| self.source_ref = source_ref |
| if scope: |
| if renames: |
| scope = {rename_implicit(k): v for k, v in scope.items()} |
| else: |
| scope = dict() |
| self.scope: Dict[str, object] = scope |
| self.scope["__builtins__"] = builtins.__dict__.copy() |
| for ( |
| name, |
| package_module, |
| ) in torch.package.package_importer._package_imported_modules.items(): |
| name = name.replace(">", "_").replace("<", "_").replace(".", "_dot_") |
| # Write the package module into the scope so that we can import it |
| self.scope["__builtins__"][name] = package_module # type: ignore[index] |
| # Write the demangled name to the scope so that we can use it |
| self.scope[name] = package_module |
| |
| self.argnames: List[str] = [] |
| # Code is python expression strings generated for each guard |
| self.code: List[str] = [] |
| # shape_env_code is only used by local_builder and is used for |
| # shape env code. This exists only because we need to make sure |
| # shape env guards get run after tensor match guards (since the |
| # tensor match guards make sure we actually have tensors) |
| self.shape_env_code: List[str] = [] |
| |
| # [Note - On Eager Tensor Guards] |
| # Most of the time, we generate Python code in a guard to directly |
| # check various properties. However, tensors are a bit special; |
| # it is too slow to check their properties one-by-one in Python. |
| # Instead, there is a C++ function TensorGuards.check which takes |
| # all of the tensor arguments and checks them all against compile-time |
| # examples entirely in C++. Thus, every time we process a |
| # TENSOR_MATCH guard, we just add another entry to |
| # tensor_check_names/tensor_check_examples, saying "for this local, |
| # check it against this example", and it all ends up getting |
| # swept up into a single call to ___check_tensors. Invariant: |
| # len(tensor_check_names) == len(tensor_check_examples). |
| self.tensor_check_names: List[str] = [] |
| self.tensor_check_examples: List[torch.Tensor] = [] |
| |
| self.check_fn_manager: CheckFunctionManager = check_fn_manager |
| |
| # Warning: use this with care! This lets you access what the current |
| # value of the value you are guarding on is. You probably don't want |
| # to actually durably save this value though (because it's specific |
| # to this frame!) Instead, you should be reading out some property |
| # (like its type) which is what you permanently install into the |
| # guard code. |
| def get(self, name: str) -> Any: |
| return eval(name, self.scope, CLOSURE_VARS) |
| |
| # Registers the usage of the source name referenced by the |
| # string (or stored in the Guard) as being guarded upon. It's important |
| # to call this before generating some code that makes use of 'guard', |
| # because without this call, we won't actually bind the variable |
| # you reference in the actual guard closure (oops!) |
| def arg_ref(self, guard: Union[str, Guard]) -> str: |
| name: str |
| if isinstance(guard, str): |
| name = guard |
| else: |
| name = guard.name |
| base = strip_getattr_getitem(strip_function_call(name)) |
| if base not in self.argnames: |
| if re.match(r"^\d+$", base): |
| log.warning(f"invalid var name: {guard}") |
| self.argnames.append(base) |
| |
| return name |
| |
| def TYPE_MATCH(self, guard: Guard): |
| # ___check_type_id is same as `id(type(x)) == y` |
| t = type(self.get(guard.name)) |
| obj_id = self.id_ref(t) |
| code = f"___check_type_id({self.arg_ref(guard)}, {obj_id})" |
| self._produce_guard_code(guard, [code]) |
| |
| def BOOL_FALSE(self, guard: Guard): |
| # Guard on the runtime value being 'False', |
| # can be faster than seemingly equivalent checks like DICT_KEYS for empty dict |
| # |
| # WARNING: this guard is not safe to use generally. It only works if the runtime |
| # value is of a type that supports bool(), and some types e.g. Tensor do not. |
| # Only use this guard in cases you can guarantee the runtime type will be friendly. |
| # (e.g. Specialized NNModule with mutation protection via setattr) |
| # |
| # Why not simply check the runtime type inside this guard? It's slow enough to defeat |
| # the purpose of using this guard, which itself is supposed to be a faster alternative |
| # to DICT_KEYS. |
| ref = self.arg_ref(guard) |
| code = f"not {ref}" |
| self._produce_guard_code(guard, [code]) |
| |
| def ID_MATCH(self, guard: Guard): |
| # ___check_obj_id is same as `id(x) == y` |
| m = re.match(r"^type\((.+)\)$", guard.name) |
| if m: |
| # optional optimization to produce cleaner/faster guard code |
| return self.TYPE_MATCH( |
| Guard(m.group(1), guard.source, GuardBuilder.TYPE_MATCH) |
| ) |
| |
| code = f"___check_obj_id({self.arg_ref(guard)}, {self.id_ref(self.get(guard.name))})" |
| self._produce_guard_code(guard, [code]) |
| |
| def NAME_MATCH(self, guard: Guard): |
| obj = self.get(guard.name) |
| code = f"{self.arg_ref(guard)}.__name__ == {obj.__name__}" |
| self._produce_guard_code(guard, [code]) |
| |
| def HASATTR(self, guard: Guard): |
| m = re.match(r"^(.*)[.]([a-zA-Z0-9_]+)$", guard.name) |
| assert m, f"invalid hasattr check {guard.name}" |
| base, attr = m.group(1, 2) |
| ref = self.arg_ref(base) |
| val = hasattr(self.get(base), attr) |
| code = None |
| if val: |
| code = f"hasattr({ref}, {attr!r})" |
| else: |
| code = f"not hasattr({ref}, {attr!r})" |
| |
| self._produce_guard_code(guard, [code], provided_guarded_object=self.get(base)) |
| |
| def EQUALS_MATCH(self, guard: Guard): |
| ref = self.arg_ref(guard) |
| val = self.get(guard.name) |
| t = type(val) |
| np_types = ( |
| ( |
| np.int8, |
| np.int16, |
| np.int32, |
| np.int64, |
| np.uint8, |
| np.uint16, |
| np.uint32, |
| np.uint64, |
| np.float16, |
| np.float32, |
| np.float64, |
| ) |
| if HAS_NUMPY |
| else () |
| ) |
| assert istype( |
| val, |
| ( |
| int, |
| float, |
| bool, |
| type(None), |
| str, |
| type, |
| list, |
| tuple, |
| set, |
| slice, |
| frozenset, |
| range, |
| torch.Size, |
| torch.device, |
| torch.dtype, |
| ) |
| + np_types, |
| ), t.__name__ |
| |
| if istype(val, (torch.device, torch.dtype)): |
| # TODO(jansel): is this slow? perhaps optimize it |
| code = [f"str({ref}) == {str(val)!r}"] |
| self._produce_guard_code(guard, code) |
| return |
| |
| # Special case for nan because float("nan") == float("nan") evaluates to False |
| if istype(val, float) and math.isnan(val): |
| code = list() |
| code.append(f"___check_type_id({ref}, {self.id_ref(t)})") |
| code.append(f"__math_isnan({ref})") |
| self._produce_guard_code(guard, code) |
| return |
| |
| # Add type check to prevent equality check between tensor and non-tensor. |
| code = list() |
| if istype(val, (list, tuple)): |
| self.LIST_LENGTH(guard) |
| |
| for idx, elem in enumerate(val): |
| code.append( |
| f"___check_type_id({ref}[{idx}], {self.id_ref(type(elem))})" |
| ) |
| |
| elif not istype(val, torch.Size): |
| code.append(f"___check_type_id({ref}, {self.id_ref(t)})") |
| |
| if istype(val, torch.Size): |
| val = tuple(val) |
| |
| code.append(f"{ref} == {val!r}") |
| self._produce_guard_code(guard, code) |
| |
| def CONSTANT_MATCH(self, guard: Guard): |
| val = self.get(guard.name) |
| if istype(val, (bool, type(None))): |
| self.ID_MATCH(guard) |
| else: |
| self.EQUALS_MATCH(guard) |
| |
| def NN_MODULE(self, guard: Guard): |
| self.ID_MATCH(guard) |
| ref = self.arg_ref(guard) |
| val = self.get(guard.name) |
| |
| def setup_guard(): |
| assert istype(val.training, bool) |
| self.code.append(f"{ref}.training == {val.training}") |
| |
| if hasattr(val, "training"): |
| # There are cases where a monkeypatched object has a guard made between __new__ and __init__ |
| setup_guard() |
| else: |
| unimplemented(f"Guard setup for uninitialized class {type(val)}") |
| |
| def FUNCTION_MATCH(self, guard: Guard): |
| """things like torch.add and user defined functions""" |
| if guard.is_local(): |
| return self.ID_MATCH(guard) |
| |
| def BUILTIN_MATCH(self, guard: Guard): |
| return self.FUNCTION_MATCH(guard) |
| |
| def PYMODULE_MATCH(self, guard: Guard): |
| return self.FUNCTION_MATCH(guard) |
| |
| def LIST_LENGTH(self, guard): |
| ref = self.arg_ref(guard) |
| value = self.get(guard.name) |
| t = type(value) |
| |
| code = list() |
| code.append(f"___check_type_id({ref}, {self.id_ref(t)})") |
| code.append(f"len({ref}) == {len(value)}") |
| |
| self._produce_guard_code(guard, code) |
| |
| def TUPLE_ITERATOR_LEN(self, guard): |
| ref = self.arg_ref(guard) |
| value = self.get(guard.name) |
| t = type(value) |
| |
| code = list() |
| code.append(f"___check_type_id({ref}, {self.id_ref(t)})") |
| code.append(f"___tuple_iterator_len({ref}) == {tuple_iterator_len(value)}") |
| |
| self._produce_guard_code(guard, code) |
| |
| def DICT_KEYS(self, guard): |
| ref = self.arg_ref(guard) |
| value = self.get(guard.name) |
| t = type(value) |
| |
| code = list() |
| code.append(f"___check_type_id({ref}, {self.id_ref(t)})") |
| param_key_ids = set(dict_param_key_ids(value)) |
| const_keys = set(dict_const_keys(value)) |
| const_keys_repr = dict_const_keys_repr(const_keys) |
| if param_key_ids: |
| code.append(f"___dict_param_key_ids({ref}) == {param_key_ids!r}") |
| code.append(f"___dict_const_keys({ref}) == {const_keys_repr}") |
| else: |
| code.append(f"set({ref}.keys()) == {const_keys_repr}") |
| |
| self._produce_guard_code(guard, code) |
| |
| def WEAKREF_ALIVE(self, guard): |
| self._produce_guard_code(guard, [f"{self.arg_ref(guard)} is not None"]) |
| |
| def NN_MODULE_PARAM_NAMES(self, guard): |
| ref = self.arg_ref(guard) |
| value = self.get(guard.name) |
| t = type(value) |
| keys = {k for k, v in value.named_parameters()} |
| |
| code = list() |
| code.append(f"___check_type_id({ref}, {self.id_ref(t)})") |
| code.append(f"{{k for k, v in {ref}.named_parameters()}} == {keys!r}") |
| |
| self._produce_guard_code(guard, code) |
| |
| def ODICT_KEYS(self, guard): |
| """OrderedDict keys match""" |
| ref = self.arg_ref(guard) |
| value = self.get(guard.name) |
| t = type(value) |
| |
| code = list() |
| code.append(f"___check_type_id({ref}, {self.id_ref(t)})") |
| code.append(f"str({ref}.keys()) == {str(value.keys())!r}") |
| |
| self._produce_guard_code(guard, code) |
| |
| def OBJECT_MUTATION(self, guard: Guard): |
| mutation_guard.watch(self.get(guard.name), self.check_fn_manager) |
| |
| def GRAD_MODE(self, guard: Guard): |
| """Guard on the initial grad state""" |
| assert guard.name == "" |
| assert guard.source is GuardSource.GLOBAL |
| code = None |
| if convert_frame.initial_grad_state: |
| code = "___is_grad_enabled()" |
| else: |
| code = "not ___is_grad_enabled()" |
| self._produce_guard_code(guard, [code]) |
| |
| def SHAPE_ENV(self, guard: Guard): |
| # Let's handle ShapeEnv guards. To do this, we will resolve |
| # shape variables to sources from tracked_fakes. This must happen after |
| # tensor checks. |
| assert guard.name == "" |
| output_graph = self.check_fn_manager.output_graph |
| # NB: self.output_graph can be None in the debug_nops tests |
| fs = output_graph.tracked_fakes |
| guards = output_graph.shape_env.produce_guards( |
| [a.fake for a in fs], |
| [a.source for a in fs], |
| source_ref=self.source_ref, |
| ) |
| for shape_guard in guards: |
| self._produce_guard_code(guard, [shape_guard], shape_env=True) |
| |
| def TENSOR_MATCH(self, guard: Guard): |
| if guard.is_nn_module(): |
| self.ID_MATCH(guard) |
| else: |
| value = self.get(guard.name) |
| assert isinstance(value, torch.Tensor) |
| tensor_name = self.arg_ref(guard) |
| # [Note - On Export Tensor Guards] |
| # |
| # In eager mode, tensor guards are evaluated through C++, in guards.cpp |
| # see [Note - On Eager Tensor Guards] for more info. |
| # |
| # In export mode, we instead maintain parallel logic between C++ and python |
| # here, with an exception of checking the dispatch key - with the idea that a dispatch key |
| # is an entirely runtime notion that would make no sense to keep in an exported graph. |
| # |
| # Now, this idea is okay, but to paraphrase @ezyang, this mental model is sufficient for now, although |
| # not entirely true. |
| # For example, suppose one of the input tensors had the negative dispatch key. |
| # You should end up with a graph that is specialized for tensors that have a negative dispatch key. |
| # If you allow a Tensor that does NOT have this bit set, you will accidentally run it "as if" it were negated. |
| # Now, negative key only shows up for complex numbers, and most likely, the exported to target doesn't |
| # support this feature at all, but the point stands that :some: tensor state only shows up on dispatch key. |
| # TODO(voz): Either populate a dispatch_key check into the guards, or error on users passing in an unsupported |
| # subset of keys during export. |
| # |
| # The list of tensor fields and calls we care about can be found in `terms` below. |
| # TODO(voz): We are missing storage offset in all our tensor guards? |
| code: List[str] = list() |
| if self.check_fn_manager.output_graph.export: |
| self.TYPE_MATCH(guard) |
| terms = [ |
| "dtype", |
| "device.type", |
| "device.index", |
| "requires_grad", |
| "ndimension()", |
| ] |
| if not config.dynamic_shapes: |
| terms.append("stride()") |
| # We need to do this to avoid the torch.Size type in guards |
| code.append(f"{tensor_name}.shape == {tuple(value.shape)}") |
| |
| for term in terms: |
| real_value = self.get(tensor_name + "." + term) |
| code.append(f"{tensor_name}.{term} == {real_value}") |
| else: |
| self.tensor_check_names.append(tensor_name) |
| self.tensor_check_examples.append(value) |
| |
| # A frame is valid for reuse with dynamic dimensions if the new dynamic dimensions are a |
| # strict subset of the old. |
| # |
| # The logic here is as follows: |
| # |
| # Every mark_dynamic directive is a user-knows-best command, which can incur a raise at tracing |
| # time if we find guards that run counter to the user directive. |
| # If compiling a frame with explicit dynamic dims X could cause an exception, we MUST NOT skip compiling. |
| # |
| # If the frame is compiled with any marked dynamic indices, let's call that set of indices X. |
| # When we evaluated inputs against the guards, given the same tensor with potentially new dynamic indices, |
| # let's call that set Y. |
| # |
| # When X is a strict subset of Y, the potential new raises introduced during compilation are a strict subset |
| # of the raises we |
| # could have encountered. The frame compiled under Y is safe to reuse with X. |
| # When X is not a strict subset of Y, the non-overlapping new elements of X may cause new raises, and the |
| # frame is no longer fit for reuse. |
| # |
| # This is the case because any newly introduced mark_dynamic directives have a chance of |
| # raising, failing compilation. Any existing mark_dynamic indices that we lost are safe to lose |
| # as all it means is that we have gotten rid of a user directive which could incur a raise at compile time. |
| # In the case of when there is no Y, that is, there are no dynamic indices marked at all, the frame is safe |
| # to reuse |
| # as an empty set is a safe degeneration - that is, a strictly static tensor is always valid for a frame |
| # compiled with that same |
| # tensor + more onerous user directives. |
| static, reason = tensor_always_has_static_shape( |
| value, guard.source, is_tensor=True |
| ) |
| if not static: |
| if hasattr(value, "_dynamo_dynamic_indices"): |
| code.append( |
| f"({tensor_name}._dynamo_dynamic_indices.issubset({value._dynamo_dynamic_indices})) if hasattr({tensor_name}, '_dynamo_dynamic_indices') else True" # noqa: B950 |
| ) |
| # In the case of us not having any dynamic dimension indices, we compiled the frame with no chance of |
| # raising for this specific tensor - and any inputs with more dynamic user directives specified must be recompiled. |
| else: |
| code.append( |
| f"hasattr({tensor_name}, '_dynamo_dynamic_indices') == False" |
| ) |
| else: |
| assert not hasattr( |
| value, "_dynamo_dynamic_indices" |
| ), f"Illegal Unreachable state, guard accumulation for dynamic tensor that should have been static. Initial static message: {tensor_static_reason_to_message(reason)}" # noqa: B950 |
| |
| if len(code) > 0: |
| self._produce_guard_code(guard, code) |
| |
| # A util that appends guarded code, or, in the case of export, adds data onto guards |
| def _produce_guard_code( |
| self, guard, code_list, provided_guarded_object=None, shape_env=False |
| ): |
| # WARNING: It is important that cur_frame/caller do NOT stay in |
| # the current frame, because they will keep things live longer |
| # than they should. See TestMisc.test_release_module_memory |
| cur_frame = currentframe() |
| assert cur_frame is not None |
| caller = cur_frame.f_back |
| del cur_frame |
| assert caller is not None |
| func_name = getframeinfo(caller)[2] |
| del caller |
| # We use func_name for export, so might as well get a nice defensive check out of it |
| assert func_name in dir( |
| self.__class__ |
| ), f"_produce_guard_code must be called from inside GuardedCode. Called from {func_name}" |
| |
| if shape_env: |
| self.shape_env_code.extend(code_list) |
| else: |
| self.code.extend(code_list) |
| |
| # Not all guards have names, some can be installed globally (see asserts on HAS_GRAD) |
| if provided_guarded_object is None: |
| name_valid = guard.name is not None and guard.name != "" |
| |
| guarded_object = self.get(guard.name) if name_valid else None |
| else: |
| guarded_object = provided_guarded_object |
| |
| guarded_object_type = ( |
| weakref.ref(type(guarded_object)) if guarded_object is not None else None |
| ) |
| obj_ref = None |
| if hasattr(guarded_object.__class__, "__weakref__"): |
| obj_ref = weakref.ref(guarded_object) |
| |
| guard.set_export_info( |
| func_name, |
| guarded_object_type, |
| code_list, |
| obj_ref, |
| ) |
| |
| |
| # NB: Naively, you'd expect this to only be a function that produces |
| # the callable that constitutes the guard. However, there is some |
| # delicate handling for invalidating this check function when the |
| # locals/globals get invalidated, so there's some extra state |
| # we have to hold in this manager class. |
| # |
| # TODO: this object has reference cycle with itself, via check_fn which |
| # references back to CheckFunction via ___guarded_code in closure_vars. |
| # Ideally, there shouldn't be any ref cycle so that guards are |
| # promptly disposed of. |
| class CheckFunctionManager: |
| def __init__( |
| self, |
| output_graph=None, |
| f_locals: Optional[Dict[str, object]] = None, |
| f_globals: Optional[Dict[str, object]] = None, |
| guard_fail_fn: Optional[Callable[[Tuple[str, str]], None]] = None, |
| ): |
| guards = output_graph.guards if output_graph else None |
| self.valid = True |
| self._weakrefs: List["ReferenceType[object]"] = [] |
| self._seen_ids: Set[int] = set() |
| self.output_graph = output_graph |
| |
| # Note: right overrides left |
| def combine_scopes(left, right): |
| if left is None: |
| return right |
| |
| if right is None: |
| return left |
| |
| return {**left, **right} |
| |
| def source_ref(source): |
| guard_source = source.guard_source() |
| if guard_source is GuardSource.CONSTANT: |
| # No need to track constants |
| return source.name() |
| builder = guard_source.select(w_local(), w_global()) |
| assert builder is not None |
| return builder.arg_ref(source.name()) |
| |
| local_builder = GuardBuilder( |
| self.id_ref, |
| source_ref, |
| combine_scopes(f_globals, f_locals), |
| self, |
| renames=True, |
| ) |
| global_builder = GuardBuilder( |
| self.id_ref, source_ref, f_globals, self, renames=False |
| ) |
| # source_ref can cause a cycle, make sure we break it with weakref |
| w_local = weakref.ref(local_builder) |
| w_global = weakref.ref(global_builder) |
| for guard in sorted(guards or [], key=Guard.sort_key): |
| if ( |
| not config.guard_nn_modules |
| and guard.is_nn_module() |
| # Default func args must be guarded on. |
| # TODO: we could make use of 'DefaultsSource' and offer a .guard.is_defaults() API |
| and "__defaults__" not in guard.name |
| and "__kwdefaults__" not in guard.name |
| and "hooks" not in guard.name |
| ): |
| continue |
| guard.create(local_builder, global_builder) |
| self.check_fn = self.compile_check_fn( |
| local_builder, global_builder, guards, guard_fail_fn |
| ) |
| self._seen_ids.clear() |
| |
| def compile_check_fn( |
| self, local_builder, global_builder, guards_out, guard_fail_fn |
| ): |
| assert not (set(local_builder.argnames) & set(global_builder.argnames)) |
| # see parallel handling of ".0" / "___implicit0" in _eval_frame.c |
| largs = [a for a in local_builder.scope.keys() if a == "___implicit0"] |
| largs += [a for a in local_builder.argnames if a != "___implicit0"] |
| largs += ["**___kwargs_ignored"] |
| args = ",".join(largs) |
| |
| code_parts = ( |
| ["___guarded_code.valid"] + local_builder.code + global_builder.code |
| ) |
| # TODO(whc) maybe only the 'check_tensors' one is ambiguous? if so we can be less general.. |
| verbose_code_parts = ( |
| ["___guarded_code.valid"] + local_builder.code + global_builder.code |
| ) |
| |
| tensor_check_names = ( |
| local_builder.tensor_check_names + global_builder.tensor_check_names |
| ) |
| |
| check_tensors_fn = None |
| check_tensors_verbose_fn = None |
| if tensor_check_names: |
| assert ( |
| not self.output_graph.export |
| ), "Illegal to set tensor_check_names in export." |
| tensor_check_examples = ( |
| local_builder.tensor_check_examples |
| + global_builder.tensor_check_examples |
| ) |
| tensor_guards = TensorGuards( |
| *tensor_check_examples, dynamic_shapes=config.dynamic_shapes |
| ) |
| check_tensors_fn = tensor_guards.check |
| check_tensors_verbose_fn = tensor_guards.check_verbose |
| code_parts.append(f"___check_tensors({', '.join(tensor_check_names)})") |
| verbose_args = ", ".join( |
| tensor_check_names + ["tensor_check_names=tensor_check_names"] |
| ) |
| verbose_code_parts.append(f"___check_tensors_verbose({verbose_args})") |
| |
| aotautograd_guards: List[GuardEnvExpr] = ( |
| self.output_graph.tracing_context.guards_context.aotautograd_guards |
| if self.output_graph |
| else [] |
| ) |
| for guard in aotautograd_guards: |
| if isinstance(guard, DuplicateInputs): |
| pos_a = self.output_graph.pos_to_arg[guard.input_pos_a] |
| pos_b = self.output_graph.pos_to_arg[guard.input_pos_b] |
| assert ( |
| pos_b >= 0 and pos_a >= 0 |
| ), "Deduped args out of bounds, cannot be negative" |
| |
| assert self.output_graph.graphargs[ |
| pos_a |
| ].is_tensor, "Deduped arg must be a tensor" |
| assert self.output_graph.graphargs[ |
| pos_b |
| ].is_tensor, "Deduped arg must be a tensor" |
| code_part = f"{self.output_graph.graphargs[pos_a].source.name()} is {self.output_graph.graphargs[pos_b].source.name()}" # noqa: B950 |
| code_parts.append(code_part) |
| verbose_code_parts.append(code_part) |
| else: |
| raise RuntimeError(f"Unknown GuardEnvExpr: {guard}") |
| |
| code_parts.extend(local_builder.shape_env_code) |
| verbose_code_parts.extend(local_builder.shape_env_code) |
| assert not global_builder.shape_env_code |
| |
| code = " and ".join(unique(code_parts)) |
| closure_vars = collections.OrderedDict( |
| [ |
| ("___guarded_code", self), |
| ("___check_tensors", check_tensors_fn), |
| ("___check_tensors_verbose", check_tensors_verbose_fn), |
| ("tensor_check_names", tensor_check_names), |
| ] |
| + list(SYMPY_INTERP.items()) |
| ) |
| closure_vars.update(CLOSURE_VARS) |
| py_code = f"""\ |
| def ___make_guard_fn({','.join(closure_vars.keys())}): |
| return lambda {args}: {code} |
| """ |
| if os.environ.get("TORCHDYNAMO_PRINT_GUARDS", None) == "1": |
| print("GUARDS", code) |
| set_guard_fail_hook(guard_fail_hook) |
| out: Dict[str, Any] = dict() |
| # print("RUNNING PY CODE", py_code) |
| exec(py_code, global_builder.scope, out) |
| guard_fn = out["___make_guard_fn"](*closure_vars.values()) |
| guard_fn.closure_vars = closure_vars |
| # TODO(whc) maybe '.code_parts' was only kept around for the guard callback? so we don't need both |
| guard_fn.args = largs |
| guard_fn.code_parts = code_parts |
| guard_fn.verbose_code_parts = verbose_code_parts |
| guard_fn.global_scope = global_builder.scope |
| guard_fn.guard_fail_fn = guard_fail_fn |
| return guard_fn |
| |
| def invalidate(self, ref): |
| # A weakref is no longer valid, self.check_fn should return false |
| self.valid = False |
| |
| def id_ref(self, obj): |
| """add a weakref, return the id""" |
| try: |
| if id(obj) not in self._seen_ids: |
| self._weakrefs.append(weakref.ref(obj, self.invalidate)) |
| self._seen_ids.add(id(obj)) |
| except TypeError: |
| pass # cannot weakref bool object |
| return id(obj) |
| |
| |
| def guard_fail_hook( |
| guard_fn: GuardFn, code: types.CodeType, f_locals: Dict[str, object], last: bool |
| ) -> None: |
| """ |
| called whenever a guard fails. |
| """ |
| if not guard_fn.guard_fail_fn and not last: |
| return |
| scope = {rename_implicit(k): v for k, v in f_locals.items()} |
| scope.update(guard_fn.closure_vars) |
| reason = None |
| for part in guard_fn.verbose_code_parts: |
| fail_reason = eval(part, guard_fn.global_scope, scope) |
| # TODO(whc) hacky for now as not every 'part' in guard_fn.verbose_code_parts |
| # is updated to return a string explaining the failure. |
| if isinstance(fail_reason, str): |
| reason = fail_reason |
| break |
| elif isinstance(fail_reason, bool) and not fail_reason: |
| reason = part |
| break |
| try: |
| if guard_fn.guard_fail_fn is not None: |
| guard_fn.guard_fail_fn( |
| GuardFail(reason or "unknown reason", orig_code_map[code]) |
| ) |
| except Exception as e: |
| log.error( |
| "Failure in guard_fail_fn callback - raising here will cause a NULL Error on guard eval", |
| exc_info=True, |
| ) |
| |
| if last: |
| guard_failures[orig_code_map[code]].append(reason) |
| |
| |
| def guard_error_hook( |
| guard_fn: GuardFn, code: types.CodeType, f_locals: Dict[str, object], last: bool |
| ): |
| print( |
| f"ERROR RUNNING GUARDS {code.co_name} {code.co_filename}:{code.co_firstlineno}" |
| ) |
| # TODO: If we passed in the exception here, we could get a precise |
| # column number of which subexpression failed. But that would also |
| # require us to have the TRUE code that was eval'ed, not a shoddy |
| # reconstruction (like is done here) |
| print("lambda " + ", ".join(guard_fn.args) + ":") |
| print(" ", " and\n ".join(guard_fn.code_parts)) |
| |
| |
| set_guard_error_hook(guard_error_hook) |
| |
| |
| def unique(seq): |
| seen = set() |
| for x in seq: |
| if x not in seen: |
| yield x |
| seen.add(x) |