Spaces:
Running
Running
File size: 10,407 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 |
from typing import List, Tuple, Union, Dict, Any, Set, Mapping, Optional
import collections
from dataclasses import dataclass
import torch
import torch.fx
from torch.fx.node import _get_qualified_name
from torch.fx._compatibility import compatibility
__all__ = ['get_acc_ops_name', 'get_node_target', 'is_node_output_tensor', 'FxNetAccFusionsFinder', 'legalize_graph']
Tensors = Union[Tuple[torch.Tensor], List[torch.Tensor]]
TensorOrTensors = Union[torch.Tensor, Tensors]
NodeList = List[torch.fx.Node]
NodeSet = Set[torch.fx.Node]
Names = List[str]
CALLABLE_NODE_OPS = {"call_module", "call_function", "call_method"}
@compatibility(is_backward_compatible=False)
def get_acc_ops_name(k):
if isinstance(k, str):
return k
elif k.__module__ and "acc_ops" in k.__module__:
return f"acc_ops.{k.__name__}"
else:
module = k.__module__.replace('torch._ops', 'torch.ops') # WAR for bug in how torch.ops assigns module
return f"{module if module else ''}.{k.__name__}"
@compatibility(is_backward_compatible=False)
def get_node_target(submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node) -> str:
"""
Given a `node` returns its target typename.
For "call_method" node, return node.target which is the name of that method being called.
This could potential lead to conflict but should be okay because normally it's on a tensor.
For "call_function" node, return typename of node.target.
For "call_module" node, return typename of the module that node.target point to.
If seeing "_VariableFunctionsClass" in the target name string, it will be replaced by
"torch". e.g. _VariableFunctionsClass.relu would become torch.relu.
"""
assert node.op in CALLABLE_NODE_OPS, (
"Expect op types of " + ", ".join(CALLABLE_NODE_OPS) + f", but found {node.op}"
)
if node.op == "call_module":
assert isinstance(node.target, str)
submod = submodules[node.target]
submod_type = getattr(submod, "_base_class_origin", type(submod))
return get_acc_ops_name(submod_type)
elif node.op == "call_function":
target: Any = node.target
return (
f"acc_ops.{target.__name__}"
if target.__module__ is not None and "acc_ops" in target.__module__
else _get_qualified_name(target)
)
else:
assert isinstance(node.target, str)
return node.target
@compatibility(is_backward_compatible=False)
def is_node_output_tensor(node: torch.fx.Node) -> bool:
"""Checks if the node output produces a Tensor or not.
NOTE: This requires to run `ShapeProp` on the containing fx graph before
calling this function. This is because it works by checking the `type`
metadata on the node. This metadata is produced by the `ShapeProp`.
"""
type_ = node.meta.get("type", None)
return type_ is not None and issubclass(type_, torch.Tensor)
@compatibility(is_backward_compatible=False)
class FxNetAccFusionsFinder:
"""
Finds groups of connected ACC nodes that pass non-tensor data between each other.
Such groups are called fusion groups.
"""
def __init__(self, module: torch.fx.GraphModule, acc_nodes: NodeSet):
self.module = module
self.nodes = list(module.graph.nodes)
self.acc_nodes = acc_nodes
@dataclass
class FusionGroup:
# The smallest idx of nodes in the fusion group after topological sorting all the nodes in the model.
top_node_idx: int
# Nodes in this fusion group.
nodes: NodeSet
# Inputs to this fusion group.
inputs: NodeSet
# Nodes that in the fusion group that haven't been processed yet.
nodes_need_process: NodeSet
def add_node(self, node):
"""
Add a node to fusion group.
"""
if node in self.nodes:
return
self.nodes_need_process.add(node)
self.nodes.add(node)
self.inputs.discard(node)
self.inputs.update(
{
n
for n in node.all_input_nodes
if n.op in CALLABLE_NODE_OPS and n not in self.nodes
}
)
def recursive_add_node(
self,
fusion_group: "FxNetAccFusionsFinder.FusionGroup",
inputs: Union[NodeSet, NodeList],
visited: Optional[NodeSet] = None,
):
"""
Start from inputs and going reverse topological order. If any upstream node
is in the fusion group, add all the nodes in this path to fusion group.
"""
for arg in inputs:
# skip the node if already seen
if visited is not None:
if arg in visited:
continue
visited.add(arg)
# Skip placeholder and get_attr because they won't be in the fusion group.
if arg.op not in CALLABLE_NODE_OPS:
continue
# If the node has smaller idx, it's already an upstream node of the fusion
# group. We don't need to check it anymore.
if self.nodes.index(arg) < fusion_group.top_node_idx:
continue
# If the node is in the fusion group, return True.
if arg in fusion_group.nodes:
return True
# Check the upstream nodes of the node, if any of them is in the fusion group
# we'll add this node to fusion group and return True.
if self.recursive_add_node(fusion_group, arg.all_input_nodes, visited):
fusion_group.add_node(arg)
return True
return False
def __call__(self) -> Dict[torch.fx.Node, NodeSet]:
result: Dict[torch.fx.Node, NodeSet] = {}
acc_nodes = list(self.acc_nodes)
for node in acc_nodes:
if node in result:
continue
if node.op not in CALLABLE_NODE_OPS:
continue
if "tensor_meta" in node.meta:
continue
if node not in self.acc_nodes:
continue
fusion_group: FxNetAccFusionsFinder.FusionGroup = self.FusionGroup(
top_node_idx=self.nodes.index(node),
nodes={node},
inputs=set(node.all_input_nodes),
nodes_need_process={node},
)
while fusion_group.nodes_need_process:
node = fusion_group.nodes_need_process.pop()
self.recursive_add_node(
fusion_group,
fusion_group.inputs,
visited=set(),
)
# Optionally add downstream nodes
if "tensor_meta" not in node.meta:
for user in node.users:
if user.op not in CALLABLE_NODE_OPS:
continue
if user in fusion_group.nodes:
continue
fusion_group.add_node(user)
self.recursive_add_node(
fusion_group,
fusion_group.inputs,
visited=set(),
)
# Add some upstream nodes
for arg in node.all_input_nodes:
if arg.op not in CALLABLE_NODE_OPS:
continue
if "tensor_meta" in arg.meta:
continue
if arg in fusion_group.nodes:
continue
fusion_group.add_node(arg)
fusion_group.top_node_idx = min(
fusion_group.top_node_idx, self.nodes.index(arg)
)
self.recursive_add_node(
fusion_group,
fusion_group.inputs,
visited=set(),
)
if not (set(fusion_group.nodes) <= self.acc_nodes):
self.acc_nodes -= fusion_group.nodes
else:
for n in fusion_group.nodes:
result[n] = fusion_group.nodes
return result
@compatibility(is_backward_compatible=False)
def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
"""
Replace the graph of the given GraphModule with one that contains the same nodes as the
original, but in topologically sorted order.
This is used by the merge_matmul transformation below, which disturbs the topologically sorted
order of its input GraphModule, so that this order is restored before further transformation.
Arguments:
gm: The graph module to topologically sort. It is modified in-place.
Returns:
The graph module in-place sorted
"""
indeg = dict.fromkeys(gm.graph.nodes, 0)
new_graph = torch.fx.Graph()
# Track how many unfulfilled dependencies each node has
for node in gm.graph.nodes:
for user in node.users:
indeg[user] += 1
queue: collections.deque = collections.deque()
# Add all nodes with no dependencies to the queue
for node in gm.graph.nodes:
if indeg[node] == 0:
queue.append(node)
env: Dict[torch.fx.Node, torch.fx.Node] = {}
# Pop nodes from the queue, and add nodes that have had all their
# dependencies fulfilled
while len(queue) > 0:
cur = queue.popleft()
env[cur] = new_graph.node_copy(cur, lambda x: env[x])
for user in cur.users:
indeg[user] -= 1
if indeg[user] == 0:
queue.append(user)
# If the new graph's size is not as large as the old one, then there must be
# a cycle (i.e. some node's dependencies were not satisfied.)
if len(new_graph.nodes) < len(gm.graph.nodes):
raise RuntimeError(f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}")
new_graph._codegen = gm.graph._codegen
gm.graph = new_graph
return gm
|