Spaces:
Running
Running
from typing import List, Tuple, Union, Dict, Any, Set, Mapping, Optional | |
import collections | |
from dataclasses import dataclass | |
import torch | |
import torch.fx | |
from torch.fx.node import _get_qualified_name | |
from torch.fx._compatibility import compatibility | |
__all__ = ['get_acc_ops_name', 'get_node_target', 'is_node_output_tensor', 'FxNetAccFusionsFinder', 'legalize_graph'] | |
Tensors = Union[Tuple[torch.Tensor], List[torch.Tensor]] | |
TensorOrTensors = Union[torch.Tensor, Tensors] | |
NodeList = List[torch.fx.Node] | |
NodeSet = Set[torch.fx.Node] | |
Names = List[str] | |
CALLABLE_NODE_OPS = {"call_module", "call_function", "call_method"} | |
def get_acc_ops_name(k): | |
if isinstance(k, str): | |
return k | |
elif k.__module__ and "acc_ops" in k.__module__: | |
return f"acc_ops.{k.__name__}" | |
else: | |
module = k.__module__.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module | |
return f"{module if module else ''}.{k.__name__}" | |
def get_node_target(submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node) -> str: | |
""" | |
Given a `node` returns its target typename. | |
For "call_method" node, return node.target which is the name of that method being called. | |
This could potential lead to conflict but should be okay because normally it's on a tensor. | |
For "call_function" node, return typename of node.target. | |
For "call_module" node, return typename of the module that node.target point to. | |
If seeing "_VariableFunctionsClass" in the target name string, it will be replaced by | |
"torch". e.g. _VariableFunctionsClass.relu would become torch.relu. | |
""" | |
assert node.op in CALLABLE_NODE_OPS, ( | |
"Expect op types of " + ", ".join(CALLABLE_NODE_OPS) + f", but found {node.op}" | |
) | |
if node.op == "call_module": | |
assert isinstance(node.target, str) | |
submod = submodules[node.target] | |
submod_type = getattr(submod, "_base_class_origin", type(submod)) | |
return get_acc_ops_name(submod_type) | |
elif node.op == "call_function": | |
target: Any = node.target | |
return ( | |
f"acc_ops.{target.__name__}" | |
if target.__module__ is not None and "acc_ops" in target.__module__ | |
else _get_qualified_name(target) | |
) | |
else: | |
assert isinstance(node.target, str) | |
return node.target | |
def is_node_output_tensor(node: torch.fx.Node) -> bool: | |
"""Checks if the node output produces a Tensor or not. | |
NOTE: This requires to run `ShapeProp` on the containing fx graph before | |
calling this function. This is because it works by checking the `type` | |
metadata on the node. This metadata is produced by the `ShapeProp`. | |
""" | |
type_ = node.meta.get("type", None) | |
return type_ is not None and issubclass(type_, torch.Tensor) | |
class FxNetAccFusionsFinder: | |
""" | |
Finds groups of connected ACC nodes that pass non-tensor data between each other. | |
Such groups are called fusion groups. | |
""" | |
def __init__(self, module: torch.fx.GraphModule, acc_nodes: NodeSet): | |
self.module = module | |
self.nodes = list(module.graph.nodes) | |
self.acc_nodes = acc_nodes | |
class FusionGroup: | |
# The smallest idx of nodes in the fusion group after topological sorting all the nodes in the model. | |
top_node_idx: int | |
# Nodes in this fusion group. | |
nodes: NodeSet | |
# Inputs to this fusion group. | |
inputs: NodeSet | |
# Nodes that in the fusion group that haven't been processed yet. | |
nodes_need_process: NodeSet | |
def add_node(self, node): | |
""" | |
Add a node to fusion group. | |
""" | |
if node in self.nodes: | |
return | |
self.nodes_need_process.add(node) | |
self.nodes.add(node) | |
self.inputs.discard(node) | |
self.inputs.update( | |
{ | |
n | |
for n in node.all_input_nodes | |
if n.op in CALLABLE_NODE_OPS and n not in self.nodes | |
} | |
) | |
def recursive_add_node( | |
self, | |
fusion_group: "FxNetAccFusionsFinder.FusionGroup", | |
inputs: Union[NodeSet, NodeList], | |
visited: Optional[NodeSet] = None, | |
): | |
""" | |
Start from inputs and going reverse topological order. If any upstream node | |
is in the fusion group, add all the nodes in this path to fusion group. | |
""" | |
for arg in inputs: | |
# skip the node if already seen | |
if visited is not None: | |
if arg in visited: | |
continue | |
visited.add(arg) | |
# Skip placeholder and get_attr because they won't be in the fusion group. | |
if arg.op not in CALLABLE_NODE_OPS: | |
continue | |
# If the node has smaller idx, it's already an upstream node of the fusion | |
# group. We don't need to check it anymore. | |
if self.nodes.index(arg) < fusion_group.top_node_idx: | |
continue | |
# If the node is in the fusion group, return True. | |
if arg in fusion_group.nodes: | |
return True | |
# Check the upstream nodes of the node, if any of them is in the fusion group | |
# we'll add this node to fusion group and return True. | |
if self.recursive_add_node(fusion_group, arg.all_input_nodes, visited): | |
fusion_group.add_node(arg) | |
return True | |
return False | |
def __call__(self) -> Dict[torch.fx.Node, NodeSet]: | |
result: Dict[torch.fx.Node, NodeSet] = {} | |
acc_nodes = list(self.acc_nodes) | |
for node in acc_nodes: | |
if node in result: | |
continue | |
if node.op not in CALLABLE_NODE_OPS: | |
continue | |
if "tensor_meta" in node.meta: | |
continue | |
if node not in self.acc_nodes: | |
continue | |
fusion_group: FxNetAccFusionsFinder.FusionGroup = self.FusionGroup( | |
top_node_idx=self.nodes.index(node), | |
nodes={node}, | |
inputs=set(node.all_input_nodes), | |
nodes_need_process={node}, | |
) | |
while fusion_group.nodes_need_process: | |
node = fusion_group.nodes_need_process.pop() | |
self.recursive_add_node( | |
fusion_group, | |
fusion_group.inputs, | |
visited=set(), | |
) | |
# Optionally add downstream nodes | |
if "tensor_meta" not in node.meta: | |
for user in node.users: | |
if user.op not in CALLABLE_NODE_OPS: | |
continue | |
if user in fusion_group.nodes: | |
continue | |
fusion_group.add_node(user) | |
self.recursive_add_node( | |
fusion_group, | |
fusion_group.inputs, | |
visited=set(), | |
) | |
# Add some upstream nodes | |
for arg in node.all_input_nodes: | |
if arg.op not in CALLABLE_NODE_OPS: | |
continue | |
if "tensor_meta" in arg.meta: | |
continue | |
if arg in fusion_group.nodes: | |
continue | |
fusion_group.add_node(arg) | |
fusion_group.top_node_idx = min( | |
fusion_group.top_node_idx, self.nodes.index(arg) | |
) | |
self.recursive_add_node( | |
fusion_group, | |
fusion_group.inputs, | |
visited=set(), | |
) | |
if not (set(fusion_group.nodes) <= self.acc_nodes): | |
self.acc_nodes -= fusion_group.nodes | |
else: | |
for n in fusion_group.nodes: | |
result[n] = fusion_group.nodes | |
return result | |
def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: | |
""" | |
Replace the graph of the given GraphModule with one that contains the same nodes as the | |
original, but in topologically sorted order. | |
This is used by the merge_matmul transformation below, which disturbs the topologically sorted | |
order of its input GraphModule, so that this order is restored before further transformation. | |
Arguments: | |
gm: The graph module to topologically sort. It is modified in-place. | |
Returns: | |
The graph module in-place sorted | |
""" | |
indeg = dict.fromkeys(gm.graph.nodes, 0) | |
new_graph = torch.fx.Graph() | |
# Track how many unfulfilled dependencies each node has | |
for node in gm.graph.nodes: | |
for user in node.users: | |
indeg[user] += 1 | |
queue: collections.deque = collections.deque() | |
# Add all nodes with no dependencies to the queue | |
for node in gm.graph.nodes: | |
if indeg[node] == 0: | |
queue.append(node) | |
env: Dict[torch.fx.Node, torch.fx.Node] = {} | |
# Pop nodes from the queue, and add nodes that have had all their | |
# dependencies fulfilled | |
while len(queue) > 0: | |
cur = queue.popleft() | |
env[cur] = new_graph.node_copy(cur, lambda x: env[x]) | |
for user in cur.users: | |
indeg[user] -= 1 | |
if indeg[user] == 0: | |
queue.append(user) | |
# If the new graph's size is not as large as the old one, then there must be | |
# a cycle (i.e. some node's dependencies were not satisfied.) | |
if len(new_graph.nodes) < len(gm.graph.nodes): | |
raise RuntimeError(f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}") | |
new_graph._codegen = gm.graph._codegen | |
gm.graph = new_graph | |
return gm | |