Spaces:
Build error
Build error
File size: 4,823 Bytes
d7a991a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
import warnings
from argparse import ArgumentParser, Namespace
from tempfile import TemporaryDirectory
import mmcv
import torch
from mmcv.runner import CheckpointLoader
try:
from model_archiver.model_packaging import package_model
from model_archiver.model_packaging_utils import ModelExportUtils
except ImportError:
package_model = None
def mmpose2torchserve(config_file: str,
checkpoint_file: str,
output_folder: str,
model_name: str,
model_version: str = '1.0',
force: bool = False):
"""Converts MMPose model (config + checkpoint) to TorchServe `.mar`.
Args:
config_file:
In MMPose config format.
The contents vary for each task repository.
checkpoint_file:
In MMPose checkpoint format.
The contents vary for each task repository.
output_folder:
Folder where `{model_name}.mar` will be created.
The file created will be in TorchServe archive format.
model_name:
If not None, used for naming the `{model_name}.mar` file
that will be created under `output_folder`.
If None, `{Path(checkpoint_file).stem}` will be used.
model_version:
Model's version.
force:
If True, if there is an existing `{model_name}.mar`
file under `output_folder` it will be overwritten.
"""
mmcv.mkdir_or_exist(output_folder)
config = mmcv.Config.fromfile(config_file)
with TemporaryDirectory() as tmpdir:
model_file = osp.join(tmpdir, 'config.py')
config.dump(model_file)
handler_path = osp.join(osp.dirname(__file__), 'mmpose_handler.py')
model_name = model_name or osp.splitext(
osp.basename(checkpoint_file))[0]
# use mmcv CheckpointLoader if checkpoint is not from a local file
if not osp.isfile(checkpoint_file):
ckpt = CheckpointLoader.load_checkpoint(checkpoint_file)
checkpoint_file = osp.join(tmpdir, 'checkpoint.pth')
with open(checkpoint_file, 'wb') as f:
torch.save(ckpt, f)
args = Namespace(
**{
'model_file': model_file,
'serialized_file': checkpoint_file,
'handler': handler_path,
'model_name': model_name,
'version': model_version,
'export_path': output_folder,
'force': force,
'requirements_file': None,
'extra_files': None,
'runtime': 'python',
'archive_format': 'default'
})
manifest = ModelExportUtils.generate_manifest_json(args)
package_model(args, manifest)
def parse_args():
parser = ArgumentParser(
description='Convert MMPose models to TorchServe `.mar` format.')
parser.add_argument('config', type=str, help='config file path')
parser.add_argument('checkpoint', type=str, help='checkpoint file path')
parser.add_argument(
'--output-folder',
type=str,
required=True,
help='Folder where `{model_name}.mar` will be created.')
parser.add_argument(
'--model-name',
type=str,
default=None,
help='If not None, used for naming the `{model_name}.mar`'
'file that will be created under `output_folder`.'
'If None, `{Path(checkpoint_file).stem}` will be used.')
parser.add_argument(
'--model-version',
type=str,
default='1.0',
help='Number used for versioning.')
parser.add_argument(
'-f',
'--force',
action='store_true',
help='overwrite the existing `{model_name}.mar`')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = parse_args()
# Following strings of text style are from colorama package
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
red_text, blue_text = '\x1b[31m', '\x1b[34m'
white_background = '\x1b[107m'
msg = white_background + bright_style + red_text
msg += 'DeprecationWarning: This tool will be deprecated in future. '
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg += reset_style
warnings.warn(msg)
if package_model is None:
raise ImportError('`torch-model-archiver` is required.'
'Try: pip install torch-model-archiver')
mmpose2torchserve(args.config, args.checkpoint, args.output_folder,
args.model_name, args.model_version, args.force)
|