| import inspect |
| from typing import Any, Dict, List, Optional, Union |
| |
| import torch.nn |
| |
| from . import utils, variables |
| from .bytecode_transformation import ( |
| create_call_function, |
| create_call_method, |
| create_instruction, |
| ) |
| from .codegen import PyCodegen |
| from .exc import unimplemented |
| from .source import LocalSource, Source |
| from .utils import nn_module_new, object_new |
| from .variables.base import ( |
| is_side_effect_safe, |
| MutableLocalBase, |
| MutableLocalSource, |
| VariableTracker, |
| ) |
| |
| |
| class MutableSideEffects(MutableLocalBase): |
| """ |
| VariableTracker.mutable_local marker to indicate a list passed as |
| an input that if we mutate we need to re-apply those mutations after |
| the graph runs. |
| """ |
| |
| def __init__(self, source: Source, is_modified: bool = False): |
| super().__init__(MutableLocalSource.Existing) |
| self.source = source |
| self.is_modified = is_modified |
| |
| |
| class AttributeMutation(MutableLocalBase): |
| """ |
| VariableTracker.mutable_local marker to track changes to attributes |
| """ |
| |
| def __init__(self, typ: MutableLocalSource, source: Optional[Source]): |
| super().__init__(typ) |
| self.source = source |
| |
| |
| class AttributeMutationExisting(AttributeMutation): |
| def __init__(self, source: Source): |
| super().__init__(MutableLocalSource.Existing, source) |
| self.source = source |
| |
| |
| class AttributeMutationNew(AttributeMutation): |
| def __init__(self, source: Optional[Source], cls_source: Optional[Source]): |
| super().__init__(MutableLocalSource.Local, source) |
| self.cls_source = cls_source |
| |
| |
| class SideEffects: |
| """ |
| Track side effects (list mutation, setattr, etc) that need to be |
| applied after an FX graph is run. |
| """ |
| |
| id_to_variable: Dict[int, VariableTracker] |
| store_attr_mutations: Dict[MutableLocalBase, Dict[str, VariableTracker]] |
| keepalive: List[Any] |
| |
| def __init__( |
| self, |
| id_to_variable=None, |
| store_attr_mutations=None, |
| keepalive=None, |
| save_for_backward=None, |
| tensor_hooks=None, |
| ): |
| super().__init__() |
| self.id_to_variable = id_to_variable or {} |
| self.store_attr_mutations = store_attr_mutations or {} |
| self.keepalive = keepalive or [] |
| self.save_for_backward = save_for_backward or [] |
| self.tensor_hooks = tensor_hooks or {} |
| |
| def __eq__(self, other: object) -> bool: |
| assert isinstance(other, SideEffects) |
| # NB: do NOT test keepalive |
| return ( |
| self.id_to_variable == other.id_to_variable |
| and self.store_attr_mutations == other.store_attr_mutations |
| and self.save_for_backward == other.save_for_backward |
| and self.tensor_hooks == other.tensor_hooks |
| ) |
| |
| def diff(self, other: "SideEffects") -> Optional[str]: |
| if self.id_to_variable != other.id_to_variable: |
| sk_itv = self.id_to_variable.keys() |
| ok_itv = other.id_to_variable.keys() |
| if sk_itv != ok_itv: |
| return f"id_to_variable keys: {sk_itv} != {ok_itv}" |
| # Feel free to augment this with more fancy diffing logic |
| # if needed for debugging |
| return "id_to_variable: unknown diff" |
| elif self.store_attr_mutations != other.store_attr_mutations: |
| sk_sam = self.store_attr_mutations.keys() |
| ok_sam = other.store_attr_mutations.keys() |
| if sk_sam != ok_sam: |
| return f"store_attr_mutations keys: {sk_sam} != {ok_sam}" |
| return "store_attr_mutations: unknown diff" |
| elif self.save_for_backward != other.save_for_backward: |
| return "save_for_backward" |
| else: |
| return None |
| |
| def clone(self): |
| """Create a shallow copy""" |
| return self.__class__( |
| id_to_variable=dict(self.id_to_variable), |
| store_attr_mutations={ |
| k: dict(v) for k, v in self.store_attr_mutations.items() |
| }, |
| keepalive=list(self.keepalive), |
| save_for_backward=self.save_for_backward, |
| tensor_hooks=self.tensor_hooks, |
| ) |
| |
| def apply(self, fn, cache=None, skip_fn=lambda _: False): |
| if cache is None: |
| cache = dict() |
| |
| self.id_to_variable = { |
| k: VariableTracker.apply(fn, v, cache, skip_fn) |
| for k, v in self.id_to_variable.items() |
| } |
| self.store_attr_mutations = { |
| k: VariableTracker.apply(fn, v, cache, skip_fn) |
| for k, v in self.store_attr_mutations.items() |
| } |
| self.save_for_backward = VariableTracker.apply( |
| fn, self.save_for_backward, cache, skip_fn |
| ) |
| self.tensor_hooks = VariableTracker.apply(fn, self.tensor_hooks, cache, skip_fn) |
| |
| def __contains__(self, item): |
| return id(item) in self.id_to_variable |
| |
| def __getitem__(self, item): |
| return self.id_to_variable[id(item)] |
| |
| def check_allowed_side_effect(self, item): |
| from torch._dynamo.variables.misc import AutogradFunctionContextVariable |
| |
| # People do things like self.dim = dim inside autograd.Function. |
| # These are benign. |
| if isinstance(item, AutogradFunctionContextVariable): |
| return True |
| if not is_side_effect_safe(item.mutable_local): |
| unimplemented( |
| "HigherOrderOperator: Mutating a variable not in the current scope (SideEffects)" |
| ) |
| |
| def store_attr(self, item: VariableTracker, name: str, value: VariableTracker): |
| assert self.is_attribute_mutation(item) |
| self.check_allowed_side_effect(item) |
| if item.mutable_local not in self.store_attr_mutations: |
| self.store_attr_mutations[item.mutable_local] = {} |
| self.store_attr_mutations[item.mutable_local][name] = value |
| |
| def load_attr(self, item, name, deleted_ok=False): |
| assert self.is_attribute_mutation(item) |
| result = self.store_attr_mutations[item.mutable_local][name] |
| if not deleted_ok and isinstance(result, variables.DeletedVariable): |
| unimplemented("read deleted attribute") |
| return result |
| |
| def store_cell(self, cellvar, value): |
| assert isinstance(cellvar, variables.NewCellVariable) |
| assert isinstance(value, variables.VariableTracker) |
| self.store_attr(cellvar, "cell_contents", value) |
| |
| def load_cell(self, cellvar): |
| assert isinstance(cellvar, variables.NewCellVariable) |
| return self.load_attr(cellvar, "cell_contents") |
| |
| def load_global(self, gvar: VariableTracker, name: str): |
| assert isinstance(gvar, variables.VariableTracker) |
| return self.load_attr(gvar, name) |
| |
| def store_global(self, gvar: VariableTracker, name: str, value: VariableTracker): |
| assert isinstance(gvar, variables.VariableTracker) |
| assert isinstance(value, variables.VariableTracker) |
| self.store_attr(gvar, name, value) |
| |
| @staticmethod |
| def cls_supports_mutation_side_effects(cls): |
| return inspect.getattr_static(cls, "__setattr__", None) in ( |
| object.__setattr__, |
| torch.nn.Module.__setattr__, |
| ) |
| |
| def is_attribute_mutation(self, item): |
| return isinstance(item.mutable_local, AttributeMutation) |
| |
| def has_pending_mutation(self, item): |
| return self.is_attribute_mutation(item) and bool( |
| self.store_attr_mutations.get(item.mutable_local) |
| ) |
| |
| def is_modified(self, item): |
| if isinstance(item.mutable_local, AttributeMutationNew): |
| return True |
| if self.is_attribute_mutation(item): |
| return item.mutable_local in self.store_attr_mutations |
| return item.mutable_local.is_modified |
| |
| def _track_obj( |
| self, |
| item: Any, |
| variable: VariableTracker, |
| mutable_cls=MutableSideEffects, |
| ): |
| """Start tracking a new variable for mutation""" |
| assert variable.source is not None |
| variable.mutable_local = mutable_cls(variable.source) |
| self.id_to_variable[id(item)] = variable |
| self.keepalive.append(item) |
| return variable |
| |
| track_mutable = _track_obj |
| |
| def track_object_existing( |
| self, |
| item: Any, |
| variable: VariableTracker, |
| ): |
| return self._track_obj(item, variable, mutable_cls=AttributeMutationExisting) |
| |
| def track_object_new( |
| self, |
| cls_source: Source, |
| user_cls: Any, |
| variable_cls: Any, |
| options, |
| ): |
| if user_cls is torch.autograd.function.FunctionCtx: |
| obj = torch.autograd.Function() |
| elif issubclass(user_cls, torch.nn.Module): |
| obj = nn_module_new(user_cls) |
| else: |
| obj = object_new(user_cls) |
| variable = variable_cls( |
| obj, |
| mutable_local=AttributeMutationNew(None, cls_source), |
| **options, |
| ) |
| self.id_to_variable[id(obj)] = variable |
| self.keepalive.append(obj) |
| return variable |
| |
| def track_cell_new( |
| self, |
| ): |
| obj = object() |
| variable = variables.NewCellVariable( |
| mutable_local=AttributeMutationNew(None, None), |
| ) |
| self.id_to_variable[id(obj)] = variable |
| self.keepalive.append(obj) |
| return variable |
| |
| def track_cell_existing(self, source: Source, item: Any): |
| variable = variables.NewCellVariable( |
| mutable_local=AttributeMutationExisting(source), |
| ) |
| self.id_to_variable[id(item)] = variable |
| self.keepalive.append(item) |
| return variable |
| |
| def track_global_existing(self, source: Source, item: Any): |
| variable = variables.NewGlobalVariable( |
| mutable_local=AttributeMutationExisting(source), |
| ) |
| self.id_to_variable[id(item)] = variable |
| self.keepalive.append(item) |
| return variable |
| |
| def track_save_for_backward(self, ctx, args): |
| assert isinstance(ctx, variables.AutogradFunctionContextVariable) |
| self.save_for_backward.append((ctx, args)) |
| |
| def track_tensor_variables_from_runahead_side_effects(self, other): |
| # In higher order ops we want to keep track of tensors seen in the |
| # speculate_subgraph so that we don't lift them again as a new input in |
| # other speculate_subgraph or in the root tracer. |
| for other_item in other.keepalive: |
| other_id = id(other_item) |
| other_variable = other.id_to_variable[other_id] |
| if other_id not in self.id_to_variable and isinstance( |
| other_variable, variables.TensorVariable |
| ): |
| self.track_object_existing(other_item, other_variable) |
| |
| def prune_dead_object_new(self, tx): |
| live_new_objects = set() |
| skip_obj = None |
| |
| def visit(var: VariableTracker): |
| if ( |
| isinstance(var.mutable_local, AttributeMutationNew) |
| and var.mutable_local is not skip_obj |
| ): |
| live_new_objects.add(var.mutable_local) |
| return var |
| |
| def is_live(var: Union[MutableLocalBase, VariableTracker]): |
| if isinstance(var, AttributeMutationNew): |
| return var in live_new_objects |
| if isinstance(var, VariableTracker): |
| return is_live(var.mutable_local) |
| return True |
| |
| VariableTracker.apply(visit, (tx.stack, tx.symbolic_locals)) |
| for var in self.id_to_variable.values(): |
| if not isinstance(var.mutable_local, AttributeMutationNew): |
| VariableTracker.apply(visit, var) |
| |
| for skip_obj, setattrs in self.store_attr_mutations.items(): |
| VariableTracker.apply(visit, setattrs) |
| |
| self.id_to_variable = { |
| k: v for k, v in self.id_to_variable.items() if is_live(v) |
| } |
| self.store_attr_mutations = { |
| k: v for k, v in self.store_attr_mutations.items() if is_live(k) |
| } |
| |
| def mutation(self, var): |
| self.check_allowed_side_effect(var) |
| if isinstance(var.mutable_local, MutableSideEffects): |
| var.mutable_local = MutableSideEffects(var.mutable_local.source, True) |
| |
| def _get_modified_vars(self): |
| return [var for var in self.id_to_variable.values() if self.is_modified(var)] |
| |
| def codegen_save_tempvars(self, cg: PyCodegen): |
| for var in self._get_modified_vars(): |
| if isinstance( |
| var.mutable_local, (AttributeMutationExisting, AttributeMutationNew) |
| ) and isinstance(var, variables.NewCellVariable): |
| cg.load_import_from(utils.__name__, "make_cell") |
| cg.extend_output(create_call_function(0, True)) |
| cg.add_cache(var) |
| if isinstance(var.mutable_local, AttributeMutationNew): |
| var.mutable_local.source = LocalSource(cg.tempvars[var]) # type: ignore[attr-defined] |
| elif isinstance(var.mutable_local, AttributeMutationNew): |
| if isinstance(var, variables.AutogradFunctionContextVariable): |
| unimplemented("AutogradFunctionContextVariable escaped") |
| if "__call_nn_module_init" in self.store_attr_mutations.get( |
| var.mutable_local, {} |
| ): |
| assert isinstance(var, variables.UnspecializedNNModuleVariable) |
| cg.load_import_from(utils.__name__, "nn_module_new") |
| else: |
| cg.load_import_from(utils.__name__, "object_new") |
| cg(var.mutable_local.cls_source) |
| cg.extend_output(create_call_function(1, True)) |
| cg.add_cache(var) |
| var.mutable_local.source = LocalSource(cg.tempvars[var]) |
| elif var in cg.tempvars: |
| assert cg.tempvars.get(var) is None |
| # subsequent usage should point to the original variable |
| cg(var.mutable_local.source) |
| cg.add_cache(var) |
| |
| for ctx, args in self.save_for_backward: |
| cg(ctx.source) |
| cg.extend_output( |
| [create_instruction("LOAD_METHOD", argval="save_for_backward")] |
| ) |
| for arg in args: |
| cg(arg) |
| cg.extend_output( |
| [ |
| *create_call_method(len(args)), |
| create_instruction("POP_TOP"), |
| ] |
| ) |
| |
| def register_hook(self, tensor, hook, handle, name): |
| assert isinstance(tensor, variables.TensorVariable) |
| assert isinstance(hook, variables.VariableTracker) |
| assert ( |
| isinstance(handle, variables.RemovableHandleVariable) |
| and handle.mutable_local |
| ) |
| assert hasattr(torch.Tensor, name) |
| idx = len(self.tensor_hooks.keys()) |
| # duplicate index possible because of self.remove_hook() |
| while idx in self.tensor_hooks: |
| idx += 1 |
| self.tensor_hooks[idx] = (tensor, hook, handle, name) |
| assert not handle.idx |
| handle.idx = idx |
| |
| def remove_hook(self, idx): |
| del self.tensor_hooks[idx] |
| |
| def codegen_hooks(self, cg): |
| for ( |
| tensor, |
| hook, |
| handle, |
| name, |
| ) in self.tensor_hooks.values(): |
| # Note: [On tensor.register_hook] |
| # |
| # register_hook on a tensor, AKA backward hooks, have slightly nuanced differences in how they are implemented |
| # when it comes to hooks on objects with sources (inputs, params) vs objects without sources (intermediaries). |
| # |
| # For tensors with a source, we bypass direct inclusion of register_hook calls in the graph. |
| # Instead, these are tracked and stashed as a global variable, enabling their association with tensors in |
| # the residuals. During dynamo's frame creation, these hooks are invoked seamlessly on known reconstructible/fetch-able |
| # tensors. Because a source indicates knowledge of this object outside the torch compile region, and |
| # because we are running residuals firmly before .backward() can be run, it is sound to invoke |
| # `register_hook` on a known tensor. |
| # |
| # For tensors without a source, we support a limited subset of hooks. Global functions only, and |
| # compiled_autograd must be enabled or we will graph break. |
| # |
| # Handling the Handle: When a user retains the register_hook result in a handle, we intercept the |
| # STORE_FAST operation to record the user-designated local variable name. This ensures the reconstructed |
| # bytecode retains this name. If no handle is defined, we simply pop the generated value to keep the |
| # stack intact. |
| # |
| # Dynamo Tensor Hooks Workflow: |
| # - Functions passed to register_hook are lifted globally. |
| # - For tensors with sources: |
| # - In the "side_effects" phase of codegen, we iterate over tensors with hooks to: |
| # - Generate the tensor. |
| # - Issue a register_hook call on the tensor, linking to the globally stored function. |
| # - Incorporate a handle if one was established in the eager phase. |
| # - For tensors without sources: |
| # - We don't generate any instructions for registering a hook. |
| # - Handles from intermediary hooks are NYI. |
| # - We produce a call function that utilizes the trace_wrapped higher order op, closing over it. |
| # - We then manually insert the call function above into the graph. |
| # - The handle's exact user-specified name, "user_code_variable_name", is discerned and associated during STORE_FAST. |
| assert tensor.source, "Hooks on non input tensors NYI - should not get here" |
| cg(tensor) |
| cg.extend_output([cg.create_load_attr(name)]) |
| cg(hook) |
| cg.extend_output(create_call_function(1, True)) |
| |
| # Adding the handle to the cache means RemovableHandleVariable().reconstruct() will |
| # be associated with the return value of register_hook(). This consumes the top of stack. |
| cg.add_cache(handle) |
| |
| def codegen_update_mutated(self, cg: PyCodegen): |
| suffixes = [] |
| for var in self._get_modified_vars(): |
| if isinstance(var, variables.ListVariable): |
| # old[:] = new |
| cg(var, allow_cache=False) |
| cg(var.mutable_local.source) # type: ignore[attr-defined] |
| cg.extend_output( |
| [ |
| cg.create_load_const(None), |
| cg.create_load_const(None), |
| create_instruction("BUILD_SLICE", arg=2), |
| ] |
| ) |
| suffixes.append([create_instruction("STORE_SUBSCR")]) |
| elif isinstance(var, variables.ConstDictVariable): |
| cg.tx.output.update_co_names("clear") |
| cg.tx.output.update_co_names("update") |
| |
| cg(var.mutable_local.source) # type: ignore[attr-defined] |
| cg.extend_output([create_instruction("LOAD_METHOD", argval="update")]) |
| cg(var, allow_cache=False) |
| |
| cg(var.mutable_local.source) # type: ignore[attr-defined] |
| cg.extend_output([create_instruction("LOAD_METHOD", argval="clear")]) |
| |
| suffixes.append( |
| [ |
| *create_call_method(0), # clear |
| create_instruction("POP_TOP"), |
| *create_call_method(1), # update |
| create_instruction("POP_TOP"), |
| ] |
| ) |
| elif self.is_attribute_mutation(var): |
| for name, value in self.store_attr_mutations.get( |
| var.mutable_local, {} |
| ).items(): |
| if isinstance(var, variables.NewGlobalVariable): |
| cg.tx.output.update_co_names(name) |
| cg(value) |
| suffixes.append( |
| [create_instruction("STORE_GLOBAL", argval=name)] |
| ) |
| elif name == "__call_nn_module_init": |
| pass # handled in codegen_save_tempvars |
| elif isinstance(value, variables.DeletedVariable): |
| if isinstance( |
| var.mutable_local, AttributeMutationExisting |
| ) and hasattr(getattr(var, "value", None), name): |
| cg.tx.output.update_co_names(name) |
| cg(var.mutable_local.source) |
| suffixes.append( |
| [create_instruction("DELETE_ATTR", argval=name)] |
| ) |
| else: |
| cg.tx.output.update_co_names(name) |
| cg(value) |
| cg(var.mutable_local.source) |
| suffixes.append([create_instruction("STORE_ATTR", argval=name)]) |
| elif isinstance(var, variables.TupleIteratorVariable): |
| for _ in range(var.index): |
| cg.load_import_from(utils.__name__, "iter_next") |
| cg(var.mutable_local.source) # type: ignore[attr-defined] |
| cg.extend_output(create_call_function(1, True)) |
| cg.append_output(create_instruction("POP_TOP")) |
| else: |
| raise AssertionError(type(var)) |
| |
| # do all the actual mutations at the very end to handle dependencies |
| for suffix in reversed(suffixes): |
| cg.extend_output(suffix) |
| |
| def is_empty(self): |
| return not ( |
| any(map(self.is_modified, self.id_to_variable.values())) |
| or self.tensor_hooks |
| or self.save_for_backward |
| or self.tensor_hooks |
| ) |
| |
| def clear(self): |
| self.keepalive.clear() |
| self.id_to_variable.clear() |