|
|
|
|
|
|
|
""" |
|
Functions in this file are responsible for modifying the eval frame |
|
handler at RUNTIME. Therefore, all functions in this file are hot. |
|
Functions that only execute at compile time should be placed |
|
in torch._dynamo.convert_frame. |
|
""" |
|
|
|
from __future__ import annotations |
|
|
|
import contextlib |
|
import functools |
|
import inspect |
|
import logging |
|
import os |
|
import sys |
|
import textwrap |
|
import traceback |
|
import types |
|
import warnings |
|
import weakref |
|
from enum import Enum |
|
from os.path import dirname, join |
|
from typing import ( |
|
Any, |
|
Callable, |
|
Dict, |
|
List, |
|
NamedTuple, |
|
Optional, |
|
Set, |
|
Tuple, |
|
TYPE_CHECKING, |
|
Union, |
|
) |
|
from unittest.mock import patch |
|
|
|
import torch |
|
import torch.fx |
|
import torch.utils._pytree as pytree |
|
import torch.utils.checkpoint |
|
from torch import _guards |
|
from torch._utils_internal import log_export_usage |
|
from torch.export.dynamic_shapes import _process_dynamic_shapes |
|
from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode |
|
from torch.fx.experimental.symbolic_shapes import ( |
|
ConstraintViolationError, |
|
DimDynamic, |
|
StatelessSymbolicContext, |
|
) |
|
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo |
|
|
|
from ..fx import GraphModule |
|
from .backends.registry import CompilerFn, lookup_backend |
|
|
|
from .hooks import Hooks |
|
|
|
|
|
reset_code = torch._C._dynamo.eval_frame.reset_code |
|
set_eval_frame = torch._C._dynamo.eval_frame.set_eval_frame |
|
set_guard_error_hook = torch._C._dynamo.eval_frame.set_guard_error_hook |
|
skip_code = torch._C._dynamo.eval_frame.skip_code |
|
unsupported = torch._C._dynamo.eval_frame.unsupported |
|
|
|
from . import config, convert_frame, external_utils, trace_rules, utils |
|
from .code_context import code_context |
|
from .exc import CondOpArgsMismatchError, UserError, UserErrorType |
|
from .mutation_guard import install_generation_tagging_init |
|
from .utils import common_constant_types, compile_times |
|
|
|
log = logging.getLogger(__name__) |
|
|
|
from torch._dispatch.python import enable_python_dispatcher |
|
|
|
always_optimize_code_objects = utils.ExactWeakKeyDictionary() |
|
null_context = contextlib.nullcontext |
|
|
|
|
|
import sympy |
|
|
|
if TYPE_CHECKING: |
|
from torch._subclasses import fake_tensor |
|
from .types import CacheEntry, DynamoCallback |
|
|
|
|
|
|
|
class Unset(Enum): |
|
token = 0 |
|
|
|
|
|
cached_backends: Dict[int, CompilerFn] = {} |
|
|
|
unset = Unset.token |
|
|
|
|
|
def _reset_guarded_backend_cache(): |
|
global cached_backends |
|
for backend in cached_backends.values(): |
|
if hasattr(backend, "reset"): |
|
backend.reset() |
|
cached_backends.clear() |
|
|
|
|
|
DONT_WRAP_FILES = { |
|
|
|
inspect.getsourcefile(GraphModule), |
|
join(dirname(dirname(__file__)), "onnx/_internal/fx/dynamo_graph_extractor.py"), |
|
} |
|
|
|
|
|
def _debug_get_cache_entry_list( |
|
code: Union[types.CodeType, Callable[..., Any]] |
|
) -> List[CacheEntry]: |
|
""" |
|
Given a code object or a callable object, retrieve the cache entries |
|
stored in this code. |
|
""" |
|
if callable(code): |
|
code = code.__code__ |
|
return torch._C._dynamo.eval_frame._debug_get_cache_entry_list(code) |
|
|
|
|
|
class OptimizedModule(torch.nn.Module): |
|
""" |
|
Wraps the original nn.Module object and later patches its |
|
forward method to optimized self.forward method. |
|
""" |
|
|
|
_torchdynamo_orig_callable: Callable[..., Any] |
|
get_compiler_config: Callable[[], Any] |
|
|
|
_opt_mod_attributes = { |
|
"_orig_mod", |
|
"dynamo_ctx", |
|
"_torchdynamo_orig_callable", |
|
"get_compiler_config", |
|
"forward", |
|
"_forward", |
|
"__dict__", |
|
"named_children_walk", |
|
} |
|
|
|
def __init__(self, mod: torch.nn.Module, dynamo_ctx): |
|
super().__init__() |
|
|
|
self._orig_mod = mod |
|
self.dynamo_ctx = dynamo_ctx |
|
self._initialize() |
|
|
|
def _initialize(self): |
|
|
|
if isinstance(self.dynamo_ctx, DisableContext): |
|
|
|
self.forward = self.dynamo_ctx(self._orig_mod.__call__) |
|
elif isinstance(self._orig_mod.forward, types.MethodType) and trace_rules.check( |
|
self._orig_mod.forward |
|
): |
|
|
|
|
|
|
|
self.forward = self.dynamo_ctx(external_utils.wrap_inline(self._orig_mod)) |
|
else: |
|
|
|
self.forward = self.dynamo_ctx(self._orig_mod.__call__) |
|
|
|
if hasattr(self._orig_mod, "_initialize_hook"): |
|
self._forward = self.forward |
|
self.forward = self._call_lazy_check |
|
|
|
def __reduce__(self): |
|
return (self.__class__, (self._orig_mod, self.dynamo_ctx)) |
|
|
|
def __getstate__(self): |
|
state = dict(self.__dict__) |
|
state.pop("forward", None) |
|
state.pop("__call__", None) |
|
return state |
|
|
|
def __setstate__(self, state): |
|
self.__dict__ = state |
|
self._initialize() |
|
|
|
def __getattr__(self, name): |
|
if name == "_orig_mod": |
|
return self._modules["_orig_mod"] |
|
return getattr(self._orig_mod, name) |
|
|
|
def __setattr__(self, name, val): |
|
|
|
if hasattr(type(self), name): |
|
return super().__setattr__(name, val) |
|
|
|
if name in OptimizedModule._opt_mod_attributes: |
|
return super().__setattr__(name, val) |
|
return setattr(self._orig_mod, name, val) |
|
|
|
def _call_lazy_check(self, *args, **kwargs): |
|
if hasattr(self._orig_mod, "_initialize_hook"): |
|
|
|
|
|
|
|
|
|
self._orig_mod._infer_parameters(self._orig_mod, args, kwargs) |
|
return self._forward(*args, **kwargs) |
|
|
|
def __dir__(self): |
|
orig_mod_attrs = self._orig_mod.__dir__() |
|
return orig_mod_attrs + [ |
|
attr for attr in super().__dir__() if attr not in orig_mod_attrs |
|
] |
|
|
|
|
|
def remove_from_cache(f): |
|
""" |
|
Make sure f.__code__ is not cached to force a recompile |
|
""" |
|
if isinstance(f, types.CodeType): |
|
reset_code(f) |
|
elif hasattr(f, "__code__"): |
|
reset_code(f.__code__) |
|
elif hasattr(getattr(f, "forward", None), "__code__"): |
|
reset_code(f.forward.__code__) |
|
else: |
|
from . import reset |
|
|
|
reset() |
|
log.warning("could not determine __code__ for %s", f) |
|
|
|
|
|
def nothing(): |
|
pass |
|
|
|
|
|
def always_false(): |
|
return False |
|
|
|
|
|
def innermost_fn(fn): |
|
""" |
|
In case of nesting of _TorchDynamoContext calls, find the innermost |
|
function. TorchDynamo caches on fn.__code__ object, so its necessary to find |
|
the innermost function to pass on the optimize, run, disable etc. |
|
""" |
|
unaltered_fn = fn |
|
while hasattr(unaltered_fn, "_torchdynamo_orig_callable"): |
|
unaltered_fn = unaltered_fn._torchdynamo_orig_callable |
|
assert callable(unaltered_fn) |
|
return unaltered_fn |
|
|
|
|
|
def make_set_enable_dynamic(enable: bool): |
|
assert isinstance(enable, bool) |
|
if enable: |
|
|
|
return config._make_closure_patcher(assume_static_by_default=False) |
|
else: |
|
return config._make_closure_patcher( |
|
automatic_dynamic_shapes=False, assume_static_by_default=True |
|
) |
|
|
|
|
|
class _TorchDynamoContext: |
|
def __init__( |
|
self, |
|
callback: DynamoCallback, |
|
on_enter=nothing, |
|
backend_ctx_ctor=null_context, |
|
patch_fn=nothing, |
|
first_ctx=False, |
|
*, |
|
export=False, |
|
dynamic=None, |
|
compiler_config=None, |
|
): |
|
super().__init__() |
|
assert callable(callback) or callback is False or callback is None |
|
self.callback: DynamoCallback = callback |
|
self._backend_ctx_ctor = backend_ctx_ctor |
|
self.prior: Union[Unset, DynamoCallback] = unset |
|
self.first_ctx = first_ctx |
|
self.export = export |
|
self._dynamic = dynamic |
|
self.compiler_config = compiler_config |
|
self.cleanup_fns: List[Callable[[], Any]] = [] |
|
self.enter_exit_hooks = [] |
|
patch_fn() |
|
|
|
|
|
backend = innermost_fn(callback) |
|
cached_backends.setdefault(id(backend), backend) |
|
|
|
if dynamic is not None: |
|
self.enter_exit_hooks.append(make_set_enable_dynamic(dynamic)) |
|
|
|
if on_enter is not nothing: |
|
|
|
def call_on_enter(): |
|
on_enter() |
|
return nothing |
|
|
|
self.enter_exit_hooks.append(call_on_enter) |
|
|
|
if backend_ctx_ctor is not contextlib.nullcontext: |
|
|
|
def call_backend_ctx(): |
|
ctx = backend_ctx_ctor() |
|
ctx.__enter__() |
|
return functools.partial(ctx.__exit__, None, None, None) |
|
|
|
self.enter_exit_hooks.append(call_backend_ctx) |
|
|
|
def __enter__(self): |
|
if config.raise_on_ctx_manager_usage: |
|
raise RuntimeError( |
|
"torch._dynamo.optimize(...) is used with a context manager. " |
|
"Please refer to https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html " |
|
"to use torch._dynamo.optimize(...) as an annotation/decorator. " |
|
) |
|
self.cleanup_fns = [enter() for enter in self.enter_exit_hooks] |
|
self.prior = set_eval_frame(self.callback) |
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
assert self.prior is not unset |
|
set_eval_frame(self.prior) |
|
self.prior = unset |
|
for cleanup in self.cleanup_fns: |
|
cleanup() |
|
self.cleanup_fns.clear() |
|
|
|
def __call__(self, fn): |
|
|
|
def get_compiler_config(): |
|
return self.compiler_config |
|
|
|
fn = innermost_fn(fn) |
|
|
|
|
|
if isinstance(fn, GraphModule): |
|
|
|
code_context.get_context(fn.forward.__code__)[ |
|
"orig_graphmodule" |
|
] = weakref.ref(fn) |
|
|
|
|
|
if isinstance(fn, torch.nn.Module): |
|
mod = fn |
|
new_mod = OptimizedModule(mod, self) |
|
|
|
|
|
new_mod._torchdynamo_orig_callable = mod.forward |
|
|
|
|
|
|
|
assert not hasattr(new_mod, "get_compiler_config") |
|
new_mod.get_compiler_config = get_compiler_config |
|
|
|
return new_mod |
|
|
|
if inspect.isclass(fn): |
|
|
|
|
|
cls_obj = fn |
|
cls_obj.__call__ = self(cls_obj.__call__) |
|
if issubclass(cls_obj, torch.nn.Module): |
|
|
|
cls_obj._call_impl = self(cls_obj._call_impl) |
|
return cls_obj |
|
|
|
assert callable(fn) |
|
|
|
try: |
|
filename = inspect.getsourcefile(fn) |
|
except TypeError: |
|
filename = None |
|
if ( |
|
(filename is None or trace_rules.check(fn)) |
|
and ( |
|
getattr(fn, "__name__", "") |
|
not in ["_call_impl", "_wrapped_call_impl", "_lazy_forward"] |
|
) |
|
and filename not in DONT_WRAP_FILES |
|
): |
|
|
|
fn = external_utils.wrap_inline(fn) |
|
|
|
def do_nothing(*arg, **kwargs): |
|
pass |
|
|
|
if hasattr(self, "callback"): |
|
callback = self.callback |
|
else: |
|
callback = do_nothing |
|
|
|
is_jit_tracing = torch._C._is_tracing |
|
is_fx_tracing = torch.fx._symbolic_trace.is_fx_tracing |
|
|
|
@functools.wraps(fn) |
|
def _fn(*args, **kwargs): |
|
if is_fx_tracing(): |
|
if config.error_on_nested_fx_trace: |
|
raise RuntimeError( |
|
"Detected that you are using FX to symbolically trace " |
|
"a dynamo-optimized function. This is not supported at the moment." |
|
) |
|
else: |
|
return fn(*args, **kwargs) |
|
|
|
if is_jit_tracing(): |
|
if config.error_on_nested_jit_trace: |
|
raise RuntimeError( |
|
"Detected that you are using FX to torch.jit.trace " |
|
"a dynamo-optimized function. This is not supported at the moment." |
|
) |
|
else: |
|
return fn(*args, **kwargs) |
|
|
|
cleanups = [enter() for enter in self.enter_exit_hooks] |
|
prior = set_eval_frame(callback) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
saved_dynamic_layer_stack_depth = ( |
|
torch._C._functorch.get_dynamic_layer_stack_depth() |
|
) |
|
|
|
try: |
|
return fn(*args, **kwargs) |
|
finally: |
|
|
|
torch._C._functorch.pop_dynamic_layer_stack_and_undo_to_depth( |
|
saved_dynamic_layer_stack_depth |
|
) |
|
|
|
set_eval_frame(prior) |
|
for cleanup in cleanups: |
|
cleanup() |
|
|
|
|
|
_fn._torchdynamo_inline = fn |
|
|
|
|
|
|
|
_fn._torchdynamo_orig_callable = fn |
|
|
|
|
|
|
|
assert not hasattr(_fn, "get_compiler_config") |
|
_fn.get_compiler_config = get_compiler_config |
|
|
|
|
|
|
|
if callback not in (None, False): |
|
if not hasattr(fn, "__code__"): |
|
raise RuntimeError( |
|
textwrap.dedent( |
|
""" |
|
|
|
torch._dynamo.optimize is called on a non function object. |
|
If this is a callable class, please wrap the relevant code into a function and optimize the |
|
wrapper function. |
|
|
|
>> class CallableClass: |
|
>> def __init__(self): |
|
>> super().__init__() |
|
>> self.relu = torch.nn.ReLU() |
|
>> |
|
>> def __call__(self, x): |
|
>> return self.relu(torch.sin(x)) |
|
>> |
|
>> def print_hello(self): |
|
>> print("Hello world") |
|
>> |
|
>> mod = CallableClass() |
|
|
|
If you want to optimize the __call__ function and other code, wrap that up in a function |
|
|
|
>> def wrapper_fn(x): |
|
>> y = mod(x) |
|
>> return y.sum() |
|
|
|
and then optimize the wrapper_fn |
|
|
|
>> opt_wrapper_fn = torch._dynamo.optimize(wrapper_fn) |
|
""" |
|
) |
|
) |
|
always_optimize_code_objects[fn.__code__] = True |
|
|
|
return _fn |
|
|
|
|
|
class OptimizeContext(_TorchDynamoContext): |
|
def __init__( |
|
self, |
|
callback, |
|
backend_ctx_ctor, |
|
first_ctx=False, |
|
*, |
|
export=False, |
|
dynamic=None, |
|
compiler_config=None, |
|
rebuild_ctx: Optional[ |
|
Callable[[], Union[OptimizeContext, _NullDecorator]] |
|
] = None, |
|
): |
|
def on_enter(): |
|
install_generation_tagging_init() |
|
|
|
super().__init__( |
|
callback=callback, |
|
on_enter=on_enter, |
|
backend_ctx_ctor=backend_ctx_ctor, |
|
patch_fn=TorchPatcher.patch, |
|
first_ctx=first_ctx, |
|
export=export, |
|
dynamic=dynamic, |
|
compiler_config=compiler_config, |
|
) |
|
|
|
if config.compiled_autograd: |
|
|
|
def call_compiled_autograd(): |
|
assert rebuild_ctx is not None |
|
compiler_fn = rebuild_ctx() |
|
ctx = torch._dynamo.compiled_autograd.enable(compiler_fn) |
|
ctx.__enter__() |
|
return functools.partial(ctx.__exit__, None, None, None) |
|
|
|
self.enter_exit_hooks.append(call_compiled_autograd) |
|
|
|
def __reduce__(self): |
|
return ( |
|
self.__class__, |
|
(self.callback, self._backend_ctx_ctor, self.first_ctx), |
|
{ |
|
"export": self.export, |
|
"dynamic": self._dynamic, |
|
"compiler_config": self.compiler_config, |
|
}, |
|
) |
|
|
|
|
|
class RunOnlyContext(_TorchDynamoContext): |
|
def __init__(self): |
|
|
|
def on_enter(): |
|
torch._dynamo.mutation_guard.GenerationTracker.generation += 1 |
|
|
|
super().__init__(callback=False, on_enter=on_enter) |
|
|
|
def __reduce__(self): |
|
return (self.__class__, ()) |
|
|
|
|
|
class DisableContext(_TorchDynamoContext): |
|
def __init__(self): |
|
super().__init__(callback=None) |
|
|
|
def __call__(self, fn): |
|
|
|
|
|
|
|
|
|
fn = innermost_fn(fn) |
|
|
|
if isinstance(fn, torch.nn.Module): |
|
mod = fn |
|
new_mod = OptimizedModule(mod, self) |
|
new_mod._torchdynamo_orig_callable = mod.forward |
|
return new_mod |
|
|
|
if inspect.isclass(fn): |
|
|
|
|
|
cls_obj = fn |
|
|
|
|
|
|
|
cls_obj.__init__ = self(cls_obj.__init__) |
|
cls_obj.__call__ = self(cls_obj.__call__) |
|
if issubclass(cls_obj, torch.nn.Module): |
|
|
|
cls_obj._call_impl = self(cls_obj._call_impl) |
|
return cls_obj |
|
|
|
assert callable(fn) |
|
|
|
callback = self.callback |
|
|
|
@functools.wraps(fn) |
|
def _fn(*args, **kwargs): |
|
prior = set_eval_frame(callback) |
|
try: |
|
return fn(*args, **kwargs) |
|
finally: |
|
set_eval_frame(prior) |
|
|
|
_fn._torchdynamo_disable = True |
|
|
|
|
|
|
|
_fn._torchdynamo_orig_callable = fn |
|
|
|
return _fn |
|
|
|
def __reduce__(self): |
|
return (self.__class__, ()) |
|
|
|
|
|
def _optimize_catch_errors( |
|
compile_fn, |
|
hooks: Hooks, |
|
backend_ctx_ctor=null_context, |
|
export=False, |
|
dynamic=None, |
|
compiler_config=None, |
|
rebuild_ctx=None, |
|
): |
|
return OptimizeContext( |
|
convert_frame.catch_errors_wrapper(compile_fn, hooks), |
|
backend_ctx_ctor=backend_ctx_ctor, |
|
first_ctx=True, |
|
export=export, |
|
dynamic=dynamic, |
|
compiler_config=compiler_config, |
|
rebuild_ctx=rebuild_ctx, |
|
) |
|
|
|
|
|
def get_compiler_fn(compiler_fn): |
|
from .repro.after_dynamo import wrap_backend_debug |
|
|
|
if hasattr(compiler_fn, "compiler_name"): |
|
compiler_str = compiler_fn.compiler_name |
|
elif isinstance(compiler_fn, str): |
|
compiler_str = compiler_fn |
|
else: |
|
compiler_str = None |
|
compiler_fn = lookup_backend(compiler_fn) |
|
return wrap_backend_debug(compiler_fn, compiler_str) |
|
|
|
|
|
class _NullDecorator(contextlib.nullcontext): |
|
def __call__(self, fn): |
|
assert callable(fn) |
|
return fn |
|
|
|
|
|
def check_if_dynamo_supported(): |
|
if sys.version_info >= (3, 13): |
|
raise RuntimeError("Python 3.13+ not yet supported for torch.compile") |
|
|
|
|
|
def is_dynamo_supported(): |
|
try: |
|
check_if_dynamo_supported() |
|
return True |
|
except Exception: |
|
return False |
|
|
|
|
|
def check_if_inductor_supported(): |
|
check_if_dynamo_supported() |
|
|
|
if sys.platform == "win32": |
|
raise RuntimeError("Windows not yet supported for inductor") |
|
|
|
|
|
def is_inductor_supported(): |
|
try: |
|
check_if_inductor_supported() |
|
return True |
|
except Exception: |
|
return False |
|
|
|
|
|
def optimize(*args, **kwargs): |
|
def rebuild_ctx(): |
|
return optimize(*args, **kwargs) |
|
|
|
return _optimize(rebuild_ctx, *args, **kwargs) |
|
|
|
|
|
def _optimize( |
|
rebuild_ctx: Callable[[], Union[OptimizeContext, _NullDecorator]], |
|
backend="inductor", |
|
*, |
|
nopython=False, |
|
guard_export_fn=None, |
|
guard_fail_fn=None, |
|
disable=False, |
|
dynamic=None, |
|
) -> Union[OptimizeContext, _NullDecorator]: |
|
""" |
|
The main entrypoint of TorchDynamo. Do graph capture and call |
|
backend() to optimize extracted graphs. |
|
|
|
Args: |
|
backend: One of the two things: |
|
- Either, a function/callable taking a torch.fx.GraphModule and |
|
example_inputs and returning a python callable that runs the |
|
graph faster. |
|
One can also provide additional context for the backend, like |
|
torch.jit.fuser("fuser2"), by setting the backend_ctx_ctor attribute. |
|
See AOTAutogradMemoryEfficientFusionWithContext for the usage. |
|
- Or, a string backend name in `torch._dynamo.list_backends()` |
|
nopython: If True, graph breaks will be errors and there will |
|
be a single whole-program graph. |
|
disable: If True, turn this decorator into a no-op |
|
dynamic: If True, upfront compile as dynamic a kernel as possible. If False, |
|
disable all dynamic shapes support (always specialize). If None, automatically |
|
detect when sizes vary and generate dynamic kernels upon recompile. |
|
|
|
Example Usage:: |
|
|
|
@torch._dynamo.optimize() |
|
def toy_example(a, b): |
|
... |
|
""" |
|
check_if_dynamo_supported() |
|
|
|
|
|
|
|
|
|
|
|
hooks = Hooks(guard_export_fn=guard_export_fn, guard_fail_fn=guard_fail_fn) |
|
torch._C._log_api_usage_once("torch._dynamo.optimize") |
|
if disable or os.environ.get("TORCHDYNAMO_DISABLE", "") == "1": |
|
return _NullDecorator() |
|
|
|
backend = get_compiler_fn(backend) |
|
|
|
|
|
backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context) |
|
|
|
if nopython: |
|
return optimize_assert( |
|
backend, |
|
dynamic=dynamic, |
|
hooks=hooks, |
|
rebuild_ctx=rebuild_ctx, |
|
) |
|
|
|
|
|
|
|
return _optimize_catch_errors( |
|
convert_frame.convert_frame(backend, hooks=hooks), |
|
hooks, |
|
backend_ctx_ctor, |
|
dynamic=dynamic, |
|
compiler_config=backend.get_compiler_config() |
|
if hasattr(backend, "get_compiler_config") |
|
else None, |
|
rebuild_ctx=rebuild_ctx, |
|
) |
|
|
|
|
|
|
|
@patch("torch._dynamo.symbolic_convert.explain", True) |
|
def explain(f, *extra_args, **extra_kwargs): |
|
def inner(*args, **kwargs): |
|
|
|
from . import reset |
|
|
|
reset() |
|
|
|
graphs: List[torch.fx.GraphModule] = [] |
|
break_reasons: List[Any] = [] |
|
op_count: int = 0 |
|
ops_per_graph: List[torch.fx.Node] = [] |
|
out_guards: List[_guards.Guard] = [] |
|
|
|
def dynamo_graph_accumulating_compiler( |
|
gm: torch.fx.GraphModule, example_inputs |
|
): |
|
from .backends.debugging import _explain_graph_detail |
|
|
|
nonlocal graphs |
|
nonlocal op_count |
|
nonlocal ops_per_graph |
|
nonlocal break_reasons |
|
|
|
gm, graphs, op_count, ops_per_graph, break_reasons = _explain_graph_detail( |
|
gm, graphs, op_count, ops_per_graph, break_reasons |
|
) |
|
|
|
return gm.forward |
|
|
|
def guard_export_print(guards): |
|
nonlocal out_guards |
|
out_guards.extend(guards) |
|
|
|
opt_f = optimize( |
|
dynamo_graph_accumulating_compiler, |
|
nopython=False, |
|
guard_export_fn=guard_export_print, |
|
)(f) |
|
|
|
opt_f(*args, **kwargs) |
|
|
|
graph_count = len(graphs) |
|
graph_break_count = graph_count - 1 |
|
compile_time = compile_times(repr="str") |
|
|
|
|
|
reset() |
|
from .backends.debugging import ExplainOutput |
|
|
|
return ExplainOutput( |
|
graphs, |
|
graph_count, |
|
graph_break_count, |
|
break_reasons, |
|
op_count, |
|
ops_per_graph, |
|
out_guards, |
|
compile_time, |
|
) |
|
|
|
if extra_args or extra_kwargs: |
|
warnings.warn( |
|
"explain(f, *args, **kwargs) is deprecated, use explain(f)(*args, **kwargs) instead. " |
|
"If you don't migrate, we may break your explain call in the future if your user defined kwargs " |
|
"conflict with future kwargs added to explain(f).", |
|
FutureWarning, |
|
stacklevel=2, |
|
) |
|
return inner(*extra_args, **extra_kwargs) |
|
else: |
|
return inner |
|
|
|
|
|
class FlattenInputOutputSignature(torch.fx.interpreter.Transformer): |
|
def __init__( |
|
self, |
|
m: torch.fx.GraphModule, |
|
flat_args: Tuple[Any], |
|
matched_input_elements_positions: List[int], |
|
flat_results: List[Any], |
|
matched_output_elements_positions: List[int], |
|
example_fake_inputs: List[torch.Tensor], |
|
flat_args_dynamic_dims: List[Set[int]], |
|
fake_mode: Optional[fake_tensor.FakeTensorMode] = None, |
|
): |
|
super().__init__(m) |
|
|
|
assert len(flat_args_dynamic_dims) == len(flat_args) |
|
matched_input_elements_to_fake = { |
|
val: example_fake_inputs[ix] |
|
for ix, val in enumerate(matched_input_elements_positions) |
|
} |
|
|
|
self.new_args = [] |
|
for i in range(0, len(flat_args)): |
|
arg = super().placeholder(f"arg{i}", (), {}) |
|
if i in matched_input_elements_to_fake: |
|
arg.node.meta["val"] = matched_input_elements_to_fake[i] |
|
else: |
|
|
|
|
|
if fake_mode is not None and isinstance(flat_args[i], torch.Tensor): |
|
|
|
arg.node.meta["val"] = fake_mode.from_tensor( |
|
flat_args[i], |
|
symbolic_context=StatelessSymbolicContext( |
|
dynamic_sizes=[ |
|
DimDynamic.DYNAMIC |
|
if d in flat_args_dynamic_dims[i] |
|
else DimDynamic.STATIC |
|
for d in range(len(flat_args[i].shape)) |
|
], |
|
constraint_sizes=[None] * len(flat_args[i].shape), |
|
), |
|
) |
|
self.new_args.append(arg) |
|
self.old_args_gen = (self.new_args[i] for i in matched_input_elements_positions) |
|
self.matched_output_elements_positions = matched_output_elements_positions |
|
self.flat_results = flat_results |
|
|
|
def placeholder(self, target, args, kwargs): |
|
arg = next(self.old_args_gen) |
|
if "val" in self.current_node.meta: |
|
arg.node.meta["val"] = self.current_node.meta["val"] |
|
if "tensor_dict" in self.current_node.meta: |
|
arg.node.meta["tensor_dict"] = self.current_node.meta["tensor_dict"] |
|
if "example_value" in self.current_node.meta: |
|
|
|
arg.node.meta["example_value"] = self.current_node.meta["example_value"] |
|
if "unbacked_bindings" in self.current_node.meta: |
|
arg.node.meta["unbacked_bindings"] = self.current_node.meta[ |
|
"unbacked_bindings" |
|
] |
|
return arg |
|
|
|
def output(self, target, args, kwargs): |
|
dynamo_result_flat = args[0] |
|
lookup = [*dynamo_result_flat, *self.new_args] |
|
new_results_flat = [] |
|
for i in range(len(self.flat_results)): |
|
if self.matched_output_elements_positions[i] is not None: |
|
new_results_flat.append( |
|
lookup[self.matched_output_elements_positions[i]] |
|
) |
|
else: |
|
const_val = self.flat_results[i] |
|
assert isinstance(const_val, tuple(common_constant_types)) |
|
new_results_flat.append(const_val) |
|
return super().output(target, (new_results_flat,), {}) |
|
|
|
def run_node(self, n): |
|
self.current_node = n |
|
result_proxy = super().run_node(n) |
|
if "val" in self.current_node.meta: |
|
result_proxy.node.meta["val"] = self.current_node.meta["val"] |
|
if "example_value" in self.current_node.meta: |
|
|
|
result_proxy.node.meta["example_value"] = self.current_node.meta[ |
|
"example_value" |
|
] |
|
if "unbacked_bindings" in self.current_node.meta: |
|
result_proxy.node.meta["unbacked_bindings"] = self.current_node.meta[ |
|
"unbacked_bindings" |
|
] |
|
if self.current_node.op != "output": |
|
result_proxy.node._rename( |
|
getattr(self.current_node, "name", result_proxy.node.name) |
|
) |
|
return result_proxy |
|
|
|
def transform(self): |
|
result_gm = super().transform() |
|
if "dynamo_flat_name_to_original_fqn" in self.module.meta: |
|
result_gm.meta["dynamo_flat_name_to_original_fqn"] = self.module.meta[ |
|
"dynamo_flat_name_to_original_fqn" |
|
] |
|
return result_gm |
|
|
|
|
|
class ExportResult(NamedTuple): |
|
graph_module: torch.fx.GraphModule |
|
guards: _guards.GuardsSet |
|
|
|
|
|
|
|
|
|
def check_signature_rewritable(graph): |
|
input_errors = [] |
|
for node in graph.graph.find_nodes(op="placeholder"): |
|
assert hasattr(node, "_dynamo_source") |
|
source = node._dynamo_source |
|
user_stacks = graph._source_to_user_stacks.get(source) |
|
if user_stacks is None: |
|
continue |
|
assert len(user_stacks) > 0 |
|
|
|
|
|
stack = None |
|
for s in user_stacks: |
|
if len(s) == 0: |
|
continue |
|
stack = s |
|
break |
|
if stack is None: |
|
msg = f"{source.name()}, a closed over free variable" |
|
else: |
|
tb = "".join(traceback.format_list(stack)) |
|
extra = "" |
|
if len(user_stacks) > 1: |
|
extra = f"(elided {len(user_stacks) - 1} more accesses)" |
|
msg = f"{source.name()}, accessed at:\n{tb}{extra}" |
|
|
|
input_errors.append(msg) |
|
|
|
if input_errors: |
|
raise UserError( |
|
UserErrorType.INVALID_INPUT, |
|
"Cannot export model which references tensors that are neither " |
|
"buffers/parameters/constants nor are direct inputs. For each tensor, if you'd " |
|
"like this tensor to be an explicit input, add it as a dummy argument " |
|
"to the top-level model definition you are exporting; if you would " |
|
"like its value to be embedded as an exported constant, wrap its access " |
|
"in a function marked with @assume_constant_result.\n\n" |
|
+ "\n\n".join(input_errors), |
|
) |
|
|
|
|
|
def rewrite_signature( |
|
f_sig, |
|
graph, |
|
fake_mode, |
|
flat_args, |
|
in_spec, |
|
example_fake_inputs, |
|
graph_captured_input, |
|
graph_captured_output, |
|
dynamo_traced_result, |
|
flat_args_dynamic_dims, |
|
): |
|
orig_args, orig_kwargs = pytree.tree_unflatten(flat_args, in_spec) |
|
|
|
def check_user_input_output(flat_values, error_type): |
|
supported_types = [ |
|
torch.Tensor, |
|
torch.SymInt, |
|
torch.SymFloat, |
|
torch.SymBool, |
|
torch._C.ScriptObject, |
|
] + list(common_constant_types) |
|
|
|
def is_supported_type(val): |
|
return isinstance(val, tuple(supported_types)) |
|
|
|
value_type = "input" if error_type == UserErrorType.INVALID_INPUT else "output" |
|
|
|
for v in flat_values: |
|
if not is_supported_type(v): |
|
if error_type == UserErrorType.INVALID_INPUT and v is None: |
|
continue |
|
|
|
raise UserError( |
|
error_type, |
|
f"It looks like one of the {value_type}s with type `{type(v)}` " |
|
"is not supported or pytree-flattenable. \n" |
|
f"Exported graphs {value_type}s can only contain the " |
|
f"following supported types: {supported_types}. \n" |
|
"If you are using a custom class object, " |
|
"please register a pytree_flatten/unflatten function " |
|
"using `torch.utils._pytree.register_pytree_node` or " |
|
"`torch.export.register_dataclass`.", |
|
) |
|
|
|
check_user_input_output(flat_args, UserErrorType.INVALID_INPUT) |
|
flat_results_traced, out_spec_traced = pytree.tree_flatten(dynamo_traced_result) |
|
check_user_input_output(flat_results_traced, UserErrorType.INVALID_OUTPUT) |
|
|
|
def produce_matching(debug_type, sources, candidates): |
|
matched_elements_positions: List[Optional[int]] = [] |
|
dict_of_source_vals = {} |
|
for i, val in enumerate(sources): |
|
dict_of_source_vals[id(val)] = i |
|
|
|
for i, val in enumerate(candidates): |
|
if isinstance(val, tuple(common_constant_types)): |
|
matched_elements_positions.append(None) |
|
elif id(val) not in dict_of_source_vals: |
|
raise AssertionError( |
|
f"Unexpectedly found a {type(val)} in the {debug_type}.\n" |
|
'Please file an issue along with a paste of the logs from TORCH_LOGS="+export"' |
|
) |
|
else: |
|
matched_elements_positions.append(dict_of_source_vals[id(val)]) |
|
|
|
return matched_elements_positions |
|
|
|
matched_input_elements_positions = produce_matching( |
|
"inputs", flat_args, graph_captured_input |
|
) |
|
|
|
assert graph_captured_output is not None |
|
matched_output_elements_positions = produce_matching( |
|
"outputs", list(graph_captured_output) + flat_args, flat_results_traced |
|
) |
|
|
|
new_graph = FlattenInputOutputSignature( |
|
graph, |
|
flat_args, |
|
matched_input_elements_positions, |
|
flat_results_traced, |
|
matched_output_elements_positions, |
|
example_fake_inputs, |
|
flat_args_dynamic_dims, |
|
fake_mode, |
|
).transform() |
|
|
|
|
|
def argument_names(f_sig, args, kwargs) -> List[str]: |
|
def signature_to_fullargspec(sig: inspect.Signature): |
|
|
|
params = list(sig.parameters.values()) |
|
|
|
args = [ |
|
p.name |
|
for p in params |
|
if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD |
|
] |
|
kwonlyargs = [ |
|
p.name for p in params if p.kind == inspect.Parameter.KEYWORD_ONLY |
|
] |
|
varargs = next( |
|
(p.name for p in params if p.kind == inspect.Parameter.VAR_POSITIONAL), |
|
None, |
|
) |
|
varkw = next( |
|
(p.name for p in params if p.kind == inspect.Parameter.VAR_KEYWORD), |
|
None, |
|
) |
|
|
|
defaults = tuple( |
|
p.default |
|
for p in params |
|
if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD |
|
and p.default is not inspect.Parameter.empty |
|
) |
|
kwonlydefaults = { |
|
p.name: p.default |
|
for p in params |
|
if p.kind == inspect.Parameter.KEYWORD_ONLY |
|
and p.default is not inspect.Parameter.empty |
|
} |
|
|
|
annotations = {} |
|
if sig.return_annotation: |
|
annotations = {"return": sig.return_annotation} |
|
for parameter in params: |
|
annotations[parameter.name] = parameter.annotation |
|
|
|
return inspect.FullArgSpec( |
|
args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations |
|
) |
|
|
|
fullargspec = signature_to_fullargspec(f_sig) |
|
|
|
|
|
input_strs = fullargspec.args[: len(args)] |
|
|
|
if len(args) > len(fullargspec.args): |
|
|
|
|
|
assert fullargspec.varargs is not None, "More arguments than expected" |
|
input_strs += [ |
|
f"{fullargspec.varargs}_{i}" |
|
for i in range(0, len(args) - len(input_strs)) |
|
] |
|
elif len(args) < len(fullargspec.args): |
|
|
|
|
|
|
|
|
|
|
|
for unprovided_arg in fullargspec.args[ |
|
len(args) : -len(fullargspec.defaults or []) |
|
]: |
|
assert unprovided_arg in kwargs, f"Missing argument {unprovided_arg}" |
|
|
|
|
|
input_strs += list(kwargs.keys()) |
|
|
|
|
|
|
|
for kwonly_arg in fullargspec.kwonlyargs: |
|
kwonlydefaults = fullargspec.kwonlydefaults or {} |
|
assert ( |
|
kwonly_arg in kwargs or kwonly_arg in kwonlydefaults |
|
), f"Missing keyword only argument {kwonly_arg}" |
|
|
|
return input_strs |
|
|
|
new_graph.graph._codegen = _PyTreeCodeGen( |
|
_PyTreeInfo( |
|
argument_names(f_sig, orig_args, orig_kwargs), |
|
in_spec, |
|
out_spec_traced, |
|
) |
|
) |
|
new_graph.recompile() |
|
return new_graph |
|
|
|
|
|
def export( |
|
f: Callable[..., Any], |
|
*extra_args, |
|
aten_graph: bool = False, |
|
pre_dispatch: bool = False, |
|
decomposition_table: Optional[ |
|
Dict[torch._ops.OpOverload, Callable[..., Any]] |
|
] = None, |
|
tracing_mode: str = "symbolic", |
|
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, |
|
assume_static_by_default: bool = False, |
|
same_signature: bool = True, |
|
disable_constraint_solver: bool = False, |
|
prefer_deferred_runtime_asserts_over_guards: bool = False, |
|
_allow_complex_guards_as_runtime_asserts: bool = False, |
|
_log_export_usage: bool = True, |
|
**extra_kwargs, |
|
) -> Callable[..., ExportResult]: |
|
""" |
|
Export an input function f to a format that can be executed outside of PyTorch using the FX graph. |
|
|
|
Args: |
|
f (callable): A PyTorch function to be exported. |
|
|
|
aten_graph (bool): If True, exports a graph with ATen operators. |
|
If False, exports a graph with Python operators. Default is False. |
|
|
|
pre_dispatch (bool): If True, exports a graph with ATen operators, |
|
but before any logic in the PyTorch dispatcher has run. |
|
This can be useful if you want to apply further transformations on a graph before running it |
|
through autograd, autocast, or any other functionalities that are integrated into the dispatcher. |
|
This flag is only valid if aten_graph=True is set. |
|
Default is False. |
|
|
|
decomposition_table (dict): A dictionary that maps operators to their decomposition functions. |
|
Required if aten_graph or tracing_mode is specified. Default is None. |
|
|
|
tracing_mode (str): If "symbolic", turn on dynamic shapes support. Default is "symbolic". |
|
|
|
dynamic_shapes: |
|
An optional argument where the type should either be: |
|
1) a dict from argument names of ``f`` to their dynamic shape specifications, |
|
2) a tuple that specifies dynamic shape specifications for each input in original order. |
|
If you are specifying dynamism on keyword args, you will need to pass them in the order that |
|
is defined in the original function signature. |
|
|
|
The dynamic shape of a tensor argument can be specified as either |
|
(1) a dict from dynamic dimension indices to :func:`Dim` types, where it is |
|
not required to include static dimension indices in this dict, but when they are, |
|
they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, |
|
where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions |
|
are denoted by None. Arguments that are dicts or tuples / lists of tensors are |
|
recursively specified by using mappings or sequences of contained specifications. |
|
|
|
same_signature (bool): If True, rewrite the returned graph's signature to be the same as f. |
|
|
|
disable_constraint_solver (bool): Whether the dim constraint solver must be disabled. |
|
|
|
Returns: |
|
A function that given args and kwargs, returns a tuple of (graph, guards) |
|
Graph: An FX graph representing the execution of the input PyTorch function with the provided arguments and options. |
|
Guards: The guards we accumulated during tracing f above |
|
|
|
Raises: |
|
AssertionError: If decomposition_table is specified without setting aten_graph=True, |
|
or if graph breaks during tracing in export. |
|
|
|
AssertionError: If Dynamo input and output is not consistent with traced input/output. |
|
|
|
Note - this headerdoc was authored by ChatGPT, with slight modifications by the author. |
|
""" |
|
if _log_export_usage: |
|
log_export_usage(event="export.private_api", flags={"_dynamo"}) |
|
|
|
|
|
_f = f |
|
_assume_static_by_default = assume_static_by_default |
|
|
|
def inner(*args, **kwargs): |
|
constraints = _process_dynamic_shapes(_f, args, kwargs, dynamic_shapes) |
|
f = _f |
|
assume_static_by_default = _assume_static_by_default |
|
check_if_dynamo_supported() |
|
torch._C._log_api_usage_once("torch._dynamo.export") |
|
if decomposition_table is not None: |
|
assert ( |
|
aten_graph |
|
), "Specifying a decomposition_table table or tracing mode is illegal without setting aten_graph=True" |
|
if pre_dispatch: |
|
assert aten_graph, "pre_dispatch=True can only be used when aten_graph=True" |
|
f = innermost_fn(f) |
|
call_to_inspect = f.forward if isinstance(f, torch.nn.Module) else f |
|
original_signature = inspect.signature(call_to_inspect) |
|
graph = None |
|
out_guards = None |
|
graph_captured_input = None |
|
graph_captured_result: Optional[Tuple[torch.Tensor, ...]] = None |
|
fake_mode = None |
|
|
|
def guard_export_print(guards: _guards.GuardsSet): |
|
nonlocal out_guards |
|
assert ( |
|
out_guards is None |
|
), "whole graph export entails exactly one guard export" |
|
out_guards = guards |
|
|
|
example_inputs = [] |
|
|
|
def dynamo_normalization_capturing_compiler( |
|
gm: torch.fx.GraphModule, inner_example_inputs |
|
): |
|
nonlocal graph |
|
assert ( |
|
graph is None |
|
), "Tried to emit a second graph during export. Tracing through 'f' must produce a single graph." |
|
graph = gm |
|
|
|
nonlocal fake_mode, example_inputs |
|
|
|
|
|
|
|
|
|
|
|
fake_mode = _guards.detect_fake_mode() |
|
example_inputs = inner_example_inputs |
|
|
|
def result_capturing_wrapper(*graph_inputs): |
|
nonlocal graph_captured_result |
|
nonlocal graph_captured_input |
|
|
|
graph_captured_input = graph_inputs |
|
assert graph is not None |
|
|
|
named_parameters = dict(graph.named_parameters(remove_duplicate=False)) |
|
named_buffers = dict(graph.named_buffers(remove_duplicate=False)) |
|
|
|
ambient_fake_mode = ( |
|
_guards.detect_fake_mode(graph_inputs) |
|
if _guards.detect_fake_mode(graph_inputs) is not None |
|
else fake_mode |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
ignore_fresh_unbacked = null_context() |
|
if shape_env := ambient_fake_mode.shape_env: |
|
ignore_fresh_unbacked = shape_env.ignore_fresh_unbacked_symbols() |
|
|
|
with ( |
|
ambient_fake_mode |
|
), enable_python_dispatcher(), ignore_fresh_unbacked: |
|
params_and_buffers = { |
|
**named_parameters, |
|
**named_buffers, |
|
} |
|
fake_params_buffers = dict() |
|
|
|
for name, value in params_and_buffers.items(): |
|
fake_params_buffers[name] = ambient_fake_mode.from_tensor( |
|
value, static_shapes=True |
|
) |
|
|
|
fake_graph_inputs = pytree.tree_map( |
|
ambient_fake_mode.from_tensor, graph_inputs |
|
) |
|
graph_captured_result = torch.func.functional_call( |
|
graph, fake_params_buffers, fake_graph_inputs |
|
) |
|
|
|
return graph_captured_result |
|
|
|
return result_capturing_wrapper |
|
|
|
|
|
|
|
flat_args, in_spec = pytree.tree_flatten((args, kwargs)) |
|
|
|
remove_from_cache(f) |
|
constraint_violation_error = None |
|
if tracing_mode != "symbolic": |
|
assume_static_by_default = True |
|
with config.patch( |
|
specialize_int=True, |
|
assume_static_by_default=assume_static_by_default, |
|
automatic_dynamic_shapes=False, |
|
capture_dynamic_output_shape_ops=True, |
|
capture_scalar_outputs=True, |
|
prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, |
|
_allow_complex_guards_as_runtime_asserts=_allow_complex_guards_as_runtime_asserts, |
|
): |
|
opt_f = optimize_assert( |
|
dynamo_normalization_capturing_compiler, |
|
hooks=Hooks( |
|
guard_export_fn=guard_export_print, |
|
guard_fail_fn=None, |
|
), |
|
export=True, |
|
export_constraints=constraints, |
|
)(f) |
|
|
|
try: |
|
result_traced = opt_f(*args, **kwargs) |
|
except ConstraintViolationError as e: |
|
constraint_violation_error = e |
|
remove_from_cache(f) |
|
|
|
if ( |
|
not disable_constraint_solver |
|
and (shape_env := getattr(fake_mode, "shape_env", None)) is not None |
|
and (dim_constraints := shape_env.dim_constraints) is not None |
|
and not isinstance( |
|
call_to_inspect, (torch._ops.OpOverloadPacket, torch._ops.OpOverload) |
|
) |
|
and not trace_rules.check(call_to_inspect) |
|
): |
|
dim_constraints.solve() |
|
dim_constraints.remove_redundant_dynamic_results() |
|
forced_specializations = dim_constraints.forced_specializations() |
|
msg = dim_constraints.prettify_results( |
|
original_signature, |
|
dynamic_shapes, |
|
constraint_violation_error, |
|
forced_specializations, |
|
) |
|
if constraint_violation_error: |
|
constraint_violation_error.args = ( |
|
constraint_violation_error.args[0] + msg, |
|
) |
|
else: |
|
if forced_specializations: |
|
constraint_violation_error = ConstraintViolationError(msg) |
|
else: |
|
log.info( |
|
"Summary of dimension constraints:%s", |
|
msg, |
|
) |
|
|
|
|
|
for k in shape_env.var_to_range.keys(): |
|
if isinstance(k, sympy.Integer): |
|
constraint_violation_error = ConstraintViolationError( |
|
f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n" |
|
"It appears that you're trying to set a constraint on a " |
|
f"value which we evaluated to have a static value of {k}. " |
|
'Set TORCH_LOGS="+export" for more information.' |
|
) |
|
if constraint_violation_error: |
|
raise constraint_violation_error |
|
|
|
assert ( |
|
graph is not None |
|
), "Failed to produce a graph during tracing as no tensor operations were found." |
|
assert hasattr(graph, "_source_to_user_stacks") |
|
assert out_guards is not None, "Failed to produce guards during tracing" |
|
assert fake_mode is not None |
|
|
|
log.info( |
|
"Dynamo captured graph:\n\n%s", graph.print_readable(print_output=False) |
|
) |
|
|
|
|
|
|
|
if same_signature: |
|
check_signature_rewritable(graph) |
|
|
|
|
|
example_fake_inputs = [fake_mode.from_tensor(t) for t in example_inputs] |
|
|
|
if aten_graph: |
|
|
|
def graph_with_interpreter(*args): |
|
with torch.fx.traceback.preserve_node_meta(): |
|
return torch.fx.Interpreter(graph).run(*args) |
|
|
|
with maybe_disable_fake_tensor_mode(), enable_python_dispatcher(), ( |
|
fake_mode |
|
): |
|
try: |
|
graph = make_fx( |
|
graph_with_interpreter, |
|
decomposition_table=decomposition_table, |
|
tracing_mode="real", |
|
_allow_non_fake_inputs=True, |
|
pre_dispatch=pre_dispatch, |
|
_allow_fake_constant=False, |
|
)(*example_fake_inputs) |
|
except CondOpArgsMismatchError as e: |
|
|
|
raise UserError( |
|
UserErrorType.DYNAMIC_CONTROL_FLOW, |
|
str(e), |
|
case_name="cond_operands", |
|
) |
|
|
|
assert graph is not None |
|
for node in graph.graph.find_nodes(op="get_attr"): |
|
if isinstance(getattr(graph, node.target), torch.Tensor): |
|
node.meta["val"] = fake_mode.from_tensor( |
|
getattr(graph, node.target), static_shapes=True |
|
) |
|
|
|
if same_signature: |
|
flat_args_dynamic_dims = [ |
|
{c.dim for c in (constraints or ()) if c.w_tensor() is x} |
|
for x in flat_args |
|
] |
|
graph = rewrite_signature( |
|
original_signature, |
|
graph, |
|
fake_mode, |
|
flat_args, |
|
in_spec, |
|
example_fake_inputs, |
|
graph_captured_input, |
|
graph_captured_result, |
|
result_traced, |
|
flat_args_dynamic_dims, |
|
) |
|
|
|
assert graph is not None |
|
graph.meta["input_shape_constraints"] = ( |
|
[constraint.serializable_spec for constraint in constraints] |
|
if constraints |
|
else [] |
|
) |
|
|
|
return ExportResult(graph, out_guards) |
|
|
|
if extra_args or extra_kwargs: |
|
warnings.warn( |
|
"export(f, *args, **kwargs) is deprecated, use export(f)(*args, **kwargs) instead. " |
|
"If you don't migrate, we may break your export call in the future if your user defined kwargs " |
|
"conflict with future kwargs added to export(f).", |
|
FutureWarning, |
|
stacklevel=2, |
|
) |
|
return inner(*extra_args, **extra_kwargs) |
|
else: |
|
return inner |
|
|
|
|
|
def optimize_assert( |
|
backend, |
|
*, |
|
hooks=Hooks(None, None), |
|
export=False, |
|
export_constraints=None, |
|
dynamic=None, |
|
rebuild_ctx=None, |
|
): |
|
""" |
|
The same as `torch._dynamo.optimize(backend, nopython=True)` |
|
""" |
|
backend = get_compiler_fn(backend) |
|
|
|
|
|
backend_ctx_ctor = getattr(backend, "backend_ctx_ctor", null_context) |
|
|
|
return _optimize_catch_errors( |
|
convert_frame.convert_frame_assert( |
|
backend, export=export, export_constraints=export_constraints |
|
), |
|
hooks, |
|
backend_ctx_ctor, |
|
export=export, |
|
dynamic=dynamic, |
|
rebuild_ctx=rebuild_ctx, |
|
) |
|
|
|
|
|
class TorchPatcher: |
|
@staticmethod |
|
@functools.lru_cache(None) |
|
def patch(): |
|
|
|
|
|
|
|
from .decorators import disable |
|
|
|
torch.jit.trace = disable(torch.jit.trace) |
|
torch.jit.trace_module = disable(torch.jit.trace_module) |
|
torch.jit._get_trace_graph = disable(torch.jit._get_trace_graph) |
|
torch.fx._symbolic_trace.Tracer.trace = disable( |
|
torch.fx._symbolic_trace.Tracer.trace |
|
) |
|
torch.distributions.Distribution.set_default_validate_args(False) |
|
|
|
from ..optim import ( |
|
adadelta, |
|
adagrad, |
|
adam, |
|
adamax, |
|
adamw, |
|
asgd, |
|
lbfgs, |
|
nadam, |
|
radam, |
|
rmsprop, |
|
rprop, |
|
sgd, |
|
sparse_adam, |
|
) |
|
|
|
optimizer_modules = { |
|
adadelta, |
|
adagrad, |
|
adam, |
|
adamax, |
|
adamw, |
|
asgd, |
|
lbfgs, |
|
nadam, |
|
radam, |
|
rmsprop, |
|
rprop, |
|
sgd, |
|
sparse_adam, |
|
} |
|
|
|
for opt_mod in optimizer_modules: |
|
opt_name = opt_mod.__name__.split(".")[-1] |
|
fused_fn_name = f"_fused_{opt_name}" |
|
single_tensor_fn_name = f"_single_tensor_{opt_name}" |
|
|
|
if hasattr(opt_mod, fused_fn_name): |
|
setattr( |
|
opt_mod, fused_fn_name, disable(getattr(opt_mod, fused_fn_name)) |
|
) |
|
|
|
optimizer_classes = [ |
|
opt |
|
for opt in torch.optim.__dict__.values() |
|
if inspect.isclass(opt) and issubclass(opt, torch.optim.Optimizer) |
|
] |
|
|
|
|
|
excluded_optimizer_classes = { |
|
torch.optim.SparseAdam, |
|
torch.optim.LBFGS, |
|
} |
|
|
|
for opt in optimizer_classes: |
|
if opt in excluded_optimizer_classes: |
|
opt.step = disable(opt.step) |
|
|
|
if hasattr(opt, "_init_group"): |
|
opt._init_group = disable(opt._init_group) |
|
|
|
@staticmethod |
|
def suppress_torch_distributed_warnings(fn): |
|
def inner_fn(*args, **kwargs): |
|
warnings.filterwarnings( |
|
"ignore", category=UserWarning, module="torch.distributed" |
|
) |
|
return fn(*args, **kwargs) |
|
|
|
return inner_fn |
|
|