|
|
|
|
|
import collections |
|
import dataclasses |
|
import functools |
|
import inspect |
|
import sys |
|
from typing import Dict, List, Optional |
|
|
|
from torch._subclasses.fake_tensor import is_fake |
|
|
|
from .. import polyfill, variables |
|
from ..bytecode_transformation import ( |
|
create_call_function, |
|
create_call_method, |
|
create_instruction, |
|
create_load_method, |
|
) |
|
from ..eval_frame import skip_code |
|
from ..exc import unimplemented |
|
from ..guards import GuardBuilder, install_guard |
|
from ..source import AttrSource, GetItemSource |
|
from ..utils import dict_keys, dict_values, istype, specialize_symnode |
|
from .base import MutableLocal, VariableTracker |
|
from .constant import ConstantVariable |
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_hashable(x): |
|
if isinstance(x, variables.TensorVariable): |
|
|
|
|
|
|
|
return x.as_proxy().node.meta.get("example_value") is not None |
|
elif isinstance(x, variables.TupleVariable): |
|
return all(is_hashable(e) for e in x.items) |
|
else: |
|
return isinstance( |
|
x, |
|
( |
|
variables.BuiltinVariable, |
|
variables.SymNodeVariable, |
|
variables.ConstantVariable, |
|
variables.EnumVariable, |
|
variables.user_defined.UserDefinedClassVariable, |
|
variables.UserFunctionVariable, |
|
variables.SkipFunctionVariable, |
|
variables.misc.NumpyVariable, |
|
variables.NNModuleVariable, |
|
variables.UnspecializedNNModuleVariable, |
|
variables.MethodWrapperVariable, |
|
variables.TorchInGraphFunctionVariable, |
|
variables.TypingVariable, |
|
variables.FunctoolsPartialVariable, |
|
), |
|
) |
|
|
|
|
|
class ConstDictVariable(VariableTracker): |
|
_nonvar_fields = { |
|
"user_cls", |
|
*VariableTracker._nonvar_fields, |
|
} |
|
|
|
class _HashableTracker: |
|
""" |
|
Auxiliary opaque internal class that wraps a VariableTracker and makes it hashable |
|
This should not be seen or touched by anything outside of ConstDictVariable and its children |
|
Note that it's also fine to put VTs into dictionaries and sets, but doing so does not take into account aliasing |
|
""" |
|
|
|
def __init__(self, vt): |
|
|
|
vt = specialize_symnode(vt) |
|
|
|
|
|
if not is_hashable(vt): |
|
unimplemented(f"Dict key of type {type(vt)}. Key: {vt}") |
|
self.vt = vt |
|
|
|
@property |
|
def underlying_value(self): |
|
if isinstance(self.vt, variables.TensorVariable): |
|
x = self.vt.as_proxy().node.meta["example_value"] |
|
elif isinstance(self.vt, variables.TupleVariable): |
|
Hashable = ConstDictVariable._HashableTracker |
|
x = tuple(Hashable(e).underlying_value for e in self.vt.items) |
|
elif isinstance(self.vt, variables.NNModuleVariable): |
|
return self.vt.module |
|
elif isinstance(self.vt, variables.UnspecializedNNModuleVariable): |
|
return self.vt.value |
|
elif isinstance(self.vt, variables.UserFunctionVariable): |
|
return self.vt.get_function() |
|
else: |
|
x = self.vt.as_python_constant() |
|
return x |
|
|
|
def __hash__(self): |
|
return hash(self.underlying_value) |
|
|
|
@staticmethod |
|
def _eq_impl(a, b): |
|
|
|
if type(a) != type(b): |
|
return False |
|
elif isinstance(a, tuple): |
|
Hashable = ConstDictVariable._HashableTracker |
|
return len(a) == len(b) and all( |
|
Hashable._eq_impl(u, v) for u, v in zip(a, b) |
|
) |
|
elif is_fake(a): |
|
return a is b |
|
else: |
|
return a == b |
|
|
|
def __eq__(self, other: "ConstDictVariable._HashableTracker") -> bool: |
|
Hashable = ConstDictVariable._HashableTracker |
|
assert isinstance(other, Hashable) or ConstantVariable.is_literal( |
|
other |
|
), type(other) |
|
if isinstance(other, Hashable): |
|
return Hashable._eq_impl(self.underlying_value, other.underlying_value) |
|
|
|
|
|
return Hashable._eq_impl(self.underlying_value, other) |
|
|
|
def __init__( |
|
self, items: Dict[VariableTracker, VariableTracker], user_cls=dict, **kwargs |
|
): |
|
super().__init__(**kwargs) |
|
|
|
Hashable = ConstDictVariable._HashableTracker |
|
|
|
|
|
assert all( |
|
isinstance(x, (VariableTracker, Hashable)) |
|
and isinstance(v, VariableTracker) |
|
for x, v in items.items() |
|
) |
|
|
|
def make_hashable(key): |
|
return key if isinstance(key, Hashable) else Hashable(key) |
|
|
|
self.items = {make_hashable(x): v for x, v in items.items()} |
|
self.user_cls = user_cls |
|
|
|
def as_proxy(self): |
|
return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()} |
|
|
|
def debug_repr(self): |
|
return ( |
|
"{" |
|
+ ", ".join( |
|
f"{k.vt.debug_repr()}: {v.debug_repr()}" for k, v in self.items.items() |
|
) |
|
+ "}" |
|
) |
|
|
|
def as_python_constant(self): |
|
return { |
|
k.vt.as_python_constant(): v.as_python_constant() |
|
for k, v in self.items.items() |
|
} |
|
|
|
def keys_as_python_constant(self): |
|
return {k.vt.as_python_constant(): v for k, v in self.items.items()} |
|
|
|
def python_type(self): |
|
return self.user_cls |
|
|
|
def __contains__(self, vt): |
|
assert isinstance(vt, VariableTracker) |
|
Hashable = ConstDictVariable._HashableTracker |
|
return is_hashable(vt) and Hashable(vt) in self.items |
|
|
|
def reconstruct(self, codegen): |
|
|
|
if self.user_cls is collections.OrderedDict: |
|
codegen.extend_output( |
|
[ |
|
codegen.create_load_python_module(collections, True), |
|
codegen.create_load_attr("OrderedDict"), |
|
] |
|
) |
|
|
|
for key, value in self.items.items(): |
|
codegen(key.vt) |
|
codegen(value) |
|
|
|
if self.user_cls is collections.OrderedDict: |
|
codegen.extend_output( |
|
[ |
|
create_instruction("BUILD_MAP", arg=len(self.items)), |
|
*create_call_function(1, False), |
|
] |
|
) |
|
|
|
else: |
|
codegen.append_output(create_instruction("BUILD_MAP", arg=len(self.items))) |
|
|
|
def getitem_const(self, arg: VariableTracker): |
|
key = ConstDictVariable._HashableTracker(arg) |
|
if key not in self.items: |
|
unimplemented(f"dict KeyError: {arg.value}") |
|
return self.items[key] |
|
|
|
def call_method( |
|
self, |
|
tx, |
|
name, |
|
args: "List[VariableTracker]", |
|
kwargs: "Dict[str, VariableTracker]", |
|
) -> "VariableTracker": |
|
from . import ( |
|
BuiltinVariable, |
|
ConstantVariable, |
|
ListIteratorVariable, |
|
ListVariable, |
|
TupleVariable, |
|
) |
|
|
|
Hashable = ConstDictVariable._HashableTracker |
|
|
|
arg_hashable = args and is_hashable(args[0]) |
|
|
|
if name == "__getitem__": |
|
assert len(args) == 1 |
|
return self.getitem_const(args[0]) |
|
elif name == "items": |
|
assert not (args or kwargs) |
|
if self.source: |
|
tx.output.guard_on_key_order.add(self.source.name()) |
|
return TupleVariable( |
|
[TupleVariable([k.vt, v]) for k, v in self.items.items()] |
|
) |
|
elif name == "keys": |
|
if self.source: |
|
tx.output.guard_on_key_order.add(self.source.name()) |
|
assert not (args or kwargs) |
|
return DictKeys(self) |
|
elif name == "values": |
|
if self.source: |
|
tx.output.guard_on_key_order.add(self.source.name()) |
|
assert not (args or kwargs) |
|
return DictValues(self) |
|
elif name == "copy": |
|
assert not (args or kwargs) |
|
return self.clone(items=self.items.copy(), mutable_local=MutableLocal()) |
|
elif name == "__len__": |
|
assert not (args or kwargs) |
|
return ConstantVariable.create(len(self.items)) |
|
elif name == "__setitem__" and arg_hashable and self.mutable_local: |
|
assert not kwargs and len(args) == 2 |
|
tx.output.side_effects.mutation(self) |
|
self.items[Hashable(args[0])] = args[1] |
|
return ConstantVariable.create(None) |
|
elif name == "__delitem__" and arg_hashable and self.mutable_local: |
|
tx.output.side_effects.mutation(self) |
|
self.items.__delitem__(Hashable(args[0])) |
|
return ConstantVariable.create(None) |
|
elif name in ("pop", "get") and len(args) in (1, 2) and args[0] not in self: |
|
|
|
if len(args) == 1: |
|
return ConstantVariable(None) |
|
else: |
|
return args[1] |
|
elif name == "pop" and arg_hashable and self.mutable_local: |
|
tx.output.side_effects.mutation(self) |
|
return self.items.pop(Hashable(args[0])) |
|
elif name == "clear": |
|
tx.output.side_effects.mutation(self) |
|
self.items.clear() |
|
return ConstantVariable.create(None) |
|
elif ( |
|
name == "update" |
|
and len(args) == 1 |
|
and isinstance( |
|
args[0], |
|
( |
|
ConstDictVariable, |
|
ListVariable, |
|
TupleVariable, |
|
ListIteratorVariable, |
|
), |
|
) |
|
and self.mutable_local |
|
): |
|
tx.output.side_effects.mutation(self) |
|
if isinstance(args[0], ConstDictVariable): |
|
dict_vt = args[0] |
|
else: |
|
dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0]) |
|
self.items.update(dict_vt.items) |
|
|
|
kwargs = { |
|
Hashable(ConstantVariable.create(k)): v for k, v in kwargs.items() |
|
} |
|
self.items.update(kwargs) |
|
return ConstantVariable.create(None) |
|
elif name in ("get", "__getattr__") and args[0] in self: |
|
return self.getitem_const(args[0]) |
|
elif name == "__contains__" and len(args) == 1: |
|
return ConstantVariable.create(args[0] in self) |
|
else: |
|
return super().call_method(tx, name, args, kwargs) |
|
|
|
def unpack_var_sequence(self, tx): |
|
return [x.vt for x in self.items.keys()] |
|
|
|
|
|
class DefaultDictVariable(ConstDictVariable): |
|
def __init__(self, items, user_cls, default_factory=None, **kwargs): |
|
super().__init__(items, user_cls, **kwargs) |
|
assert user_cls is collections.defaultdict |
|
self.default_factory = default_factory |
|
|
|
def is_python_constant(self): |
|
|
|
|
|
if self.default_factory not in [list, tuple, dict] and not self.items: |
|
return False |
|
return super().is_python_constant() |
|
|
|
def debug_repr(self): |
|
return ( |
|
f"defaultdict({self.default_factory.debug_repr()}, {super().debug_repr()})" |
|
) |
|
|
|
@staticmethod |
|
def is_supported_arg(arg): |
|
if isinstance(arg, variables.BuiltinVariable): |
|
return arg.fn in [list, tuple, dict] |
|
else: |
|
return isinstance(arg, variables.functions.BaseUserFunctionVariable) |
|
|
|
def call_method( |
|
self, |
|
tx, |
|
name, |
|
args: "List[VariableTracker]", |
|
kwargs: "Dict[str, VariableTracker]", |
|
) -> "VariableTracker": |
|
if name == "__getitem__": |
|
assert len(args) == 1 |
|
|
|
if args[0] in self: |
|
return self.getitem_const(args[0]) |
|
else: |
|
if self.default_factory is None: |
|
raise KeyError(f"{args[0]}") |
|
else: |
|
default_var = self.default_factory.call_function(tx, [], {}) |
|
super().call_method( |
|
tx, "__setitem__", (args[0], default_var), kwargs |
|
) |
|
return default_var |
|
else: |
|
return super().call_method(tx, name, args, kwargs) |
|
|
|
|
|
class SetVariable(ConstDictVariable): |
|
"""We model a sets as dictonary with None values""" |
|
|
|
def __init__( |
|
self, |
|
items: List[VariableTracker], |
|
**kwargs, |
|
): |
|
items = dict.fromkeys(items, SetVariable._default_value()) |
|
super().__init__(items, **kwargs) |
|
|
|
def debug_repr(self): |
|
if not self.items: |
|
return "set()" |
|
else: |
|
return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}" |
|
|
|
@property |
|
def set_items(self): |
|
return set(self.items.keys()) |
|
|
|
@staticmethod |
|
def _default_value(): |
|
|
|
return ConstantVariable.create(None) |
|
|
|
def as_proxy(self): |
|
return {k.vt.as_proxy() for k in self.set_items} |
|
|
|
def python_type(self): |
|
return set |
|
|
|
def as_python_constant(self): |
|
return {k.vt.as_python_constant() for k in self.set_items} |
|
|
|
def reconstruct(self, codegen): |
|
codegen.foreach([x.vt for x in self.set_items]) |
|
codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items))) |
|
|
|
def call_method( |
|
self, |
|
tx, |
|
name, |
|
args: List[VariableTracker], |
|
kwargs: Dict[str, VariableTracker], |
|
) -> "VariableTracker": |
|
from . import ListVariable, TupleVariable |
|
|
|
|
|
if name == "add": |
|
assert not kwargs |
|
assert len(args) == 1 |
|
name = "__setitem__" |
|
args = (args[0], SetVariable._default_value()) |
|
elif name == "pop": |
|
assert not kwargs |
|
assert not args |
|
|
|
result = self.set_items.pop().vt |
|
super().call_method(tx, name, (result,), kwargs) |
|
return result |
|
elif name == "isdisjoint": |
|
assert not kwargs |
|
assert len(args) == 1 |
|
return variables.UserFunctionVariable( |
|
polyfill.set_isdisjoint |
|
).call_function(tx, [self, args[0]], {}) |
|
elif ( |
|
name == "update" |
|
and len(args) == 1 |
|
and isinstance( |
|
args[0], |
|
( |
|
SetVariable, |
|
ListVariable, |
|
TupleVariable, |
|
), |
|
) |
|
and self.mutable_local |
|
): |
|
if isinstance(args[0], (ListVariable, TupleVariable)): |
|
arg = SetVariable(args[0].unpack_var_sequence(tx)) |
|
else: |
|
arg = args[0] |
|
return super().call_method(tx, "update", (arg,), kwargs) |
|
return super().call_method(tx, name, args, kwargs) |
|
|
|
def getitem_const(self, arg: VariableTracker): |
|
raise RuntimeError("Illegal to getitem on a set") |
|
|
|
|
|
class DictView(VariableTracker): |
|
""" |
|
Models _PyDictViewObject |
|
|
|
This is an "abstract" class. Subclasses will override kv and the items method |
|
""" |
|
|
|
kv: Optional[str] = None |
|
|
|
def __init__(self, dv_dict: ConstDictVariable, **kwargs): |
|
super().__init__(**kwargs) |
|
assert self.kv in ("keys", "values") |
|
assert isinstance(dv_dict, ConstDictVariable) |
|
self.dv_dict = dv_dict |
|
|
|
@property |
|
def view_items(self): |
|
return getattr(self.dv_dict.items, self.kv)() |
|
|
|
@property |
|
def view_items_vt(self): |
|
|
|
|
|
raise NotImplementedError |
|
|
|
def unpack_var_sequence(self, tx): |
|
def unwrap(x): |
|
return x.vt if self.kv == "keys" else x |
|
|
|
return [unwrap(x) for x in self.view_items] |
|
|
|
def reconstruct(self, codegen): |
|
codegen(self.dv_dict) |
|
codegen.extend_output( |
|
[ |
|
create_load_method(self.kv), |
|
*create_call_method(0), |
|
] |
|
) |
|
|
|
def call_method( |
|
self, |
|
tx, |
|
name, |
|
args: List["VariableTracker"], |
|
kwargs: Dict[str, "VariableTracker"], |
|
) -> "VariableTracker": |
|
if name == "__len__": |
|
return self.dv_dict.call_method(tx, name, args, kwargs) |
|
return super().call_method(tx, name, args, kwargs) |
|
|
|
|
|
class DictKeys(DictView): |
|
kv = "keys" |
|
|
|
@property |
|
def set_items(self): |
|
return set(self.view_items) |
|
|
|
@property |
|
def view_items_vt(self): |
|
|
|
return [x.vt for x in self.view_items] |
|
|
|
def python_type(self): |
|
return dict_keys |
|
|
|
def call_method( |
|
self, |
|
tx, |
|
name, |
|
args: List["VariableTracker"], |
|
kwargs: Dict[str, "VariableTracker"], |
|
) -> "VariableTracker": |
|
if name == "__contains__": |
|
return self.dv_dict.call_method(tx, name, args, kwargs) |
|
return super().call_method(tx, name, args, kwargs) |
|
|
|
|
|
class DictValues(DictView): |
|
|
|
kv = "values" |
|
|
|
@property |
|
def view_items_vt(self): |
|
return list(self.view_items) |
|
|
|
def python_type(self): |
|
return dict_values |
|
|
|
|
|
def _is_matching_transformers_cls(cls) -> bool: |
|
mod = sys.modules.get("transformers.file_utils") |
|
return mod is not None and issubclass(cls, mod.ModelOutput) |
|
|
|
|
|
def _is_matching_diffusers_cls(cls) -> bool: |
|
mod = sys.modules.get("diffusers.utils") |
|
return mod is not None and issubclass(cls, mod.BaseOutput) |
|
|
|
|
|
def _call_hasattr_customobj(self, tx, name: str) -> "VariableTracker": |
|
"""Shared method between DataClassVariable and CustomizedDictVariable where items are attrs""" |
|
if name in self.items or hasattr(self.user_cls, name): |
|
return ConstantVariable(True) |
|
elif istype(self.mutable_local, MutableLocal) and self.source is None: |
|
|
|
return ConstantVariable(False) |
|
elif self.mutable_local is None and self.source: |
|
|
|
try: |
|
example = tx.output.root_tx.get_example_value(self.source) |
|
install_guard( |
|
AttrSource(self.source, name).make_guard(GuardBuilder.HASATTR) |
|
) |
|
return ConstantVariable(hasattr(example, name)) |
|
except KeyError: |
|
pass |
|
unimplemented( |
|
f"hasattr({self.__class__.__name__}, {name}) {self.mutable_local} {self.source}" |
|
) |
|
|
|
|
|
class DataClassVariable(ConstDictVariable): |
|
""" |
|
This is a bit of a hack to deal with |
|
transformers.file_utils.ModelOutput() from huggingface. |
|
|
|
ModelOutput causes trouble because it a a mix of a dataclass and a |
|
OrderedDict and it calls super() methods implemented in C. |
|
""" |
|
|
|
|
|
include_none = False |
|
|
|
@staticmethod |
|
@functools.lru_cache(None) |
|
def _patch_once(): |
|
try: |
|
from transformers.file_utils import ModelOutput |
|
|
|
for obj in ModelOutput.__dict__.values(): |
|
if callable(obj): |
|
skip_code(obj.__code__) |
|
except ImportError: |
|
pass |
|
|
|
try: |
|
from diffusers.utils import BaseOutput |
|
|
|
for obj in BaseOutput.__dict__.values(): |
|
if callable(obj): |
|
skip_code(obj.__code__) |
|
except ImportError: |
|
pass |
|
|
|
@staticmethod |
|
def is_matching_cls(cls): |
|
return _is_matching_transformers_cls(cls) or _is_matching_diffusers_cls(cls) |
|
|
|
@classmethod |
|
def is_matching_object(cls, obj): |
|
return cls.is_matching_cls(type(obj)) |
|
|
|
@classmethod |
|
def create(cls, user_cls, args, kwargs, options): |
|
DataClassVariable._patch_once() |
|
|
|
skip_code(user_cls.__init__.__code__) |
|
keys = [f.name for f in dataclasses.fields(user_cls)] |
|
bound = inspect.signature(user_cls).bind(*args, **kwargs) |
|
bound.apply_defaults() |
|
assert set(bound.arguments.keys()) == set(keys) |
|
items = {} |
|
for key in keys: |
|
val = bound.arguments[key] |
|
key = ConstantVariable.create(key) |
|
if isinstance(val, VariableTracker): |
|
items[key] = val |
|
else: |
|
if cls.include_none: |
|
assert variables.ConstantVariable.is_literal(val) |
|
items[key] = variables.ConstantVariable.create(val) |
|
else: |
|
assert val is None, f"unexpected {val}" |
|
|
|
if len(items) == 1 and not isinstance(items[keys[0]], variables.TensorVariable): |
|
unimplemented("DataClassVariable iterator constructor") |
|
|
|
|
|
return cls(items, user_cls, **options) |
|
|
|
@classmethod |
|
def wrap(cls, builder, obj): |
|
user_cls = type(obj) |
|
keys = [f.name for f in dataclasses.fields(user_cls)] |
|
|
|
excluded = [] |
|
items = {} |
|
for key in keys: |
|
|
|
if hasattr(obj, key): |
|
val = getattr(obj, key) |
|
var = builder.__class__( |
|
tx=builder.tx, source=AttrSource(builder.source, key) |
|
)(val) |
|
if val is not None or cls.include_none: |
|
key = ConstantVariable.create(key) |
|
items[key] = var |
|
else: |
|
excluded.append(var) |
|
return cls(items, user_cls) |
|
|
|
def __init__(self, items, user_cls, **options): |
|
super().__init__(items, user_cls, **options) |
|
assert self.is_matching_cls(user_cls) |
|
|
|
def as_proxy(self): |
|
raise NotImplementedError |
|
|
|
def reconstruct(self, codegen): |
|
codegen.extend_output([codegen._create_load_const(self.user_cls)]) |
|
|
|
d = self.keys_as_python_constant() |
|
codegen.foreach(d.values()) |
|
keys = tuple(d.keys()) |
|
codegen.extend_output(codegen.create_call_function_kw(len(keys), keys, True)) |
|
|
|
def call_method( |
|
self, |
|
tx, |
|
name, |
|
args: "List[VariableTracker]", |
|
kwargs: "Dict[str, VariableTracker]", |
|
) -> "VariableTracker": |
|
if name == "__getitem__": |
|
assert not kwargs and len(args) == 1 |
|
val = args[0] |
|
if val.python_type() == str: |
|
return self.getitem_const(val) |
|
else: |
|
return self.call_method(tx, "to_tuple", [], {}).call_method( |
|
tx, "__getitem__", args, kwargs |
|
) |
|
elif name == "to_tuple": |
|
assert not (args or kwargs) |
|
return variables.TupleVariable(list(self.items.values())) |
|
elif name == "__setattr__": |
|
name = "__setitem__" |
|
return super().call_method(tx, name, args, kwargs) |
|
|
|
def var_getattr(self, tx, name: str) -> "VariableTracker": |
|
name_vt = ConstantVariable.create(name) |
|
if name_vt in self: |
|
return self.call_method(tx, "__getitem__", [name_vt], {}) |
|
elif not self.include_none: |
|
defaults = {f.name: f.default for f in dataclasses.fields(self.user_cls)} |
|
if name in defaults: |
|
assert variables.ConstantVariable.is_literal(defaults[name]) |
|
return variables.ConstantVariable.create(defaults[name]) |
|
super().var_getattr(tx, name) |
|
|
|
call_hasattr = _call_hasattr_customobj |
|
|
|
|
|
class CustomizedDictVariable(ConstDictVariable): |
|
@staticmethod |
|
def is_matching_cls(cls): |
|
|
|
if ( |
|
issubclass(cls, collections.OrderedDict) |
|
and cls.__init__ is collections.OrderedDict.__init__ |
|
and not hasattr(cls, "__post_init__") |
|
): |
|
return True |
|
|
|
|
|
|
|
return _is_matching_transformers_cls(cls) or _is_matching_diffusers_cls(cls) |
|
|
|
@classmethod |
|
def is_matching_object(cls, obj): |
|
return cls.is_matching_cls(type(obj)) |
|
|
|
|
|
|
|
@classmethod |
|
def create(cls, user_cls, args, kwargs, options): |
|
|
|
for attr_name in ("__init__", "__post_init__", "__setattr__", "__setitem__"): |
|
if hasattr(user_cls, attr_name): |
|
fn = getattr(user_cls, attr_name) |
|
assert callable(fn), f"expect callable attr {attr_name}" |
|
if hasattr(fn, "__code__"): |
|
skip_code(fn.__code__) |
|
|
|
if dataclasses.is_dataclass(user_cls): |
|
|
|
bound = inspect.signature(user_cls).bind(*args, **kwargs) |
|
bound.apply_defaults() |
|
|
|
def make_var(x): |
|
if isinstance(x, VariableTracker): |
|
return x |
|
elif ConstantVariable.is_literal(x): |
|
return ConstantVariable.create(x) |
|
else: |
|
unimplemented( |
|
"expect VariableTracker or ConstantVariable.is_literal" |
|
) |
|
|
|
bound_args = {} |
|
if _is_matching_transformers_cls(user_cls) or _is_matching_diffusers_cls( |
|
user_cls |
|
): |
|
|
|
for k, v in bound.arguments.items(): |
|
if isinstance(v, ConstantVariable) and v.value is None or v is None: |
|
continue |
|
bound_args[k] = v |
|
else: |
|
bound_args = bound.arguments |
|
|
|
items = { |
|
ConstantVariable.create(k): make_var(v) for k, v in bound_args.items() |
|
} |
|
elif not args: |
|
|
|
items = {ConstantVariable.create(k): v for k, v in kwargs.items()} |
|
elif len(args) == 1 and isinstance(args[0], ConstDictVariable) and not kwargs: |
|
|
|
items = args[0].items |
|
else: |
|
unimplemented("custom dict init with args/kwargs unimplemented") |
|
|
|
return cls(items, user_cls, **options) |
|
|
|
|
|
@classmethod |
|
def wrap(cls, builder, obj): |
|
raise NotImplementedError |
|
|
|
def __init__(self, items, user_cls, **options): |
|
super().__init__(items, user_cls, **options) |
|
assert self.is_matching_cls(user_cls) |
|
|
|
def as_proxy(self): |
|
raise NotImplementedError |
|
|
|
|
|
|
|
def reconstruct(self, codegen): |
|
is_hf_model_output = _is_matching_transformers_cls( |
|
self.user_cls |
|
) or _is_matching_diffusers_cls(self.user_cls) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if is_hf_model_output: |
|
|
|
codegen.append_output(codegen.create_load_global("torch", True, add=True)) |
|
codegen.append_output(codegen.create_load_attr("_dynamo")) |
|
codegen.append_output(codegen.create_load_attr("disable")) |
|
codegen.extend_output([codegen._create_load_const(self.user_cls)]) |
|
|
|
if is_hf_model_output: |
|
|
|
codegen.extend_output(create_call_function(1, False)) |
|
|
|
|
|
d = self.keys_as_python_constant() |
|
codegen.foreach(d.values()) |
|
keys = tuple(d.keys()) |
|
codegen.extend_output(codegen.create_call_function_kw(len(keys), keys, True)) |
|
|
|
def call_method( |
|
self, |
|
tx, |
|
name, |
|
args: "List[VariableTracker]", |
|
kwargs: "Dict[str, VariableTracker]", |
|
) -> "VariableTracker": |
|
fn = getattr(self.user_cls, name) |
|
source = None if self.source is None else AttrSource(self.source, name) |
|
|
|
if hasattr(fn, "__objclass__") and fn.__objclass__ in ( |
|
dict, |
|
collections.OrderedDict, |
|
): |
|
|
|
return super().call_method(tx, name, args, kwargs) |
|
elif name in ("__getitem__", "to_tuple", "__setitem__", "__setattr__"): |
|
|
|
return tx.inline_user_function_return( |
|
variables.UserFunctionVariable(fn, source=source), |
|
[self] + list(args), |
|
kwargs, |
|
) |
|
|
|
unimplemented("custom dict: call_method unimplemented name=%s", name) |
|
|
|
def var_getattr(self, tx, name: str) -> "VariableTracker": |
|
name_vt = ConstantVariable.create(name) |
|
if name_vt in self: |
|
return self.call_method(tx, "__getitem__", [name_vt], {}) |
|
super().var_getattr(tx, name) |
|
|
|
call_hasattr = _call_hasattr_customobj |
|
|
|
|
|
@functools.lru_cache(None) |
|
def _install_PretrainedConfig_patch(): |
|
import transformers |
|
|
|
|
|
|
|
|
|
def _dynamo_overriden_transformers_eq(self, other): |
|
if not hasattr(other, "__dict__"): |
|
return False |
|
return self.__dict__ == other.__dict__ |
|
|
|
transformers.configuration_utils.PretrainedConfig.__eq__ = ( |
|
_dynamo_overriden_transformers_eq |
|
) |
|
|
|
|
|
class HFPretrainedConfigVariable(VariableTracker): |
|
""" |
|
Hack for HuggingFace PretrainedConfig |
|
""" |
|
|
|
@staticmethod |
|
def is_matching_cls(cls): |
|
mod = sys.modules.get("transformers.configuration_utils") |
|
is_match = mod is not None and issubclass(cls, mod.PretrainedConfig) |
|
|
|
|
|
if is_match: |
|
_install_PretrainedConfig_patch() |
|
return is_match |
|
|
|
@classmethod |
|
def is_matching_object(cls, obj): |
|
return cls.is_matching_cls(type(obj)) |
|
|
|
def __init__(self, obj, **kwargs): |
|
super().__init__(**kwargs) |
|
self.obj = obj |
|
assert self.is_matching_cls(type(obj)) |
|
|
|
def var_getattr(self, tx, name: str) -> "VariableTracker": |
|
from . import ConstantVariable |
|
|
|
return ConstantVariable.create(getattr(self.obj, name)) |
|
|
|
def call_hasattr(self, tx, name: str) -> "VariableTracker": |
|
return variables.ConstantVariable.create(hasattr(self.obj, name)) |
|
|
|
|
|
class PythonSysModulesVariable(VariableTracker): |
|
"""Special case for sys.modules. |
|
|
|
Without this we will guard on the exact set of modules imported in the |
|
lifetime of the python program. |
|
""" |
|
|
|
def python_type(self): |
|
return dict |
|
|
|
def reconstruct(self, codegen): |
|
codegen.extend_output( |
|
[ |
|
codegen.create_load_python_module(sys, True), |
|
codegen.create_load_attr("modules"), |
|
] |
|
) |
|
|
|
def call_method( |
|
self, tx, name, args: List[VariableTracker], kwargs: Dict[str, VariableTracker] |
|
): |
|
if name == "__getitem__": |
|
return self.call_getitem(tx, *args, **kwargs) |
|
elif name == "get": |
|
return self.call_get(tx, *args, **kwargs) |
|
elif name == "__contains__": |
|
return self.call_contains(tx, *args, **kwargs) |
|
unimplemented(f"sys.modules.{name}(*{args}, **{kwargs})") |
|
|
|
def _contains_helper(self, tx, key: VariableTracker): |
|
k = key.as_python_constant() |
|
has_key = k in sys.modules |
|
install_guard( |
|
self.make_guard( |
|
functools.partial(GuardBuilder.DICT_CONTAINS, key=k, invert=not has_key) |
|
) |
|
) |
|
return k, has_key |
|
|
|
def call_contains(self, tx, key: VariableTracker): |
|
k, has_key = self._contains_helper(tx, key) |
|
return ConstantVariable.create(value=has_key) |
|
|
|
def call_get( |
|
self, tx, key: VariableTracker, default: Optional[VariableTracker] = None |
|
): |
|
from .builder import VariableBuilder |
|
|
|
k, has_key = self._contains_helper(tx, key) |
|
|
|
if has_key: |
|
return VariableBuilder( |
|
tx, |
|
GetItemSource(self.source, k), |
|
)(sys.modules[k]) |
|
|
|
if default is not None: |
|
return default |
|
|
|
return ConstantVariable.create(value=None) |
|
|
|
def call_getitem(self, tx, key: VariableTracker): |
|
from .builder import VariableBuilder |
|
|
|
k, has_key = self._contains_helper(tx, key) |
|
return VariableBuilder( |
|
tx, |
|
GetItemSource(self.source, k), |
|
)(sys.modules[k]) |
|
|