|
|
|
from torch.fx import ( |
|
GraphModule, |
|
Node, |
|
map_arg |
|
) |
|
from torch.fx.graph import Graph |
|
from .match_utils import ( |
|
_is_match, |
|
MatchAllNode, |
|
) |
|
from .pattern_utils import ( |
|
_sorted_patterns_dict, |
|
) |
|
|
|
from ..backend_config import ( |
|
BackendConfig, |
|
get_native_backend_config, |
|
) |
|
from ..backend_config.utils import ( |
|
get_fuser_method_mapping, |
|
get_fusion_pattern_to_root_node_getter, |
|
get_fusion_pattern_to_extra_inputs_getter, |
|
) |
|
|
|
from .custom_config import FuseCustomConfig |
|
|
|
from .fuse_handler import ( |
|
_get_fusion_pattern_to_fuse_handler_cls, |
|
FuseHandler, |
|
) |
|
|
|
from typing import Any, Callable, Dict, List, Tuple, Union |
|
import warnings |
|
|
|
from torch.ao.quantization.utils import Pattern, NodePattern |
|
|
|
|
|
__all__ = [ |
|
"fuse", |
|
|
|
|
|
"FuseHandler", |
|
] |
|
|
|
|
|
def fuse( |
|
model: GraphModule, |
|
is_qat: bool, |
|
fuse_custom_config: Union[FuseCustomConfig, Dict[str, Any], None] = None, |
|
backend_config: Union[BackendConfig, Dict[str, Any], None] = None, |
|
) -> GraphModule: |
|
if fuse_custom_config is None: |
|
fuse_custom_config = FuseCustomConfig() |
|
|
|
if isinstance(fuse_custom_config, dict): |
|
warnings.warn( |
|
"Passing a fuse_custom_config_dict to fuse is deprecated and will not be supported " |
|
"in a future version. Please pass in a FuseCustomConfig instead.", |
|
FutureWarning, |
|
stacklevel=2, |
|
) |
|
fuse_custom_config = FuseCustomConfig.from_dict(fuse_custom_config) |
|
|
|
if isinstance(backend_config, dict): |
|
warnings.warn( |
|
"Passing a backend_config_dict to prepare is deprecated and will not be supported " |
|
"in a future version. Please pass in a BackendConfig instead.", |
|
FutureWarning, |
|
stacklevel=2, |
|
) |
|
backend_config = BackendConfig.from_dict(backend_config) |
|
|
|
named_modules = dict(model.named_modules()) |
|
|
|
if backend_config is None: |
|
backend_config = get_native_backend_config() |
|
|
|
fusion_pattern_to_fuse_handler_cls = _sorted_patterns_dict(_get_fusion_pattern_to_fuse_handler_cls(backend_config)) |
|
fuser_method_mapping = get_fuser_method_mapping(backend_config) |
|
fusion_pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config) |
|
fusion_pattern_to_extra_inputs_getter = get_fusion_pattern_to_extra_inputs_getter(backend_config) |
|
|
|
|
|
fusion_pairs = _find_matches( |
|
model, model.graph, fusion_pattern_to_fuse_handler_cls) |
|
|
|
|
|
fused_graph = Graph() |
|
env: Dict[Any, Any] = {} |
|
|
|
def load_arg(a): |
|
return map_arg(a, lambda node: env[node.name]) |
|
|
|
def default_root_node_getter(node_pattern): |
|
while not isinstance(node_pattern[-1], Node): |
|
node_pattern = node_pattern[-1] |
|
return node_pattern[-1] |
|
|
|
for node in model.graph.nodes: |
|
maybe_last_node, pattern, matched_node_pattern, obj, node_to_subpattern = \ |
|
fusion_pairs.get(node.name, (None, None, None, None, None)) |
|
|
|
if node_to_subpattern is not None: |
|
node_subpattern = node_to_subpattern.get(node, None) |
|
else: |
|
node_subpattern = None |
|
if maybe_last_node is node: |
|
assert obj is not None |
|
root_node_getter = fusion_pattern_to_root_node_getter.get(pattern, default_root_node_getter) |
|
root_node = root_node_getter(matched_node_pattern) |
|
extra_inputs_getter = fusion_pattern_to_extra_inputs_getter.get(pattern, None) |
|
extra_inputs = [] |
|
if extra_inputs_getter is not None: |
|
extra_inputs = extra_inputs_getter(matched_node_pattern) |
|
|
|
|
|
env[node.name] = obj.fuse( |
|
load_arg, named_modules, fused_graph, root_node, extra_inputs, matched_node_pattern, |
|
fuse_custom_config, fuser_method_mapping, is_qat) |
|
elif maybe_last_node is None or node_subpattern is MatchAllNode: |
|
env[node.name] = fused_graph.node_copy(node, load_arg) |
|
|
|
|
|
model = GraphModule(model, fused_graph) |
|
return model |
|
|
|
def _find_matches( |
|
root: GraphModule, |
|
graph: Graph, |
|
pattern_to_fuse_handler_cls: Dict[Pattern, Callable], |
|
) -> Dict[str, Tuple[Node, Pattern, NodePattern, FuseHandler, Dict[Node, Any]]]: |
|
modules = dict(root.named_modules()) |
|
|
|
match_map : Dict[ |
|
str, Tuple[Node, Pattern, NodePattern, FuseHandler, Dict[Node, Any]]] = {} |
|
|
|
node_to_subpattern: Dict[Node, Any] = {} |
|
|
|
|
|
def apply_match(pattern, node, match, matched_node_pattern, node_to_subpattern): |
|
if isinstance(pattern, tuple): |
|
s, *args = pattern |
|
current_node_pattern: List[Node] = [] |
|
apply_match(s, node, match, current_node_pattern, node_to_subpattern) |
|
for subpattern, arg in zip(args, node.args): |
|
apply_match(subpattern, arg, match, current_node_pattern, node_to_subpattern) |
|
matched_node_pattern.append(tuple(current_node_pattern)) |
|
else: |
|
|
|
if node.name not in match_map: |
|
matched_node_pattern.append(node) |
|
|
|
|
|
if pattern is not MatchAllNode: |
|
node_to_subpattern[node] = pattern |
|
root_node, pattern, handler = match |
|
match_map[node.name] = (root_node, pattern, matched_node_pattern, handler, node_to_subpattern) |
|
|
|
for node in reversed(graph.nodes): |
|
if node.name not in match_map: |
|
for pattern, fuse_handler_cls in pattern_to_fuse_handler_cls.items(): |
|
matched_node_pattern: List[Node] = [] |
|
if _is_match(modules, node, pattern): |
|
apply_match(pattern, node, (node, pattern, fuse_handler_cls(node)), matched_node_pattern, node_to_subpattern) |
|
break |
|
|
|
return match_map |
|
|