Spaces:
Running
Running
import argparse | |
import copy | |
from collections import defaultdict | |
from dataclasses import dataclass | |
from typing import NamedTuple, Sequence, Iterable, Any, List, Dict, Optional, Tuple | |
import logging | |
import torch | |
from torch.fx.passes.graph_manipulation import get_size_of_node | |
from torch.fx.node import map_arg | |
from torch.fx._compatibility import compatibility | |
from .operator_support import ( | |
get_node_target, | |
OperatorSupportBase, | |
) | |
from .graph_drawer import FxGraphDrawer | |
from .shape_prop import ShapeProp | |
from .split_utils import split_by_tags | |
from .tools_common import ( | |
FxNetAccFusionsFinder, | |
CALLABLE_NODE_OPS, | |
Tensors, | |
NodeList, | |
NodeSet, | |
is_node_output_tensor, | |
) | |
__all__ = ['FxNetAccNodesFinder', 'FxNetSplitterInternalError', 'Subgraph', 'SplitResult', 'generate_inputs_for_submodules'] | |
_LOGGER = logging.getLogger(__name__) | |
DEFAULT_MIN_ACC_MODULE_SIZE = 1 | |
DEFAULT_SKIP_FUSION = False | |
DEFAULT_ALLOW_NON_TENSOR = False | |
class _SplitterSettingBase: | |
def __init__( | |
self, | |
min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE, | |
skip_fusion=DEFAULT_SKIP_FUSION, | |
allow_non_tensor=DEFAULT_ALLOW_NON_TENSOR | |
): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--min-acc-module-size", | |
"--min_acc_module_size", | |
required=False, | |
type=int, | |
help="Minimum size limit of an accelerator subgraph.", | |
) | |
parser.add_argument( | |
"--skip-fusion", | |
"--skip_fusion", | |
default=False, | |
action="store_true", | |
help="If true then no fusion groups. Fusion group is used to " | |
"enforce no non-tensor data flow between submodules. If we don't " | |
"have this constrain, setting this to false is recommended as it " | |
"can reduce overhead.", | |
) | |
parser.add_argument( | |
"--allow-non-tensor", | |
"--allow_non_tensor", | |
default=False, | |
action="store_true", | |
help="For some backends non-tensor data flow between cpu and them " | |
"are not allowed. Therefore, if a node supported by accelerator but " | |
"it has non-tensor inputs or outputs to a cpu node we would want to " | |
"consider it as a cpu node during splitting. However, for some backends " | |
"we might not care about non-tensor data flow and we can set this option " | |
"to true to disable the functionality that prevent non-tensor data flow.", | |
) | |
args, unknown = parser.parse_known_args() | |
self.min_acc_module_size: int = args.min_acc_module_size if args.min_acc_module_size else min_acc_module_size | |
self.skip_fusion: bool = args.skip_fusion if args.skip_fusion else skip_fusion | |
self.allow_non_tensor: bool = args.allow_non_tensor if args.allow_non_tensor else allow_non_tensor | |
class FxNetAccNodesFinder: | |
""" | |
Finds a set of nodes that can be supported on ACC, excluding nodes that have non-tensor | |
input/output to cpu nodes to prevent non-tensor data flow between backends and cpu. | |
I.e. if we have a chain: | |
ACC_NODE_1 -> ACC_NODE_2 -> ACC_NODE_3 -> CPU_NODE_1 | |
where every ACC node produces non-tensor output, then they all should be treated as CPU nodes. | |
This behavior can be turned off by passing allow_non_tensor=True. | |
""" | |
def __init__( | |
self, | |
module: torch.fx.GraphModule, | |
operator_support: OperatorSupportBase, | |
allow_non_tensor: bool, | |
): | |
self.module = module | |
self.operator_support = operator_support | |
self.allow_non_tensor = allow_non_tensor | |
def reduce_acc_nodes_non_tensor_input_helper( | |
self, cpu_worklist: NodeList | |
): | |
""" | |
Transitively excludes nodes from ACC supported set. | |
For every node in the worklist: | |
- removes its downstream ACC nodes from ACC supported set, | |
- if any downstream ACC node produces non-tensor output, | |
then it gets added into the worklist. | |
""" | |
while cpu_worklist: | |
node = cpu_worklist.pop(0) | |
for user in node.users: | |
if user in self.acc_nodes: | |
self.acc_nodes.remove(user) | |
if not is_node_output_tensor(user): | |
cpu_worklist.append(user) | |
def reduce_acc_nodes_non_tensor_input(self): | |
""" | |
Excludes nodes from ACC supported set that have direct | |
upstream CPU nodes that produce non-tensor outputs. | |
""" | |
non_tensor_cpu_nodes: NodeList = [] | |
for node in self.module.graph.nodes: | |
if node.op not in CALLABLE_NODE_OPS: | |
continue | |
if node in self.acc_nodes: | |
continue | |
if is_node_output_tensor(node): | |
continue | |
non_tensor_cpu_nodes.append(node) | |
self.reduce_acc_nodes_non_tensor_input_helper(non_tensor_cpu_nodes) | |
def reduce_acc_nodes_non_tensor_output(self): | |
""" | |
Excludes nodes from ACC supported set that produce non-tensor | |
outputs and have downstream CPU nodes. | |
""" | |
while True: | |
new_cpu_nodes: NodeList = [] | |
for acc_node in self.acc_nodes: | |
if is_node_output_tensor(acc_node): | |
continue | |
for user in acc_node.users: | |
if user not in self.acc_nodes: | |
new_cpu_nodes.append(acc_node) | |
break | |
if not new_cpu_nodes: | |
break | |
for new_cpu_node in new_cpu_nodes: | |
self.acc_nodes.remove(new_cpu_node) | |
self.reduce_acc_nodes_non_tensor_input_helper(new_cpu_nodes) | |
def __call__(self) -> NodeSet: | |
submodules = dict(self.module.named_modules()) | |
self.acc_nodes = { | |
n | |
for n in self.module.graph.nodes | |
if n.op in CALLABLE_NODE_OPS | |
and self.operator_support.is_node_supported(submodules, n) | |
} | |
if not self.allow_non_tensor: | |
self.reduce_acc_nodes_non_tensor_input() | |
self.reduce_acc_nodes_non_tensor_output() | |
return self.acc_nodes | |
class FxNetSplitterInternalError(Exception): | |
pass | |
class Subgraph: | |
is_acc: bool | |
nodes: NodeList | |
class SplitResult(NamedTuple): | |
""" | |
Stores the results of the splitter. | |
Attributes: | |
split_module: root module after splitting. | |
submodule_inputs: a dict that maps submodule name to its inputs. | |
non_acc_submodule_prefix: the prefix for non acc submodules. For | |
acc submodule the prefix is alwasy "_run_on_acc_". | |
""" | |
split_module: torch.fx.GraphModule | |
submodule_inputs: Dict[str, Any] | |
non_acc_submodule_prefix: str | |
def generate_inputs_for_submodules( | |
model: torch.nn.Module, | |
inputs: Sequence[Any], | |
target_submodules: Iterable[str], | |
deepcopy: bool = False, | |
) -> Dict[str, Any]: | |
""" | |
Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this | |
function doesn't work. | |
Args: | |
model: root model. | |
inputs: inputs to the root model. | |
target_submodules: submodules that we want to generate inputs for. | |
Returns: | |
A dict that maps from submodule name to its inputs. | |
""" | |
handles = [] | |
results = {} | |
submodule_to_names = {mod: name for name, mod in model.named_modules()} | |
def pre_forward(module, module_inputs): | |
results[submodule_to_names[module]] = copy.deepcopy(module_inputs) if deepcopy else module_inputs | |
for name, mod in model.named_modules(): | |
if name in target_submodules: | |
handles.append(mod.register_forward_pre_hook(pre_forward)) | |
def clean_up_handles(): | |
for h in handles: | |
h.remove() | |
try: | |
with torch.no_grad(): | |
model(*inputs) | |
except Exception as e: | |
clean_up_handles() | |
raise e | |
clean_up_handles() | |
return results | |
class _SplitterBase: | |
""" | |
Splits a GraphModule into sub-GraphModules for execution on CPU or the accelerator. | |
Output is a GraphModule with supported and unsupported operators grouped into as few sub-GraphModules as possible. | |
Assumes that only "call_module", "call_function" and "call_method" from FX IR can potentially be executed on the accelerator. | |
Given the following graph: | |
==> b ==> | |
// \\ | |
a d | |
\\ // | |
==> c ==> | |
class SimpleModule(torch.nn.Module): | |
def forward(self, a): | |
b = torch.sin(a) | |
c = torch.cos(a) | |
d = b + c | |
return d | |
and providing "operator_support" that indicates that 'b' and 'c' can be executed on the accelerator, | |
we will get the following split result: | |
main: | |
def forward(self, a): | |
run_on_acc_0_0 = self._run_on_acc_0_0(a) | |
getitem = run_on_acc_0_0[0] | |
getitem_1 = run_on_acc_0_0[1] | |
run_on_cpu_1_1 = self._run_on_cpu_1_1(getitem, getitem_1) | |
return run_on_cpu_1_1 | |
_run_on_acc_0_0: | |
def forward(self, a): | |
sin_1 = torch.sin(a) | |
cos_1 = torch.cos(a) | |
return (sin_1, cos_1) | |
_run_on_cpu_1_1: | |
def forward(self, sin_1, cos_1): | |
add_1 = sin_1 + cos_1 | |
return add_1 | |
""" | |
# PCIe bandwidth for the backend, default to 100 GB/s | |
PCIe_BW = 100 * 2 ** 30 | |
def __init__( | |
self, | |
module: torch.fx.GraphModule, | |
sample_input: Sequence[Any], | |
operator_support: OperatorSupportBase, | |
settings: _SplitterSettingBase, | |
non_acc_submodule_name: str = "_run_on_cpu_", | |
): | |
""" | |
Preprocesses graph before splitting: | |
- finds nodes supported by ACC, | |
- finds fusion groups for ACC nodes having non-tensor IO, | |
- builds a graph of direct dependencies, | |
- builds a map of fused nodes to their fusions. | |
As a result we get self.acc_nodes, self.deps and self.fusions. | |
""" | |
assert isinstance(module, torch.fx.GraphModule) | |
self.module = module | |
ShapeProp(self.module).propagate(*sample_input) | |
self.settings = settings | |
self.operator_support = operator_support | |
self.sample_input = sample_input | |
self.acc_nodes = FxNetAccNodesFinder(self.module, self.operator_support, self.settings.allow_non_tensor)() | |
if self.settings.skip_fusion: | |
self.fusions = {} | |
else: | |
self.fusions = FxNetAccFusionsFinder(module, self.acc_nodes)() | |
# Modify deps to add more deps for fused nodes | |
self.deps = self.find_deps() | |
self.update_deps_for_fusions() | |
self.non_acc_submodule_name = non_acc_submodule_name | |
self._node_submodule_map: Dict[str, str] = {} | |
# =============================================================== | |
# Helpers for ctor and initial state | |
# =============================================================== | |
def get_node_submodule_map(self) -> Dict[str, str]: | |
""" Returns a map from node name to submodule name, e.g. | |
node: main_module_impl_impl_over_arch_unary_multiple_embedding | |
_pooling_embedding_pooling_sparse_entity_equivalence_key | |
_proxy_embedding_bag | |
maps to submodule name of: _run_on_acc_1 | |
""" | |
return self._node_submodule_map | |
def find_deps(self) -> Dict[torch.fx.Node, NodeSet]: | |
""" | |
Builds a graph of node dependencies. Leaf nodes don't have any | |
dependencies and the "output" node doesn't have nodes depending on it. | |
Resulting graph has only direct dependencies, i.e. there are no | |
transitive dependencies. | |
""" | |
deps: Dict[torch.fx.Node, NodeSet] = defaultdict(set) | |
for node in self.module.graph.nodes: | |
if node.op not in CALLABLE_NODE_OPS: | |
continue | |
for user in node.users: | |
if user.op != "output": | |
deps[user].add(node) | |
return deps | |
def update_deps_for_fusions(self): | |
""" | |
Updates graph of dependencies so that: | |
- nodes from the same fusion depend on the same set of outer nodes, | |
- outer nodes depending on a fusion depend on all nodes in that fusion. | |
""" | |
for node in self.fusions: | |
fusion = self.fusions[node] | |
for fused_neighbor in fusion: | |
self.deps[node].update(self.deps[fused_neighbor] - fusion) | |
for user in fused_neighbor.users: | |
if user not in fusion: | |
self.deps[user].add(node) | |
# =============================================================== | |
# Helpers for preview | |
# =============================================================== | |
def _lower_model_to_backend( | |
self, mod: torch.fx.GraphModule, inputs: Tensors | |
) -> torch.nn.Module: | |
""" | |
Lower the model to a backend. | |
""" | |
return mod | |
def _find_culprit( | |
self, mod: torch.fx.GraphModule, inputs: Tensors | |
) -> str: | |
""" | |
When an error occurs during lowering or running the lowered mod, we use this | |
function to find culprits in the `mod` that causes the error. | |
""" | |
return "Unable to find a culprit because _find_culprit() function is not implemented." | |
def _draw_graph_based_on_node_support( | |
self, mod: torch.fx.GraphModule, supported_nodes: NodeList | |
): | |
color_map = { | |
"default": "AliceBlue", | |
"supported": "chartreuse1", | |
"unsupported": "crimson", | |
} | |
class CustomDrawer(FxGraphDrawer): | |
def _get_node_style(self, node): | |
template = super()._get_node_style(node) | |
if node in supported_nodes: | |
template["fillcolor"] = color_map["supported"] | |
elif node.op in CALLABLE_NODE_OPS: | |
template["fillcolor"] = color_map["unsupported"] | |
else: | |
template["fillcolor"] = color_map["default"] | |
return template | |
drawer = CustomDrawer(mod, "node_support", ignore_getattr=True) | |
dot_graph = drawer.get_main_dot_graph() | |
dot_graph.write_raw("node_support.dot") | |
def node_support_preview(self, dump_graph: bool = False): | |
submodules = dict(self.module.named_modules()) | |
supported_nodes: NodeList = [] | |
supported_node_types = defaultdict(set) | |
unsupported_node_types = defaultdict(set) | |
def get_dtype(arg): | |
tensor_meta = arg.meta.get("tensor_meta") | |
return getattr(tensor_meta, "dtype", None) | |
for node in self.module.graph.nodes: | |
if node.op not in CALLABLE_NODE_OPS: | |
continue | |
target = get_node_target(submodules, node) | |
# Store dtype of arg in node.args. If arg doesn't have dtype, i.e. not a tensor, we'll store None. | |
arg_dtypes = [ | |
get_dtype(arg) if isinstance(arg, torch.fx.Node) else None | |
for arg in node.args | |
] | |
# Find last non-None element. If all elements are None, return max_len. | |
last_index = len(arg_dtypes) - next( | |
( | |
i | |
for i, dtype in enumerate(reversed(arg_dtypes)) | |
if dtype is not None | |
), | |
len(arg_dtypes), | |
) | |
# Strip None elements at the end. | |
arg_dtypes_tuple = tuple(arg_dtypes[:last_index]) | |
kwarg_dtypes_tuple = tuple( | |
(k, get_dtype(arg)) | |
for k, arg in node.kwargs.items() | |
if isinstance(arg, torch.fx.Node) | |
) | |
if self.operator_support.is_node_supported(submodules, node): | |
supported_nodes.append(node) | |
supported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple)) | |
else: | |
unsupported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple)) | |
if dump_graph: | |
self._draw_graph_based_on_node_support(self.module, supported_nodes) | |
reports = "\nSupported node types in the model:\n" | |
for t, dtypes in supported_node_types.items(): | |
for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes: | |
reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n" | |
reports += "\nUnsupported node types in the model:\n" | |
for t, dtypes in unsupported_node_types.items(): | |
for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes: | |
reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n" | |
print(reports) | |
# Return reports for testing purpose | |
return reports | |
def split_preview(self, dump_graph: bool = False): | |
reports = "" | |
subgraphs = self.put_nodes_into_subgraphs() | |
acc_subgraphs_num = len([g for g in subgraphs if g.is_acc]) | |
cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num | |
reports += f"Before removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:" | |
reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n" | |
subgraphs = self.remove_small_acc_subgraphs(subgraphs) | |
acc_subgraphs_num = len([g for g in subgraphs if g.is_acc]) | |
cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num | |
reports += f"After removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:" | |
reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n" | |
for i, subgraph in enumerate(subgraphs): | |
reports += f"_run_on_acc_{i}: " if subgraph.is_acc else f"{self.non_acc_submodule_name}{i}: " | |
reports += f"{len(subgraph.nodes)} node(s)\n" | |
self.tag(subgraphs) | |
split_mod = self.split(remove_tag=True) | |
split_mod.eval() | |
if dump_graph: | |
drawer = FxGraphDrawer( | |
split_mod, "preview", ignore_getattr=True | |
) | |
dot_graphs = drawer.get_all_dot_graphs() | |
for name, dot_graph in dot_graphs.items(): | |
dot_graph.write_raw(f"{name}.dot") | |
max_qps: float = self.PCIe_BW | |
bottleneck_module = "" | |
for node in split_mod.graph.nodes: | |
if node.op == "call_module" and "acc" in node.target: | |
reports += f"\nProcessing acc submodule {node.target}\n" | |
submod = getattr(split_mod, node.target) | |
def get_submod_inputs(main_mod, submod, example_inputs): | |
sub_inputs = None | |
def get_inputs(self, inputs): | |
nonlocal sub_inputs | |
sub_inputs = inputs | |
handle = submod.register_forward_pre_hook(get_inputs) | |
main_mod(*example_inputs) | |
handle.remove() | |
return sub_inputs | |
submod_inputs = get_submod_inputs( | |
split_mod, submod, self.sample_input | |
) | |
ShapeProp(submod).propagate(*submod_inputs) | |
total_input_bytes = 0 | |
total_output_bytes = 0 | |
reports += "Checking inputs...\n" | |
for n in submod.graph.nodes: | |
if n.op == "placeholder": | |
if not is_node_output_tensor(n): | |
reports += f"Input {n.name} is not a tensor, this might cause problems during lowering!\n" | |
else: | |
total_input_bytes += get_size_of_node(submod, n)[0] | |
if n.op == "output": | |
output_node = n | |
reports += "Checking outputs...\n" | |
def get_bytes(node: torch.fx.Node): | |
nonlocal total_output_bytes | |
nonlocal reports | |
if not is_node_output_tensor(node): | |
reports += f"Output {node.name} is not a tensor, this might cause problems during lowering!\n" | |
else: | |
total_output_bytes += get_size_of_node(submod, node)[0] | |
map_arg(output_node.args, get_bytes) # type: ignore[possibly-undefined] | |
qps = self.PCIe_BW / max(total_input_bytes, total_output_bytes) | |
reports += f"Total input size in bytes is {total_input_bytes}, total output size in bytes is {total_output_bytes}," | |
reports += f" theoretical max qps (bounds by PCIe bandwidth) for this submodule is {qps}.\n" | |
if qps < max_qps: | |
max_qps = qps | |
bottleneck_module = node.target | |
try: | |
lowered_submod = self._lower_model_to_backend(submod, submod_inputs) | |
except RuntimeError: | |
reports += "Run into an error during lowering!\n" | |
reports += self._find_culprit(submod, submod_inputs) | |
continue | |
try: | |
lowered_submod(*submod_inputs) | |
except RuntimeError: | |
reports += "Run into an error during inference!\n" | |
reports += self._find_culprit(submod, submod_inputs) | |
else: | |
reports += "Lowering and running succeed!\n" | |
reports += f"\nTheoretical max qps (bounds by PCIe bandwidth) for this model is {max_qps}," | |
reports += f" bottleneck is submodule {bottleneck_module}." | |
print(reports) | |
# return the reports for testing purposes | |
return reports | |
# =============================================================== | |
# Helpers for extend_acc_subgraph() method | |
# =============================================================== | |
def find_reverse_deps( | |
self, tag_id: Optional[int] = None | |
) -> Dict[torch.fx.Node, NodeSet]: | |
""" | |
Builds reversed topological node dependencies, if tag_id is specified, | |
we ignore nodes that are in later subgraph i.e. nodes have greater tag_id. | |
""" | |
result: Dict[torch.fx.Node, NodeSet] = defaultdict(set) | |
for node in self.module.graph.nodes: | |
if node.op not in CALLABLE_NODE_OPS: | |
continue | |
for user in node.users: | |
if user.op not in CALLABLE_NODE_OPS: | |
continue | |
if tag_id is None or (int(user.tag.split("_")[-1]) < tag_id): | |
result[node].add(user) | |
return result | |
def update_reverse_deps_for_fusions( | |
self, deps: Dict[torch.fx.Node, NodeSet] | |
): | |
processed_node = set() | |
for node, fusion in self.fusions.items(): | |
if node in processed_node: | |
continue | |
new_dep = set() | |
# Create a new dependency set which include all the | |
# dependencies of the nodes in the fusion group | |
for n in fusion: | |
new_dep.update(deps[n]) | |
# Exclude nodes in the fusion | |
new_dep.difference_update(fusion) | |
# Update dependency | |
for n in fusion: | |
deps[n] = new_dep | |
for arg in n.all_input_nodes: | |
if arg not in fusion: | |
deps[arg].update(fusion) | |
processed_node.add(n) | |
def find_parent_nodes_of_subgraph(self, tag: str) -> NodeSet: | |
""" | |
Finds parent nodes of the `tag` subgraph. | |
Traverse the inputs of nodes in the subgraph, if input doesn't belong to the subgraph | |
and is not a placeholder, we consider it as the parent node of the subgraph. | |
""" | |
parent_nodes = set() | |
for node in self.module.graph.nodes: | |
if node.op in CALLABLE_NODE_OPS and node.tag == tag: | |
for arg in node.all_input_nodes: | |
if arg.op in CALLABLE_NODE_OPS and arg.tag != tag: | |
parent_nodes.add(arg) | |
return parent_nodes | |
def extend_acc_subgraph(self, tag: str): | |
""" | |
Extend the acc subgraph with `tag` going the reversed topological direction. | |
""" | |
# Dict that maps node to its users and ignore users that | |
# are in the subgraph that has greater tag | |
deps = self.find_reverse_deps(tag_id=int(tag.split("_")[-1])) | |
self.update_reverse_deps_for_fusions(deps) | |
# Parent nodes of the subgraph | |
parent_nodes = self.find_parent_nodes_of_subgraph(tag) | |
visited_nodes: NodeSet = set() | |
while parent_nodes: | |
node = None | |
# Find a acc node that depends on visited nodes only | |
for n in parent_nodes: | |
if deps[n] <= visited_nodes and n in self.acc_nodes: | |
node = n | |
break | |
if node is None: | |
break | |
# Put the node into `tag` subgraph | |
node.tag = tag # type: ignore[attr-defined] | |
parent_nodes.remove(node) | |
visited_nodes.add(node) | |
# If node is in a fusion group, add all fusion buddies to parent nodes | |
if node in self.fusions: | |
for fusion_node in self.fusions[node]: | |
if fusion_node not in visited_nodes: | |
parent_nodes.add(fusion_node) | |
# Add inputs of the node to parent nodes | |
for arg in node.all_input_nodes: | |
if arg.op in CALLABLE_NODE_OPS and arg not in visited_nodes: | |
parent_nodes.add(arg) | |
# =============================================================== | |
# Helpers for split() method | |
# =============================================================== | |
def starter_nodes(self) -> Tuple[NodeSet, NodeSet]: | |
""" | |
Finds nodes that consume module inputs or get_attr nodes. | |
""" | |
starter_cpu_nodes: NodeSet = set() | |
starter_acc_nodes: NodeSet = set() | |
for node in self.module.graph.nodes: | |
if node.op not in {"placeholder", "get_attr"}: | |
continue | |
for user in node.users: | |
if user in self.acc_nodes: | |
starter_acc_nodes.add(user) | |
else: | |
starter_cpu_nodes.add(user) | |
return starter_cpu_nodes, starter_acc_nodes | |
def put_nodes_into_subgraphs(self) -> List[Subgraph]: | |
# We start graph traversal from leaf nodes | |
current_cpu_nodes, current_acc_nodes = self.starter_nodes() | |
visited_nodes: NodeSet = set() | |
# Determine which subgraph to start from based on which subgraph has | |
# 0-dep node | |
acc_subgraph: bool = not any(len(self.deps[n]) == 0 for n in current_cpu_nodes) | |
current_subgraph_nodes: NodeList = [] | |
# Result accumulator | |
subgraphs: List[Subgraph] = [] | |
while current_cpu_nodes or current_acc_nodes: | |
# Find the first node that should belong to the current subgraph and has all dependencies resolved | |
current_nodes = current_acc_nodes if acc_subgraph else current_cpu_nodes | |
node = next( | |
(n for n in current_nodes if self.deps[n] <= visited_nodes), | |
None, | |
) | |
# If nothing was found, then it's time to flip the mode and start a new subgraph | |
if node is None: | |
if not current_subgraph_nodes: | |
raise FxNetSplitterInternalError("Subgraph can't be empty") | |
subgraphs.append( | |
Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes) | |
) | |
acc_subgraph = not acc_subgraph | |
current_subgraph_nodes = [] | |
continue | |
current_nodes.remove(node) | |
visited_nodes.add(node) | |
current_subgraph_nodes.append(node) | |
# Add fusion buddies | |
if node in self.fusions: | |
if node in self.acc_nodes: | |
current_acc_nodes.update(self.fusions[node] - visited_nodes) | |
else: | |
current_cpu_nodes.update(self.fusions[node] - visited_nodes) | |
# Put depending nodes into the queue | |
for user in node.users: | |
if user.op not in CALLABLE_NODE_OPS: | |
continue | |
# Add downstream nodes | |
if user in self.acc_nodes: | |
current_acc_nodes.add(user) | |
else: | |
current_cpu_nodes.add(user) | |
# Check if the last subgraph was not created | |
if current_subgraph_nodes: | |
subgraphs.append( | |
Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes) | |
) | |
if not subgraphs: | |
raise FxNetSplitterInternalError("Couldn't create subgraphs") | |
return subgraphs | |
def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]: | |
""" | |
This pass finds ACC submodules with less than specified size and merges | |
them with adjacent CPU submodules. | |
""" | |
result: List[Subgraph] = [] | |
for subgraph in subgraphs: | |
if subgraph.is_acc: | |
if len(subgraph.nodes) >= self.settings.min_acc_module_size: | |
result.append(subgraph) | |
else: | |
print( | |
"Eliminating acc subgraph because it's smaller than the threshold: " | |
f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}" | |
) | |
if result: | |
result[-1].nodes.extend(subgraph.nodes) | |
else: | |
subgraph.is_acc = False | |
result.append(subgraph) | |
else: | |
if result and not result[-1].is_acc: | |
result[-1].nodes.extend(subgraph.nodes) | |
else: | |
result.append(subgraph) | |
return result | |
def tag(self, subgraphs: List[Subgraph]): | |
self.tags: List[str] = [] | |
for subgraph in subgraphs: | |
tag = f"_run_on_acc_{len(self.tags)}" if subgraph.is_acc else f"{self.non_acc_submodule_name}{len(self.tags)}" | |
self.tags.append(tag) | |
for node in subgraph.nodes: | |
if hasattr(node, "tag"): | |
raise FxNetSplitterInternalError(f"Node {node} was already tagged") | |
node.tag = tag # type: ignore[attr-defined] | |
self._node_submodule_map[node.name] = tag | |
def split(self, remove_tag: bool = False) -> torch.fx.GraphModule: | |
split_module = split_by_tags(self.module, self.tags) | |
if remove_tag: | |
for node in self.module.graph.nodes: | |
if hasattr(node, "tag"): | |
del node.tag | |
return split_module | |
def __call__(self) -> torch.fx.GraphModule: | |
subgraphs = self.put_nodes_into_subgraphs() | |
subgraphs = self.remove_small_acc_subgraphs(subgraphs) | |
acc_subgraphs_count = len([s for s in subgraphs if s.is_acc]) | |
non_acc_subgraphs_count = len(subgraphs) - acc_subgraphs_count | |
print(f"Got {acc_subgraphs_count} acc subgraphs and {non_acc_subgraphs_count} non-acc subgraphs") | |
self.tag(subgraphs) | |
return self.split() | |
def generate_split_results(self) -> SplitResult: | |
split_module = self() | |
submodule_names = [] | |
for name, mod in split_module.named_children(): | |
submodule_names.append(name) | |
submodule_inputs = generate_inputs_for_submodules(split_module, self.sample_input, submodule_names) | |
return SplitResult(split_module, submodule_inputs, self.non_acc_submodule_name) | |