|
|
|
|
|
import hashlib |
|
import torch |
|
import torch.fx |
|
from typing import Any, Dict, Optional, TYPE_CHECKING |
|
from torch.fx.node import _get_qualified_name, _format_arg |
|
from torch.fx.graph import _parse_stack_trace |
|
from torch.fx.passes.shape_prop import TensorMetadata |
|
from torch.fx._compatibility import compatibility |
|
from itertools import chain |
|
|
|
__all__ = ['FxGraphDrawer'] |
|
try: |
|
import pydot |
|
HAS_PYDOT = True |
|
except ImportError: |
|
HAS_PYDOT = False |
|
|
|
_COLOR_MAP = { |
|
"placeholder": '"AliceBlue"', |
|
"call_module": "LemonChiffon1", |
|
"get_param": "Yellow2", |
|
"get_attr": "LightGrey", |
|
"output": "PowderBlue", |
|
} |
|
|
|
_HASH_COLOR_MAP = [ |
|
"CadetBlue1", |
|
"Coral", |
|
"DarkOliveGreen1", |
|
"DarkSeaGreen1", |
|
"GhostWhite", |
|
"Khaki1", |
|
"LavenderBlush1", |
|
"LightSkyBlue", |
|
"MistyRose1", |
|
"MistyRose2", |
|
"PaleTurquoise2", |
|
"PeachPuff1", |
|
"Salmon", |
|
"Thistle1", |
|
"Thistle3", |
|
"Wheat1", |
|
] |
|
|
|
_WEIGHT_TEMPLATE = { |
|
"fillcolor": "Salmon", |
|
"style": '"filled,rounded"', |
|
"fontcolor": "#000000", |
|
} |
|
|
|
if HAS_PYDOT: |
|
@compatibility(is_backward_compatible=False) |
|
class FxGraphDrawer: |
|
""" |
|
Visualize a torch.fx.Graph with graphviz |
|
Basic usage: |
|
g = FxGraphDrawer(symbolic_traced, "resnet18") |
|
g.get_dot_graph().write_svg("a.svg") |
|
""" |
|
|
|
def __init__( |
|
self, |
|
graph_module: torch.fx.GraphModule, |
|
name: str, |
|
ignore_getattr: bool = False, |
|
ignore_parameters_and_buffers: bool = False, |
|
skip_node_names_in_args: bool = True, |
|
parse_stack_trace: bool = False, |
|
dot_graph_shape: Optional[str] = None, |
|
): |
|
self._name = name |
|
self.dot_graph_shape = ( |
|
dot_graph_shape if dot_graph_shape is not None else "record" |
|
) |
|
_WEIGHT_TEMPLATE["shape"] = self.dot_graph_shape |
|
|
|
self._dot_graphs = { |
|
name: self._to_dot( |
|
graph_module, name, ignore_getattr, ignore_parameters_and_buffers, skip_node_names_in_args, parse_stack_trace |
|
) |
|
} |
|
|
|
for node in graph_module.graph.nodes: |
|
if node.op != "call_module": |
|
continue |
|
|
|
leaf_node = self._get_leaf_node(graph_module, node) |
|
|
|
if not isinstance(leaf_node, torch.fx.GraphModule): |
|
continue |
|
|
|
|
|
self._dot_graphs[f"{name}_{node.target}"] = self._to_dot( |
|
leaf_node, |
|
f"{name}_{node.target}", |
|
ignore_getattr, |
|
ignore_parameters_and_buffers, |
|
skip_node_names_in_args, |
|
parse_stack_trace, |
|
) |
|
|
|
def get_dot_graph(self, submod_name=None) -> pydot.Dot: |
|
""" |
|
Visualize a torch.fx.Graph with graphviz |
|
Example: |
|
>>> # xdoctest: +REQUIRES(module:pydot) |
|
>>> # xdoctest: +REQUIRES(module:ubelt) |
|
>>> # define module |
|
>>> class MyModule(torch.nn.Module): |
|
>>> def __init__(self): |
|
>>> super().__init__() |
|
>>> self.linear = torch.nn.Linear(4, 5) |
|
>>> def forward(self, x): |
|
>>> return self.linear(x).clamp(min=0.0, max=1.0) |
|
>>> module = MyModule() |
|
>>> # trace the module |
|
>>> symbolic_traced = torch.fx.symbolic_trace(module) |
|
>>> # setup output file |
|
>>> import ubelt as ub |
|
>>> dpath = ub.Path.appdir('torch/tests/FxGraphDrawer').ensuredir() |
|
>>> fpath = dpath / 'linear.svg' |
|
>>> # draw the graph |
|
>>> g = FxGraphDrawer(symbolic_traced, "linear") |
|
>>> g.get_dot_graph().write_svg(fpath) |
|
""" |
|
if submod_name is None: |
|
return self.get_main_dot_graph() |
|
else: |
|
return self.get_submod_dot_graph(submod_name) |
|
|
|
def get_main_dot_graph(self) -> pydot.Dot: |
|
return self._dot_graphs[self._name] |
|
|
|
def get_submod_dot_graph(self, submod_name) -> pydot.Dot: |
|
return self._dot_graphs[f"{self._name}_{submod_name}"] |
|
|
|
def get_all_dot_graphs(self) -> Dict[str, pydot.Dot]: |
|
return self._dot_graphs |
|
|
|
def _get_node_style(self, node: torch.fx.Node) -> Dict[str, str]: |
|
|
|
template = { |
|
"shape": self.dot_graph_shape, |
|
"fillcolor": "#CAFFE3", |
|
"style": '"filled,rounded"', |
|
"fontcolor": "#000000", |
|
} |
|
if node.op in _COLOR_MAP: |
|
template["fillcolor"] = _COLOR_MAP[node.op] |
|
else: |
|
|
|
target_name = node._pretty_print_target(node.target) |
|
target_hash = int(hashlib.md5(target_name.encode()).hexdigest()[:8], 16) |
|
template["fillcolor"] = _HASH_COLOR_MAP[target_hash % len(_HASH_COLOR_MAP)] |
|
return template |
|
|
|
def _get_leaf_node( |
|
self, module: torch.nn.Module, node: torch.fx.Node |
|
) -> torch.nn.Module: |
|
py_obj = module |
|
assert isinstance(node.target, str) |
|
atoms = node.target.split(".") |
|
for atom in atoms: |
|
if not hasattr(py_obj, atom): |
|
raise RuntimeError( |
|
str(py_obj) + " does not have attribute " + atom + "!" |
|
) |
|
py_obj = getattr(py_obj, atom) |
|
return py_obj |
|
|
|
def _typename(self, target: Any) -> str: |
|
if isinstance(target, torch.nn.Module): |
|
ret = torch.typename(target) |
|
elif isinstance(target, str): |
|
ret = target |
|
else: |
|
ret = _get_qualified_name(target) |
|
|
|
|
|
|
|
|
|
return ret.replace("{", r"\{").replace("}", r"\}") |
|
|
|
|
|
|
|
|
|
def _shorten_file_name( |
|
self, |
|
full_file_name: str, |
|
truncate_to_last_n: int = 2, |
|
): |
|
splits = full_file_name.split('/') |
|
if len(splits) >= truncate_to_last_n: |
|
return '/'.join(splits[-truncate_to_last_n:]) |
|
return full_file_name |
|
|
|
|
|
def _get_node_label( |
|
self, |
|
module: torch.fx.GraphModule, |
|
node: torch.fx.Node, |
|
skip_node_names_in_args: bool, |
|
parse_stack_trace: bool, |
|
) -> str: |
|
def _get_str_for_args_kwargs(arg): |
|
if isinstance(arg, tuple): |
|
prefix, suffix = r"|args=(\l", r",\n)\l" |
|
arg_strs_list = [_format_arg(a, max_list_len=8) for a in arg] |
|
elif isinstance(arg, dict): |
|
prefix, suffix = r"|kwargs={\l", r",\n}\l" |
|
arg_strs_list = [ |
|
f"{k}: {_format_arg(v, max_list_len=8)}" |
|
for k, v in arg.items() |
|
] |
|
else: |
|
return "" |
|
|
|
|
|
if skip_node_names_in_args: |
|
arg_strs_list = [a for a in arg_strs_list if "%" not in a] |
|
if len(arg_strs_list) == 0: |
|
return "" |
|
arg_strs = prefix + r",\n".join(arg_strs_list) + suffix |
|
if len(arg_strs_list) == 1: |
|
arg_strs = arg_strs.replace(r"\l", "").replace(r"\n", "") |
|
return arg_strs.replace("{", r"\{").replace("}", r"\}") |
|
|
|
|
|
label = "{" + f"name=%{node.name}|op_code={node.op}\n" |
|
|
|
if node.op == "call_module": |
|
leaf_module = self._get_leaf_node(module, node) |
|
label += r"\n" + self._typename(leaf_module) + r"\n|" |
|
extra = "" |
|
if hasattr(leaf_module, "__constants__"): |
|
extra = r"\n".join( |
|
[f"{c}: {getattr(leaf_module, c)}" for c in leaf_module.__constants__] |
|
) |
|
label += extra + r"\n" |
|
else: |
|
label += f"|target={self._typename(node.target)}" + r"\n" |
|
if len(node.args) > 0: |
|
label += _get_str_for_args_kwargs(node.args) |
|
if len(node.kwargs) > 0: |
|
label += _get_str_for_args_kwargs(node.kwargs) |
|
label += f"|num_users={len(node.users)}" + r"\n" |
|
|
|
tensor_meta = node.meta.get('tensor_meta') |
|
label += self._tensor_meta_to_label(tensor_meta) |
|
|
|
|
|
|
|
buf_meta = node.meta.get('buf_meta', None) |
|
if buf_meta is not None: |
|
label += f"|buf={buf_meta.name}" + r"\n" |
|
label += f"|n_origin={buf_meta.n_origin}" + r"\n" |
|
|
|
|
|
|
|
if parse_stack_trace and node.stack_trace is not None: |
|
parsed_stack_trace = _parse_stack_trace(node.stack_trace) |
|
fname = self._shorten_file_name(parsed_stack_trace.file) |
|
label += f"|file={fname}:{parsed_stack_trace.lineno} {parsed_stack_trace.code}" + r"\n" |
|
|
|
|
|
return label + "}" |
|
|
|
def _tensor_meta_to_label(self, tm) -> str: |
|
if tm is None: |
|
return "" |
|
elif isinstance(tm, TensorMetadata): |
|
return self._stringify_tensor_meta(tm) |
|
elif isinstance(tm, list): |
|
result = "" |
|
for item in tm: |
|
result += self._tensor_meta_to_label(item) |
|
return result |
|
elif isinstance(tm, dict): |
|
result = "" |
|
for v in tm.values(): |
|
result += self._tensor_meta_to_label(v) |
|
return result |
|
elif isinstance(tm, tuple): |
|
result = "" |
|
for item in tm: |
|
result += self._tensor_meta_to_label(item) |
|
return result |
|
else: |
|
raise RuntimeError(f"Unsupported tensor meta type {type(tm)}") |
|
|
|
def _stringify_tensor_meta(self, tm: TensorMetadata) -> str: |
|
result = "" |
|
if not hasattr(tm, "dtype"): |
|
print("tm", tm) |
|
result += "|" + "dtype" + "=" + str(tm.dtype) + r"\n" |
|
result += "|" + "shape" + "=" + str(tuple(tm.shape)) + r"\n" |
|
result += "|" + "requires_grad" + "=" + str(tm.requires_grad) + r"\n" |
|
result += "|" + "stride" + "=" + str(tm.stride) + r"\n" |
|
if tm.is_quantized: |
|
assert tm.qparams is not None |
|
assert "qscheme" in tm.qparams |
|
qscheme = tm.qparams["qscheme"] |
|
if qscheme in { |
|
torch.per_tensor_affine, |
|
torch.per_tensor_symmetric, |
|
}: |
|
result += "|" + "q_scale" + "=" + str(tm.qparams["scale"]) + r"\n" |
|
result += "|" + "q_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n" |
|
elif qscheme in { |
|
torch.per_channel_affine, |
|
torch.per_channel_symmetric, |
|
torch.per_channel_affine_float_qparams, |
|
}: |
|
result += "|" + "q_per_channel_scale" + "=" + str(tm.qparams["scale"]) + r"\n" |
|
result += "|" + "q_per_channel_zero_point" + "=" + str(tm.qparams["zero_point"]) + r"\n" |
|
result += "|" + "q_per_channel_axis" + "=" + str(tm.qparams["axis"]) + r"\n" |
|
else: |
|
raise RuntimeError(f"Unsupported qscheme: {qscheme}") |
|
result += "|" + "qscheme" + "=" + str(tm.qparams["qscheme"]) + r"\n" |
|
return result |
|
|
|
def _get_tensor_label(self, t: torch.Tensor) -> str: |
|
return str(t.dtype) + str(list(t.shape)) + r"\n" |
|
|
|
|
|
|
|
def _to_dot( |
|
self, |
|
graph_module: torch.fx.GraphModule, |
|
name: str, |
|
ignore_getattr: bool, |
|
ignore_parameters_and_buffers: bool, |
|
skip_node_names_in_args: bool, |
|
parse_stack_trace: bool, |
|
) -> pydot.Dot: |
|
""" |
|
Actual interface to visualize a fx.Graph. Note that it takes in the GraphModule instead of the Graph. |
|
If ignore_parameters_and_buffers is True, the parameters and buffers |
|
created with the module will not be added as nodes and edges. |
|
""" |
|
|
|
|
|
dot_graph = pydot.Dot(name, rankdir="TB") |
|
|
|
|
|
buf_name_to_subgraph = {} |
|
|
|
for node in graph_module.graph.nodes: |
|
if ignore_getattr and node.op == "get_attr": |
|
continue |
|
|
|
style = self._get_node_style(node) |
|
dot_node = pydot.Node( |
|
node.name, label=self._get_node_label(graph_module, node, skip_node_names_in_args, parse_stack_trace), **style |
|
) |
|
|
|
current_graph = dot_graph |
|
|
|
buf_meta = node.meta.get('buf_meta', None) |
|
if buf_meta is not None and buf_meta.n_origin > 1: |
|
buf_name = buf_meta.name |
|
if buf_name not in buf_name_to_subgraph: |
|
buf_name_to_subgraph[buf_name] = pydot.Cluster(buf_name, label=buf_name) |
|
current_graph = buf_name_to_subgraph.get(buf_name) |
|
|
|
current_graph.add_node(dot_node) |
|
|
|
def get_module_params_or_buffers(): |
|
for pname, ptensor in chain( |
|
leaf_module.named_parameters(), leaf_module.named_buffers() |
|
): |
|
pname1 = node.name + "." + pname |
|
label1 = ( |
|
pname1 + "|op_code=get_" + "parameter" |
|
if isinstance(ptensor, torch.nn.Parameter) |
|
else "buffer" + r"\l" |
|
) |
|
dot_w_node = pydot.Node( |
|
pname1, |
|
label="{" + label1 + self._get_tensor_label(ptensor) + "}", |
|
**_WEIGHT_TEMPLATE, |
|
) |
|
dot_graph.add_node(dot_w_node) |
|
dot_graph.add_edge(pydot.Edge(pname1, node.name)) |
|
|
|
if node.op == "call_module": |
|
leaf_module = self._get_leaf_node(graph_module, node) |
|
|
|
if not ignore_parameters_and_buffers and not isinstance(leaf_module, torch.fx.GraphModule): |
|
get_module_params_or_buffers() |
|
|
|
for subgraph in buf_name_to_subgraph.values(): |
|
subgraph.set('color', 'royalblue') |
|
subgraph.set('penwidth', '2') |
|
dot_graph.add_subgraph(subgraph) |
|
|
|
for node in graph_module.graph.nodes: |
|
if ignore_getattr and node.op == "get_attr": |
|
continue |
|
|
|
for user in node.users: |
|
dot_graph.add_edge(pydot.Edge(node.name, user.name)) |
|
|
|
return dot_graph |
|
|
|
else: |
|
if not TYPE_CHECKING: |
|
@compatibility(is_backward_compatible=False) |
|
class FxGraphDrawer: |
|
def __init__( |
|
self, |
|
graph_module: torch.fx.GraphModule, |
|
name: str, |
|
ignore_getattr: bool = False, |
|
ignore_parameters_and_buffers: bool = False, |
|
skip_node_names_in_args: bool = True, |
|
parse_stack_trace: bool = False, |
|
dot_graph_shape: Optional[str] = None, |
|
): |
|
raise RuntimeError('FXGraphDrawer requires the pydot package to be installed. Please install ' |
|
'pydot through your favorite Python package manager.') |
|
|