Spaces:
Running
Running
File size: 3,614 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
# mypy: disable-error-code="method-assign"
import functools
import weakref
import torch.nn
from torch.nn import Module
from .utils import ExactWeakKeyDictionary, is_lazy_module
class MutationTracker:
db = ExactWeakKeyDictionary()
def __init__(self):
self.mutation_count = 0
self.watchers = []
def on_mutation(self, name):
self.mutation_count += 1
tmp = self.watchers
self.watchers = []
for ref in tmp:
guarded = ref()
if guarded is not None:
guarded.invalidate(ref)
def track(self, guarded_code):
self.watchers.append(weakref.ref(guarded_code))
def watch(obj, guarded_code):
"""invalidate guarded_code when obj is mutated"""
ensure_patched(type(obj))
if obj not in MutationTracker.db:
MutationTracker.db[obj] = MutationTracker()
tracker = MutationTracker.db[obj]
tracker.track(guarded_code)
def ensure_patched(cls):
if getattr(cls, "___needs_mutation_patch", True):
cls.___needs_mutation_patch = False
original_setattr = cls.__setattr__
@functools.wraps(original_setattr)
def custom_setattr(self, key, value):
try:
MutationTracker.db[self].on_mutation(key)
except KeyError:
pass
return original_setattr(self, key, value)
cls.__setattr__ = custom_setattr
class GenerationTracker:
generation = 0
dynamic_classes = ExactWeakKeyDictionary()
generation_values = ExactWeakKeyDictionary()
@classmethod
def tag(cls, obj):
cls.generation_values[obj] = cls.generation
@staticmethod
def mark_class_dynamic(cls):
assert issubclass(cls, torch.nn.Module)
GenerationTracker.dynamic_classes[cls] = True
@classmethod
def get_generation_value(cls, obj):
if obj not in cls.generation_values:
return -1
return cls.generation_values[obj]
@classmethod
def check(cls, obj):
return (
obj in cls.generation_values
and cls.generation_values[obj] == cls.generation
)
def is_dynamic_nn_module(obj):
"""Check for nn.Modules() created dynamically or mutated"""
if isinstance(obj, torch.nn.Module) and "forward" in obj.__dict__:
# A monkey patched `.forward` indicates something wacky is going on
return True
if hasattr(obj, "torchdynamo_force_dynamic"):
return obj.torchdynamo_force_dynamic
if is_lazy_module(obj):
return False
dyn = GenerationTracker.dynamic_classes.get(type(obj)) or GenerationTracker.check(
obj
)
return dyn
def install_generation_tagging_init():
"""
Monkey patch torch.nn.Module.__init__ and torch.nn.Module.__setstate__
so we can detect nn.Module instances created dynamically inside forward methods.
"""
if getattr(Module, "___needs_generation_tag_patch", True):
init = Module.__init__
def patched_init(self, *args, **kwargs):
init(self, *args, **kwargs)
GenerationTracker.tag(self)
Module.__init__ = patched_init
setstate = Module.__setstate__
def patched_setstate(self, state):
setstate(self, state)
GenerationTracker.tag(self)
Module.__setstate__ = patched_setstate
Module.___needs_generation_tag_patch = False # type: ignore[attr-defined]
GenerationTracker.generation += 1
|