|
|
|
import os |
|
import textwrap |
|
from enum import auto, Enum |
|
from traceback import extract_stack, format_exc, format_list, StackSummary |
|
from typing import Any, cast, NoReturn, Optional |
|
|
|
import torch._guards |
|
|
|
from . import config |
|
|
|
from .utils import counters |
|
|
|
|
|
def exportdb_error_message(case_name): |
|
return ( |
|
"For more information about this error, see: " |
|
+ "https://pytorch.org/docs/main/generated/exportdb/index.html#" |
|
+ case_name.replace("_", "-") |
|
) |
|
|
|
|
|
import logging |
|
|
|
log = logging.getLogger(__name__) |
|
graph_breaks_log = torch._logging.getArtifactLogger(__name__, "graph_breaks") |
|
|
|
|
|
class TorchDynamoException(RuntimeError): |
|
pass |
|
|
|
|
|
class InternalTorchDynamoError(TorchDynamoException): |
|
pass |
|
|
|
|
|
class RestartAnalysis(TorchDynamoException): |
|
restart_reason: str |
|
|
|
def __init__(self, *args, restart_reason=None): |
|
self.restart_reason = restart_reason |
|
super().__init__(*args) |
|
|
|
|
|
class SpeculationRestartAnalysis(RestartAnalysis): |
|
pass |
|
|
|
|
|
class UnspecializeRestartAnalysis(RestartAnalysis): |
|
pass |
|
|
|
|
|
class SkipFrame(TorchDynamoException): |
|
pass |
|
|
|
|
|
class TorchRuntimeError(TorchDynamoException): |
|
pass |
|
|
|
|
|
class InvalidBackend(TorchDynamoException): |
|
def __init__(self, name): |
|
super().__init__( |
|
f"Invalid backend: {name!r}, see `torch._dynamo.list_backends()` for available backends." |
|
) |
|
|
|
|
|
class ResetRequired(TorchDynamoException): |
|
def __init__(self): |
|
super().__init__( |
|
textwrap.dedent( |
|
""" |
|
Must call `torch._dynamo.reset()` before changing backends. Detected two calls to |
|
`torch.compile()` with a different backend compiler arguments. |
|
""" |
|
) |
|
) |
|
|
|
|
|
class BackendCompilerFailed(TorchDynamoException): |
|
def __init__(self, backend_fn, inner_exception): |
|
self.backend_name = getattr(backend_fn, "__name__", "?") |
|
self.inner_exception = inner_exception |
|
msg = f"backend={self.backend_name!r} raised:\n{type(inner_exception).__name__}: {inner_exception}" |
|
super().__init__(msg) |
|
|
|
|
|
class Unsupported(TorchDynamoException): |
|
def __init__(self, msg): |
|
super().__init__(msg) |
|
self.real_stack = torch._guards.TracingContext.extract_stack() |
|
self.msg = msg |
|
self.category: Optional[str] = None |
|
self.add_to_stats() |
|
|
|
def remove_from_stats(self): |
|
assert self.category is not None |
|
counters[self.category][self.msg] -= 1 |
|
if counters[self.category][self.msg] <= 0: |
|
del counters[self.category][self.msg] |
|
|
|
def add_to_stats(self, category="unimplemented"): |
|
self.category = category |
|
counters[category][self.msg] += 1 |
|
|
|
|
|
class RecompileError(TorchDynamoException): |
|
pass |
|
|
|
|
|
class ArgsMismatchError(Unsupported): |
|
def __init__(self, msg): |
|
super().__init__(msg) |
|
|
|
|
|
class AttributeMutationError(Unsupported): |
|
def __init__(self, msg): |
|
super().__init__(msg) |
|
|
|
|
|
class CondOpArgsMismatchError(ArgsMismatchError): |
|
""" |
|
Internal error from cond() due to arguments mismatch. |
|
""" |
|
|
|
def __init__(self, msg): |
|
super().__init__(msg) |
|
|
|
|
|
class UserErrorType(Enum): |
|
DYNAMIC_CONTROL_FLOW = auto() |
|
ANTI_PATTERN = auto() |
|
STANDARD_LIBRARY = auto() |
|
CONSTRAINT_VIOLATION = auto() |
|
DYNAMIC_DIM = auto() |
|
INVALID_INPUT = auto() |
|
INVALID_OUTPUT = auto() |
|
|
|
|
|
class UserError(Unsupported): |
|
def __init__(self, error_type: UserErrorType, msg, case_name=None): |
|
""" |
|
Type of errors that would be valid in Eager, but not supported in TorchDynamo. |
|
The error message should tell user about next actions. |
|
|
|
error_type: Type of user error |
|
msg: Actionable error message |
|
case_name: (Optional) Unique name (snake case) for the usage example in exportdb. |
|
""" |
|
if case_name is not None: |
|
assert isinstance(case_name, str) |
|
if msg.endswith("."): |
|
msg += " " |
|
else: |
|
msg += "\n" |
|
msg += exportdb_error_message(case_name) |
|
super().__init__(msg) |
|
self.error_type = error_type |
|
self.message = msg |
|
|
|
|
|
class UserStopIteration(TorchDynamoException): |
|
value: Optional[Any] |
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__("unhandled `raise StopIteration`") |
|
if len(args) > 0: |
|
self.value = args[0] |
|
else: |
|
self.value = None |
|
|
|
|
|
class UnsafeScriptObjectError(TorchDynamoException): |
|
pass |
|
|
|
|
|
class UncapturedHigherOrderOpError(TorchDynamoException): |
|
pass |
|
|
|
|
|
class IncorrectUsage(Exception): |
|
pass |
|
|
|
|
|
class ObservedException(TorchDynamoException): |
|
pass |
|
|
|
|
|
|
|
exceptions_allowed_to_be_fallback = ( |
|
torch._subclasses.fake_tensor.DataDependentOutputException, |
|
torch._subclasses.fake_tensor.DynamicOutputShapeException, |
|
torch._subclasses.fake_tensor.UnsupportedOperatorException, |
|
torch._subclasses.fake_tensor.UnsupportedFakeTensorException, |
|
) |
|
|
|
|
|
def unimplemented_with_warning(e: Exception, code, msg: str) -> NoReturn: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
graph_break_msg = format_error_msg_verbose(e, code) |
|
graph_breaks_log.debug("%s", graph_break_msg) |
|
log.warning(msg) |
|
unimplemented(msg, from_exc=e) |
|
|
|
|
|
_NOTHING = object() |
|
|
|
|
|
def unimplemented(msg: str, *, from_exc: Any = _NOTHING) -> NoReturn: |
|
assert msg != os.environ.get("BREAK", False) |
|
if from_exc is not _NOTHING: |
|
raise Unsupported(msg) from from_exc |
|
raise Unsupported(msg) |
|
|
|
|
|
def warning(msg: str) -> None: |
|
counters["warnings"][msg] += 1 |
|
assert msg != os.environ.get("BREAK", False) |
|
|
|
|
|
|
|
|
|
class KeyErrorMsg: |
|
def __init__(self, value): |
|
self.value = value |
|
|
|
def __str__(self): |
|
return str(self.value) |
|
|
|
def __repr__(self) -> str: |
|
return self.__str__() |
|
|
|
|
|
def augment_exc_message(exc: Exception, msg: str = "\n", export: bool = False) -> None: |
|
import traceback |
|
|
|
exc.innermost_user_frame_summary = None |
|
|
|
real_stack = get_real_stack(exc) |
|
if real_stack is not None and len(real_stack) > 0: |
|
exc.innermost_user_frame_summary = real_stack[-1] |
|
msg += f"\nfrom user code:\n {''.join(traceback.format_list(real_stack))}" |
|
|
|
if config.replay_record_enabled and hasattr(exc, "record_filename"): |
|
msg += f"\nLast frame execution written to {exc.record_filename}. To run only this frame while debugging, run\ |
|
torch._dynamo.replay('{exc.record_filename}').\n" |
|
|
|
if not config.verbose and hasattr(exc, "real_stack"): |
|
msg += '\nSet TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information\n' |
|
|
|
if hasattr(exc, "inner_exception") and hasattr( |
|
exc.inner_exception, "minifier_path" |
|
): |
|
if hasattr(exc.inner_exception, "buck_command"): |
|
msg += ( |
|
f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run " |
|
f"this buck command to find the smallest traced graph " |
|
f"which reproduces this error: {exc.inner_exception.buck_command}\n" |
|
) |
|
else: |
|
msg += ( |
|
f"\nMinifier script written to {exc.inner_exception.minifier_path}. Run " |
|
"this script to find the smallest traced graph which reproduces this error.\n" |
|
) |
|
|
|
if not config.suppress_errors and not export: |
|
msg += ( |
|
"\n\n" |
|
"You can suppress this exception and fall back to eager by setting:\n" |
|
" import torch._dynamo\n" |
|
" torch._dynamo.config.suppress_errors = True\n" |
|
) |
|
|
|
old_msg = "" if len(exc.args) == 0 else str(exc.args[0]) |
|
|
|
if isinstance(exc, KeyError): |
|
exc.args = (KeyErrorMsg(old_msg + msg),) + exc.args[1:] |
|
else: |
|
new_msg = old_msg + msg |
|
exc.args = (new_msg,) + exc.args[1:] |
|
|
|
|
|
def get_real_stack(exc: Exception, frame=None) -> Optional[StackSummary]: |
|
real_stack = getattr(exc, "real_stack", None) |
|
if real_stack is None: |
|
return None |
|
|
|
|
|
|
|
|
|
|
|
stack_above_dynamo = [] |
|
if frame is not None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
stack_above_dynamo = filter_stack(extract_stack()) |
|
|
|
return cast(StackSummary, stack_above_dynamo + real_stack) |
|
|
|
|
|
|
|
def filter_stack(stack): |
|
user_stack = [] |
|
for frame in stack: |
|
if "convert_frame" in frame.filename: |
|
break |
|
if "eval_frame" in frame.filename or "torch._dynamo.optimize(" in frame.line: |
|
continue |
|
user_stack.append(frame) |
|
|
|
return user_stack |
|
|
|
|
|
def format_error_msg_verbose( |
|
exc: Exception, code, record_filename=None, frame=None |
|
) -> str: |
|
msg = ( |
|
f"WON'T CONVERT {code.co_name} {code.co_filename} line {code.co_firstlineno}\n" |
|
) |
|
msg += "=" * 10 + " TorchDynamo Stack Trace " + "=" * 10 + "\n" |
|
msg += format_exc() |
|
real_stack = get_real_stack(exc, frame) |
|
if real_stack is not None: |
|
msg += ( |
|
"\n" |
|
+ "=" * 10 |
|
+ " The above exception occurred while processing the following code " |
|
+ "=" * 10 |
|
+ "\n\n" |
|
) |
|
msg += "".join(format_list(real_stack)) |
|
msg += "\n" |
|
msg += "=" * 10 |
|
|
|
return msg |
|
|
|
|
|
def format_error_msg(exc: Exception, code, record_filename=None, frame=None) -> str: |
|
msg = os.linesep * 2 |
|
|
|
if config.verbose: |
|
msg = format_error_msg_verbose(exc, code, record_filename, frame) |
|
else: |
|
msg = f"WON'T CONVERT {code.co_name} {code.co_filename}\ |
|
line {code.co_firstlineno} \ndue to: \n{format_exc()}" |
|
|
|
return msg |
|
|