|
|
|
import collections |
|
import contextlib |
|
import copy |
|
import functools |
|
import itertools |
|
import logging |
|
import operator |
|
import re |
|
import sys |
|
import traceback |
|
import weakref |
|
from dataclasses import dataclass |
|
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING, Union |
|
|
|
import sympy |
|
|
|
import torch._guards |
|
|
|
import torch._logging |
|
|
|
import torch.nn |
|
import torch.utils._pytree as pytree |
|
from torch import fx |
|
from torch._guards import GlobalContextCheckpointState, Source, TracingContext |
|
from torch._utils_internal import signpost_event |
|
from torch.fx._lazy_graph_module import _make_graph_module |
|
from torch.fx.experimental._backward_state import BackwardState |
|
from torch.fx.experimental.symbolic_shapes import free_symbols, is_symbolic, ShapeEnv |
|
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts |
|
from torch.utils._python_dispatch import is_traceable_wrapper_subclass |
|
|
|
from . import config, logging as torchdynamo_logging, variables |
|
from .backends.registry import CompiledFn, CompilerFn |
|
from .bytecode_transformation import ( |
|
create_call_function, |
|
create_instruction, |
|
Instruction, |
|
unique_id, |
|
) |
|
from .code_context import code_context |
|
from .codegen import PyCodegen |
|
from .current_scope_id import enter_new_scope |
|
from .exc import ( |
|
BackendCompilerFailed, |
|
exceptions_allowed_to_be_fallback, |
|
SkipFrame, |
|
unimplemented, |
|
unimplemented_with_warning, |
|
) |
|
from .guards import GuardBuilder, install_guard |
|
from .mutation_guard import is_dynamic_nn_module |
|
from .side_effects import AttributeMutationExisting, SideEffects |
|
from .source import ( |
|
AttrSource, |
|
BackwardStateSource, |
|
ConstantSource, |
|
GetItemSource, |
|
GlobalStateSource, |
|
is_constant_source, |
|
is_from_local_source, |
|
LocalSource, |
|
ParamBufferSource, |
|
ShapeEnvSource, |
|
SyntheticLocalSource, |
|
TensorProperty, |
|
TensorPropertySource, |
|
) |
|
from .utils import ( |
|
checkpoint_params, |
|
CleanupHook, |
|
clone_inputs, |
|
count_calls, |
|
counters, |
|
dynamo_timed, |
|
get_instruction_source_311, |
|
get_locals_to_steal, |
|
get_static_address_type, |
|
graph_break_reasons, |
|
increment_op_count, |
|
lazy_format_graph_code, |
|
LazyString, |
|
nn_module_proxy, |
|
same, |
|
set_example_value, |
|
) |
|
from .variables.base import VariableTracker |
|
from .variables.builder import ( |
|
BackwardStateGraphArg, |
|
GraphArg, |
|
TrackedFake, |
|
VariableBuilder, |
|
wrap_fx_proxy, |
|
) |
|
from .variables.lists import BaseListVariable |
|
from .variables.misc import NullVariable |
|
from .variables.nn_module import NNModuleVariable |
|
from .variables.tensor import ( |
|
NumpyNdarrayVariable, |
|
SymNodeVariable, |
|
TensorVariable, |
|
UnspecializedPythonVariable, |
|
) |
|
|
|
from .variables.torch_function import TensorWithTFOverrideVariable |
|
|
|
if TYPE_CHECKING: |
|
from torch._dynamo.symbolic_convert import InstructionTranslatorBase |
|
|
|
log = logging.getLogger(__name__) |
|
graph_tabular_log = torch._logging.getArtifactLogger(__name__, "graph") |
|
graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code") |
|
graph_sizes_log = torch._logging.getArtifactLogger(__name__, "graph_sizes") |
|
trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call") |
|
|
|
|
|
@dataclass(frozen=True) |
|
class VariableTrackerCacheKey: |
|
vt_id: int |
|
|
|
|
|
|
|
source: Source |
|
|
|
|
|
class VariableTrackerCache: |
|
def __init__(self): |
|
self.cache = {} |
|
|
|
def lookup(self, value, source): |
|
key = VariableTrackerCacheKey(id(value), source) |
|
if key not in self.cache: |
|
return None |
|
return self.cache[key] |
|
|
|
def add(self, value, source, vt): |
|
key = VariableTrackerCacheKey(id(value), source) |
|
self.cache[key] = vt |
|
|
|
def clone(self): |
|
|
|
new_cache = VariableTrackerCache() |
|
new_cache.cache.update(self.cache) |
|
return new_cache |
|
|
|
def clear(self): |
|
self.cache.clear() |
|
|
|
|
|
@functools.lru_cache(None) |
|
def _step_logger(): |
|
return torchdynamo_logging.get_step_logger(log) |
|
|
|
|
|
@dataclass |
|
class GraphCompileReason: |
|
"""Stores why a given output graph was compiled; i.e. what caused the graph break.""" |
|
|
|
reason: str |
|
user_stack: List[traceback.FrameSummary] |
|
|
|
|
|
graph_break: bool = True |
|
|
|
def __post_init__(self): |
|
if self.graph_break: |
|
graph_break_reasons.append(self) |
|
|
|
|
|
def _get_gen_rand_values_fn(random_calls): |
|
def _gen_rand_values(): |
|
return [fn(*args, **kwargs) for fn, args, kwargs in random_calls] |
|
|
|
return _gen_rand_values |
|
|
|
|
|
class FakeRootModule(torch.nn.Module): |
|
"""Trick the constructor of fx.GraphModule""" |
|
|
|
def __init__(self, nn_modules: Dict[str, torch.nn.Module]): |
|
super().__init__() |
|
for k, v in nn_modules.items(): |
|
setattr(self, k, v) |
|
|
|
def __repr__(self): |
|
return "FakeRootModule(...)" |
|
|
|
|
|
class WrapperBackend: |
|
def __init__(self, backend: CompilerFn): |
|
self.backend: CompilerFn = backend |
|
|
|
def __call__(self, gm: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): |
|
self.restore = checkpoint_params(gm) |
|
self.gm = gm |
|
copy_gm = copy.deepcopy(self.gm) |
|
self.candidate = self.backend(copy_gm, example_inputs) |
|
|
|
if self.candidate is None or self.candidate is self.gm.forward: |
|
return self.gm.forward |
|
|
|
if not config.verify_correctness: |
|
return self.candidate |
|
|
|
|
|
try: |
|
correct = self.gm.forward(*clone_inputs(example_inputs)) |
|
result = self.candidate(*clone_inputs(example_inputs)) |
|
|
|
|
|
if same(correct, result): |
|
return self.candidate |
|
|
|
raise RuntimeError(f"incorrect results of backend {self}") |
|
return self.gm.forward |
|
|
|
except Exception: |
|
log.exception("error in verify_correctness") |
|
raise |
|
finally: |
|
self.restore() |
|
|
|
|
|
Scope = Dict[str, object] |
|
|
|
|
|
class OutputGraph: |
|
""" |
|
Wrapper class to hold outputs of InstructionTranslator. Mainly the |
|
generated fx.Graph. |
|
|
|
OutputGraph is 1:1 with a frame being processed. Each frame is associated |
|
with some root InstructionTranslator. When user code calls a function, |
|
we construct a InliningInstructionTranslator that continues to write into |
|
the root InstructionTranslator's OutputGraph. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
code_options: Dict[str, Any], |
|
compiler_fn: Optional[CompilerFn], |
|
root_tx, |
|
export: bool, |
|
export_constraints, |
|
frame_state, |
|
local_scope: Scope, |
|
global_scope: Scope, |
|
f_code, |
|
): |
|
super().__init__() |
|
self.tracers = [SubgraphTracer(self, export_root=export)] |
|
|
|
|
|
self.input_source_to_var: Dict[Source, VariableTracker] = {} |
|
self.export = export |
|
self.export_constraints = export_constraints |
|
self.frame_state = frame_state |
|
|
|
self.input_source_to_sizes_strides: Dict[Source, Dict[str, Any]] = {} |
|
self.cleanup_hooks: List[Callable[[], Any]] = [] |
|
|
|
self.compile_id: int = next(_compile_id_counter) |
|
|
|
self.installed_globals: Set[str] = set() |
|
|
|
|
|
|
|
self.co_fields = { |
|
"co_name": f_code.co_name, |
|
"co_filename": f_code.co_filename, |
|
"co_firstlineno": f_code.co_firstlineno, |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.tracked_fakes: List[TrackedFake] = [] |
|
|
|
|
|
|
|
self.bound_symbols: Set[sympy.Symbol] = set() |
|
|
|
shape_env = ShapeEnv( |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tracked_fakes=self.tracked_fakes, |
|
allow_scalar_outputs=config.capture_scalar_outputs, |
|
allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops, |
|
prefer_deferred_runtime_asserts_over_guards=config.prefer_deferred_runtime_asserts_over_guards, |
|
_allow_complex_guards_as_runtime_asserts=config._allow_complex_guards_as_runtime_asserts, |
|
co_fields=self.co_fields, |
|
) |
|
|
|
|
|
|
|
import torch._functorch.config as _config |
|
|
|
with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False): |
|
fake_mode = torch._subclasses.FakeTensorMode( |
|
shape_env=shape_env, |
|
|
|
allow_non_fake_inputs=True if self.export else False, |
|
export=self.export, |
|
) |
|
self.tracing_context: TracingContext = TracingContext(fake_mode) |
|
self.init_ambient_guards() |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.tracked_fakes_id_to_source: Dict[ |
|
int, List[Source] |
|
] = collections.defaultdict(list) |
|
|
|
self.param_name_to_source: Optional[Dict[str, Source]] = dict() |
|
self.side_effects = SideEffects() |
|
|
|
|
|
self.variable_tracker_cache = VariableTrackerCache() |
|
self.unique_var_id = itertools.count() |
|
self.code_options = dict(code_options) |
|
self.output_instructions: List[Instruction] = [] |
|
|
|
|
|
self.timestamp = 0 |
|
|
|
|
|
self.register_finalizer_fns: List[Callable[[fx.GraphModule], None]] = [] |
|
|
|
|
|
self.compiler_fn: Optional[CompilerFn] = compiler_fn |
|
self.global_scope = global_scope |
|
self.local_scope = local_scope |
|
self.root_tx = root_tx |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.source_to_user_stacks: Dict[Source, List[traceback.StackSummary]] = {} |
|
|
|
self._current_tx: List[InstructionTranslatorBase] = [] |
|
self.cleanups: List[CleanupHook] = [] |
|
self.should_exit = False |
|
self.unspec_variable_map: Dict[str, UnspecializedPythonVariable] = {} |
|
self.torch_function_enabled = torch._C._is_torch_function_enabled() |
|
|
|
|
|
|
|
|
|
|
|
self.has_user_defined_allowed_in_graph = False |
|
|
|
|
|
|
|
self.non_compliant_ops: Set[torch._ops.OpOverload] = set({}) |
|
|
|
|
|
|
|
self.compliant_custom_ops: Set[torch._ops.OpOverload] = set({}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.save_global_state() |
|
|
|
|
|
|
|
self.dynamo_flat_name_to_original_fqn: Dict[str, str] = {} |
|
|
|
|
|
|
|
|
|
|
|
self.random_calls: List[ |
|
Tuple[Callable[..., object], Tuple[object, ...], Dict[str, object]] |
|
] = [] |
|
self.random_values_var = None |
|
|
|
|
|
self.pregraph_bytecode: List[Instruction] = [] |
|
|
|
|
|
self.backward_state: Dict[str, VariableTracker] = {} |
|
self.backward_state_proxy: Optional[torch.fx.Proxy] = None |
|
self.backward_state_var: Optional[str] = None |
|
|
|
self.name_of_builtins_dict_key_in_fglobals: str = ( |
|
self.install_builtins_dict_in_fglobals() |
|
) |
|
|
|
self.guard_on_key_order: Set[str] = set() |
|
|
|
def install_builtins_dict_in_fglobals(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
f_builtins = self.global_scope["__builtins__"] |
|
if not isinstance(f_builtins, dict): |
|
f_builtins = f_builtins.__dict__ |
|
return self.install_global("__builtins_dict__", f_builtins) |
|
|
|
def add_backward_state_hook(self, hook: VariableTracker, prefix="hook"): |
|
name = f"{prefix}{len(self.backward_state)}" |
|
assert name not in self.backward_state |
|
self.backward_state[name] = hook |
|
return name, self.get_backward_state_proxy() |
|
|
|
def get_backward_state_proxy(self): |
|
if self.backward_state_proxy is None: |
|
if self.export: |
|
unimplemented("backward_state does not support export") |
|
self.backward_state_proxy = self.root_tracer.create_graph_input( |
|
"dynamo_backward_state", BackwardState, source=BackwardStateSource() |
|
) |
|
self.backward_state_proxy.node.meta["grapharg"] = BackwardStateGraphArg() |
|
set_example_value(self.backward_state_proxy.node, BackwardState()) |
|
self.backward_state_var = self.new_var() |
|
return self.backward_state_proxy |
|
|
|
|
|
def init_ambient_guards(self): |
|
|
|
|
|
self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV)) |
|
|
|
self.guards.add( |
|
GlobalStateSource().make_guard(GuardBuilder.DETERMINISTIC_ALGORITHMS) |
|
) |
|
|
|
self.guards.add(GlobalStateSource().make_guard(GuardBuilder.GRAD_MODE)) |
|
|
|
self.guards.add(GlobalStateSource().make_guard(GuardBuilder.DEFAULT_DEVICE)) |
|
|
|
self.guards.add( |
|
GlobalStateSource().make_guard(GuardBuilder.TORCH_FUNCTION_STATE) |
|
) |
|
|
|
ci = torch._C._functorch.peek_interpreter_stack() |
|
if ci is not None: |
|
self.guards.add( |
|
GlobalStateSource().make_guard(GuardBuilder.FUNCTORCH_STACK_MATCH) |
|
) |
|
|
|
def synthetic_graph_input(self, fn, args): |
|
""" |
|
call fn(*args) before the graph runs and turn the result into a fake input. |
|
""" |
|
example_value = fn(*args) |
|
varname = self.new_var() |
|
cg = PyCodegen(self.root_tx) |
|
cg.load_import_from( |
|
fn.__module__, |
|
fn.__name__, |
|
) |
|
cg.foreach(map(variables.ConstantVariable.create, args)) |
|
cg.call_function(len(args), True) |
|
cg.store(varname) |
|
self.pregraph_bytecode.extend(cg.get_instructions()) |
|
source = SyntheticLocalSource(varname) |
|
result = VariableBuilder(self.root_tx, source)(example_value) |
|
TracingContext.get().guards_context.dynamo_guards.remove_guards_with_source( |
|
source |
|
) |
|
return result |
|
|
|
def add_cleanup_hook(self, fn: Callable[[], Any]): |
|
self.cleanup_hooks.append(fn) |
|
|
|
def call_cleanup_hooks(self): |
|
for hook in reversed(self.cleanup_hooks): |
|
hook() |
|
self.cleanup_hooks.clear() |
|
|
|
@property |
|
def root_tracer(self): |
|
return self.tracers[0] |
|
|
|
@property |
|
def current_tracer(self): |
|
return self.tracers[-1] |
|
|
|
def is_root_tracer(self): |
|
|
|
return len(self.tracers) == 1 |
|
|
|
@property |
|
def graph(self): |
|
return self.current_tracer.graph |
|
|
|
|
|
@graph.setter |
|
def graph(self, value): |
|
self.current_tracer.graph = value |
|
|
|
@property |
|
def input_name_to_proxy(self): |
|
return self.current_tracer.input_name_to_proxy |
|
|
|
@property |
|
def real_value_cache(self): |
|
return self.current_tracer.real_value_cache |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_proxy(self, *args, **kwargs): |
|
return self.current_tracer.create_proxy(*args, **kwargs) |
|
|
|
def create_node(self, *args, **kwargs): |
|
return self.current_tracer.create_node(*args, **kwargs) |
|
|
|
def remove_node(self, *args, **kwargs): |
|
return self.current_tracer.remove_node(*args, **kwargs) |
|
|
|
@contextlib.contextmanager |
|
def subtracer(self, source_target, prior_tracer): |
|
new_scope_ctx = enter_new_scope() |
|
try: |
|
if prior_tracer: |
|
|
|
assert prior_tracer.parent is self.current_tracer |
|
new_scope_ctx.__enter__() |
|
tracer = ( |
|
prior_tracer |
|
if prior_tracer |
|
else SubgraphTracer( |
|
self, parent=self.current_tracer, source_target=source_target |
|
) |
|
) |
|
self.tracers.append(tracer) |
|
yield tracer |
|
finally: |
|
new_scope_ctx.__exit__(None, None, None) |
|
self.tracers.pop() |
|
|
|
@property |
|
def output(self): |
|
return self |
|
|
|
@property |
|
def fake_mode(self): |
|
return self.tracing_context.fake_mode |
|
|
|
@property |
|
def shape_env(self): |
|
return self.tracing_context.fake_mode.shape_env |
|
|
|
@property |
|
def guards(self) -> torch._guards.GuardsSet: |
|
return self.tracing_context.guards_context.dynamo_guards |
|
|
|
@property |
|
def nn_modules(self) -> Dict[str, Any]: |
|
return self.tracing_context.module_context.nn_modules |
|
|
|
def save_global_state(self, out=None): |
|
""" |
|
Saves to out if it is provided. Else saves to the tracing context's global_state. |
|
""" |
|
global_state = ( |
|
out if out is not None else self.tracing_context.global_context.global_state |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
global_state["torch_function_enabled"] = ( |
|
self.set_torch_function_state, |
|
self.torch_function_enabled, |
|
) |
|
global_state["grad_enabled"] = (torch.set_grad_enabled, torch.is_grad_enabled()) |
|
|
|
global_state["autocast_enabled"] = ( |
|
functools.partial(torch.set_autocast_enabled, "cuda"), |
|
torch.is_autocast_enabled("cuda"), |
|
) |
|
global_state["autocast_cpu_enabled"] = ( |
|
functools.partial(torch.set_autocast_enabled, "cpu"), |
|
torch.is_autocast_enabled("cpu"), |
|
) |
|
global_state["autocast_gpu_dtype"] = ( |
|
functools.partial(torch.set_autocast_dtype, "cuda"), |
|
torch.get_autocast_dtype("cuda"), |
|
) |
|
global_state["autocast_cpu_dtype"] = ( |
|
functools.partial(torch.set_autocast_dtype, "cpu"), |
|
torch.get_autocast_dtype("cpu"), |
|
) |
|
global_state["autocast_cache_enabled"] = ( |
|
torch.set_autocast_cache_enabled, |
|
torch.is_autocast_cache_enabled(), |
|
) |
|
|
|
def push_tx(self, tx): |
|
self._current_tx.append(tx) |
|
|
|
def pop_tx(self): |
|
return self._current_tx.pop() |
|
|
|
@property |
|
def current_tx(self): |
|
return self.root_tx if not self._current_tx else self._current_tx[-1] |
|
|
|
def add_symbol_bindings(self, arg: GraphArg): |
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.export: |
|
return |
|
|
|
assert arg.fake_tensor is not None |
|
|
|
def bind_symint(s, prop): |
|
if not (is_symbolic(s) and isinstance(s.node.expr, sympy.Symbol)): |
|
return |
|
s0 = s.node.expr |
|
if s0 in self.bound_symbols: |
|
return |
|
self.bound_symbols.add(s0) |
|
log.debug("bind_symint %s %s", s, prop.name()) |
|
|
|
|
|
proxy = self.root_tracer.create_graph_input( |
|
str(s0), |
|
torch.SymInt, |
|
before=True, |
|
source=prop, |
|
) |
|
set_example_value(proxy.node, s) |
|
proxy.node.meta["grapharg"] = GraphArg( |
|
prop, |
|
s, |
|
pass_arg_as_tensor=False, |
|
fake_tensor=None, |
|
is_tensor=False, |
|
) |
|
|
|
def handle_tensor(t, src): |
|
for i, s in enumerate(t.size()): |
|
bind_symint(s, TensorPropertySource(src, TensorProperty.SIZE, i)) |
|
if t.layout is torch.strided: |
|
for i, s in enumerate(t.stride()): |
|
bind_symint(s, TensorPropertySource(src, TensorProperty.STRIDE, i)) |
|
bind_symint( |
|
t.storage_offset(), |
|
TensorPropertySource(src, TensorProperty.STORAGE_OFFSET), |
|
) |
|
elif t.layout is torch.sparse_coo: |
|
handle_tensor(t._indices(), src) |
|
handle_tensor(t._values(), src) |
|
elif t.layout in {torch.sparse_csr, torch.sparse_bsr}: |
|
handle_tensor(t.crow_indices(), src) |
|
handle_tensor(t.col_indices(), src) |
|
elif t.layout in {torch.sparse_csc, torch.sparse_bsc}: |
|
handle_tensor(t.ccol_indices(), src) |
|
handle_tensor(t.row_indices(), src) |
|
if is_traceable_wrapper_subclass(t): |
|
attrs, ctx = t.__tensor_flatten__() |
|
for attr in attrs: |
|
inner_t = getattr(t, attr) |
|
handle_tensor(inner_t, AttrSource(src, attr)) |
|
|
|
handle_tensor(arg.fake_tensor, arg.source) |
|
|
|
def count_calls(self): |
|
return count_calls(self.graph) |
|
|
|
def is_empty_graph(self): |
|
return len(list(self.graph.nodes)) == 0 |
|
|
|
def get_submodule(self, keys): |
|
assert keys |
|
obj: Union[torch.nn.Module, Dict[str, torch.nn.Module]] = self.nn_modules |
|
for k in keys.split("."): |
|
if isinstance(obj, dict): |
|
obj = obj[k] |
|
else: |
|
obj = getattr(obj, k) |
|
return obj |
|
|
|
def new_var(self, name="tmp"): |
|
existing = set(self.code_options["co_varnames"]) |
|
|
|
while True: |
|
var = f"{name}_{next(self.unique_var_id)}" |
|
if var not in existing: |
|
self.code_options["co_varnames"] += (var,) |
|
return var |
|
|
|
def update_co_names(self, name): |
|
"""Ensure self.code_options.co_names contains name""" |
|
if name not in self.code_options["co_names"]: |
|
self.code_options["co_names"] += (name,) |
|
|
|
@staticmethod |
|
def module_key_name(*names): |
|
|
|
name = "_".join(map(str, names)) |
|
|
|
name = re.sub(r"^[GL]\['?(.*?)'?\]$", r"\1", name) |
|
|
|
name = re.sub(r"\[(\d+)\]", r"_\g<1>", name) |
|
|
|
name = re.sub(r"[^a-zA-Z0-9]", "_", name) |
|
|
|
if not name or not name[0].isalpha(): |
|
name = "sub" + name |
|
|
|
return name |
|
|
|
def register_attr_or_module( |
|
self, |
|
target: Union[torch.nn.Module, torch.Tensor, Any], |
|
*names, |
|
**options, |
|
): |
|
if is_dynamic_nn_module(target, self.root_tx.export): |
|
return variables.UnspecializedNNModuleVariable(target, **options) |
|
|
|
options = dict(options) |
|
assert "source" in options |
|
source = options["source"] |
|
assert not isinstance(source, ParamBufferSource) |
|
|
|
if isinstance(target, torch.Tensor): |
|
tracer = self.current_tracer |
|
if not self.is_root_tracer(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tracer = self.root_tracer |
|
|
|
def wrap_name(module_key): |
|
assert self.param_name_to_source is not None |
|
self.param_name_to_source[module_key] = source |
|
|
|
|
|
|
|
if target in self.root_tx.output.side_effects: |
|
return self.root_tx.output.side_effects[target] |
|
|
|
if get_static_address_type(target) == "guarded": |
|
install_guard(source.make_guard(GuardBuilder.ID_MATCH)) |
|
elif not is_constant_source(source): |
|
install_guard(source.make_guard(GuardBuilder.TENSOR_MATCH)) |
|
|
|
vt = wrap_fx_proxy( |
|
self.root_tx, |
|
tracer.create_proxy("get_attr", module_key, tuple(), {}), |
|
example_value=target, |
|
**options, |
|
) |
|
|
|
|
|
|
|
vt = self.root_tx.output.side_effects.track_object_existing(target, vt) |
|
return vt |
|
|
|
elif isinstance(target, torch.nn.Module): |
|
assert isinstance(target, torch.nn.Module) |
|
|
|
if source: |
|
install_guard(source.make_guard(GuardBuilder.NN_MODULE)) |
|
|
|
def wrap_name(module_key): |
|
return NNModuleVariable(type(target), module_key, target, **options) |
|
|
|
else: |
|
|
|
|
|
|
|
|
|
def wrap_name(module_key): |
|
return variables.UnspecializedNNModuleVariable(target, **options) |
|
|
|
elif isinstance(target, (torch.SymInt, torch.SymFloat)): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def wrap_name(module_key): |
|
return SymNodeVariable.create( |
|
self, |
|
self.create_proxy("get_attr", module_key, tuple(), {}), |
|
sym_num=target, |
|
**options, |
|
) |
|
|
|
|
|
else: |
|
|
|
def wrap_name(module_key): |
|
self.output.update_co_names(module_key) |
|
self.global_scope[module_key] = target |
|
return VariableBuilder(self, ConstantSource(source_name=module_key))( |
|
target |
|
) |
|
|
|
for k, v in self.nn_modules.items(): |
|
if v is target: |
|
|
|
return wrap_name(k) |
|
|
|
name = OutputGraph.module_key_name(*names) |
|
|
|
base = name |
|
for i in itertools.count(): |
|
if name not in self.nn_modules: |
|
self.nn_modules[name] = target |
|
if isinstance(target, torch.nn.Module): |
|
|
|
def register_leaf_name(leaf_name): |
|
assert self.param_name_to_source is not None |
|
new_source = ParamBufferSource(source, leaf_name) |
|
new_name = f"{name}.{leaf_name}" |
|
self.param_name_to_source[new_name] = new_source |
|
if isinstance(source, LocalSource): |
|
self.dynamo_flat_name_to_original_fqn[ |
|
OutputGraph.module_key_name(new_source.name()) |
|
] = leaf_name |
|
|
|
|
|
|
|
if hasattr(target, "_parameters"): |
|
for leaf_name, _ in target.named_parameters(): |
|
register_leaf_name(leaf_name) |
|
if hasattr(target, "_buffers"): |
|
for leaf_name, _ in target.named_buffers(): |
|
register_leaf_name(leaf_name) |
|
|
|
return wrap_name(name) |
|
name = f"{base}_{i}" |
|
|
|
raise AssertionError("unreachable") |
|
|
|
def handle_aliases_for_stolen_lists(self, tx): |
|
|
|
maybe_gm = self.local_scope.get("self") |
|
stolen_list_names = get_locals_to_steal(maybe_gm) |
|
if not stolen_list_names: |
|
return [] |
|
|
|
alias_insts = [] |
|
needs_alias: Dict[ |
|
str, List[Union[VariableTracker, AttributeMutationExisting]] |
|
] = {} |
|
|
|
queue = [ |
|
*tx.stack, |
|
*tx.symbolic_locals.values(), |
|
*self.side_effects.store_attr_mutations.keys(), |
|
] |
|
|
|
while queue: |
|
x = queue.pop() |
|
if isinstance(x, BaseListVariable): |
|
assert isinstance(x.items, List) |
|
queue += x.items |
|
continue |
|
|
|
if not ( |
|
isinstance(x, (VariableTracker, AttributeMutationExisting)) |
|
and isinstance(x.source, GetItemSource) |
|
and isinstance(x.source.base, LocalSource) |
|
and x.source.base.local_name in stolen_list_names |
|
): |
|
continue |
|
|
|
stolen_name = x.source.base.local_name |
|
if stolen_name not in needs_alias: |
|
needs_alias[stolen_name] = [] |
|
needs_alias[stolen_name].append(x) |
|
|
|
visited = {} |
|
for arg in self.graphargs: |
|
if not ( |
|
isinstance(arg._example, list) |
|
and isinstance(arg.source, LocalSource) |
|
and arg.source.local_name in needs_alias |
|
): |
|
continue |
|
|
|
|
|
list_name = arg.source.local_name |
|
assert list_name in self.code_options["co_varnames"] |
|
for x in needs_alias[list_name]: |
|
list_idx = x.source.index |
|
if list_idx not in visited: |
|
alias_name = self.new_var( |
|
f"{list_name}_ref" |
|
) |
|
|
|
visited[list_idx] = alias_name |
|
|
|
alias_insts.extend( |
|
[ |
|
create_instruction("LOAD_FAST", argval=list_name), |
|
create_instruction("LOAD_CONST", argval=list_idx), |
|
create_instruction("BINARY_SUBSCR"), |
|
create_instruction("STORE_FAST", argval=alias_name), |
|
] |
|
) |
|
|
|
|
|
x.source = LocalSource(visited[list_idx]) |
|
|
|
return alias_insts |
|
|
|
def compile_subgraph( |
|
self, tx, partial_convert=False, reason: Optional[GraphCompileReason] = None |
|
): |
|
""" |
|
Generate a subgraph to continue execution on user code. |
|
Automatically restore live variables. |
|
""" |
|
assert reason is not None |
|
|
|
from .decorators import disable |
|
|
|
self.partial_convert = partial_convert |
|
self.compile_subgraph_reason = reason |
|
self.should_exit = True |
|
|
|
log.debug("COMPILING GRAPH due to %s", reason) |
|
|
|
if not all(block.can_restore() for block in tx.block_stack): |
|
unimplemented("compile_subgraph with block_depth != 0") |
|
|
|
prefix_insts: List[Instruction] = [] |
|
if sys.version_info >= (3, 11): |
|
|
|
for inst in tx.prefix_insts: |
|
if inst.opname == "MAKE_CELL": |
|
prefix_insts.append( |
|
create_instruction("MAKE_CELL", argval=inst.argval) |
|
) |
|
elif inst.opname == "COPY_FREE_VARS": |
|
prefix_insts.append( |
|
create_instruction( |
|
"COPY_FREE_VARS", arg=len(tx.code_options["co_freevars"]) |
|
) |
|
) |
|
else: |
|
prefix_insts.append(copy.copy(inst)) |
|
assert not ( |
|
self.pregraph_bytecode and self.export |
|
), "export does not support pregraph_bytecode" |
|
prefix_insts.extend(self.pregraph_bytecode) |
|
prefix_insts.extend(self.handle_aliases_for_stolen_lists(tx)) |
|
|
|
def append_prefix_insts(): |
|
self.add_output_instructions(prefix_insts) |
|
prefix_insts.clear() |
|
|
|
for block in reversed(tx.block_stack): |
|
block.exit(tx) |
|
|
|
self.cleanup_graph() |
|
tx.prune_dead_locals() |
|
stack_values = list(tx.stack) |
|
|
|
|
|
|
|
for value in stack_values: |
|
value.realize() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nn_modules_proxies = { |
|
name: nn_module_proxy(mod) for name, mod in self.nn_modules.items() |
|
} |
|
root = FakeRootModule(nn_modules_proxies) |
|
|
|
restore_vars = [] |
|
val_to_names: Dict[VariableTracker, List[str]] = {} |
|
if stack_values: |
|
val_to_names[stack_values[-1]] = list() |
|
|
|
|
|
|
|
|
|
|
|
for k, v in tx.symbolic_locals.items(): |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(v.source, LocalSource) and v.source.local_name == k: |
|
continue |
|
|
|
if sys.version_info >= (3, 12): |
|
|
|
if type.__instancecheck__(NullVariable, v): |
|
continue |
|
else: |
|
|
|
assert not type.__instancecheck__(NullVariable, v) |
|
if v not in val_to_names: |
|
val_to_names[v] = list() |
|
val_to_names[v].append(k) |
|
for v in val_to_names.keys(): |
|
restore_vars.extend(val_to_names[v]) |
|
stack_values.extend([v] * len(val_to_names[v])) |
|
|
|
|
|
if len(self.random_calls) > 0: |
|
append_prefix_insts() |
|
random_calls_instructions = [] |
|
self.random_values_var = self.new_var("random_values") |
|
rand_fn = disable(_get_gen_rand_values_fn(self.random_calls)) |
|
rand_fn_name = self.install_global("__gen_rand_values", rand_fn) |
|
codegen = PyCodegen(tx, root) |
|
random_calls_instructions.extend( |
|
codegen.load_function_name(rand_fn_name, True) |
|
) |
|
random_calls_instructions.extend(create_call_function(0, False)) |
|
random_calls_instructions.append( |
|
codegen.create_store(tx.output.random_values_var), |
|
) |
|
self.add_output_instructions(random_calls_instructions) |
|
|
|
if ( |
|
stack_values |
|
and all( |
|
not isinstance( |
|
v, |
|
( |
|
UnspecializedPythonVariable, |
|
NumpyNdarrayVariable, |
|
TensorWithTFOverrideVariable, |
|
), |
|
) |
|
and not (isinstance(v, SymNodeVariable) and v.python_type() is float) |
|
for v in stack_values |
|
) |
|
and all(isinstance(x, TensorVariable) for x in stack_values) |
|
and len(set(stack_values)) == len(stack_values) |
|
and self.side_effects.is_empty() |
|
and not len(tx.debug_locals) != 0 |
|
and not self.backward_state |
|
): |
|
append_prefix_insts() |
|
|
|
self.add_output_instructions( |
|
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root) |
|
+ [create_instruction("UNPACK_SEQUENCE", arg=len(stack_values))] |
|
) |
|
|
|
self.add_output_instructions( |
|
[PyCodegen(tx).create_store(var) for var in reversed(restore_vars)] |
|
) |
|
else: |
|
graph_output_var = self.new_var("graph_out") |
|
pass1 = PyCodegen(tx, root, graph_output_var) |
|
self.codegen_suffix(tx, stack_values, pass1) |
|
|
|
|
|
pass2 = PyCodegen( |
|
tx, |
|
root, |
|
graph_output_var, |
|
tempvars={val: None for val, count in pass1.uses.items() if count > 1}, |
|
) |
|
self.codegen_suffix(tx, stack_values, pass2) |
|
|
|
stored_graph_output_var = False |
|
output = [] |
|
if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0: |
|
output.extend( |
|
self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root) |
|
) |
|
|
|
if len(pass2.graph_outputs) != 0: |
|
output.append(pass2.create_store(graph_output_var)) |
|
stored_graph_output_var = True |
|
else: |
|
output.append(create_instruction("POP_TOP")) |
|
append_prefix_insts() |
|
self.add_output_instructions(output + pass2.get_instructions()) |
|
|
|
|
|
self.add_output_instructions( |
|
[PyCodegen(tx).create_store(var) for var in reversed(restore_vars)] |
|
) |
|
|
|
if stored_graph_output_var: |
|
self.add_output_instructions( |
|
[PyCodegen(tx).create_delete(graph_output_var)] |
|
) |
|
|
|
def codegen_suffix(self, tx, stack_values, cg): |
|
if self.backward_state: |
|
assert not self.export |
|
for name, val in self.backward_state.items(): |
|
cg(val) |
|
cg.append_output(cg.create_load(self.backward_state_var)) |
|
cg.store_attr(name) |
|
self.side_effects.codegen_hooks(cg) |
|
self.side_effects.codegen_save_tempvars(cg) |
|
|
|
|
|
for debug_var, args in tx.debug_locals: |
|
cg(debug_var) |
|
for arg in args: |
|
cg(arg) |
|
cg.extend_output(create_call_function(len(args), True)) |
|
cg.extend_output([create_instruction("POP_TOP")]) |
|
|
|
cg.restore_stack(stack_values, value_from_source=not tx.export) |
|
self.side_effects.codegen_update_mutated(cg) |
|
|
|
def cleanup_graph(self): |
|
""" |
|
Remove "creation_timestamp" from node meta |
|
|
|
Remove this pattern from the graph: |
|
torch._C._set_grad_enabled(False) |
|
torch._C._set_grad_enabled(True) |
|
""" |
|
assert self.should_exit |
|
nodes = list(self.graph.nodes) |
|
for node in nodes: |
|
node.meta.pop("creation_timestamp", None) |
|
|
|
grad_enabled = torch.is_grad_enabled() |
|
for node1, node2 in zip(nodes, nodes[1:]): |
|
if ( |
|
node1.target is torch._C._set_grad_enabled |
|
and tuple(node1.args) == (not grad_enabled,) |
|
and not node1._erased |
|
): |
|
grad_enabled = node1.args[0] |
|
if ( |
|
node2.target is torch._C._set_grad_enabled |
|
and tuple(node2.args) == (not grad_enabled,) |
|
and not node2._erased |
|
): |
|
grad_enabled = node2.args[0] |
|
self.graph.erase_node(node1) |
|
self.graph.erase_node(node2) |
|
|
|
def get_graph_sizes_structured(self): |
|
ret = {} |
|
for node in self.graph.nodes: |
|
example_value = node.meta.get("example_value", None) |
|
if isinstance(example_value, torch._subclasses.FakeTensor): |
|
size = example_value.size() |
|
ret[node.name] = [s if isinstance(s, int) else repr(s) for s in size] |
|
return ret |
|
|
|
def get_graph_sizes(self, name: str): |
|
graph_sizes_str = "TRACED GRAPH TENSOR SIZES\n" |
|
graph_sizes_str += f"===== {name} =====\n" |
|
for node in self.graph.nodes: |
|
example_value = node.meta.get("example_value", None) |
|
if isinstance(example_value, torch._subclasses.FakeTensor): |
|
size = example_value.size() |
|
graph_sizes_str += f"{node.name}: {tuple(size)}\n" |
|
concrete_size = [] |
|
has_symint = False |
|
for sz in size: |
|
if isinstance(sz, int): |
|
concrete_size.append(sz) |
|
elif isinstance(sz, torch.SymInt): |
|
has_symint = True |
|
concrete_size.append(sz.node.hint) |
|
else: |
|
break |
|
else: |
|
if has_symint: |
|
graph_sizes_str += ( |
|
f"{node.name} (concrete): {tuple(concrete_size)}\n" |
|
) |
|
return graph_sizes_str |
|
|
|
@contextlib.contextmanager |
|
def restore_global_state(self): |
|
""" |
|
Momentarily restores the global state to what it was prior to tracing the current output |
|
""" |
|
prior_global_state = self.tracing_context.global_context.copy_graphstate() |
|
current_global_state: Dict[str, Tuple[Any, bool]] = {} |
|
self.save_global_state(out=current_global_state) |
|
try: |
|
|
|
self.tracing_context.global_context.restore_graphstate(prior_global_state) |
|
yield |
|
finally: |
|
|
|
self.tracing_context.global_context.restore_graphstate( |
|
GlobalContextCheckpointState(current_global_state) |
|
) |
|
|
|
@torch._guards.TracingContext.clear_frame() |
|
def compile_and_call_fx_graph(self, tx, rv, root): |
|
""" |
|
Generate code from self.graph and return the Instruction()s to |
|
call that generated code. |
|
""" |
|
from .decorators import disable |
|
|
|
assert self.should_exit |
|
|
|
name = unique_id("__compiled_fn") |
|
|
|
assert isinstance(rv, list) |
|
assert isinstance(root, FakeRootModule) |
|
self.create_node( |
|
"output", |
|
"output", |
|
(self.current_tracer.create_arg(tuple(x.as_proxy() for x in rv)),), |
|
{}, |
|
) |
|
if not config.do_not_emit_runtime_asserts: |
|
insert_deferred_runtime_asserts( |
|
fx.GraphModule(root, self.graph), |
|
self.shape_env, |
|
name, |
|
) |
|
|
|
|
|
self.remove_unused_graphargs() |
|
ncalls = count_calls(self.graph) |
|
counters["stats"]["calls_captured"] += ncalls |
|
|
|
|
|
self.real_value_cache.clear() |
|
|
|
gm = _make_graph_module(root, self.graph) |
|
for register_finalizer in self.register_finalizer_fns: |
|
register_finalizer(gm) |
|
|
|
gm.compile_subgraph_reason = self.compile_subgraph_reason |
|
gm.meta[ |
|
"dynamo_flat_name_to_original_fqn" |
|
] = self.dynamo_flat_name_to_original_fqn.copy() |
|
|
|
graph_code_log.debug( |
|
"%s", |
|
lazy_format_graph_code(name, gm, include_stride=True, include_device=True), |
|
) |
|
torch._logging.trace_structured( |
|
"dynamo_output_graph", |
|
lambda: {"sizes": self.get_graph_sizes_structured()}, |
|
payload_fn=lambda: gm.print_readable( |
|
print_output=False, include_stride=True, include_device=True |
|
), |
|
) |
|
self.call_cleanup_hooks() |
|
old_fake_mode = self.tracing_context.fake_mode |
|
if not self.export: |
|
import torch._functorch.config as _config |
|
|
|
with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False): |
|
|
|
backend_fake_mode = torch._subclasses.FakeTensorMode( |
|
shape_env=old_fake_mode.shape_env, |
|
) |
|
|
|
|
|
|
|
self.tracing_context.fake_mode = backend_fake_mode |
|
|
|
with self.restore_global_state(): |
|
compiled_fn = self.call_user_compiler(gm) |
|
|
|
from torch.fx._lazy_graph_module import _LazyGraphModule |
|
|
|
if isinstance(compiled_fn, _LazyGraphModule) or ( |
|
isinstance(getattr(compiled_fn, "__self__", None), _LazyGraphModule) |
|
and compiled_fn.__name__ == "_lazy_forward" |
|
): |
|
|
|
|
|
|
|
|
|
|
|
lazy_gm = ( |
|
compiled_fn |
|
if isinstance(compiled_fn, _LazyGraphModule) |
|
else compiled_fn.__self__ |
|
) |
|
|
|
_LazyGraphModule.force_recompile(lazy_gm) |
|
|
|
if not isinstance(compiled_fn, _LazyGraphModule): |
|
|
|
compiled_fn = lazy_gm.forward |
|
|
|
compiled_fn = disable(compiled_fn) |
|
|
|
counters["stats"]["unique_graphs"] += 1 |
|
|
|
self.install_global_unsafe(name, compiled_fn) |
|
|
|
cg = PyCodegen(tx) |
|
cg.make_call_generated_code(name) |
|
return cg.get_instructions() |
|
|
|
@property |
|
def placeholders(self) -> List[fx.Node]: |
|
return self.graph.find_nodes(op="placeholder") |
|
|
|
@property |
|
def graphargs(self) -> List[GraphArg]: |
|
return [node.meta["grapharg"] for node in self.placeholders] |
|
|
|
@dynamo_timed(phase_name="backend_compile") |
|
def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn: |
|
assert self.compiler_fn is not None |
|
tot = 0 |
|
placeholders = [] |
|
for node in gm.graph.nodes: |
|
if node.op in ("call_function", "call_method", "call_module"): |
|
tot += 1 |
|
if node.op == "placeholder": |
|
placeholders.append(node) |
|
increment_op_count(tot) |
|
for pl in placeholders: |
|
arg = pl.meta["grapharg"] |
|
|
|
pl._dynamo_source = arg.source |
|
|
|
gm._param_name_to_source = self.param_name_to_source |
|
gm._source_to_user_stacks = self.source_to_user_stacks |
|
|
|
try: |
|
name = ( |
|
self.compiler_fn.__name__ |
|
if hasattr(self.compiler_fn, "__name__") |
|
else "" |
|
) |
|
_step_logger()(logging.INFO, f"calling compiler function {name}") |
|
compiler_fn = self.compiler_fn |
|
if config.verify_correctness: |
|
compiler_fn = WrapperBackend(compiler_fn) |
|
compiled_fn = compiler_fn(gm, self.example_inputs()) |
|
_step_logger()(logging.INFO, f"done compiler function {name}") |
|
assert callable(compiled_fn), "compiler_fn did not return callable" |
|
except exceptions_allowed_to_be_fallback as e: |
|
if self.has_user_defined_allowed_in_graph: |
|
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback( |
|
e.__traceback__ |
|
) from None |
|
msg = ( |
|
"Backend compiler failed with a fake tensor exception at \n" |
|
f"{self.root_tx.format_frame_summary()}" |
|
"Adding a graph break." |
|
) |
|
unimplemented_with_warning(e, self.root_tx.f_code, msg) |
|
except SkipFrame as e: |
|
|
|
|
|
raise e |
|
except Exception as e: |
|
raise BackendCompilerFailed(self.compiler_fn, e).with_traceback( |
|
e.__traceback__ |
|
) from None |
|
|
|
signpost_event( |
|
"dynamo", |
|
"OutputGraph.call_user_compiler", |
|
{ |
|
**self.co_fields, |
|
"op_count": tot, |
|
"node_count": len(gm.graph.nodes), |
|
"input_count": len(placeholders), |
|
}, |
|
) |
|
|
|
return compiled_fn |
|
|
|
def example_inputs(self) -> List[torch.Tensor]: |
|
result = [] |
|
for arg in self.graphargs: |
|
result.append(arg.example) |
|
return result |
|
|
|
def remove_unused_graphargs(self) -> None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert self.should_exit |
|
|
|
|
|
def is_static_true(b_node: fx.node.Argument): |
|
if b_node is True: |
|
return True |
|
if not isinstance(b_node, fx.Node): |
|
return False |
|
b = b_node.meta.get("example_value") |
|
if b is None: |
|
return False |
|
if b is True: |
|
return True |
|
if ( |
|
isinstance(b, torch.SymBool) |
|
and (r := b.node.maybe_as_bool()) is not None |
|
): |
|
return r |
|
|
|
|
|
return False |
|
|
|
def is_symnode_arg(a: fx.node.Argument): |
|
from torch.fx.experimental.sym_node import SymTypes |
|
|
|
if isinstance(a, (int, float, bool)): |
|
return True |
|
if isinstance(a, fx.Node): |
|
return isinstance(a.meta.get("example_value"), SymTypes) |
|
return False |
|
|
|
|
|
|
|
|
|
def is_symnode_compute_node(node): |
|
from torch.fx.experimental.sym_node import SymTypes |
|
|
|
if node.op != "call_function": |
|
return False |
|
|
|
if not isinstance(node.meta.get("example_value"), SymTypes): |
|
return False |
|
|
|
|
|
|
|
if not all(is_symnode_arg(a) for a in node.args): |
|
return False |
|
if not all(is_symnode_arg(a) for a in node.kwargs.values()): |
|
return False |
|
return True |
|
|
|
|
|
|
|
|
|
|
|
def is_accessor_node(node): |
|
if ( |
|
node.op == "call_method" |
|
and isinstance(node.args[0].meta.get("example_value"), torch.Tensor) |
|
and node.target in ["size", "stride", "storage_offset", "item"] |
|
): |
|
return True |
|
if node.op == "call_function" and node.target in [ |
|
torch.ops.aten.sym_size, |
|
torch.ops.aten.sym_size.default, |
|
torch.ops.aten.sym_size.int, |
|
torch.ops.aten.sym_stride, |
|
torch.ops.aten.sym_stride.default, |
|
torch.ops.aten.sym_stride.int, |
|
torch.ops.aten.sym_storage_offset, |
|
torch.ops.aten.sym_storage_offset.default, |
|
]: |
|
return True |
|
return False |
|
|
|
for node in reversed(list(self.graph.nodes)): |
|
if len(list(node.users)) == 0: |
|
if ( |
|
node.op == "get_attr" |
|
or (node.op == "call_function" and node.target is operator.getitem) |
|
or ( |
|
node.op == "call_function" |
|
and node.target is torch._check |
|
and is_static_true(node.args[0]) |
|
) |
|
or is_symnode_compute_node(node) |
|
or is_accessor_node(node) |
|
): |
|
self.remove_node(node) |
|
|
|
def placeholder_binds_symbol(node): |
|
arg = node.meta["grapharg"] |
|
example = arg.example |
|
if isinstance(example, torch.SymInt) and isinstance( |
|
example.node.expr, sympy.Symbol |
|
): |
|
return example.node.expr |
|
return None |
|
|
|
def remove_unused(node): |
|
log.debug("REMOVE UNUSED GRAPHARG %s", node.meta["grapharg"].source.name()) |
|
|
|
|
|
del node.meta["grapharg"] |
|
self.remove_node(node) |
|
self.real_value_cache.pop(node, None) |
|
|
|
used_symbols: Set[sympy.Symbol] = set() |
|
|
|
def update_used_symbols(used_symbols, fake: Union[torch.SymInt, torch.Tensor]): |
|
used_symbols |= free_symbols(fake) |
|
|
|
recheck_placeholders = [] |
|
for node in self.placeholders: |
|
binds_symbol = placeholder_binds_symbol(node) is not None |
|
|
|
if binds_symbol: |
|
if not node.users: |
|
recheck_placeholders.append(node) |
|
else: |
|
if not node.users and not isinstance( |
|
node.meta["grapharg"], BackwardStateGraphArg |
|
): |
|
remove_unused(node) |
|
else: |
|
|
|
arg = node.meta["grapharg"] |
|
if isinstance(arg, BackwardStateGraphArg): |
|
continue |
|
if isinstance(node.meta["grapharg"].example, torch.ScriptObject): |
|
real_script_obj = node.meta["grapharg"].example |
|
fake_script_obj = node.meta["grapharg"].example_strong_ref |
|
flat_dict = dict(real_script_obj.__obj_flatten__()) |
|
for attr in flat_dict.keys(): |
|
fake_attr_val = getattr(fake_script_obj.wrapped_obj, attr) |
|
pytree.tree_map_only( |
|
(torch.SymInt, torch.Tensor), |
|
lambda t: update_used_symbols(used_symbols, t), |
|
fake_attr_val, |
|
) |
|
continue |
|
fake = ( |
|
arg.fake_tensor if arg.fake_tensor is not None else arg.example |
|
) |
|
update_used_symbols(used_symbols, fake) |
|
|
|
|
|
for node in recheck_placeholders: |
|
symbol = placeholder_binds_symbol(node) |
|
if symbol is not None: |
|
if symbol not in used_symbols: |
|
remove_unused(node) |
|
else: |
|
|
|
used_symbols.remove(symbol) |
|
|
|
def add_output_instructions(self, prefix: List[Instruction]) -> None: |
|
""" |
|
We call this on the creation of a new compiled subgraph that is inserted |
|
before user code. |
|
""" |
|
self.output_instructions.extend(prefix) |
|
self.should_exit = True |
|
|
|
def install_global_unsafe(self, name, value) -> None: |
|
""" |
|
WARNING: prefer the safer `install_global_by_id/install_global`. |
|
torch.compile instances should be independent of each other; |
|
one footgun is to have one instance depend on the existence of |
|
a global installed by another instance. This can happen if we mangle |
|
a global the same way across both instances. |
|
""" |
|
assert name not in self.installed_globals |
|
self.installed_globals.add(name) |
|
self.cleanups.append(CleanupHook.create(self.global_scope, name, value)) |
|
|
|
def install_global_by_id(self, prefix, value) -> str: |
|
""" |
|
Installs a global if it hasn't been installed already. |
|
This is determined by (prefix, id(value)) pair. |
|
|
|
Returns the name of the newly installed global. |
|
""" |
|
|
|
|
|
name = f"{prefix}_{id(value)}_c{self.compile_id}" |
|
if name in self.installed_globals: |
|
return name |
|
self.install_global_unsafe(name, value) |
|
return name |
|
|
|
def install_global(self, prefix, value) -> str: |
|
""" |
|
Installs a global, generating a unique name for it. |
|
|
|
Returns the name of the newly installed global. |
|
""" |
|
|
|
name = unique_id(prefix) |
|
self.install_global_unsafe(name, value) |
|
return name |
|
|
|
def cleanup(self) -> None: |
|
|
|
|
|
self.root_tx = None |
|
self.nn_modules.clear() |
|
self.param_name_to_source = None |
|
|
|
for node in self.graph.nodes: |
|
if "grapharg" in node.meta: |
|
del node.meta["grapharg"] |
|
self.real_value_cache.clear() |
|
self.input_name_to_proxy.clear() |
|
self.side_effects.clear() |
|
self.variable_tracker_cache.clear() |
|
self.register_finalizer_fns.clear() |
|
self.dynamo_flat_name_to_original_fqn.clear() |
|
self.tracing_context.clear() |
|
|
|
def set_torch_function_state(self, enabled: bool) -> None: |
|
self.torch_function_enabled = enabled |
|
|
|
def add_graph_finalizer( |
|
self, register_finalizer: Callable[[fx.GraphModule], None] |
|
) -> None: |
|
self.register_finalizer_fns.append(register_finalizer) |
|
|
|
def example_value_from_input_node(self, node: torch.fx.Node): |
|
"""Extract the non-fake example tensor""" |
|
if node.op == "placeholder": |
|
return node.meta["grapharg"].example |
|
assert node.op == "get_attr" |
|
return self.nn_modules[node.target] |
|
|
|
|
|
err_epilogue = ( |
|
"With the current config, we will graph break " |
|
"(and fall back to eager-mode PyTorch) on all ops " |
|
"that have do not have the 'pt2_compliant_tag'. " |
|
"Please see the following doc for how to mark this op as PT2 compliant " |
|
"https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html" |
|
) |
|
|
|
|
|
def check_pt2_compliant_op(output_graph, kind, target, args, kwargs): |
|
if kind != "call_function": |
|
return |
|
|
|
def encountered_compliant_op(target): |
|
if target.namespace in {"prim", "prims", "aten"}: |
|
return |
|
output_graph.compliant_custom_ops.add(target) |
|
|
|
def encountered_non_compliant_op(target, msg): |
|
output_graph.non_compliant_ops.add(target) |
|
if config.only_allow_pt2_compliant_ops: |
|
unimplemented(msg + " " + err_epilogue) |
|
|
|
if isinstance(target, torch._ops.OpOverload): |
|
if torch.Tag.pt2_compliant_tag in target.tags: |
|
encountered_compliant_op(target) |
|
return |
|
encountered_non_compliant_op( |
|
target, |
|
f"Encountered the torch.ops.OpOverload {target} " |
|
f"that is not PT2 compliant.", |
|
) |
|
return |
|
|
|
if isinstance(target, torch._ops.OpOverloadPacket): |
|
overloads = tuple(target.overloads()) |
|
|
|
|
|
if len(overloads) == 1: |
|
op = getattr(target, overloads[0]) |
|
if torch.Tag.pt2_compliant_tag in op.tags: |
|
encountered_compliant_op(op) |
|
return |
|
encountered_non_compliant_op( |
|
op, |
|
f"Encountered the non-overloaded " |
|
f"torch.ops.OpOverloadPacket {target} " |
|
f"that is not PT2 compliant. ", |
|
) |
|
return |
|
|
|
args, kwargs = torch._dynamo.utils.get_fake_values_from_nodes( |
|
output_graph.current_tx, (args, kwargs), False |
|
) |
|
try: |
|
overload = torch._C._jit_resolve_packet( |
|
target._qualified_op_name, *args, **kwargs |
|
) |
|
except RuntimeError as e: |
|
unimplemented(str(e)) |
|
|
|
op = getattr(target, overload) |
|
if torch.Tag.pt2_compliant_tag in op.tags: |
|
encountered_compliant_op(op) |
|
else: |
|
encountered_non_compliant_op( |
|
op, |
|
f"Encountered the torch.ops.OpOverloadPacket {target} " |
|
f"which resolves to the overload ({overload}) that is " |
|
f"not PT2 compliant.", |
|
) |
|
|
|
|
|
_compile_id_counter = itertools.count() |
|
|
|
|
|
class SubgraphTracer(fx.Tracer): |
|
""" |
|
Holds an FX graph that is being traced. OutputGraph owns a SubgraphTracer |
|
and the separation of responsibilities is that SubgraphTracer is |
|
responsible for building the graph while OutputGraph is responsible for |
|
compiling and executing the graph. |
|
""" |
|
|
|
def __init__( |
|
self, output_graph, parent=None, export_root=False, source_target=None |
|
): |
|
super().__init__() |
|
self.output_graph = weakref.proxy(output_graph) |
|
self.graph = torch.fx.Graph() |
|
|
|
|
|
|
|
|
|
if export_root: |
|
assert parent is None |
|
self.export_root = export_root |
|
|
|
|
|
|
|
self.input_name_to_proxy: Dict[str, fx.Proxy] = {} |
|
|
|
self.real_value_cache: Dict[fx.Node, torch.Tensor] = {} |
|
|
|
|
|
self.parent = parent |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.lifted_freevars = {} |
|
self.prev_inst = None |
|
|
|
self._cur_code = None |
|
self._orig_gm_meta = None |
|
self._orig_gm_lineno_map = None |
|
self._orig_gm_firstlineno = None |
|
|
|
|
|
|
|
|
|
if self.parent is None: |
|
self.source_fn_stack = [] |
|
else: |
|
self.source_fn_stack = self.parent.source_fn_stack + [ |
|
(self.graph._target_to_str(source_target), source_target) |
|
] |
|
|
|
def create_proxy( |
|
self, |
|
kind, |
|
target, |
|
args, |
|
kwargs, |
|
name=None, |
|
type_expr=None, |
|
proxy_factory_fn=None, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.parent is not None: |
|
flat_args, tree_spec = pytree.tree_flatten((args, kwargs)) |
|
new_flat_args = [] |
|
for arg in flat_args: |
|
maybe_new_arg = self.maybe_lift_tracked_freevar_to_input(arg) |
|
new_flat_args.append(maybe_new_arg) |
|
|
|
args, kwargs = pytree.tree_unflatten(new_flat_args, tree_spec) |
|
|
|
rv = super().create_proxy( |
|
kind, target, args, kwargs, name, type_expr, proxy_factory_fn |
|
) |
|
|
|
|
|
tx = self.output_graph.current_tx |
|
|
|
|
|
if sys.version_info >= (3, 11) and kind in ( |
|
"call_function", |
|
"call_method", |
|
"call_module", |
|
): |
|
cur_inst = tx.current_instruction |
|
if ( |
|
cur_inst is not self.prev_inst |
|
and cur_inst.positions is not None |
|
and cur_inst.positions.lineno is not None |
|
): |
|
tx_code = tx.f_code |
|
header = tx.get_line_of_code_header(lineno=cur_inst.positions.lineno) |
|
|
|
def get_trace_call_log_str(): |
|
line = get_instruction_source_311(tx_code, cur_inst).rstrip() |
|
return f"TRACE FX call {rv.node.name} from {header}\n{line}" |
|
|
|
trace_call_log.debug("%s", LazyString(get_trace_call_log_str)) |
|
self.prev_inst = cur_inst |
|
|
|
|
|
is_retracing = False |
|
if tx.f_code is not self._cur_code: |
|
orig_graphmodule_maybe = code_context.get_context(tx.f_code).get( |
|
"orig_graphmodule", lambda: None |
|
)() |
|
if isinstance(orig_graphmodule_maybe, torch.fx.GraphModule): |
|
is_retracing = True |
|
self._orig_gm_meta = [ |
|
nd.meta for nd in orig_graphmodule_maybe.graph.nodes |
|
] |
|
self._orig_gm_lineno_map = orig_graphmodule_maybe._lineno_map |
|
self._orig_gm_firstlineno = ( |
|
orig_graphmodule_maybe.forward.__code__.co_firstlineno |
|
) |
|
else: |
|
self._orig_gm_meta = None |
|
self._orig_gm_lineno_map = None |
|
self._orig_gm_firstlineno = None |
|
nn_module_stack = tx.nn_module_stack |
|
if nn_module_stack: |
|
rv.node.meta["nn_module_stack"] = nn_module_stack.copy() |
|
|
|
if kind in {"call_function", "call_method"}: |
|
rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ |
|
(rv.node.name, target) |
|
] |
|
elif kind == "call_module": |
|
if self.parent is not None: |
|
unimplemented("Invoking an nn.Module inside HigherOrderOperator") |
|
|
|
rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ |
|
( |
|
rv.node.name, |
|
rv.node.meta["nn_module_stack"][target][1], |
|
) |
|
] |
|
|
|
|
|
if ( |
|
self._orig_gm_meta |
|
and self._orig_gm_lineno_map |
|
and self._orig_gm_firstlineno |
|
): |
|
lineno = tx.current_instruction.starts_line |
|
node_idx = None |
|
if lineno is not None: |
|
node_idx = self._orig_gm_lineno_map.get( |
|
lineno - self._orig_gm_firstlineno, None |
|
) |
|
if node_idx is not None: |
|
meta = self._orig_gm_meta[node_idx] |
|
for field in fx.proxy._COPY_META_FIELDS: |
|
if field in meta: |
|
rv.node.meta[field] = meta[field] |
|
if "stack_trace" in meta: |
|
rv.node.meta["stack_trace"] = meta["stack_trace"] |
|
|
|
if not is_retracing: |
|
if "nn_module_stack" not in rv.node.meta: |
|
nn_module_stack = tx.nn_module_stack |
|
if nn_module_stack: |
|
rv.node.meta["nn_module_stack"] = nn_module_stack.copy() |
|
|
|
if "source_fn_stack" not in rv.node.meta: |
|
if kind in {"call_function", "call_method"}: |
|
rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ |
|
(rv.node.name, target) |
|
] |
|
elif kind == "call_module": |
|
if self.parent is not None: |
|
unimplemented( |
|
"Invoking an nn.Module inside HigherOrderOperator" |
|
) |
|
|
|
rv.node.meta["source_fn_stack"] = self.source_fn_stack + [ |
|
( |
|
rv.node.name, |
|
rv.node.meta["nn_module_stack"][target][1], |
|
) |
|
] |
|
|
|
if "stack_trace" not in rv.node.meta: |
|
frame_summaries: List[traceback.FrameSummary] = [] |
|
while tx: |
|
frame_summaries.append(tx.frame_summary()) |
|
tx = getattr(tx, "parent", None) |
|
|
|
frame_summaries.reverse() |
|
|
|
|
|
msgs = traceback.StackSummary.from_list(frame_summaries).format() |
|
rv.node.stack_trace = "".join(msgs) |
|
|
|
return rv |
|
|
|
def create_node( |
|
self, op, target, args=None, kwargs=None, name=None, type_expr=None |
|
): |
|
check_pt2_compliant_op(self.output_graph, op, target, args, kwargs) |
|
if self.parent is not None: |
|
flat_args = pytree.arg_tree_leaves(*args, **kwargs) |
|
for arg in flat_args: |
|
if not isinstance(arg, torch.fx.Node): |
|
continue |
|
assert ( |
|
arg.graph == self.graph |
|
), "create_node using arg not from this SubgraphTracer" |
|
|
|
node = super().create_node(op, target, args, kwargs, name, type_expr) |
|
node.meta["creation_timestamp"] = self.output_graph.timestamp |
|
return node |
|
|
|
|
|
|
|
def remove_node(self, node): |
|
if len(node.users) > 0: |
|
user_graph_nodes: List[torch.fx.Node] = [] |
|
for user in node.users.keys(): |
|
|
|
|
|
if user.graph != self.graph: |
|
|
|
|
|
|
|
user_graph_nodes.extend(reversed(list(user.graph.nodes))) |
|
for other_graph_node in user_graph_nodes: |
|
other_graph_node.graph.erase_node(other_graph_node) |
|
self.graph.erase_node(node) |
|
self.input_name_to_proxy.pop(node.name, None) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_graph_input(self, name, type_expr=None, before=False, source=None): |
|
log.debug( |
|
"create_graph_input %s %s", |
|
name, |
|
source.name() if source is not None else "(none)", |
|
) |
|
if source is None: |
|
assert ( |
|
self.parent is not None |
|
), "you are required to provide a source for inputs on the root tracer" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.export_root: |
|
if not is_from_local_source(source, allow_cell_or_freevar=False): |
|
self.output_graph.source_to_user_stacks.setdefault(source, []).append( |
|
TracingContext.extract_stack() |
|
) |
|
|
|
|
|
if name in self.input_name_to_proxy: |
|
for i in itertools.count(): |
|
candidate_name = f"{name}_{i}" |
|
if candidate_name not in self.input_name_to_proxy: |
|
name = candidate_name |
|
break |
|
|
|
if self.input_name_to_proxy: |
|
prev_name = next(reversed(self.input_name_to_proxy)) |
|
node = self.input_name_to_proxy[prev_name].node |
|
if before: |
|
ctx = self.graph.inserting_before(node) |
|
else: |
|
ctx = self.graph.inserting_after(node) |
|
else: |
|
ctx = self.graph.inserting_before(None) |
|
with ctx: |
|
proxy = self.create_proxy("placeholder", name, (), {}, type_expr=type_expr) |
|
if self.input_name_to_proxy and before: |
|
k, v = self.input_name_to_proxy.popitem() |
|
self.input_name_to_proxy[name] = proxy |
|
self.input_name_to_proxy[k] = v |
|
else: |
|
self.input_name_to_proxy[name] = proxy |
|
return proxy |
|
|
|
|
|
def lift_tracked_freevar_to_input(self, proxy): |
|
|
|
|
|
assert ( |
|
self.parent is not None |
|
), "lift_tracked_freevar_to_input should not be called on root SubgraphTracer" |
|
|
|
|
|
|
|
if proxy in self.lifted_freevars: |
|
return self.lifted_freevars[proxy] |
|
new_proxy = self.create_graph_input(proxy.node.name) |
|
set_example_value(new_proxy.node, proxy.node.meta["example_value"]) |
|
self.lifted_freevars[proxy] = new_proxy |
|
if self.parent is not None and proxy.tracer != self.parent: |
|
self.parent.lift_tracked_freevar_to_input(proxy) |
|
return new_proxy |
|
|
|
def maybe_lift_tracked_freevar_to_input(self, arg): |
|
""" |
|
If arg is a free variable, then lift it to be an input. |
|
Returns the new lifted arg (if arg was a freevar), else the |
|
original arg. |
|
""" |
|
if not isinstance(arg, torch.fx.Proxy): |
|
return arg |
|
elif arg.tracer == self: |
|
return arg |
|
return self.lift_tracked_freevar_to_input(arg) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|