| import functools |
| import weakref |
| |
| import torch.nn |
| from torch.nn import Module |
| |
| from .utils import ExactWeakKeyDictionary |
| |
| |
| class MutationTracker: |
| db = ExactWeakKeyDictionary() |
| |
| def __init__(self): |
| self.mutation_count = 0 |
| self.watchers = [] |
| |
| def on_mutation(self, name): |
| self.mutation_count += 1 |
| tmp = self.watchers |
| self.watchers = [] |
| for ref in tmp: |
| guarded = ref() |
| if guarded is not None: |
| guarded.invalidate(ref) |
| |
| def track(self, guarded_code): |
| self.watchers.append(weakref.ref(guarded_code)) |
| |
| |
| def watch(obj, guarded_code): |
| """invalidate guarded_code when obj is mutated""" |
| ensure_patched(type(obj)) |
| |
| if obj not in MutationTracker.db: |
| MutationTracker.db[obj] = MutationTracker() |
| tracker = MutationTracker.db[obj] |
| tracker.track(guarded_code) |
| |
| |
| def ensure_patched(cls): |
| if getattr(cls, "___needs_mutation_patch", True): |
| cls.___needs_mutation_patch = False |
| original_setattr = cls.__setattr__ |
| |
| @functools.wraps(original_setattr) |
| def custom_setattr(self, key, value): |
| try: |
| MutationTracker.db[self].on_mutation(key) |
| except KeyError: |
| pass |
| return original_setattr(self, key, value) |
| |
| cls.__setattr__ = custom_setattr |
| |
| |
| class GenerationTracker: |
| generation = 0 |
| dynamic_classes = ExactWeakKeyDictionary() |
| generation_values = ExactWeakKeyDictionary() |
| |
| @classmethod |
| def tag(cls, obj): |
| cls.generation_values[obj] = cls.generation |
| |
| @staticmethod |
| def mark_class_dynamic(cls): |
| assert issubclass(cls, torch.nn.Module) |
| GenerationTracker.dynamic_classes[cls] = True |
| |
| @classmethod |
| def get_generation_value(cls, obj): |
| if obj not in cls.generation_values: |
| return -1 |
| return cls.generation_values[obj] |
| |
| @classmethod |
| def check(cls, obj): |
| return ( |
| obj in cls.generation_values |
| and cls.generation_values[obj] == cls.generation |
| ) |
| |
| |
| def is_dynamic_nn_module(obj): |
| """Check for nn.Modules() created dynamically or mutated""" |
| if hasattr(obj, "torchdynamo_force_dynamic"): |
| return obj.torchdynamo_force_dynamic |
| dyn = GenerationTracker.dynamic_classes.get(type(obj)) or GenerationTracker.check( |
| obj |
| ) |
| return dyn |
| |
| |
| def install_generation_tagging_init(): |
| """ |
| Monkey patch torch.nn.Module.__init__ and torch.nn.Module.__setstate__ |
| so we can detect nn.Module instances created dynamically inside forward methods. |
| """ |
| |
| if getattr(Module, "___needs_generation_tag_patch", True): |
| init = Module.__init__ |
| |
| def patched_init(self, *args, **kwargs): |
| init(self, *args, **kwargs) |
| GenerationTracker.tag(self) |
| |
| Module.__init__ = patched_init |
| |
| setstate = Module.__setstate__ |
| |
| def patched_setstate(self, state): |
| setstate(self, state) |
| GenerationTracker.tag(self) |
| |
| Module.__setstate__ = patched_setstate |
| |
| Module.___needs_generation_tag_patch = False |
| |
| GenerationTracker.generation += 1 |