File size: 4,394 Bytes
d1ceb73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
# mypy: allow-untyped-defs
import logging
import warnings
from typing import Any, Dict, Iterable, Optional, Tuple
import torch
import torch.export
import torch.export._trace
from torch._utils_internal import log_export_usage
log = logging.getLogger(__name__)
__all__ = ["report_exportability"]
def _generate_inputs_for_submodules(
model: torch.nn.Module,
target_submodules: Iterable[str],
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
) -> Dict[str, Tuple[Any, Any]]:
"""
Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this
function doesn't work.
Args:
model: root model.
inputs: inputs to the root model.
target_submodules: submodules that we want to generate inputs for.
Returns:
A dict that maps from submodule name to its inputs.
"""
kwargs = kwargs or {}
handles = []
results = {}
submodule_to_names = {mod: name for name, mod in model.named_modules()}
def pre_forward(module, module_args, module_kwargs):
results[submodule_to_names[module]] = (module_args, module_kwargs)
try:
for name, mod in model.named_modules():
if name in target_submodules:
handles.append(
mod.register_forward_pre_hook(pre_forward, with_kwargs=True)
)
model(*args, **kwargs)
except Exception as e:
warnings.warn(
f"Failed to generate submodule inputs because of the following error:\n{e}"
)
finally:
for h in handles:
h.remove()
return results
def report_exportability(
mod: torch.nn.Module,
args: Tuple[Any, ...],
kwargs: Optional[Dict[str, Any]] = None,
*,
strict: bool = True,
pre_dispatch: bool = False,
) -> Dict[str, Optional[Exception]]:
"""
Report exportability issues for a module in one-shot.
Args:
mod: root module.
args: args to the root module.
kwargs: kwargs to the root module.
Returns:
A dict that maps from submodule name to the exception that was raised when trying to export it.
`None` means the module is exportable without issue.
Sample output:
{
'': UnsupportedOperatorException(func=<OpOverload(op='testlib.op_missing_meta', overload='default')>),
'submod_1': UnsupportedOperatorException(func=<OpOverload(op='testlib.op_missing_meta', overload='default')>),
'submod_2': None
}
"""
log_export_usage(event="export.report_exportability")
kwargs = kwargs or {}
all_submod_names = [name for name, _ in mod.named_modules() if name != ""]
submod_inputs = _generate_inputs_for_submodules(mod, all_submod_names, args, kwargs)
report: Dict[str, Optional[Exception]] = {}
def try_export(module, module_name, args, kwargs):
nonlocal submod_inputs, report, strict, pre_dispatch
if args is not None or kwargs is not None:
try:
torch.export._trace._export(
module,
args,
kwargs,
strict=strict,
pre_dispatch=pre_dispatch,
)
report[module_name] = None
log.info("Successfully exported `%s`", module_name)
return
except Exception as e:
short_msg = repr(e).split("\n")[0]
log.warning(
"Failed exporting `%s` with exception: %s", module_name, short_msg
)
report[module_name] = e
for name, submod in module.named_children():
sub_module_name = name if module_name == "" else f"{module_name}.{name}"
submod_args, submod_kwargs = submod_inputs.get(
sub_module_name, (None, None)
)
try_export(submod, sub_module_name, submod_args, submod_kwargs)
return
try_export(mod, "", args, kwargs)
unique_issues = set()
for exception in report.values():
if exception is not None:
key = repr(exception).split("\\n")[0]
unique_issues.add(key)
log.warning("Found %d export issues:", len(unique_issues))
for issue in unique_issues:
log.warning(issue)
return report
|