|
|
|
|
|
import contextlib |
|
import functools |
|
import inspect |
|
import itertools |
|
import logging |
|
import math |
|
import operator |
|
import types |
|
from collections import defaultdict, OrderedDict |
|
from typing import Dict, List |
|
|
|
import torch |
|
from torch import sym_float, sym_int |
|
|
|
from .. import config, polyfill, variables |
|
from ..exc import ( |
|
AttributeMutationError, |
|
unimplemented, |
|
Unsupported, |
|
UserError, |
|
UserErrorType, |
|
) |
|
from ..guards import GuardBuilder, install_guard |
|
from ..replay_record import DummyModule |
|
from ..source import AttrSource, GetItemSource, is_constant_source, TypeSource |
|
from ..utils import ( |
|
check_constant_args, |
|
check_numpy_ndarray_args, |
|
check_unspec_or_constant_args, |
|
check_unspec_python_args, |
|
extract_fake_example_value, |
|
get_fake_value, |
|
guard_if_dyn, |
|
istype, |
|
numpy_operator_wrapper, |
|
proxy_args_kwargs, |
|
tensortype_to_dtype, |
|
) |
|
from .base import MutableLocal, VariableTracker |
|
from .constant import ConstantVariable |
|
from .ctx_manager import EventVariable, StreamVariable |
|
from .dicts import ( |
|
ConstDictVariable, |
|
DefaultDictVariable, |
|
DictView, |
|
is_hashable, |
|
SetVariable, |
|
) |
|
from .lists import ( |
|
BaseListVariable, |
|
ListIteratorVariable, |
|
ListVariable, |
|
SizeVariable, |
|
TupleIteratorVariable, |
|
TupleVariable, |
|
) |
|
from .tensor import ( |
|
FakeItemVariable, |
|
supported_comparison_ops, |
|
SymNodeVariable, |
|
TensorVariable, |
|
UnspecializedPythonVariable, |
|
) |
|
from .user_defined import UserDefinedObjectVariable, UserDefinedVariable |
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
IN_PLACE_DESUGARING_MAP = { |
|
operator.iadd: operator.add, |
|
operator.isub: operator.sub, |
|
operator.imul: operator.mul, |
|
operator.ifloordiv: operator.floordiv, |
|
operator.itruediv: operator.truediv, |
|
operator.imod: operator.mod, |
|
operator.imatmul: operator.imatmul, |
|
operator.ilshift: operator.lshift, |
|
operator.irshift: operator.rshift, |
|
operator.ipow: operator.pow, |
|
operator.iand: operator.and_, |
|
operator.ior: operator.or_, |
|
operator.ixor: operator.xor, |
|
} |
|
|
|
|
|
def _polyfill_call_impl(name): |
|
"""Create a BuiltinVariable.call_{name} method that inlines through polyfill.{name}""" |
|
|
|
def call_fn(self, tx, *args, **kwargs): |
|
return tx.inline_user_function_return( |
|
variables.UserFunctionVariable(fn), args, kwargs |
|
) |
|
|
|
fn = getattr(polyfill, name) |
|
call_fn.__name__ = f"call_{name}" |
|
return call_fn |
|
|
|
|
|
class BuiltinVariable(VariableTracker): |
|
_SENTINEL = object() |
|
_nonvar_fields = { |
|
"fn", |
|
*VariableTracker._nonvar_fields, |
|
} |
|
|
|
@classmethod |
|
def create_with_source(cls, value, source): |
|
install_guard(source.make_guard(GuardBuilder.BUILTIN_MATCH)) |
|
return BuiltinVariable(value, source=source) |
|
|
|
@staticmethod |
|
@functools.lru_cache(None) |
|
def _constant_fold_functions(): |
|
fns = { |
|
abs, |
|
all, |
|
any, |
|
bool, |
|
callable, |
|
chr, |
|
divmod, |
|
float, |
|
getattr, |
|
int, |
|
len, |
|
max, |
|
min, |
|
ord, |
|
pow, |
|
repr, |
|
round, |
|
str, |
|
str.format, |
|
sum, |
|
type, |
|
operator.abs, |
|
operator.pos, |
|
operator.neg, |
|
operator.not_, |
|
operator.truth, |
|
operator.invert, |
|
operator.pow, |
|
operator.mul, |
|
operator.matmul, |
|
operator.floordiv, |
|
operator.truediv, |
|
operator.mod, |
|
operator.add, |
|
operator.sub, |
|
operator.getitem, |
|
operator.length_hint, |
|
operator.lshift, |
|
operator.rshift, |
|
operator.and_, |
|
operator.or_, |
|
operator.xor, |
|
operator.ipow, |
|
operator.imul, |
|
operator.imatmul, |
|
operator.ifloordiv, |
|
operator.itruediv, |
|
operator.imod, |
|
operator.iadd, |
|
operator.isub, |
|
operator.ilshift, |
|
operator.irshift, |
|
operator.iand, |
|
operator.ixor, |
|
operator.ior, |
|
operator.index, |
|
} |
|
from .tensor import supported_comparison_ops |
|
|
|
fns.update(supported_comparison_ops.values()) |
|
fns.update(x for x in math.__dict__.values() if isinstance(x, type(math.sqrt))) |
|
return fns |
|
|
|
def can_constant_fold_through(self): |
|
return self.fn in self._constant_fold_functions() |
|
|
|
@staticmethod |
|
@functools.lru_cache(None) |
|
def _fx_graph_functions(): |
|
fns = { |
|
operator.abs, |
|
operator.pos, |
|
operator.neg, |
|
operator.not_, |
|
operator.invert, |
|
operator.pow, |
|
operator.mul, |
|
operator.matmul, |
|
operator.floordiv, |
|
operator.truediv, |
|
operator.mod, |
|
operator.add, |
|
operator.lt, |
|
operator.gt, |
|
operator.ge, |
|
operator.le, |
|
operator.ne, |
|
operator.eq, |
|
operator.sub, |
|
operator.getitem, |
|
operator.length_hint, |
|
operator.lshift, |
|
operator.rshift, |
|
operator.and_, |
|
operator.or_, |
|
operator.xor, |
|
operator.ipow, |
|
operator.imul, |
|
operator.imatmul, |
|
operator.ifloordiv, |
|
operator.itruediv, |
|
operator.imod, |
|
operator.iadd, |
|
operator.isub, |
|
operator.ilshift, |
|
operator.irshift, |
|
operator.iand, |
|
operator.ixor, |
|
operator.ior, |
|
} |
|
return fns |
|
|
|
@staticmethod |
|
@functools.lru_cache(None) |
|
def _binops(): |
|
|
|
fns = { |
|
operator.add: (["__add__", "__radd__", "__iadd__"], operator.iadd), |
|
operator.sub: (["__sub__", "__rsub__", "__isub__"], operator.isub), |
|
operator.mul: (["__mul__", "__rmul__", "__imul__"], operator.imul), |
|
operator.truediv: ( |
|
["__truediv__", "__rtruediv__", "__itruediv__"], |
|
operator.itruediv, |
|
), |
|
operator.floordiv: ( |
|
["__floordiv__", "__rfloordiv__", "__ifloordiv__"], |
|
operator.ifloordiv, |
|
), |
|
operator.mod: (["__mod__", "__rmod__", "__imod__"], operator.imod), |
|
pow: (["__pow__", "__rpow__", "__ipow__"], operator.ipow), |
|
operator.pow: (["__pow__", "__rpow__", "__ipow__"], operator.ipow), |
|
operator.lshift: ( |
|
["__lshift__", "__rlshift__", "__ilshift__"], |
|
operator.ilshift, |
|
), |
|
operator.rshift: ( |
|
["__rshift__", "__rrshift__", "__irshift__"], |
|
operator.irshift, |
|
), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
return fns |
|
|
|
@staticmethod |
|
@functools.lru_cache(None) |
|
def _binop_handlers(): |
|
|
|
|
|
|
|
|
|
from .dicts import DictKeys, SetVariable |
|
from .functions import BaseUserFunctionVariable, UserFunctionVariable |
|
from .nn_module import NNModuleVariable |
|
from .tensor import supported_const_comparison_ops |
|
from .torch import BaseTorchVariable |
|
from .user_defined import ( |
|
UserDefinedClassVariable, |
|
UserDefinedObjectVariable, |
|
UserDefinedVariable, |
|
) |
|
|
|
|
|
op_handlers = {} |
|
for ( |
|
op, |
|
(magic_method_names, in_place_op), |
|
) in BuiltinVariable._binops().items(): |
|
op_handlers[op] = [] |
|
op_handlers[in_place_op] = [] |
|
|
|
forward_name, reverse_name, inplace_name = magic_method_names |
|
|
|
|
|
def user_defined_handler( |
|
tx, |
|
a, |
|
b, |
|
*, |
|
forward_name=forward_name, |
|
reverse_name=reverse_name, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(a, UserDefinedVariable): |
|
return a.call_method(tx, forward_name, [b], {}) |
|
else: |
|
return b.call_method(tx, reverse_name, [a], {}) |
|
|
|
op_handlers[op].append( |
|
((UserDefinedVariable, VariableTracker), user_defined_handler) |
|
) |
|
op_handlers[op].append( |
|
((VariableTracker, UserDefinedVariable), user_defined_handler) |
|
) |
|
|
|
def user_defined_inplace_handler(tx, a, b, *, forward_name=inplace_name): |
|
return a.call_method(tx, forward_name, [b], {}) |
|
|
|
op_handlers[in_place_op].append( |
|
((UserDefinedVariable, VariableTracker), user_defined_inplace_handler) |
|
) |
|
op_handlers[in_place_op].append( |
|
((VariableTracker, UserDefinedVariable), user_defined_inplace_handler) |
|
) |
|
|
|
|
|
def dynamic_handler(tx, a, b, *, fn=op): |
|
from .builder import wrap_fx_proxy |
|
|
|
return wrap_fx_proxy( |
|
tx, |
|
tx.output.create_proxy( |
|
"call_function", fn, *proxy_args_kwargs([a, b], {}) |
|
), |
|
) |
|
|
|
op_handlers[op].append( |
|
((SymNodeVariable, VariableTracker), dynamic_handler) |
|
) |
|
op_handlers[op].append( |
|
((VariableTracker, SymNodeVariable), dynamic_handler) |
|
) |
|
|
|
|
|
op_handlers[in_place_op].append( |
|
((SymNodeVariable, VariableTracker), dynamic_handler) |
|
) |
|
op_handlers[in_place_op].append( |
|
((VariableTracker, SymNodeVariable), dynamic_handler) |
|
) |
|
|
|
|
|
|
|
|
|
def tuple_add_handler(tx, a, b): |
|
return TupleVariable([*a.items, *b.unpack_var_sequence(tx)]) |
|
|
|
def size_add_handler(tx, a, b): |
|
return SizeVariable([*a.items, *b.unpack_var_sequence(tx)]) |
|
|
|
list_like_addition_handlers = [ |
|
|
|
|
|
|
|
( |
|
(SizeVariable, SizeVariable), |
|
size_add_handler, |
|
), |
|
( |
|
(TupleVariable, TupleVariable), |
|
tuple_add_handler, |
|
), |
|
( |
|
(TupleVariable, ConstantVariable), |
|
tuple_add_handler, |
|
), |
|
( |
|
(ConstantVariable, TupleVariable), |
|
lambda tx, a, b: TupleVariable( |
|
[*a.unpack_var_sequence(tx), *b.items], |
|
), |
|
), |
|
( |
|
( |
|
ListVariable, |
|
(BaseListVariable, ConstantVariable, ListIteratorVariable), |
|
), |
|
lambda tx, a, b: ListVariable( |
|
[*a.items, *b.unpack_var_sequence(tx)], mutable_local=MutableLocal() |
|
), |
|
), |
|
( |
|
(BaseListVariable, BaseListVariable), |
|
lambda tx, a, b: type(a)([*a.items, *b.items]), |
|
), |
|
] |
|
op_handlers[operator.add].extend(list_like_addition_handlers) |
|
|
|
def list_iadd_handler(tx, a, b): |
|
if not a.mutable_local or not b.has_unpack_var_sequence(tx): |
|
|
|
return None |
|
|
|
seq = b.unpack_var_sequence(tx) |
|
tx.output.side_effects.mutation(a) |
|
a.items.extend(seq) |
|
return a |
|
|
|
list_like_iadd_handlers = [ |
|
( |
|
(ListVariable, VariableTracker), |
|
list_iadd_handler, |
|
), |
|
( |
|
(TupleVariable, TupleVariable), |
|
tuple_add_handler, |
|
), |
|
( |
|
(TupleVariable, ConstantVariable), |
|
tuple_add_handler, |
|
), |
|
] |
|
op_handlers[operator.iadd].extend(list_like_iadd_handlers) |
|
|
|
|
|
def expand_list_like(tx, lst, const): |
|
if isinstance(lst, ConstantVariable): |
|
lst, const = const, lst |
|
return lst.__class__( |
|
items=lst.items * const.as_python_constant(), |
|
mutable_local=MutableLocal(), |
|
) |
|
|
|
list_like_expansion_handlers = [ |
|
((ListVariable, ConstantVariable), expand_list_like), |
|
((TupleVariable, ConstantVariable), expand_list_like), |
|
((ConstantVariable, ListVariable), expand_list_like), |
|
((ConstantVariable, TupleVariable), expand_list_like), |
|
] |
|
op_handlers[operator.mul].extend(list_like_expansion_handlers) |
|
|
|
size_or_tuple = (SizeVariable, TupleVariable) |
|
has_set_items = (SetVariable, DictKeys) |
|
|
|
def create_cmp_op_handlers(op): |
|
def compare_by_value(tx, a, b): |
|
return ConstantVariable(op(a.value, b.value)) |
|
|
|
result = [((ConstantVariable, ConstantVariable), compare_by_value)] |
|
|
|
if op in supported_const_comparison_ops.values(): |
|
|
|
none_result = op(object(), None) |
|
if op.__name__.startswith("is_"): |
|
|
|
def never(tx, a, b): |
|
return ConstantVariable(none_result) |
|
|
|
obj_op_none = never |
|
none_op_obj = never |
|
else: |
|
|
|
def obj_op_none(tx, a, b: ConstantVariable): |
|
if b.value is None or b.value is True or b.value is False: |
|
return ConstantVariable(none_result) |
|
|
|
def none_op_obj(tx, a: ConstantVariable, b): |
|
if a.value is None or a.value is True or a.value is False: |
|
return ConstantVariable(none_result) |
|
|
|
types_that_are_never_none = ( |
|
TensorVariable, |
|
SymNodeVariable, |
|
NNModuleVariable, |
|
BaseListVariable, |
|
UserDefinedVariable, |
|
BaseUserFunctionVariable, |
|
ConstDictVariable, |
|
BaseTorchVariable, |
|
) |
|
result.extend( |
|
[ |
|
( |
|
(types_that_are_never_none, ConstantVariable), |
|
obj_op_none, |
|
), |
|
( |
|
(ConstantVariable, types_that_are_never_none), |
|
none_op_obj, |
|
), |
|
] |
|
) |
|
|
|
def list_compare_nocheck(tx, left, right): |
|
return BaseListVariable.list_compare(tx, op, left, right) |
|
|
|
def list_compare_check(tx, left, right): |
|
if type(left) is not type( |
|
right |
|
): |
|
unimplemented(f"{op.__name__}({left}, {right})") |
|
return BaseListVariable.list_compare(tx, op, left, right) |
|
|
|
def compare_set_items(tx, left, right): |
|
return ConstantVariable(op(left.set_items, right.set_items)) |
|
|
|
def compare_via_method(tx, left, right): |
|
return left.call_method(tx, f"__{op.__name__}__", [right], {}) |
|
|
|
if op.__name__.startswith("is_"): |
|
compare_user_defined = compare_by_value |
|
else: |
|
compare_user_defined = compare_via_method |
|
|
|
op_var = BuiltinVariable(op) |
|
result.extend( |
|
[ |
|
( |
|
( |
|
(UserFunctionVariable, BuiltinVariable), |
|
(UserFunctionVariable, BuiltinVariable), |
|
), |
|
lambda tx, a, b: ConstantVariable(op(a.fn, b.fn)), |
|
), |
|
( |
|
( |
|
NNModuleVariable, |
|
NNModuleVariable, |
|
), |
|
lambda tx, a, b: ConstantVariable( |
|
op( |
|
tx.output.get_submodule(a.module_key), |
|
tx.output.get_submodule(b.module_key), |
|
) |
|
), |
|
), |
|
((size_or_tuple, size_or_tuple), list_compare_nocheck), |
|
( |
|
(variables.BaseListVariable, variables.BaseListVariable), |
|
list_compare_check, |
|
), |
|
((has_set_items, has_set_items), compare_set_items), |
|
( |
|
(UserDefinedObjectVariable, UserDefinedObjectVariable), |
|
compare_user_defined, |
|
), |
|
( |
|
(UserDefinedClassVariable, UserDefinedClassVariable), |
|
compare_user_defined, |
|
), |
|
( |
|
( |
|
(StreamVariable, EventVariable, ConstantVariable), |
|
(StreamVariable, EventVariable, ConstantVariable), |
|
), |
|
compare_by_value, |
|
), |
|
( |
|
(TensorVariable, VariableTracker), |
|
op_var._comparison_with_tensor, |
|
), |
|
( |
|
(VariableTracker, TensorVariable), |
|
op_var._comparison_with_tensor, |
|
), |
|
( |
|
(SymNodeVariable, VariableTracker), |
|
op_var._comparison_with_symnode, |
|
), |
|
( |
|
(VariableTracker, SymNodeVariable), |
|
op_var._comparison_with_symnode, |
|
), |
|
] |
|
) |
|
|
|
if op.__name__.startswith("is_"): |
|
|
|
def handle_is(tx, left, right): |
|
|
|
|
|
if type(left) is not type(right): |
|
return ConstantVariable.create(op.__name__ != "is_") |
|
|
|
result.append(((VariableTracker, VariableTracker), handle_is)) |
|
|
|
return result |
|
|
|
for op in supported_comparison_ops.values(): |
|
assert callable(op) |
|
assert op not in op_handlers |
|
op_handlers[op] = create_cmp_op_handlers(op) |
|
|
|
return op_handlers |
|
|
|
@staticmethod |
|
def _find_binop_handler(op, a_type, b_type): |
|
handlers = BuiltinVariable._binop_handlers().get(op) |
|
if handlers is None: |
|
return None |
|
|
|
matches = [] |
|
for (type1, type2), handler in handlers: |
|
if issubclass(a_type, type1) and issubclass(b_type, type2): |
|
matches.append(handler) |
|
return matches |
|
|
|
def can_insert_in_graph(self): |
|
return self.fn in self._fx_graph_functions() |
|
|
|
def __init__(self, fn, **kwargs): |
|
super().__init__(**kwargs) |
|
self.fn = fn |
|
|
|
def __str__(self): |
|
if self.fn is None: |
|
name = "None" |
|
else: |
|
name = self.fn.__name__ |
|
|
|
return f"{self.__class__.__name__}({name})" |
|
|
|
def python_type(self): |
|
return type(self.fn) |
|
|
|
def as_python_constant(self): |
|
return self.fn |
|
|
|
def as_proxy(self): |
|
DTYPE = { |
|
bool: torch.bool, |
|
int: torch.int64, |
|
float: torch.float64, |
|
} |
|
if self.fn in DTYPE: |
|
return DTYPE[self.fn] |
|
return super().as_proxy() |
|
|
|
def reconstruct(self, codegen): |
|
name = self.fn.__name__ |
|
assert self.fn.__module__ == "builtins" |
|
assert name not in codegen.tx.f_globals, "shadowed global" |
|
codegen.append_output(codegen.create_load_global(name, False, add=True)) |
|
|
|
def constant_args(self, *args, **kwargs): |
|
return check_constant_args(args, kwargs) |
|
|
|
def tensor_args(self, *args): |
|
any_tensor = False |
|
for arg in args: |
|
if isinstance(arg, variables.GetAttrVariable): |
|
return False |
|
any_tensor = any_tensor or isinstance(arg, variables.TensorVariable) |
|
return any_tensor |
|
|
|
def tensor_args_type(self, arg_types): |
|
any_tensor = False |
|
for arg_type in arg_types: |
|
if issubclass(arg_type, variables.GetAttrVariable): |
|
return False |
|
any_tensor = any_tensor or issubclass(arg_type, variables.TensorVariable) |
|
return any_tensor |
|
|
|
def python_and_tensor_constant_only(self, *args, **kwargs): |
|
tensor_args = [] |
|
non_tensor_args = [] |
|
for i in itertools.chain(args, kwargs.values()): |
|
if isinstance(i, variables.TensorVariable): |
|
tensor_args.append(i) |
|
else: |
|
non_tensor_args.append(i) |
|
return all( |
|
is_constant_source(t.source) if t.source is not None else False |
|
for t in tensor_args |
|
) and self.constant_args(*non_tensor_args) |
|
|
|
@staticmethod |
|
def unwrap_unspec_args_kwargs(args, kwargs): |
|
return [x.as_python_constant() for x in args], { |
|
k: v.as_python_constant() for k, v in kwargs.items() |
|
} |
|
|
|
def has_constant_handler(self, args, kwargs): |
|
return self.can_constant_fold_through() and check_unspec_or_constant_args( |
|
args, kwargs |
|
) |
|
|
|
@staticmethod |
|
def _make_handler(fn, arg_types: List[type], has_kwargs: bool): |
|
from .builder import SourcelessBuilder |
|
from .lazy import LazyVariableTracker |
|
|
|
obj = BuiltinVariable(fn) |
|
handlers = [] |
|
|
|
if any(issubclass(t, LazyVariableTracker) for t in arg_types): |
|
return lambda tx, args, kwargs: obj.call_function( |
|
tx, [v.realize() for v in args], kwargs |
|
) |
|
|
|
if inspect.isclass(fn) and issubclass(fn, Exception): |
|
|
|
def create_exception_class_object(tx, args, kwargs): |
|
if fn is AssertionError and not all( |
|
isinstance(x, variables.ConstantVariable) |
|
and isinstance(x.value, str) |
|
for x in args |
|
): |
|
unimplemented("assert with non-string message") |
|
|
|
return variables.ExceptionVariable(fn, args, **kwargs) |
|
|
|
return create_exception_class_object |
|
|
|
if obj.can_insert_in_graph() and not ( |
|
fn is operator.getitem |
|
and not issubclass(arg_types[0], variables.TensorVariable) |
|
): |
|
if obj.tensor_args_type(arg_types): |
|
return obj._handle_insert_op_in_graph |
|
elif has_kwargs: |
|
|
|
handlers.append(obj._handle_insert_op_in_graph) |
|
|
|
|
|
|
|
if len(arg_types) == 2 and not has_kwargs: |
|
|
|
binop_handlers = BuiltinVariable._find_binop_handler(fn, *arg_types) |
|
if not binop_handlers: |
|
pass |
|
elif len(binop_handlers) == 1: |
|
(binop_handler,) = binop_handlers |
|
handlers.append(lambda tx, args, _: binop_handler(tx, *args)) |
|
else: |
|
|
|
def call_binop_handlers(tx, args, _): |
|
for fn in binop_handlers: |
|
rv = fn(tx, *args) |
|
if rv: |
|
return rv |
|
|
|
handlers.append(call_binop_handlers) |
|
|
|
self_handler = getattr(obj, f"call_{fn.__name__}", None) |
|
if self_handler: |
|
|
|
def call_self_handler(tx, args, kwargs): |
|
try: |
|
result = self_handler(tx, *args, **kwargs) |
|
if result is not None: |
|
return result |
|
except TypeError: |
|
|
|
|
|
try: |
|
inspect.signature(self_handler).bind(tx, *args, **kwargs) |
|
except TypeError as e: |
|
has_constant_handler = obj.has_constant_handler(args, kwargs) |
|
if not has_constant_handler: |
|
log.warning( |
|
"incorrect arg count %s %s and no constant handler", |
|
self_handler, |
|
e, |
|
) |
|
unimplemented( |
|
f"invalid handler args {self_handler} {args} {kwargs}" |
|
) |
|
else: |
|
raise |
|
except Unsupported as exc: |
|
has_constant_handler = obj.has_constant_handler(args, kwargs) |
|
if not has_constant_handler: |
|
raise |
|
|
|
exc.remove_from_stats() |
|
|
|
handlers.append(call_self_handler) |
|
|
|
if obj.can_constant_fold_through(): |
|
builder = SourcelessBuilder.create |
|
|
|
if ( |
|
all(issubclass(x, ConstantVariable) for x in arg_types) |
|
and not has_kwargs |
|
): |
|
|
|
def constant_fold_handler(tx, args, kwargs): |
|
|
|
try: |
|
res = fn( |
|
*[x.as_python_constant() for x in args], |
|
) |
|
except Exception as exc: |
|
unimplemented(f"constant fold exception: {repr(exc)}") |
|
return builder(tx, res) |
|
|
|
else: |
|
|
|
def constant_fold_handler(tx, args, kwargs): |
|
|
|
if check_unspec_or_constant_args(args, kwargs): |
|
try: |
|
res = fn( |
|
*[x.as_python_constant() for x in args], |
|
**{ |
|
k: v.as_python_constant() for k, v in kwargs.items() |
|
}, |
|
) |
|
except Exception as exc: |
|
unimplemented(f"constant fold exception: {repr(exc)}") |
|
return builder(tx, res) |
|
|
|
handlers.append(constant_fold_handler) |
|
|
|
error_msg = f"builtin: {fn.__name__} {arg_types} {has_kwargs}" |
|
if len(handlers) == 0: |
|
return lambda *args: unimplemented(error_msg) |
|
elif len(handlers) == 1: |
|
(handler,) = handlers |
|
|
|
def builtin_dipatch(tx, args, kwargs): |
|
rv = handler(tx, args, kwargs) |
|
if rv: |
|
return rv |
|
unimplemented(error_msg) |
|
|
|
else: |
|
|
|
def builtin_dipatch(tx, args, kwargs): |
|
for fn in handlers: |
|
rv = fn(tx, args, kwargs) |
|
if rv: |
|
return rv |
|
unimplemented(error_msg) |
|
|
|
return builtin_dipatch |
|
|
|
def _handle_insert_op_in_graph(self, tx, args, kwargs): |
|
from .builder import wrap_fx_proxy, wrap_fx_proxy_cls |
|
|
|
if kwargs and not self.tensor_args(*args, *kwargs.values()): |
|
return |
|
|
|
fn = self.fn |
|
try: |
|
|
|
if self.python_and_tensor_constant_only(*args, **kwargs): |
|
from ..bytecode_transformation import unique_id |
|
from .functions import invoke_and_store_as_constant |
|
|
|
return invoke_and_store_as_constant( |
|
tx, fn, unique_id(fn.__name__), args, kwargs |
|
) |
|
|
|
if fn in IN_PLACE_DESUGARING_MAP and isinstance( |
|
args[0], variables.ConstantVariable |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
fn, args = IN_PLACE_DESUGARING_MAP[fn], [args[0], args[1]] |
|
|
|
if fn is operator.getitem and isinstance(args[1], SymNodeVariable): |
|
|
|
|
|
|
|
fn, args = torch.select, [ |
|
args[0], |
|
variables.ConstantVariable.create(0), |
|
args[1], |
|
] |
|
|
|
|
|
|
|
if check_numpy_ndarray_args(args, kwargs) and not any( |
|
type(arg) == variables.TensorVariable for arg in args |
|
): |
|
proxy = tx.output.create_proxy( |
|
"call_function", |
|
numpy_operator_wrapper(fn), |
|
*proxy_args_kwargs(args, kwargs), |
|
) |
|
|
|
return wrap_fx_proxy_cls(variables.NumpyNdarrayVariable, tx, proxy) |
|
|
|
proxy = tx.output.create_proxy( |
|
"call_function", |
|
fn, |
|
*proxy_args_kwargs(args, kwargs), |
|
) |
|
if any(isinstance(arg, FakeItemVariable) for arg in args): |
|
return wrap_fx_proxy_cls( |
|
FakeItemVariable, |
|
tx, |
|
proxy, |
|
) |
|
elif check_unspec_python_args(args, kwargs): |
|
_args, _kwargs = self.unwrap_unspec_args_kwargs(args, kwargs) |
|
raw_value = fn(*_args, **_kwargs) |
|
|
|
need_unwrap = any( |
|
x.need_unwrap |
|
for x in itertools.chain(args, kwargs.values()) |
|
if isinstance(x, variables.UnspecializedPythonVariable) |
|
) |
|
|
|
return wrap_fx_proxy_cls( |
|
UnspecializedPythonVariable, |
|
tx, |
|
proxy, |
|
raw_value=raw_value, |
|
need_unwrap=need_unwrap, |
|
) |
|
elif all(isinstance(x, SymNodeVariable) for x in args): |
|
return SymNodeVariable.create(tx, proxy, None) |
|
else: |
|
|
|
|
|
if fn is operator.truediv and isinstance( |
|
args[0], variables.UnspecializedPythonVariable |
|
): |
|
args[0] = args[0].convert_to_constant(tx) |
|
return wrap_fx_proxy(tx, proxy) |
|
|
|
except NotImplementedError: |
|
unimplemented(f"partial tensor op: {self} {args} {kwargs}") |
|
|
|
call_function_handler_cache = {} |
|
|
|
def call_function( |
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" |
|
) -> "VariableTracker": |
|
if kwargs: |
|
kwargs = {k: v.realize() for k, v in kwargs.items()} |
|
key = (self.fn, *(type(x) for x in args), True) |
|
else: |
|
key = (self.fn, *(type(x) for x in args)) |
|
|
|
handler = self.call_function_handler_cache.get(key) |
|
if not handler: |
|
self.call_function_handler_cache[key] = handler = self._make_handler( |
|
self.fn, [type(x) for x in args], bool(kwargs) |
|
) |
|
return handler(tx, args, kwargs) |
|
|
|
def call_method( |
|
self, |
|
tx, |
|
name, |
|
args: "List[VariableTracker]", |
|
kwargs: "Dict[str, VariableTracker]", |
|
) -> "VariableTracker": |
|
if self.fn == object and name == "__setattr__": |
|
assert len(args) == 3 |
|
assert len(kwargs) == 0 |
|
obj, name_var, val = args |
|
obj = obj.realize() |
|
if ( |
|
isinstance(obj, UserDefinedObjectVariable) |
|
and tx.output.side_effects.is_attribute_mutation(obj) |
|
and name_var.is_python_constant() |
|
): |
|
return obj.method_setattr_standard(tx, name_var, val) |
|
if self.fn == dict and name == "fromkeys": |
|
return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs) |
|
if self.fn == itertools.chain and name == "from_iterable": |
|
assert len(args) == 1 |
|
assert len(kwargs) == 0 |
|
obj = args[0] |
|
items = [] |
|
for item in obj.unpack_var_sequence(tx): |
|
items.extend(item.unpack_var_sequence(tx)) |
|
return variables.TupleVariable(items) |
|
|
|
return super().call_method(tx, name, args, kwargs) |
|
|
|
def _call_int_float(self, tx, arg): |
|
|
|
|
|
if isinstance(arg, (SymNodeVariable, variables.TensorVariable)): |
|
if isinstance(arg, variables.TensorVariable): |
|
item = arg.call_method(tx, "item", [], {}) |
|
else: |
|
item = arg |
|
fn_ = sym_int if self.fn is int else sym_float |
|
from torch._dynamo.variables.builder import wrap_fx_proxy |
|
|
|
return wrap_fx_proxy( |
|
tx=tx, |
|
proxy=tx.output.create_proxy( |
|
"call_function", |
|
fn_, |
|
(item.as_proxy(),), |
|
{}, |
|
), |
|
) |
|
|
|
call_int = _call_int_float |
|
call_float = _call_int_float |
|
|
|
def call_str(self, tx, arg): |
|
|
|
if isinstance(arg, (variables.UserFunctionVariable)): |
|
return variables.ConstantVariable.create(value=str(arg.fn)) |
|
|
|
def _call_min_max(self, tx, *args): |
|
if len(args) == 1 and args[0].has_unpack_var_sequence(tx): |
|
|
|
items = args[0].unpack_var_sequence(tx) |
|
return self._call_min_max_seq(tx, items) |
|
elif len(args) == 2: |
|
return self._call_min_max_binary(tx, args[0], args[1]) |
|
elif len(args) > 2: |
|
return self._call_min_max_seq(tx, args) |
|
|
|
def _call_min_max_seq(self, tx, items): |
|
assert len(items) > 0 |
|
if len(items) == 1: |
|
return items[0] |
|
|
|
return functools.reduce(functools.partial(self._call_min_max_binary, tx), items) |
|
|
|
def _call_min_max_binary(self, tx, a, b): |
|
if self.tensor_args(a, b): |
|
if not isinstance(a, variables.TensorVariable): |
|
a, b = b, a |
|
assert isinstance(a, variables.TensorVariable) |
|
|
|
|
|
if isinstance(a, FakeItemVariable): |
|
a = variables.TorchInGraphFunctionVariable(torch.tensor).call_function( |
|
tx, [a], {} |
|
) |
|
|
|
|
|
if isinstance(a, SymNodeVariable) or isinstance(b, SymNodeVariable): |
|
from .builder import wrap_fx_proxy_cls |
|
|
|
return wrap_fx_proxy_cls( |
|
type(a), |
|
tx=tx, |
|
proxy=tx.output.create_proxy( |
|
"call_function", |
|
self.fn, |
|
*proxy_args_kwargs([a, b], {}), |
|
), |
|
) |
|
|
|
|
|
if b.is_python_constant(): |
|
if isinstance(a, variables.NumpyNdarrayVariable): |
|
import numpy as np |
|
|
|
fn = variables.NumpyVariable(np.clip) |
|
else: |
|
fn = variables.TorchInGraphFunctionVariable(torch.clamp) |
|
kwargs = {"min": b} if (self.fn is max) else {"max": b} |
|
result = fn.call_function(tx, [a], kwargs) |
|
else: |
|
if isinstance(a, variables.NumpyNdarrayVariable): |
|
import numpy as np |
|
|
|
fn = {max: np.maximum, min: np.minimum}[self.fn] |
|
fn = variables.NumpyVariable(fn) |
|
else: |
|
fn = {max: torch.maximum, min: torch.minimum}[self.fn] |
|
fn = variables.TorchInGraphFunctionVariable(fn) |
|
result = fn.call_function(tx, [a, b], {}) |
|
|
|
|
|
if all( |
|
isinstance( |
|
i, |
|
( |
|
variables.UnspecializedPythonVariable, |
|
variables.ConstantVariable, |
|
), |
|
) |
|
for i in [a, b] |
|
): |
|
if any(isinstance(val, FakeItemVariable) for val in [a, b]): |
|
return variables.FakeItemVariable.from_tensor_variable(result) |
|
|
|
if b.is_python_constant(): |
|
raw_b = b.as_python_constant() |
|
else: |
|
raw_b = b.raw_value |
|
if self.fn is max: |
|
raw_res = max(a.raw_value, raw_b) |
|
else: |
|
raw_res = min(a.raw_value, raw_b) |
|
|
|
need_unwrap = any( |
|
x.need_unwrap |
|
for x in [a, b] |
|
if isinstance(x, variables.UnspecializedPythonVariable) |
|
) |
|
return variables.UnspecializedPythonVariable.from_tensor_variable( |
|
result, raw_res, need_unwrap |
|
) |
|
|
|
else: |
|
return result |
|
elif isinstance(a, SymNodeVariable) or isinstance(b, SymNodeVariable): |
|
fn = torch.sym_max if self.fn is max else torch.sym_min |
|
proxy = tx.output.create_proxy( |
|
"call_function", fn, *proxy_args_kwargs([a, b], {}) |
|
) |
|
return SymNodeVariable.create(tx, proxy, None) |
|
|
|
call_min = _call_min_max |
|
call_max = _call_min_max |
|
|
|
def call_abs(self, tx, arg: "VariableTracker"): |
|
|
|
abs_method = BuiltinVariable(getattr).call_function( |
|
tx, [arg, ConstantVariable.create("__abs__")], {} |
|
) |
|
return abs_method.call_function(tx, [], {}) |
|
|
|
def call_pos(self, tx, arg: "VariableTracker"): |
|
|
|
pos_method = BuiltinVariable(getattr).call_function( |
|
tx, [arg, ConstantVariable.create("__pos__")], {} |
|
) |
|
return pos_method.call_function(tx, [], {}) |
|
|
|
def call_index(self, tx, arg: "VariableTracker"): |
|
if isinstance(arg, variables.TensorVariable): |
|
unimplemented("unsupported index(tensor)") |
|
|
|
arg = guard_if_dyn(arg) |
|
constant_value = operator.index(arg) |
|
return variables.ConstantVariable.create(constant_value) |
|
|
|
def call_round(self, tx, arg, *args, **kwargs): |
|
|
|
round_method = BuiltinVariable(getattr).call_function( |
|
tx, [arg, ConstantVariable.create("__round__")], {} |
|
) |
|
return round_method.call_function(tx, args, kwargs) |
|
|
|
def call_range(self, tx, *args): |
|
if check_unspec_or_constant_args(args, {}): |
|
return variables.RangeVariable(args) |
|
elif self._dynamic_args(*args): |
|
args = [ |
|
variables.ConstantVariable.create(guard_if_dyn(arg)) for arg in args |
|
] |
|
return variables.RangeVariable(args) |
|
|
|
return None |
|
|
|
def _dynamic_args(self, *args, **kwargs): |
|
return any(isinstance(x, SymNodeVariable) for x in args) or any( |
|
isinstance(x, SymNodeVariable) for x in kwargs.values() |
|
) |
|
|
|
def call_slice(self, tx, *args): |
|
return variables.SliceVariable(args) |
|
|
|
def _dyn_proxy(self, tx, *args, **kwargs): |
|
from .builder import wrap_fx_proxy |
|
|
|
return wrap_fx_proxy( |
|
tx, |
|
tx.output.create_proxy( |
|
"call_function", self.fn, *proxy_args_kwargs(args, kwargs) |
|
), |
|
) |
|
|
|
def _call_iter_tuple_list(self, tx, obj=None, *args, **kwargs): |
|
if self._dynamic_args(*args, **kwargs): |
|
return self._dyn_proxy(tx, *args, **kwargs) |
|
|
|
if isinstance(obj, variables.IteratorVariable): |
|
|
|
|
|
return obj |
|
|
|
cls = variables.BaseListVariable.cls_for(self.fn) |
|
if obj is None: |
|
return cls( |
|
[], |
|
mutable_local=MutableLocal(), |
|
) |
|
elif obj.has_unpack_var_sequence(tx): |
|
if obj.source and not is_constant_source(obj.source): |
|
if isinstance(obj, TupleIteratorVariable): |
|
install_guard( |
|
obj.source.make_guard(GuardBuilder.TUPLE_ITERATOR_LEN) |
|
) |
|
else: |
|
if ( |
|
getattr(obj, "source", False) |
|
and isinstance(obj, ConstDictVariable) |
|
and not istype(obj, SetVariable) |
|
): |
|
tx.output.guard_on_key_order.add(obj.source.name()) |
|
|
|
install_guard(obj.source.make_guard(GuardBuilder.SEQUENCE_LENGTH)) |
|
|
|
return cls( |
|
list(obj.unpack_var_sequence(tx)), |
|
mutable_local=MutableLocal(), |
|
) |
|
|
|
def call_iter(self, tx, obj, *args, **kwargs): |
|
|
|
ret = self._call_iter_tuple_list(tx, obj, *args, **kwargs) |
|
|
|
if ret is None: |
|
|
|
|
|
|
|
return obj.call_method(tx, "__iter__", args, kwargs) |
|
return ret |
|
|
|
call_tuple = _call_iter_tuple_list |
|
call_list = _call_iter_tuple_list |
|
|
|
def call_callable(self, tx, arg): |
|
from .functions import BaseUserFunctionVariable |
|
from .nn_module import NNModuleVariable |
|
|
|
if isinstance( |
|
arg, |
|
( |
|
variables.UserDefinedClassVariable, |
|
BaseUserFunctionVariable, |
|
NNModuleVariable, |
|
), |
|
): |
|
return variables.ConstantVariable.create(True) |
|
elif isinstance(arg, UserDefinedVariable): |
|
return variables.ConstantVariable.create(callable(arg.value)) |
|
elif isinstance(arg, (ConstantVariable, SymNodeVariable, TensorVariable)): |
|
return variables.ConstantVariable.create(False) |
|
|
|
def call_cast(self, _, *args, **kwargs): |
|
if len(args) == 2: |
|
return args[1] |
|
|
|
unimplemented(f"unsupported args to builtin cast(): {args} {kwargs}") |
|
|
|
def call_dict(self, tx, *args, **kwargs): |
|
return BuiltinVariable.call_custom_dict(tx, dict, *args, **kwargs) |
|
|
|
@staticmethod |
|
def call_custom_dict(tx, user_cls, *args, **kwargs): |
|
if not kwargs: |
|
if not args: |
|
args = ({},) |
|
assert len(args) == 1 |
|
arg = args[0] |
|
if isinstance(arg, dict): |
|
return ConstDictVariable(arg, user_cls, mutable_local=MutableLocal()) |
|
elif isinstance(arg, variables.ConstDictVariable): |
|
return arg.clone(user_cls=user_cls, mutable_local=MutableLocal()) |
|
elif isinstance( |
|
arg, |
|
( |
|
ListVariable, |
|
TupleVariable, |
|
ListIteratorVariable, |
|
), |
|
): |
|
items = dict( |
|
x.unpack_var_sequence(tx) for x in arg.unpack_var_sequence(tx) |
|
) |
|
return ConstDictVariable(items, user_cls, mutable_local=MutableLocal()) |
|
elif not args and kwargs: |
|
items = {ConstantVariable.create(k): v for k, v in kwargs.items()} |
|
return variables.ConstDictVariable( |
|
items, user_cls=user_cls, mutable_local=MutableLocal() |
|
) |
|
unimplemented(f"{user_cls.__name__}(): {args} {kwargs}") |
|
|
|
@staticmethod |
|
def call_custom_dict_fromkeys(tx, user_cls, *args, **kwargs): |
|
assert user_cls in {dict, OrderedDict, defaultdict} |
|
if kwargs: |
|
|
|
assert user_cls is OrderedDict |
|
assert len(args) == 1 and len(kwargs) == 1 and "value" in kwargs |
|
args = (*args, kwargs.pop("value")) |
|
if len(args) == 0: |
|
raise UserError(TypeError, "fromkeys expected at least 1 argument, got 0") |
|
if len(args) == 1: |
|
args = (*args, ConstantVariable.create(None)) |
|
assert len(args) == 2 |
|
arg, value = args |
|
DictVariableType = ( |
|
ConstDictVariable if user_cls is not defaultdict else DefaultDictVariable |
|
) |
|
|
|
if isinstance(arg, dict): |
|
arg = [ConstantVariable.create(k) for k in arg.keys()] |
|
return DictVariableType( |
|
dict.fromkeys(arg, value), user_cls, mutable_local=MutableLocal() |
|
) |
|
elif arg.has_unpack_var_sequence(tx) and all( |
|
is_hashable(v) for v in arg.unpack_var_sequence(tx) |
|
): |
|
keys = arg.unpack_var_sequence(tx) |
|
return DictVariableType( |
|
dict.fromkeys(keys, value), user_cls, mutable_local=MutableLocal() |
|
) |
|
unimplemented(f"{user_cls.__name__}.fromkeys(): {args} {kwargs}") |
|
|
|
def call_set(self, tx, *args, **kwargs): |
|
|
|
assert not kwargs |
|
if not args: |
|
return SetVariable([], mutable_local=MutableLocal()) |
|
assert len(args) == 1 |
|
arg = args[0] |
|
if isinstance(arg, variables.SetVariable): |
|
return arg.clone(mutable_local=MutableLocal()) |
|
elif arg.has_unpack_var_sequence(tx): |
|
items = arg.unpack_var_sequence(tx) |
|
return SetVariable(items, mutable_local=MutableLocal()) |
|
else: |
|
unimplemented(f"set(): {args} {kwargs}") |
|
|
|
def call_zip(self, tx, *args, **kwargs): |
|
if kwargs: |
|
assert len(kwargs) == 1 and "strict" in kwargs |
|
if all(x.has_unpack_var_sequence(tx) for x in args): |
|
unpacked = [arg.unpack_var_sequence(tx) for arg in args] |
|
if kwargs.pop("strict", False) and len(unpacked) > 0: |
|
if not all(len(u) == len(unpacked[0]) for u in unpacked): |
|
raise UserError( |
|
ValueError, |
|
"zip() has one argument of len differing from others", |
|
) |
|
items = [variables.TupleVariable(list(item)) for item in zip(*unpacked)] |
|
return variables.TupleVariable(items) |
|
|
|
def call_enumerate(self, tx, *args): |
|
if len(args) == 1: |
|
start = 0 |
|
else: |
|
assert len(args) == 2 |
|
assert isinstance(args[1], variables.ConstantVariable) |
|
start = args[1].as_python_constant() |
|
if args[0].has_unpack_var_sequence(tx): |
|
items = [ |
|
variables.TupleVariable( |
|
[variables.ConstantVariable.create(idx), var], |
|
) |
|
for idx, var in enumerate(args[0].unpack_var_sequence(tx), start) |
|
] |
|
return variables.TupleVariable(items) |
|
|
|
def call_len(self, tx, *args, **kwargs): |
|
return args[0].call_method(tx, "__len__", args[1:], kwargs) |
|
|
|
def call_getitem(self, tx, *args, **kwargs): |
|
return args[0].call_method(tx, "__getitem__", args[1:], kwargs) |
|
|
|
def call_isinstance(self, tx, arg, isinstance_type): |
|
try: |
|
arg_type = arg.python_type() |
|
except NotImplementedError: |
|
unimplemented( |
|
f"isinstance({arg}, {isinstance_type}): can't determine type of {arg}" |
|
) |
|
|
|
isinstance_type = isinstance_type.as_python_constant() |
|
|
|
if isinstance(arg, variables.TensorVariable) and arg.dtype is not None: |
|
|
|
def _tensor_isinstance(tensor_var, tensor_type): |
|
def check_type(ty): |
|
if ty not in tensortype_to_dtype: |
|
return issubclass(arg.python_type(), ty) |
|
|
|
dtypes = tensortype_to_dtype[ty] |
|
return arg.dtype in dtypes |
|
|
|
if type(tensor_type) is tuple: |
|
return any(check_type(ty) for ty in tensor_type) |
|
else: |
|
return check_type(tensor_type) |
|
|
|
return variables.ConstantVariable.create( |
|
_tensor_isinstance(arg, isinstance_type) |
|
) |
|
|
|
|
|
if isinstance(arg, variables.UserDefinedObjectVariable) and isinstance( |
|
arg.value, types.MemberDescriptorType |
|
): |
|
unimplemented( |
|
f"isinstance called on UserDefinedClass {arg} {isinstance_type}" |
|
) |
|
|
|
if ( |
|
isinstance(arg, variables.UserDefinedObjectVariable) |
|
and "__instancecheck__" in isinstance_type.__class__.__dict__ |
|
): |
|
return variables.ConstantVariable.create( |
|
isinstance_type.__class__.__instancecheck__(isinstance_type, arg.value) |
|
) |
|
|
|
try: |
|
val = issubclass(arg_type, isinstance_type) |
|
except TypeError: |
|
val = arg_type is isinstance_type |
|
return variables.ConstantVariable.create(val) |
|
|
|
def call_issubclass(self, tx, left_ty, right_ty): |
|
"""Checks if first arg is subclass of right arg""" |
|
try: |
|
left_ty_py = left_ty.as_python_constant() |
|
right_ty_py = right_ty.as_python_constant() |
|
except NotImplementedError: |
|
unimplemented( |
|
f"call_issubclass args not constant left_ty: {left_ty}, right_ty: {right_ty}" |
|
) |
|
|
|
return variables.ConstantVariable(issubclass(left_ty_py, right_ty_py)) |
|
|
|
def call_super(self, tx, a, b): |
|
return variables.SuperVariable(a, b) |
|
|
|
def call_next(self, tx, arg: VariableTracker): |
|
try: |
|
return arg.next_variable(tx) |
|
except Unsupported as ex: |
|
if isinstance(arg, variables.BaseListVariable): |
|
ex.remove_from_stats() |
|
return arg.items[0] |
|
raise |
|
|
|
def call_hasattr(self, tx, obj, attr): |
|
if attr.is_python_constant(): |
|
name = attr.as_python_constant() |
|
if isinstance(obj, variables.BuiltinVariable): |
|
return variables.ConstantVariable(hasattr(obj.fn, name)) |
|
return obj.call_hasattr(tx, name) |
|
|
|
def call_map(self, tx, fn, seq): |
|
if seq.has_unpack_var_sequence(tx): |
|
items = [fn.call_function(tx, [x], {}) for x in seq.unpack_var_sequence(tx)] |
|
return variables.TupleVariable(items) |
|
|
|
def call_sum(self, tx, seq, start=_SENTINEL): |
|
|
|
if isinstance(seq, (variables.ListVariable, variables.TupleVariable)) and all( |
|
isinstance(x, variables.ConstantVariable) |
|
and isinstance(x.value, (int, float)) |
|
for x in seq.items |
|
): |
|
if start is self._SENTINEL: |
|
return variables.ConstantVariable.create( |
|
sum(x.value for x in seq.items), |
|
) |
|
if isinstance(start, variables.ConstantVariable) and isinstance( |
|
start.value, (int, float) |
|
): |
|
return variables.ConstantVariable.create( |
|
sum((x.value for x in seq.items), start=start.value), |
|
) |
|
if seq.has_unpack_var_sequence(tx): |
|
if start is self._SENTINEL: |
|
start = variables.ConstantVariable.create(0) |
|
items = seq.unpack_var_sequence(tx) |
|
return BuiltinVariable(functools.reduce).call_function( |
|
tx, |
|
[ |
|
BuiltinVariable(operator.add), |
|
variables.TupleVariable(items), |
|
start, |
|
], |
|
{}, |
|
) |
|
|
|
def call_StopIteration(self, tx, *args): |
|
return variables.StopIterationVariable([*args]) |
|
|
|
def call_reduce(self, tx, function, iterable, initial=_SENTINEL): |
|
if iterable.has_unpack_var_sequence(tx): |
|
items = iterable.unpack_var_sequence(tx) |
|
if initial is self._SENTINEL: |
|
value, items = items[0], items[1:] |
|
else: |
|
value = initial |
|
for element in items: |
|
value = function.call_function(tx, [value, element], {}) |
|
return value |
|
|
|
def call_getattr( |
|
self, tx, obj: VariableTracker, name_var: VariableTracker, default=None |
|
): |
|
from .. import trace_rules |
|
from . import ( |
|
ConstantVariable, |
|
GetAttrVariable, |
|
PythonModuleVariable, |
|
TorchInGraphFunctionVariable, |
|
UserFunctionVariable, |
|
) |
|
from .builder import SourcelessBuilder, VariableBuilder |
|
|
|
name = name_var.as_python_constant() |
|
|
|
if not name_var.is_python_constant(): |
|
unimplemented("non-const getattr() name") |
|
|
|
if tx.output.side_effects.is_attribute_mutation(obj): |
|
if isinstance(obj, variables.UnspecializedNNModuleVariable): |
|
if ( |
|
name |
|
in ( |
|
"named_parameters", |
|
"parameters", |
|
"named_buffers", |
|
"buffers", |
|
"named_modules", |
|
"modules", |
|
) |
|
and obj.is_state_mutated |
|
and tx.output.side_effects.has_pending_mutation(obj) |
|
): |
|
unimplemented( |
|
f"pending mutation on nn module, so graph breaking at {name!r} call" |
|
) |
|
|
|
try: |
|
|
|
return tx.output.side_effects.load_attr(obj, name) |
|
except KeyError: |
|
pass |
|
|
|
if default is not None: |
|
hasattr_var = self.call_hasattr(tx, obj, name_var) |
|
assert hasattr_var.as_python_constant() in (True, False) |
|
if not hasattr_var.as_python_constant(): |
|
return default |
|
|
|
options = {} |
|
if obj.source: |
|
source = AttrSource(obj.source, name) |
|
options["source"] = source |
|
else: |
|
source = None |
|
|
|
if name == "__bases__": |
|
try: |
|
value = obj.as_python_constant() |
|
if isinstance(value, type): |
|
bases = value.__bases__ |
|
if source is not None: |
|
tuple_args = [ |
|
VariableBuilder(tx, GetItemSource(source, i))(b) |
|
for i, b in enumerate(bases) |
|
] |
|
else: |
|
tuple_args = [SourcelessBuilder.create(tx, b) for b in bases] |
|
|
|
return variables.TupleVariable(tuple_args, **options) |
|
except NotImplementedError: |
|
pass |
|
|
|
if isinstance(obj, variables.NNModuleVariable): |
|
return obj.var_getattr(tx, name) |
|
elif isinstance( |
|
obj, |
|
( |
|
variables.TensorVariable, |
|
variables.NamedTupleVariable, |
|
variables.ConstantVariable, |
|
variables.DistributedVariable, |
|
variables.UserDefinedClassVariable, |
|
variables.UserDefinedObjectVariable, |
|
), |
|
): |
|
try: |
|
return obj.var_getattr(tx, name) |
|
except NotImplementedError: |
|
return GetAttrVariable(obj, name, **options) |
|
elif isinstance(obj, TorchInGraphFunctionVariable): |
|
|
|
member = getattr(obj.value, name) |
|
if isinstance( |
|
member, (torch._ops.OpOverloadPacket, torch._ops.OpOverload) |
|
) and trace_rules.is_aten_op_or_tensor_method(member): |
|
return TorchInGraphFunctionVariable(member, **options) |
|
elif isinstance(obj, (PythonModuleVariable, DummyModule)): |
|
if obj.is_torch or name not in obj.value.__dict__: |
|
member = getattr(obj.value, name) |
|
else: |
|
member = obj.value.__dict__[name] |
|
|
|
if config.replay_record_enabled: |
|
tx.exec_recorder.record_module_access(obj.value, name, member) |
|
|
|
if source is not None: |
|
return VariableBuilder(tx, source)(member) |
|
else: |
|
return SourcelessBuilder.create(tx, member) |
|
elif istype(obj, UserFunctionVariable) and name in ("__name__", "__module__"): |
|
return ConstantVariable.create(getattr(obj.fn, name)) |
|
else: |
|
try: |
|
return obj.var_getattr(tx, name) |
|
except NotImplementedError: |
|
return GetAttrVariable(obj, name, **options) |
|
|
|
def call_setattr( |
|
self, tx, obj: VariableTracker, name_var: VariableTracker, val: VariableTracker |
|
): |
|
if isinstance( |
|
obj, |
|
( |
|
variables.DataClassVariable, |
|
variables.CustomizedDictVariable, |
|
variables.PlacementVariable, |
|
variables.UserDefinedObjectVariable, |
|
), |
|
): |
|
return obj.call_method(tx, "__setattr__", [name_var, val], {}) |
|
elif ( |
|
tx.output.side_effects.is_attribute_mutation(obj) |
|
and name_var.is_python_constant() |
|
): |
|
name = name_var.as_python_constant() |
|
if isinstance(obj, variables.TensorVariable): |
|
from .builder import wrap_fx_proxy |
|
|
|
if name == "requires_grad": |
|
|
|
unimplemented( |
|
"mutating requires_grad can introduce a new leaf from non-leaf or vice versa in " |
|
"the middle of the graph, which aot_autograd does not currently know how to handle. " |
|
) |
|
if name == "data": |
|
|
|
|
|
|
|
|
|
|
|
to_remove = [] |
|
for tf in tx.output.tracked_fakes: |
|
if tf.source == obj.source: |
|
to_remove.append(tf) |
|
for tf in to_remove: |
|
tx.output.tracked_fakes.remove(tf) |
|
|
|
|
|
with dynamo_disable_grad(tx), torch.no_grad(): |
|
|
|
out = wrap_fx_proxy( |
|
tx, |
|
tx.output.create_proxy( |
|
"call_function", |
|
torch.Tensor.set_, |
|
*proxy_args_kwargs([obj, val], {}), |
|
), |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def _lower_version_count_by_1(x): |
|
version = x._version |
|
if version > 0: |
|
version = version - 1 |
|
torch._C._autograd._unsafe_set_version_counter(x, version) |
|
return x |
|
|
|
tx.output.create_proxy( |
|
"call_function", |
|
_lower_version_count_by_1, |
|
(out.as_proxy(),), |
|
{}, |
|
) |
|
_lower_version_count_by_1(obj.as_proxy().node.meta["example_value"]) |
|
|
|
|
|
return out |
|
|
|
tx.output.side_effects.store_attr(obj, name, val) |
|
if name == "_grad": |
|
tx.output.side_effects.store_attr(obj, "grad", val) |
|
|
|
return val |
|
elif isinstance(obj, variables.UserDefinedObjectVariable): |
|
unimplemented( |
|
f"setattr(UserDefinedObjectVariable) {type(obj.value).__setattr__}" |
|
) |
|
elif isinstance(obj, variables.NNModuleVariable): |
|
if not tx.output.is_root_tracer(): |
|
raise AttributeMutationError( |
|
"Can't inplace modify module params/buffers inside HigherOrderOp" |
|
) |
|
if name_var.is_python_constant() and isinstance( |
|
val, variables.TensorVariable |
|
): |
|
assigning_fake_val = get_fake_value(val.as_proxy().node, tx) |
|
|
|
try: |
|
getattr_var = obj.var_getattr(tx, name_var.as_python_constant()) |
|
except AttributeError: |
|
getattr_var = None |
|
|
|
if isinstance(getattr_var, variables.TensorVariable): |
|
|
|
existing_fake_attr = get_fake_value(getattr_var.as_proxy().node, tx) |
|
|
|
|
|
mod_setattr = inspect.getattr_static(obj.module_type, "__setattr__") |
|
if ( |
|
existing_fake_attr is assigning_fake_val |
|
and mod_setattr is torch.nn.Module.__setattr__ |
|
): |
|
return getattr_var |
|
|
|
obj.convert_to_unspecialized(tx) |
|
|
|
|
|
|
|
|
|
|
|
elif isinstance(obj, variables.dicts.HFPretrainedConfigVariable) and tx.export: |
|
if name_var.is_python_constant() and isinstance( |
|
val, variables.ConstantVariable |
|
): |
|
setattr( |
|
obj.obj, name_var.as_python_constant(), val.as_python_constant() |
|
) |
|
return ConstantVariable(None) |
|
|
|
def call_delattr(self, tx, obj: VariableTracker, name_var: VariableTracker): |
|
return self.call_setattr(tx, obj, name_var, variables.DeletedVariable()) |
|
|
|
def call_type(self, tx, obj: VariableTracker): |
|
from .builder import SourcelessBuilder, VariableBuilder |
|
|
|
try: |
|
py_type = obj.python_type() |
|
except NotImplementedError as error: |
|
raise UserError( |
|
UserErrorType.INVALID_INPUT, |
|
str(error), |
|
case_name="unknown_python_type", |
|
) from None |
|
|
|
if obj.source is None: |
|
return SourcelessBuilder.create(tx, py_type) |
|
else: |
|
return VariableBuilder(tx, TypeSource(obj.source))(py_type) |
|
|
|
def call_reversed(self, tx, obj: VariableTracker): |
|
if obj.has_unpack_var_sequence(tx): |
|
items = list(reversed(obj.unpack_var_sequence(tx))) |
|
return variables.TupleVariable(items) |
|
|
|
def call_sorted(self, tx, obj: VariableTracker, **kwargs): |
|
if ( |
|
obj.has_unpack_var_sequence(tx) |
|
and not isinstance(obj, variables.TensorVariable) |
|
and all(x.is_python_constant() for x in obj.unpack_var_sequence(tx)) |
|
): |
|
function = kwargs.pop("key", None) |
|
reverse = kwargs.pop( |
|
"reverse", ConstantVariable.create(False) |
|
).as_python_constant() |
|
assert len(kwargs) == 0 |
|
if function: |
|
items = sorted( |
|
obj.unpack_var_sequence(tx), |
|
key=lambda x: function.call_function( |
|
tx, [x], {} |
|
).as_python_constant(), |
|
reverse=reverse, |
|
) |
|
else: |
|
items = sorted( |
|
obj.unpack_var_sequence(tx), |
|
key=lambda x: x.as_python_constant(), |
|
reverse=reverse, |
|
) |
|
return variables.ListVariable(items) |
|
|
|
def call_chain(self, tx, *args): |
|
if all(obj.has_unpack_var_sequence(tx) for obj in args): |
|
items = [] |
|
for obj in args: |
|
items.extend(obj.unpack_var_sequence(tx)) |
|
return variables.TupleVariable(items) |
|
|
|
def call_islice(self, tx, iterable, *args): |
|
if iterable.has_unpack_var_sequence(tx) and all( |
|
x.is_python_constant() for x in args |
|
): |
|
const_args = [x.as_python_constant() for x in args] |
|
items = iterable.unpack_var_sequence(tx) |
|
items = list(itertools.islice(items, *const_args)) |
|
return variables.TupleVariable(items) |
|
|
|
|
|
def call_neg(self, tx, a): |
|
if isinstance(a, SymNodeVariable): |
|
return SymNodeVariable.create( |
|
tx, |
|
(operator.neg)(a.as_proxy()), |
|
sym_num=None, |
|
) |
|
|
|
return None |
|
|
|
def call_format(self, tx, _format_string, *args, **kwargs): |
|
format_string = _format_string.as_python_constant() |
|
return variables.StringFormatVariable.create(format_string, args, kwargs) |
|
|
|
def call_id(self, tx, *args): |
|
if len(args) > 0 and isinstance(args[0], variables.NNModuleVariable): |
|
nn_mod_variable = args[0] |
|
mod = tx.output.get_submodule(nn_mod_variable.module_key) |
|
return variables.ConstantVariable.create(id(mod)) |
|
elif len(args) == 1 and isinstance( |
|
args[0], variables.UserDefinedObjectVariable |
|
): |
|
install_guard(args[0].source.make_guard(GuardBuilder.ID_MATCH)) |
|
constant_result = id(args[0].value) |
|
return variables.ConstantVariable.create(constant_result) |
|
else: |
|
unimplemented(f"call_id with args {args}") |
|
|
|
def call_deepcopy(self, tx, x): |
|
unimplemented(f"copy.deepcopy {repr(x)}") |
|
|
|
def _comparison_with_tensor(self, tx, left, right): |
|
from .builder import wrap_fx_proxy_cls |
|
from .tensor import supported_tensor_comparison_op_values |
|
|
|
op = self.fn |
|
|
|
if op in [operator.is_, operator.is_not]: |
|
is_result = ( |
|
isinstance(left, TensorVariable) |
|
and isinstance(right, TensorVariable) |
|
and id(extract_fake_example_value(left.as_proxy().node)) |
|
== id(extract_fake_example_value(right.as_proxy().node)) |
|
) |
|
if op is operator.is_: |
|
return ConstantVariable.create(is_result) |
|
else: |
|
return ConstantVariable.create(not is_result) |
|
|
|
if op not in supported_tensor_comparison_op_values: |
|
unimplemented(f"{op.__name__}({left}, {right})") |
|
if ( |
|
isinstance(left, TensorVariable) |
|
and isinstance(right, TensorVariable) |
|
and (left.size and right.size) is not None |
|
and left.size != right.size |
|
): |
|
try: |
|
torch.broadcast_shapes(left.size, right.size) |
|
except RuntimeError: |
|
|
|
unimplemented(f"{op.__name__}({left}, {right})") |
|
tensor_cls = left if isinstance(left, TensorVariable) else right |
|
proxy = tx.output.create_proxy( |
|
"call_function", op, (left.as_proxy(), right.as_proxy()), {} |
|
) |
|
return wrap_fx_proxy_cls( |
|
type(tensor_cls), |
|
tx, |
|
proxy, |
|
) |
|
|
|
def _comparison_with_symnode(self, tx, left, right): |
|
from .tensor import supported_tensor_comparison_op_values |
|
|
|
op = self.fn |
|
|
|
if op not in supported_tensor_comparison_op_values: |
|
unimplemented(f"{op.__name__}({left}, {right})") |
|
|
|
proxy = tx.output.create_proxy( |
|
"call_function", op, (left.as_proxy(), right.as_proxy()), {} |
|
) |
|
return SymNodeVariable.create( |
|
tx, |
|
proxy, |
|
sym_num=None, |
|
) |
|
|
|
def call_and_(self, tx, a, b): |
|
|
|
if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): |
|
return None |
|
if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance( |
|
b, (SymNodeVariable, ConstantVariable) |
|
): |
|
return SymNodeVariable.create( |
|
tx, |
|
tx.output.create_proxy( |
|
"call_function", operator.and_, *proxy_args_kwargs([a, b], {}) |
|
), |
|
sym_num=None, |
|
) |
|
if hasattr(a, "set_items") and hasattr(b, "set_items"): |
|
return SetVariable(list(a.set_items & b.set_items)) |
|
|
|
|
|
def call_or_(self, tx, a, b): |
|
|
|
if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): |
|
return None |
|
if isinstance(a, (SymNodeVariable, ConstantVariable)) and isinstance( |
|
b, (SymNodeVariable, ConstantVariable) |
|
): |
|
return SymNodeVariable.create( |
|
tx, |
|
tx.output.create_proxy( |
|
"call_function", operator.or_, *proxy_args_kwargs([a, b], {}) |
|
), |
|
sym_num=None, |
|
) |
|
if hasattr(a, "set_items") and hasattr(b, "set_items"): |
|
return SetVariable(list(a.set_items | b.set_items)) |
|
|
|
return None |
|
|
|
def call_not_(self, tx, a): |
|
if isinstance(a, SymNodeVariable): |
|
return SymNodeVariable.create( |
|
tx, |
|
tx.output.create_proxy( |
|
"call_function", operator.not_, *proxy_args_kwargs([a], {}) |
|
), |
|
sym_num=None, |
|
) |
|
|
|
|
|
if isinstance(a, DictView): |
|
a = a.dv_dict |
|
if isinstance(a, (ListVariable, ConstDictVariable)): |
|
return ConstantVariable.create(len(a.items) == 0) |
|
|
|
return None |
|
|
|
def call_contains(self, tx, a: VariableTracker, b: VariableTracker): |
|
return a.call_method(tx, "__contains__", [b], {}) |
|
|
|
call_all = _polyfill_call_impl("all") |
|
call_any = _polyfill_call_impl("any") |
|
|
|
|
|
@contextlib.contextmanager |
|
def dynamo_disable_grad(tx): |
|
from . import GradModeVariable |
|
|
|
org_value = torch.is_grad_enabled() |
|
gmv = GradModeVariable.create(tx, False) |
|
try: |
|
gmv.enter(tx) |
|
yield |
|
finally: |
|
gmv.exit(tx) |
|
|