|
""" ONNX optimization script |
|
|
|
Run ONNX models through the optimizer to prune unneeded nodes, fuse batchnorm layers into conv, etc. |
|
|
|
NOTE: This isn't working consistently in recent PyTorch/ONNX combos (ie PyTorch 1.6 and ONNX 1.7), |
|
it seems time to switch to using the onnxruntime online optimizer (can also be saved for offline). |
|
|
|
Copyright 2020 Ross Wightman |
|
""" |
|
import argparse |
|
import warnings |
|
|
|
import onnx |
|
from onnx import optimizer |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Optimize ONNX model") |
|
|
|
parser.add_argument("model", help="The ONNX model") |
|
parser.add_argument("--output", required=True, help="The optimized model output filename") |
|
|
|
|
|
def traverse_graph(graph, prefix=''): |
|
content = [] |
|
indent = prefix + ' ' |
|
graphs = [] |
|
num_nodes = 0 |
|
for node in graph.node: |
|
pn, gs = onnx.helper.printable_node(node, indent, subgraphs=True) |
|
assert isinstance(gs, list) |
|
content.append(pn) |
|
graphs.extend(gs) |
|
num_nodes += 1 |
|
for g in graphs: |
|
g_count, g_str = traverse_graph(g) |
|
content.append('\n' + g_str) |
|
num_nodes += g_count |
|
return num_nodes, '\n'.join(content) |
|
|
|
|
|
def main(): |
|
args = parser.parse_args() |
|
onnx_model = onnx.load(args.model) |
|
num_original_nodes, original_graph_str = traverse_graph(onnx_model.graph) |
|
|
|
|
|
passes = [ |
|
|
|
'eliminate_identity', |
|
'eliminate_nop_dropout', |
|
'eliminate_nop_pad', |
|
'eliminate_nop_transpose', |
|
'eliminate_unused_initializer', |
|
'extract_constant_to_initializer', |
|
'fuse_add_bias_into_conv', |
|
'fuse_bn_into_conv', |
|
'fuse_consecutive_concats', |
|
'fuse_consecutive_reduce_unsqueeze', |
|
'fuse_consecutive_squeezes', |
|
'fuse_consecutive_transposes', |
|
|
|
'fuse_pad_into_conv', |
|
|
|
|
|
] |
|
|
|
|
|
|
|
|
|
|
|
warnings.warn("I've had issues with optimizer in recent versions of PyTorch / ONNX." |
|
"Try onnxruntime optimization if this doesn't work.") |
|
optimized_model = optimizer.optimize(onnx_model, passes) |
|
|
|
num_optimized_nodes, optimzied_graph_str = traverse_graph(optimized_model.graph) |
|
print('==> The model after optimization:\n{}\n'.format(optimzied_graph_str)) |
|
print('==> The optimized model has {} nodes, the original had {}.'.format(num_optimized_nodes, num_original_nodes)) |
|
|
|
|
|
onnx.save(optimized_model, args.output) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|