Spaces:
Running
Running
"""ONNX exporter exceptions.""" | |
from __future__ import annotations | |
import textwrap | |
from typing import Optional | |
from torch import _C | |
from torch.onnx import _constants | |
from torch.onnx._internal import diagnostics | |
__all__ = [ | |
"OnnxExporterError", | |
"OnnxExporterWarning", | |
"CheckerError", | |
"SymbolicValueError", | |
"UnsupportedOperatorError", | |
] | |
class OnnxExporterWarning(UserWarning): | |
"""Base class for all warnings in the ONNX exporter.""" | |
pass | |
class OnnxExporterError(RuntimeError): | |
"""Errors raised by the ONNX exporter.""" | |
pass | |
class CheckerError(OnnxExporterError): | |
"""Raised when ONNX checker detects an invalid model.""" | |
pass | |
class UnsupportedOperatorError(OnnxExporterError): | |
"""Raised when an operator is unsupported by the exporter.""" | |
def __init__(self, name: str, version: int, supported_version: Optional[int]): | |
if supported_version is not None: | |
diagnostic_rule: diagnostics.infra.Rule = ( | |
diagnostics.rules.operator_supported_in_newer_opset_version | |
) | |
msg = diagnostic_rule.format_message(name, version, supported_version) | |
diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg) | |
else: | |
if name.startswith(("aten::", "prim::", "quantized::")): | |
diagnostic_rule = diagnostics.rules.missing_standard_symbolic_function | |
msg = diagnostic_rule.format_message( | |
name, version, _constants.PYTORCH_GITHUB_ISSUES_URL | |
) | |
diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg) | |
else: | |
diagnostic_rule = diagnostics.rules.missing_custom_symbolic_function | |
msg = diagnostic_rule.format_message(name) | |
diagnostics.diagnose(diagnostic_rule, diagnostics.levels.ERROR, msg) | |
super().__init__(msg) | |
class SymbolicValueError(OnnxExporterError): | |
"""Errors around TorchScript values and nodes.""" | |
def __init__(self, msg: str, value: _C.Value): | |
message = ( | |
f"{msg} [Caused by the value '{value}' (type '{value.type()}') in the " | |
f"TorchScript graph. The containing node has kind '{value.node().kind()}'.] " | |
) | |
code_location = value.node().sourceRange() | |
if code_location: | |
message += f"\n (node defined in {code_location})" | |
try: | |
# Add its input and output to the message. | |
message += "\n\n" | |
message += textwrap.indent( | |
( | |
"Inputs:\n" | |
+ ( | |
"\n".join( | |
f" #{i}: {input_} (type '{input_.type()}')" | |
for i, input_ in enumerate(value.node().inputs()) | |
) | |
or " Empty" | |
) | |
+ "\n" | |
+ "Outputs:\n" | |
+ ( | |
"\n".join( | |
f" #{i}: {output} (type '{output.type()}')" | |
for i, output in enumerate(value.node().outputs()) | |
) | |
or " Empty" | |
) | |
), | |
" ", | |
) | |
except AttributeError: | |
message += ( | |
" Failed to obtain its input and output for debugging. " | |
"Please refer to the TorchScript graph for debugging information." | |
) | |
super().__init__(message) | |