Spaces:
Running
Running
import dataclasses | |
import functools | |
import inspect | |
import logging | |
import re | |
import time | |
import warnings | |
from contextlib import contextmanager, nullcontext | |
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union | |
import torch | |
import torch._dynamo | |
import torch.fx | |
import torch.utils._pytree as pytree | |
from torch._dynamo.exc import UserError, UserErrorType | |
from torch._export.non_strict_utils import ( | |
make_constraints, | |
make_fake_inputs, | |
make_fake_params_buffers, | |
) | |
from torch._export.passes.add_runtime_assertions_for_constraints_pass import ( | |
_AddRuntimeAssertionsForInlineConstraintsPass, | |
) | |
from torch._export.passes.collect_tracepoints_pass import CollectTracepointsPass | |
from torch._export.passes.lift_constants_pass import ( | |
ConstantAttrMap, | |
lift_constants_pass, | |
rewrite_script_object_meta, | |
) | |
from torch._export.wrappers import _wrap_submodules | |
from torch._functorch.aot_autograd import aot_export_module | |
from torch._guards import detect_fake_mode | |
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode | |
from torch._utils_internal import log_export_usage | |
from torch.export.exported_program import OutputKind | |
from torch.fx.experimental.symbolic_shapes import ( | |
ConstraintViolationError, | |
free_unbacked_symbols, | |
GuardOnDataDependentSymNode, | |
ShapeEnv, | |
) | |
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo | |
from torch.utils._sympy.value_ranges import ValueRangeError | |
from ._safeguard import AutogradStateOpsFailSafeguard | |
from .dynamic_shapes import _process_constraints, Constraint | |
from .exported_program import ( | |
_disable_prexisiting_fake_mode, | |
ExportedProgram, | |
InputKind, | |
ModuleCallEntry, | |
ModuleCallSignature, | |
) | |
from .graph_signature import ( | |
_sig_to_specs, | |
ArgumentSpec, | |
ConstantArgument, | |
CustomObjArgument, | |
ExportGraphSignature, | |
SymIntArgument, | |
TensorArgument, | |
) | |
log = logging.getLogger(__name__) | |
class ExportDynamoConfig: | |
""" | |
Manage Export-specific configurations of Dynamo. | |
""" | |
allow_rnn: bool = True | |
reorderable_logging_functions: Set[Callable] = dataclasses.field( | |
default_factory=set | |
) | |
DEFAULT_EXPORT_DYNAMO_CONFIG = ExportDynamoConfig() | |
DEFAULT_EXPORT_DYNAMO_CONFIG.reorderable_logging_functions = { | |
logging.critical, | |
logging.debug, | |
logging.error, | |
logging.exception, | |
logging.info, | |
logging.log, | |
logging.warning, | |
print, | |
warnings.warn, | |
} | |
def _ignore_backend_decomps(): | |
orig_mkldnn_flag = torch.backends.mkldnn.set_flags(False) | |
orig_nnpack_flag = torch.backends.nnpack.set_flags(False) | |
try: | |
yield | |
finally: | |
torch.backends.mkldnn.set_flags(*orig_mkldnn_flag) | |
torch.backends.nnpack.set_flags(*orig_nnpack_flag) | |
def _convert_input_to_fake(gm, args, kwargs): | |
params_buffers = _get_params_buffers(gm) | |
fake_inps: List[torch.Tensor] = [] | |
for node in gm.graph.nodes: | |
if node.op == "placeholder" and "val" in node.meta: | |
fake_val = node.meta["val"] | |
if fake_val is not None and isinstance(fake_val, torch.Tensor): | |
fake_inps.append(fake_val) | |
if detected_fake_mode := detect_fake_mode(fake_inps): | |
fake_mode = detected_fake_mode | |
else: | |
fake_mode = FakeTensorMode(shape_env=ShapeEnv()) | |
if len(args) == 0 and len(kwargs) == 0: | |
return (), {}, params_buffers, fake_mode | |
count = 0 | |
def convert_to_fake(x): | |
nonlocal count | |
val = fake_inps[count] | |
count += 1 | |
return val | |
fake_args = pytree.tree_map_only(torch.Tensor, convert_to_fake, args) | |
# TODO properly use the cached fake tensor | |
fake_kwargs = pytree.tree_map_only(torch.Tensor, fake_mode.from_tensor, kwargs) | |
fake_params_buffers = pytree.tree_map_only( | |
torch.Tensor, | |
functools.partial(fake_mode.from_tensor, static_shapes=True), | |
params_buffers, | |
) | |
return fake_args, fake_kwargs, fake_params_buffers, fake_mode | |
def _replace_param_buffer_names(param_buffer_table, sig): | |
for spec in sig.input_specs: | |
if spec.kind in ( | |
InputKind.PARAMETER, | |
InputKind.BUFFER, | |
): | |
spec.target = param_buffer_table[spec.target] | |
for spec in sig.output_specs: | |
if spec.kind in ( | |
OutputKind.BUFFER_MUTATION, | |
OutputKind.GRADIENT_TO_PARAMETER, | |
): | |
spec.target = param_buffer_table[spec.target] | |
def _convert_to_positional_args(orig_arg_names, args, kwargs): | |
assert len(orig_arg_names) == len(args) + len(kwargs), ( | |
f"Total number of arg names is expected to be {len(orig_arg_names)} " | |
f"but got {len(args)} positional args, {len(kwargs)} kwargs." | |
) | |
reordered_kwargs = [kwargs[kw_name] for kw_name in orig_arg_names[len(args) :]] | |
return ( | |
*args, | |
*reordered_kwargs, | |
) | |
def _normalize_nn_module_stack(gm_torch_level, root_cls): | |
# Append a root module to every nn_module_stack. | |
root = "L['self']" | |
root_key = re.sub(r"[^a-zA-Z0-9]", "_", root) | |
for gm in gm_torch_level.modules(): | |
if not isinstance(gm, torch.fx.GraphModule): | |
continue | |
for node in gm.graph.nodes: | |
if node.op in ["placeholder", "output"]: | |
continue | |
add_root = True | |
if nn_module_stack := node.meta.get("nn_module_stack", {}): | |
path, ty = next(iter(nn_module_stack.values())) | |
# After deserializing the class `ty` might not exist anymore so | |
# it could be a string | |
if inspect.isclass(ty) and issubclass(ty, torch.nn.Module): | |
# TODO Figure out why sometimes we have root sometimes we don't. | |
if path == root and ty is root_cls: | |
add_root = False | |
else: | |
assert isinstance(ty, str) | |
if add_root: | |
def normalize_path(path): | |
try: | |
parts = [] | |
class Path: | |
def __getattr__(self, name): | |
parts.append(name) | |
return self | |
def __getitem__(self, idx): | |
parts.append(str(idx)) | |
return self | |
eval(path, {"L": {"self": Path()}}) | |
return ".".join(parts) | |
except Exception: # TODO(zhxchen17) Remove this. | |
return path | |
nn_module_stack = {root_key: (root, root_cls), **nn_module_stack} | |
node.meta["nn_module_stack"] = { | |
key: (normalize_path(path), ty) | |
for key, (path, ty) in nn_module_stack.items() | |
} | |
def _get_param_buffer_mapping( | |
original_module: torch.nn.Module, | |
traced_module: torch.nn.Module, | |
) -> Dict[str, str]: | |
""" | |
Returns a mapping of parameter/buffer names from the new module to the | |
original model. This is to help with restoring the FQN for parameter/buffers | |
of a traced module to what the original module contains. | |
""" | |
param_lookup: Dict[int, List[str]] = {} | |
buffer_lookup: Dict[int, List[str]] = {} | |
for name, param in original_module.named_parameters(remove_duplicate=False): | |
param_lookup.setdefault(id(param), []).append(name) | |
for name, buffer in original_module.named_buffers(remove_duplicate=False): | |
buffer_lookup.setdefault(id(buffer), []).append(name) | |
param_buffer_table: Dict[str, str] = {} | |
for dynamo_name, dynamo_param in traced_module.named_parameters( | |
remove_duplicate=False | |
): | |
assert dynamo_name not in param_buffer_table | |
if id(dynamo_param) in param_lookup: | |
param_buffer_table[dynamo_name] = param_lookup[id(dynamo_param)].pop() | |
for dynamo_name, dynamo_buffer in traced_module.named_buffers( | |
remove_duplicate=False | |
): | |
assert dynamo_name not in param_buffer_table | |
if id(dynamo_buffer) in buffer_lookup: | |
param_buffer_table[dynamo_name] = buffer_lookup[id(dynamo_buffer)].pop() | |
return param_buffer_table | |
def _remap_constants( | |
orig_constant_attrs: ConstantAttrMap, | |
graph_signature: ExportGraphSignature, | |
constants: Dict[str, Union[torch.Tensor, torch.ScriptObject]], | |
) -> None: | |
"""Rewrite the graph signature and constants table to use the FQN from the original module.""" | |
remap_table: Dict[str, str] = {} | |
for name, value in constants.items(): | |
if value in orig_constant_attrs: | |
remap_table[name] = orig_constant_attrs[value] | |
for spec in graph_signature.input_specs: | |
if spec.kind in ( | |
InputKind.CONSTANT_TENSOR, | |
InputKind.CUSTOM_OBJ, | |
): | |
orig_target = spec.target | |
assert orig_target is not None | |
spec.target = remap_table.get(orig_target, orig_target) | |
constant = constants[orig_target] | |
del constants[orig_target] | |
constants[spec.target] = constant | |
def _restore_state_dict( | |
original_module: torch.nn.Module, traced_module: torch.fx.GraphModule | |
) -> None: | |
""" | |
Restores the state dict of the traced module to that of the original module. | |
""" | |
param_buffer_table = _get_param_buffer_mapping(original_module, traced_module) | |
# Since the graph module is flattened (no module heirarchy), we | |
# need to noramlize the module by replacing "." with "_". If we | |
# don't, it will try to save the weight to a submodule which no | |
# longer exists. | |
for name, fqn in param_buffer_table.items(): | |
param_buffer_table[name] = fqn.replace(".", "_") | |
# Replace state dict attr names with the fqn | |
for name, fqn in param_buffer_table.items(): | |
if not hasattr(traced_module, name): | |
continue | |
attr = getattr(traced_module, name) | |
if isinstance(attr, torch.Tensor) and not isinstance(attr, torch.nn.Parameter): | |
traced_module.register_buffer(fqn, attr) | |
else: | |
setattr(traced_module, fqn, attr) | |
delattr(traced_module, name) | |
# Replace graph getattr nodes with the correct name | |
for node in traced_module.graph.nodes: | |
if node.op == "get_attr": | |
attr_name = node.target | |
if attr_name in param_buffer_table: | |
node.target = param_buffer_table[attr_name] | |
traced_module.recompile() | |
def _export_to_torch_ir( | |
f: Callable, | |
args: Tuple[Any, ...], | |
kwargs: Optional[Dict[str, Any]] = None, | |
constraints: Optional[List[Constraint]] = None, | |
*, | |
preserve_module_call_signature: Tuple[str, ...] = (), | |
disable_constraint_solver: bool = False, | |
restore_fqn: bool = True, | |
_log_export_usage: bool = True, | |
) -> torch.fx.GraphModule: | |
""" | |
Traces either an nn.Module's forward function or just a callable with PyTorch | |
operations inside and produce a torch.fx.GraphModule in torch IR. | |
""" | |
if _log_export_usage: | |
log_export_usage(event="export.private_api", flags={"_export_to_torch_ir"}) | |
kwargs = kwargs or {} | |
if not isinstance(args, tuple): | |
raise UserError( | |
UserErrorType.INVALID_INPUT, | |
f"Expecting `args` to be a tuple of example positional inputs, got {type(args)}", | |
) | |
with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)): | |
try: | |
module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {} | |
with _wrap_submodules( | |
f, preserve_module_call_signature, module_call_specs | |
), _ignore_backend_decomps(): | |
gm_torch_level, _ = torch._dynamo.export( | |
f, | |
constraints=constraints, # type: ignore[arg-type] | |
assume_static_by_default=True, | |
tracing_mode="symbolic", | |
disable_constraint_solver=disable_constraint_solver, | |
_log_export_usage=_log_export_usage, | |
)( | |
*args, | |
**kwargs, | |
) | |
except (ConstraintViolationError, ValueRangeError) as e: | |
raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: TRY200 | |
except GuardOnDataDependentSymNode as e: | |
raise UserError( # noqa: TRY200 | |
UserErrorType.ANTI_PATTERN, | |
f"Consider annotating your code using torch._constrain_as_*(). {str(e)}", | |
case_name="constrain_as_size_example", | |
) | |
gm_torch_level.meta["module_call_specs"] = module_call_specs | |
if isinstance(f, torch.nn.Module) and restore_fqn: | |
_restore_state_dict(f, gm_torch_level) | |
return gm_torch_level | |
def _gather_constant_attrs(m: torch.nn.Module) -> ConstantAttrMap: | |
"""Search the module hierarchy, gathering up all tensor and ScriptObject constants. | |
Returns a dictionary mapping hash(value) to the name of the constant. We | |
have to abuse `hash` here unfortunately, see: [ScriptObject hash]. | |
""" | |
constants = ConstantAttrMap() | |
buffers_parameters = set(m.buffers()) | |
buffers_parameters.update(m.parameters()) | |
def inner(m: torch.nn.Module, prefix_atoms: List[str], constants): | |
for k, v in m.__dict__.items(): | |
if isinstance(v, (torch.Tensor, torch.ScriptObject)): | |
if v in buffers_parameters: | |
# filter out buffers and parameters, leaving only constants | |
continue | |
fqn = ".".join(prefix_atoms + [k]) | |
if v in constants: | |
raise ValueError( | |
f"Duplicate reference to constant attribute found: '{constants[v]}' and '{fqn}'." | |
) | |
constants[v] = fqn | |
for k, v in m.named_children(): | |
inner(v, prefix_atoms + [k], constants) | |
inner(m, [], constants) | |
return constants | |
def _export_non_strict( | |
mod: torch.nn.Module, | |
fake_args, | |
fake_kwargs, | |
fake_params_buffers, | |
constant_attrs: ConstantAttrMap, | |
*, | |
transform=lambda x: x, # TODO(zhxchen17) Revisit if this is needed later. | |
pre_dispatch=False, | |
): | |
# [NOTE] If the user is exporting under training mode, we want to detect if there is any | |
# state change in the autograd global state and error. If the user is exporting under inference | |
# mode, we don't care. | |
is_grad_enabled = torch._C.is_grad_enabled() | |
grad_safe_guard = ( | |
AutogradStateOpsFailSafeguard() if is_grad_enabled else nullcontext() | |
) | |
def _compiling_state_context(): | |
old_value = torch.compiler._is_compiling_flag | |
try: | |
torch.compiler._is_compiling_flag = True | |
yield | |
finally: | |
torch.compiler._is_compiling_flag = old_value | |
# This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode, | |
# otherwise aot_export_module will error out because it sees a mix of fake_modes. | |
# And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about. | |
with torch.nn.utils.stateless._reparametrize_module( | |
mod, fake_params_buffers | |
), grad_safe_guard, _ignore_backend_decomps(), _compiling_state_context(): # type: ignore[attr-defined] | |
gm, graph_signature = transform(aot_export_module)( | |
mod, | |
fake_args, | |
trace_joint=False, | |
pre_dispatch=pre_dispatch, | |
kwargs=fake_kwargs, | |
) | |
# TODO unfortunately preserving graph-level metadata is not | |
# working well with aot_export. So we manually copy it. | |
# (The node-level meta is addressed above.) | |
if isinstance(mod, torch.fx.GraphModule) and hasattr(mod, "meta"): | |
gm.meta.update(mod.meta) | |
if pre_dispatch: | |
from torch._export.passes.replace_set_grad_with_hop_pass import ( | |
replace_set_grad_with_hop_pass, | |
) | |
gm = replace_set_grad_with_hop_pass(gm) | |
# NOTE: aot_export adds symint metadata for placeholders with int values; | |
# since these become specialized, we replace such metadata with the original values | |
flat_args = pytree.tree_leaves((fake_args, fake_kwargs)) | |
index = 0 | |
total_non_user_inputs = ( | |
len(graph_signature.parameters) | |
+ len(graph_signature.buffers) | |
+ len(graph_signature.input_tokens) | |
) | |
for node in gm.graph.nodes: | |
if node.op == "placeholder": | |
if index >= total_non_user_inputs: | |
user_arg = flat_args[index - total_non_user_inputs] | |
if not isinstance(user_arg, torch.Tensor): | |
node.meta["val"] = user_arg | |
index += 1 | |
is_joint = graph_signature.backward_signature is not None | |
def make_argument_spec(node) -> ArgumentSpec: | |
if isinstance(node, (int, bool, float, type(None))): | |
# For const outputs we just directly return this | |
return ConstantArgument(value=node) | |
assert ( | |
"val" in node.meta | |
), f"{node} is not a constant or a node with a 'val' metadata field" | |
val = node.meta["val"] | |
if isinstance(val, FakeTensor): | |
return TensorArgument(name=node.name) | |
elif isinstance(val, torch.SymInt): | |
return SymIntArgument(name=node.name) | |
elif isinstance(val, torch.ScriptObject): | |
return CustomObjArgument( | |
name=node.name, class_fqn=val._type().qualified_name() # type: ignore[attr-defined] | |
) | |
else: | |
# TODO: this branch is likely wrong, all permissible ConstantArgument type | |
# should have been handled already | |
return ConstantArgument(value=val) | |
input_specs, output_specs = _sig_to_specs( | |
user_inputs=set(graph_signature.user_inputs), | |
inputs_to_parameters=graph_signature.inputs_to_parameters, # type: ignore[arg-type] | |
inputs_to_buffers=graph_signature.inputs_to_buffers, # type: ignore[arg-type] | |
user_outputs=set(graph_signature.user_outputs), # type: ignore[arg-type] | |
buffer_mutations=graph_signature.buffers_to_mutate, # type: ignore[arg-type] | |
user_input_mutations=graph_signature.user_inputs_to_mutate, # type: ignore[arg-type] | |
grad_params=graph_signature.backward_signature.gradients_to_parameters if is_joint else {}, # type: ignore[arg-type, union-attr] | |
grad_user_inputs=graph_signature.backward_signature.gradients_to_user_inputs if is_joint else {}, # type: ignore[arg-type, union-attr] | |
loss_output=graph_signature.backward_signature.loss_output if is_joint else None, # type: ignore[arg-type, union-attr] | |
inputs=[ | |
make_argument_spec(node) | |
for node in gm.graph.nodes | |
if node.op == "placeholder" | |
], | |
outputs=[ | |
make_argument_spec(node) | |
for node in pytree.tree_leaves(next(iter(reversed(gm.graph.nodes))).args) | |
], | |
input_tokens=graph_signature.input_tokens, | |
output_tokens=graph_signature.output_tokens, | |
) | |
export_graph_signature = ExportGraphSignature( | |
input_specs=input_specs, output_specs=output_specs | |
) | |
constants = rewrite_script_object_meta(gm) | |
constants.update(lift_constants_pass(gm, export_graph_signature, constant_attrs)) | |
class _ExportedProgramNonStrict: | |
gm: torch.fx.GraphModule | |
sig: ExportGraphSignature | |
constants: Dict[str, Union[torch.Tensor, torch._C.ScriptObject]] | |
return _ExportedProgramNonStrict( | |
gm, | |
export_graph_signature, | |
constants, | |
) | |
def _get_params_buffers(mod: torch.nn.Module) -> Dict[str, torch.Tensor]: | |
params_buffers: Dict[str, torch.Tensor] = {} | |
for name, param in mod.named_parameters(remove_duplicate=False): | |
params_buffers[name] = param | |
for name, buffer in mod.named_buffers(remove_duplicate=False): | |
params_buffers[name] = buffer | |
return params_buffers | |
def _rewrite_dynamo_tensor_constants( | |
orig_mod_buffers: Set[torch.Tensor], | |
traced_mod_buffers: Dict[str, torch.Tensor], | |
graph_signature: ExportGraphSignature, | |
constants: Dict[str, Union[torch.Tensor, torch.ScriptObject]], | |
): | |
"""Dynamo erroneously marks tensor attributes on modules as a buffers. | |
Rewrite them to be tensor constants. | |
""" | |
for spec in graph_signature.input_specs: | |
if spec.kind == InputKind.BUFFER: | |
assert spec.target is not None | |
value = traced_mod_buffers[spec.target] | |
if value not in orig_mod_buffers: | |
# This was a tensor constant erroneously marked as a buffer. | |
# Convert it int oa constant in the graph signature, and add its | |
# value to the constants table. | |
spec.kind = InputKind.CONSTANT_TENSOR | |
constants[spec.target] = value | |
def _rewrite_non_persistent_buffers( | |
orig_mod: torch.nn.Module, | |
graph_signature: ExportGraphSignature, | |
constants: Dict[str, Union[torch.Tensor, torch.ScriptObject]], | |
): | |
"""Dynamo erroneously drops the persistent flag on buffers. | |
Rewrite non-persistent buffers to reflect the original module. | |
""" | |
state_dict = orig_mod.state_dict() | |
for spec in graph_signature.input_specs: | |
if spec.kind == InputKind.BUFFER: | |
assert spec.target is not None | |
if spec.target not in state_dict: | |
assert spec.target not in constants | |
spec.persistent = False | |
constants[spec.target] = orig_mod.get_buffer(spec.target) | |
def get_ep_stats(ep: ExportedProgram) -> Dict[str, Any]: | |
op_count = 0 | |
op_set = set() | |
for m in ep.graph_module.modules(): | |
if not isinstance(m, torch.fx.GraphModule): | |
continue | |
for node in m.graph.nodes: | |
if node.op != "call_function": | |
continue | |
op_count += 1 | |
assert hasattr(node.target, "__module__") | |
assert hasattr(node.target, "__name__") | |
op_set.add(f"{node.target.__module__}.{node.target.__name__}") | |
return {"op_count": op_count, "op_set": op_set} | |
_EXPORT_FLAGS: Optional[Set[str]] = None | |
def _log_export_wrapper(fn): | |
def wrapper(*args, **kwargs): | |
global _EXPORT_FLAGS | |
try: | |
start = time.time() | |
ep = fn(*args, **kwargs) | |
end = time.time() | |
log_export_usage( | |
event="export.time", | |
metrics=end - start, | |
flags=_EXPORT_FLAGS, | |
**get_ep_stats(ep), | |
) | |
except Exception as e: | |
t = type(e) | |
error_type = t.__module__ + "." + t.__qualname__ | |
log_export_usage( | |
event="export.error", | |
type=error_type, | |
message=str(e), | |
flags=_EXPORT_FLAGS, | |
) | |
raise e | |
finally: | |
_EXPORT_FLAGS = None | |
return ep | |
return wrapper | |
def _export( | |
mod: torch.nn.Module, | |
args: Tuple[Any, ...], | |
kwargs: Optional[Dict[str, Any]] = None, | |
dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any], List[Any]]] = None, | |
*, | |
strict: bool = True, | |
preserve_module_call_signature: Tuple[str, ...] = (), | |
pre_dispatch: bool = False, | |
) -> ExportedProgram: | |
""" | |
Traces either an nn.Module's forward function or just a callable with PyTorch | |
operations inside and produce a ExportedProgram. | |
Args: | |
f: the `nn.Module` to trace. | |
args: example positional inputs. | |
kwargs: optional example keyword inputs. | |
dynamic_shapes: | |
An optional argument where the type should either be: | |
1) a dict from argument names of ``f`` to their dynamic shape specifications, | |
2) a tuple that specifies dynamic shape specifications for each input in original order. | |
If you are specifying dynamism on keyword args, you will need to pass them in the order that | |
is defined in the original function signature. | |
The dynamic shape of a tensor argument can be specified as either | |
(1) a dict from dynamic dimension indices to :func:`Dim` types, where it is | |
not required to include static dimension indices in this dict, but when they are, | |
they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, | |
where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions | |
are denoted by None. Arguments that are dicts or tuples / lists of tensors are | |
recursively specified by using mappings or sequences of contained specifications. | |
preserve_module_call_signature: A list of submodule paths for which the original | |
calling conventions are preserved as metadata. | |
Returns: | |
An ExportedProgram containing the traced method. | |
""" | |
from .dynamic_shapes import _process_dynamic_shapes | |
global _EXPORT_FLAGS | |
flags = set() | |
flags.add("strict" if strict else "non_strict") | |
flags.add("pre_dispatch" if pre_dispatch else "aot_dispatch") | |
log_export_usage(event="export.enter", flags=flags) | |
_EXPORT_FLAGS = flags | |
constraints = _process_dynamic_shapes(mod, args, kwargs, dynamic_shapes) or [] | |
kwargs = kwargs or {} | |
constant_attrs = _gather_constant_attrs(mod) | |
flat_args, orig_in_spec = pytree.tree_flatten((args, kwargs)) | |
if not strict: | |
out_spec = None | |
module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {} | |
def strip_root(x): | |
if isinstance(x, str) and x.startswith("_export_root"): | |
stripped = x[len("_export_root") :] | |
return stripped[1:] if stripped.startswith(".") else stripped | |
return x | |
def fixup_key(x): | |
return "L__self__" + strip_root(x) | |
def _tuplify_outputs(aot_export): | |
def _aot_export_non_strict(mod, args, kwargs=None, **flags): | |
kwargs = kwargs or {} | |
class Wrapper(torch.nn.Module): | |
def __init__(self, mod): | |
super().__init__() | |
self._export_root = mod | |
def forward(self, *args, **kwargs): | |
nonlocal out_spec | |
if isinstance(self._export_root, torch.fx.GraphModule): | |
with torch.fx.traceback.preserve_node_meta(): | |
tree_out = torch.fx.Interpreter(self._export_root).run( | |
*args, **kwargs | |
) | |
else: | |
tree_out = self._export_root(*args, **kwargs) | |
flat_outs, out_spec = pytree.tree_flatten(tree_out) | |
return tuple(flat_outs) | |
wrapped_mod = Wrapper(mod) | |
# Patch export_root to the signatures so that wrapper module correctly populates the | |
# in/out spec | |
new_preserved_call_signatures = [ | |
"_export_root." + i for i in preserve_module_call_signature | |
] | |
with _wrap_submodules( | |
wrapped_mod, new_preserved_call_signatures, module_call_specs | |
): | |
gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags) | |
sig.parameters = pytree.tree_map(strip_root, sig.parameters) | |
sig.buffers = pytree.tree_map(strip_root, sig.buffers) | |
sig.inputs_to_buffers = pytree.tree_map( | |
strip_root, sig.inputs_to_buffers | |
) | |
sig.inputs_to_parameters = pytree.tree_map( | |
strip_root, sig.inputs_to_parameters | |
) | |
sig.buffers_to_mutate = pytree.tree_map( | |
strip_root, sig.buffers_to_mutate | |
) | |
for node in gm.graph.nodes: | |
if "nn_module_stack" in node.meta: | |
nn_module_stack = node.meta["nn_module_stack"] | |
node.meta["nn_module_stack"] = { | |
fixup_key(key): val | |
for key, val in pytree.tree_map( | |
strip_root, nn_module_stack | |
).items() | |
} | |
return gm, sig | |
return _aot_export_non_strict | |
( | |
fake_mode, | |
fake_args, | |
fake_kwargs, | |
equalities_inputs, | |
original_signature, | |
) = make_fake_inputs(mod, args, kwargs, constraints) | |
fake_params_buffers = make_fake_params_buffers( | |
fake_mode, _get_params_buffers(mod) | |
) | |
with fake_mode: | |
ep_non_strict = _export_non_strict( | |
mod, | |
fake_args, | |
fake_kwargs, | |
fake_params_buffers, | |
constant_attrs, | |
pre_dispatch=pre_dispatch, | |
transform=_tuplify_outputs, | |
) | |
try: | |
range_constraints = make_constraints( | |
fake_mode, | |
equalities_inputs, | |
original_signature, | |
ep_non_strict.gm, | |
) | |
except (ConstraintViolationError, ValueRangeError) as e: | |
raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: TRY200 | |
assert out_spec is not None | |
gm = ep_non_strict.gm | |
module_call_signatures = { | |
strip_root(fqn): ModuleCallSignature(inputs=[], outputs=[], **specs) | |
for fqn, specs in module_call_specs.items() | |
} | |
if len(preserve_module_call_signature) > 0: | |
for node in gm.graph.nodes: | |
if node.target == torch.ops.higher_order._export_tracepoint: | |
if "path" in node.kwargs: | |
path = strip_root(node.kwargs["path"]) | |
with gm.graph.inserting_before(node): | |
new_node = gm.graph.create_node( | |
"call_function", | |
torch.ops.higher_order._export_tracepoint, | |
args=node.args, | |
kwargs={ | |
"path": path, | |
"kind": node.kwargs["kind"], | |
}, | |
) | |
node.replace_all_uses_with(new_node) | |
gm.graph.erase_node(node) | |
res = CollectTracepointsPass(module_call_signatures, ep_non_strict.sig)(gm) | |
assert res is not None | |
gm = res.graph_module | |
_rewrite_non_persistent_buffers(mod, ep_non_strict.sig, ep_non_strict.constants) | |
return ExportedProgram( | |
root=gm, | |
graph=gm.graph, | |
graph_signature=ep_non_strict.sig, | |
state_dict=mod.state_dict(keep_vars=True), | |
range_constraints=range_constraints, | |
module_call_graph=[ | |
ModuleCallEntry( | |
"", | |
ModuleCallSignature( | |
inputs=[], outputs=[], in_spec=orig_in_spec, out_spec=out_spec | |
), | |
) | |
] | |
+ [ | |
ModuleCallEntry(fqn, sig) for fqn, sig in module_call_signatures.items() | |
], | |
example_inputs=(args, kwargs), | |
constants=ep_non_strict.constants, | |
) | |
gm_torch_level = _export_to_torch_ir( | |
mod, | |
args, | |
kwargs, | |
constraints, | |
preserve_module_call_signature=preserve_module_call_signature, | |
restore_fqn=False, # don't need to restore because we will do it later | |
_log_export_usage=False, | |
) | |
# We detect the fake_mode by looking at gm_torch_level's placeholders, this is the fake_mode created in dynamo. | |
( | |
fake_args, | |
fake_kwargs, | |
fake_params_buffers, | |
dynamo_fake_mode, | |
) = _convert_input_to_fake(gm_torch_level, args, kwargs) | |
# First, we want to pass through the graph to try populating | |
# val field for getattr if there is anything missing. | |
# This can happen when quantization adds extra params and forgets | |
# to update "val" | |
for node in gm_torch_level.graph.nodes: | |
if node.op == "get_attr" and "val" not in node.meta: | |
attr = getattr(gm_torch_level, node.target) | |
# Checks if it is not a HigherOrderOp branch or a module | |
if not isinstance(attr, torch.nn.Module): | |
assert ( | |
dynamo_fake_mode is not None | |
), "Cannot find dynamo_fake_mode. This could be due to the exported graph module have no placeholders." | |
node.meta["val"] = dynamo_fake_mode.from_tensor( | |
attr, static_shapes=True | |
) | |
# When aot_export lifts the params, we lose the nn_module_stack | |
# and source_fn from the param nodes as they are treated as fresh inputs | |
# Therefore, we manually extract them before calling into aot_export | |
params_buffers_to_node_meta = {} | |
for node in gm_torch_level.graph.nodes: | |
target = node.target | |
meta = node.meta | |
if node.op == "call_module": | |
submodule = getattr(gm_torch_level, target) | |
if isinstance(submodule, torch.nn.Module): | |
for name, _ in submodule.named_parameters( | |
recurse=True, remove_duplicate=False | |
): | |
params_buffers_to_node_meta[target + "." + name] = meta | |
for name, _ in submodule.named_buffers( | |
recurse=True, remove_duplicate=False | |
): | |
params_buffers_to_node_meta[target + "." + name] = meta | |
if node.op == "get_attr": | |
submodule = getattr(gm_torch_level, target) | |
if not isinstance(submodule, torch.fx.GraphModule): | |
params_buffers_to_node_meta[target] = meta | |
# If the call_function uses param as input, we also need to update params' meta | |
# with this call_function node's meta. | |
# This is basically the same flow as torch.fx.traceback.preserve_meta() | |
if node.op == "call_function" and not isinstance( | |
node.target, torch._ops.HigherOrderOperator | |
): | |
for arg in node._input_nodes: | |
if arg.op == "get_attr": | |
for entry in torch.fx.proxy._COPY_META_FIELDS: | |
if entry in meta: | |
params_buffers_to_node_meta[arg.target][entry] = meta[entry] | |
# Fix the graph output signature to be tuple if scalar | |
out_spec = orig_out_spec = gm_torch_level._out_spec | |
assert out_spec is not None | |
# aot_export expect the return type to always be a tuple. | |
if out_spec.type not in (list, tuple): | |
out_spec = pytree.TreeSpec(tuple, None, [out_spec]) | |
orig_arg_names = gm_torch_level.graph._codegen.pytree_info.orig_args # type: ignore[attr-defined] | |
gm_torch_level.graph._codegen = _PyTreeCodeGen( | |
_PyTreeInfo( | |
orig_arg_names, | |
gm_torch_level._in_spec, | |
out_spec, | |
) | |
) | |
gm_torch_level.recompile() | |
_normalize_nn_module_stack(gm_torch_level, type(mod)) | |
# NOTE: graph module expects only positional args | |
ep_non_strict = _export_non_strict( | |
gm_torch_level, | |
_convert_to_positional_args(orig_arg_names, fake_args, fake_kwargs), | |
{}, | |
fake_params_buffers, | |
constant_attrs, | |
pre_dispatch=pre_dispatch, | |
) | |
gm = ep_non_strict.gm | |
export_graph_signature = ep_non_strict.sig | |
constants = ep_non_strict.constants | |
# After aot_export, set the param/buffer metadata back into placeholders | |
# Technically, users can still construct this data from param names | |
# without relying on this metadata | |
for node in gm.graph.nodes: | |
if node.op == "placeholder": | |
if node.target in export_graph_signature.inputs_to_parameters: | |
param_name = export_graph_signature.inputs_to_parameters[node.target] | |
if param_name in params_buffers_to_node_meta: | |
for k, v in params_buffers_to_node_meta[param_name].items(): | |
node.meta[k] = v | |
if node.target in export_graph_signature.inputs_to_buffers: | |
buffer_name = export_graph_signature.inputs_to_buffers[node.target] | |
if buffer_name in params_buffers_to_node_meta: | |
for k, v in params_buffers_to_node_meta[buffer_name].items(): | |
node.meta[k] = v | |
# The unbacked symint symbols are updated in aot_export | |
# so we serialize them here instead of inside dynamo | |
gm.meta["inline_constraints"] = { | |
k: v | |
for k, v in dynamo_fake_mode.shape_env.var_to_range.items() | |
if free_unbacked_symbols(k) | |
} | |
num_lifted = next( | |
( | |
i | |
for i, s in enumerate(export_graph_signature.input_specs) | |
if s.kind == InputKind.USER_INPUT | |
), | |
len(export_graph_signature.input_specs), | |
) | |
range_constraints = _process_constraints( | |
dynamo_fake_mode, | |
gm, | |
num_lifted, | |
flat_args, | |
) | |
# Do some cleanups on the graph module to restore the state dict to the | |
# expected form. Each of these steps should probably get fixed upstream. | |
# 1. Remove tensor constants that were added as buffers. | |
_rewrite_dynamo_tensor_constants( | |
orig_mod_buffers=set(mod.buffers()), | |
traced_mod_buffers=dict(gm_torch_level.named_buffers()), | |
graph_signature=ep_non_strict.sig, | |
constants=ep_non_strict.constants, | |
) | |
# 2. Restore FQN of param/buffers | |
param_buffer_table: Dict[str, str] = _get_param_buffer_mapping(mod, gm_torch_level) | |
_replace_param_buffer_names(param_buffer_table, export_graph_signature) | |
# 3. Remove non-persistent buffers from the graph signature | |
_rewrite_non_persistent_buffers(mod, ep_non_strict.sig, ep_non_strict.constants) | |
# 4. Rewrite constants to have the same FQN as the original module. | |
_remap_constants(constant_attrs, export_graph_signature, constants) | |
module_call_signatures = { | |
fqn: ModuleCallSignature(inputs=[], outputs=[], **specs) | |
for fqn, specs in gm_torch_level.meta["module_call_specs"].items() | |
} | |
if len(preserve_module_call_signature) > 0: | |
res = CollectTracepointsPass(module_call_signatures, export_graph_signature)(gm) | |
assert res is not None | |
gm = res.graph_module | |
assert orig_out_spec is not None | |
exported_program = ExportedProgram( | |
root=gm, | |
graph=gm.graph, | |
graph_signature=export_graph_signature, | |
state_dict=mod.state_dict(keep_vars=True), | |
range_constraints=range_constraints, | |
module_call_graph=[ | |
ModuleCallEntry( | |
"", | |
ModuleCallSignature( | |
inputs=[], outputs=[], in_spec=orig_in_spec, out_spec=orig_out_spec | |
), | |
) | |
] | |
+ [ModuleCallEntry(fqn, sig) for fqn, sig in module_call_signatures.items()], | |
example_inputs=(args, kwargs), | |
constants=constants, | |
) | |
log.debug("Exported program from AOTAutograd:\n%s", exported_program) | |
if len(range_constraints) > 0: | |
exported_program = exported_program._transform_do_not_use( | |
_AddRuntimeAssertionsForInlineConstraintsPass(range_constraints) | |
) | |
return exported_program | |