|
|
|
import abc |
|
import copy |
|
import operator |
|
from collections import defaultdict |
|
from copy import deepcopy |
|
from enum import Enum |
|
from typing import Any, cast, Dict, List, Optional, Set, Tuple, Union |
|
|
|
import torch |
|
import torch.fx._pytree as fx_pytree |
|
import torch.utils._pytree as pytree |
|
from torch._library.fake_class_registry import FakeScriptObject |
|
from torch.export._tree_utils import reorder_kwargs |
|
from torch.export.exported_program import ( |
|
ConstantArgument, |
|
ExportedProgram, |
|
InputKind, |
|
ModuleCallSignature, |
|
SymIntArgument, |
|
TensorArgument, |
|
) |
|
from torch.fx._symbolic_trace import is_fx_tracing |
|
from torch.utils._pytree import GetAttrKey, SequenceKey |
|
|
|
__all__ = ["InterpreterModule", "UnflattenedModule", "unflatten", "FlatArgsAdapter"] |
|
|
|
|
|
class _AttrKind(Enum): |
|
PARAMETER = "parameter" |
|
BUFFER = "buffer" |
|
CONSTANT = "constant" |
|
|
|
|
|
|
|
|
|
def _assign_attr( |
|
from_obj: Union[torch.Tensor, torch.ScriptObject], |
|
to_module: torch.nn.Module, |
|
target: str, |
|
attr_kind: _AttrKind, |
|
persistent: bool = True, |
|
): |
|
*prefix, field = target.split(".") |
|
for item in prefix: |
|
t = getattr(to_module, item, None) |
|
|
|
if t is None: |
|
t = torch.nn.Module() |
|
setattr(to_module, item, t) |
|
to_module = t |
|
|
|
if attr_kind == _AttrKind.PARAMETER: |
|
assert isinstance(from_obj, torch.nn.Parameter) |
|
to_module.register_parameter(field, from_obj) |
|
elif attr_kind == _AttrKind.BUFFER: |
|
assert isinstance(from_obj, torch.Tensor) |
|
to_module.register_buffer(field, from_obj, persistent=persistent) |
|
elif attr_kind == _AttrKind.CONSTANT: |
|
assert not isinstance( |
|
from_obj, FakeScriptObject |
|
), "FakeScriptObject should only exist during tracing." |
|
assert isinstance( |
|
from_obj, |
|
( |
|
torch.Tensor, |
|
torch.ScriptObject, |
|
), |
|
) |
|
setattr(to_module, field, from_obj) |
|
|
|
|
|
class InterpreterModule(torch.nn.Module): |
|
"""A module that uses torch.fx.Interpreter to execute instead of the usual |
|
codegen that GraphModule uses. This provides better stack trace information |
|
and makes it easier to debug execution. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
graph: torch.fx.Graph, |
|
): |
|
super().__init__() |
|
self.graph = graph |
|
self.graph.owning_module = self |
|
|
|
def forward(self, *args, **kwargs): |
|
assert self.graph_module is not None, "Didn't finalize this InterpreterModule" |
|
if torch.compiler.is_dynamo_compiling(): |
|
|
|
|
|
return self.graph_module(*args, **kwargs) |
|
else: |
|
if kwargs: |
|
|
|
|
|
|
|
|
|
arg_list = list(args) |
|
kwarg_names = self.arg_names[len(arg_list) :] |
|
for kwarg_name in kwarg_names: |
|
if kwarg_name in kwargs: |
|
arg_list.append(kwargs[kwarg_name]) |
|
|
|
|
|
|
|
|
|
assert len(kwarg_names) == len(kwargs) |
|
assert len(arg_list) == len(self.arg_names) |
|
args = tuple(arg_list) |
|
|
|
return torch.fx.Interpreter(self, graph=self.graph).run( |
|
*args, enable_io_processing=False |
|
) |
|
|
|
def finalize(self): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.__dict__["graph_module"] = torch.fx.GraphModule(self, self.graph) |
|
self.graph.lint() |
|
|
|
|
|
self.arg_names = [] |
|
for node in self.graph.nodes: |
|
if node.op == "placeholder": |
|
self.arg_names.append(node.target) |
|
|
|
|
|
class FlatArgsAdapter(abc.ABC): |
|
""" |
|
Adapts input arguments with ``input_spec`` to align ``target_spec``. |
|
""" |
|
|
|
@abc.abstractmethod |
|
def adapt( |
|
self, |
|
target_spec: pytree.TreeSpec, |
|
input_spec: pytree.TreeSpec, |
|
input_args: List[Any], |
|
) -> List[Any]: |
|
"""NOTE: This adapter may mutate given ``input_args_with_path``.""" |
|
... |
|
|
|
|
|
class UnflattenedModule(torch.nn.Module): |
|
def __init__( |
|
self, |
|
export_module: ExportedProgram, |
|
flat_args_adapter: Optional[FlatArgsAdapter] = None, |
|
): |
|
super().__init__() |
|
if export_module.graph_signature.backward_signature is not None: |
|
raise ValueError("Unflattening on JointExportModule NYI") |
|
|
|
fqn_list = [entry.fqn for entry in export_module.module_call_graph] |
|
assert fqn_list[0] == "" |
|
export_graph = deepcopy(export_module.graph) |
|
self.graph_signature = deepcopy(export_module.graph_signature) |
|
self.graph = torch.fx.Graph() |
|
self.module_call_graph = deepcopy(export_module.module_call_graph) |
|
self.flat_args_adapter = flat_args_adapter |
|
|
|
self.adapted = False |
|
|
|
_inplace_buffer_mutations(export_graph, self.graph_signature) |
|
_outline_submodules(export_graph, self) |
|
|
|
self.range_constraints = export_module.range_constraints |
|
self.equality_constraints: List = [] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
state_dict = export_module.state_dict |
|
assigned_params: Set[str] = set() |
|
id_to_param: Dict[int, torch.nn.Parameter] = {} |
|
for name in self.graph_signature.parameters: |
|
param = state_dict[name] |
|
if id(param) not in id_to_param: |
|
id_to_param[id(param)] = torch.nn.Parameter(param.clone()) |
|
|
|
_assign_attr( |
|
id_to_param[id(param)], |
|
self, |
|
name, |
|
attr_kind=_AttrKind.PARAMETER, |
|
) |
|
assigned_params.add(name) |
|
|
|
non_persistent_buffers = set(self.graph_signature.non_persistent_buffers) |
|
assigned_buffers: Set[str] = set() |
|
id_to_buffer: Dict[ |
|
int, Tuple[torch.nn.Parameter, bool] |
|
] = {} |
|
for name in self.graph_signature.buffers: |
|
if name in non_persistent_buffers: |
|
persistent = False |
|
buffer = export_module.constants[name] |
|
else: |
|
persistent = True |
|
buffer = state_dict[name] |
|
|
|
if id(buffer) not in id_to_buffer: |
|
id_to_buffer[id(buffer)] = (buffer.clone(), persistent) |
|
|
|
_assign_attr( |
|
id_to_buffer[id(buffer)][0], |
|
self, |
|
name, |
|
attr_kind=_AttrKind.BUFFER, |
|
persistent=persistent, |
|
) |
|
assigned_buffers.add(name) |
|
|
|
|
|
|
|
for name, tensor in state_dict.items(): |
|
if name in assigned_params or name in assigned_buffers: |
|
continue |
|
|
|
is_buffer = False |
|
if id(tensor) in id_to_buffer or not isinstance( |
|
tensor, torch.nn.Parameter |
|
): |
|
is_buffer = True |
|
|
|
if is_buffer: |
|
if ( |
|
id(tensor) not in id_to_buffer |
|
): |
|
id_to_buffer[id(tensor)] = ( |
|
tensor, |
|
True, |
|
) |
|
_assign_attr( |
|
id_to_buffer[id(tensor)][0], |
|
self, |
|
name, |
|
attr_kind=_AttrKind.BUFFER, |
|
persistent=True, |
|
) |
|
else: |
|
if id(tensor) not in id_to_param: |
|
id_to_param[id(tensor)] = tensor |
|
_assign_attr( |
|
id_to_param[id(tensor)], |
|
self, |
|
name, |
|
attr_kind=_AttrKind.PARAMETER, |
|
) |
|
|
|
|
|
id_to_const: Dict[int, Union[torch.Tensor, torch._C.ScriptObject]] = {} |
|
for fqn, constant in export_module.constants.items(): |
|
if id(constant) not in id_to_const: |
|
if isinstance(constant, torch.Tensor): |
|
constant = constant.clone() |
|
id_to_const[id(constant)] = constant |
|
_constant = id_to_const[id(constant)] |
|
_assign_attr( |
|
_constant, |
|
self, |
|
fqn, |
|
attr_kind=_AttrKind.CONSTANT, |
|
) |
|
|
|
|
|
|
|
consts_map: Dict[int, List[Tuple[str, str]]] = defaultdict(list) |
|
consts_targets: Set[str] = set() |
|
|
|
def add_to_consts_map(obj_id, node_name, target_name): |
|
name_list = consts_map[obj_id] |
|
name_list.append((node_name, target_name)) |
|
|
|
added_params_buffers: Set[str] = set() |
|
for s in self.graph_signature.input_specs: |
|
if s.kind == InputKind.PARAMETER or ( |
|
s.kind == InputKind.BUFFER and s.persistent |
|
): |
|
assert hasattr(s.arg, "name") |
|
assert isinstance(s.target, str) |
|
add_to_consts_map( |
|
id(export_module.state_dict[s.target]), s.arg.name, s.target |
|
) |
|
consts_targets.add(s.target) |
|
added_params_buffers.add(s.target) |
|
elif ( |
|
(s.kind == InputKind.BUFFER and not s.persistent) |
|
or s.kind == InputKind.CONSTANT_TENSOR |
|
or s.kind == InputKind.CUSTOM_OBJ |
|
): |
|
assert hasattr(s.arg, "name") |
|
assert isinstance(s.target, str) |
|
add_to_consts_map( |
|
id(export_module.constants[s.target]), s.arg.name, s.target |
|
) |
|
consts_targets.add(s.target) |
|
|
|
|
|
for const_name, const in export_module.constants.items(): |
|
if const_name not in consts_targets: |
|
assert ( |
|
id(const) in consts_map |
|
), "Constants should be either aliased or appear in graph signature" |
|
ph_name, _ = consts_map[id(const)][0] |
|
add_to_consts_map(id(const), ph_name, const_name) |
|
added_params_buffers.add(s.target) |
|
|
|
|
|
for fqn, tensor in export_module.state_dict.items(): |
|
if fqn not in added_params_buffers: |
|
if id(tensor) not in consts_map: |
|
|
|
|
|
|
|
continue |
|
ph_name, _ = consts_map[id(tensor)][0] |
|
add_to_consts_map(id(tensor), ph_name, fqn) |
|
|
|
|
|
inputs_to_state: Dict[str, List[str]] = {} |
|
for node_target in consts_map.values(): |
|
targets = [t[1] for t in node_target] |
|
for n, _ in node_target: |
|
inputs_to_state[n] = targets |
|
|
|
_sink_params(self, inputs_to_state, []) |
|
|
|
|
|
def check_module_inputs(module, scope): |
|
if hasattr(module, "graph"): |
|
for node in module.graph.nodes: |
|
|
|
|
|
|
|
|
|
|
|
|
|
if ( |
|
node.op == "placeholder" |
|
and node.name in inputs_to_state |
|
and any( |
|
fqn.split(".")[: len(scope)] == scope |
|
for fqn in inputs_to_state[node.name] |
|
) |
|
): |
|
raise AssertionError( |
|
f"{node.name} was not sunk into the module {scope} which has the graph: {module.graph}" |
|
) |
|
|
|
for name, submod in module.named_children(): |
|
scope.append(name) |
|
check_module_inputs(submod, scope) |
|
|
|
|
|
check_module_inputs(self, []) |
|
|
|
|
|
|
|
|
|
self.input_placeholders = [ |
|
node for node in self.graph.nodes if node.op == "placeholder" |
|
] |
|
self.check_input_constraints = True |
|
|
|
fqn_order = {fqn: i for i, fqn in enumerate(fqn_list)} |
|
|
|
for name, _ in self.named_modules(remove_duplicate=False): |
|
if name not in fqn_order: |
|
fqn_order[name] = len(fqn_order) |
|
_reorder_submodules(self, fqn_order) |
|
assert [fqn for fqn, _ in self.named_modules(remove_duplicate=False)] == list( |
|
fqn_order.keys() |
|
) |
|
|
|
def _print_graph(self): |
|
for fqn, mod in self.named_modules(): |
|
print(fqn + ":") |
|
if hasattr(mod, "graph") and isinstance(mod.graph, torch.fx.Graph): |
|
print(mod.graph) |
|
|
|
def forward(self, *args, **kwargs): |
|
signature = self.module_call_graph[0].signature |
|
|
|
reordered_kwargs = reorder_kwargs(kwargs, signature.in_spec) |
|
|
|
flat_args_with_path, in_spec = pytree.tree_flatten_with_path( |
|
(args, reordered_kwargs) |
|
) |
|
flat_args = [x[1] for x in flat_args_with_path] |
|
if is_fx_tracing(): |
|
return_val = torch.fx.Interpreter(self, graph=self.graph).run( |
|
*flat_args, enable_io_processing=False |
|
) |
|
|
|
if isinstance(return_val, tuple) and len(return_val) == 1: |
|
return return_val[0] |
|
return return_val |
|
|
|
if in_spec != signature.in_spec: |
|
if not self.adapted: |
|
print( |
|
"Input treespec does not match with exported module's: \n" |
|
f"Input treespec: {in_spec}. ", |
|
f"Exported module treespec: {signature.in_spec}", |
|
) |
|
if self.flat_args_adapter is None: |
|
raise TypeError( |
|
"There is no flat args adapter sepcified. " |
|
"Are you sure you are calling this with the right arguments? " |
|
) |
|
else: |
|
if not self.adapted: |
|
print("Adapting flat arg to match exported module's treespec") |
|
flat_args = self.flat_args_adapter.adapt( |
|
target_spec=signature.in_spec, |
|
input_spec=in_spec, |
|
input_args=flat_args, |
|
) |
|
self.adapted = True |
|
if len(flat_args) != signature.in_spec.num_leaves: |
|
raise TypeError( |
|
f"Flat args adaption failed, number of args mismatch " |
|
f"Adatped: {len(flat_args)} \n" |
|
f"Exported module: {signature.in_spec.num_leaves}" |
|
) |
|
|
|
if self.check_input_constraints: |
|
|
|
|
|
from torch._export.utils import _check_input_constraints_for_graph |
|
|
|
if self.adapted is True: |
|
|
|
|
|
|
|
new_flat_args_with_path = [ |
|
((SequenceKey(idx=0), GetAttrKey(name="<unknown location>")), arg) |
|
for arg in flat_args |
|
] |
|
else: |
|
new_flat_args_with_path = flat_args_with_path |
|
|
|
_check_input_constraints_for_graph( |
|
self.input_placeholders, new_flat_args_with_path, self.range_constraints |
|
) |
|
tree_out = torch.fx.Interpreter(self, graph=self.graph).run( |
|
*flat_args, enable_io_processing=False |
|
) |
|
return pytree.tree_unflatten(tree_out, signature.out_spec) |
|
|
|
|
|
def unflatten( |
|
module: ExportedProgram, flat_args_adapter: Optional[FlatArgsAdapter] = None |
|
) -> UnflattenedModule: |
|
"""Unflatten an ExportedProgram, producing a module with the same module |
|
hierarchy as the original eager module. This can be useful if you are trying |
|
to use :mod:`torch.export` with another system that expects a module |
|
hierachy instead of the flat graph that :mod:`torch.export` usually produces. |
|
|
|
.. note:: The args/kwargs of unflattened modules will not necessarily match |
|
the eager module, so doing a module swap (e.g. :code:`self.submod = |
|
new_mod`) will not necessarily work. If you need to swap a module out, you |
|
need to set the :code:`preserve_module_call_signature` parameter of |
|
:func:`torch.export.export`. |
|
|
|
Args: |
|
module (ExportedProgram): The ExportedProgram to unflatten. |
|
flat_args_adapter (Optional[FlatArgsAdapter]): Adapt flat args if input TreeSpec does not match with exported module's. |
|
|
|
Returns: |
|
An instance of :class:`UnflattenedModule`, which has the same module |
|
hierarchy as the original eager module pre-export. |
|
""" |
|
return UnflattenedModule(module, flat_args_adapter) |
|
|
|
|
|
def _inplace_buffer_mutations(graph: torch.fx.Graph, graph_signature) -> None: |
|
"""Transform buffer mutations from their functionalized form into a copy_ |
|
node in the graph. |
|
|
|
Functionalization represents buffer mutation by passing the buffer as an input and output. So for example, the eager code: |
|
def forward(self, x): |
|
self.buffer += x |
|
return x * x |
|
|
|
Will become a graph that looks like: |
|
def forward(self, buffer, x): |
|
mutated_buffer = aten.add(buffer, x) |
|
mul = aten.mul(x, x) |
|
return (mutated_buffer, mul) |
|
|
|
We want to inplace this into something that looks like the original eager code: |
|
def forward(self, buffer, x): |
|
mutated_buffer = aten.add(buffer, x) |
|
buffer.copy_(mutated_buffer) |
|
mul = aten.mul(x, x) |
|
return (mul,) |
|
""" |
|
output_node = next(iter(reversed(graph.nodes))) |
|
assert output_node.op == "output" and len(output_node.args) == 1 |
|
return_args = output_node.args[0] |
|
|
|
mutation_node_to_buffer = graph_signature.buffers_to_mutate |
|
mutations = return_args[: len(mutation_node_to_buffer)] |
|
buffers_to_inputs = {v: k for k, v in graph_signature.inputs_to_buffers.items()} |
|
input_name_to_node = { |
|
node.name: node for node in graph.nodes if node.op == "placeholder" |
|
} |
|
|
|
for mutation in mutations: |
|
buffer_name = mutation_node_to_buffer[mutation.name] |
|
input_name = buffers_to_inputs[buffer_name] |
|
input_node = input_name_to_node[input_name] |
|
|
|
with graph.inserting_after(mutation): |
|
new_node = graph.create_node( |
|
"call_function", torch.ops.aten.copy_, (input_node, mutation) |
|
) |
|
for k, v in mutation.meta.items(): |
|
new_node.meta[k] = v |
|
|
|
mutation.replace_all_uses_with(new_node, lambda x: x is not new_node) |
|
|
|
|
|
|
|
|
|
user_outputs = tuple( |
|
return_args[len(mutation_node_to_buffer) :], |
|
) |
|
output_node.args = ((user_outputs),) |
|
|
|
|
|
def _is_prefix(candidate, target): |
|
"""Check whether `candidate` is a prefix of `target`.""" |
|
return len(candidate) < len(target) and target[: len(candidate)] == candidate |
|
|
|
|
|
def _compute_accessor(parent_fqn: str, child_fqn: str) -> str: |
|
if parent_fqn == "": |
|
|
|
return child_fqn |
|
|
|
parent_split = parent_fqn.split(".") |
|
child_split = child_fqn.split(".") |
|
|
|
assert ( |
|
child_split[: len(parent_split)] == parent_split |
|
), f"Child module '{child_fqn}' is not a descendant of parent module '{parent_fqn}'" |
|
return ".".join(child_split[len(parent_split) :]) |
|
|
|
|
|
def _verify_graph_equivalence(x: torch.nn.Module, y: torch.nn.Module): |
|
def graph_dump(graph: torch.fx.Graph) -> str: |
|
ret = [] |
|
nodes_idx: Dict[int, int] = {} |
|
|
|
def arg_dump(arg) -> str: |
|
if isinstance(arg, torch.fx.Node): |
|
return "%" + str(nodes_idx[id(arg)]) |
|
return str(arg) |
|
|
|
for i, node in enumerate(graph.nodes): |
|
args_dump = [str(arg) for arg in pytree.tree_map(arg_dump, node.args)] |
|
args_dump += [ |
|
f"{key}={value}" |
|
for key, value in pytree.tree_map(arg_dump, node.kwargs).items() |
|
] |
|
target = node.target if node.op == "call_function" else "" |
|
ret.append(f"{i}: {node.op}[{target}]({', '.join(args_dump)})") |
|
nodes_idx[id(node)] = i |
|
return "\n".join(ret) |
|
|
|
assert graph_dump(x.graph) == graph_dump(y.graph) |
|
|
|
|
|
def _add_spec(gm: torch.nn.Module, spec) -> str: |
|
i = 0 |
|
while hasattr(gm, f"_spec_{i}"): |
|
i += 1 |
|
name = f"_spec_{i}" |
|
setattr(gm, name, spec) |
|
return name |
|
|
|
|
|
def _generate_flatten(gm: torch.nn.Module, node, spec) -> torch.fx.Node: |
|
name = _add_spec(gm, spec) |
|
spec_node = gm.graph.get_attr(name) |
|
return gm.graph.call_function(fx_pytree.tree_flatten_spec, (node, spec_node)) |
|
|
|
|
|
def _generate_unflatten(gm: torch.nn.Module, nodes, spec) -> torch.fx.Node: |
|
name = _add_spec(gm, spec) |
|
spec_node = gm.graph.get_attr(name) |
|
return gm.graph.call_function(pytree.tree_unflatten, (nodes, spec_node)) |
|
|
|
|
|
def _get_submodule(mod: torch.nn.Module, target: str): |
|
*prefix, field = target.split(".") |
|
|
|
for item in prefix: |
|
submod = getattr(mod, item, None) |
|
|
|
if submod is None: |
|
return None |
|
|
|
if not isinstance(submod, torch.nn.Module): |
|
return None |
|
|
|
mod = submod |
|
|
|
return getattr(mod, field, None) |
|
|
|
|
|
def _add_submodule(mod: torch.nn.Module, target: str, module_to_add: torch.nn.Module): |
|
*prefix, field = target.split(".") |
|
|
|
for item in prefix: |
|
submod = getattr(mod, item, None) |
|
|
|
if submod is None: |
|
submod = torch.nn.Module() |
|
setattr(mod, item, submod) |
|
|
|
if not isinstance(submod, torch.nn.Module): |
|
return False |
|
|
|
mod = submod |
|
|
|
mod.add_module(field, module_to_add) |
|
|
|
|
|
class _ModuleFrame: |
|
def __init__( |
|
self, |
|
flat_graph, |
|
nodes, |
|
seen_nodes, |
|
seen_modules, |
|
parent, |
|
module_stack, |
|
module_id, |
|
module_call_graph: Dict[str, ModuleCallSignature], |
|
module: Optional[torch.nn.Module] = None, |
|
): |
|
self.flat_graph = flat_graph |
|
self.nodes = nodes |
|
self.seen_nodes = seen_nodes |
|
self.seen_modules = seen_modules |
|
self.parent = parent |
|
self.module_stack = module_stack |
|
self.module_id = module_id |
|
|
|
self.module_call_graph = module_call_graph |
|
self.verbose = False |
|
|
|
self.fqn = self.module_stack[-1] |
|
if module is not None: |
|
self.module = module |
|
else: |
|
self.module = InterpreterModule(torch.fx.Graph()) |
|
if self.module_id in self.seen_modules: |
|
self.cached_graph_module = self.seen_modules[self.module_id] |
|
else: |
|
self.cached_graph_module = None |
|
self.seen_modules[self.module_id] = self.module |
|
|
|
self.graph = self.module.graph |
|
|
|
|
|
self.node_map: Dict[torch.fx.Node, torch.fx.Node] = {} |
|
self.node_to_placeholder = {} |
|
|
|
self.parent_call_module: Optional[torch.fx.Node] = None |
|
if parent is not None: |
|
accessor = _compute_accessor(parent.fqn, self.fqn) |
|
_add_submodule( |
|
parent.module, |
|
accessor, |
|
( |
|
self.module |
|
if self.cached_graph_module is None |
|
else self.cached_graph_module |
|
), |
|
) |
|
self.parent_call_module = parent.graph.call_module(accessor) |
|
|
|
signature = module_call_graph.get(self.fqn) |
|
if signature is not None and self.parent is not None: |
|
assert signature.in_spec.num_children == 2 |
|
args_spec = signature.in_spec.children_specs[0] |
|
kwargs_spec = signature.in_spec.children_specs[1] |
|
assert args_spec.context is None |
|
assert kwargs_spec.context is not None |
|
|
|
with self.graph.inserting_after(None): |
|
arg_nodes = [] |
|
for idx in range(args_spec.num_children): |
|
arg_nodes.append(self.graph.placeholder(f"_positional_arg_{idx}")) |
|
kwarg_nodes = {} |
|
for name in kwargs_spec.context: |
|
kwarg_nodes[name] = self.graph.placeholder(name) |
|
flat_args = _generate_flatten( |
|
self.module, |
|
(tuple(arg_nodes), kwarg_nodes), |
|
signature.in_spec, |
|
) |
|
for idx, arg in enumerate(signature.inputs): |
|
flat_arg_node = self.graph.create_node( |
|
op="call_function", |
|
target=operator.getitem, |
|
args=(flat_args, idx), |
|
name=( |
|
arg.name |
|
if not isinstance(arg, ConstantArgument) |
|
else f"_constant_{idx}" |
|
), |
|
) |
|
if isinstance(arg, ConstantArgument): |
|
continue |
|
flat_arg_node.meta = copy.copy(self.seen_nodes[arg.name].meta) |
|
self.node_to_placeholder[self.seen_nodes[arg.name]] = flat_arg_node |
|
|
|
with self.parent.graph.inserting_before(self.parent_call_module): |
|
input_nodes: List[Optional[torch.fx.Node]] = [] |
|
for input in signature.inputs: |
|
if isinstance(input, ConstantArgument) and input.value is None: |
|
input_nodes.append(None) |
|
else: |
|
assert isinstance(input, (TensorArgument, SymIntArgument)) |
|
input_nodes.append( |
|
self.parent.remap_input(self.seen_nodes[input.name]) |
|
) |
|
|
|
inputs_node = _generate_unflatten( |
|
self.parent.module, |
|
input_nodes, |
|
signature.in_spec, |
|
) |
|
|
|
args_node = self.parent.graph.call_function( |
|
operator.getitem, (inputs_node, 0) |
|
) |
|
kwargs_node = self.parent.graph.call_function( |
|
operator.getitem, (inputs_node, 1) |
|
) |
|
arg_nodes = [ |
|
self.parent.graph.call_function(operator.getitem, (args_node, i)) |
|
for i in range(args_spec.num_children) |
|
] |
|
kwarg_nodes = { |
|
k: self.parent.graph.call_function( |
|
operator.getitem, (kwargs_node, k) |
|
) |
|
for k in kwargs_spec.context |
|
} |
|
assert self.parent_call_module is not None |
|
self.parent_call_module.args = tuple(arg_nodes) |
|
self.parent_call_module.kwargs = kwarg_nodes |
|
|
|
def add_placeholder(self, x): |
|
assert self.fqn != "", f"Cannot add placeholder {x} to root module" |
|
assert x.graph is self.flat_graph |
|
|
|
with self.graph.inserting_before(None): |
|
placeholder_node = self.graph.placeholder(x.name, type_expr=x.type) |
|
|
|
|
|
placeholder_node.meta = copy.copy(x.meta) |
|
self.node_to_placeholder[x] = placeholder_node |
|
|
|
def remap_input(self, x): |
|
assert x.graph is self.flat_graph |
|
if x in self.node_map: |
|
return self.node_map[x] |
|
if x not in self.node_to_placeholder: |
|
self.add_placeholder(x) |
|
if self.parent_call_module is not None: |
|
|
|
|
|
self.parent_call_module.insert_arg(0, self.parent.remap_input(x)) |
|
return self.node_to_placeholder[x] |
|
|
|
def finalize_outputs(self): |
|
orig_outputs = [] |
|
|
|
signature = self.module_call_graph.get(self.fqn) |
|
if signature is not None and self.parent is not None: |
|
for output in signature.outputs: |
|
if isinstance(output, (TensorArgument, SymIntArgument)): |
|
orig_outputs.append(self.seen_nodes[output.name]) |
|
else: |
|
raise RuntimeError( |
|
f"Unsupported data type for output node: {output}" |
|
) |
|
|
|
tree_out_node = _generate_unflatten( |
|
self.module, |
|
tuple( |
|
self.node_map[self.seen_nodes[output.name]] |
|
for output in orig_outputs |
|
), |
|
signature.out_spec, |
|
) |
|
parent_out: Optional[torch.fx.Node] = _generate_flatten( |
|
self.parent.module, self.parent_call_module, signature.out_spec |
|
) |
|
graph_outputs: Union[torch.fx.Node, List[torch.fx.Node]] = tree_out_node |
|
else: |
|
graph_outputs = [] |
|
|
|
for orig_node in self.node_map.keys(): |
|
for user_node in orig_node.users: |
|
if user_node.name not in self.seen_nodes: |
|
|
|
orig_outputs.append(orig_node) |
|
graph_outputs.append(self.node_map[orig_node]) |
|
break |
|
|
|
parent_out = self.parent_call_module |
|
if len(graph_outputs) == 1: |
|
graph_outputs = graph_outputs[0] |
|
|
|
assert isinstance(graph_outputs, (list, torch.fx.Node)) |
|
|
|
self.graph.output(graph_outputs) |
|
|
|
|
|
if parent_out is None: |
|
return |
|
|
|
parent_out.meta["val"] = ( |
|
graph_outputs.meta.get("val") |
|
if isinstance(graph_outputs, torch.fx.Node) |
|
else [o.meta.get("val") for o in graph_outputs] |
|
) |
|
|
|
if len(orig_outputs) == 1 and signature is None: |
|
self.parent.node_map[orig_outputs[0]] = parent_out |
|
else: |
|
for i, orig_output in enumerate(orig_outputs): |
|
|
|
proxy_out = torch.fx.Proxy(parent_out)[i].node |
|
proxy_out.meta["val"] = orig_output.meta.get("val") |
|
self.parent.node_map[orig_output] = proxy_out |
|
|
|
if self.cached_graph_module is not None: |
|
_verify_graph_equivalence(self.cached_graph_module, self.module) |
|
|
|
def copy_node(self, node): |
|
self.print("copying", node.format_node()) |
|
self.node_map[node] = self.graph.node_copy(node, self.remap_input) |
|
self.seen_nodes[node.name] = node |
|
|
|
def run_outer(self): |
|
i = 0 |
|
for node in self.flat_graph.nodes: |
|
self.print(i, node.meta.get("nn_module_stack"), node.format_node()) |
|
i += 1 |
|
|
|
|
|
node_idx: int = 0 |
|
node = self.nodes[node_idx] |
|
while node.op == "placeholder": |
|
self.copy_node(node) |
|
node_idx += 1 |
|
node = self.nodes[node_idx] |
|
|
|
self.run_from(node_idx) |
|
|
|
|
|
for node in self.flat_graph.nodes: |
|
if node.op == "output": |
|
self.copy_node(node) |
|
|
|
def print(self, *args, **kwargs): |
|
if self.verbose: |
|
print(*args, **kwargs) |
|
|
|
def run_from(self, node_idx): |
|
module_idx = 0 |
|
|
|
while node_idx < len(self.nodes): |
|
node = self.nodes[node_idx] |
|
assert node.op != "placeholder" |
|
|
|
self.print() |
|
self.print("STEP", node_idx, node.format_node()) |
|
self.print(self.module_stack) |
|
if node.op == "output": |
|
if len(self.module_stack) == 1: |
|
|
|
|
|
|
|
return node_idx |
|
|
|
|
|
self.finalize_outputs() |
|
return node_idx |
|
|
|
if len(node.meta.get("nn_module_stack", {})) == 0: |
|
raise RuntimeError(f"Unable to find nn_module_stack for node {node}") |
|
|
|
nn_module_stack = node.meta["nn_module_stack"] |
|
from torch._export.passes._node_metadata_hook import ( |
|
_EMPTY_NN_MODULE_STACK_KEY, |
|
) |
|
|
|
if ( |
|
len(nn_module_stack) == 1 |
|
and _EMPTY_NN_MODULE_STACK_KEY in nn_module_stack |
|
): |
|
|
|
node_module_stack = self.module_stack |
|
else: |
|
node_module_stack = [ |
|
path for path, ty in node.meta["nn_module_stack"].values() |
|
] |
|
|
|
if node_module_stack[: len(self.module_stack)] != self.module_stack: |
|
|
|
|
|
|
|
|
|
|
|
self.finalize_outputs() |
|
self.print("outlining", self.fqn) |
|
self.print(self.graph) |
|
return node_idx |
|
|
|
assert node_module_stack is not None |
|
|
|
if _is_prefix(self.module_stack, node_module_stack): |
|
|
|
|
|
next_module = node_module_stack[len(self.module_stack)] |
|
self.print("Creating new stack frame for", next_module) |
|
|
|
|
|
node_idx = _ModuleFrame( |
|
self.flat_graph, |
|
self.nodes, |
|
self.seen_nodes, |
|
self.seen_modules, |
|
self, |
|
self.module_stack + [next_module], |
|
list(node.meta["nn_module_stack"].keys())[len(self.module_stack)], |
|
self.module_call_graph, |
|
).run_from(node_idx) |
|
module_idx += 1 |
|
continue |
|
|
|
|
|
|
|
assert node_module_stack == self.module_stack |
|
self.copy_node(node) |
|
node_idx += 1 |
|
|
|
|
|
def _outline_submodules(orig_graph: torch.fx.Graph, root_module: UnflattenedModule): |
|
seen_nodes: Dict[str, torch.fx.Node] = {} |
|
seen_modules: Dict[int, torch.nn.Module] = {} |
|
_ModuleFrame( |
|
orig_graph, |
|
tuple(orig_graph.nodes), |
|
seen_nodes, |
|
seen_modules, |
|
None, |
|
[""], |
|
"", |
|
{ |
|
entry.fqn: entry.signature |
|
for entry in root_module.module_call_graph |
|
if entry.signature |
|
}, |
|
module=root_module, |
|
).run_outer() |
|
|
|
|
|
def _reorder_submodules( |
|
parent: torch.nn.Module, fqn_order: Dict[str, int], prefix: str = "" |
|
): |
|
|
|
if prefix == "": |
|
for fqn in list(fqn_order.keys())[1:]: |
|
if _get_submodule(parent, fqn) is None: |
|
_add_submodule(parent, fqn, torch.nn.Module()) |
|
|
|
children = [] |
|
for name, child in list(parent._modules.items()): |
|
if child is None: |
|
continue |
|
fqn = prefix + name |
|
_reorder_submodules(child, fqn_order, prefix=fqn + ".") |
|
delattr(parent, name) |
|
children.append((fqn_order[fqn], name, child)) |
|
children.sort(key=operator.itemgetter(0)) |
|
for _, name, child in children: |
|
parent.register_module(name, child) |
|
|
|
|
|
def _sink_params( |
|
module: torch.nn.Module, |
|
inputs_to_state: Dict[str, List[str]], |
|
scope: List[str], |
|
): |
|
"""Sink params, buffers, and constants from graph inputs into get_attr nodes. |
|
|
|
Exported modules are purely functional, so they pass their parameters and |
|
buffers in as inputs to the graph. |
|
|
|
To replicate eager's semantics, we need to get them from the module state |
|
via get_attr instead. |
|
|
|
module: GraphModule, potentially containining nested submodules. |
|
inputs_to_state: mapping graph input names to the corresponding key in the state_dict. |
|
scope: tracks where we are in the module hierarchy, so that we can emit the |
|
right `getattr(self, "foo.bar")` calls, etc. |
|
""" |
|
|
|
|
|
|
|
module_id_to_inputs_removed: Dict[int, List[str]] = defaultdict(list) |
|
|
|
|
|
|
|
for name, submodule in module._modules.items(): |
|
submod_id_to_inputs_removed = _sink_params( |
|
cast(torch.nn.Module, submodule), inputs_to_state, scope + [name] |
|
) |
|
for k, v in submod_id_to_inputs_removed.items(): |
|
module_id_to_inputs_removed[k].extend(v) |
|
|
|
if not hasattr(module, "graph"): |
|
|
|
return module_id_to_inputs_removed |
|
|
|
graph = module.graph |
|
inputs = list(filter(lambda n: n.op == "placeholder", graph.nodes)) |
|
the_last_input = inputs[-1] |
|
|
|
|
|
call_module_nodes = filter(lambda n: n.op == "call_module", graph.nodes) |
|
for node in call_module_nodes: |
|
submodule = _recursive_getattr(module, node.target.split(".")) |
|
|
|
|
|
if submodule is not None and id(submodule) in module_id_to_inputs_removed: |
|
node.args = tuple( |
|
filter( |
|
lambda n: n.name not in module_id_to_inputs_removed[id(submodule)], |
|
node.args, |
|
) |
|
) |
|
|
|
|
|
inputs_to_state_of_scope: Dict[torch.fx.Node, list[str]] = {} |
|
for node in inputs: |
|
if node.name not in inputs_to_state: |
|
continue |
|
|
|
state_name = None |
|
for sn in inputs_to_state[node.name]: |
|
sn_split = sn.split(".") |
|
if sn_split[: len(scope)] == scope: |
|
state_name = sn_split |
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if state_name is None: |
|
continue |
|
|
|
inputs_to_state_of_scope[node] = state_name |
|
|
|
|
|
inputs_removed: List[str] = [] |
|
|
|
for node, state_name in inputs_to_state_of_scope.items(): |
|
if len(node.users) > 0: |
|
attr_path = state_name[len(scope) :] |
|
state_attr = _recursive_getattr(module, attr_path) |
|
assert isinstance(state_attr, (torch.Tensor, torch.ScriptObject)) |
|
|
|
|
|
with graph.inserting_after(the_last_input): |
|
new_node = graph.create_node("get_attr", ".".join(attr_path)) |
|
|
|
node.replace_all_uses_with(new_node, propagate_meta=True) |
|
|
|
graph.erase_node(node) |
|
inputs_removed.append(node.name) |
|
|
|
if isinstance(module, InterpreterModule): |
|
module.finalize() |
|
|
|
return {id(module): inputs_removed} |
|
|
|
|
|
def _recursive_getattr(obj, attr_path): |
|
for attr in attr_path: |
|
if not hasattr(obj, attr): |
|
return None |
|
obj = getattr(obj, attr) |
|
|
|
return obj |
|
|