|
|
|
import collections |
|
import collections.abc |
|
import contextlib |
|
import copy |
|
import dataclasses |
|
import dis |
|
import functools |
|
import importlib |
|
import inspect |
|
import itertools |
|
import linecache |
|
import logging |
|
import operator |
|
import sys |
|
import textwrap |
|
import threading |
|
import traceback |
|
import types |
|
import typing |
|
import weakref |
|
from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type |
|
from unittest.mock import patch |
|
|
|
import torch |
|
import torch._logging |
|
from torch._guards import tracing, TracingContext |
|
|
|
from . import config, exc, logging as torchdynamo_logging, trace_rules, variables |
|
from .bytecode_analysis import ( |
|
get_indexof, |
|
JUMP_OPNAMES, |
|
livevars_analysis, |
|
propagate_line_nums, |
|
) |
|
from .bytecode_transformation import ( |
|
cleaned_instructions, |
|
create_call_function, |
|
create_instruction, |
|
create_jump_absolute, |
|
create_swap, |
|
get_code_keys, |
|
Instruction, |
|
is_generator, |
|
unique_id, |
|
) |
|
from .code_context import code_context |
|
from .codegen import PyCodegen |
|
from .exc import ArgsMismatchError, BackendCompilerFailed, unimplemented, Unsupported |
|
from .funcname_cache import get_funcname |
|
from .guards import GuardBuilder, install_guard |
|
from .output_graph import GraphCompileReason, OutputGraph |
|
from .replay_record import DummyModule, ExecutionRecorder |
|
from .resume_execution import ContinueExecutionCache, ReenterWith |
|
from .source import ( |
|
AttrSource, |
|
GetItemSource, |
|
GlobalSource, |
|
GlobalWeakRefSource, |
|
LocalSource, |
|
Source, |
|
) |
|
from .trace_rules import is_builtin_constant, is_forbidden |
|
from .utils import ( |
|
counters, |
|
get_fake_value, |
|
get_instruction_source_311, |
|
graph_break_dup_warning_checker, |
|
istype, |
|
LazyString, |
|
proxy_args_kwargs, |
|
) |
|
from .variables.base import is_side_effect_safe, MutableLocal, typestr, VariableTracker |
|
from .variables.builder import VariableBuilder, wrap_fx_proxy |
|
from .variables.builtin import BuiltinVariable |
|
from .variables.constant import ConstantVariable |
|
from .variables.ctx_manager import ( |
|
ContextWrappingVariable, |
|
GenericContextWrappingVariable, |
|
WithExitFunctionVariable, |
|
) |
|
from .variables.dicts import ConstDictVariable, SetVariable |
|
from .variables.functions import ( |
|
BaseUserFunctionVariable, |
|
NestedUserFunctionVariable, |
|
SkipFunctionVariable, |
|
UserFunctionVariable, |
|
UserMethodVariable, |
|
) |
|
from .variables.lists import ( |
|
BaseListVariable, |
|
ListIteratorVariable, |
|
ListVariable, |
|
SliceVariable, |
|
TupleVariable, |
|
) |
|
from .variables.misc import ( |
|
ClosureVariable, |
|
GetAttrVariable, |
|
InlinedClosureVariable, |
|
NullVariable, |
|
PythonModuleVariable, |
|
UnknownVariable, |
|
) |
|
from .variables.nn_module import NNModuleVariable |
|
from .variables.tensor import supported_comparison_ops, SymNodeVariable, TensorVariable |
|
from .variables.user_defined import ( |
|
RemovableHandleVariable, |
|
UserDefinedClassVariable, |
|
UserDefinedObjectVariable, |
|
) |
|
|
|
log = logging.getLogger(__name__) |
|
graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks") |
|
trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call") |
|
trace_source_log = torch._logging.getArtifactLogger(__name__, "trace_source") |
|
trace_bytecode_log = torch._logging.getArtifactLogger(__name__, "trace_bytecode") |
|
tls = threading.local() |
|
compare_op_handlers: Dict[str, Any] = { |
|
k: BuiltinVariable(v).call_function for k, v in supported_comparison_ops.items() |
|
} |
|
handle_contains = BuiltinVariable(operator.contains).call_function |
|
handle_not = BuiltinVariable(operator.not_).call_function |
|
compare_op_handlers["in"] = lambda tx, args, _: handle_contains( |
|
tx, [*reversed(args)], {} |
|
) |
|
compare_op_handlers["not in"] = lambda tx, args, _: handle_not( |
|
tx, [handle_contains(tx, [*reversed(args)], {})], {} |
|
) |
|
|
|
|
|
@dataclasses.dataclass |
|
class SpeculationEntry: |
|
filename: str |
|
lineno: int |
|
instruction_pointer: int |
|
failed: bool = False |
|
reason: Optional[GraphCompileReason] = None |
|
|
|
def fail_and_restart_analysis(self): |
|
""" |
|
Start tracing of the current frame over again, and don't take this branch. |
|
""" |
|
self.failed = True |
|
if self.reason is not None: |
|
restart_reason = self.reason.reason |
|
else: |
|
restart_reason = "Unknown fail_and_restart_analysis" |
|
raise exc.SpeculationRestartAnalysis(restart_reason=restart_reason) |
|
|
|
|
|
@dataclasses.dataclass |
|
class SpeculationLog: |
|
""" |
|
SpeculationLog replaces the prior copy_graphstate/restore_graphstate |
|
checkpointing. Rather than saving/restoring state, we restart the |
|
dynamo conversion process over from the beginning -- but when we |
|
hit the start of the speculation that failed, we instead generate |
|
a graph break. |
|
""" |
|
|
|
entries: List[SpeculationEntry] = dataclasses.field(default_factory=list) |
|
index: int = 0 |
|
|
|
def restart(self): |
|
self.index = 0 |
|
|
|
def clear(self): |
|
self.entries.clear() |
|
self.index = 0 |
|
|
|
def next(self, filename: str, lineno: int, instruction_pointer) -> SpeculationEntry: |
|
""" |
|
Lookup or create a SpeculationEntry() that is shared across |
|
RestartAnalysis calls. Args are used only for debug checks. |
|
""" |
|
if len(self.entries) == self.index: |
|
self.entries.append(SpeculationEntry(filename, lineno, instruction_pointer)) |
|
entry = self.entries[self.index] |
|
self.index += 1 |
|
assert ( |
|
entry.instruction_pointer == instruction_pointer |
|
and entry.filename == filename |
|
and entry.lineno == lineno |
|
), textwrap.dedent( |
|
f""" |
|
SpecuationLog diverged at {self.index} of {len(self.entries)}: |
|
- Run1: {entry.filename}:{entry.lineno} (ip={entry.instruction_pointer}) |
|
- Run2: {filename}:{lineno} (ip={instruction_pointer}) |
|
Please submit a bug report. |
|
""" |
|
) |
|
return entry |
|
|
|
|
|
@functools.lru_cache(None) |
|
def _step_logger(): |
|
return torchdynamo_logging.get_step_logger(log) |
|
|
|
|
|
@dataclasses.dataclass |
|
class BlockStackEntry: |
|
|
|
inst: Instruction |
|
target: Instruction |
|
stack_index: Optional[int] = None |
|
with_context: Optional[ContextWrappingVariable] = None |
|
|
|
def can_restore(self): |
|
return self.with_context is not None |
|
|
|
def resume_fn(self): |
|
assert self.stack_index is not None |
|
if self.with_context and self.with_context.target_values: |
|
return ReenterWith(self.stack_index, tuple(self.with_context.target_values)) |
|
else: |
|
return ReenterWith(self.stack_index) |
|
|
|
def exit(self, tx): |
|
assert self.with_context is not None |
|
return self.with_context.exit(tx) |
|
|
|
|
|
class ReturnValueOp(Exception): |
|
pass |
|
|
|
|
|
def stack_op(fn: typing.Callable[..., object]): |
|
nargs = len(inspect.signature(fn).parameters) |
|
fn_var = BuiltinVariable(fn) |
|
|
|
@functools.wraps(fn) |
|
def impl(self: "InstructionTranslatorBase", inst: Instruction): |
|
self.push(fn_var.call_function(self, self.popn(nargs), {})) |
|
|
|
return impl |
|
|
|
|
|
def _detect_and_normalize_assert_statement( |
|
self: "InstructionTranslatorBase", |
|
truth_fn: typing.Callable[[object], bool], |
|
push: bool, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (truth_fn is not operator.truth) or push: |
|
return False |
|
|
|
assert isinstance(self.instruction_pointer, int) |
|
current_instruction_pointer = self.instruction_pointer |
|
inst = self.instructions[current_instruction_pointer] |
|
|
|
if sys.version_info < (3, 9): |
|
if inst.opname != "LOAD_GLOBAL" or inst.argval != "AssertionError": |
|
return False |
|
else: |
|
if inst.opname != "LOAD_ASSERTION_ERROR": |
|
return False |
|
|
|
current_instruction_pointer += 1 |
|
|
|
|
|
error_msg = "assertion error" |
|
|
|
inst = self.instructions[current_instruction_pointer] |
|
|
|
if inst.opname == "LOAD_CONST": |
|
if not isinstance(inst.argval, str): |
|
return False |
|
error_msg = inst.argval |
|
|
|
|
|
|
|
current_instruction_pointer += 1 |
|
inst = self.instructions[current_instruction_pointer] |
|
if inst.opname not in ("CALL_FUNCTION", "PRECALL", "CALL"): |
|
return False |
|
|
|
|
|
|
|
current_instruction_pointer += 1 |
|
if inst.opname == "PRECALL": |
|
current_instruction_pointer += 1 |
|
inst = self.instructions[current_instruction_pointer] |
|
|
|
if inst.opname != "RAISE_VARARGS": |
|
return False |
|
|
|
self.push(ConstantVariable.create(error_msg)) |
|
|
|
return True |
|
|
|
|
|
def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool): |
|
def jump_graph_break(self, inst, value, extra_msg=""): |
|
if not self.should_compile_partial_graph(): |
|
unimplemented("should_compile_partial_graph=False") |
|
|
|
if self.maybe_has_backedge(): |
|
msg = ( |
|
"Skipping frame because there is a graph break in a for/while loop\n" |
|
f"{self.frame_summary()}" |
|
) |
|
log.info(msg) |
|
raise exc.SkipFrame(msg) |
|
|
|
self.push(value) |
|
log.debug("generic_jump triggered compile") |
|
self.output.compile_subgraph( |
|
self, |
|
reason=GraphCompileReason( |
|
f"generic_jump {typestr(value)}{extra_msg}", [self.frame_summary()] |
|
), |
|
) |
|
self.pop() |
|
|
|
if_next = self.create_call_resume_at(self.next_instruction) |
|
if push: |
|
self.push(value) |
|
if_jump = self.create_call_resume_at(inst.target) |
|
|
|
self.output.add_output_instructions( |
|
[create_instruction(inst.opname, target=if_jump[0])] + if_next + if_jump |
|
) |
|
|
|
def inner(self: "InstructionTranslatorBase", inst: Instruction): |
|
value: VariableTracker = self.pop() |
|
if ( |
|
config.rewrite_assert_with_torch_assert |
|
and _detect_and_normalize_assert_statement(self, truth_fn, push) |
|
): |
|
error_msg: VariableTracker = self.pop() |
|
|
|
if value.is_python_constant(): |
|
if bool(value.as_python_constant()): |
|
return self.jump(inst) |
|
else: |
|
jump_graph_break(self, inst, value) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(value, TensorVariable): |
|
self.output.create_proxy( |
|
"call_function", |
|
torch._assert_async, |
|
*proxy_args_kwargs((value, error_msg), {}), |
|
) |
|
self.jump(inst) |
|
return |
|
|
|
if isinstance(value, SymNodeVariable): |
|
|
|
|
|
sym_expr = value.sym_num |
|
if not isinstance(sym_expr, torch.SymBool): |
|
sym_expr = sym_expr != 0 |
|
|
|
result = torch.fx.experimental.symbolic_shapes.expect_true(sym_expr) |
|
if not result: |
|
unimplemented( |
|
"Assertion failed on symbolic shapes. Did you make sure eager mode succeeds?" |
|
) |
|
self.jump(inst) |
|
return |
|
|
|
scalar_to_tensor_proxy = self.output.create_proxy( |
|
"call_function", torch.scalar_tensor, *proxy_args_kwargs((value,), {}) |
|
) |
|
|
|
scalar_to_tensor = wrap_fx_proxy( |
|
self, |
|
scalar_to_tensor_proxy, |
|
example_value=get_fake_value(scalar_to_tensor_proxy.node, self), |
|
) |
|
|
|
self.output.create_proxy( |
|
"call_function", |
|
torch._assert_async, |
|
*proxy_args_kwargs((scalar_to_tensor, error_msg), {}), |
|
) |
|
self.jump(inst) |
|
return |
|
|
|
if value.is_python_constant(): |
|
if truth_fn(value.as_python_constant()): |
|
if push: |
|
self.push(value) |
|
self.jump(inst) |
|
elif ( |
|
isinstance(value, (TensorVariable)) and self.should_compile_partial_graph() |
|
): |
|
jump_graph_break(self, inst, value) |
|
elif isinstance(value, NNModuleVariable): |
|
|
|
mod = self.output.get_submodule(value.module_key) |
|
if truth_fn(mod): |
|
if push: |
|
self.push(value) |
|
self.jump(inst) |
|
elif isinstance(value, UserDefinedObjectVariable): |
|
x = value.var_getattr(self, "__bool__") |
|
|
|
if isinstance(x, GetAttrVariable): |
|
x = value.var_getattr(self, "__len__") |
|
|
|
|
|
if isinstance(x, UserMethodVariable): |
|
result = x.call_function(self, [], {}) |
|
if isinstance(result, ConstantVariable) and isinstance( |
|
result.value, (bool, int) |
|
): |
|
if truth_fn(result.value): |
|
if push: |
|
self.push(value) |
|
self.jump(inst) |
|
else: |
|
unimplemented( |
|
"generic_jump on UserDefined with __bool__ returning non-constant" |
|
) |
|
|
|
else: |
|
if truth_fn(True): |
|
if push: |
|
self.push(value) |
|
self.jump(inst) |
|
elif not isinstance(value, TensorVariable) and value.has_unpack_var_sequence( |
|
self |
|
): |
|
if truth_fn(len(value.unpack_var_sequence(self))): |
|
if push: |
|
self.push(value) |
|
self.jump(inst) |
|
elif isinstance(value, SymNodeVariable): |
|
try: |
|
eval_result = value.evaluate_expr(self.output) |
|
except exc.UserError as e: |
|
if self.should_compile_partial_graph(): |
|
return jump_graph_break(self, inst, value, extra_msg=f"\n{e}") |
|
raise |
|
if truth_fn(eval_result): |
|
if push: |
|
self.push(value) |
|
self.jump(inst) |
|
elif isinstance(value, variables.BackwardHookVariable): |
|
if truth_fn(True): |
|
if push: |
|
self.push(value) |
|
self.jump(inst) |
|
else: |
|
from .source import is_constant_source |
|
|
|
if value.source is not None and is_constant_source(value.source): |
|
if truth_fn(value.get_real_value()): |
|
if push: |
|
self.push(value) |
|
self.jump(inst) |
|
else: |
|
|
|
raise exc.UserError( |
|
exc.UserErrorType.DYNAMIC_CONTROL_FLOW, |
|
"Dynamic control flow is not supported at the moment. Please use " |
|
"functorch.experimental.control_flow.cond to explicitly capture the control flow.", |
|
case_name="cond_operands", |
|
) |
|
|
|
return inner |
|
|
|
|
|
explain = False |
|
|
|
|
|
def break_graph_if_unsupported(*, push): |
|
def decorator(inner_fn): |
|
@functools.wraps(inner_fn) |
|
def wrapper(self: "InstructionTranslatorBase", inst: Instruction): |
|
speculation = self.speculate() |
|
if speculation.failed: |
|
assert speculation.reason is not None |
|
return handle_graph_break(self, inst, speculation.reason) |
|
try: |
|
return inner_fn(self, inst) |
|
except Unsupported as excp: |
|
if self.generic_context_manager_depth > 0: |
|
|
|
|
|
excp.remove_from_stats() |
|
unimplemented("Graph break under GenericContextWrappingVariable") |
|
|
|
if isinstance(excp, exc.UncapturedHigherOrderOpError): |
|
raise |
|
|
|
if not self.should_compile_partial_graph(): |
|
raise |
|
|
|
user_stack = excp.real_stack |
|
|
|
try: |
|
frame_loc = (user_stack[-1].filename, user_stack[-1].lineno) |
|
except IndexError: |
|
|
|
code_options = self.code_options |
|
frame_loc = ( |
|
code_options["co_filename"], |
|
code_options["co_firstlineno"], |
|
) |
|
|
|
|
|
if ( |
|
graph_break_log.isEnabledFor(logging.DEBUG) |
|
and not explain |
|
and graph_break_dup_warning_checker.add(frame_loc) |
|
): |
|
user_stack_formatted = "".join(traceback.format_list(user_stack)) |
|
|
|
|
|
graph_break_log.debug( |
|
"Graph break: from user code at:\n%s", |
|
user_stack_formatted, |
|
exc_info=True, |
|
) |
|
else: |
|
|
|
|
|
|
|
log.debug( |
|
"Unsupported break in user code at %s:%s (details suppressed)", |
|
*frame_loc, |
|
) |
|
|
|
if self.maybe_has_backedge(): |
|
msg = ( |
|
"Skipping frame because there is a graph break in a for/while loop\n" |
|
f"{self.frame_summary()}" |
|
) |
|
log.info(msg) |
|
raise exc.SkipFrame(msg) from excp |
|
|
|
excp.remove_from_stats() |
|
excp.add_to_stats("graph_break") |
|
speculation.reason = GraphCompileReason(excp.msg, user_stack) |
|
speculation.fail_and_restart_analysis() |
|
|
|
def handle_graph_break( |
|
self: "InstructionTranslatorBase", |
|
inst: Instruction, |
|
reason: GraphCompileReason, |
|
): |
|
self.output.compile_subgraph(self, reason=reason) |
|
cg = PyCodegen(self) |
|
cleanup: List[Instruction] = [] |
|
|
|
for b in self.block_stack: |
|
assert b.with_context is not None |
|
b.with_context.reconstruct_type(cg) |
|
cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup)) |
|
self.output.add_output_instructions(cg.get_instructions()) |
|
del cg |
|
|
|
if sys.version_info >= (3, 11) and inst.opname == "CALL": |
|
kw_names = ( |
|
self.kw_names.as_python_constant() |
|
if self.kw_names is not None |
|
else () |
|
) |
|
if len(kw_names) > 0: |
|
self.output.add_output_instructions( |
|
[create_instruction("KW_NAMES", argval=kw_names)] |
|
) |
|
self.output.add_output_instructions( |
|
create_call_function(inst.arg, False) |
|
) |
|
else: |
|
|
|
assert inst.target is None |
|
inst_copy = copy.copy(inst) |
|
inst_copy.exn_tab_entry = None |
|
self.output.add_output_instructions([inst_copy]) |
|
|
|
self.output.add_output_instructions(cleanup) |
|
|
|
if ( |
|
sys.version_info >= (3, 11) |
|
and sys.version_info < (3, 12) |
|
and inst.opname == "CALL" |
|
): |
|
|
|
stack_effect = dis.stack_effect( |
|
dis.opmap["PRECALL"], inst.arg |
|
) + dis.stack_effect(dis.opmap["CALL"], inst.arg) |
|
else: |
|
stack_effect = dis.stack_effect(inst.opcode, inst.arg) |
|
self.popn(push - stack_effect) |
|
|
|
for _ in range(push): |
|
self.push(UnknownVariable()) |
|
self.output.add_output_instructions( |
|
self.create_call_resume_at(self.next_instruction) |
|
) |
|
|
|
return wrapper |
|
|
|
return decorator |
|
|
|
|
|
class BytecodeDistpatchTableMeta(type): |
|
"""Installs a `cls.dispatch_table` on every subclass to speed up calls to self.OPCODE()""" |
|
|
|
def __init__(cls, name, bases, dct): |
|
super().__init__(name, bases, dct) |
|
|
|
def _missing(opname, *args): |
|
unimplemented(f"missing: {opname}") |
|
|
|
dispatch_table = { |
|
op: getattr(cls, opname, functools.partial(_missing, opname)) |
|
for opname, op in dis.opmap.items() |
|
} |
|
cls.dispatch_table = [dispatch_table.get(i) for i in range(2**8)] |
|
|
|
|
|
class InstructionTranslatorBase( |
|
metaclass=BytecodeDistpatchTableMeta, |
|
): |
|
output: OutputGraph |
|
symbolic_locals: Dict[str, VariableTracker] |
|
symbolic_globals: Dict[str, VariableTracker] |
|
stack: List[VariableTracker] |
|
instruction_pointer: Optional[int] |
|
current_instruction: Instruction |
|
block_stack: List[BlockStackEntry] |
|
lineno: int |
|
kw_names: Optional[ConstantVariable] |
|
accept_prefix_inst: bool |
|
prefix_insts: List[Instruction] |
|
inline_depth: int |
|
inconsistent_side_effects: bool |
|
current_speculation: Optional[SpeculationEntry] |
|
dispatch_table: List[Any] |
|
exn_vt_stack: List[VariableTracker] |
|
exec_recorder: Optional[ExecutionRecorder] |
|
strict_checks_fn: Optional[Callable[[VariableTracker], bool]] |
|
|
|
def mark_inconsistent_side_effects(self): |
|
""" |
|
InstructionTranslator has encountered instructions which may cause |
|
dynamo to see a different version of history from eager |
|
See: https://github.com/pytorch/pytorch/issues/110765 |
|
""" |
|
self.inconsistent_side_effects = True |
|
|
|
def maybe_has_backedge(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cur_offset = self.current_instruction.offset |
|
assert self.instruction_pointer is not None |
|
for inst in self.instructions[self.instruction_pointer :]: |
|
if inst.opname in ("RETURN_VALUE", "RETURN_CONST"): |
|
return False |
|
if inst.opname in JUMP_OPNAMES: |
|
jump_offset = inst.argval |
|
if jump_offset < cur_offset: |
|
return True |
|
return False |
|
|
|
def cell_and_freevars(self): |
|
if not hasattr(self, "_cell_and_freevars"): |
|
self._cell_and_freevars = tuple( |
|
self.code_options["co_cellvars"] or [] |
|
) + tuple(self.code_options["co_freevars"] or []) |
|
|
|
|
|
|
|
if isinstance(self, InliningInstructionTranslator): |
|
self._cell_and_freevars += self.parent.cell_and_freevars() |
|
return self._cell_and_freevars |
|
|
|
def prune_dead_locals(self): |
|
reads = livevars_analysis(self.instructions, self.current_instruction) |
|
|
|
|
|
|
|
reads = reads | set(self.cell_and_freevars()) |
|
self.symbolic_locals = { |
|
k: v for k, v in self.symbolic_locals.items() if k in reads |
|
} |
|
self.output.side_effects.prune_dead_object_new(self) |
|
|
|
def call_function( |
|
self, |
|
fn: VariableTracker, |
|
args: List[VariableTracker], |
|
kwargs: Dict[str, VariableTracker], |
|
): |
|
assert isinstance(fn, VariableTracker) |
|
assert isinstance(args, list) |
|
assert isinstance(kwargs, dict) |
|
assert all( |
|
isinstance(x, VariableTracker) |
|
for x in itertools.chain(args, kwargs.values()) |
|
) |
|
inner_fn = None |
|
if hasattr(fn, "value"): |
|
inner_fn = fn.value |
|
if hasattr(fn, "fn"): |
|
inner_fn = fn.fn |
|
if inner_fn and callable(inner_fn) and is_forbidden(inner_fn): |
|
raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}") |
|
self.push(fn.call_function(self, args, kwargs)) |
|
|
|
def inline_user_function_return(self, fn, args, kwargs): |
|
""" |
|
A call to some user defined function by inlining it. |
|
""" |
|
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) |
|
|
|
def get_line_of_code_header(self, lineno=None): |
|
if lineno is None: |
|
lineno = self.lineno |
|
inline_depth_str = ( |
|
f" (inline depth: {self.inline_depth})" if self.inline_depth > 0 else "" |
|
) |
|
funcname = get_funcname(self.f_code.co_filename, lineno) |
|
funcname_str = "" if funcname is None else f" ({funcname})" |
|
return f"{self.f_code.co_filename}:{lineno} in {self.f_code.co_name}{funcname_str}{inline_depth_str}" |
|
|
|
def get_log_starts_line_log_str(self): |
|
log_str = f"TRACE starts_line {self.get_line_of_code_header()}\n" |
|
line = linecache.getline(self.f_code.co_filename, self.lineno).rstrip() |
|
log_str += f" {line}" |
|
return log_str |
|
|
|
def starts_line(self, lineno): |
|
if self.lineno == lineno: |
|
return |
|
self.lineno = lineno |
|
TracingContext.set_current_loc( |
|
self.f_code.co_filename, lineno, self.f_code.co_name |
|
) |
|
if trace_source_log.isEnabledFor(logging.DEBUG): |
|
trace_source_log.debug("%s", LazyString(self.get_log_starts_line_log_str)) |
|
|
|
def step(self): |
|
"""Process exactly one instruction, return False we should exit""" |
|
ip = self.instruction_pointer |
|
if ip is None: |
|
return False |
|
self.current_instruction = inst = self.instructions[ip] |
|
self.instruction_pointer = ip + 1 |
|
|
|
if inst.starts_line: |
|
self.starts_line(inst.starts_line) |
|
|
|
if ( |
|
not self.stack |
|
and self.should_compile_partial_graph() |
|
and self.is_non_empty_graph() |
|
): |
|
self.current_speculation = self.speculate() |
|
if self.current_speculation.failed: |
|
return self.step_graph_break(inst) |
|
|
|
if trace_bytecode_log.isEnabledFor(logging.DEBUG): |
|
trace_bytecode_log.debug( |
|
"TRACE %s %s %s", inst.opname, inst.argval, self.stack |
|
) |
|
|
|
self.update_block_stack(inst) |
|
|
|
try: |
|
self.dispatch_table[inst.opcode](self, inst) |
|
return not self.output.should_exit |
|
except exc.ObservedException: |
|
self.exception_handler() |
|
return True |
|
except ReturnValueOp: |
|
return False |
|
except Unsupported: |
|
if self.current_speculation is None: |
|
log.debug("empty checkpoint") |
|
raise |
|
log.debug("step triggered compile", exc_info=True) |
|
|
|
self.current_speculation.fail_and_restart_analysis() |
|
|
|
if sys.version_info >= (3, 11): |
|
|
|
def update_block_stack(self, inst): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
entry = inst.exn_tab_entry |
|
if entry: |
|
|
|
|
|
|
|
|
|
if ( |
|
len(self.block_stack) >= 2 |
|
and entry.target is not self.block_stack[-1].target |
|
and entry.target is self.block_stack[-2].target |
|
): |
|
|
|
self.block_stack.pop() |
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.block_stack and inst.opname not in ("NOP", "JUMP_BACKWARD"): |
|
|
|
|
|
|
|
assert len(self.block_stack) == 1 |
|
self.block_stack.pop() |
|
|
|
else: |
|
|
|
def update_block_stack(self, inst): |
|
pass |
|
|
|
@property |
|
def next_instruction(self): |
|
return self.instructions[self.instruction_pointer] |
|
|
|
def step_graph_break(self, continue_inst): |
|
|
|
assert not self.output.output_instructions |
|
assert self.current_speculation is not None |
|
self.output.compile_subgraph( |
|
self, |
|
partial_convert=True, |
|
reason=GraphCompileReason("step_unsupported", [self.frame_summary()]), |
|
) |
|
self.output.add_output_instructions( |
|
[create_jump_absolute(continue_inst)] + self.instructions |
|
) |
|
|
|
def run_ctx_mgr(self): |
|
|
|
|
|
|
|
return TracingContext.current_frame(None) |
|
|
|
def run(self): |
|
with self.run_ctx_mgr(): |
|
try: |
|
self.output.push_tx(self) |
|
while self.step(): |
|
pass |
|
except BackendCompilerFailed: |
|
raise |
|
except Exception as e: |
|
if self.exec_recorder: |
|
e.exec_record = self.exec_recorder.get_record() |
|
raise |
|
finally: |
|
self.output.pop_tx() |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(self, InstructionTranslator): |
|
self.output.cleanup() |
|
|
|
def push(self, val: Optional[VariableTracker]): |
|
assert val is None or isinstance( |
|
val, VariableTracker |
|
), f"push expects VariableTracker, got {typestr(val)}" |
|
self.stack.append(val) |
|
|
|
def push_many(self, vals: List[VariableTracker]): |
|
for val in vals: |
|
self.push(val) |
|
|
|
def pop(self) -> VariableTracker: |
|
return self.stack.pop() |
|
|
|
def popn(self, n: int) -> List[VariableTracker]: |
|
return [*reversed([self.pop() for _ in range(n)])] |
|
|
|
def LOAD_FAST(self, inst): |
|
name = inst.argval |
|
|
|
if self.exec_recorder and name in self.f_locals: |
|
self.exec_recorder.add_local_var(name, self.f_locals[name]) |
|
|
|
try: |
|
self.push(self.symbolic_locals[name].unwrap()) |
|
except KeyError: |
|
if name.startswith("."): |
|
try: |
|
|
|
self.push(self.symbolic_locals[name.replace(".", "implicit")]) |
|
except KeyError: |
|
unimplemented("undefined LOAD_FAST (implicit)") |
|
else: |
|
unimplemented("undefined LOAD_FAST") |
|
|
|
|
|
if name.startswith("___stack"): |
|
self.symbolic_locals.pop(name) |
|
|
|
def LOAD_DEREF(self, inst): |
|
assert inst.argval in self.cell_and_freevars() |
|
|
|
if self.exec_recorder and inst.argval in self.f_locals: |
|
self.exec_recorder.add_local_var(inst.argval, self.f_locals[inst.argval]) |
|
|
|
if inst.argval not in self.symbolic_locals: |
|
unimplemented(f"undefined LOAD_DEREF {inst.argval}") |
|
self.push(self.symbolic_locals[inst.argval]) |
|
|
|
def STORE_FAST(self, inst): |
|
loaded_vt = self.pop() |
|
name = inst.argval |
|
loaded_vt.set_name_hint(name) |
|
self.symbolic_locals[name] = loaded_vt |
|
|
|
def DELETE_FAST(self, inst): |
|
del self.symbolic_locals[inst.argval] |
|
|
|
STORE_DEREF = STORE_FAST |
|
|
|
def LOAD_CLOSURE(self, inst): |
|
self.push(ClosureVariable(name=inst.argval)) |
|
|
|
def _load_const(self, inst): |
|
i = inst.arg |
|
if i is None: |
|
return ConstantVariable.create(value=inst.argval) |
|
val = self._constants_cache[i] |
|
if not val: |
|
self._constants_cache[i] = val = ConstantVariable.create(value=inst.argval) |
|
return val |
|
|
|
def LOAD_CONST(self, inst): |
|
self.push(self._load_const(inst)) |
|
|
|
def LOAD_GLOBAL(self, inst): |
|
if sys.version_info >= (3, 11): |
|
if inst.arg % 2: |
|
self.PUSH_NULL(inst) |
|
|
|
name = inst.argval |
|
|
|
if self.exec_recorder: |
|
if name in self.f_globals: |
|
self.exec_recorder.add_global_var(name, self.f_globals[name]) |
|
else: |
|
assert name in self.f_builtins |
|
self.exec_recorder.builtins[name] = self.f_builtins[name] |
|
|
|
if name in self.symbolic_globals: |
|
variable = self.output.side_effects[self.symbolic_globals[name]] |
|
self.push(self.output.side_effects.load_global(variable, name)) |
|
return |
|
|
|
try: |
|
value = self.f_globals[name] |
|
except KeyError: |
|
return self.load_builtin(inst) |
|
|
|
source = GlobalSource(name) |
|
self.push(VariableBuilder(self, source)(value)) |
|
|
|
def STORE_GLOBAL(self, inst): |
|
value = self.pop() |
|
name = inst.argval |
|
source = GlobalSource(name) |
|
if name not in self.symbolic_globals: |
|
self.symbolic_globals[name] = object() |
|
variable = self.output.side_effects.track_global_existing( |
|
source, self.symbolic_globals[name] |
|
) |
|
if isinstance(value, RemovableHandleVariable): |
|
unimplemented("Storing handles in globals - NYI") |
|
self.output.side_effects.store_global(variable, name, value) |
|
|
|
def import_source(self, module_name): |
|
"""Create an alias to a module for use in guards""" |
|
if "torch_package" in module_name: |
|
value = torch.package.package_importer._package_imported_modules[ |
|
module_name |
|
] |
|
alias = ( |
|
module_name.replace(">", "_").replace("<", "_").replace(".", "_dot_") |
|
) |
|
else: |
|
value = importlib.import_module(module_name) |
|
alias = f"__import_{module_name.replace('.', '_dot_')}" |
|
f_globals = self.output.global_scope |
|
assert alias not in f_globals or f_globals[alias] is value |
|
f_globals[alias] = value |
|
self.output.update_co_names(alias) |
|
return GlobalSource(alias) |
|
|
|
def resolve_name(self, name, package, level): |
|
""" |
|
Copied from the Cpython implementation of __import__ |
|
Resolve a relative module name to an absolute one. |
|
https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L902 |
|
""" |
|
bits = package.rsplit(".", level - 1) |
|
if len(bits) < level: |
|
raise ImportError("attempted relative import beyond top-level package") |
|
base = bits[0] |
|
return f"{base}.{name}" if name else base |
|
|
|
def calc_package(self): |
|
""" |
|
Copied from the Cpython implementation of __import__ |
|
https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L1090 |
|
""" |
|
package = self.f_globals.get("__package__") |
|
spec = self.f_globals.get("__spec__") |
|
if package is not None: |
|
if spec is not None and package != spec.parent: |
|
log.warning( |
|
"__package__ != __spec__.parent (%r != %r)", |
|
package, |
|
spec.parent, |
|
stacklevel=3, |
|
) |
|
return package |
|
elif spec is not None: |
|
return spec.parent |
|
else: |
|
log.warning( |
|
"can't resolve package from __spec__ or __package__, " |
|
"falling back on __name__ and __path__", |
|
stacklevel=3, |
|
) |
|
package = self.f_globals["__name__"] |
|
if "__path__" not in self.f_globals: |
|
package = package.rpartition(".")[0] |
|
return package |
|
|
|
def IMPORT_NAME(self, inst): |
|
level, fromlist = self.popn(2) |
|
level = level.as_python_constant() |
|
fromlist = fromlist.as_python_constant() |
|
module_name = inst.argval |
|
|
|
|
|
recorded_name = ( |
|
f"{ExecutionRecorder.LOCAL_MOD_PREFIX}_{level}_{fromlist}_{module_name}" |
|
) |
|
if recorded_name in self.f_globals: |
|
value = self.f_globals[recorded_name] |
|
source = GlobalSource(recorded_name) |
|
else: |
|
try: |
|
value = __import__( |
|
module_name, |
|
fromlist=fromlist, |
|
level=level, |
|
globals=self.f_globals, |
|
) |
|
except ImportError: |
|
unimplemented("import a module that does not exist") |
|
|
|
if level != 0: |
|
pkg = self.calc_package() |
|
module_name = self.resolve_name(module_name, pkg, level) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if not fromlist: |
|
top_level_module_name = module_name.partition(".")[0] |
|
source = self.import_source(top_level_module_name) |
|
else: |
|
source = self.import_source(module_name) |
|
|
|
if self.exec_recorder: |
|
self.exec_recorder.add_local_mod(recorded_name, value) |
|
|
|
if istype(value, (types.ModuleType, DummyModule)): |
|
self.push(PythonModuleVariable(value, source=source)) |
|
else: |
|
unimplemented(f"IMPORT_NAME {typestr(value)}") |
|
|
|
def IMPORT_FROM(self, inst): |
|
self.DUP_TOP(inst) |
|
self._load_attr(inst) |
|
|
|
def load_builtin_from_argval(self, argval): |
|
if argval not in self.f_builtins: |
|
raise NameError(f"name '{argval}' is not defined") |
|
val = self.f_builtins[argval] |
|
|
|
if callable(val): |
|
builtins_source = GlobalSource( |
|
self.output.name_of_builtins_dict_key_in_fglobals |
|
) |
|
var_source = GetItemSource(builtins_source, argval) |
|
self.push(VariableBuilder(self, var_source)(val)) |
|
else: |
|
assert is_builtin_constant(val) |
|
self.push(ConstantVariable.create(value=val)) |
|
|
|
def load_builtin(self, inst): |
|
self.load_builtin_from_argval(inst.argval) |
|
|
|
def jump(self, inst): |
|
self.instruction_pointer = self.indexof[inst.target] |
|
|
|
JUMP_FORWARD = jump |
|
JUMP_ABSOLUTE = jump |
|
|
|
POP_JUMP_IF_FALSE = generic_jump(operator.not_, False) |
|
POP_JUMP_IF_TRUE = generic_jump(operator.truth, False) |
|
JUMP_IF_FALSE_OR_POP = generic_jump(operator.not_, True) |
|
JUMP_IF_TRUE_OR_POP = generic_jump(operator.truth, True) |
|
|
|
def SETUP_LOOP(self, inst): |
|
|
|
self.block_stack.append(BlockStackEntry(inst, inst.target)) |
|
|
|
def SETUP_EXCEPT(self, inst): |
|
|
|
self.block_stack.append(BlockStackEntry(inst, inst.target)) |
|
|
|
def POP_BLOCK(self, inst): |
|
self.block_stack.pop() |
|
|
|
def SETUP_WITH(self, inst): |
|
self.setup_or_before_with(inst) |
|
|
|
def SETUP_FINALLY(self, inst): |
|
self.block_stack.append(BlockStackEntry(inst, inst.target)) |
|
|
|
def BEGIN_FINALLY(self, inst): |
|
self.push(None) |
|
|
|
def WITH_CLEANUP_START(self, inst): |
|
exit, exc = self.popn(2) |
|
assert exc is None |
|
self.push(exc) |
|
self.push(exit.call_function(self, [ConstantVariable.create(None)] * 3, {})) |
|
|
|
def WITH_CLEANUP_FINISH(self, inst): |
|
self.popn(2) |
|
self.push(None) |
|
|
|
def CALL_FINALLY(self, inst): |
|
""" |
|
pushes the address of the next instruction onto the stack and increments |
|
bytecode counter by delta |
|
""" |
|
|
|
addr = self.indexof[self.next_instruction] |
|
self.push(ConstantVariable.create(addr)) |
|
self.jump(inst) |
|
|
|
def END_FINALLY(self, inst): |
|
|
|
|
|
tos = self.pop() |
|
if isinstance(tos, ConstantVariable): |
|
self.instruction_pointer = tos.as_python_constant() |
|
else: |
|
pass |
|
|
|
def POP_FINALLY(self, inst): |
|
|
|
preserve_tos = inst.argval |
|
if preserve_tos: |
|
tos = self.pop() |
|
_ = self.pop() |
|
if preserve_tos: |
|
self.push(tos) |
|
|
|
def FOR_ITER(self, inst): |
|
it = self.pop().realize() |
|
try: |
|
val = it.next_variable(self) |
|
self.push(it) |
|
self.push(val) |
|
except (StopIteration, exc.UserStopIteration): |
|
|
|
if sys.version_info >= (3, 12): |
|
|
|
|
|
|
|
|
|
self.push(it) |
|
self.push(ConstantVariable.create(None)) |
|
self.jump(inst) |
|
|
|
def RAISE_VARARGS(self, inst): |
|
if inst.arg == 0: |
|
unimplemented("re-raise") |
|
elif inst.arg == 1: |
|
val = self.pop() |
|
|
|
|
|
if ( |
|
isinstance(val, BuiltinVariable) and val.fn is StopIteration |
|
) or isinstance(val, variables.StopIterationVariable): |
|
raise exc.UserStopIteration |
|
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(val, variables.BuiltinVariable): |
|
|
|
|
|
val = val.call_function(self, [], {}) |
|
|
|
|
|
self.exn_vt_stack.append(val) |
|
|
|
|
|
if isinstance(val, variables.ExceptionVariable): |
|
raise exc.ObservedException(f"raised exception {val}") |
|
unimplemented(f"raise {exc}") |
|
else: |
|
unimplemented("raise ... from ...") |
|
|
|
def exception_handler(self): |
|
if sys.version_info >= (3, 11): |
|
exn_tab_entry = self.current_instruction.exn_tab_entry |
|
if exn_tab_entry: |
|
|
|
|
|
|
|
|
|
while len(self.stack) > exn_tab_entry.depth: |
|
self.pop() |
|
|
|
|
|
if exn_tab_entry.lasti: |
|
|
|
|
|
|
|
unimplemented("lasti=True while exception handling") |
|
self.push( |
|
variables.ConstantVariable(self.current_instruction.offset) |
|
) |
|
|
|
|
|
assert len(self.exn_vt_stack) |
|
self.push(self.exn_vt_stack[-1]) |
|
|
|
|
|
self.jump(exn_tab_entry) |
|
else: |
|
|
|
|
|
self.stack.clear() |
|
if type(self) is InstructionTranslator: |
|
raise Unsupported("Observed exception") |
|
raise exc.ObservedException |
|
else: |
|
if len(self.block_stack): |
|
|
|
|
|
assert len(self.exn_vt_stack) |
|
exception_var = self.exn_vt_stack[-1] |
|
|
|
block_stack_entry = self.block_stack.pop() |
|
|
|
while block_stack_entry.inst.opname == "EXCEPT_HANDLER": |
|
|
|
|
|
self.popn(3) |
|
if len(self.block_stack) == 0: |
|
unimplemented( |
|
"exception is raised when block stack " "is empty" |
|
) |
|
block_stack_entry = self.block_stack.pop() |
|
|
|
if block_stack_entry.inst.opname != "SETUP_FINALLY": |
|
unimplemented( |
|
"exception is raised when top of the block stack " |
|
"is not exception handler (e.g. try .. with .. except). " |
|
f"Current TOS is {block_stack_entry.inst}" |
|
) |
|
|
|
|
|
|
|
except_handler_inst = Instruction(1e6, "EXCEPT_HANDLER", None, 0) |
|
self.block_stack.append(BlockStackEntry(except_handler_inst, None)) |
|
|
|
|
|
if len(self.exn_vt_stack) >= 2: |
|
old_exception = self.exn_vt_stack[-2] |
|
|
|
|
|
|
|
self.push(variables.UnknownVariable()) |
|
self.push(old_exception) |
|
self.push(variables.BuiltinVariable(old_exception.exc_type)) |
|
else: |
|
|
|
self.push(variables.ConstantVariable(None)) |
|
self.push(variables.ConstantVariable(None)) |
|
self.push(variables.ConstantVariable(None)) |
|
|
|
|
|
|
|
self.push(variables.UnknownVariable()) |
|
self.push(exception_var) |
|
self.push(variables.BuiltinVariable(exception_var.exc_type)) |
|
|
|
|
|
self.jump(block_stack_entry) |
|
else: |
|
|
|
|
|
self.stack.clear() |
|
if type(self) is InstructionTranslator: |
|
raise Unsupported("Observed exception") |
|
raise exc.ObservedException |
|
|
|
def PUSH_EXC_INFO(self, inst): |
|
val = self.pop() |
|
assert len(self.exn_vt_stack) |
|
self.push(self.exn_vt_stack[-1]) |
|
self.push(val) |
|
|
|
def POP_EXCEPT(self, inst): |
|
if sys.version_info >= (3, 11): |
|
val = self.pop() |
|
assert isinstance(val, variables.ExceptionVariable) |
|
|
|
|
|
assert len(self.exn_vt_stack) |
|
self.exn_vt_stack.pop() |
|
else: |
|
assert len(self.block_stack) > 0 |
|
if self.block_stack[-1].inst.opname != "EXCEPT_HANDLER": |
|
raise AssertionError( |
|
"Bug in Dynamo tracing of exception handling." |
|
"Top of the block stack is not EXCEPT_HANDLER." |
|
) |
|
self.block_stack.pop() |
|
|
|
self.popn(3) |
|
|
|
|
|
assert len(self.exn_vt_stack) |
|
self.exn_vt_stack.pop() |
|
|
|
def check_if_exc_matches(self): |
|
assert len(self.stack) >= 2 |
|
expected_exc_types = self.pop() |
|
exc_instance = self.stack[-1] |
|
|
|
|
|
|
|
|
|
|
|
if not isinstance(expected_exc_types, (BuiltinVariable, TupleVariable)): |
|
unimplemented( |
|
f"except has an unsupported types of objects {expected_exc_types}" |
|
) |
|
|
|
if sys.version_info >= (3, 11): |
|
if not isinstance(exc_instance, variables.ExceptionVariable): |
|
unimplemented( |
|
f"except expects to recieve an object of exception type but received {exc_instance}" |
|
) |
|
|
|
if isinstance(expected_exc_types, TupleVariable): |
|
expected_types = expected_exc_types.items |
|
else: |
|
expected_types = [ |
|
expected_exc_types, |
|
] |
|
|
|
for expected_type in expected_types: |
|
if not isinstance(expected_type, BuiltinVariable): |
|
unimplemented( |
|
f"except has an unsupported types of object {expected_type}" |
|
) |
|
if isinstance(exc_instance, variables.ExceptionVariable) and issubclass( |
|
exc_instance.exc_type, expected_type.fn |
|
): |
|
return True |
|
elif isinstance(exc_instance, variables.BuiltinVariable) and issubclass( |
|
exc_instance.fn, expected_type.fn |
|
): |
|
return True |
|
|
|
return False |
|
|
|
def CHECK_EXC_MATCH(self, inst): |
|
self.push(variables.ConstantVariable(self.check_if_exc_matches())) |
|
|
|
def JUMP_IF_NOT_EXC_MATCH(self, inst): |
|
if not self.check_if_exc_matches(): |
|
self.jump(inst) |
|
|
|
def COMPARE_OP(self, inst): |
|
if inst.argval == "exception match": |
|
self.CHECK_EXC_MATCH(inst) |
|
else: |
|
self.push(compare_op_handlers[inst.argval](self, self.popn(2), {})) |
|
|
|
def GET_ITER(self, inst): |
|
self.call_function(BuiltinVariable(iter), [self.pop()], {}) |
|
|
|
@break_graph_if_unsupported(push=1) |
|
def CALL_FUNCTION(self, inst): |
|
args = self.popn(inst.argval) |
|
fn = self.pop() |
|
self.call_function(fn, args, {}) |
|
|
|
@break_graph_if_unsupported(push=1) |
|
def CALL_FUNCTION_EX(self, inst): |
|
kwargsvars: VariableTracker |
|
if inst.argval == 0: |
|
kwargsvars = ConstDictVariable({}) |
|
argsvars = self.pop() |
|
elif inst.argval == 1: |
|
kwargsvars = self.pop() |
|
argsvars = self.pop() |
|
else: |
|
unimplemented("CALL_FUNCTION_EX") |
|
fn = self.pop() |
|
if sys.version_info >= (3, 11): |
|
null = self.pop() |
|
assert isinstance(null, NullVariable) |
|
|
|
if ( |
|
isinstance(fn, GetAttrVariable) |
|
and isinstance(fn.obj, TensorVariable) |
|
and fn.name == "view" |
|
and isinstance(argsvars, (ConstantVariable, TensorVariable)) |
|
): |
|
|
|
|
|
|
|
argsvars = TupleVariable([argsvars]) |
|
|
|
if not isinstance( |
|
argsvars, BaseListVariable |
|
) and argsvars.has_unpack_var_sequence(self): |
|
argsvars = TupleVariable(argsvars.unpack_var_sequence(self)) |
|
|
|
if not isinstance(argsvars, BaseListVariable) or not isinstance( |
|
kwargsvars, ConstDictVariable |
|
): |
|
unimplemented(f"non-static call {typestr(argsvars)} {typestr(kwargsvars)}") |
|
|
|
|
|
kwargsvars = kwargsvars.keys_as_python_constant() |
|
self.call_function(fn, argsvars.items, kwargsvars) |
|
|
|
@break_graph_if_unsupported(push=1) |
|
def CALL_FUNCTION_KW(self, inst): |
|
argnames = self.pop() |
|
args = self.popn(inst.argval) |
|
fn = self.pop() |
|
assert isinstance(argnames, TupleVariable) and argnames.is_python_constant() |
|
argnames = argnames.as_python_constant() |
|
args, kwargs_list = args[: -len(argnames)], args[-len(argnames) :] |
|
kwargs = dict(zip(argnames, kwargs_list)) |
|
assert len(kwargs) == len(argnames) |
|
self.call_function(fn, args, kwargs) |
|
|
|
def LOAD_METHOD_SUPER(self, inst): |
|
self.CALL_FUNCTION(dataclasses.replace(inst, argval=2)) |
|
arg = inst.argval[0] |
|
argval = self.code_options["co_names"][arg] |
|
if sys.version_info < (3, 11): |
|
self._load_attr(dataclasses.replace(inst, argval=argval)) |
|
else: |
|
self.LOAD_METHOD(dataclasses.replace(inst, argval=argval)) |
|
|
|
def LOAD_ATTR_SUPER(self, inst): |
|
self.CALL_FUNCTION(dataclasses.replace(inst, argval=2)) |
|
arg = inst.argval[0] |
|
argval = self.code_options["co_names"][arg] |
|
self._load_attr(dataclasses.replace(inst, argval=argval)) |
|
|
|
def LOAD_METHOD(self, inst): |
|
self._load_attr(inst) |
|
obj = self.pop() |
|
if sys.version_info >= (3, 11): |
|
|
|
|
|
|
|
self.PUSH_NULL(inst) |
|
self.push(obj) |
|
else: |
|
self.push(obj) |
|
self.push(None) |
|
|
|
def CALL_METHOD(self, inst): |
|
args = self.popn(inst.argval) |
|
dummy = self.pop() |
|
assert dummy is None |
|
fn = self.pop() |
|
self.call_function(fn, args, {}) |
|
|
|
def _load_attr(self, inst): |
|
obj = self.pop() |
|
result = BuiltinVariable(getattr).call_function( |
|
self, [obj, ConstantVariable.create(inst.argval)], {} |
|
) |
|
self.push(result) |
|
|
|
def LOAD_ATTR(self, inst): |
|
if sys.version_info >= (3, 12): |
|
if inst.arg % 2: |
|
self.LOAD_METHOD(inst) |
|
return |
|
self._load_attr(inst) |
|
|
|
def STORE_ATTR(self, inst): |
|
speculation = self.speculate() |
|
if speculation.failed: |
|
return self.store_attr_graph_break(inst) |
|
val, obj = self.popn(2) |
|
|
|
if isinstance(obj, NNModuleVariable) and not isinstance(val, ConstantVariable): |
|
|
|
|
|
assert ( |
|
not self.export |
|
), f"Mutating module attribute {inst.argval} during export." |
|
|
|
try: |
|
BuiltinVariable(setattr).call_function( |
|
self, [obj, ConstantVariable.create(inst.argval), val], {} |
|
) |
|
return |
|
except Unsupported as e: |
|
if not self.should_compile_partial_graph(): |
|
raise |
|
log.debug("STORE_ATTR triggered compile", exc_info=True) |
|
e.remove_from_stats() |
|
e.add_to_stats("graph_break") |
|
speculation.fail_and_restart_analysis() |
|
|
|
def store_attr_graph_break(self, inst): |
|
if not self.should_compile_partial_graph(): |
|
unimplemented("should_compile_partial_graph=False") |
|
self.output.compile_subgraph( |
|
self, reason=GraphCompileReason("store_attr", [self.frame_summary()]) |
|
) |
|
self.output.add_output_instructions([copy.copy(inst)]) |
|
self.popn(2) |
|
self.output.add_output_instructions( |
|
self.create_call_resume_at(self.next_instruction) |
|
) |
|
|
|
def DELETE_ATTR(self, inst): |
|
obj = self.pop() |
|
BuiltinVariable(delattr).call_function( |
|
self, [obj, ConstantVariable.create(inst.argval)], {} |
|
) |
|
|
|
def create_call_resume_at(self, offset): |
|
raise AssertionError( |
|
f"create_call_resume_at not overridden by subclass {type(self)}" |
|
) |
|
|
|
def should_compile_partial_graph(self) -> bool: |
|
raise AssertionError( |
|
f"should_compile_partial_graph not overridden by subclass {type(self)}" |
|
) |
|
|
|
@break_graph_if_unsupported(push=0) |
|
def STORE_SUBSCR(self, inst): |
|
val, obj, key = self.popn(3) |
|
result = obj.call_method(self, "__setitem__", [key, val], {}) |
|
|
|
def DELETE_SUBSCR(self, inst): |
|
obj, key = self.popn(2) |
|
obj.call_method(self, "__delitem__", [key], {}) |
|
|
|
def BUILD_TUPLE(self, inst): |
|
items = self.popn(inst.argval) |
|
self.push(TupleVariable(items)) |
|
|
|
def BUILD_SLICE(self, inst): |
|
items = self.popn(inst.argval) |
|
self.push(SliceVariable(items)) |
|
|
|
def BUILD_LIST(self, inst): |
|
items = self.popn(inst.argval) |
|
self.push(ListVariable(items, mutable_local=MutableLocal())) |
|
|
|
def BUILD_SET(self, inst): |
|
if config.inject_BUILD_SET_unimplemented_TESTING_ONLY: |
|
unimplemented("missing: BUILD_SET") |
|
items = self.popn(inst.argval) |
|
new_set = SetVariable(items, mutable_local=MutableLocal()) |
|
self.push(new_set) |
|
|
|
def BUILD_LIST_UNPACK(self, inst, cls=ListVariable): |
|
seqs = self.popn(inst.argval) |
|
items = list() |
|
for seq in seqs: |
|
try: |
|
items.extend(seq.unpack_var_sequence(self)) |
|
except NotImplementedError: |
|
unimplemented(f"BUILD_LIST_UNPACK {seq}") |
|
self.push(cls(items, mutable_local=MutableLocal())) |
|
|
|
def BUILD_TUPLE_UNPACK(self, inst): |
|
self.BUILD_LIST_UNPACK(inst, cls=TupleVariable) |
|
|
|
BUILD_TUPLE_UNPACK_WITH_CALL = BUILD_TUPLE_UNPACK |
|
|
|
def BUILD_MAP(self, inst): |
|
items = self.popn(inst.argval * 2) |
|
d = dict(zip(items[::2], items[1::2])) |
|
self.push(ConstDictVariable(d, mutable_local=MutableLocal())) |
|
|
|
def BUILD_MAP_UNPACK(self, inst): |
|
items = self.popn(inst.argval) |
|
|
|
items = [BuiltinVariable(dict).call_function(self, [x], {}) for x in items] |
|
result = dict() |
|
for x in items: |
|
assert isinstance(x, ConstDictVariable) |
|
result.update(x.items) |
|
self.push( |
|
ConstDictVariable( |
|
result, |
|
mutable_local=MutableLocal(), |
|
) |
|
) |
|
|
|
BUILD_MAP_UNPACK_WITH_CALL = BUILD_MAP_UNPACK |
|
|
|
def BUILD_CONST_KEY_MAP(self, inst): |
|
keys = self.pop() |
|
values = self.popn(inst.argval) |
|
assert isinstance(keys, TupleVariable) |
|
assert keys.is_python_constant() |
|
|
|
keys = keys.unpack_var_sequence(self) |
|
assert len(keys) == len(values) |
|
|
|
self.push( |
|
ConstDictVariable( |
|
dict(zip(keys, values)), |
|
mutable_local=MutableLocal(), |
|
) |
|
) |
|
|
|
def MAP_ADD(self, inst): |
|
k, v = self.popn(2) |
|
assert inst.argval > 0 |
|
obj = self.stack[-inst.arg].realize() |
|
assert isinstance(obj, ConstDictVariable) |
|
obj.call_method(self, "__setitem__", (k, v), {}) |
|
|
|
def SET_ADD(self, inst): |
|
v = self.pop() |
|
assert inst.argval > 0 |
|
obj = self.stack[-inst.arg] |
|
assert isinstance(obj, SetVariable) |
|
assert obj.mutable_local |
|
return obj.call_method(self, "add", [v], {}) |
|
|
|
def SET_UPDATE(self, inst): |
|
v = self.pop() |
|
assert inst.argval > 0 |
|
obj = self.stack[-inst.arg] |
|
assert isinstance(obj, SetVariable) |
|
assert obj.mutable_local |
|
obj.call_method(self, "update", [v], {}) |
|
|
|
def LIST_APPEND(self, inst): |
|
v = self.pop() |
|
assert inst.argval > 0 |
|
obj = self.stack[-inst.arg].realize() |
|
assert isinstance(obj, ListVariable) |
|
assert obj.mutable_local |
|
self.output.side_effects.mutation(obj) |
|
obj.items.append(v) |
|
|
|
def MAKE_FUNCTION(self, inst): |
|
flags = inst.arg |
|
old_stack = list(self.stack) |
|
if sys.version_info < (3, 11): |
|
fn_name = self.pop() |
|
code = self.pop() |
|
if sys.version_info >= (3, 11): |
|
|
|
|
|
assert hasattr(code.value, "co_qualname") |
|
fn_name = ConstantVariable.create(value=code.value.co_qualname) |
|
defaults = None |
|
closure = None |
|
annotations = None |
|
kwdefaults = None |
|
|
|
if flags & 0x08: |
|
closure = self.pop() |
|
if flags & 0x04: |
|
annotations = self.pop() |
|
if flags & 0x02: |
|
kwdefaults = self.pop() |
|
if flags & 0x01: |
|
defaults = self.pop() |
|
|
|
self.push( |
|
NestedUserFunctionVariable( |
|
fn_name, |
|
code, |
|
self.f_globals, |
|
defaults, |
|
kwdefaults, |
|
annotations, |
|
closure, |
|
closure_scope=self, |
|
) |
|
) |
|
|
|
def UNPACK_SEQUENCE(self, inst): |
|
seq = self.pop() |
|
if isinstance(seq, TensorVariable): |
|
val = seq.unpack_var_sequence(self, idxes=range(inst.argval)) |
|
elif isinstance(seq, GetAttrVariable) and isinstance(seq.obj, TensorVariable): |
|
|
|
proxy = getattr(seq.obj.as_proxy(), seq.name) |
|
val = [wrap_fx_proxy(self, proxy[i]) for i in range(inst.argval)] |
|
elif seq.has_unpack_var_sequence(self): |
|
val = seq.unpack_var_sequence(self) |
|
else: |
|
unimplemented(f"UNPACK_SEQUENCE {seq}") |
|
if len(val) != inst.argval: |
|
unimplemented("UNPACK_SEQUENCE length mismatch") |
|
for i in reversed(val): |
|
self.push(i) |
|
|
|
def UNPACK_EX(self, inst): |
|
assert 0 <= inst.argval <= 0xFFFF |
|
prefix = inst.argval & 0xFF |
|
suffix = inst.argval >> 8 |
|
seq = self.pop() |
|
if seq.has_unpack_var_sequence(self): |
|
vals = list(seq.unpack_var_sequence(self)) |
|
assert len(vals) >= prefix + suffix |
|
vals_prefix = vals[:prefix] |
|
vals_list = vals[prefix : len(vals) - suffix] |
|
vals_suffix = vals[len(vals) - suffix :] |
|
for item in reversed(vals_suffix): |
|
self.push(item) |
|
self.push(TupleVariable(vals_list)) |
|
for item in reversed(vals_prefix): |
|
self.push(item) |
|
else: |
|
unimplemented(f"UNPACK_EX {seq}") |
|
|
|
def NOP(self, inst): |
|
pass |
|
|
|
def POP_TOP(self, inst): |
|
self.pop() |
|
|
|
def ROT_TWO(self, inst): |
|
a = self.pop() |
|
b = self.pop() |
|
self.push(a) |
|
self.push(b) |
|
|
|
def ROT_THREE(self, inst): |
|
a = self.pop() |
|
b = self.pop() |
|
c = self.pop() |
|
self.push(a) |
|
self.push(c) |
|
self.push(b) |
|
|
|
def ROT_FOUR(self, inst): |
|
a = self.pop() |
|
b = self.pop() |
|
c = self.pop() |
|
d = self.pop() |
|
self.push(a) |
|
self.push(d) |
|
self.push(c) |
|
self.push(b) |
|
|
|
def DUP_TOP(self, inst): |
|
a = self.pop() |
|
self.push(a) |
|
self.push(a) |
|
|
|
def DUP_TOP_TWO(self, inst): |
|
a = self.pop() |
|
b = self.pop() |
|
self.push(b) |
|
self.push(a) |
|
self.push(b) |
|
self.push(a) |
|
|
|
def FORMAT_VALUE(self, inst): |
|
flags = inst.arg |
|
if (flags & 0x04) == 0x04: |
|
fmt_spec = self.pop() |
|
else: |
|
fmt_spec = ConstantVariable.create("") |
|
|
|
value = self.pop() |
|
if isinstance(value, SymNodeVariable): |
|
value = ConstantVariable.create(str(value.sym_num)) |
|
if (flags & 0x03) == 0x01: |
|
value = BuiltinVariable(str).call_function(self, [value], {}) |
|
elif (flags & 0x03) == 0x02: |
|
value = BuiltinVariable(repr).call_function(self, [value], {}) |
|
elif (flags & 0x03) == 0x03: |
|
value = BuiltinVariable(ascii).call_function(self, [value], {}) |
|
|
|
fmt_var = ConstantVariable.create("{:" + fmt_spec.as_python_constant() + "}") |
|
|
|
self.call_function(BuiltinVariable(str.format), [fmt_var, value], {}) |
|
|
|
def BUILD_STRING(self, inst): |
|
format_string_parts: List[str] = [] |
|
args: List[VariableTracker] = [] |
|
kwargs: Dict[str, VariableTracker] = {} |
|
for part in self.popn(inst.arg): |
|
if isinstance(part, ConstantVariable): |
|
format_string_parts.append("{}") |
|
args.append(part) |
|
elif isinstance(part, variables.StringFormatVariable): |
|
format_string_parts.append(part.format_string) |
|
args.extend(part.sym_args) |
|
if set(kwargs.keys()) & set(part.sym_kwargs.keys()): |
|
unimplemented( |
|
f"BUILD_STRING key conflict {kwargs} & {part.sym_kwargs}" |
|
) |
|
kwargs.update(part.sym_kwargs) |
|
else: |
|
unimplemented(f"BUILD_STRING {part}") |
|
self.push( |
|
variables.StringFormatVariable.create( |
|
"".join(format_string_parts), args, kwargs |
|
) |
|
) |
|
|
|
def IS_OP(self, inst): |
|
assert inst.argval == 0 or inst.argval == 1 |
|
if inst.argval == 0: |
|
new_argval = "is" |
|
else: |
|
new_argval = "is not" |
|
new_inst = create_instruction("COMPARE_OP", argval=new_argval) |
|
self.COMPARE_OP(new_inst) |
|
|
|
def CONTAINS_OP(self, inst): |
|
assert inst.argval == 0 or inst.argval == 1 |
|
left, right = self.popn(2) |
|
op = inst.argval |
|
self.push(right.call_method(self, "__contains__", [left], {})) |
|
if op == 1: |
|
self.UNARY_NOT(inst) |
|
|
|
def LIST_EXTEND(self, inst): |
|
v = self.pop() |
|
assert inst.argval > 0 |
|
obj = self.stack[-inst.arg] |
|
assert isinstance(obj, ListVariable) |
|
assert obj.mutable_local |
|
obj.call_method(self, "extend", [v], {}) |
|
|
|
def LIST_TO_TUPLE(self, inst): |
|
self.push(BuiltinVariable(tuple).call_function(self, [self.pop()], {})) |
|
|
|
def DICT_MERGE(self, inst): |
|
v = self.pop() |
|
assert inst.argval > 0 |
|
obj = self.stack[-inst.arg].realize() |
|
assert isinstance(obj, ConstDictVariable) |
|
assert obj.mutable_local |
|
obj.call_method(self, "update", [v], {}) |
|
|
|
DICT_UPDATE = DICT_MERGE |
|
|
|
def GEN_START(self, inst): |
|
self.pop() |
|
|
|
def GET_LEN(self, inst): |
|
tos = self.stack[-1] |
|
if tos.is_python_constant(): |
|
self.push(ConstantVariable.create(len(tos.as_python_constant()))) |
|
else: |
|
self.push(tos.call_method(self, "__len__", [], {})) |
|
|
|
def MATCH_MAPPING(self, inst): |
|
tos = self.stack[-1] |
|
assert isinstance(tos, ConstDictVariable) |
|
if isinstance(tos.items, collections.abc.Mapping): |
|
self.push(ConstantVariable.create(True)) |
|
else: |
|
self.push(ConstantVariable.create(False)) |
|
|
|
def MATCH_SEQUENCE(self, inst): |
|
tos = self.stack[-1] |
|
assert tos.is_python_constant() |
|
tos_value = tos.as_python_constant() |
|
if isinstance(tos_value, collections.abc.Sequence) and not isinstance( |
|
tos_value, (str, bytes, bytearray) |
|
): |
|
self.push(ConstantVariable.create(True)) |
|
else: |
|
self.push(ConstantVariable.create(False)) |
|
|
|
def MATCH_KEYS(self, inst): |
|
tos = self.stack[-1] |
|
tos1 = self.stack[-2] |
|
assert isinstance(tos1, ConstDictVariable) |
|
|
|
if all(k in tos1 for k in tos): |
|
self.push(TupleVariable([tos1.getitem_const(k) for k in tos])) |
|
if sys.version_info < (3, 11): |
|
self.push(ConstantVariable.create(True)) |
|
else: |
|
self.push(ConstantVariable.create(None)) |
|
if sys.version_info < (3, 11): |
|
self.push(ConstantVariable.create(False)) |
|
|
|
def LOAD_ASSERTION_ERROR(self, inst): |
|
self.load_builtin_from_argval("AssertionError") |
|
|
|
UNARY_POSITIVE = stack_op(operator.pos) |
|
UNARY_NEGATIVE = stack_op(operator.neg) |
|
UNARY_NOT = stack_op(operator.not_) |
|
UNARY_INVERT = stack_op(operator.invert) |
|
|
|
BINARY_POWER = stack_op(operator.pow) |
|
BINARY_MULTIPLY = stack_op(operator.mul) |
|
BINARY_MATRIX_MULTIPLY = stack_op(operator.matmul) |
|
BINARY_FLOOR_DIVIDE = stack_op(operator.floordiv) |
|
BINARY_TRUE_DIVIDE = stack_op(operator.truediv) |
|
BINARY_MODULO = stack_op(operator.mod) |
|
BINARY_REMAINDER = stack_op(operator.mod) |
|
BINARY_ADD = stack_op(operator.add) |
|
BINARY_SUBTRACT = stack_op(operator.sub) |
|
BINARY_SUBSCR = break_graph_if_unsupported(push=1)(stack_op(operator.getitem)) |
|
BINARY_LSHIFT = stack_op(operator.lshift) |
|
BINARY_RSHIFT = stack_op(operator.rshift) |
|
BINARY_AND = stack_op(operator.and_) |
|
BINARY_OR = stack_op(operator.or_) |
|
BINARY_XOR = stack_op(operator.xor) |
|
|
|
INPLACE_POWER = stack_op(operator.ipow) |
|
INPLACE_MULTIPLY = stack_op(operator.imul) |
|
INPLACE_MATRIX_MULTIPLY = stack_op(operator.imatmul) |
|
INPLACE_FLOOR_DIVIDE = stack_op(operator.ifloordiv) |
|
INPLACE_TRUE_DIVIDE = stack_op(operator.itruediv) |
|
INPLACE_MODULO = stack_op(operator.imod) |
|
INPLACE_REMAINDER = stack_op(operator.imod) |
|
INPLACE_ADD = stack_op(operator.iadd) |
|
INPLACE_SUBTRACT = stack_op(operator.isub) |
|
INPLACE_LSHIFT = stack_op(operator.ilshift) |
|
INPLACE_RSHIFT = stack_op(operator.irshift) |
|
INPLACE_AND = stack_op(operator.iand) |
|
INPLACE_XOR = stack_op(operator.ixor) |
|
INPLACE_OR = stack_op(operator.ior) |
|
|
|
|
|
def RESUME(self, inst): |
|
if inst.arg == 0: |
|
self.append_prefix_inst(inst) |
|
self.accept_prefix_inst = False |
|
else: |
|
assert not self.accept_prefix_inst |
|
|
|
if sys.version_info >= (3, 11): |
|
|
|
def BINARY_OP(self, inst): |
|
return _binary_op_lookup[inst.arg](self, inst) |
|
|
|
def PRECALL(self, inst): |
|
pass |
|
|
|
def KW_NAMES(self, inst): |
|
kw_names = self.code_options["co_consts"][inst.arg] |
|
assert isinstance(kw_names, tuple) |
|
for name in kw_names: |
|
assert isinstance(name, str) |
|
assert self.kw_names is None |
|
self.kw_names = ConstantVariable.create(value=kw_names) |
|
|
|
def PUSH_NULL(self, inst): |
|
self.push(NullVariable()) |
|
|
|
@break_graph_if_unsupported(push=1) |
|
def CALL(self, inst): |
|
|
|
|
|
contents = self.popn(inst.arg + 2) |
|
if isinstance(contents[0], NullVariable): |
|
fn = contents[1] |
|
args = [] |
|
else: |
|
fn = contents[0] |
|
args = [contents[1]] |
|
kw_names = self.kw_names.value if self.kw_names else () |
|
if kw_names: |
|
args = args + contents[2 : -len(kw_names)] |
|
kwargs_list = contents[-len(kw_names) :] |
|
kwargs = dict(zip(kw_names, kwargs_list)) |
|
assert len(kwargs) == len(kw_names) |
|
else: |
|
args = args + contents[2:] |
|
kwargs = {} |
|
self.call_function(fn, args, kwargs) |
|
self.kw_names = None |
|
|
|
def COPY(self, inst): |
|
self.push(self.stack[-inst.arg]) |
|
|
|
def SWAP(self, inst): |
|
self.stack[-1], self.stack[-inst.arg] = self.stack[-inst.arg], self.stack[-1] |
|
|
|
JUMP_BACKWARD = jump |
|
JUMP_BACKWARD_NO_INTERRUPT = jump |
|
|
|
POP_JUMP_FORWARD_IF_TRUE = generic_jump(operator.truth, False) |
|
POP_JUMP_BACKWARD_IF_TRUE = generic_jump(operator.truth, False) |
|
POP_JUMP_FORWARD_IF_FALSE = generic_jump(operator.not_, False) |
|
POP_JUMP_BACKWARD_IF_FALSE = generic_jump(operator.not_, False) |
|
|
|
def CACHE(self, inst): |
|
pass |
|
|
|
def BEFORE_WITH(self, inst): |
|
self.setup_or_before_with(inst) |
|
|
|
def setup_or_before_with(self, inst): |
|
ctx = self.pop() |
|
if not isinstance(ctx, ContextWrappingVariable): |
|
unimplemented(f"{inst.opname} {ctx}") |
|
|
|
if isinstance(ctx, GenericContextWrappingVariable): |
|
self.generic_context_manager_depth += 1 |
|
|
|
exit = WithExitFunctionVariable( |
|
ctx, |
|
inst.target, |
|
) |
|
|
|
if sys.version_info >= (3, 11): |
|
|
|
|
|
|
|
|
|
if inst.exn_tab_entry and ( |
|
not self.block_stack |
|
or inst.exn_tab_entry.target is not self.block_stack[-1].target |
|
): |
|
target = None |
|
else: |
|
target = self.next_instruction.exn_tab_entry.target |
|
else: |
|
target = inst.target |
|
|
|
if target: |
|
if isinstance(self, InstructionTranslator): |
|
self.block_stack.append( |
|
BlockStackEntry(inst, target, len(self.stack), ctx) |
|
) |
|
else: |
|
self.block_stack.append(BlockStackEntry(inst, target)) |
|
|
|
self.push(exit) |
|
self.push(ctx.enter(self)) |
|
|
|
def append_prefix_inst(self, inst): |
|
assert self.accept_prefix_inst |
|
self.prefix_insts.append(inst) |
|
|
|
def MAKE_CELL(self, inst): |
|
if sys.version_info >= (3, 12) and not self.accept_prefix_inst: |
|
|
|
|
|
assert isinstance(self.symbolic_locals[inst.argval], NullVariable) |
|
self.symbolic_locals[ |
|
inst.argval |
|
] = self.output.side_effects.track_cell_new() |
|
else: |
|
self.append_prefix_inst(inst) |
|
|
|
def COPY_FREE_VARS(self, inst): |
|
self.append_prefix_inst(inst) |
|
|
|
def RETURN_GENERATOR(self, inst): |
|
self.append_prefix_inst(inst) |
|
|
|
|
|
|
|
|
|
|
|
def END_FOR(self, inst): |
|
self.popn(2) |
|
|
|
def LOAD_FAST_CHECK(self, inst): |
|
if isinstance(self.symbolic_locals[inst.argval], NullVariable): |
|
unimplemented("LOAD_FAST_CHECK on uninitialized variable") |
|
self.LOAD_FAST(inst) |
|
|
|
def LOAD_FAST_AND_CLEAR(self, inst): |
|
if inst.argval not in self.symbolic_locals: |
|
self.push(NullVariable()) |
|
else: |
|
self.LOAD_FAST(inst) |
|
self.symbolic_locals[inst.argval] = NullVariable() |
|
|
|
def LOAD_SUPER_ATTR(self, inst): |
|
self.CALL_FUNCTION(dataclasses.replace(inst, argval=2)) |
|
if inst.arg & 1: |
|
self.LOAD_METHOD(inst) |
|
else: |
|
self._load_attr(inst) |
|
|
|
def CALL_INTRINSIC_1(self, inst): |
|
if inst.argval == 5: |
|
|
|
self.UNARY_POSITIVE(inst) |
|
elif inst.argval == 6: |
|
|
|
self.push(TupleVariable(self.pop().unpack_var_sequence(self))) |
|
else: |
|
unimplemented(f"missing CALL_INTRINSIC_1 operand {inst.argval}") |
|
|
|
def END_SEND(self, inst): |
|
del self.stack[-2] |
|
|
|
def is_non_empty_graph(self): |
|
if self.output.count_calls() > 1: |
|
|
|
self.is_non_empty_graph = lambda: True |
|
return True |
|
return False |
|
|
|
def format_frame_summary(self, additional_stack_frames=None): |
|
if additional_stack_frames is None: |
|
additional_stack_frames = [] |
|
return "".join( |
|
traceback.format_list( |
|
[self.frame_summary()] + list(reversed(additional_stack_frames)) |
|
) |
|
) |
|
|
|
def frame_summary(self): |
|
return traceback.FrameSummary( |
|
getattr(self.f_code, "co_filename", "<unknown>"), |
|
self.lineno, |
|
getattr(self.f_code, "co_name", "<unknown>"), |
|
lookup_line=False, |
|
) |
|
|
|
def store_global_weakref_by_id(self, prefix, value): |
|
global_name = self.output.install_global_by_id(prefix, weakref.ref(value)) |
|
install_guard( |
|
GlobalWeakRefSource(global_name).make_guard(GuardBuilder.WEAKREF_ALIVE) |
|
) |
|
return global_name |
|
|
|
@property |
|
def fake_mode(self): |
|
return self.output.tracing_context.fake_mode |
|
|
|
def find_symbolic_locals_name(self, tensor_variable): |
|
for key, value in self.symbolic_locals.items(): |
|
if value is tensor_variable: |
|
return key |
|
return None |
|
|
|
@contextlib.contextmanager |
|
def strict_translation_mode(self, check_fn: Callable[[VariableTracker], bool]): |
|
""" |
|
Strict mode is enabled on a per-VariableTracker level depending on the return value of check_fn(node). |
|
""" |
|
prior = self.strict_checks_fn |
|
self.strict_checks_fn = check_fn |
|
try: |
|
yield |
|
finally: |
|
self.strict_checks_fn = prior |
|
|
|
def speculate(self) -> SpeculationEntry: |
|
return self.speculation_log.next( |
|
self.f_code.co_filename, self.lineno, self.instruction_pointer |
|
) |
|
|
|
def __init__( |
|
self, |
|
output: OutputGraph, |
|
instructions: List[Instruction], |
|
f_locals: Dict[str, Any], |
|
f_globals: Dict[str, Any], |
|
f_builtins: Dict[str, Any], |
|
code_options: Dict[str, Any], |
|
symbolic_locals: Dict[str, VariableTracker], |
|
symbolic_globals: Dict[str, VariableTracker], |
|
f_code: types.CodeType, |
|
export: bool, |
|
inline_depth: int, |
|
speculation_log: SpeculationLog, |
|
): |
|
super().__init__() |
|
self.speculation_log = speculation_log |
|
|
|
|
|
self.output = output |
|
self.symbolic_locals = symbolic_locals |
|
self.symbolic_globals = symbolic_globals |
|
self.stack = [] |
|
self.instruction_pointer = 0 |
|
self.current_instruction = create_instruction("NOP") |
|
self.block_stack = [] |
|
|
|
self.generic_context_manager_depth = 0 |
|
self.lineno = -1 |
|
self.kw_names = None |
|
self.accept_prefix_inst = True |
|
self.prefix_insts = [] |
|
self.exn_vt_stack = [] |
|
|
|
|
|
self.instructions: List[Instruction] = instructions |
|
self.indexof: Dict[Instruction, int] = get_indexof(self.instructions) |
|
self.f_locals: Dict[ |
|
str, Any |
|
] = f_locals |
|
self.f_globals: Dict[str, Any] = f_globals |
|
self.f_builtins: Dict[str, Any] = f_builtins |
|
self.code_options: Dict[str, Any] = code_options |
|
self.f_code: types.CodeType = f_code |
|
|
|
|
|
if config.replay_record_enabled: |
|
self.exec_recorder = ExecutionRecorder( |
|
code=f_code, code_options=code_options |
|
) |
|
else: |
|
self.exec_recorder = None |
|
|
|
|
|
|
|
self.nn_module_stack: Dict[str, Tuple[str, Type[Any]]] = {} |
|
|
|
self.export = export |
|
|
|
self.current_speculation = None |
|
|
|
self.strict_checks_fn = None |
|
|
|
if sys.version_info >= (3, 10): |
|
from .resume_execution import ( |
|
CO_ASYNC_GENERATOR, |
|
CO_COROUTINE, |
|
CO_GENERATOR, |
|
CO_ITERABLE_COROUTINE, |
|
) |
|
|
|
if f_code.co_flags & ( |
|
CO_GENERATOR | CO_COROUTINE | CO_ITERABLE_COROUTINE | CO_ASYNC_GENERATOR |
|
): |
|
self.push(BuiltinVariable(None)) |
|
|
|
self.inline_depth = inline_depth |
|
self.inconsistent_side_effects = False |
|
self._constants_cache: List[Optional[VariableTracker]] = [None] * len( |
|
f_code.co_consts |
|
) |
|
linecache.lazycache(f_code.co_filename, f_globals) |
|
|
|
|
|
class InstructionTranslator(InstructionTranslatorBase): |
|
mutated_closure_cell_contents: Set[str] |
|
|
|
@staticmethod |
|
def current_tx() -> "InstructionTranslator": |
|
return tls.current_tx |
|
|
|
@contextlib.contextmanager |
|
def set_current_tx(self): |
|
prior = getattr(tls, "current_tx", None) |
|
tls.current_tx = self |
|
try: |
|
yield |
|
finally: |
|
tls.current_tx = prior |
|
|
|
def __init__( |
|
self, |
|
instructions: List[Instruction], |
|
f_code, |
|
f_locals, |
|
f_globals, |
|
f_builtins, |
|
code_options, |
|
compiler_fn, |
|
one_graph, |
|
export, |
|
export_constraints, |
|
mutated_closure_cell_contents: Set[str], |
|
frame_state, |
|
speculation_log: SpeculationLog, |
|
): |
|
_step_logger()( |
|
logging.INFO, |
|
f"torchdynamo start tracing {f_code.co_name} {code_options['co_filename']}:{code_options['co_firstlineno']}", |
|
) |
|
super().__init__( |
|
output=OutputGraph( |
|
code_options, |
|
compiler_fn, |
|
self, |
|
export, |
|
export_constraints, |
|
frame_state, |
|
local_scope=f_locals, |
|
global_scope=f_globals, |
|
f_code=f_code, |
|
), |
|
instructions=instructions, |
|
f_locals=f_locals, |
|
f_globals=f_globals, |
|
f_builtins=f_builtins, |
|
code_options=code_options, |
|
symbolic_locals={}, |
|
|
|
symbolic_globals={}, |
|
f_code=f_code, |
|
export=export, |
|
inline_depth=0, |
|
speculation_log=speculation_log, |
|
) |
|
|
|
self._throw_if_in_functorch() |
|
|
|
|
|
|
|
with tracing(self.output.tracing_context), self.set_current_tx(): |
|
self.one_graph: bool = one_graph |
|
self.export = export |
|
self.mutated_closure_cell_contents = mutated_closure_cell_contents |
|
if self.export: |
|
assert ( |
|
self.one_graph |
|
), "Export without one graph - something has gone wrong." |
|
|
|
vars = list(code_options["co_varnames"]) |
|
cells_and_freevars = [x for x in self.cell_and_freevars() if x not in vars] |
|
vars.extend(cells_and_freevars) |
|
cells_and_freevars_set = set(cells_and_freevars) |
|
|
|
self.symbolic_locals = { |
|
k: variables.LazyVariableTracker.create( |
|
f_locals[k], |
|
source=LocalSource(k, cell_or_freevar=k in cells_and_freevars_set), |
|
) |
|
for k in vars |
|
if k in f_locals |
|
} |
|
|
|
self.debug_locals: List[Tuple[VariableTracker, List[VariableTracker]]] = [] |
|
if export: |
|
|
|
|
|
self.symbolic_locals = variables.LazyVariableTracker.realize_all( |
|
self.symbolic_locals |
|
) |
|
|
|
self._freevars_ids = dict() |
|
for name in self.code_options["co_freevars"]: |
|
if name in f_locals: |
|
self._freevars_ids[name] = id(f_locals[name]) |
|
|
|
def _throw_if_in_functorch(self): |
|
|
|
eager = torch._dynamo.lookup_backend("eager") |
|
compiler_fn = inspect.getattr_static( |
|
self.output.compiler_fn, "compiler_fn", self.output.compiler_fn |
|
) |
|
ci = torch._C._functorch.peek_interpreter_stack() |
|
forbidden_keys = ( |
|
torch._C._functorch.TransformType.Vmap, |
|
torch._C._functorch.TransformType.Grad, |
|
torch._C._functorch.TransformType.Jvp, |
|
) |
|
if ci is not None and ci.key() in forbidden_keys and compiler_fn is not eager: |
|
|
|
name = ci.key().name.lower() |
|
msg = f"torch.func.{name}(fn) requires the function to be inlined by dynamo" |
|
unimplemented(msg) |
|
|
|
def get_example_value(self, source: Source): |
|
if isinstance(source, LocalSource): |
|
return self.f_locals[source.local_name] |
|
if isinstance(source, GlobalSource): |
|
return self.f_globals[source.global_name] |
|
raise KeyError |
|
|
|
def run(self): |
|
super().run() |
|
|
|
def match_nested_cell(self, name, cell): |
|
"""Match a cell in this method to one in a function we are inlining""" |
|
try: |
|
value = cell.cell_contents |
|
except ValueError: |
|
return None |
|
|
|
if id(value) != self._freevars_ids.get(name): |
|
return None |
|
return self.symbolic_locals[name] |
|
|
|
def should_compile_partial_graph(self): |
|
if sys.version_info >= (3, 11): |
|
|
|
entry = self.current_instruction.exn_tab_entry |
|
if entry and ( |
|
not self.block_stack or entry.target is not self.block_stack[-1].target |
|
): |
|
return False |
|
return ( |
|
all(b.can_restore() for b in self.block_stack) |
|
and not self.one_graph |
|
and self.generic_context_manager_depth == 0 |
|
) |
|
|
|
def create_call_resume_at(self, inst): |
|
self.instruction_pointer = None |
|
|
|
if inst.opname == "RETURN_VALUE": |
|
return [create_instruction("RETURN_VALUE")] |
|
elif inst.opname == "RETURN_CONST": |
|
return [create_instruction("RETURN_CONST", argval=inst.argval)] |
|
|
|
reads = livevars_analysis(self.instructions, inst) |
|
all_argnames = tuple( |
|
k |
|
for k in self.symbolic_locals.keys() |
|
if k in reads and k not in self.cell_and_freevars() |
|
) |
|
|
|
argnames = tuple( |
|
k |
|
for k in all_argnames |
|
if not type.__instancecheck__(NullVariable, self.symbolic_locals[k]) |
|
) |
|
argnames_null = tuple( |
|
k |
|
for k in all_argnames |
|
if type.__instancecheck__(NullVariable, self.symbolic_locals[k]) |
|
) |
|
if sys.version_info < (3, 12): |
|
assert len(argnames_null) == 0, "variables should not be NULL in < 3.12" |
|
|
|
cg = PyCodegen(self) |
|
|
|
|
|
|
|
|
|
stack_ctx_vars = [] |
|
for i, var in enumerate(self.stack): |
|
if type.__instancecheck__(ContextWrappingVariable, var): |
|
ctx = cast(ContextWrappingVariable, var) |
|
target_values = ( |
|
() if ctx.target_values is None else tuple(ctx.target_values) |
|
) |
|
stack_ctx_vars.append((i, target_values)) |
|
|
|
ctx.reconstruct_type(cg) |
|
cg.extend_output(create_swap(len(self.stack) - i + 1)) |
|
cg.append_output(create_instruction("POP_TOP")) |
|
|
|
argnames_ctx_vars = [] |
|
for name in argnames: |
|
if type.__instancecheck__( |
|
ContextWrappingVariable, var := self.symbolic_locals[name] |
|
): |
|
ctx = cast(ContextWrappingVariable, var) |
|
target_values = ( |
|
() if ctx.target_values is None else tuple(ctx.target_values) |
|
) |
|
argnames_ctx_vars.append((name, target_values)) |
|
|
|
ctx.reconstruct_type(cg) |
|
cg.append_output(create_instruction("STORE_FAST", argval=name)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
null_idxes: List[int] = [] |
|
if sys.version_info >= (3, 11): |
|
|
|
for i, var in enumerate(self.stack): |
|
if type.__instancecheck__(NullVariable, var): |
|
null_idxes.append(i) |
|
|
|
null_cnt = 0 |
|
for i, var in enumerate(reversed(self.stack)): |
|
if type.__instancecheck__(NullVariable, var): |
|
for j in range(2, i + 2 - null_cnt): |
|
cg.append_output(create_instruction("SWAP", arg=j)) |
|
cg.extend_output(cg.pop_null()) |
|
null_cnt += 1 |
|
|
|
|
|
|
|
stack_len = len(self.stack) - len(null_idxes) |
|
nargs = stack_len + len(argnames) |
|
|
|
name = unique_id(f"__resume_at_{inst.offset}") |
|
|
|
new_code: types.CodeType = ContinueExecutionCache.lookup( |
|
self.f_code, |
|
self.lineno, |
|
inst.offset, |
|
tuple(b.target.offset for b in self.block_stack), |
|
stack_len, |
|
argnames, |
|
argnames_null, |
|
tuple(b.resume_fn() for b in self.block_stack), |
|
tuple(stack_ctx_vars), |
|
tuple(argnames_ctx_vars), |
|
tuple(null_idxes), |
|
) |
|
|
|
|
|
|
|
orig_graphmodule_maybe = code_context.get_context(self.f_code).get( |
|
"orig_graphmodule", lambda: None |
|
)() |
|
if orig_graphmodule_maybe is not None: |
|
code_context.get_context(new_code)["orig_graphmodule"] = weakref.ref( |
|
orig_graphmodule_maybe |
|
) |
|
|
|
if new_code.co_freevars: |
|
|
|
self.output.install_global_unsafe(name, new_code) |
|
cg.make_function_with_closure(name, new_code, True, stack_len) |
|
else: |
|
|
|
self.output.install_global_unsafe( |
|
name, types.FunctionType(new_code, self.f_globals, name) |
|
) |
|
cg.extend_output(cg.load_function_name(name, True, stack_len)) |
|
|
|
cg.extend_output([cg.create_load(k) for k in argnames]) |
|
cg.extend_output(create_call_function(nargs, False)) |
|
cg.append_output(create_instruction("RETURN_VALUE")) |
|
return cg.get_instructions() |
|
|
|
def symbolic_locals_contain_module_class(self): |
|
for v in self.symbolic_locals.values(): |
|
if isinstance(v, UserDefinedClassVariable) and issubclass( |
|
v.as_python_constant(), torch.nn.Module |
|
): |
|
return True |
|
return False |
|
|
|
def _return(self, inst): |
|
if ( |
|
self.output.count_calls() == 0 |
|
and not self.inconsistent_side_effects |
|
and not self.symbolic_locals_contain_module_class() |
|
and not self.export |
|
): |
|
raise exc.SkipFrame("because no content in function call") |
|
self.instruction_pointer = None |
|
_step_logger()( |
|
logging.INFO, |
|
f"torchdynamo done tracing {self.f_code.co_name} ({inst.opname})", |
|
) |
|
log.debug("%s triggered compile", inst.opname) |
|
self.output.compile_subgraph( |
|
self, |
|
reason=GraphCompileReason( |
|
"return_value", [self.frame_summary()], graph_break=False |
|
), |
|
) |
|
return_inst = ( |
|
create_instruction("RETURN_VALUE") |
|
if inst.opname == "RETURN_VALUE" |
|
else create_instruction("RETURN_CONST", argval=inst.argval) |
|
) |
|
self.output.add_output_instructions([return_inst]) |
|
raise ReturnValueOp |
|
|
|
def RETURN_VALUE(self, inst): |
|
self._return(inst) |
|
|
|
def RETURN_CONST(self, inst): |
|
self._return(inst) |
|
|
|
|
|
if sys.version_info >= (3, 11): |
|
_binary_op_lookup = [ |
|
getattr( |
|
InstructionTranslator, |
|
opname[3:] if "INPLACE" in opname else f"BINARY_{opname[3:]}", |
|
) |
|
for opname, _ in dis._nb_ops |
|
] |
|
|
|
|
|
class InliningInstructionTranslator(InstructionTranslatorBase): |
|
"""Trace and inline a called method""" |
|
|
|
symbolic_result: Optional[TensorVariable] |
|
|
|
@classmethod |
|
def inline_call(cls, parent, func, args, kwargs): |
|
with patch.dict(counters, {"unimplemented": counters["inline_call"]}): |
|
return cls.inline_call_(parent, func, args, kwargs) |
|
|
|
@staticmethod |
|
def check_inlineable(func): |
|
if func.has_self(): |
|
unimplemented("inline with __self__") |
|
|
|
result = trace_rules.check_verbose(func, is_inlined_call=True) |
|
if result.skipped: |
|
from torch._dynamo.variables.misc import produce_trampoline_autograd_apply |
|
|
|
|
|
|
|
if hasattr(getattr(func, "fn", None), "_origin") and func.fn._origin in [ |
|
produce_trampoline_autograd_apply, |
|
]: |
|
|
|
return trace_rules.SkipResult( |
|
False, "allowlist in dynamo known function" |
|
) |
|
fn_qualname = func.fn.__qualname__ if hasattr(func, "fn") else "" |
|
unimplemented( |
|
f"'inline in skipfiles: {fn_qualname} | {func.get_name()} {func.get_filename()}, {result.reason}'" |
|
) |
|
|
|
if isinstance(func, UserFunctionVariable) and inspect.getattr_static( |
|
func.get_function(), "_torchdynamo_disable", False |
|
): |
|
unimplemented( |
|
f"call torch._dynamo.disable() wrapped function {func.get_function()}" |
|
) |
|
else: |
|
return result |
|
|
|
@staticmethod |
|
def inline_call_( |
|
parent, func: VariableTracker, args: List[VariableTracker], kwargs |
|
): |
|
if isinstance(func, SkipFunctionVariable): |
|
unimplemented("inline with functions in skip files") |
|
assert isinstance( |
|
func, |
|
(UserFunctionVariable, NestedUserFunctionVariable), |
|
) |
|
result = InliningInstructionTranslator.check_inlineable(func) |
|
assert result.skipped is False |
|
try: |
|
sub_locals, closure_cells = func.bind_args(parent, args, kwargs) |
|
except TypeError as e: |
|
|
|
raise ArgsMismatchError( |
|
"{reason}.\n func = {func}, args = {args}, kwargs = {kwargs}".format( |
|
reason=str(e), |
|
func=f"'{func.get_name()}' {func.get_filename()}:{func.get_code().co_firstlineno}", |
|
args=[arg.python_type() for arg in args], |
|
kwargs=kwargs, |
|
), |
|
) |
|
|
|
for v in itertools.chain(sub_locals.values(), closure_cells.values()): |
|
if not isinstance(v, VariableTracker): |
|
unimplemented(f"unconverted arg {v}") |
|
|
|
code: types.CodeType = func.get_code() |
|
if code.co_name in ("__setitem__", "__setattr__") and not ( |
|
args |
|
and isinstance( |
|
args[0], |
|
(variables.CustomizedDictVariable, variables.UserDefinedObjectVariable), |
|
) |
|
): |
|
unimplemented(f"inline {code.co_name}") |
|
|
|
suffix = "" |
|
|
|
|
|
if torch._logging._internal.log_state.is_artifact_enabled("bytecode"): |
|
suffix = f"\n{dis.Bytecode(code).dis()}" |
|
if sys.version_info >= (3, 11): |
|
cur_inst = parent.current_instruction |
|
parent_code = parent.f_code |
|
header = parent.get_line_of_code_header(lineno=cur_inst.positions.lineno) |
|
|
|
def get_trace_call_log_str(): |
|
line = get_instruction_source_311(parent_code, cur_inst).rstrip() |
|
return f"TRACE inlined call {code.co_name} from {header}\n{line}" |
|
|
|
trace_call_log.debug("%s", LazyString(get_trace_call_log_str)) |
|
log.debug("INLINING %s%s, %s", code, suffix, result.reason) |
|
|
|
|
|
|
|
if args and isinstance(args[0], NNModuleVariable): |
|
module = parent.output.get_submodule(args[0].module_key) |
|
if isinstance(module, torch.fx.GraphModule): |
|
|
|
|
|
code_context.get_context(module.forward.__code__)[ |
|
"orig_graphmodule" |
|
] = weakref.ref(module) |
|
|
|
tracer: InliningInstructionTranslator |
|
if is_generator(code): |
|
tracer = InliningGeneratorInstructionTranslator( |
|
parent, code, sub_locals, parent.symbolic_globals, closure_cells, func |
|
) |
|
else: |
|
tracer = InliningInstructionTranslator( |
|
parent, code, sub_locals, parent.symbolic_globals, closure_cells, func |
|
) |
|
|
|
strict_ctx: Any = contextlib.nullcontext() |
|
if parent.strict_checks_fn: |
|
strict_ctx = tracer.strict_translation_mode(parent.strict_checks_fn) |
|
try: |
|
with strict_ctx: |
|
tracer.run() |
|
except exc.ObservedException as e: |
|
msg = f"Observed exception DURING INLING {code} : {e}" |
|
|
|
|
|
parent.exn_vt_stack.extend(tracer.exn_vt_stack) |
|
log.debug(msg) |
|
|
|
raise |
|
except exc.SkipFrame as e: |
|
msg = f"SKIPPED INLINING {code}: {e}" |
|
log.debug(msg) |
|
raise Unsupported(msg) from e |
|
except Exception as e: |
|
log.debug("FAILED INLINING %s", code) |
|
raise |
|
assert tracer.symbolic_result is not None |
|
func.export_freevars(parent, tracer) |
|
|
|
if tracer.f_globals is parent.f_globals: |
|
|
|
parent.symbolic_globals.update(tracer.symbolic_globals) |
|
|
|
parent.inconsistent_side_effects |= tracer.inconsistent_side_effects |
|
|
|
log.debug("DONE INLINING %s", code) |
|
|
|
if is_generator(code): |
|
assert isinstance(tracer, InliningGeneratorInstructionTranslator) |
|
assert tracer.symbolic_result.as_python_constant() is None |
|
return ListIteratorVariable( |
|
tracer.generated_items, |
|
mutable_local=MutableLocal(), |
|
) |
|
else: |
|
return tracer.symbolic_result |
|
|
|
def __init__( |
|
self, |
|
parent: InstructionTranslatorBase, |
|
code: types.CodeType, |
|
symbolic_locals: Dict[str, VariableTracker], |
|
symbolic_globals: Dict[str, VariableTracker], |
|
closure_cells: Dict[str, VariableTracker], |
|
funcvar: BaseUserFunctionVariable, |
|
): |
|
f_globals = funcvar.get_globals() |
|
f_builtins = f_globals["__builtins__"] |
|
if not isinstance(f_builtins, dict): |
|
f_builtins = f_builtins.__dict__ |
|
instructions = cleaned_instructions(code) |
|
propagate_line_nums(instructions) |
|
super().__init__( |
|
output=parent.output, |
|
f_locals={}, |
|
f_globals=f_globals, |
|
f_builtins=f_builtins, |
|
symbolic_locals=symbolic_locals, |
|
symbolic_globals=symbolic_globals, |
|
instructions=instructions, |
|
code_options={k: getattr(code, k) for k in get_code_keys()}, |
|
f_code=code, |
|
export=parent.export, |
|
inline_depth=parent.inline_depth + 1, |
|
speculation_log=parent.speculation_log, |
|
) |
|
self.parent = parent |
|
self.symbolic_result = None |
|
self.closure_cells = closure_cells |
|
self.nn_module_stack = parent.nn_module_stack.copy() |
|
|
|
@property |
|
def fake_mode(self): |
|
return self.parent.fake_mode |
|
|
|
def run_ctx_mgr(self): |
|
return TracingContext.current_frame(self.parent.frame_summary()) |
|
|
|
def STORE_DEREF(self, inst): |
|
if inst.argval in self.closure_cells: |
|
cell = self.closure_cells[inst.argval] |
|
val = self.pop() |
|
if isinstance(cell, ClosureVariable): |
|
if not self.output.is_root_tracer(): |
|
unimplemented( |
|
"HigherOrderOperator: Mutating a variable not in the current scope (ClosureVariable)" |
|
) |
|
self.output.root_tx.symbolic_locals[cell.name] = val |
|
else: |
|
self.output.side_effects.store_cell(cell, val) |
|
else: |
|
maybe_cell = self.symbolic_locals.get(inst.argval) |
|
if isinstance( |
|
maybe_cell, |
|
variables.NewCellVariable, |
|
): |
|
self.output.side_effects.store_cell( |
|
self.symbolic_locals[inst.argval], self.pop() |
|
) |
|
else: |
|
if ( |
|
maybe_cell is not None |
|
and maybe_cell.source.name() |
|
not in self.output.root_tx.mutated_closure_cell_contents |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.output.root_tx.mutated_closure_cell_contents.add( |
|
maybe_cell.source.name() |
|
) |
|
raise exc.UnspecializeRestartAnalysis |
|
unimplemented("write to __closure__ while inlining") |
|
|
|
def LOAD_DEREF(self, inst): |
|
if inst.argval in self.closure_cells: |
|
cell = self.closure_cells[inst.argval] |
|
if isinstance(cell, ClosureVariable): |
|
self.push(self.output.root_tx.symbolic_locals[cell.name]) |
|
else: |
|
self.push(self.output.side_effects.load_cell(cell)) |
|
else: |
|
maybe_sym_local = self.symbolic_locals.get(inst.argval, None) |
|
if isinstance(maybe_sym_local, variables.NewCellVariable): |
|
self.push(self.output.side_effects.load_cell(maybe_sym_local)) |
|
else: |
|
super().LOAD_DEREF(inst) |
|
|
|
def LOAD_CLOSURE(self, inst): |
|
assert inst.argval in self.cell_and_freevars() |
|
if inst.argval in self.closure_cells: |
|
self.push(self.closure_cells[inst.argval]) |
|
else: |
|
self.push(InlinedClosureVariable(name=inst.argval)) |
|
|
|
def check_replace_is_safe(self, oldvar): |
|
if not is_side_effect_safe(oldvar.mutable_local): |
|
unimplemented( |
|
"HigherOrderOperator: Mutating a variable not in the current scope (replace_all)" |
|
) |
|
|
|
def should_compile_partial_graph(self): |
|
return False |
|
|
|
def create_call_resume_at(self, offset): |
|
unimplemented("cant resume while inlining") |
|
|
|
def RETURN_VALUE(self, inst): |
|
self.symbolic_result = self.pop() |
|
self.instruction_pointer = None |
|
raise ReturnValueOp |
|
|
|
def RETURN_CONST(self, inst): |
|
self.symbolic_result = self._load_const(inst) |
|
self.instruction_pointer = None |
|
raise ReturnValueOp |
|
|
|
def get_globals_source_and_value(self, name): |
|
if "__name__" in self.f_globals: |
|
module_name = self.f_globals["__name__"] |
|
module_source = self.import_source(module_name) |
|
if "torch_package" in module_name: |
|
fglobals_value = torch.package.package_importer._package_imported_modules[module_name] |
|
else: |
|
fglobals_value = importlib.import_module(module_name) |
|
fglobals_vt = VariableBuilder(self, module_source)(fglobals_value) |
|
global_source = AttrSource(module_source, name) |
|
else: |
|
globals_name = self.output.install_global_by_id( |
|
"___unnamed_scope", self.f_globals |
|
) |
|
globals_source = GlobalSource(globals_name) |
|
fglobals_value = self.f_globals |
|
fglobals_vt = VariableBuilder(self, globals_source)(fglobals_value) |
|
global_source = GetItemSource(globals_source, name) |
|
return fglobals_value, fglobals_vt, global_source |
|
|
|
def LOAD_GLOBAL(self, inst): |
|
if self.output.global_scope is self.f_globals: |
|
super().LOAD_GLOBAL(inst) |
|
else: |
|
if sys.version_info >= (3, 11): |
|
if inst.arg % 2: |
|
self.PUSH_NULL(inst) |
|
|
|
name = inst.argval |
|
|
|
_, fglobals_vt, global_source = self.get_globals_source_and_value(name) |
|
if self.output.side_effects.has_pending_mutation_of_attr(fglobals_vt, name): |
|
self.push(self.output.side_effects.load_attr(fglobals_vt, name)) |
|
else: |
|
try: |
|
value = self.f_globals[name] |
|
except KeyError: |
|
return self.load_builtin(inst) |
|
|
|
self.push(VariableBuilder(self, global_source)(value)) |
|
|
|
def STORE_GLOBAL(self, inst): |
|
if self.f_globals is self.parent.f_globals: |
|
super().STORE_GLOBAL(inst) |
|
else: |
|
value = self.pop() |
|
if isinstance(value, RemovableHandleVariable): |
|
unimplemented("Storing handles in globals - NYI") |
|
name = inst.argval |
|
fglobals_value, fglobals_vt, _ = self.get_globals_source_and_value(name) |
|
fglobals_vt = self.output.side_effects.track_object_existing( |
|
fglobals_value, fglobals_vt |
|
) |
|
self.output.side_effects.store_attr(fglobals_vt, name, value) |
|
|
|
|
|
class InliningGeneratorInstructionTranslator(InliningInstructionTranslator): |
|
generated_items: List[VariableTracker] |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.generated_items = [] |
|
|
|
def YIELD_VALUE(self, inst: Instruction): |
|
self.generated_items.append(self.pop()) |
|
self.push(ConstantVariable.create(None)) |
|
|
|
def GET_YIELD_FROM_ITER(self, inst): |
|
tos = self.stack[-1] |
|
if not isinstance(tos, ListIteratorVariable): |
|
self.pop() |
|
res = BuiltinVariable(iter).call_function(self, [tos], {}) |
|
self.push(res) |
|
|
|
def YIELD_FROM(self, inst): |
|
assert len(self.stack) >= 2 |
|
val = self.pop() |
|
tos = self.stack[-1] |
|
if not (isinstance(val, ConstantVariable) and val.value is None): |
|
|
|
|
|
|
|
|
|
|
|
unimplemented("Unreachable sub-generator code") |
|
|
|
try: |
|
val = tos.next_variable(self) |
|
except (StopIteration, exc.UserStopIteration) as ex: |
|
|
|
self.pop() |
|
self.push(ConstantVariable.create(ex.value)) |
|
else: |
|
self.push(val) |
|
|
|
self.YIELD_VALUE(inst) |
|
|
|
|
|
assert ( |
|
isinstance(self.instruction_pointer, int) |
|
and self.instruction_pointer > 0 |
|
) |
|
self.instruction_pointer -= 1 |
|
|
|
def SEND(self, inst): |
|
assert len(self.stack) >= 2 |
|
val = self.pop() |
|
tos = self.stack[-1] |
|
if isinstance(tos, ListIteratorVariable) or ( |
|
isinstance(tos, UserDefinedObjectVariable) |
|
and isinstance(tos.value, collections.abc.Iterator) |
|
): |
|
if isinstance(val, ConstantVariable) and val.value is None: |
|
try: |
|
val = tos.next_variable(self) |
|
except (StopIteration, exc.UserStopIteration) as ex: |
|
|
|
|
|
|
|
|
|
|
|
|
|
if sys.version_info < (3, 12): |
|
self.pop() |
|
self.push(ConstantVariable.create(ex.value)) |
|
self.jump(inst) |
|
else: |
|
self.push(val) |
|
else: |
|
|
|
|
|
|
|
|
|
|
|
unimplemented("Unreachable sub-generator code") |
|
else: |
|
unimplemented(f"SEND {typestr(tos)}") |
|
|