Spaces:
Running
Running
import collections | |
import contextlib | |
import copy | |
import dataclasses | |
import dis | |
import functools | |
import importlib | |
import inspect | |
import itertools | |
import linecache | |
import logging | |
import operator | |
import sys | |
import textwrap | |
import threading | |
import traceback | |
import types | |
import typing | |
import weakref | |
from typing import Any, Dict, List, NamedTuple, Optional, Set, Tuple, Type | |
from unittest.mock import patch | |
import torch | |
import torch._logging | |
from torch._guards import Checkpointable, tracing, TracingContext | |
from . import config, exc, logging as torchdynamo_logging, trace_rules, variables | |
from .bytecode_analysis import ( | |
get_indexof, | |
JUMP_OPNAMES, | |
livevars_analysis, | |
propagate_line_nums, | |
) | |
from .bytecode_transformation import ( | |
cleaned_instructions, | |
create_call_function, | |
create_instruction, | |
create_jump_absolute, | |
Instruction, | |
is_generator, | |
unique_id, | |
) | |
from .code_context import code_context | |
from .codegen import PyCodegen | |
from .current_scope_id import current_scope_id | |
from .exc import ArgsMismatchError, BackendCompilerFailed, unimplemented, Unsupported | |
from .funcname_cache import get_funcname | |
from .guards import GuardBuilder, install_guard | |
from .output_graph import GraphCompileReason, OutputGraph, OutputGraphState | |
from .replay_record import DummyModule, ExecutionRecorder | |
from .resume_execution import ContinueExecutionCache, ReenterWith | |
from .source import ( | |
AttrSource, | |
GetItemSource, | |
GlobalSource, | |
GlobalWeakRefSource, | |
LocalSource, | |
Source, | |
) | |
from .trace_rules import is_builtin_constant, is_forbidden | |
from .utils import ( | |
counters, | |
get_fake_value, | |
get_instruction_source_311, | |
graph_break_dup_warning_checker, | |
istype, | |
LazyString, | |
proxy_args_kwargs, | |
) | |
from .variables.base import ( | |
_is_top_level_scope, | |
is_side_effect_safe, | |
MutableLocal, | |
typestr, | |
VariableTracker, | |
) | |
from .variables.builder import VariableBuilder, wrap_fx_proxy | |
from .variables.builtin import BuiltinVariable | |
from .variables.constant import ConstantVariable | |
from .variables.ctx_manager import ( | |
ContextWrappingVariable, | |
GenericContextWrappingVariable, | |
WithExitFunctionVariable, | |
) | |
from .variables.dicts import ConstDictVariable, SetVariable | |
from .variables.functions import ( | |
BaseUserFunctionVariable, | |
NestedUserFunctionVariable, | |
SkipFunctionVariable, | |
UserFunctionVariable, | |
UserMethodVariable, | |
) | |
from .variables.lists import ( | |
BaseListVariable, | |
ListIteratorVariable, | |
ListVariable, | |
SliceVariable, | |
TupleVariable, | |
) | |
from .variables.misc import ( | |
ClosureVariable, | |
GetAttrVariable, | |
InlinedClosureVariable, | |
NullVariable, | |
PythonModuleVariable, | |
UnknownVariable, | |
) | |
from .variables.nn_module import NNModuleVariable | |
from .variables.tensor import ( | |
supported_const_comparison_ops, | |
supported_tensor_comparison_ops, | |
SymNodeVariable, | |
TensorVariable, | |
) | |
from .variables.user_defined import ( | |
RemovableHandleVariable, | |
UserDefinedClassVariable, | |
UserDefinedObjectVariable, | |
UserDefinedVariable, | |
) | |
log = logging.getLogger(__name__) | |
graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks") | |
trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call") | |
trace_source_log = torch._logging.getArtifactLogger(__name__, "trace_source") | |
tls = threading.local() | |
class SpeculationEntry: | |
filename: str | |
lineno: int | |
instruction_pointer: int | |
failed: bool = False | |
reason: Optional[GraphCompileReason] = None | |
def fail_and_restart_analysis(self): | |
""" | |
Start tracing of the current frame over again, and don't take this branch. | |
""" | |
self.failed = True | |
raise exc.SpeculationRestartAnalysis() | |
class SpeculationLog: | |
""" | |
SpeculationLog replaces the prior copy_graphstate/restore_graphstate | |
checkpointing. Rather than saving/restoring state, we restart the | |
dynamo conversion process over from the beginning -- but when we | |
hit the start of the speculation that failed, we instead generate | |
a graph break. | |
""" | |
entries: List[SpeculationEntry] = dataclasses.field(default_factory=list) | |
index: int = 0 | |
def restart(self): | |
self.index = 0 | |
def clear(self): | |
self.entries.clear() | |
self.index = 0 | |
def next(self, filename: str, lineno: int, instruction_pointer) -> SpeculationEntry: | |
""" | |
Lookup or create a SpeculationEntry() that is shared across | |
RestartAnalysis calls. Args are used only for debug checks. | |
""" | |
if len(self.entries) == self.index: | |
self.entries.append(SpeculationEntry(filename, lineno, instruction_pointer)) | |
entry = self.entries[self.index] | |
self.index += 1 | |
assert ( | |
entry.instruction_pointer == instruction_pointer | |
and entry.filename == filename | |
and entry.lineno == lineno | |
), textwrap.dedent( | |
f""" | |
SpecuationLog diverged at {self.index} of {len(self.entries)}: | |
- Run1: {entry.filename}:{entry.lineno} (ip={entry.instruction_pointer}) | |
- Run2: {filename}:{lineno} (ip={instruction_pointer}) | |
Please submit a bug report. | |
""" | |
) | |
return entry | |
def _step_logger(): | |
return torchdynamo_logging.get_step_logger(log) | |
class BlockStackEntry: | |
target: Instruction | |
stack_index: Optional[int] = None | |
with_context: Optional[ContextWrappingVariable] = None | |
def can_restore(self): | |
return self.with_context is not None | |
def resume_fn(self): | |
assert self.stack_index is not None | |
if self.with_context and self.with_context.target_values: | |
return ReenterWith(self.stack_index, tuple(self.with_context.target_values)) | |
else: | |
return ReenterWith(self.stack_index) | |
def exit(self, tx): | |
assert self.with_context is not None | |
return self.with_context.exit(tx) | |
class InstructionTranslatorGraphState(NamedTuple): | |
output: OutputGraphState | |
symbolic_locals: Dict[str, VariableTracker] | |
stack: List[VariableTracker] | |
block_stack: List[BlockStackEntry] | |
instruction_pointer: Optional[int] | |
current_instruction: Instruction | |
next_instruction: Optional[Instruction] | |
lineno: int | |
def diff(self, other: "InstructionTranslatorGraphState") -> Optional[str]: | |
for k in self._fields: | |
if k == "output": | |
return self.output.diff(other.output, prefix=f"{k}.") | |
sv = getattr(self, k) | |
ov = getattr(other, k) | |
if sv != ov: | |
return f"{k} mismatch: {sv} != {ov}" | |
return None | |
def stack_op(fn: typing.Callable[..., object]): | |
nargs = len(inspect.signature(fn).parameters) | |
fn_var = BuiltinVariable(fn) | |
def impl(self: "InstructionTranslatorBase", inst: Instruction): | |
self.push(fn_var.call_function(self, self.popn(nargs), {})) | |
return impl | |
def _detect_and_normalize_assert_statement( | |
self: "InstructionTranslatorBase", | |
truth_fn: typing.Callable[[object], bool], | |
push: bool, | |
): | |
# Detect if this jump instruction is assert and normalize the assert | |
# by pushing dummy error message when nothing is given. | |
# | |
# Python 3.9 assertion is in following format: | |
# 18 POP_JUMP_IF_TRUE 28 | |
# 20 LOAD_ASSERTION_ERROR | |
# 22 LOAD_CONST 3 ('Assert message') -> optional instruction | |
# 24 CALL_FUNCTION 1 -> optional instruction | |
# 26 RAISE_VARARGS | |
# | |
# Python 3.8 assertion is in following format: | |
# 18 POP_JUMP_IF_TRUE 28 | |
# 20 LOAD_GLOBAL 0 (Assertion type) | |
# 22 LOAD_CONST 3 ('Assert message') -> optional instruction | |
# 24 CALL_FUNCTION 1 -> optional instruction | |
# 26 RAISE_VARARGS 1 | |
if (truth_fn is not operator.truth) or push: | |
return False | |
assert isinstance(self.instruction_pointer, int) | |
current_instruction_pointer = self.instruction_pointer | |
inst = self.instructions[current_instruction_pointer] | |
# Detect LOAD_ASSERTION_ERROR or LOAD_GLOBAL 0 | |
if sys.version_info < (3, 9): | |
if inst.opname != "LOAD_GLOBAL" or inst.argval != "AssertionError": | |
return False | |
else: | |
if inst.opname != "LOAD_ASSERTION_ERROR": | |
return False | |
current_instruction_pointer += 1 | |
# Use dummy error message if its hard to extract | |
error_msg = "assertion error" | |
inst = self.instructions[current_instruction_pointer] | |
# DETECT RAISE_VARARGS or LOAD CONST | |
if inst.opname == "LOAD_CONST": | |
if not isinstance(inst.argval, str): | |
return False | |
error_msg = inst.argval | |
# if it is LOAD_CONSTANT, it must be followed by CALL_FUNCTION | |
# (PRECALL for Python 3.11+) | |
current_instruction_pointer += 1 | |
inst = self.instructions[current_instruction_pointer] | |
if inst.opname not in ("CALL_FUNCTION", "PRECALL"): | |
return False | |
# for Python 3.11+, PRECALL should be followed by CALL, then RAISE_VARARGS | |
# for Python < 3.11, CALL_FUNCTION should be followed by RAISE_VARARGS | |
current_instruction_pointer += 1 | |
if inst.opname == "PRECALL": | |
current_instruction_pointer += 1 | |
inst = self.instructions[current_instruction_pointer] | |
if inst.opname != "RAISE_VARARGS": | |
return False | |
self.push(ConstantVariable.create(error_msg)) | |
return True | |
def generic_jump(truth_fn: typing.Callable[[object], bool], push: bool): | |
def inner(self: "InstructionTranslatorBase", inst: Instruction): | |
value: VariableTracker = self.pop() | |
if ( | |
config.rewrite_assert_with_torch_assert | |
and _detect_and_normalize_assert_statement(self, truth_fn, push) | |
): | |
error_msg: VariableTracker = self.pop() | |
# Skip over things like `assert True` | |
if value.is_python_constant() and bool(value.as_python_constant()): | |
self.jump(inst) | |
return | |
# TODO maybe should respect DtoH sync intention of users later?? | |
# Manually insert torch._assert_async instead of python assert and jump over | |
# assert related instructions as we don't need them anymore. | |
# if we see Tensor as assert statement, no need to call scalar_tensor | |
if isinstance(value, TensorVariable): | |
self.output.create_proxy( | |
"call_function", | |
torch._assert_async, | |
*proxy_args_kwargs((value, error_msg), {}), | |
) | |
self.jump(inst) | |
return | |
if isinstance(value, SymNodeVariable): | |
# if the assertion is normal shape expression. | |
# just install guard and bail out. | |
sym_expr = value.sym_num | |
if not isinstance(sym_expr, torch.SymBool): | |
sym_expr = sym_expr != 0 | |
result = torch.fx.experimental.symbolic_shapes.expect_true(sym_expr) | |
if not result: | |
raise unimplemented( | |
"Assertion failed on symbolic shapes. Did you make sure eager mode succeeds?" | |
) | |
self.jump(inst) | |
return | |
scalar_to_tensor_proxy = self.output.create_proxy( | |
"call_function", torch.scalar_tensor, *proxy_args_kwargs((value,), {}) | |
) | |
scalar_to_tensor = wrap_fx_proxy( | |
self, | |
scalar_to_tensor_proxy, | |
example_value=get_fake_value(scalar_to_tensor_proxy.node, self), | |
) | |
self.output.create_proxy( | |
"call_function", | |
torch._assert_async, | |
*proxy_args_kwargs((scalar_to_tensor, error_msg), {}), | |
) | |
self.jump(inst) | |
return | |
if value.is_python_constant(): | |
if truth_fn(value.as_python_constant()): | |
push and self.push(value) | |
self.jump(inst) | |
elif ( | |
isinstance(value, (TensorVariable)) and self.should_compile_partial_graph() | |
): | |
# compile a partial subgraph prefix then jump into user code | |
if self.has_backedge(): | |
msg = ( | |
"Skipping frame because there is a graph break in a for/while loop\n" | |
f"{self.frame_summary()}" | |
) | |
log.info(msg) | |
raise exc.SkipFrame(msg) | |
self.push(value) | |
log.debug("generic_jump triggered compile") | |
self.output.compile_subgraph( | |
self, | |
reason=GraphCompileReason( | |
f"generic_jump {typestr(value)}", [self.frame_summary()] | |
), | |
) | |
self.pop() | |
if_next = self.create_call_resume_at(self.next_instruction) | |
push and self.push(value) | |
if_jump = self.create_call_resume_at(inst.target) | |
self.output.add_output_instructions( | |
[create_instruction(inst.opname, target=if_jump[0])] + if_next + if_jump | |
) | |
elif isinstance(value, NNModuleVariable): | |
# Equivalent of "self.nn_module is not None" | |
mod = self.output.get_submodule(value.module_key) | |
if truth_fn(mod): | |
push and self.push(value) | |
self.jump(inst) | |
elif isinstance(value, UserDefinedObjectVariable): | |
x = value.var_getattr(self, "__bool__") | |
# if __bool__ is missing, trying __len__ to infer a truth value. | |
if isinstance(x, GetAttrVariable): | |
x = value.var_getattr(self, "__len__") | |
# __bool__ or __len__ is function | |
if isinstance(x, UserMethodVariable): | |
result = x.call_function(self, [], {}) | |
if isinstance(result, ConstantVariable) and isinstance( | |
result.value, (bool, int) | |
): | |
if truth_fn(result.value): | |
push and self.push(value) | |
self.jump(inst) | |
else: | |
unimplemented( | |
"generic_jump on UserDefined with __bool__ returning non-constant" | |
) | |
# __bool__ or __len__ is non-function or not existed in the user defined object | |
else: | |
if truth_fn(True): | |
push and self.push(value) | |
self.jump(inst) | |
elif not isinstance(value, TensorVariable) and value.has_unpack_var_sequence( | |
self | |
): | |
if truth_fn(len(value.unpack_var_sequence(self))): | |
push and self.push(value) | |
self.jump(inst) | |
elif isinstance(value, SymNodeVariable): | |
eval_result = value.evaluate_expr(self.output) | |
if truth_fn(eval_result): | |
push and self.push(value) | |
self.jump(inst) | |
elif isinstance(value, variables.BackwardHookVariable): | |
if truth_fn(True): | |
push and self.push(value) | |
self.jump(inst) | |
else: | |
from .source import is_constant_source | |
if value.source is not None and is_constant_source(value.source): | |
if truth_fn(value.get_real_value()): # type: ignore[attr-defined] | |
push and self.push(value) | |
self.jump(inst) | |
else: | |
# TODO link the torch.cond doc later | |
raise exc.UserError( | |
exc.UserErrorType.DYNAMIC_CONTROL_FLOW, | |
"Dynamic control flow is not supported at the moment. Please use " | |
"functorch.experimental.control_flow.cond to explicitly capture the control flow.", | |
case_name="cond_operands", | |
) | |
return inner | |
explain = False | |
def break_graph_if_unsupported(*, push): | |
def decorator(inner_fn): | |
def wrapper(self: "InstructionTranslatorBase", inst: Instruction): | |
speculation = self.speculate() | |
if speculation.failed: | |
assert speculation.reason is not None | |
return handle_graph_break(self, inst, speculation.reason) | |
try: | |
TracingContext.set_current_loc( | |
self.f_code.co_filename, self.lineno, self.f_code.co_name | |
) | |
return inner_fn(self, inst) | |
except Unsupported as excp: | |
if self.generic_context_manager_depth > 0: | |
# We don't support graph break under GenericContextWrappingVariable, | |
# If there is, we roll back to the checkpoint and fall back. | |
excp.remove_from_stats() | |
unimplemented("Graph break under GenericContextWrappingVariable") | |
if isinstance(excp, exc.UncapturedHigherOrderOpError): | |
raise | |
if not self.should_compile_partial_graph(): | |
raise | |
user_stack = excp.real_stack | |
# TODO: Also report the traceback from the parent frame | |
user_stack_formatted = "".join(traceback.format_list(user_stack)) | |
frame_loc = (user_stack[-1].filename, user_stack[-1].lineno) | |
# torch._dynamo.explain() formats this a little nicer, and presents a slightly | |
# more actionable user code pointer | |
if ( | |
graph_break_log.isEnabledFor(logging.DEBUG) | |
and not explain | |
and graph_break_dup_warning_checker.add(frame_loc) | |
): | |
# This log line is exercised from | |
# python test/dynamo/test_exc.py -k test_graph_break_log | |
graph_break_log.debug( | |
"Graph break: from user code at:\n%s", | |
user_stack_formatted, | |
exc_info=True, | |
) | |
else: | |
# This log line MUST NOT contain the string "Graph break", | |
# exercised by | |
# python test/dynamo/test_misc.py -k test_duplicate_graph_break_log | |
log.debug( | |
"Unsupported break in user code at %s:%s (details suppressed)", | |
*frame_loc, | |
) | |
if self.has_backedge(): | |
msg = ( | |
"Skipping frame because there is a graph break in a for/while loop\n" | |
f"{self.frame_summary()}" | |
) | |
log.info(msg) | |
raise exc.SkipFrame(msg) from excp | |
excp.remove_from_stats() | |
excp.add_to_stats("graph_break") | |
speculation.reason = GraphCompileReason(excp.msg, user_stack) | |
speculation.fail_and_restart_analysis() | |
def handle_graph_break( | |
self: "InstructionTranslatorBase", | |
inst: Instruction, | |
reason: GraphCompileReason, | |
): | |
self.output.compile_subgraph(self, reason=reason) | |
cg = PyCodegen(self) | |
cleanup: List[Instruction] = [] | |
# Reconstruct the context variables in the block stack | |
for b in self.block_stack: | |
assert b.with_context is not None | |
cg(b.with_context) | |
cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup)) | |
self.output.add_output_instructions(cg.get_instructions()) | |
del cg | |
if sys.version_info >= (3, 11) and inst.opname == "CALL": | |
kw_names = ( | |
self.kw_names.as_python_constant() | |
if self.kw_names is not None | |
else () | |
) | |
if len(kw_names) > 0: | |
self.output.add_output_instructions( | |
[create_instruction("KW_NAMES", argval=kw_names)] | |
) | |
self.output.add_output_instructions( | |
create_call_function(inst.arg, False) | |
) | |
else: | |
# copy instruction, but without exception table data | |
assert inst.target is None | |
inst_copy = copy.copy(inst) | |
inst_copy.exn_tab_entry = None | |
self.output.add_output_instructions([inst_copy]) | |
self.output.add_output_instructions(cleanup) | |
if sys.version_info >= (3, 11) and inst.opname == "CALL": | |
# stack effect for PRECALL + CALL is split between the two instructions | |
stack_effect = dis.stack_effect( | |
dis.opmap["PRECALL"], inst.arg | |
) + dis.stack_effect(dis.opmap["CALL"], inst.arg) | |
else: | |
stack_effect = dis.stack_effect(inst.opcode, inst.arg) | |
self.popn(push - stack_effect) | |
for _ in range(push): | |
self.push(UnknownVariable()) | |
self.output.add_output_instructions( | |
self.create_call_resume_at(self.next_instruction) | |
) | |
return wrapper | |
return decorator | |
class InstructionTranslatorBase(Checkpointable[InstructionTranslatorGraphState]): | |
output: OutputGraph | |
symbolic_locals: Dict[str, VariableTracker] | |
symbolic_globals: Dict[str, VariableTracker] | |
stack: List[VariableTracker] | |
instruction_pointer: Optional[int] | |
current_instruction: Instruction | |
next_instruction: Optional[Instruction] | |
block_stack: List[BlockStackEntry] | |
lineno: int | |
kw_names: Optional[ConstantVariable] | |
accept_prefix_inst: bool | |
prefix_insts: List[Instruction] | |
inline_depth: int | |
inconsistent_side_effects: bool | |
current_speculation: Optional[SpeculationEntry] | |
def mark_inconsistent_side_effects(self): | |
""" | |
InstructionTranslator has encountered instructions which may cause | |
dynamo to see a different version of history from eager | |
See: https://github.com/pytorch/pytorch/issues/110765 | |
""" | |
self.inconsistent_side_effects = True | |
def has_backedge(self): | |
cur_offset = self.current_instruction.offset | |
assert self.instruction_pointer is not None | |
for inst in self.instructions[self.instruction_pointer :]: | |
if inst.opname in JUMP_OPNAMES: | |
jump_offset = inst.argval | |
if jump_offset < cur_offset: | |
return True | |
return False | |
def cell_and_freevars(self): | |
if not hasattr(self, "_cell_and_freevars"): | |
self._cell_and_freevars = tuple( | |
self.code_options["co_cellvars"] or [] | |
) + tuple(self.code_options["co_freevars"] or []) | |
return self._cell_and_freevars | |
def prune_dead_locals(self): | |
reads = livevars_analysis(self.instructions, self.current_instruction) | |
# implicit use by super() | |
# reads = reads | {"__class__"} | |
# output variables? | |
reads = reads | set(self.cell_and_freevars()) | |
self.symbolic_locals = { | |
k: v for k, v in self.symbolic_locals.items() if k in reads | |
} | |
self.output.side_effects.prune_dead_object_new(self) | |
def call_function( | |
self, | |
fn: VariableTracker, | |
args: List[VariableTracker], | |
kwargs: Dict[str, VariableTracker], | |
): | |
assert isinstance(fn, VariableTracker) | |
assert isinstance(args, list) | |
assert isinstance(kwargs, dict) | |
assert all( | |
isinstance(x, VariableTracker) | |
for x in itertools.chain(args, kwargs.values()) | |
) | |
inner_fn = None | |
if hasattr(fn, "value"): | |
inner_fn = fn.value | |
if hasattr(fn, "fn"): | |
inner_fn = fn.fn | |
if inner_fn and callable(inner_fn) and is_forbidden(inner_fn): | |
raise AssertionError(f"Attempt to trace forbidden callable {inner_fn}") | |
self.push(fn.call_function(self, args, kwargs)) | |
def inline_user_function_return(self, fn, args, kwargs): | |
""" | |
A call to some user defined function by inlining it. | |
""" | |
return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) | |
def get_line_of_code_header(self, lineno=None): | |
if lineno is None: | |
lineno = self.lineno | |
inline_depth_str = ( | |
f" (inline depth: {self.inline_depth})" if self.inline_depth > 0 else "" | |
) | |
funcname = get_funcname(self.f_code.co_filename, lineno) | |
funcname_str = "" if funcname is None else f" ({funcname})" | |
return f"{self.f_code.co_filename}:{lineno} in {self.f_code.co_name}{funcname_str}{inline_depth_str}" | |
def get_log_starts_line_log_str(self): | |
log_str = f"TRACE starts_line {self.get_line_of_code_header()}\n" | |
line = linecache.getline(self.f_code.co_filename, self.lineno).rstrip() | |
log_str += f" {line}" | |
return log_str | |
def log_starts_line(self): | |
trace_source_log.debug("%s", LazyString(self.get_log_starts_line_log_str)) | |
def step(self): | |
"""Process exactly one instruction, return False we should exit""" | |
assert isinstance(self.instruction_pointer, int) | |
inst = self.instructions[self.instruction_pointer] | |
self.current_instruction = inst | |
self.instruction_pointer += 1 | |
if self.instruction_pointer < len(self.instructions): | |
self.next_instruction = self.instructions[self.instruction_pointer] | |
else: | |
self.instruction_pointer = None | |
self.next_instruction = None | |
if inst.starts_line and self.lineno != inst.starts_line: | |
self.lineno = inst.starts_line | |
self.log_starts_line() | |
if ( | |
len(self.stack) == 0 | |
and self.should_compile_partial_graph() | |
and self.is_non_empty_graph() | |
): | |
self.current_speculation = self.speculate() | |
if self.current_speculation.failed: | |
return self.step_graph_break(inst) | |
log.debug("TRACE %s %s %s", inst.opname, inst.argval, self.stack) | |
# 3.11 no longer uses a block stack, but we still keep track of one | |
# so that we know which contexts are currently active. | |
# For our purposes, all exception table entries with the same target | |
# are considered to be part of the same "block". | |
if sys.version_info >= (3, 11): | |
entry = inst.exn_tab_entry | |
if not ( | |
# still in the same block | |
self.block_stack | |
and entry | |
and self.block_stack[-1].target is entry.target | |
): | |
if not entry: | |
# no longer in any block | |
# It is possible for NOPs to be between two instructions | |
# in the same block, but the NOPs are not covered by an | |
# exception table entry. In this case, assume that we | |
# are still in the same block. | |
if self.block_stack and inst.opname != "NOP": | |
# If we really escape from a block and the current | |
# instruction is not in another block, then there | |
# should be no other nested blocks that we are in. | |
assert len(self.block_stack) == 1 | |
self.block_stack.pop() | |
elif ( | |
# current instruction is in the previous block | |
len(self.block_stack) > 1 | |
and self.block_stack[-2].target is entry.target | |
): | |
# exit the current block | |
self.block_stack.pop() | |
else: | |
# current instruction is in a new block | |
# push block to stack - note, BEFORE_WITH blocks won't | |
# be pushed here since BEFORE_WITH pushes the block, and | |
# the current instruction would be counted as being in that block. | |
self.block_stack.append( | |
BlockStackEntry(entry.target, len(self.stack)) | |
) | |
try: | |
if not hasattr(self, inst.opname): | |
unimplemented(f"missing: {inst.opname}") | |
TracingContext.set_current_loc( | |
self.f_code.co_filename, self.lineno, self.f_code.co_name | |
) | |
getattr(self, inst.opname)(inst) | |
return inst.opname != "RETURN_VALUE" | |
except Unsupported: | |
if self.current_speculation is None: | |
log.debug("empty checkpoint") | |
raise | |
log.debug("step triggered compile", exc_info=True) | |
self.current_speculation.fail_and_restart_analysis() | |
def step_graph_break(self, continue_inst): | |
# generate code from checkpoint | |
assert not self.output.output_instructions | |
assert self.current_speculation is not None | |
self.output.compile_subgraph( | |
self, | |
partial_convert=True, | |
reason=GraphCompileReason("step_unsupported", [self.frame_summary()]), | |
) | |
self.output.add_output_instructions( | |
[create_jump_absolute(continue_inst)] + self.instructions | |
) | |
def run_ctx_mgr(self): | |
# NB: Don't push the top level frame summary; set_current_loc will | |
# take care of it. However, DO make sure we attach real_stack to | |
# exceptions | |
return TracingContext.current_frame(None) | |
def run(self): | |
with self.run_ctx_mgr(): | |
try: | |
self.output.push_tx(self) | |
while ( | |
self.instruction_pointer is not None | |
and not self.output.should_exit | |
and self.step() | |
): | |
pass | |
except BackendCompilerFailed: | |
raise | |
except Exception as e: | |
if config.replay_record_enabled: | |
e.exec_record = self.exec_recorder.get_record() # type: ignore[attr-defined] | |
raise | |
finally: | |
self.output.pop_tx() | |
# Cleanup the outputGraph to delete the held tensors. We perform the | |
# cleanup only for InstructionTranslator and not | |
# InliningInstructionTranslator. The InliningInstructionTranslator | |
# mutates the output object and is restored to original state if | |
# there was an exception. | |
if isinstance(self, InstructionTranslator): | |
self.output.cleanup() | |
def push(self, val: Optional[VariableTracker]): | |
assert val is None or isinstance( | |
val, VariableTracker | |
), f"push expects VariableTracker, got {typestr(val)}" | |
self.stack.append(val) # type: ignore[arg-type] | |
def push_many(self, vals: List[VariableTracker]): | |
for val in vals: | |
self.push(val) | |
def pop(self) -> VariableTracker: | |
return self.stack.pop() | |
def popn(self, n: int) -> List[VariableTracker]: | |
assert n >= 0 | |
return list(reversed([self.pop() for _ in range(n)])) | |
def LOAD_FAST(self, inst): | |
name = inst.argval | |
if name in self.f_locals and config.replay_record_enabled: | |
self.exec_recorder.add_local_var(name, self.f_locals[name]) | |
if name.startswith(".") and name not in self.symbolic_locals: | |
# This happens in dict/list comprehensions | |
name = name.replace(".", "implicit") | |
assert name not in self.cell_and_freevars() | |
if name not in self.symbolic_locals: | |
unimplemented("undefined LOAD_FAST") | |
self.push(self.symbolic_locals[name]) | |
if name.startswith("___stack"): | |
self.symbolic_locals.pop(name) | |
def LOAD_DEREF(self, inst): | |
assert inst.argval in self.cell_and_freevars() | |
if inst.argval in self.f_locals and config.replay_record_enabled: | |
self.exec_recorder.add_local_var(inst.argval, self.f_locals[inst.argval]) | |
if inst.argval not in self.symbolic_locals: | |
unimplemented(f"undefined LOAD_DEREF {inst.argval}") | |
self.push(self.symbolic_locals[inst.argval]) | |
def STORE_FAST(self, inst): | |
loaded_vt = self.pop() | |
name = inst.argval | |
# Only rename at the top-level scope, this is to avoid the confusion between | |
# mutating a variable vs renaming it (e.g. a = b) during speculating a higher order op, | |
# where mutation is prohibited and it's difficult to differentiate it with renaming. | |
if _is_top_level_scope(current_scope_id()): | |
loaded_vt = loaded_vt.rename(self, name) | |
self.symbolic_locals[name] = loaded_vt | |
def DELETE_FAST(self, inst): | |
del self.symbolic_locals[inst.argval] | |
STORE_DEREF = STORE_FAST | |
def LOAD_CLOSURE(self, inst): | |
self.push(ClosureVariable(name=inst.argval)) | |
def LOAD_CONST(self, inst): | |
# For empty tuples, create empty TupleVariable | |
if isinstance(inst.argval, tuple) and not inst.argval: | |
self.push(TupleVariable([])) | |
else: | |
self.push(ConstantVariable.create(value=inst.argval)) | |
def get_global_source(self, name): | |
source: Source | |
if self.output.global_scope is self.f_globals: | |
source = GlobalSource(name) | |
else: | |
if "__name__" in self.f_globals: | |
source = AttrSource( | |
self.import_source(self.f_globals["__name__"]), name | |
) | |
else: | |
mangled_name = self.output.install_global_by_id( | |
"___unnamed_scope", self.f_globals | |
) | |
source = GetItemSource(GlobalSource(mangled_name), name) | |
return source | |
def LOAD_GLOBAL(self, inst): | |
if sys.version_info >= (3, 11): | |
if inst.arg % 2: | |
self.PUSH_NULL(inst) | |
name = inst.argval | |
if config.replay_record_enabled: | |
if name in self.f_globals: | |
self.exec_recorder.add_global_var(name, self.f_globals[name]) | |
else: | |
assert name in self.f_builtins | |
self.exec_recorder.builtins[name] = self.f_builtins[name] | |
if inst.argval == "AssertionError": | |
unimplemented("assert with non-string message") | |
if name in self.symbolic_globals: | |
variable = self.output.side_effects[self.symbolic_globals[name]] | |
self.push(self.output.side_effects.load_global(variable, name)) | |
return | |
try: | |
value = self.f_globals[name] | |
except KeyError: | |
return self.load_builtin(inst) | |
source = self.get_global_source(name) | |
self.push(VariableBuilder(self, source)(value)) | |
def STORE_GLOBAL(self, inst): | |
value = self.pop() | |
name = inst.argval | |
source = self.get_global_source(name) | |
if name not in self.symbolic_globals: | |
self.symbolic_globals[name] = object() # type: ignore[assignment] # sentinel object | |
variable = self.output.side_effects.track_global_existing( | |
source, self.symbolic_globals[name] | |
) | |
if isinstance(value, RemovableHandleVariable): | |
unimplemented("Storing handles in globals - NYI") | |
self.output.side_effects.store_global(variable, name, value) | |
def import_source(self, module_name): | |
"""Create an alias to a module for use in guards""" | |
if "torch_package" in module_name: | |
value = torch.package.package_importer._package_imported_modules[ | |
module_name | |
] | |
alias = ( | |
module_name.replace(">", "_").replace("<", "_").replace(".", "_dot_") | |
) | |
else: | |
value = importlib.import_module(module_name) | |
alias = f"__import_{module_name.replace('.', '_dot_')}" | |
f_globals = self.output.global_scope | |
assert alias not in f_globals or f_globals[alias] is value | |
f_globals[alias] = value | |
self.output.update_co_names(alias) | |
return GlobalSource(alias) | |
def resolve_name(self, name, package, level): | |
""" | |
Copied from the Cpython implementation of __import__ | |
Resolve a relative module name to an absolute one. | |
https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L902 | |
""" | |
bits = package.rsplit(".", level - 1) | |
if len(bits) < level: | |
raise ImportError("attempted relative import beyond top-level package") | |
base = bits[0] | |
return f"{base}.{name}" if name else base | |
def calc_package(self): | |
""" | |
Copied from the Cpython implementation of __import__ | |
https://github.com/python/cpython/blob/5a094f0255eea1db58fb2cf14c200971e64ec36e/Lib/importlib/_bootstrap.py#L1090 | |
""" | |
package = self.f_globals.get("__package__") | |
spec = self.f_globals.get("__spec__") | |
if package is not None: | |
if spec is not None and package != spec.parent: | |
log.warning( | |
"__package__ != __spec__.parent (%r != %r)", | |
package, | |
spec.parent, | |
stacklevel=3, | |
) | |
return package | |
elif spec is not None: | |
return spec.parent | |
else: | |
log.warning( | |
"can't resolve package from __spec__ or __package__, " | |
"falling back on __name__ and __path__", | |
stacklevel=3, | |
) | |
package = self.f_globals["__name__"] | |
if "__path__" not in self.f_globals: | |
package = package.rpartition(".")[0] | |
return package | |
def IMPORT_NAME(self, inst): | |
level, fromlist = self.popn(2) | |
level = level.as_python_constant() | |
fromlist = fromlist.as_python_constant() | |
module_name = inst.argval | |
# Are we replaying? if so, load recorded module | |
recorded_name = ( | |
f"{ExecutionRecorder.LOCAL_MOD_PREFIX}_{level}_{fromlist}_{module_name}" | |
) | |
if recorded_name in self.f_globals: | |
value = self.f_globals[recorded_name] | |
source = GlobalSource(recorded_name) | |
else: | |
value = __import__( | |
module_name, | |
fromlist=fromlist, | |
level=level, | |
globals=self.f_globals, | |
) | |
if level != 0: | |
pkg = self.calc_package() | |
module_name = self.resolve_name(module_name, pkg, level) | |
# For __import__, when the name variable is of the form package.module, | |
# normally, the top-level package (the name up till the first dot) is | |
# returned, not the module named by module_name. However, when a | |
# non-empty fromlist argument is given, the module named by name is | |
# returned. Therefore, we set the source correctly here. | |
if not fromlist: | |
top_level_module_name = module_name.partition(".")[0] | |
source = self.import_source(top_level_module_name) | |
else: | |
source = self.import_source(module_name) | |
if config.replay_record_enabled: | |
self.exec_recorder.add_local_mod(recorded_name, value) | |
if istype(value, (types.ModuleType, DummyModule)): | |
self.push(PythonModuleVariable(value, source=source)) | |
else: | |
unimplemented(f"IMPORT_NAME {typestr(value)}") | |
def IMPORT_FROM(self, inst): | |
self.DUP_TOP(inst) | |
self.LOAD_ATTR(inst) | |
def load_builtin(self, inst): | |
if inst.argval not in self.f_builtins: | |
raise NameError(f"name '{inst.argval}' is not defined") | |
val = self.f_builtins[inst.argval] | |
if callable(val): | |
self.push(VariableBuilder(self, GlobalSource(inst.argval))(val)) | |
else: | |
assert is_builtin_constant(val) | |
self.push(ConstantVariable.create(value=val)) | |
def jump(self, inst): | |
self.instruction_pointer = self.indexof[inst.target] | |
JUMP_FORWARD = jump | |
JUMP_ABSOLUTE = jump | |
POP_JUMP_IF_FALSE = generic_jump(operator.not_, False) | |
POP_JUMP_IF_TRUE = generic_jump(operator.truth, False) | |
JUMP_IF_FALSE_OR_POP = generic_jump(operator.not_, True) | |
JUMP_IF_TRUE_OR_POP = generic_jump(operator.truth, True) | |
def SETUP_LOOP(self, inst): | |
# only exists in python<=3.7 | |
self.block_stack.append(BlockStackEntry(inst.target)) | |
def SETUP_EXCEPT(self, inst): | |
# only exists in python<=3.7 | |
self.block_stack.append(BlockStackEntry(inst.target)) | |
def POP_BLOCK(self, inst): | |
self.block_stack.pop() | |
def SETUP_WITH(self, inst): | |
self.setup_or_before_with(inst) | |
def SETUP_FINALLY(self, inst): | |
self.block_stack.append(BlockStackEntry(inst.target)) | |
def BEGIN_FINALLY(self, inst): | |
self.push(None) | |
def WITH_CLEANUP_START(self, inst): | |
exit, exc = self.popn(2) | |
assert exc is None | |
self.push(exc) | |
self.push(exit.call_function(self, [ConstantVariable.create(None)] * 3, {})) | |
def WITH_CLEANUP_FINISH(self, inst): | |
self.popn(2) | |
self.push(None) | |
def CALL_FINALLY(self, inst): | |
""" | |
pushes the address of the next instruction onto the stack and increments | |
bytecode counter by delta | |
""" | |
# Python 3.8 only | |
assert self.next_instruction is not None | |
addr = self.indexof[self.next_instruction] | |
self.push(ConstantVariable.create(addr)) | |
self.instruction_pointer = self.indexof[inst.target] | |
def END_FINALLY(self, inst): | |
# Python 3.8 only | |
# https://docs.python.org/3.8/library/dis.html#opcode-END_FINALLY | |
tos = self.pop() | |
if isinstance(tos, ConstantVariable): | |
self.instruction_pointer = tos.as_python_constant() | |
else: | |
pass | |
def POP_FINALLY(self, inst): | |
# Python 3.8 only | |
preserve_tos = inst.argval | |
if preserve_tos: | |
tos = self.pop() | |
_ = self.pop() | |
if preserve_tos: | |
self.push(tos) # type: ignore[possibly-undefined] | |
def FOR_ITER(self, inst): | |
it = self.pop().realize() | |
if isinstance(it, (variables.ListIteratorVariable, variables.IteratorVariable)): | |
try: | |
val, next_iter = it.next_variables(self) | |
self.push(next_iter) | |
self.push(val) | |
except StopIteration: | |
self.jump(inst) | |
else: | |
unimplemented(f"FOR_ITER {typestr(it)}") | |
def COMPARE_OP(self, inst): | |
left, right = self.popn(2) | |
op = inst.argval | |
supported_any = dict( | |
itertools.chain( | |
supported_tensor_comparison_ops.items(), | |
supported_const_comparison_ops.items(), | |
) | |
) | |
if ( | |
isinstance( | |
left, | |
( | |
TensorVariable, | |
SymNodeVariable, | |
NNModuleVariable, | |
BaseListVariable, | |
UserDefinedVariable, | |
BaseUserFunctionVariable, | |
ConstDictVariable, | |
), | |
) | |
and isinstance(right, ConstantVariable) | |
and right.value is None | |
and op in supported_const_comparison_ops | |
): | |
# <non-None> is None | |
self.push( | |
ConstantVariable.create( | |
supported_const_comparison_ops[op](object(), right.value) | |
) | |
) | |
elif ( | |
left.is_python_constant() | |
and right.is_python_constant() | |
and op in supported_any | |
): | |
# constant fold | |
self.push( | |
ConstantVariable.create( | |
supported_any[op]( | |
left.as_python_constant(), right.as_python_constant() | |
), | |
) | |
) | |
elif op in ("in", "not in"): | |
self.push(right.call_method(self, "__contains__", [left], {})) | |
if op == "not in": | |
self.UNARY_NOT(inst) | |
else: | |
self.push( | |
BuiltinVariable(supported_any[op]).call_function( | |
self, [left, right], {} | |
) | |
) | |
def GET_ITER(self, inst): | |
self.call_function(BuiltinVariable(iter), [self.pop()], {}) | |
def CALL_FUNCTION(self, inst): | |
args = self.popn(inst.argval) | |
fn = self.pop() | |
self.call_function(fn, args, {}) | |
def CALL_FUNCTION_EX(self, inst): | |
kwargsvars: VariableTracker | |
if inst.argval == 0: | |
kwargsvars = ConstDictVariable({}) | |
argsvars = self.pop() | |
elif inst.argval == 1: | |
kwargsvars = self.pop() | |
argsvars = self.pop() | |
else: | |
unimplemented("CALL_FUNCTION_EX") | |
fn = self.pop() | |
if sys.version_info >= (3, 11): | |
null = self.pop() | |
assert isinstance(null, NullVariable) | |
if ( | |
isinstance(fn, GetAttrVariable) | |
and isinstance(fn.obj, TensorVariable) | |
and fn.name == "view" | |
and isinstance(argsvars, (ConstantVariable, TensorVariable)) | |
): | |
# Hack to handle special case in some bert models. Converts | |
# x.view(*shape) into x.view(shape), which is correct for view() | |
# but not generally. See test_transpose_for_scores(). | |
argsvars = TupleVariable([argsvars]) | |
if not isinstance( | |
argsvars, BaseListVariable | |
) and argsvars.has_unpack_var_sequence(self): | |
argsvars = TupleVariable(argsvars.unpack_var_sequence(self)) | |
if not isinstance(argsvars, BaseListVariable) or not isinstance( | |
kwargsvars, ConstDictVariable | |
): | |
unimplemented(f"non-static call {typestr(argsvars)} {typestr(kwargsvars)}") | |
# Map to a dictionary of str -> VariableTracker | |
kwargsvars = kwargsvars.keys_as_python_constant() | |
self.call_function(fn, argsvars.items, kwargsvars) | |
def CALL_FUNCTION_KW(self, inst): | |
argnames = self.pop() | |
args = self.popn(inst.argval) | |
fn = self.pop() | |
assert isinstance(argnames, TupleVariable) and argnames.is_python_constant() | |
argnames = argnames.as_python_constant() | |
args, kwargs_list = args[: -len(argnames)], args[-len(argnames) :] | |
kwargs = dict(zip(argnames, kwargs_list)) | |
assert len(kwargs) == len(argnames) | |
self.call_function(fn, args, kwargs) | |
def LOAD_METHOD_SUPER(self, inst): | |
self.CALL_FUNCTION(dataclasses.replace(inst, argval=2)) | |
arg = inst.argval[0] | |
argval = self.code_options["co_names"][arg] | |
if sys.version_info < (3, 11): | |
self.LOAD_ATTR(dataclasses.replace(inst, argval=argval)) | |
else: | |
self.LOAD_METHOD(dataclasses.replace(inst, argval=argval)) | |
def LOAD_ATTR_SUPER(self, inst): | |
self.CALL_FUNCTION(dataclasses.replace(inst, argval=2)) | |
arg = inst.argval[0] | |
argval = self.code_options["co_names"][arg] | |
self.LOAD_ATTR(dataclasses.replace(inst, argval=argval)) | |
def LOAD_METHOD(self, inst): | |
self.LOAD_ATTR(inst) | |
obj = self.pop() | |
if sys.version_info >= (3, 11): | |
# always follow the NULL + fn convention, since if obj | |
# is actually a method, self is already bound to it, so it | |
# doesn't need to be passed in as an arg. | |
self.PUSH_NULL(inst) | |
self.push(obj) | |
else: | |
self.push(obj) | |
self.push(None) | |
def CALL_METHOD(self, inst): | |
args = self.popn(inst.argval) | |
dummy = self.pop() | |
assert dummy is None | |
fn = self.pop() | |
self.call_function(fn, args, {}) | |
def LOAD_ATTR(self, inst): | |
obj = self.pop() | |
result = BuiltinVariable(getattr).call_function( | |
self, [obj, ConstantVariable.create(inst.argval)], {} | |
) | |
self.push(result) | |
def STORE_ATTR(self, inst): | |
speculation = self.speculate() | |
if speculation.failed: | |
return self.store_attr_graph_break(inst) | |
val, obj = self.popn(2) | |
if isinstance(obj, NNModuleVariable): | |
# We don't allow side effects during export | |
# https://github.com/pytorch/torchdynamo/issues/1475 | |
assert ( | |
not self.export | |
), f"Mutating module attribute {inst.argval} during export." | |
try: | |
BuiltinVariable(setattr).call_function( | |
self, [obj, ConstantVariable.create(inst.argval), val], {} | |
) | |
return | |
except Unsupported as e: | |
if not self.should_compile_partial_graph(): | |
raise | |
log.debug("STORE_ATTR triggered compile", exc_info=True) | |
e.remove_from_stats() | |
e.add_to_stats("graph_break") | |
speculation.fail_and_restart_analysis() | |
def store_attr_graph_break(self, inst): | |
self.output.compile_subgraph( | |
self, reason=GraphCompileReason("store_attr", [self.frame_summary()]) | |
) | |
self.output.add_output_instructions([copy.copy(inst)]) | |
self.popn(2) | |
self.output.add_output_instructions( | |
self.create_call_resume_at(self.next_instruction) | |
) | |
def DELETE_ATTR(self, inst): | |
obj = self.pop() | |
BuiltinVariable(delattr).call_function( | |
self, [obj, ConstantVariable.create(inst.argval)], {} | |
) | |
def create_call_resume_at(self, offset): | |
raise AssertionError( | |
f"create_call_resume_at not overridden by subclass {type(self)}" | |
) | |
def should_compile_partial_graph(self) -> bool: | |
raise AssertionError( | |
f"should_compile_partial_graph not overridden by subclass {type(self)}" | |
) | |
def STORE_SUBSCR(self, inst): | |
val, obj, key = self.popn(3) | |
result = obj.call_method(self, "__setitem__", [key, val], {}) | |
def BUILD_TUPLE(self, inst): | |
items = self.popn(inst.argval) | |
self.push(TupleVariable(items)) | |
def BUILD_SLICE(self, inst): | |
items = self.popn(inst.argval) | |
self.push(SliceVariable(items)) | |
def BUILD_LIST(self, inst): | |
items = self.popn(inst.argval) | |
self.push(ListVariable(items, mutable_local=MutableLocal())) | |
def BUILD_SET(self, inst): | |
if config.inject_BUILD_SET_unimplemented_TESTING_ONLY: | |
unimplemented("missing: BUILD_SET") | |
items = self.popn(inst.argval) | |
new_set = SetVariable(items, mutable_local=MutableLocal()) | |
self.push(new_set) | |
def BUILD_LIST_UNPACK(self, inst, cls=ListVariable): | |
seqs = self.popn(inst.argval) | |
items = list() | |
for seq in seqs: | |
try: | |
items.extend(seq.unpack_var_sequence(self)) | |
except NotImplementedError: | |
unimplemented(f"BUILD_LIST_UNPACK {seq}") | |
self.push(cls(items, mutable_local=MutableLocal())) | |
def BUILD_TUPLE_UNPACK(self, inst): | |
self.BUILD_LIST_UNPACK(inst, cls=TupleVariable) | |
BUILD_TUPLE_UNPACK_WITH_CALL = BUILD_TUPLE_UNPACK | |
def BUILD_MAP(self, inst): | |
items = self.popn(inst.argval * 2) | |
d = dict(zip(items[::2], items[1::2])) | |
self.push(ConstDictVariable(d, mutable_local=MutableLocal())) | |
def BUILD_MAP_UNPACK(self, inst): | |
items = self.popn(inst.argval) | |
# ensure everything is a dict | |
items = [BuiltinVariable(dict).call_function(self, [x], {}) for x in items] | |
result = dict() | |
for x in items: | |
assert isinstance(x, ConstDictVariable) | |
result.update(x.items) | |
self.push( | |
ConstDictVariable( | |
result, | |
mutable_local=MutableLocal(), | |
) | |
) | |
BUILD_MAP_UNPACK_WITH_CALL = BUILD_MAP_UNPACK | |
def BUILD_CONST_KEY_MAP(self, inst): | |
keys = self.pop() | |
values = self.popn(inst.argval) | |
assert isinstance(keys, TupleVariable) | |
assert keys.is_python_constant() | |
keys = keys.unpack_var_sequence(self) | |
assert len(keys) == len(values) | |
self.push( | |
ConstDictVariable( | |
dict(zip(keys, values)), | |
mutable_local=MutableLocal(), | |
) | |
) | |
def MAP_ADD(self, inst): | |
k, v = self.popn(2) | |
assert inst.argval > 0 | |
obj = self.stack[-inst.arg].realize() | |
assert isinstance(obj, ConstDictVariable) | |
obj.call_method(self, "__setitem__", (k, v), {}) # type: ignore[arg-type] | |
def SET_ADD(self, inst): | |
v = self.pop() | |
assert inst.argval > 0 | |
obj = self.stack[-inst.arg] | |
assert isinstance(obj, SetVariable) | |
assert obj.mutable_local | |
return obj.call_method(self, "add", [v], {}) | |
def LIST_APPEND(self, inst): | |
v = self.pop() | |
assert inst.argval > 0 | |
obj = self.stack[-inst.arg].realize() | |
assert isinstance(obj, ListVariable) | |
assert obj.mutable_local | |
self.output.side_effects.mutation(obj) | |
obj.items.append(v) | |
def MAKE_FUNCTION(self, inst): | |
flags = inst.arg | |
old_stack = list(self.stack) | |
if sys.version_info < (3, 11): | |
fn_name = self.pop() | |
code = self.pop() | |
if sys.version_info >= (3, 11): | |
# MAKE_FUNCTION behavior actually changed in 3.11, see | |
# https://github.com/python/cpython/pull/93189/ | |
assert hasattr(code.value, "co_qualname") # type: ignore[attr-defined] | |
fn_name = ConstantVariable.create(value=code.value.co_qualname) # type: ignore[attr-defined] | |
defaults = None | |
closure = None | |
annotations = None | |
kwdefaults = None | |
if flags & 0x08: | |
closure = self.pop() | |
if flags & 0x04: | |
annotations = self.pop() | |
if flags & 0x02: | |
kwdefaults = self.pop() | |
if flags & 0x01: | |
defaults = self.pop() | |
self.push( | |
NestedUserFunctionVariable( | |
fn_name, | |
code, | |
self.f_globals, | |
defaults, | |
kwdefaults, | |
annotations, | |
closure, | |
closure_scope=self, | |
) | |
) | |
def UNPACK_SEQUENCE(self, inst): | |
seq = self.pop() | |
if isinstance(seq, TensorVariable): | |
val = seq.unpack_var_sequence(self, idxes=range(inst.argval)) | |
elif isinstance(seq, GetAttrVariable) and isinstance(seq.obj, TensorVariable): | |
# x, y = a.shape | |
proxy = getattr(seq.obj.as_proxy(), seq.name) | |
val = [wrap_fx_proxy(self, proxy[i]) for i in range(inst.argval)] | |
elif seq.has_unpack_var_sequence(self): | |
val = seq.unpack_var_sequence(self) | |
else: | |
unimplemented(f"UNPACK_SEQUENCE {seq}") | |
if len(val) != inst.argval: | |
unimplemented("UNPACK_SEQUENCE length mismatch") | |
for i in reversed(val): | |
self.push(i) | |
def UNPACK_EX(self, inst): | |
assert 0 <= inst.argval <= 0xFFFF | |
prefix = inst.argval & 0xFF # low byte | |
suffix = inst.argval >> 8 # high byte | |
seq = self.pop() | |
if seq.has_unpack_var_sequence(self): | |
vals = list(seq.unpack_var_sequence(self)) | |
assert len(vals) >= prefix + suffix | |
vals_prefix = vals[:prefix] | |
vals_list = vals[prefix : len(vals) - suffix] | |
vals_suffix = vals[len(vals) - suffix :] | |
for item in reversed(vals_suffix): | |
self.push(item) | |
self.push(TupleVariable(vals_list)) | |
for item in reversed(vals_prefix): | |
self.push(item) | |
else: | |
unimplemented(f"UNPACK_EX {seq}") | |
def NOP(self, inst): | |
pass | |
def POP_TOP(self, inst): | |
self.pop() | |
def ROT_TWO(self, inst): | |
a = self.pop() | |
b = self.pop() | |
self.push(a) | |
self.push(b) | |
def ROT_THREE(self, inst): | |
a = self.pop() | |
b = self.pop() | |
c = self.pop() | |
self.push(a) | |
self.push(c) | |
self.push(b) | |
def ROT_FOUR(self, inst): | |
a = self.pop() | |
b = self.pop() | |
c = self.pop() | |
d = self.pop() | |
self.push(a) | |
self.push(d) | |
self.push(c) | |
self.push(b) | |
def DUP_TOP(self, inst): | |
a = self.pop() | |
self.push(a) | |
self.push(a) | |
def DUP_TOP_TWO(self, inst): | |
a = self.pop() | |
b = self.pop() | |
self.push(b) | |
self.push(a) | |
self.push(b) | |
self.push(a) | |
def FORMAT_VALUE(self, inst): | |
flags = inst.arg | |
if (flags & 0x04) == 0x04: | |
fmt_spec = self.pop() | |
else: | |
fmt_spec = ConstantVariable.create("") | |
value = self.pop() | |
if isinstance(value, SymNodeVariable): | |
value = ConstantVariable.create(str(value.sym_num)) | |
if (flags & 0x03) == 0x01: | |
value = BuiltinVariable(str).call_function(self, [value], {}) | |
elif (flags & 0x03) == 0x02: | |
value = BuiltinVariable(repr).call_function(self, [value], {}) | |
elif (flags & 0x03) == 0x03: | |
value = BuiltinVariable(ascii).call_function(self, [value], {}) | |
fmt_var = ConstantVariable.create("{:" + fmt_spec.as_python_constant() + "}") | |
self.call_function(BuiltinVariable(str.format), [fmt_var, value], {}) | |
def BUILD_STRING(self, inst): | |
format_string_parts: List[str] = [] | |
args: List[VariableTracker] = [] | |
kwargs: Dict[str, VariableTracker] = {} | |
for part in self.popn(inst.arg): | |
if isinstance(part, ConstantVariable): | |
format_string_parts.append("{}") | |
args.append(part) | |
elif isinstance(part, variables.StringFormatVariable): | |
format_string_parts.append(part.format_string) | |
args.extend(part.sym_args) | |
if set(kwargs.keys()) & set(part.sym_kwargs.keys()): | |
unimplemented( | |
f"BUILD_STRING key conflict {kwargs} & {part.sym_kwargs}" | |
) | |
kwargs.update(part.sym_kwargs) | |
else: | |
unimplemented(f"BUILD_STRING {part}") | |
self.push( | |
variables.StringFormatVariable.create( | |
"".join(format_string_parts), args, kwargs | |
) | |
) | |
def IS_OP(self, inst): | |
assert inst.argval == 0 or inst.argval == 1 | |
if inst.argval == 0: | |
new_argval = "is" | |
else: | |
new_argval = "is not" | |
new_inst = create_instruction("COMPARE_OP", argval=new_argval) | |
self.COMPARE_OP(new_inst) | |
def CONTAINS_OP(self, inst): | |
assert inst.argval == 0 or inst.argval == 1 | |
left, right = self.popn(2) | |
op = inst.argval | |
self.push(right.call_method(self, "__contains__", [left], {})) | |
if op == 1: | |
self.UNARY_NOT(inst) | |
def LIST_EXTEND(self, inst): | |
v = self.pop() | |
assert inst.argval > 0 | |
obj = self.stack[-inst.arg] | |
assert isinstance(obj, ListVariable) | |
assert obj.mutable_local | |
obj.call_method(self, "extend", [v], {}) | |
def LIST_TO_TUPLE(self, inst): | |
self.push(BuiltinVariable(tuple).call_function(self, [self.pop()], {})) | |
def DICT_MERGE(self, inst): | |
v = self.pop() | |
assert inst.argval > 0 | |
obj = self.stack[-inst.arg].realize() | |
assert isinstance(obj, ConstDictVariable) | |
assert obj.mutable_local | |
obj.call_method(self, "update", [v], {}) | |
DICT_UPDATE = DICT_MERGE | |
def GEN_START(self, inst): | |
self.pop() | |
def GET_LEN(self, inst): | |
tos = self.stack[-1] | |
if tos.is_python_constant(): | |
self.push(ConstantVariable.create(len(tos.as_python_constant()))) | |
else: | |
self.push(tos.call_method(self, "__len__", [], {})) | |
def MATCH_MAPPING(self, inst): | |
tos = self.stack[-1] | |
assert isinstance(tos, ConstDictVariable) | |
if isinstance(tos.items, collections.abc.Mapping): | |
self.push(ConstantVariable.create(True)) | |
else: | |
self.push(ConstantVariable.create(False)) | |
def MATCH_SEQUENCE(self, inst): | |
tos = self.stack[-1] | |
assert tos.is_python_constant() | |
tos_value = tos.as_python_constant() | |
if isinstance(tos_value, collections.abc.Sequence) and not isinstance( | |
tos_value, (str, bytes, bytearray) | |
): | |
self.push(ConstantVariable.create(True)) | |
else: | |
self.push(ConstantVariable.create(False)) | |
def MATCH_KEYS(self, inst): | |
tos = self.stack[-1] | |
tos1 = self.stack[-2] | |
assert isinstance(tos1, ConstDictVariable) | |
if all(k in tos1 for k in tos): # type: ignore[attr-defined] | |
self.push(TupleVariable([tos1.getitem_const(k) for k in tos])) # type: ignore[attr-defined] | |
if sys.version_info < (3, 11): | |
self.push(ConstantVariable.create(True)) | |
else: | |
self.push(ConstantVariable.create(None)) | |
if sys.version_info < (3, 11): | |
self.push(ConstantVariable.create(False)) | |
def LOAD_ASSERTION_ERROR(self, inst): | |
unimplemented("assert with non-string message") | |
UNARY_POSITIVE = stack_op(operator.pos) | |
UNARY_NEGATIVE = stack_op(operator.neg) | |
UNARY_NOT = stack_op(operator.not_) | |
UNARY_INVERT = stack_op(operator.invert) | |
BINARY_POWER = stack_op(operator.pow) | |
BINARY_MULTIPLY = stack_op(operator.mul) | |
BINARY_MATRIX_MULTIPLY = stack_op(operator.matmul) | |
BINARY_FLOOR_DIVIDE = stack_op(operator.floordiv) | |
BINARY_TRUE_DIVIDE = stack_op(operator.truediv) | |
BINARY_MODULO = stack_op(operator.mod) | |
BINARY_REMAINDER = stack_op(operator.mod) | |
BINARY_ADD = stack_op(operator.add) | |
BINARY_SUBTRACT = stack_op(operator.sub) | |
BINARY_SUBSCR = break_graph_if_unsupported(push=1)(stack_op(operator.getitem)) | |
BINARY_LSHIFT = stack_op(operator.lshift) | |
BINARY_RSHIFT = stack_op(operator.rshift) | |
BINARY_AND = stack_op(operator.and_) | |
BINARY_OR = stack_op(operator.or_) | |
BINARY_XOR = stack_op(operator.xor) | |
INPLACE_POWER = stack_op(operator.ipow) | |
INPLACE_MULTIPLY = stack_op(operator.imul) | |
INPLACE_MATRIX_MULTIPLY = stack_op(operator.imatmul) | |
INPLACE_FLOOR_DIVIDE = stack_op(operator.ifloordiv) | |
INPLACE_TRUE_DIVIDE = stack_op(operator.itruediv) | |
INPLACE_MODULO = stack_op(operator.imod) | |
INPLACE_REMAINDER = stack_op(operator.imod) | |
INPLACE_ADD = stack_op(operator.iadd) | |
INPLACE_SUBTRACT = stack_op(operator.isub) | |
INPLACE_LSHIFT = stack_op(operator.ilshift) | |
INPLACE_RSHIFT = stack_op(operator.irshift) | |
INPLACE_AND = stack_op(operator.iand) | |
INPLACE_XOR = stack_op(operator.ixor) | |
INPLACE_OR = stack_op(operator.ior) | |
# 3.11 opcodes | |
def RESUME(self, inst): | |
if inst.arg == 0: | |
self.append_prefix_inst(inst) | |
self.accept_prefix_inst = False | |
else: | |
assert not self.accept_prefix_inst | |
def BINARY_OP(self, inst): | |
if sys.version_info >= (3, 11): | |
opname = dis._nb_ops[inst.arg][0][3:] # type: ignore[attr-defined] | |
if opname.startswith("INPLACE"): | |
return getattr(self, "INPLACE_" + opname[8:])(inst) | |
return getattr(self, "BINARY_" + opname)(inst) | |
else: | |
unimplemented("BINARY_OP requires Python 3.11+") | |
def PRECALL(self, inst): | |
pass | |
def KW_NAMES(self, inst): | |
kw_names = self.code_options["co_consts"][inst.arg] | |
assert isinstance(kw_names, tuple) | |
for name in kw_names: | |
assert isinstance(name, str) | |
assert self.kw_names is None | |
self.kw_names = ConstantVariable.create(value=kw_names) # type: ignore[assignment] | |
def PUSH_NULL(self, inst): | |
self.push(NullVariable()) | |
def CALL(self, inst): | |
# see https://docs.python.org/3.11/library/dis.html#opcode-CALL | |
# for convention | |
contents = self.popn(inst.arg + 2) | |
if isinstance(contents[0], NullVariable): | |
fn = contents[1] | |
args = [] | |
else: | |
fn = contents[0] | |
args = [contents[1]] | |
kw_names = self.kw_names.value if self.kw_names else () | |
if kw_names: | |
args = args + contents[2 : -len(kw_names)] | |
kwargs_list = contents[-len(kw_names) :] | |
kwargs = dict(zip(kw_names, kwargs_list)) | |
assert len(kwargs) == len(kw_names) | |
else: | |
args = args + contents[2:] | |
kwargs = {} | |
self.call_function(fn, args, kwargs) | |
self.kw_names = None | |
def COPY(self, inst): | |
self.push(self.stack[-inst.arg]) | |
def SWAP(self, inst): | |
self.stack[-1], self.stack[-inst.arg] = self.stack[-inst.arg], self.stack[-1] | |
JUMP_BACKWARD = jump | |
JUMP_BACKWARD_NO_INTERRUPT = jump | |
POP_JUMP_FORWARD_IF_TRUE = generic_jump(operator.truth, False) | |
POP_JUMP_BACKWARD_IF_TRUE = generic_jump(operator.truth, False) | |
POP_JUMP_FORWARD_IF_FALSE = generic_jump(operator.not_, False) | |
POP_JUMP_BACKWARD_IF_FALSE = generic_jump(operator.not_, False) | |
def CACHE(self, inst): | |
pass | |
def BEFORE_WITH(self, inst): | |
self.setup_or_before_with(inst) | |
def setup_or_before_with(self, inst): | |
ctx = self.pop() | |
if not isinstance(ctx, ContextWrappingVariable): | |
unimplemented(f"{inst.opname} {ctx}") | |
if isinstance(ctx, GenericContextWrappingVariable): | |
self.generic_context_manager_depth += 1 | |
exit = WithExitFunctionVariable( | |
ctx, | |
inst.target, | |
) | |
if sys.version_info >= (3, 11): | |
# see create_call_resume_at for block stack details | |
assert self.next_instruction | |
assert self.next_instruction.exn_tab_entry | |
target = self.next_instruction.exn_tab_entry.target | |
else: | |
target = inst.target | |
if isinstance(self, InstructionTranslator): | |
self.block_stack.append(BlockStackEntry(target, len(self.stack), ctx)) | |
else: | |
self.block_stack.append(BlockStackEntry(target)) | |
self.push(exit) | |
self.push(ctx.enter(self)) | |
def append_prefix_inst(self, inst): | |
assert self.accept_prefix_inst | |
self.prefix_insts.append(inst) | |
def MAKE_CELL(self, inst): | |
self.append_prefix_inst(inst) | |
def COPY_FREE_VARS(self, inst): | |
self.append_prefix_inst(inst) | |
def RETURN_GENERATOR(self, inst): | |
self.append_prefix_inst(inst) | |
def copy_graphstate(self) -> InstructionTranslatorGraphState: | |
"""Create a checkpoint of the current state by copying everything""" | |
return InstructionTranslatorGraphState( | |
self.output.copy_graphstate(), | |
dict(self.symbolic_locals), | |
list(self.stack), | |
list(self.block_stack), | |
self.instruction_pointer, | |
self.current_instruction, | |
self.next_instruction, | |
self.lineno, | |
) | |
def restore_graphstate(self, state: InstructionTranslatorGraphState): | |
"""Restore a checkpoint created by self.copy_graphstate()""" | |
( | |
output_state, | |
self.symbolic_locals, | |
self.stack, | |
self.block_stack, | |
self.instruction_pointer, | |
self.current_instruction, | |
self.next_instruction, | |
self.lineno, | |
) = state | |
self.output.restore_graphstate(output_state) | |
def is_non_empty_graph(self): | |
if self.output.count_calls() > 1: | |
# perf optimization only | |
self.is_non_empty_graph = lambda: True # type: ignore[method-assign] | |
return True | |
return False | |
def format_frame_summary(self, additional_stack_frames=None): | |
if additional_stack_frames is None: | |
additional_stack_frames = [] | |
return "".join( | |
traceback.format_list( | |
[self.frame_summary()] + list(reversed(additional_stack_frames)) | |
) | |
) | |
def frame_summary(self): | |
return traceback.FrameSummary( | |
getattr(self.f_code, "co_filename", "<unknown>"), | |
self.lineno, | |
getattr(self.f_code, "co_name", "<unknown>"), | |
lookup_line=False, | |
) | |
def store_global_weakref_by_id(self, prefix, value): | |
global_name = self.output.install_global_by_id(prefix, weakref.ref(value)) | |
install_guard( | |
GlobalWeakRefSource(global_name).make_guard(GuardBuilder.WEAKREF_ALIVE) | |
) | |
return global_name | |
def fake_mode(self): | |
return self.output.tracing_context.fake_mode | |
def find_symbolic_locals_name(self, tensor_variable): | |
for key, value in self.symbolic_locals.items(): | |
if value is tensor_variable: | |
return key | |
return None | |
def strict_translation_mode(self): | |
self.strict_checks_enabled = True | |
try: | |
yield | |
finally: | |
self.strict_checks_enabled = False | |
def speculate(self) -> SpeculationEntry: | |
return self.speculation_log.next( | |
self.f_code.co_filename, self.lineno, self.instruction_pointer | |
) | |
def __init__( | |
self, | |
output: OutputGraph, | |
instructions: List[Instruction], | |
f_locals: Dict[str, Any], | |
f_globals: Dict[str, Any], | |
f_builtins: Dict[str, Any], | |
code_options: Dict[str, Any], | |
symbolic_locals: Dict[str, VariableTracker], | |
symbolic_globals: Dict[str, VariableTracker], | |
f_code: types.CodeType, | |
export: bool, | |
inline_depth: int, | |
speculation_log: SpeculationLog, | |
): | |
super().__init__() | |
self.speculation_log = speculation_log | |
# Mutable state checkpointed by copy_graphstate() | |
self.output = output | |
self.symbolic_locals = symbolic_locals | |
self.symbolic_globals = symbolic_globals | |
self.stack = [] | |
self.instruction_pointer = 0 | |
self.current_instruction = create_instruction("NOP") | |
self.next_instruction = None | |
self.block_stack = [] | |
# states before SETUP_WITH for checkpointing and fallback | |
self.generic_context_manager_depth = 0 | |
self.lineno = code_options["co_firstlineno"] | |
self.kw_names = None | |
self.accept_prefix_inst = True | |
self.prefix_insts = [] | |
# Properties of the input/output code | |
self.instructions: List[Instruction] = instructions | |
self.indexof: Dict[Instruction, int] = get_indexof(self.instructions) | |
self.f_locals: Dict[ | |
str, Any | |
] = f_locals # needed for recording accessed locals for replay | |
self.f_globals: Dict[str, Any] = f_globals | |
self.f_builtins: Dict[str, Any] = f_builtins | |
self.code_options: Dict[str, Any] = code_options | |
self.f_code: types.CodeType = f_code | |
# Execution record for replaying errors | |
self.exec_recorder = ExecutionRecorder(code=f_code, code_options=code_options) | |
# Stack of module being parsed, current nn.module is at the end of ordered dict. | |
# The first field of tuple is the fully qualified name of current module | |
# in original hierarchy. The second field is the type of current nn.module | |
self.nn_module_stack: Dict[str, Tuple[str, Type[Any]]] = {} | |
# Flag to indicate whether tracing is used for export. | |
self.export = export | |
self.current_speculation = None | |
self.strict_checks_enabled = False | |
if sys.version_info >= (3, 10): | |
from .resume_execution import ( | |
CO_ASYNC_GENERATOR, | |
CO_COROUTINE, | |
CO_GENERATOR, | |
CO_ITERABLE_COROUTINE, | |
) | |
if f_code.co_flags & ( | |
CO_GENERATOR | CO_COROUTINE | CO_ITERABLE_COROUTINE | CO_ASYNC_GENERATOR | |
): | |
self.push(BuiltinVariable(None)) | |
self.inline_depth = inline_depth | |
self.inconsistent_side_effects = False | |
linecache.lazycache(f_code.co_filename, f_globals) | |
self.log_starts_line() | |
class InstructionTranslator(InstructionTranslatorBase): | |
mutated_closure_cell_contents: Set[str] | |
def current_tx() -> "InstructionTranslator": | |
return tls.current_tx | |
def set_current_tx(self): | |
prior = getattr(tls, "current_tx", None) | |
tls.current_tx = self | |
try: | |
yield | |
finally: | |
tls.current_tx = prior | |
def __init__( | |
self, | |
instructions: List[Instruction], | |
f_code, | |
f_locals, | |
f_globals, | |
f_builtins, | |
code_options, | |
compiler_fn, | |
one_graph, | |
export, | |
export_constraints, | |
mutated_closure_cell_contents: Set[str], | |
frame_state, | |
speculation_log: SpeculationLog, | |
): | |
_step_logger()( | |
logging.INFO, | |
f"torchdynamo start tracing {f_code.co_name} {code_options['co_filename']}:{code_options['co_firstlineno']}", | |
) | |
super().__init__( | |
output=OutputGraph( | |
code_options, | |
compiler_fn, | |
self, | |
export, | |
export_constraints, | |
frame_state, | |
local_scope=f_locals, | |
global_scope=f_globals, | |
f_code=f_code, | |
), | |
instructions=instructions, | |
f_locals=f_locals, | |
f_globals=f_globals, | |
f_builtins=f_builtins, | |
code_options=code_options, | |
symbolic_locals={}, # set below | |
# A global var is inserted only after a STORE_GLOBAL happens to it | |
symbolic_globals={}, | |
f_code=f_code, | |
export=export, | |
inline_depth=0, | |
speculation_log=speculation_log, | |
) | |
self._throw_if_in_functorch() | |
# as soon as we create the tracing context we should keep it active, so any calls | |
# into dynamo apis can rely on finding it | |
with tracing(self.output.tracing_context), self.set_current_tx(): | |
self.one_graph: bool = one_graph | |
self.export = export | |
self.mutated_closure_cell_contents = mutated_closure_cell_contents | |
if self.export: | |
assert ( | |
self.one_graph | |
), "Export without one graph - something has gone wrong." | |
vars = list(code_options["co_varnames"]) | |
cells_and_freevars = [x for x in self.cell_and_freevars() if x not in vars] | |
vars.extend(cells_and_freevars) | |
cells_and_freevars_set = set(cells_and_freevars) | |
self.symbolic_locals = { | |
k: variables.LazyVariableTracker.create( | |
f_locals[k], | |
source=LocalSource(k, cell_or_freevar=k in cells_and_freevars_set), | |
) | |
for k in vars | |
if k in f_locals | |
} | |
self.debug_locals: List[Tuple[VariableTracker, List[VariableTracker]]] = [] | |
if export: | |
# export gets confused if we never realize unused inputs | |
# in export mode just eagerly realize everything | |
self.symbolic_locals = VariableTracker.apply( | |
lambda x: x.realize(), self.symbolic_locals | |
) | |
self._freevars_ids = dict() | |
for name in self.code_options["co_freevars"]: | |
if name in f_locals: | |
self._freevars_ids[name] = id(f_locals[name]) | |
def _throw_if_in_functorch(self): | |
# Fallback to eager in case of a graph break inside vmap | |
eager = torch._dynamo.lookup_backend("eager") | |
compiler_fn = inspect.getattr_static( | |
self.output.compiler_fn, "compiler_fn", self.output.compiler_fn | |
) | |
ci = torch._C._functorch.peek_interpreter_stack() | |
forbidden_keys = ( | |
torch._C._functorch.TransformType.Vmap, | |
torch._C._functorch.TransformType.Grad, | |
) | |
if ci is not None and ci.key() in forbidden_keys and compiler_fn is not eager: | |
# if it reaches here, it means Dynamo failed to inline a functorch function | |
name = ci.key().name.lower() | |
msg = f"torch.func.{name}(fn) requires the function to be inlined by dynamo" | |
unimplemented(msg) | |
def get_example_value(self, source: Source): | |
if isinstance(source, LocalSource): | |
return self.f_locals[source.local_name] | |
if isinstance(source, GlobalSource): | |
return self.f_globals[source.global_name] | |
raise KeyError() | |
def run(self): | |
super().run() | |
def match_nested_cell(self, name, cell): | |
"""Match a cell in this method to one in a function we are inlining""" | |
try: | |
value = cell.cell_contents | |
except ValueError: | |
return None | |
# TODO(jansel): check the id of the cell rather than the contents | |
if id(value) != self._freevars_ids.get(name): | |
return None | |
return self.symbolic_locals[name] | |
def should_compile_partial_graph(self): | |
return ( | |
all(b.can_restore() for b in self.block_stack) | |
and not self.one_graph | |
and self.generic_context_manager_depth == 0 | |
) | |
def create_call_resume_at(self, inst): | |
self.instruction_pointer = None | |
if inst.opname == "RETURN_VALUE": | |
return [create_instruction("RETURN_VALUE")] | |
reads = livevars_analysis(self.instructions, inst) | |
argnames = tuple( | |
k | |
for k in self.symbolic_locals.keys() | |
if k in reads and k not in self.cell_and_freevars() | |
) | |
cg = PyCodegen(self) | |
# Python does not allow null to be an arg to a function, so | |
# we remove nulls from the stack and restore them in the | |
# prologue of the resume function | |
# sorted list of indices of nulls on the stack | |
null_idxes: List[int] = [] | |
if sys.version_info >= (3, 11): | |
# find indices of NullVariables | |
for i, var in enumerate(self.stack): | |
if isinstance(var, NullVariable): | |
null_idxes.append(i) | |
# generate bytecode to pop the nulls | |
null_cnt = 0 | |
for i, var in enumerate(reversed(self.stack)): | |
if isinstance(var, NullVariable): | |
for j in range(2, i + 2 - null_cnt): | |
cg.append_output(create_instruction("SWAP", arg=j)) | |
cg.extend_output(cg.pop_null()) | |
null_cnt += 1 | |
# we popped all nulls from the stack at runtime, | |
# so we should not count NullVariables | |
stack_len = len(self.stack) - len(null_idxes) | |
nargs = stack_len + len(argnames) | |
name = unique_id(f"__resume_at_{inst.offset}") | |
new_code: types.CodeType = ContinueExecutionCache.lookup( | |
self.f_code, | |
self.lineno, | |
inst.offset, | |
tuple(b.target.offset for b in self.block_stack), | |
stack_len, | |
argnames, | |
tuple(b.resume_fn() for b in self.block_stack), | |
tuple(null_idxes), | |
) | |
# Add original GraphModule context to the resume function to handle | |
# the case of a graph break while tracing a GraphModule | |
orig_graphmodule_maybe = code_context.get_context(self.f_code).get( | |
"orig_graphmodule", lambda: None | |
)() | |
if orig_graphmodule_maybe is not None: | |
code_context.get_context(new_code)["orig_graphmodule"] = weakref.ref( | |
orig_graphmodule_maybe | |
) | |
if new_code.co_freevars: | |
cg.make_function_with_closure(name, new_code, True, stack_len) | |
else: | |
# This is safe: we pre-generate a unique name | |
self.output.install_global_unsafe( | |
name, types.FunctionType(new_code, self.f_globals, name) | |
) | |
cg.extend_output(cg.load_function_name(name, True, stack_len)) | |
cg.extend_output([cg.create_load(k) for k in argnames]) | |
cg.extend_output(create_call_function(nargs, False)) | |
cg.append_output(create_instruction("RETURN_VALUE")) | |
return cg.get_instructions() | |
def symbolic_locals_contain_module_class(self): | |
for v in self.symbolic_locals.values(): | |
if isinstance(v, UserDefinedClassVariable) and issubclass( | |
v.as_python_constant(), torch.nn.Module | |
): | |
return True | |
return False | |
def RETURN_VALUE(self, inst): | |
if ( | |
self.output.count_calls() == 0 | |
and not self.inconsistent_side_effects | |
and not self.symbolic_locals_contain_module_class() | |
and not self.export | |
): | |
raise exc.SkipFrame("because no content in function call") | |
self.instruction_pointer = None | |
_step_logger()( | |
logging.INFO, | |
f"torchdynamo done tracing {self.f_code.co_name} (RETURN_VALUE)", | |
) | |
log.debug("RETURN_VALUE triggered compile") | |
self.output.compile_subgraph( | |
self, | |
reason=GraphCompileReason( | |
"return_value", [self.frame_summary()], graph_break=False | |
), | |
) | |
self.output.add_output_instructions([create_instruction("RETURN_VALUE")]) | |
class InliningInstructionTranslator(InstructionTranslatorBase): | |
"""Trace and inline a called method""" | |
symbolic_result: Optional[TensorVariable] | |
def inline_call(cls, parent, func, args, kwargs): | |
with patch.dict(counters, {"unimplemented": counters["inline_call"]}): | |
return cls.inline_call_(parent, func, args, kwargs) | |
def check_inlineable(func): | |
if func.has_self(): | |
unimplemented("inline with __self__") | |
result = trace_rules.check_verbose(func, is_inlined_call=True) | |
if result.skipped: | |
from torch._dynamo.variables.misc import produce_trampoline_autograd_apply | |
# _origin marks this as coming from an internal dynamo known function that is safe to | |
# trace through. | |
if hasattr(getattr(func, "fn", None), "_origin") and func.fn._origin in [ | |
produce_trampoline_autograd_apply, | |
]: | |
# Known sound | |
return trace_rules.SkipResult( | |
False, "allowlist in dynamo known function" | |
) | |
fn_qualname = func.fn.__qualname__ if hasattr(func, "fn") else "" | |
unimplemented( | |
f"'inline in skipfiles: {fn_qualname} | {func.get_name()} {func.get_filename()}, {result.reason}'" | |
) | |
if isinstance(func, UserFunctionVariable) and inspect.getattr_static( | |
func.get_function(), "_torchdynamo_disable", False | |
): | |
unimplemented( | |
f"call torch._dynamo.disable() wrapped function {func.get_function()}" | |
) | |
else: | |
return result | |
def inline_call_( | |
parent, func: VariableTracker, args: List[VariableTracker], kwargs | |
): | |
if isinstance(func, SkipFunctionVariable): | |
unimplemented("inline with functions in skip files") | |
assert isinstance( | |
func, | |
(UserFunctionVariable, NestedUserFunctionVariable), | |
) | |
result = InliningInstructionTranslator.check_inlineable(func) | |
assert result.skipped is False | |
try: | |
sub_locals, closure_cells = func.bind_args(parent, args, kwargs) | |
except TypeError as e: | |
# Wrap the general TypeError during bind_args() to the internal ArgsMismatchError with detailed info | |
raise ArgsMismatchError( # noqa: TRY200 | |
"{reason}.\n func = {func}, args = {args}, kwargs = {kwargs}".format( | |
reason=str(e), | |
func=f"'{func.get_name()}' {func.get_filename()}:{func.get_code().co_firstlineno}", | |
args=[arg.python_type() for arg in args], | |
kwargs=kwargs, | |
), | |
) | |
for v in itertools.chain(sub_locals.values(), closure_cells.values()): | |
if not isinstance(v, VariableTracker): | |
unimplemented(f"unconverted arg {v}") | |
code: types.CodeType = func.get_code() | |
if code.co_name in ("__setitem__", "__setattr__") and not ( | |
args is not None | |
and len(args) > 0 | |
and isinstance(args[0], variables.CustomizedDictVariable) | |
): | |
unimplemented(f"inline {code.co_name}") | |
suffix = "" | |
# TODO: mlazos, add support for enabling multiple artifact logs | |
# with a single alias | |
if torch._logging._internal.log_state.is_artifact_enabled("output_code"): | |
suffix = f"\n{dis.Bytecode(code).dis()}" | |
if sys.version_info >= (3, 11): | |
cur_inst = parent.current_instruction | |
parent_code = parent.f_code | |
header = parent.get_line_of_code_header(lineno=cur_inst.positions.lineno) | |
def get_trace_call_log_str(): | |
line = get_instruction_source_311(parent_code, cur_inst).rstrip() | |
return f"TRACE inlined call {code.co_name} from {header}\n{line}" | |
trace_call_log.debug("%s", LazyString(get_trace_call_log_str)) | |
log.debug("INLINING %s%s, %s", code, suffix, result.reason) | |
# Detect inline GraphModule calls in order to propagate node metadata, | |
# by checking if the first argument (self) is a variable tracking a GraphModule. | |
if args and isinstance(args[0], NNModuleVariable): | |
module = parent.output.get_submodule(args[0].module_key) | |
if isinstance(module, torch.fx.GraphModule): | |
# The inline call might not actually be a call to `forward`, | |
# but it is enough to add a context for `forward` in case it is called. | |
code_context.get_context(module.forward.__code__)[ | |
"orig_graphmodule" | |
] = weakref.ref(module) | |
tracer: InliningInstructionTranslator | |
if is_generator(code): | |
tracer = InliningGeneratorInstructionTranslator( | |
parent, code, sub_locals, parent.symbolic_globals, closure_cells, func | |
) | |
else: | |
tracer = InliningInstructionTranslator( | |
parent, code, sub_locals, parent.symbolic_globals, closure_cells, func | |
) | |
strict_ctx: Any = contextlib.nullcontext() | |
if parent.strict_checks_enabled: | |
strict_ctx = tracer.strict_translation_mode() | |
try: | |
with strict_ctx: | |
tracer.run() | |
except exc.SkipFrame as e: | |
msg = f"SKIPPED INLINING {code}: {e}" | |
log.debug(msg) | |
raise Unsupported(msg) from e | |
except Exception as e: | |
log.debug("FAILED INLINING %s", code) | |
raise | |
assert tracer.symbolic_result is not None | |
func.export_freevars(parent, tracer) | |
if tracer.f_globals is parent.f_globals: | |
# Merge symbolic_globals back if parent and child are in the same namespace | |
parent.symbolic_globals.update(tracer.symbolic_globals) | |
parent.inconsistent_side_effects |= tracer.inconsistent_side_effects | |
log.debug("DONE INLINING %s", code) | |
if is_generator(code): | |
assert isinstance(tracer, InliningGeneratorInstructionTranslator) | |
assert tracer.symbolic_result.as_python_constant() is None | |
return ListIteratorVariable( | |
tracer.generated_items, | |
mutable_local=MutableLocal(), | |
) | |
else: | |
return tracer.symbolic_result | |
def __init__( | |
self, | |
parent: InstructionTranslatorBase, | |
code: types.CodeType, | |
symbolic_locals: Dict[str, VariableTracker], | |
symbolic_globals: Dict[str, VariableTracker], | |
closure_cells: Dict[str, VariableTracker], | |
funcvar: BaseUserFunctionVariable, | |
): | |
f_globals = funcvar.get_globals() # type: ignore[attr-defined] | |
f_builtins = f_globals["__builtins__"] | |
if not isinstance(f_builtins, dict): | |
f_builtins = f_builtins.__dict__ | |
instructions = cleaned_instructions(code) | |
propagate_line_nums(instructions) | |
super().__init__( | |
output=parent.output, | |
f_locals={}, | |
f_globals=f_globals, | |
f_builtins=f_builtins, | |
symbolic_locals=symbolic_locals, | |
symbolic_globals=symbolic_globals, | |
instructions=instructions, | |
code_options={k: getattr(code, k) for k in dir(code)}, | |
f_code=code, | |
export=parent.export, | |
inline_depth=parent.inline_depth + 1, | |
speculation_log=parent.speculation_log, | |
) | |
self.parent = parent | |
self.symbolic_result = None | |
self.closure_cells = closure_cells | |
self.nn_module_stack = parent.nn_module_stack.copy() | |
def fake_mode(self): | |
return self.parent.fake_mode | |
def run_ctx_mgr(self): | |
return TracingContext.current_frame(self.parent.frame_summary()) | |
def STORE_DEREF(self, inst): | |
if inst.argval in self.closure_cells: | |
cell = self.closure_cells[inst.argval] | |
val = self.pop() | |
if isinstance(cell, ClosureVariable): | |
if not self.output.is_root_tracer(): | |
unimplemented( | |
"HigherOrderOperator: Mutating a variable not in the current scope (ClosureVariable)" | |
) | |
self.output.root_tx.symbolic_locals[cell.name] = val | |
else: | |
self.output.side_effects.store_cell(cell, val) | |
else: | |
maybe_cell = self.symbolic_locals.get(inst.argval) | |
if isinstance( | |
maybe_cell, | |
variables.NewCellVariable, | |
): | |
self.output.side_effects.store_cell( | |
self.symbolic_locals[inst.argval], self.pop() | |
) | |
else: | |
if ( | |
maybe_cell is not None | |
and maybe_cell.source.name() | |
not in self.output.root_tx.mutated_closure_cell_contents | |
): | |
# Why is the source name here unique? | |
# mutated_closure_cell_contents is a per-frame | |
# concept, and sources identify, e.g., particular | |
# locals from the frame. If you had two locals, | |
# they'll get different source names, and therefore | |
# differ here. | |
self.output.root_tx.mutated_closure_cell_contents.add( | |
maybe_cell.source.name() | |
) | |
raise exc.UnspecializeRestartAnalysis() | |
unimplemented("write to __closure__ while inlining") | |
def LOAD_DEREF(self, inst): | |
if inst.argval in self.closure_cells: | |
cell = self.closure_cells[inst.argval] | |
if isinstance(cell, ClosureVariable): | |
self.push(self.output.root_tx.symbolic_locals[cell.name]) | |
else: | |
self.push(self.output.side_effects.load_cell(cell)) | |
else: | |
maybe_sym_local = self.symbolic_locals.get(inst.argval, None) | |
if isinstance(maybe_sym_local, variables.NewCellVariable): | |
self.push(self.output.side_effects.load_cell(maybe_sym_local)) | |
else: | |
super().LOAD_DEREF(inst) | |
def LOAD_CLOSURE(self, inst): | |
assert inst.argval in self.cell_and_freevars() | |
if inst.argval in self.closure_cells: | |
self.push(self.closure_cells[inst.argval]) | |
else: | |
self.push(InlinedClosureVariable(name=inst.argval)) | |
def check_replace_is_safe(self, oldvar): | |
if not is_side_effect_safe(oldvar.mutable_local): | |
unimplemented( | |
"HigherOrderOperator: Mutating a variable not in the current scope (replace_all)" | |
) | |
def should_compile_partial_graph(self): | |
return False # inlining functions is all-or-nothing | |
def create_call_resume_at(self, offset): | |
unimplemented("cant resume while inlining") | |
def RETURN_VALUE(self, inst): | |
self.symbolic_result = self.pop() # type: ignore[assignment] | |
self.instruction_pointer = None | |
class InliningGeneratorInstructionTranslator(InliningInstructionTranslator): | |
generated_items: List[VariableTracker] | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.generated_items = [] | |
def YIELD_VALUE(self, inst: Instruction): | |
self.generated_items.append(self.pop()) | |
# TODO(jansel): figure out why this is needed, it isn't in the docs for YIELD_VALUE | |
self.push(ConstantVariable.create(None)) | |
def GET_YIELD_FROM_ITER(self, inst): | |
tos = self.stack[-1] | |
if not isinstance(tos, ListIteratorVariable): | |
self.pop() | |
res = BuiltinVariable(iter).call_function(self, [tos], {}) | |
self.push(res) | |
return self.YIELD_FROM(inst) | |
def YIELD_FROM(self, inst): | |
while True: | |
tos = self.stack[-1].realize() | |
if isinstance(tos, ConstantVariable) and tos.value is None: | |
self.pop() | |
return | |
if isinstance( | |
tos, (variables.ListIteratorVariable, variables.IteratorVariable) | |
): | |
try: | |
val, next_iter = tos.next_variables(self) | |
self.push(val) | |
# TODO(voz): Unclear if we need the push None in YIELD_VALUE? | |
self.YIELD_VALUE(inst) | |
self.pop() | |
self.push(next_iter) | |
except StopIteration: | |
return | |
else: | |
unimplemented(f"YIELD_FROM {typestr(tos)}") | |
def SEND(self, inst): | |
assert len(self.stack) >= 2 | |
val = self.pop() | |
tos = self.stack[-1] | |
if isinstance(tos, ListIteratorVariable): | |
if isinstance(val, ConstantVariable) and val.value is None: | |
self.push(val) | |
self.instruction_pointer = self.indexof[inst.target] | |
else: | |
# invoke send | |
# Unreachable code - if you hit this, you are implementing generator support and have | |
# lifted the `unimplemented("generator")` in frame conversion. This codepath handles | |
# subgenerator and lines up with this line in Python 3.11 | |
# https://github.com/python/cpython/blob/3.11/Python/ceval.c#L2597 | |
unimplemented("Unreachable sub-generator code") | |
else: | |
unimplemented(f"SEND {typestr(tos)}") | |