Spaces:
Running
Running
# Copyright (c) Tencent Inc. All rights reserved. | |
import argparse | |
import logging | |
import os | |
import os.path as osp | |
from functools import partial | |
import mmengine | |
import torch.multiprocessing as mp | |
from torch.multiprocessing import Process, set_start_method | |
from mmdeploy.apis import (create_calib_input_data, extract_model, | |
get_predefined_partition_cfg, torch2onnx, | |
torch2torchscript, visualize_model) | |
from mmdeploy.apis.core import PIPELINE_MANAGER | |
from mmdeploy.apis.utils import to_backend | |
from mmdeploy.backend.sdk.export_info import export2SDK | |
from mmdeploy.utils import (IR, Backend, get_backend, get_calib_filename, | |
get_ir_config, get_partition_config, | |
get_root_logger, load_config, target_wrapper) | |
def parse_args(): | |
parser = argparse.ArgumentParser(description='Export model to backends.') | |
parser.add_argument('deploy_cfg', help='deploy config path') | |
parser.add_argument('model_cfg', help='model config path') | |
parser.add_argument('checkpoint', help='model checkpoint path') | |
parser.add_argument('img', help='image used to convert model model') | |
parser.add_argument( | |
'--test-img', | |
default=None, | |
type=str, | |
nargs='+', | |
help='image used to test model') | |
parser.add_argument( | |
'--work-dir', | |
default=os.getcwd(), | |
help='the dir to save logs and models') | |
parser.add_argument( | |
'--calib-dataset-cfg', | |
help='dataset config path used to calibrate in int8 mode. If not \ | |
specified, it will use "val" dataset in model config instead.', | |
default=None) | |
parser.add_argument( | |
'--device', help='device used for conversion', default='cpu') | |
parser.add_argument( | |
'--log-level', | |
help='set log level', | |
default='INFO', | |
choices=list(logging._nameToLevel.keys())) | |
parser.add_argument( | |
'--show', action='store_true', help='Show detection outputs') | |
parser.add_argument( | |
'--dump-info', action='store_true', help='Output information for SDK') | |
parser.add_argument( | |
'--quant-image-dir', | |
default=None, | |
help='Image directory for quantize model.') | |
parser.add_argument( | |
'--quant', action='store_true', help='Quantize model to low bit.') | |
parser.add_argument( | |
'--uri', | |
default='192.168.1.1:60000', | |
help='Remote ipv4:port or ipv6:port for inference on edge device.') | |
args = parser.parse_args() | |
return args | |
def create_process(name, target, args, kwargs, ret_value=None): | |
logger = get_root_logger() | |
logger.info(f'{name} start.') | |
log_level = logger.level | |
wrap_func = partial(target_wrapper, target, log_level, ret_value) | |
process = Process(target=wrap_func, args=args, kwargs=kwargs) | |
process.start() | |
process.join() | |
if ret_value is not None: | |
if ret_value.value != 0: | |
logger.error(f'{name} failed.') | |
exit(1) | |
else: | |
logger.info(f'{name} success.') | |
def torch2ir(ir_type: IR): | |
"""Return the conversion function from torch to the intermediate | |
representation. | |
Args: | |
ir_type (IR): The type of the intermediate representation. | |
""" | |
if ir_type == IR.ONNX: | |
return torch2onnx | |
elif ir_type == IR.TORCHSCRIPT: | |
return torch2torchscript | |
else: | |
raise KeyError(f'Unexpected IR type {ir_type}') | |
def main(): | |
args = parse_args() | |
set_start_method('spawn', force=True) | |
logger = get_root_logger() | |
log_level = logging.getLevelName(args.log_level) | |
logger.setLevel(log_level) | |
pipeline_funcs = [ | |
torch2onnx, torch2torchscript, extract_model, create_calib_input_data | |
] | |
PIPELINE_MANAGER.enable_multiprocess(True, pipeline_funcs) | |
PIPELINE_MANAGER.set_log_level(log_level, pipeline_funcs) | |
deploy_cfg_path = args.deploy_cfg | |
model_cfg_path = args.model_cfg | |
checkpoint_path = args.checkpoint | |
quant = args.quant | |
quant_image_dir = args.quant_image_dir | |
# load deploy_cfg | |
deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path) | |
# create work_dir if not | |
mmengine.mkdir_or_exist(osp.abspath(args.work_dir)) | |
if args.dump_info: | |
export2SDK( | |
deploy_cfg, | |
model_cfg, | |
args.work_dir, | |
pth=checkpoint_path, | |
device=args.device) | |
ret_value = mp.Value('d', 0, lock=False) | |
# convert to IR | |
ir_config = get_ir_config(deploy_cfg) | |
ir_save_file = ir_config['save_file'] | |
ir_type = IR.get(ir_config['type']) | |
torch2ir(ir_type)( | |
args.img, | |
args.work_dir, | |
ir_save_file, | |
deploy_cfg_path, | |
model_cfg_path, | |
checkpoint_path, | |
device=args.device) | |
# convert backend | |
ir_files = [osp.join(args.work_dir, ir_save_file)] | |
# partition model | |
partition_cfgs = get_partition_config(deploy_cfg) | |
if partition_cfgs is not None: | |
if 'partition_cfg' in partition_cfgs: | |
partition_cfgs = partition_cfgs.get('partition_cfg', None) | |
else: | |
assert 'type' in partition_cfgs | |
partition_cfgs = get_predefined_partition_cfg( | |
deploy_cfg, partition_cfgs['type']) | |
origin_ir_file = ir_files[0] | |
ir_files = [] | |
for partition_cfg in partition_cfgs: | |
save_file = partition_cfg['save_file'] | |
save_path = osp.join(args.work_dir, save_file) | |
start = partition_cfg['start'] | |
end = partition_cfg['end'] | |
dynamic_axes = partition_cfg.get('dynamic_axes', None) | |
extract_model( | |
origin_ir_file, | |
start, | |
end, | |
dynamic_axes=dynamic_axes, | |
save_file=save_path) | |
ir_files.append(save_path) | |
# calib data | |
calib_filename = get_calib_filename(deploy_cfg) | |
if calib_filename is not None: | |
calib_path = osp.join(args.work_dir, calib_filename) | |
create_calib_input_data( | |
calib_path, | |
deploy_cfg_path, | |
model_cfg_path, | |
checkpoint_path, | |
dataset_cfg=args.calib_dataset_cfg, | |
dataset_type='val', | |
device=args.device) | |
backend_files = ir_files | |
# convert backend | |
backend = get_backend(deploy_cfg) | |
# preprocess deploy_cfg | |
if backend == Backend.RKNN: | |
# TODO: Add this to task_processor in the future | |
import tempfile | |
from mmdeploy.utils import (get_common_config, get_normalization, | |
get_quantization_config, | |
get_rknn_quantization) | |
quantization_cfg = get_quantization_config(deploy_cfg) | |
common_params = get_common_config(deploy_cfg) | |
if get_rknn_quantization(deploy_cfg) is True: | |
transform = get_normalization(model_cfg) | |
common_params.update( | |
dict( | |
mean_values=[transform['mean']], | |
std_values=[transform['std']])) | |
dataset_file = tempfile.NamedTemporaryFile(suffix='.txt').name | |
with open(dataset_file, 'w') as f: | |
f.writelines([osp.abspath(args.img)]) | |
if quantization_cfg.get('dataset', None) is None: | |
quantization_cfg['dataset'] = dataset_file | |
if backend == Backend.ASCEND: | |
# TODO: Add this to backend manager in the future | |
if args.dump_info: | |
from mmdeploy.backend.ascend import update_sdk_pipeline | |
update_sdk_pipeline(args.work_dir) | |
if backend == Backend.VACC: | |
# TODO: Add this to task_processor in the future | |
from onnx2vacc_quant_dataset import get_quant | |
from mmdeploy.utils import get_model_inputs | |
deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path) | |
model_inputs = get_model_inputs(deploy_cfg) | |
for onnx_path, model_input in zip(ir_files, model_inputs): | |
quant_mode = model_input.get('qconfig', {}).get('dtype', 'fp16') | |
assert quant_mode in ['int8', | |
'fp16'], quant_mode + ' not support now' | |
shape_dict = model_input.get('shape', {}) | |
if quant_mode == 'int8': | |
create_process( | |
'vacc quant dataset', | |
target=get_quant, | |
args=(deploy_cfg, model_cfg, shape_dict, checkpoint_path, | |
args.work_dir, args.device), | |
kwargs=dict(), | |
ret_value=ret_value) | |
# convert to backend | |
PIPELINE_MANAGER.set_log_level(log_level, [to_backend]) | |
if backend == Backend.TENSORRT: | |
PIPELINE_MANAGER.enable_multiprocess(True, [to_backend]) | |
backend_files = to_backend( | |
backend, | |
ir_files, | |
work_dir=args.work_dir, | |
deploy_cfg=deploy_cfg, | |
log_level=log_level, | |
device=args.device, | |
uri=args.uri) | |
# ncnn quantization | |
if backend == Backend.NCNN and quant: | |
from onnx2ncnn_quant_table import get_table | |
from mmdeploy.apis.ncnn import get_quant_model_file, ncnn2int8 | |
model_param_paths = backend_files[::2] | |
model_bin_paths = backend_files[1::2] | |
backend_files = [] | |
for onnx_path, model_param_path, model_bin_path in zip( | |
ir_files, model_param_paths, model_bin_paths): | |
deploy_cfg, model_cfg = load_config(deploy_cfg_path, | |
model_cfg_path) | |
quant_onnx, quant_table, quant_param, quant_bin = get_quant_model_file( # noqa: E501 | |
onnx_path, args.work_dir) | |
create_process( | |
'ncnn quant table', | |
target=get_table, | |
args=(onnx_path, deploy_cfg, model_cfg, quant_onnx, | |
quant_table, quant_image_dir, args.device), | |
kwargs=dict(), | |
ret_value=ret_value) | |
create_process( | |
'ncnn_int8', | |
target=ncnn2int8, | |
args=(model_param_path, model_bin_path, quant_table, | |
quant_param, quant_bin), | |
kwargs=dict(), | |
ret_value=ret_value) | |
backend_files += [quant_param, quant_bin] | |
if args.test_img is None: | |
args.test_img = args.img | |
extra = dict( | |
backend=backend, | |
output_file=osp.join(args.work_dir, f'output_{backend.value}.jpg'), | |
show_result=args.show) | |
if backend == Backend.SNPE: | |
extra['uri'] = args.uri | |
# get backend inference result, try render | |
create_process( | |
f'visualize {backend.value} model', | |
target=visualize_model, | |
args=(model_cfg_path, deploy_cfg_path, backend_files, args.test_img, | |
args.device), | |
kwargs=extra, | |
ret_value=ret_value) | |
# get pytorch model inference result, try visualize if possible | |
create_process( | |
'visualize pytorch model', | |
target=visualize_model, | |
args=(model_cfg_path, deploy_cfg_path, [checkpoint_path], | |
args.test_img, args.device), | |
kwargs=dict( | |
backend=Backend.PYTORCH, | |
output_file=osp.join(args.work_dir, 'output_pytorch.jpg'), | |
show_result=args.show), | |
ret_value=ret_value) | |
logger.info('All process success.') | |
if __name__ == '__main__': | |
main() | |