|
import os
|
|
import torch
|
|
import argparse
|
|
|
|
from easy_ViTPose.vit_models.model import ViTPose
|
|
from easy_ViTPose.vit_utils.util import infer_dataset_by_path, dyn_model_import
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--model-ckpt', type=str, required=True,
|
|
help='The torch model that shall be used for conversion')
|
|
parser.add_argument('--model-name', type=str, required=True, choices=['s', 'b', 'l', 'h'],
|
|
help='[s: ViT-S, b: ViT-B, l: ViT-L, h: ViT-H]')
|
|
parser.add_argument('--output', type=str, default='ckpts/',
|
|
help='File (without extension) or dir path for checkpoint output')
|
|
parser.add_argument('--dataset', type=str, required=False, default=None,
|
|
help='Name of the dataset. If None it"s extracted from the file name. \
|
|
["coco", "coco_25", "wholebody", "mpii", "ap10k", "apt36k", "aic"]')
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
dataset = args.dataset
|
|
if dataset is None:
|
|
dataset = infer_dataset_by_path(args.model_ckpt)
|
|
assert dataset in ['mpii', 'coco', 'coco_25', 'wholebody', 'aic', 'ap10k', 'apt36k'], \
|
|
'The specified dataset is not valid'
|
|
model_cfg = dyn_model_import(dataset, args.model_name)
|
|
|
|
|
|
print('>>> Converting to ONNX')
|
|
CKPT_PATH = args.model_ckpt
|
|
C, H, W = (3, 256, 192)
|
|
|
|
model = ViTPose(model_cfg)
|
|
|
|
ckpt = torch.load(CKPT_PATH, map_location='cpu', weights_only=True)
|
|
if 'state_dict' in ckpt:
|
|
ckpt = ckpt['state_dict']
|
|
|
|
model.load_state_dict(ckpt)
|
|
model.eval()
|
|
|
|
input_names = ["input_0"]
|
|
output_names = ["output_0"]
|
|
|
|
device = next(model.parameters()).device
|
|
inputs = torch.randn(1, C, H, W).to(device)
|
|
|
|
dynamic_axes = {'input_0': {0: 'batch_size'},
|
|
'output_0': {0: 'batch_size'}}
|
|
|
|
out_name = os.path.basename(args.model_ckpt).replace('.pth', '.onnx')
|
|
if not os.path.isdir(args.output):
|
|
out_name = os.path.basename(args.output)
|
|
output_onnx = os.path.join(os.path.dirname(args.output), out_name)
|
|
|
|
torch_out = torch.onnx.export(model, inputs, output_onnx, export_params=True, verbose=False,
|
|
input_names=input_names, output_names=output_names,
|
|
dynamic_axes=dynamic_axes)
|
|
print(f">>> Saved at: {os.path.abspath(output_onnx)}")
|
|
print('=' * 80)
|
|
print()
|
|
|
|
try:
|
|
import torch_tensorrt
|
|
except ModuleNotFoundError:
|
|
print('>>> TRT module not found, skipping')
|
|
import sys
|
|
sys.exit()
|
|
|
|
|
|
print('>>> Converting to TRT')
|
|
trt_ts_module = torch_tensorrt.compile(model,
|
|
|
|
inputs = [
|
|
torch_tensorrt.Input(
|
|
shape=[1, C, H, W],
|
|
dtype=torch.float32
|
|
)
|
|
],
|
|
|
|
|
|
enabled_precisions = {torch.float32},
|
|
workspace_size = 1 << 28
|
|
)
|
|
|
|
|
|
output_trt = output_onnx.replace('.onnx', '.engine')
|
|
|
|
input_names = ["input_0"]
|
|
output_names = ["output_0"]
|
|
|
|
device = next(model.parameters()).device
|
|
torch.jit.save(trt_ts_module, output_trt)
|
|
|
|
print(f">>> Saved at: {os.path.abspath(output_trt)}")
|
|
|