Spaces:
Build error
Build error
from signboard_detect import inference_signboard | |
import os | |
import argparse | |
import tqdm | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
def compose(output, mask): | |
h,w = mask.shape | |
for i in range(0, h): | |
for j in range(0,w): | |
if (mask[i,j] > 0.5): | |
output[i,j] = 255 | |
return output | |
def get_parser(): | |
parser = argparse.ArgumentParser(description="Signboard Detection") | |
parser.add_argument("--input", | |
type=str, | |
default="./images", | |
help="A list of space separated input images") | |
parser.add_argument("--output", | |
type=str, | |
default="./output/output_signboard", | |
help="A list of array of segmentation") | |
parser.add_argument("--checkpoint", | |
type=str, | |
default="./checkpoints/ss/ss.ckpt", | |
help="File path to best model checkpoint") | |
args = parser.parse_args() | |
return args | |
def handle(args): | |
if args.input: | |
if os.path.isdir(args.input): | |
args.input = [os.path.join(args.input, fname) | |
for fname in os.listdir(args.input)] | |
elif os.path.isfile(args.input): | |
args.input = [args.input] | |
for path in tqdm.tqdm(args.input): | |
print(path) | |
img = cv2.imread(path) | |
image = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
im_pil = Image.fromarray(image) | |
dimensions = img.shape | |
hei, wid = dimensions[0], dimensions[1] | |
print(hei, wid) | |
result = inference_signboard(im_pil, args.checkpoint) | |
print(" **************** Result **************** ") | |
print(result['rois'].shape) | |
print(result['masks'].shape) | |
print(result['class_ids'].shape) | |
print(result['scores'].shape) | |
print(" **************************************** ") | |
for box in result['rois']: | |
box = box.tolist() | |
image = cv2.rectangle(img, (int(box[0]), int( | |
box[1])), (int(box[2]), int(box[3])), (255, 0, 0), 2) | |
root_ext = os.path.splitext(path) | |
output_path = os.path.join(args.output, root_ext[0] + "_output" + root_ext[1]) | |
cv2.imwrite(output_path, image) | |
img_output = np.zeros((hei,wid), dtype="uint8") | |
for j in range(0,len(result['masks'])): | |
mask = result['masks'][j] | |
im_np = np.array(mask) | |
img_output = compose(img_output, im_np) | |
output_path = os.path.join(args.output, root_ext[0] + "_mask" + root_ext[1]) | |
cv2.imwrite(output_path, img_output) | |
def main(): | |
args = get_parser() | |
handle(args) | |
if __name__ == "__main__": | |
main() | |