HRNet / hrnet_quantized_onnx_eval.py
zhengrongzhang's picture
Upload 3 files (#2)
c139485
import os
import argparse
import random
import onnxruntime
import numpy as np
import torch
from torch.nn import functional as F
from torch.utils import data
import cv2
from PIL import Image
from tqdm import tqdm
from utils import input_transform, pad_image, resize_image, preprocess, get_confusion_matrix
parser = argparse.ArgumentParser(description='HRNet')
parser.add_argument('-m', '--onnx-model', default='',
type=str, help='Path to onnx model.')
parser.add_argument('-r', '--root', default='',
type=str, help='Path to dataset root.')
parser.add_argument('-l', '--list_path', default='',
type=str, help='Path to dataset list.')
parser.add_argument("--ipu", action="store_true", help="Use IPU for inference.")
parser.add_argument("--provider_config", type=str,
default="vaip_config.json", help="Path of the config file for seting provider_options.")
args = parser.parse_args()
INPUT_SIZE = [512, 1024]
NUM_CLASSES = 19
IGNORE_LABEL = 255
class Cityscapes(data.Dataset):
def __init__(self,
root,
list_path,
num_classes=19,
downsample_rate=8,
ignore_label=-1):
self.root = root
self.list_path = list_path
self.num_classes = num_classes
self.downsample_rate = downsample_rate
self.img_list = [line.strip().split() for line in open(root+list_path)]
self.files = self.read_files()
self.label_mapping = {-1: ignore_label, 0: ignore_label,
1: ignore_label, 2: ignore_label,
3: ignore_label, 4: ignore_label,
5: ignore_label, 6: ignore_label,
7: 0, 8: 1, 9: ignore_label,
10: ignore_label, 11: 2, 12: 3,
13: 4, 14: ignore_label, 15: ignore_label,
16: ignore_label, 17: 5, 18: ignore_label,
19: 6, 20: 7, 21: 8, 22: 9, 23: 10, 24: 11,
25: 12, 26: 13, 27: 14, 28: 15,
29: ignore_label, 30: ignore_label,
31: 16, 32: 17, 33: 18}
def read_files(self):
files = []
for item in self.img_list:
image_path, label_path = item
name = os.path.splitext(os.path.basename(label_path))[0]
files.append({
"img": image_path,
"label": label_path,
"name": name,
})
return files
def __len__(self):
return len(self.files)
def convert_label(self, label, inverse=False):
temp = label.copy()
if inverse:
for v, k in self.label_mapping.items():
label[temp == k] = v
else:
for k, v in self.label_mapping.items():
label[temp == k] = v
return label
def __getitem__(self, index):
item = self.files[index]
image = cv2.imread(os.path.join(self.root, item["img"]),
cv2.IMREAD_COLOR)
label = cv2.imread(os.path.join(self.root, item["label"]),
cv2.IMREAD_GRAYSCALE)
label = self.convert_label(label)
image, label = self.gen_sample(image, label)
return image.copy(), label.copy()
def gen_sample(self, image, label):
label = self.label_transform(label)
# image = image.transpose((2, 0, 1))
if self.downsample_rate != 1:
label = cv2.resize(
label,
None,
fx=self.downsample_rate,
fy=self.downsample_rate,
interpolation=cv2.INTER_NEAREST
)
return image, label
def label_transform(self, label):
return np.array(label).astype('int32')
def run_onnx_inference(ort_session, img):
"""Infer an image with onnx seession
Args:
ort_session: Onnx session
img (ndarray): Image to be infered.
Returns:
ndarray: Model inference result.
"""
pre_img, pad_h, pad_w = preprocess(img)
# transform chw into hwc format
img = np.expand_dims(pre_img, 0)
img = np.transpose(img, (0,2,3,1))
ort_inputs = {ort_session.get_inputs()[0].name: img}
o1 = ort_session.run(None, ort_inputs)[0]
h, w = o1.shape[-2:]
h_cut = int(h / INPUT_SIZE[0] * pad_h)
w_cut = int(w / INPUT_SIZE[1] * pad_w)
o1 = o1[..., :h - h_cut, :w - w_cut]
return o1
def testval(ort_session, root, list_path):
test_dataset = Cityscapes(
root=root,
list_path=list_path,
num_classes=NUM_CLASSES,
ignore_label=IGNORE_LABEL,
downsample_rate=1)
testloader = torch.utils.data.DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
num_workers=4,
pin_memory=True)
confusion_matrix = np.zeros(
(NUM_CLASSES, NUM_CLASSES))
for index, batch in enumerate(tqdm(testloader)):
image, label = batch
image = image.numpy()[0]
out = run_onnx_inference(ort_session, image)
size = label.size()
# for hwc output
out = out.transpose(0, 3, 1, 2)
if out.shape[2] != size[1] or out.shape[3] != size[2]:
out = torch.from_numpy(out).cpu()
pred = F.interpolate(
out, size=size[1:],
mode='bilinear'
)
confusion_matrix += get_confusion_matrix(
label,
pred,
size,
NUM_CLASSES,
IGNORE_LABEL)
pos = confusion_matrix.sum(1)
res = confusion_matrix.sum(0)
tp = np.diag(confusion_matrix)
pixel_acc = tp.sum()/pos.sum()
mean_acc = (tp/np.maximum(1.0, pos)).mean()
IoU_array = (tp / np.maximum(1.0, pos + res - tp))
mean_IoU = IoU_array.mean()
return mean_IoU, IoU_array, pixel_acc, mean_acc
if __name__ == "__main__":
onnx_path = args.onnx_model
root = args.root
list_path = args.list_path
if args.ipu:
providers = ["VitisAIExecutionProvider"]
provider_options = [{"config_file": args.provider_config}]
else:
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
provider_options = None
ort_session = onnxruntime.InferenceSession(onnx_path, providers=providers, provider_options=provider_options)
mean_IoU, IoU_array, pixel_acc, mean_acc = testval(ort_session, root, list_path)
msg = 'MeanIU: {: 4.4f}, Pixel_Acc: {: 4.4f}, Mean_Acc: {: 4.4f}'.format(mean_IoU, \
pixel_acc, mean_acc)
print(msg)