|
""" Caffe2 validation script |
|
|
|
This script runs Caffe2 benchmark on exported ONNX model. |
|
It is a useful tool for reporting model FLOPS. |
|
|
|
Copyright 2020 Ross Wightman |
|
""" |
|
import argparse |
|
from caffe2.python import core, workspace, model_helper |
|
from caffe2.proto import caffe2_pb2 |
|
|
|
|
|
parser = argparse.ArgumentParser(description='Caffe2 Model Benchmark') |
|
parser.add_argument('--c2-prefix', default='', type=str, metavar='NAME', |
|
help='caffe2 model pb name prefix') |
|
parser.add_argument('--c2-init', default='', type=str, metavar='PATH', |
|
help='caffe2 model init .pb') |
|
parser.add_argument('--c2-predict', default='', type=str, metavar='PATH', |
|
help='caffe2 model predict .pb') |
|
parser.add_argument('-b', '--batch-size', default=1, type=int, |
|
metavar='N', help='mini-batch size (default: 1)') |
|
parser.add_argument('--img-size', default=224, type=int, |
|
metavar='N', help='Input image dimension, uses model default if empty') |
|
|
|
|
|
def main(): |
|
args = parser.parse_args() |
|
args.gpu_id = 0 |
|
if args.c2_prefix: |
|
args.c2_init = args.c2_prefix + '.init.pb' |
|
args.c2_predict = args.c2_prefix + '.predict.pb' |
|
|
|
model = model_helper.ModelHelper(name="le_net", init_params=False) |
|
|
|
|
|
init_net_proto = caffe2_pb2.NetDef() |
|
with open(args.c2_init, "rb") as f: |
|
init_net_proto.ParseFromString(f.read()) |
|
model.param_init_net = core.Net(init_net_proto) |
|
|
|
|
|
predict_net_proto = caffe2_pb2.NetDef() |
|
with open(args.c2_predict, "rb") as f: |
|
predict_net_proto.ParseFromString(f.read()) |
|
model.net = core.Net(predict_net_proto) |
|
|
|
|
|
|
|
|
|
|
|
|
|
input_blob = model.net.external_inputs[0] |
|
model.param_init_net.GaussianFill( |
|
[], |
|
input_blob.GetUnscopedName(), |
|
shape=(args.batch_size, 3, args.img_size, args.img_size), |
|
mean=0.0, |
|
std=1.0) |
|
workspace.RunNetOnce(model.param_init_net) |
|
workspace.CreateNet(model.net, overwrite=True) |
|
workspace.BenchmarkNet(model.net.Proto().name, 5, 20, True) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|