|
|
|
import re |
|
from typing import Callable, Dict, Optional, Set, Union |
|
|
|
import torch.fx |
|
from torch.fx.node import map_arg |
|
from torch.fx.passes.split_module import split_module |
|
|
|
|
|
__all__ = ['FoldedGraphModule', 'get_unique_attr_name_in_module', 'split_const_subgraphs'] |
|
|
|
class FoldedGraphModule(torch.fx.GraphModule): |
|
""" |
|
FoldedGraphModule is a GraphModule which also contains another |
|
`const_subgraph_module` representing a subgraph which has all const attr |
|
inputs and which can be run once before running the main standard |
|
`graph`. The `const_output_names` are the ordered list names of attrs which |
|
represent what each respective output from the const_subgraph should be set |
|
on which attrs. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
root: torch.nn.Module, |
|
graph: torch.fx.Graph, |
|
const_subgraph: Optional[torch.fx.Graph] = None, |
|
fx_const_folded_attrs_name: Optional[str] = None, |
|
device_for_folded_attrs: str = "cuda", |
|
): |
|
super().__init__(root, graph) |
|
self.const_subgraph_module = ( |
|
None |
|
if const_subgraph is None |
|
else torch.fx.GraphModule(root, const_subgraph) |
|
) |
|
self.has_folding_been_run = False |
|
self.fx_const_folded_attrs_name = fx_const_folded_attrs_name |
|
self.device_for_folded_attrs = device_for_folded_attrs |
|
|
|
def __call__(self, *args, **kwargs): |
|
if not self.has_folding_been_run: |
|
self.run_folding() |
|
return super().__call__(*args) |
|
|
|
def run_folding(self): |
|
|
|
|
|
if ( |
|
self.const_subgraph_module is None |
|
or self.fx_const_folded_attrs_name is None |
|
): |
|
return |
|
|
|
assert not self.has_folding_been_run |
|
self.has_folding_been_run = True |
|
|
|
|
|
|
|
|
|
folded_attrs = self.const_subgraph_module() |
|
|
|
def _create_param(i): |
|
return torch.nn.Parameter( |
|
i.detach().clone() |
|
if not isinstance(i, int) |
|
else torch.Tensor([i]).to(device=self.device_for_folded_attrs), |
|
requires_grad=i.requires_grad if isinstance(i, torch.Tensor) else False, |
|
) |
|
|
|
params = ( |
|
torch.nn.ParameterList([_create_param(i) for i in folded_attrs]) |
|
if isinstance(folded_attrs, tuple) |
|
else _create_param(folded_attrs) |
|
) |
|
setattr(self, self.fx_const_folded_attrs_name, params) |
|
|
|
|
|
def _inline_module(gm: torch.fx.GraphModule, inline_mod_name: str): |
|
""" |
|
Given `gm` and some graph module which is called with target name `inline_mod_name`, |
|
this helper will inline all of the nodes from that called graph module into `gm`. |
|
""" |
|
|
|
inline_mod = dict(gm.named_modules())[inline_mod_name] |
|
assert isinstance(inline_mod, torch.fx.GraphModule) |
|
call_mod_node_to_replace = None |
|
for node in gm.graph.nodes: |
|
if node.op == "call_module" and node.target == inline_mod_name: |
|
call_mod_node_to_replace = node |
|
break |
|
assert call_mod_node_to_replace is not None |
|
|
|
|
|
|
|
call_mod_args = call_mod_node_to_replace.args |
|
replacement_mapping: Dict[torch.fx.Node, torch.fx.Node] = {} |
|
ph_count = 0 |
|
|
|
def replacement_fn(node): |
|
new_node = replacement_mapping[node] |
|
new_node.meta = node.meta.copy() |
|
return new_node |
|
|
|
for inline_node in inline_mod.graph.nodes: |
|
if inline_node.op == "placeholder": |
|
replacement_mapping[inline_node] = call_mod_args[ph_count] |
|
ph_count += 1 |
|
continue |
|
|
|
if inline_node.op == "output": |
|
outputs = inline_node.args[0] |
|
output_replacements = map_arg(outputs, replacement_fn) |
|
call_mod_node_to_replace.replace_all_uses_with(output_replacements) |
|
continue |
|
|
|
with gm.graph.inserting_before(call_mod_node_to_replace): |
|
new_node = gm.graph.node_copy(inline_node, replacement_fn) |
|
replacement_mapping[inline_node] = new_node |
|
|
|
gm.graph.eliminate_dead_code() |
|
|
|
|
|
def get_unique_attr_name_in_module(mod_traced: torch.fx.GraphModule, name: str) -> str: |
|
""" |
|
Make sure the name is unique (in a module) and can represents an attr. |
|
""" |
|
|
|
name = re.sub("[^0-9a-zA-Z_]+", "_", name) |
|
if name[0].isdigit(): |
|
name = f"_{name}" |
|
|
|
while hasattr(mod_traced, name): |
|
match = re.match(r"(.*)_(\d+)$", name) |
|
if match is None: |
|
name = name + "_1" |
|
else: |
|
base, num = match.group(1, 2) |
|
name = f"{base}_{int(num) + 1}" |
|
|
|
return name |
|
|
|
|
|
def split_const_subgraphs( |
|
module: Union[torch.nn.Module, torch.fx.GraphModule], |
|
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None, |
|
device_for_folded_attrs: str = "cpu", |
|
) -> FoldedGraphModule: |
|
""" |
|
Looks through `module` for any nodes that have all constant attribute inputs |
|
and separates them out into their own constant subgraph, and returns a |
|
FoldedGraphModule which runs that constant subgraph on the first run to set |
|
attributes on the module prior to running the non-constant portion of the |
|
graph. |
|
""" |
|
if not isinstance(module, torch.fx.GraphModule): |
|
mod_traced = torch.fx.symbolic_trace(module) |
|
else: |
|
mod_traced = module |
|
|
|
|
|
|
|
const_nodes: Set[torch.fx.Node] = set() |
|
found_const_folding = False |
|
for node in mod_traced.graph.nodes: |
|
|
|
|
|
if node.op in {"placeholder", "output"}: |
|
continue |
|
|
|
|
|
|
|
if node.op != "get_attr" and not set(node.all_input_nodes).issubset( |
|
const_nodes |
|
): |
|
continue |
|
|
|
|
|
if skip_folding_node_fn and skip_folding_node_fn(node): |
|
continue |
|
|
|
|
|
if node.is_impure(): |
|
continue |
|
|
|
|
|
const_nodes.add(node) |
|
if node.op != "get_attr": |
|
found_const_folding = True |
|
|
|
|
|
if not found_const_folding: |
|
return FoldedGraphModule(mod_traced, mod_traced.graph) |
|
|
|
|
|
|
|
def mod_partition(node: torch.fx.Node): |
|
return 0 if node in const_nodes else 1 |
|
|
|
split = split_module(mod_traced, module, mod_partition) |
|
|
|
const_gm, non_const_gm = split.submod_0, split.submod_1 |
|
const_mod_name, non_const_mod_name = "submod_0", "submod_1" |
|
|
|
|
|
|
|
|
|
for node in non_const_gm.graph.nodes: |
|
if node.op == "call_module": |
|
setattr(split, node.target, getattr(non_const_gm, node.target)) |
|
for node in const_gm.graph.nodes: |
|
if node.op == "call_module": |
|
setattr(split, node.target, getattr(const_gm, node.target)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
call_const_gm_args = None |
|
for node in split.graph.nodes: |
|
if node.op == "call_module": |
|
if node.target == const_mod_name: |
|
call_const_gm_args = node.args |
|
break |
|
assert call_const_gm_args is not None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
root_const_gm = torch.fx.GraphModule(split, const_gm.graph) |
|
for node in root_const_gm.graph.nodes: |
|
if node.op == "output": |
|
multiple_outputs = isinstance(node.args[0], tuple) |
|
continue |
|
if node.op != "placeholder": |
|
continue |
|
in_node = next(n for n in call_const_gm_args if n.name == node.target) |
|
assert in_node.op == "get_attr" |
|
with root_const_gm.graph.inserting_before(node): |
|
new_node = root_const_gm.graph.get_attr(in_node.target) |
|
new_node.meta = node.meta.copy() |
|
node.replace_all_uses_with(new_node) |
|
root_const_gm.graph.erase_node(node) |
|
assert "multiple_outputs" in locals() |
|
|
|
|
|
|
|
|
|
|
|
fx_const_folded_attrs_name = get_unique_attr_name_in_module( |
|
split, "_FX_CONST_FOLDED_ATTRS" |
|
) |
|
setattr( |
|
split, |
|
fx_const_folded_attrs_name, |
|
torch.nn.ParameterList() if multiple_outputs else torch.nn.Parameter(), |
|
) |
|
for node in split.graph.nodes: |
|
if node.op == "call_module" and node.target == const_mod_name: |
|
with node.graph.inserting_before(node): |
|
folded_attrs = node.graph.get_attr(fx_const_folded_attrs_name) |
|
folded_attrs.meta = node.meta.copy() |
|
node.replace_all_uses_with(folded_attrs) |
|
break |
|
|
|
split.graph.eliminate_dead_code() |
|
|
|
|
|
|
|
|
|
_inline_module(split, non_const_mod_name) |
|
|
|
return FoldedGraphModule( |
|
split, |
|
split.graph, |
|
root_const_gm.graph, |
|
fx_const_folded_attrs_name, |
|
device_for_folded_attrs, |
|
) |
|
|