|
|
|
|
|
import numpy as np |
|
from model import BiSeNet |
|
|
|
import torch |
|
|
|
import os |
|
import os.path as osp |
|
|
|
from PIL import Image |
|
import torchvision.transforms as transforms |
|
import cv2 |
|
from pathlib import Path |
|
import configargparse |
|
import tqdm |
|
|
|
|
|
|
|
def vis_parsing_maps(im, parsing_anno, stride, save_im=False, save_path='vis_results/parsing_map_on_im.jpg', |
|
img_size=(512, 512)): |
|
im = np.array(im) |
|
vis_im = im.copy().astype(np.uint8) |
|
vis_parsing_anno = parsing_anno.copy().astype(np.uint8) |
|
vis_parsing_anno = cv2.resize( |
|
vis_parsing_anno, None, fx=stride, fy=stride, interpolation=cv2.INTER_NEAREST) |
|
vis_parsing_anno_color = np.zeros( |
|
(vis_parsing_anno.shape[0], vis_parsing_anno.shape[1], 3)) + np.array([255, 255, 255]) |
|
|
|
num_of_class = np.max(vis_parsing_anno) |
|
|
|
for pi in range(1, 14): |
|
index = np.where(vis_parsing_anno == pi) |
|
vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0]) |
|
|
|
for pi in range(14, 16): |
|
index = np.where(vis_parsing_anno == pi) |
|
vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 255, 0]) |
|
for pi in range(16, 17): |
|
index = np.where(vis_parsing_anno == pi) |
|
vis_parsing_anno_color[index[0], index[1], :] = np.array([0, 0, 255]) |
|
for pi in range(17, num_of_class+1): |
|
index = np.where(vis_parsing_anno == pi) |
|
vis_parsing_anno_color[index[0], index[1], :] = np.array([255, 0, 0]) |
|
|
|
vis_parsing_anno_color = vis_parsing_anno_color.astype(np.uint8) |
|
index = np.where(vis_parsing_anno == num_of_class-1) |
|
vis_im = cv2.resize(vis_parsing_anno_color, img_size, |
|
interpolation=cv2.INTER_NEAREST) |
|
if save_im: |
|
cv2.imwrite(save_path, vis_im) |
|
|
|
|
|
def evaluate(respth='./res/test_res', dspth='./data', cp='model_final_diss.pth'): |
|
|
|
Path(respth).mkdir(parents=True, exist_ok=True) |
|
|
|
print(f'[INFO] loading model...') |
|
n_classes = 19 |
|
net = BiSeNet(n_classes=n_classes) |
|
net.cuda() |
|
net.load_state_dict(torch.load(cp)) |
|
net.eval() |
|
|
|
to_tensor = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
|
]) |
|
|
|
image_paths = os.listdir(dspth) |
|
|
|
with torch.no_grad(): |
|
for image_path in tqdm.tqdm(image_paths): |
|
if image_path.endswith('.jpg') or image_path.endswith('.png'): |
|
img = Image.open(osp.join(dspth, image_path)) |
|
ori_size = img.size |
|
image = img.resize((512, 512), Image.BILINEAR) |
|
image = image.convert("RGB") |
|
img = to_tensor(image) |
|
|
|
|
|
inputs = torch.unsqueeze(img, 0) |
|
outputs = net(inputs.cuda()) |
|
parsing = outputs.mean(0).cpu().numpy().argmax(0) |
|
|
|
image_path = int(image_path[:-4]) |
|
image_path = str(image_path) + '.png' |
|
|
|
vis_parsing_maps(image, parsing, stride=1, save_im=True, save_path=osp.join(respth, image_path), img_size=ori_size) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = configargparse.ArgumentParser() |
|
parser.add_argument('--respath', type=str, default='./result/', help='result path for label') |
|
parser.add_argument('--imgpath', type=str, default='./imgs/', help='path for input images') |
|
parser.add_argument('--modelpath', type=str, default='data_utils/face_parsing/79999_iter.pth') |
|
args = parser.parse_args() |
|
evaluate(respth=args.respath, dspth=args.imgpath, cp=args.modelpath) |
|
|