|
|
|
|
|
import collections |
|
from enum import Enum |
|
from typing import Any, Callable, Dict, List |
|
|
|
from .. import variables |
|
from ..current_scope_id import current_scope_id |
|
from ..exc import unimplemented |
|
from ..source import AttrSource, Source |
|
from ..utils import istype |
|
|
|
|
|
class MutableLocalSource(Enum): |
|
""" |
|
If the VariableTracker.mutable_local represents a Variable that: |
|
- already existed that Dynamo began tracking while introspection (Existing) |
|
- is a new variable that is created during Dynamo introspection (Local) |
|
""" |
|
|
|
Existing = 0 |
|
Local = 1 |
|
|
|
|
|
class MutableLocalBase: |
|
""" |
|
Base class for Variable.mutable_local |
|
""" |
|
|
|
def __init__(self, typ: MutableLocalSource): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if typ is MutableLocalSource.Existing: |
|
self.scope = 0 |
|
elif typ is MutableLocalSource.Local: |
|
self.scope = current_scope_id() |
|
else: |
|
unimplemented(f"Unsupported MutableLocalSource: {typ}") |
|
|
|
|
|
class MutableLocal(MutableLocalBase): |
|
""" |
|
Marker used to indicate this (list, iter, etc) was constructed in |
|
local scope and can be mutated safely in analysis without leaking |
|
state. |
|
""" |
|
|
|
def __init__(self): |
|
super().__init__(MutableLocalSource.Local) |
|
|
|
def __hash__(self): |
|
return id(self) |
|
|
|
def __eq__(self, other): |
|
return self is other |
|
|
|
|
|
def _is_top_level_scope(scope_id): |
|
return scope_id == 1 |
|
|
|
|
|
def is_side_effect_safe(m: MutableLocalBase): |
|
scope_id = current_scope_id() |
|
|
|
|
|
|
|
|
|
if _is_top_level_scope(scope_id): |
|
return True |
|
|
|
return m.scope == scope_id |
|
|
|
|
|
class VariableTrackerMeta(type): |
|
all_subclasses = [] |
|
|
|
def __instancecheck__(cls, instance) -> bool: |
|
"""Make isinstance work with LazyVariableTracker""" |
|
if type.__instancecheck__( |
|
variables.LazyVariableTracker, instance |
|
) and cls not in ( |
|
VariableTracker, |
|
variables.LazyVariableTracker, |
|
): |
|
instance = instance.realize() |
|
return type.__instancecheck__(cls, instance) |
|
|
|
def __init__(cls, name, bases, attrs): |
|
super().__init__(name, bases, attrs) |
|
VariableTrackerMeta.all_subclasses.append(cls) |
|
|
|
|
|
class VariableTracker(metaclass=VariableTrackerMeta): |
|
""" |
|
Base class for tracked locals and stack values |
|
|
|
VariableTracker instances are immutable and should be copied in |
|
order to change them. |
|
""" |
|
|
|
|
|
_nonvar_fields = { |
|
"value", |
|
"guards", |
|
"source", |
|
"mutable_local", |
|
"parents_tracker", |
|
"user_code_variable_name", |
|
} |
|
|
|
def clone(self, **kwargs): |
|
"""Shallow copy with some (optional) changes""" |
|
args = dict(self.__dict__) |
|
args.update(kwargs) |
|
return self.__class__(**args) |
|
|
|
@classmethod |
|
def visit( |
|
cls, |
|
fn: Callable[["VariableTracker"], None], |
|
value, |
|
cache=None, |
|
): |
|
""" |
|
Walk value and call fn on all the VariableTracker instances |
|
""" |
|
if cache is None: |
|
cache = dict() |
|
|
|
idx = id(value) |
|
if idx in cache: |
|
return |
|
|
|
cache[idx] = value |
|
|
|
if isinstance(value, VariableTracker): |
|
value = value.unwrap() |
|
fn(value) |
|
value = value.unwrap() |
|
nonvars = value._nonvar_fields |
|
for key, subvalue in value.__dict__.items(): |
|
if key not in nonvars: |
|
cls.visit(fn, subvalue, cache) |
|
elif istype(value, (list, tuple)): |
|
for subvalue in value: |
|
cls.visit(fn, subvalue, cache) |
|
elif istype(value, (dict, collections.OrderedDict)): |
|
for subvalue in value.values(): |
|
cls.visit(fn, subvalue, cache) |
|
|
|
def __repr__(self): |
|
return f"{self.__class__.__name__}()" |
|
|
|
def debug_repr(self): |
|
|
|
try: |
|
return repr(self.as_python_constant()) |
|
except NotImplementedError: |
|
return repr(self) |
|
|
|
def python_type(self): |
|
""" |
|
Abstract method to be implemented by subclasses of VariableTracker. |
|
|
|
This method should return the type represented by the instance of the subclass. |
|
The purpose is to provide a standardized way to retrieve the Python type information |
|
of the variable being tracked. |
|
|
|
Returns: |
|
type: The Python type (such as int, str, list, etc.) of the variable tracked by |
|
the subclass. If the type cannot be determined or is not relevant, |
|
leaving it undefined or invoking super() is always sound. |
|
|
|
Note: |
|
This is an abstract method and may be overridden in subclasses. |
|
|
|
Example: |
|
class SetVariable(VariableTracker): |
|
def python_type(self): |
|
return set |
|
|
|
Raises: |
|
NotImplementedError: If the method is not implemented in a subclass. |
|
""" |
|
raise NotImplementedError(f"{self} has no type") |
|
|
|
def as_python_constant(self): |
|
"""For constants""" |
|
raise NotImplementedError(f"{self} is not a constant") |
|
|
|
def guard_as_python_constant(self): |
|
"""Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants""" |
|
try: |
|
return self.as_python_constant() |
|
except NotImplementedError as e: |
|
unimplemented(str(e)) |
|
|
|
def is_python_constant(self): |
|
try: |
|
self.as_python_constant() |
|
return True |
|
except NotImplementedError: |
|
return False |
|
|
|
def make_guard(self, fn): |
|
if self.source: |
|
return self.source.make_guard(fn) |
|
raise NotImplementedError |
|
|
|
def const_getattr(self, tx, name: str) -> Any: |
|
"""getattr(self, name) returning a python constant""" |
|
raise NotImplementedError |
|
|
|
def var_getattr(self, tx, name: str) -> "VariableTracker": |
|
"""getattr(self, name) returning a new variable""" |
|
value = self.const_getattr(tx, name) |
|
if not variables.ConstantVariable.is_literal(value): |
|
raise NotImplementedError |
|
source = None |
|
if self.source: |
|
source = AttrSource(self.source, name) |
|
return variables.ConstantVariable.create(value, source=source) |
|
|
|
def is_proxy(self): |
|
try: |
|
self.as_proxy() |
|
return True |
|
except NotImplementedError: |
|
return False |
|
|
|
def as_proxy(self): |
|
raise NotImplementedError(str(self)) |
|
|
|
def maybe_fx_node(self): |
|
try: |
|
proxy = self.as_proxy() |
|
import torch.fx |
|
|
|
if isinstance(proxy, torch.fx.Proxy): |
|
return proxy.node |
|
return None |
|
except NotImplementedError: |
|
return None |
|
|
|
def reconstruct(self, codegen): |
|
raise NotImplementedError |
|
|
|
def can_reconstruct(self, tx): |
|
"""If it is possible to reconstruct the Python object this |
|
VariableTracker represents.""" |
|
assert tx is tx.output.root_tx, "Only root tx can reconstruct" |
|
try: |
|
from ..codegen import PyCodegen |
|
|
|
cg = PyCodegen(tx) |
|
self.reconstruct(cg) |
|
return True |
|
except NotImplementedError: |
|
return False |
|
|
|
def unpack_var_sequence(self, tx) -> List["VariableTracker"]: |
|
raise NotImplementedError |
|
|
|
def has_unpack_var_sequence(self, tx) -> bool: |
|
try: |
|
self.unpack_var_sequence(tx) |
|
return True |
|
except NotImplementedError: |
|
return False |
|
|
|
def inspect_parameter_names(self) -> List[str]: |
|
unimplemented(f"inspect_parameter_names: {self}") |
|
|
|
def call_hasattr(self, tx, name: str) -> "VariableTracker": |
|
unimplemented(f"hasattr {self.__class__.__name__} {name}") |
|
|
|
def call_function( |
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" |
|
) -> "VariableTracker": |
|
unimplemented(f"call_function {self} {args} {kwargs}") |
|
|
|
def call_method( |
|
self, |
|
tx, |
|
name, |
|
args: "List[VariableTracker]", |
|
kwargs: "Dict[str, VariableTracker]", |
|
) -> "VariableTracker": |
|
if name == "__len__" and self.has_unpack_var_sequence(tx): |
|
assert not (args or kwargs) |
|
return variables.ConstantVariable.create(len(self.unpack_var_sequence(tx))) |
|
elif ( |
|
name == "__getattr__" |
|
and len(args) == 1 |
|
and args[0].is_python_constant() |
|
and not kwargs |
|
): |
|
return self.var_getattr(tx, args[0].as_python_constant()) |
|
unimplemented(f"call_method {self} {name} {args} {kwargs}") |
|
|
|
def set_name_hint(self, name): |
|
pass |
|
|
|
def realize(self) -> "VariableTracker": |
|
"""Used by LazyVariableTracker to build the real VariableTracker""" |
|
return self |
|
|
|
def unwrap(self) -> "VariableTracker": |
|
"""Used by LazyVariableTracker to return the real VariableTracker if it already exists""" |
|
return self |
|
|
|
def is_realized(self): |
|
"""Used by LazyVariableTracker to indicate an unrealized node""" |
|
return True |
|
|
|
def next_variable(self, tx): |
|
unimplemented(f"next({self})") |
|
|
|
def is_strict_mode(self, tx): |
|
return tx.strict_checks_fn and tx.strict_checks_fn(self) |
|
|
|
def __init__( |
|
self, |
|
*, |
|
source: Source = None, |
|
mutable_local: MutableLocal = None, |
|
): |
|
super().__init__() |
|
self.source = source |
|
self.mutable_local = mutable_local |
|
|
|
|
|
def typestr(*objs): |
|
if len(objs) == 1: |
|
(obj,) = objs |
|
if isinstance(obj, VariableTracker): |
|
return str(obj) |
|
else: |
|
return type(obj).__name__ |
|
else: |
|
return " ".join(map(typestr, objs)) |
|
|