Spaces:
Runtime error
Runtime error
File size: 2,266 Bytes
0b7b08a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
#!/usr/bin/env python3
# -*- coding:utf-8 -*-
# Copyright (c) Megvii, Inc. and its affiliates.
import argparse
import os
from loguru import logger
import torch
from yolox.exp import get_exp
def make_parser():
parser = argparse.ArgumentParser("YOLOX torchscript deploy")
parser.add_argument(
"--output-name", type=str, default="yolox.torchscript.pt", help="output name of models"
)
parser.add_argument("--batch-size", type=int, default=1, help="batch size")
parser.add_argument(
"-f",
"--exp_file",
default=None,
type=str,
help="experiment description file",
)
parser.add_argument("-expn", "--experiment-name", type=str, default=None)
parser.add_argument("-n", "--name", type=str, default=None, help="model name")
parser.add_argument("-c", "--ckpt", default=None, type=str, help="ckpt path")
parser.add_argument(
"--decode_in_inference",
action="store_true",
help="decode in inference or not"
)
parser.add_argument(
"opts",
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER,
)
return parser
@logger.catch
def main():
args = make_parser().parse_args()
logger.info("args value: {}".format(args))
exp = get_exp(args.exp_file, args.name)
exp.merge(args.opts)
if not args.experiment_name:
args.experiment_name = exp.exp_name
model = exp.get_model()
if args.ckpt is None:
file_name = os.path.join(exp.output_dir, args.experiment_name)
ckpt_file = os.path.join(file_name, "best_ckpt.pth")
else:
ckpt_file = args.ckpt
# load the model state dict
ckpt = torch.load(ckpt_file, map_location="cpu")
model.eval()
if "model" in ckpt:
ckpt = ckpt["model"]
model.load_state_dict(ckpt)
model.head.decode_in_inference = args.decode_in_inference
logger.info("loading checkpoint done.")
dummy_input = torch.randn(args.batch_size, 3, exp.test_size[0], exp.test_size[1])
mod = torch.jit.trace(model, dummy_input)
mod.save(args.output_name)
logger.info("generated torchscript model named {}".format(args.output_name))
if __name__ == "__main__":
main()
|