Spaces:
Running
Running
File size: 15,978 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 |
from torch.fx.passes.utils.fuser_utils import fuse_by_partitions
import collections
import itertools
import logging
from copy import copy
from typing import Dict, Iterable, List, Optional, Sequence, Set
from torch.fx.graph_module import GraphModule
from torch.fx.node import Node, _get_qualified_name
from torch.fx.passes.operator_support import OperatorSupportBase
logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
class Partition:
def __init__(self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None):
self.id = id
self.nodes: Set[Node] = set(nodes) if nodes is not None else set()
def __repr__(self) -> str:
return str(self.nodes)
def add_node(self, node: Node):
self.nodes.add(node)
def remove_node(self, node: Node):
self.nodes.remove(node)
def size(self):
return len(self.nodes)
class _DependencyViewer:
def __init__(self, graph_module: GraphModule):
self.upstreams = collections.defaultdict(set)
self.downstreams = collections.defaultdict(set)
for node in graph_module.graph.nodes:
for input_node in node.all_input_nodes:
# add input_node and input_node's upstream dependency
self.upstreams[node].add(input_node)
self.upstreams[node].update(self.upstreams[input_node])
for node in reversed(graph_module.graph.nodes):
for output_node in node.users:
# add output_node and output_node's downstream dependency
self.downstreams[node].add(output_node)
self.downstreams[node].update(self.downstreams[output_node])
def downstreams_of(self, node: Node) -> Set[Node]:
return self.downstreams[node]
def upstreams_of(self, node: Node) -> Set[Node]:
return self.upstreams[node]
class CapabilityBasedPartitioner:
def __init__(self,
graph_module: GraphModule,
operator_support: OperatorSupportBase,
allows_single_node_partition: bool = False,
non_compute_ops: Optional[Sequence[str]] = None,
allowed_single_node_partition_ops: Optional[Sequence[str]] = None,
) -> None:
self.graph_module = graph_module
self.operator_support = operator_support
self.allows_single_node_partition = allows_single_node_partition
self.non_compute_ops = non_compute_ops if non_compute_ops is not None else []
self.allowed_single_node_partition_ops = (
allowed_single_node_partition_ops
if allowed_single_node_partition_ops is not None
else []
)
self.dependency_viewer = _DependencyViewer(graph_module)
def __is_node_supported(self, node: Node) -> bool:
return (
self.operator_support.is_node_supported(dict(self.graph_module.named_modules()), node)
)
def propose_partitions(self) -> List[Partition]:
# partition_map is a mapping from partition id to a set of partition id's.
# The value set contains all the partition ids that can be reached by doing a
# DFS starting from the partition id in the key.
partition_map : Dict[int, Set] = collections.defaultdict(set)
# assumptions: nodes in candidate list is sorted in topological order
assignment: Dict[Node, int] = {} # mapping from node to partition_id
partitions_by_id: Dict[int, Partition] = {} # mapping from partition_id to partition
new_partition_id = itertools.count()
# try to merge partition other_id into partition self_id
# merge only happens if the end graph doesn't contain cyclic dependency
# returns `True` when merge happens, `False` otherwise.
def maybe_merge_partition(self_id: int, other_id: int):
# merged_nodes is the union of nodes in two partition to-be-merged
merged_nodes = copy(partitions_by_id[self_id].nodes)
merged_nodes.update(partitions_by_id[other_id].nodes)
def dfs_iter_find_cycle(all_user_nodes: List[Node]):
for user_node in all_user_nodes:
visited_partition_ids = set()
for path_node in self.dependency_viewer.downstreams_of(user_node):
# If any of the nodes in the dfs path of this node are in the merged_nodes
# list then there is a cycle in the graph.
if path_node in merged_nodes:
return True
# If any of the nodes in the dfs path of this node are in the assignment
# map then we have to make sure that the partitions that these nodes belong
# to do not form a cycle with the current partitions being merged. This means
# iterating through all the nodes in all the parititons that are traversed in
# the dfs path and checking if they are in the merged_nodes list.
if path_node in assignment:
partition_id = assignment[path_node]
# If the partition id has already been visited then we know that it doesn't
# form a cycle with the current partitions being merged.
if partition_id in visited_partition_ids:
continue
p_map = partition_map[partition_id]
if self_id in p_map or other_id in p_map:
return True
visited_partition_ids.add(partition_id)
return False
# check if merge would create cyclic dependency.
all_user_nodes = []
for node in merged_nodes:
for user_node in node.users:
if user_node not in merged_nodes:
all_user_nodes.append(user_node)
if dfs_iter_find_cycle(all_user_nodes):
# return false indicating cyclic dependency found and
# merge is aborted
return False
# no cyclic dependency found, move forward with the merge
# updating partition nodes
partitions_by_id[self_id].nodes = merged_nodes
# updating assignment map
for node in partitions_by_id[other_id].nodes:
assignment[node] = self_id
# delete other partition
del partitions_by_id[other_id]
partition_map[self_id] = partition_map[self_id].union(partition_map[other_id])
del partition_map[other_id]
return True
def merge_single_node(node: Node, id: Optional[int]):
def _update_partition_map(node: Node, id: int):
# Iterate through all the downstream nodes of this node and update the partition map
# to indicate that there is a path from the partition id of this node to the target
# partition id.
downstream_nodes = self.dependency_viewer.downstreams_of(node)
for curr_node in downstream_nodes:
target_id = assignment.get(curr_node, None)
if target_id is not None:
partition_map[id].add(target_id)
# Iterate through all the upstream nodes of this node and update the partition map
# to indicate that there is a path from the partition id of the upstream node to the
# current node's partition id.
upstream_nodes = self.dependency_viewer.upstreams_of(node)
for curr_node in upstream_nodes:
source_id = assignment.get(curr_node, None)
if source_id is not None:
partition_map[source_id].add(id)
if node in assignment:
partitions_by_id[assignment[node]].remove_node(node)
if id is None:
assignment.pop(node)
elif id not in partitions_by_id:
assignment[node] = id
partitions_by_id[id] = Partition(id=id, nodes=[node])
_update_partition_map(node, id)
else:
assignment[node] = id
partitions_by_id[id].add_node(node)
_update_partition_map(node, id)
logger.debug("Proposing partitions...")
for node in reversed(self.graph_module.graph.nodes):
# use Dict as an ordered set to ensure deterministic partitioning result, don't care value
merge_candidates: Dict[int, None] = {}
# Note a limited horizontal fusion is enabled:
# when `node` is not supported, the code below attempts to fuse consumer of `node`.
#
# I don't see a need to add a knob to disable horizontal fusion yet, we can short-cut
# the fusion by adding an `else` block here to skip horizontal fusion.
if self.__is_node_supported(node) and node not in assignment:
partition_id = next(new_partition_id)
merge_single_node(node, partition_id)
merge_candidates[partition_id] = None
# merge all possible partitions
for node in assignment:
merge_candidates[assignment[node]] = None
merge_candidates_list = list(merge_candidates.keys())
if len(merge_candidates_list) > 1:
self_id = merge_candidates_list[0]
for other_id in merge_candidates_list[1:]:
# note: merge partition `other_id` into partition `self_id` if
# it doesn't create cyclic dependency in the graph, otherwise,
# this is a no-op
maybe_merge_partition(self_id, other_id)
# post processing to re-assign "getitem" nodes into upstream partition
logger.debug("Reassigning getitem nodes to its producer node's partition...")
nodes_reassignment: Dict[Node, int] = {}
for node in self.graph_module.graph.nodes:
is_tuple_output = True
for user in node.users:
if user.op != "call_function" or \
_get_qualified_name(user.target) != "_operator.getitem": # type: ignore[arg-type]
is_tuple_output = False
break
# node has tuple outputs, re-assign all following getitem node into node's partition
if is_tuple_output:
id = assignment.get(node, None) # type: ignore[arg-type]
for user in node.users:
if assignment.get(user, None) != id: # type: ignore[arg-type]
nodes_reassignment[user] = id # type: ignore[assignment]
for node, id in nodes_reassignment.items():
merge_single_node(node, id)
# filter out single node partitions
if not self.allows_single_node_partition:
logger.debug("Filtering out single node partitions...")
default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"}
non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops))
partitions_to_remove: List[int] = []
for id, partition in partitions_by_id.items():
compute_node_count = 0
for node in partition.nodes:
if node.op == "call_function":
assert callable(node.target)
if _get_qualified_name(node.target) not in non_compute_ops:
compute_node_count += 1
if _get_qualified_name(node.target) in self.allowed_single_node_partition_ops:
compute_node_count += 1
if compute_node_count <= 1:
partitions_to_remove.append(id)
for id in partitions_to_remove:
del partitions_by_id[id]
logger.debug("Partitions proposed:")
for id, partition in partitions_by_id.items():
logger.debug("partition #%s: %s", id, [node.name for node in partition.nodes])
return list(partitions_by_id.values())
def fuse_partitions(self, partitions: List[Partition]) -> GraphModule:
logger.debug("Fusing partitions...")
# fuse_by_partitions expects partitions in List[List[Node]]: [ [node0, node1], [node2, node3] ]
return fuse_by_partitions(self.graph_module, [list(partition.nodes) for partition in partitions])
# remove non-compute-ops that sits at the boundary of a partition.
def remove_bookend_non_compute_ops(self, partitions: List[Partition]):
non_compute_ops = set(self.non_compute_ops)
def is_non_compute_node(node: Node):
return node.op == "call_function" and \
_get_qualified_name(node.target) in non_compute_ops # type: ignore[arg-type]
# cache transparent nodes
transparent_input_nodes: Dict[Node, bool] = {}
transparent_output_nodes: Dict[Node, bool] = {}
def is_transparent_input_node(node: Node, partition: Set[Node], removed_nodes: Set[Node]):
if node.op == "placeholder" or (node not in partition) or (node in removed_nodes):
return True
if node in transparent_input_nodes:
return transparent_input_nodes[node]
if is_non_compute_node(node):
for input_n in node.all_input_nodes:
if not is_transparent_input_node(input_n, partition, removed_nodes):
transparent_input_nodes[node] = False
return False
transparent_input_nodes[node] = True
return True
transparent_input_nodes[node] = False
return False
def is_transparent_output_node(node: Node, partition: Set[Node], removed_nodes: Set[Node]):
if node.op == "placeholder" or (node not in partition) or (node in removed_nodes):
return True
if node in transparent_output_nodes:
return transparent_output_nodes[node]
if is_non_compute_node(node):
for output_n in node.users:
if not is_transparent_output_node(output_n, partition, removed_nodes):
transparent_output_nodes[node] = False
return False
transparent_output_nodes[node] = True
return True
transparent_output_nodes[node] = False
return False
for partition in partitions:
# Note it's ok to use `set` here, since we are only query if a node
# has been removed. We are NEVER going to iterate on nodes inside
# the set.
remove_node: Set[Node] = set()
for node in partition.nodes:
if is_non_compute_node(node) and \
(is_transparent_input_node(node, partition.nodes, remove_node) or
is_transparent_output_node(node, partition.nodes, remove_node)):
remove_node.add(node)
if len(remove_node) != 0:
partition.nodes = partition.nodes - remove_node
def partition_and_fuse(self) -> GraphModule:
partitions = self.propose_partitions()
fused_gm = self.fuse_partitions(partitions)
return fused_gm
|