Spaces:
Running
Running
File size: 33,171 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 |
import argparse
import copy
from collections import defaultdict
from dataclasses import dataclass
from typing import NamedTuple, Sequence, Iterable, Any, List, Dict, Optional, Tuple
import logging
import torch
from torch.fx.passes.graph_manipulation import get_size_of_node
from torch.fx.node import map_arg
from torch.fx._compatibility import compatibility
from .operator_support import (
get_node_target,
OperatorSupportBase,
)
from .graph_drawer import FxGraphDrawer
from .shape_prop import ShapeProp
from .split_utils import split_by_tags
from .tools_common import (
FxNetAccFusionsFinder,
CALLABLE_NODE_OPS,
Tensors,
NodeList,
NodeSet,
is_node_output_tensor,
)
__all__ = ['FxNetAccNodesFinder', 'FxNetSplitterInternalError', 'Subgraph', 'SplitResult', 'generate_inputs_for_submodules']
_LOGGER = logging.getLogger(__name__)
DEFAULT_MIN_ACC_MODULE_SIZE = 1
DEFAULT_SKIP_FUSION = False
DEFAULT_ALLOW_NON_TENSOR = False
class _SplitterSettingBase:
def __init__(
self,
min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE,
skip_fusion=DEFAULT_SKIP_FUSION,
allow_non_tensor=DEFAULT_ALLOW_NON_TENSOR
):
parser = argparse.ArgumentParser()
parser.add_argument(
"--min-acc-module-size",
"--min_acc_module_size",
required=False,
type=int,
help="Minimum size limit of an accelerator subgraph.",
)
parser.add_argument(
"--skip-fusion",
"--skip_fusion",
default=False,
action="store_true",
help="If true then no fusion groups. Fusion group is used to "
"enforce no non-tensor data flow between submodules. If we don't "
"have this constrain, setting this to false is recommended as it "
"can reduce overhead.",
)
parser.add_argument(
"--allow-non-tensor",
"--allow_non_tensor",
default=False,
action="store_true",
help="For some backends non-tensor data flow between cpu and them "
"are not allowed. Therefore, if a node supported by accelerator but "
"it has non-tensor inputs or outputs to a cpu node we would want to "
"consider it as a cpu node during splitting. However, for some backends "
"we might not care about non-tensor data flow and we can set this option "
"to true to disable the functionality that prevent non-tensor data flow.",
)
args, unknown = parser.parse_known_args()
self.min_acc_module_size: int = args.min_acc_module_size if args.min_acc_module_size else min_acc_module_size
self.skip_fusion: bool = args.skip_fusion if args.skip_fusion else skip_fusion
self.allow_non_tensor: bool = args.allow_non_tensor if args.allow_non_tensor else allow_non_tensor
@compatibility(is_backward_compatible=False)
class FxNetAccNodesFinder:
"""
Finds a set of nodes that can be supported on ACC, excluding nodes that have non-tensor
input/output to cpu nodes to prevent non-tensor data flow between backends and cpu.
I.e. if we have a chain:
ACC_NODE_1 -> ACC_NODE_2 -> ACC_NODE_3 -> CPU_NODE_1
where every ACC node produces non-tensor output, then they all should be treated as CPU nodes.
This behavior can be turned off by passing allow_non_tensor=True.
"""
def __init__(
self,
module: torch.fx.GraphModule,
operator_support: OperatorSupportBase,
allow_non_tensor: bool,
):
self.module = module
self.operator_support = operator_support
self.allow_non_tensor = allow_non_tensor
def reduce_acc_nodes_non_tensor_input_helper(
self, cpu_worklist: NodeList
):
"""
Transitively excludes nodes from ACC supported set.
For every node in the worklist:
- removes its downstream ACC nodes from ACC supported set,
- if any downstream ACC node produces non-tensor output,
then it gets added into the worklist.
"""
while cpu_worklist:
node = cpu_worklist.pop(0)
for user in node.users:
if user in self.acc_nodes:
self.acc_nodes.remove(user)
if not is_node_output_tensor(user):
cpu_worklist.append(user)
def reduce_acc_nodes_non_tensor_input(self):
"""
Excludes nodes from ACC supported set that have direct
upstream CPU nodes that produce non-tensor outputs.
"""
non_tensor_cpu_nodes: NodeList = []
for node in self.module.graph.nodes:
if node.op not in CALLABLE_NODE_OPS:
continue
if node in self.acc_nodes:
continue
if is_node_output_tensor(node):
continue
non_tensor_cpu_nodes.append(node)
self.reduce_acc_nodes_non_tensor_input_helper(non_tensor_cpu_nodes)
def reduce_acc_nodes_non_tensor_output(self):
"""
Excludes nodes from ACC supported set that produce non-tensor
outputs and have downstream CPU nodes.
"""
while True:
new_cpu_nodes: NodeList = []
for acc_node in self.acc_nodes:
if is_node_output_tensor(acc_node):
continue
for user in acc_node.users:
if user not in self.acc_nodes:
new_cpu_nodes.append(acc_node)
break
if not new_cpu_nodes:
break
for new_cpu_node in new_cpu_nodes:
self.acc_nodes.remove(new_cpu_node)
self.reduce_acc_nodes_non_tensor_input_helper(new_cpu_nodes)
def __call__(self) -> NodeSet:
submodules = dict(self.module.named_modules())
self.acc_nodes = {
n
for n in self.module.graph.nodes
if n.op in CALLABLE_NODE_OPS
and self.operator_support.is_node_supported(submodules, n)
}
if not self.allow_non_tensor:
self.reduce_acc_nodes_non_tensor_input()
self.reduce_acc_nodes_non_tensor_output()
return self.acc_nodes
@compatibility(is_backward_compatible=False)
class FxNetSplitterInternalError(Exception):
pass
@compatibility(is_backward_compatible=False)
@dataclass
class Subgraph:
is_acc: bool
nodes: NodeList
@compatibility(is_backward_compatible=False)
class SplitResult(NamedTuple):
"""
Stores the results of the splitter.
Attributes:
split_module: root module after splitting.
submodule_inputs: a dict that maps submodule name to its inputs.
non_acc_submodule_prefix: the prefix for non acc submodules. For
acc submodule the prefix is alwasy "_run_on_acc_".
"""
split_module: torch.fx.GraphModule
submodule_inputs: Dict[str, Any]
non_acc_submodule_prefix: str
@compatibility(is_backward_compatible=False)
def generate_inputs_for_submodules(
model: torch.nn.Module,
inputs: Sequence[Any],
target_submodules: Iterable[str],
deepcopy: bool = False,
) -> Dict[str, Any]:
"""
Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this
function doesn't work.
Args:
model: root model.
inputs: inputs to the root model.
target_submodules: submodules that we want to generate inputs for.
Returns:
A dict that maps from submodule name to its inputs.
"""
handles = []
results = {}
submodule_to_names = {mod: name for name, mod in model.named_modules()}
def pre_forward(module, module_inputs):
results[submodule_to_names[module]] = copy.deepcopy(module_inputs) if deepcopy else module_inputs
for name, mod in model.named_modules():
if name in target_submodules:
handles.append(mod.register_forward_pre_hook(pre_forward))
def clean_up_handles():
for h in handles:
h.remove()
try:
with torch.no_grad():
model(*inputs)
except Exception as e:
clean_up_handles()
raise e
clean_up_handles()
return results
class _SplitterBase:
"""
Splits a GraphModule into sub-GraphModules for execution on CPU or the accelerator.
Output is a GraphModule with supported and unsupported operators grouped into as few sub-GraphModules as possible.
Assumes that only "call_module", "call_function" and "call_method" from FX IR can potentially be executed on the accelerator.
Given the following graph:
==> b ==>
// \\
a d
\\ //
==> c ==>
class SimpleModule(torch.nn.Module):
def forward(self, a):
b = torch.sin(a)
c = torch.cos(a)
d = b + c
return d
and providing "operator_support" that indicates that 'b' and 'c' can be executed on the accelerator,
we will get the following split result:
main:
def forward(self, a):
run_on_acc_0_0 = self._run_on_acc_0_0(a)
getitem = run_on_acc_0_0[0]
getitem_1 = run_on_acc_0_0[1]
run_on_cpu_1_1 = self._run_on_cpu_1_1(getitem, getitem_1)
return run_on_cpu_1_1
_run_on_acc_0_0:
def forward(self, a):
sin_1 = torch.sin(a)
cos_1 = torch.cos(a)
return (sin_1, cos_1)
_run_on_cpu_1_1:
def forward(self, sin_1, cos_1):
add_1 = sin_1 + cos_1
return add_1
"""
# PCIe bandwidth for the backend, default to 100 GB/s
PCIe_BW = 100 * 2 ** 30
def __init__(
self,
module: torch.fx.GraphModule,
sample_input: Sequence[Any],
operator_support: OperatorSupportBase,
settings: _SplitterSettingBase,
non_acc_submodule_name: str = "_run_on_cpu_",
):
"""
Preprocesses graph before splitting:
- finds nodes supported by ACC,
- finds fusion groups for ACC nodes having non-tensor IO,
- builds a graph of direct dependencies,
- builds a map of fused nodes to their fusions.
As a result we get self.acc_nodes, self.deps and self.fusions.
"""
assert isinstance(module, torch.fx.GraphModule)
self.module = module
ShapeProp(self.module).propagate(*sample_input)
self.settings = settings
self.operator_support = operator_support
self.sample_input = sample_input
self.acc_nodes = FxNetAccNodesFinder(self.module, self.operator_support, self.settings.allow_non_tensor)()
if self.settings.skip_fusion:
self.fusions = {}
else:
self.fusions = FxNetAccFusionsFinder(module, self.acc_nodes)()
# Modify deps to add more deps for fused nodes
self.deps = self.find_deps()
self.update_deps_for_fusions()
self.non_acc_submodule_name = non_acc_submodule_name
self._node_submodule_map: Dict[str, str] = {}
# ===============================================================
# Helpers for ctor and initial state
# ===============================================================
def get_node_submodule_map(self) -> Dict[str, str]:
""" Returns a map from node name to submodule name, e.g.
node: main_module_impl_impl_over_arch_unary_multiple_embedding
_pooling_embedding_pooling_sparse_entity_equivalence_key
_proxy_embedding_bag
maps to submodule name of: _run_on_acc_1
"""
return self._node_submodule_map
def find_deps(self) -> Dict[torch.fx.Node, NodeSet]:
"""
Builds a graph of node dependencies. Leaf nodes don't have any
dependencies and the "output" node doesn't have nodes depending on it.
Resulting graph has only direct dependencies, i.e. there are no
transitive dependencies.
"""
deps: Dict[torch.fx.Node, NodeSet] = defaultdict(set)
for node in self.module.graph.nodes:
if node.op not in CALLABLE_NODE_OPS:
continue
for user in node.users:
if user.op != "output":
deps[user].add(node)
return deps
def update_deps_for_fusions(self):
"""
Updates graph of dependencies so that:
- nodes from the same fusion depend on the same set of outer nodes,
- outer nodes depending on a fusion depend on all nodes in that fusion.
"""
for node in self.fusions:
fusion = self.fusions[node]
for fused_neighbor in fusion:
self.deps[node].update(self.deps[fused_neighbor] - fusion)
for user in fused_neighbor.users:
if user not in fusion:
self.deps[user].add(node)
# ===============================================================
# Helpers for preview
# ===============================================================
def _lower_model_to_backend(
self, mod: torch.fx.GraphModule, inputs: Tensors
) -> torch.nn.Module:
"""
Lower the model to a backend.
"""
return mod
def _find_culprit(
self, mod: torch.fx.GraphModule, inputs: Tensors
) -> str:
"""
When an error occurs during lowering or running the lowered mod, we use this
function to find culprits in the `mod` that causes the error.
"""
return "Unable to find a culprit because _find_culprit() function is not implemented."
def _draw_graph_based_on_node_support(
self, mod: torch.fx.GraphModule, supported_nodes: NodeList
):
color_map = {
"default": "AliceBlue",
"supported": "chartreuse1",
"unsupported": "crimson",
}
class CustomDrawer(FxGraphDrawer):
def _get_node_style(self, node):
template = super()._get_node_style(node)
if node in supported_nodes:
template["fillcolor"] = color_map["supported"]
elif node.op in CALLABLE_NODE_OPS:
template["fillcolor"] = color_map["unsupported"]
else:
template["fillcolor"] = color_map["default"]
return template
drawer = CustomDrawer(mod, "node_support", ignore_getattr=True)
dot_graph = drawer.get_main_dot_graph()
dot_graph.write_raw("node_support.dot")
def node_support_preview(self, dump_graph: bool = False):
submodules = dict(self.module.named_modules())
supported_nodes: NodeList = []
supported_node_types = defaultdict(set)
unsupported_node_types = defaultdict(set)
def get_dtype(arg):
tensor_meta = arg.meta.get("tensor_meta")
return getattr(tensor_meta, "dtype", None)
for node in self.module.graph.nodes:
if node.op not in CALLABLE_NODE_OPS:
continue
target = get_node_target(submodules, node)
# Store dtype of arg in node.args. If arg doesn't have dtype, i.e. not a tensor, we'll store None.
arg_dtypes = [
get_dtype(arg) if isinstance(arg, torch.fx.Node) else None
for arg in node.args
]
# Find last non-None element. If all elements are None, return max_len.
last_index = len(arg_dtypes) - next(
(
i
for i, dtype in enumerate(reversed(arg_dtypes))
if dtype is not None
),
len(arg_dtypes),
)
# Strip None elements at the end.
arg_dtypes_tuple = tuple(arg_dtypes[:last_index])
kwarg_dtypes_tuple = tuple(
(k, get_dtype(arg))
for k, arg in node.kwargs.items()
if isinstance(arg, torch.fx.Node)
)
if self.operator_support.is_node_supported(submodules, node):
supported_nodes.append(node)
supported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple))
else:
unsupported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple))
if dump_graph:
self._draw_graph_based_on_node_support(self.module, supported_nodes)
reports = "\nSupported node types in the model:\n"
for t, dtypes in supported_node_types.items():
for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes:
reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n"
reports += "\nUnsupported node types in the model:\n"
for t, dtypes in unsupported_node_types.items():
for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes:
reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n"
print(reports)
# Return reports for testing purpose
return reports
def split_preview(self, dump_graph: bool = False):
reports = ""
subgraphs = self.put_nodes_into_subgraphs()
acc_subgraphs_num = len([g for g in subgraphs if g.is_acc])
cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num
reports += f"Before removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:"
reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n"
subgraphs = self.remove_small_acc_subgraphs(subgraphs)
acc_subgraphs_num = len([g for g in subgraphs if g.is_acc])
cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num
reports += f"After removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:"
reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n"
for i, subgraph in enumerate(subgraphs):
reports += f"_run_on_acc_{i}: " if subgraph.is_acc else f"{self.non_acc_submodule_name}{i}: "
reports += f"{len(subgraph.nodes)} node(s)\n"
self.tag(subgraphs)
split_mod = self.split(remove_tag=True)
split_mod.eval()
if dump_graph:
drawer = FxGraphDrawer(
split_mod, "preview", ignore_getattr=True
)
dot_graphs = drawer.get_all_dot_graphs()
for name, dot_graph in dot_graphs.items():
dot_graph.write_raw(f"{name}.dot")
max_qps: float = self.PCIe_BW
bottleneck_module = ""
for node in split_mod.graph.nodes:
if node.op == "call_module" and "acc" in node.target:
reports += f"\nProcessing acc submodule {node.target}\n"
submod = getattr(split_mod, node.target)
def get_submod_inputs(main_mod, submod, example_inputs):
sub_inputs = None
def get_inputs(self, inputs):
nonlocal sub_inputs
sub_inputs = inputs
handle = submod.register_forward_pre_hook(get_inputs)
main_mod(*example_inputs)
handle.remove()
return sub_inputs
submod_inputs = get_submod_inputs(
split_mod, submod, self.sample_input
)
ShapeProp(submod).propagate(*submod_inputs)
total_input_bytes = 0
total_output_bytes = 0
reports += "Checking inputs...\n"
for n in submod.graph.nodes:
if n.op == "placeholder":
if not is_node_output_tensor(n):
reports += f"Input {n.name} is not a tensor, this might cause problems during lowering!\n"
else:
total_input_bytes += get_size_of_node(submod, n)[0]
if n.op == "output":
output_node = n
reports += "Checking outputs...\n"
def get_bytes(node: torch.fx.Node):
nonlocal total_output_bytes
nonlocal reports
if not is_node_output_tensor(node):
reports += f"Output {node.name} is not a tensor, this might cause problems during lowering!\n"
else:
total_output_bytes += get_size_of_node(submod, node)[0]
map_arg(output_node.args, get_bytes) # type: ignore[possibly-undefined]
qps = self.PCIe_BW / max(total_input_bytes, total_output_bytes)
reports += f"Total input size in bytes is {total_input_bytes}, total output size in bytes is {total_output_bytes},"
reports += f" theoretical max qps (bounds by PCIe bandwidth) for this submodule is {qps}.\n"
if qps < max_qps:
max_qps = qps
bottleneck_module = node.target
try:
lowered_submod = self._lower_model_to_backend(submod, submod_inputs)
except RuntimeError:
reports += "Run into an error during lowering!\n"
reports += self._find_culprit(submod, submod_inputs)
continue
try:
lowered_submod(*submod_inputs)
except RuntimeError:
reports += "Run into an error during inference!\n"
reports += self._find_culprit(submod, submod_inputs)
else:
reports += "Lowering and running succeed!\n"
reports += f"\nTheoretical max qps (bounds by PCIe bandwidth) for this model is {max_qps},"
reports += f" bottleneck is submodule {bottleneck_module}."
print(reports)
# return the reports for testing purposes
return reports
# ===============================================================
# Helpers for extend_acc_subgraph() method
# ===============================================================
def find_reverse_deps(
self, tag_id: Optional[int] = None
) -> Dict[torch.fx.Node, NodeSet]:
"""
Builds reversed topological node dependencies, if tag_id is specified,
we ignore nodes that are in later subgraph i.e. nodes have greater tag_id.
"""
result: Dict[torch.fx.Node, NodeSet] = defaultdict(set)
for node in self.module.graph.nodes:
if node.op not in CALLABLE_NODE_OPS:
continue
for user in node.users:
if user.op not in CALLABLE_NODE_OPS:
continue
if tag_id is None or (int(user.tag.split("_")[-1]) < tag_id):
result[node].add(user)
return result
def update_reverse_deps_for_fusions(
self, deps: Dict[torch.fx.Node, NodeSet]
):
processed_node = set()
for node, fusion in self.fusions.items():
if node in processed_node:
continue
new_dep = set()
# Create a new dependency set which include all the
# dependencies of the nodes in the fusion group
for n in fusion:
new_dep.update(deps[n])
# Exclude nodes in the fusion
new_dep.difference_update(fusion)
# Update dependency
for n in fusion:
deps[n] = new_dep
for arg in n.all_input_nodes:
if arg not in fusion:
deps[arg].update(fusion)
processed_node.add(n)
def find_parent_nodes_of_subgraph(self, tag: str) -> NodeSet:
"""
Finds parent nodes of the `tag` subgraph.
Traverse the inputs of nodes in the subgraph, if input doesn't belong to the subgraph
and is not a placeholder, we consider it as the parent node of the subgraph.
"""
parent_nodes = set()
for node in self.module.graph.nodes:
if node.op in CALLABLE_NODE_OPS and node.tag == tag:
for arg in node.all_input_nodes:
if arg.op in CALLABLE_NODE_OPS and arg.tag != tag:
parent_nodes.add(arg)
return parent_nodes
def extend_acc_subgraph(self, tag: str):
"""
Extend the acc subgraph with `tag` going the reversed topological direction.
"""
# Dict that maps node to its users and ignore users that
# are in the subgraph that has greater tag
deps = self.find_reverse_deps(tag_id=int(tag.split("_")[-1]))
self.update_reverse_deps_for_fusions(deps)
# Parent nodes of the subgraph
parent_nodes = self.find_parent_nodes_of_subgraph(tag)
visited_nodes: NodeSet = set()
while parent_nodes:
node = None
# Find a acc node that depends on visited nodes only
for n in parent_nodes:
if deps[n] <= visited_nodes and n in self.acc_nodes:
node = n
break
if node is None:
break
# Put the node into `tag` subgraph
node.tag = tag # type: ignore[attr-defined]
parent_nodes.remove(node)
visited_nodes.add(node)
# If node is in a fusion group, add all fusion buddies to parent nodes
if node in self.fusions:
for fusion_node in self.fusions[node]:
if fusion_node not in visited_nodes:
parent_nodes.add(fusion_node)
# Add inputs of the node to parent nodes
for arg in node.all_input_nodes:
if arg.op in CALLABLE_NODE_OPS and arg not in visited_nodes:
parent_nodes.add(arg)
# ===============================================================
# Helpers for split() method
# ===============================================================
def starter_nodes(self) -> Tuple[NodeSet, NodeSet]:
"""
Finds nodes that consume module inputs or get_attr nodes.
"""
starter_cpu_nodes: NodeSet = set()
starter_acc_nodes: NodeSet = set()
for node in self.module.graph.nodes:
if node.op not in {"placeholder", "get_attr"}:
continue
for user in node.users:
if user in self.acc_nodes:
starter_acc_nodes.add(user)
else:
starter_cpu_nodes.add(user)
return starter_cpu_nodes, starter_acc_nodes
def put_nodes_into_subgraphs(self) -> List[Subgraph]:
# We start graph traversal from leaf nodes
current_cpu_nodes, current_acc_nodes = self.starter_nodes()
visited_nodes: NodeSet = set()
# Determine which subgraph to start from based on which subgraph has
# 0-dep node
acc_subgraph: bool = not any(len(self.deps[n]) == 0 for n in current_cpu_nodes)
current_subgraph_nodes: NodeList = []
# Result accumulator
subgraphs: List[Subgraph] = []
while current_cpu_nodes or current_acc_nodes:
# Find the first node that should belong to the current subgraph and has all dependencies resolved
current_nodes = current_acc_nodes if acc_subgraph else current_cpu_nodes
node = next(
(n for n in current_nodes if self.deps[n] <= visited_nodes),
None,
)
# If nothing was found, then it's time to flip the mode and start a new subgraph
if node is None:
if not current_subgraph_nodes:
raise FxNetSplitterInternalError("Subgraph can't be empty")
subgraphs.append(
Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes)
)
acc_subgraph = not acc_subgraph
current_subgraph_nodes = []
continue
current_nodes.remove(node)
visited_nodes.add(node)
current_subgraph_nodes.append(node)
# Add fusion buddies
if node in self.fusions:
if node in self.acc_nodes:
current_acc_nodes.update(self.fusions[node] - visited_nodes)
else:
current_cpu_nodes.update(self.fusions[node] - visited_nodes)
# Put depending nodes into the queue
for user in node.users:
if user.op not in CALLABLE_NODE_OPS:
continue
# Add downstream nodes
if user in self.acc_nodes:
current_acc_nodes.add(user)
else:
current_cpu_nodes.add(user)
# Check if the last subgraph was not created
if current_subgraph_nodes:
subgraphs.append(
Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes)
)
if not subgraphs:
raise FxNetSplitterInternalError("Couldn't create subgraphs")
return subgraphs
def remove_small_acc_subgraphs(self, subgraphs: List[Subgraph]) -> List[Subgraph]:
"""
This pass finds ACC submodules with less than specified size and merges
them with adjacent CPU submodules.
"""
result: List[Subgraph] = []
for subgraph in subgraphs:
if subgraph.is_acc:
if len(subgraph.nodes) >= self.settings.min_acc_module_size:
result.append(subgraph)
else:
print(
"Eliminating acc subgraph because it's smaller than the threshold: "
f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}"
)
if result:
result[-1].nodes.extend(subgraph.nodes)
else:
subgraph.is_acc = False
result.append(subgraph)
else:
if result and not result[-1].is_acc:
result[-1].nodes.extend(subgraph.nodes)
else:
result.append(subgraph)
return result
def tag(self, subgraphs: List[Subgraph]):
self.tags: List[str] = []
for subgraph in subgraphs:
tag = f"_run_on_acc_{len(self.tags)}" if subgraph.is_acc else f"{self.non_acc_submodule_name}{len(self.tags)}"
self.tags.append(tag)
for node in subgraph.nodes:
if hasattr(node, "tag"):
raise FxNetSplitterInternalError(f"Node {node} was already tagged")
node.tag = tag # type: ignore[attr-defined]
self._node_submodule_map[node.name] = tag
def split(self, remove_tag: bool = False) -> torch.fx.GraphModule:
split_module = split_by_tags(self.module, self.tags)
if remove_tag:
for node in self.module.graph.nodes:
if hasattr(node, "tag"):
del node.tag
return split_module
def __call__(self) -> torch.fx.GraphModule:
subgraphs = self.put_nodes_into_subgraphs()
subgraphs = self.remove_small_acc_subgraphs(subgraphs)
acc_subgraphs_count = len([s for s in subgraphs if s.is_acc])
non_acc_subgraphs_count = len(subgraphs) - acc_subgraphs_count
print(f"Got {acc_subgraphs_count} acc subgraphs and {non_acc_subgraphs_count} non-acc subgraphs")
self.tag(subgraphs)
return self.split()
def generate_split_results(self) -> SplitResult:
split_module = self()
submodule_names = []
for name, mod in split_module.named_children():
submodule_names.append(name)
submodule_inputs = generate_inputs_for_submodules(split_module, self.sample_input, submodule_names)
return SplitResult(split_module, submodule_inputs, self.non_acc_submodule_name)
|