Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,804 Bytes
28c256d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import argparse
from collections import OrderedDict
import torch
from mmengine.fileio import load
from mmengine.runner import save_checkpoint
def convert(src: str, dst: str, prefix: str = 'd2_model') -> None:
"""Convert Detectron2 checkpoint to MMDetection style.
Args:
src (str): The Detectron2 checkpoint path, should endswith `pkl`.
dst (str): The MMDetection checkpoint path.
prefix (str): The prefix of MMDetection model, defaults to 'd2_model'.
"""
# load arch_settings
assert src.endswith('pkl'), \
'the source Detectron2 checkpoint should endswith `pkl`.'
d2_model = load(src, encoding='latin1').get('model')
assert d2_model is not None
# convert to mmdet style
dst_state_dict = OrderedDict()
for name, value in d2_model.items():
if not isinstance(value, torch.Tensor):
value = torch.from_numpy(value)
dst_state_dict[f'{prefix}.{name}'] = value
mmdet_model = dict(state_dict=dst_state_dict, meta=dict())
save_checkpoint(mmdet_model, dst)
print(f'Convert Detectron2 model {src} to MMDetection model {dst}')
def main():
parser = argparse.ArgumentParser(
description='Convert Detectron2 checkpoint to MMDetection style')
parser.add_argument('src', help='Detectron2 model path')
parser.add_argument('dst', help='MMDetectron model save path')
parser.add_argument(
'--prefix', default='d2_model', type=str, help='prefix of the model')
args = parser.parse_args()
convert(args.src, args.dst, args.prefix)
if __name__ == '__main__':
main()
|