File size: 4,529 Bytes
ff1446e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import os
import argparse
import onnxruntime as ort
from utils import *
CFG = {
"name": "mobilenet0.25",
"min_sizes": [[16, 32], [64, 128], [256, 512]],
"steps": [8, 16, 32],
"variance": [0.1, 0.2],
"clip": False,
}
INPUT_SIZE = [608, 640] #resize scale
DEVICE = torch.device("cpu")
def vis(img_raw, dets, vis_thres):
"""Visualization original image
Args:
img_raw: origin image
dets: detections
vis_thres: visualization threshold
Returns:
visualization results
"""
for b in dets:
if b[4] < vis_thres:
continue
text = "{:.4f}".format(b[4])
b = list(map(int, b))
cv2.rectangle(img_raw, (b[0], b[1]), (b[2], b[3]), (0, 0, 255), 2)
cx = b[0]
cy = b[1] + 12
cv2.putText(img_raw, text, (cx, cy), cv2.FONT_HERSHEY_DUPLEX, 0.5, (255, 255, 255),)
# landms
cv2.circle(img_raw, (b[5], b[6]), 1, (0, 0, 255), 4)
cv2.circle(img_raw, (b[7], b[8]), 1, (0, 255, 255), 4)
cv2.circle(img_raw, (b[9], b[10]), 1, (255, 0, 255), 4)
cv2.circle(img_raw, (b[11], b[12]), 1, (0, 255, 0), 4)
cv2.circle(img_raw, (b[13], b[14]), 1, (255, 0, 0), 4)
# save image
if not os.path.exists("./results/"):
os.makedirs("./results/")
name = "./results/" + 'result' + ".jpg"
cv2.imwrite(name, img_raw)
def Retinaface_inference(run_ort, args):
"""Infer an image with onnx seession
Args:
run_ort: Onnx session
args: including image path and hyperparameters
Returns: boxes_list, confidence_list, landm_list
boxes_list = [[left, top, right, bottom]...]
confidence_list = [[confidence]...]
landm_list = [[landms(dim=10)]...]
"""
img_raw = cv2.imread(args.image_path, cv2.IMREAD_COLOR)
# preprocess
img, scale, resize = preprocess(img_raw, INPUT_SIZE, DEVICE)
# forward
outputs = run_ort.run(None, {run_ort.get_inputs()[0].name: img})
# postprocess
dets = postprocess(CFG, img, outputs, scale, resize, args.confidence_threshold, args.nms_threshold, DEVICE)
# result list
boxes = dets[:, :4]
confidences = dets[:, 4:5]
landms = dets[:, 5:]
boxes_list = [box.tolist() for box in boxes]
confidence_list = [confidence.tolist() for confidence in confidences]
landm_list = [landm.tolist() for landm in landms]
# save image
if args.save_image:
vis(img_raw, dets, args.vis_thres)
return boxes_list, confidence_list, landm_list
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Retinaface")
parser.add_argument(
"-m",
"--trained_model",
default="./weights/RetinaFace_int.onnx",
type=str,
help="Trained state_dict file path to open",
)
parser.add_argument(
"--image_path",
default="./data/widerface/val/images/18--Concerts/18_Concerts_Concerts_18_38.jpg",
type=str,
help="image path",
)
parser.add_argument(
"--confidence_threshold",
default=0.4,
type=float,
help="confidence_threshold"
)
parser.add_argument(
"--nms_threshold",
default=0.4,
type=float,
help="nms_threshold"
)
parser.add_argument(
"-s",
"--save_image",
action="store_true",
default=False,
help="show detection results",
)
parser.add_argument(
"--vis_thres",
default=0.5,
type=float,
help="visualization_threshold"
)
parser.add_argument(
"--ipu",
action="store_true",
help="Use IPU for inference.",
)
parser.add_argument(
"--provider_config",
type=str,
default="vaip_config.json",
help="Path of the config file for seting provider_options.",
)
args = parser.parse_args()
if args.ipu:
providers = ["VitisAIExecutionProvider"]
provider_options = [{"config_file": args.provider_config}]
else:
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
provider_options = None
print("Loading pretrained model from {}".format(args.trained_model))
run_ort = ort.InferenceSession(args.trained_model, providers=providers, provider_options=provider_options)
boxes_list, confidence_list, landm_list = Retinaface_inference(run_ort, args)
print('inference done!')
print(boxes_list, confidence_list, landm_list)
|