from signboard_detect import inference_signboard import os import argparse import tqdm import cv2 import numpy as np from PIL import Image def compose(output, mask): h,w = mask.shape for i in range(0, h): for j in range(0,w): if (mask[i,j] > 0.5): output[i,j] = 255 return output def get_parser(): parser = argparse.ArgumentParser(description="Signboard Detection") parser.add_argument("--input", type=str, default="./images", help="A list of space separated input images") parser.add_argument("--output", type=str, default="./output/output_signboard", help="A list of array of segmentation") parser.add_argument("--checkpoint", type=str, default="./checkpoints/ss/ss.ckpt", help="File path to best model checkpoint") args = parser.parse_args() return args def handle(args): if args.input: if os.path.isdir(args.input): args.input = [os.path.join(args.input, fname) for fname in os.listdir(args.input)] elif os.path.isfile(args.input): args.input = [args.input] for path in tqdm.tqdm(args.input): print(path) img = cv2.imread(path) image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) im_pil = Image.fromarray(image) dimensions = img.shape hei, wid = dimensions[0], dimensions[1] print(hei, wid) result = inference_signboard(im_pil, args.checkpoint) print(" **************** Result **************** ") print(result['rois'].shape) print(result['masks'].shape) print(result['class_ids'].shape) print(result['scores'].shape) print(" **************************************** ") for box in result['rois']: box = box.tolist() image = cv2.rectangle(img, (int(box[0]), int( box[1])), (int(box[2]), int(box[3])), (255, 0, 0), 2) root_ext = os.path.splitext(path) output_path = os.path.join(args.output, root_ext[0] + "_output" + root_ext[1]) cv2.imwrite(output_path, image) img_output = np.zeros((hei,wid), dtype="uint8") for j in range(0,len(result['masks'])): mask = result['masks'][j] im_np = np.array(mask) img_output = compose(img_output, im_np) output_path = os.path.join(args.output, root_ext[0] + "_mask" + root_ext[1]) cv2.imwrite(output_path, img_output) def main(): args = get_parser() handle(args) if __name__ == "__main__": main()