Spaces:
Runtime error
Runtime error
File size: 5,870 Bytes
d94f42d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import warnings
warnings.filterwarnings("ignore")
from transformers import logging
logging.set_verbosity_error()
import cv2
import numpy as np
from PIL import Image
from glob import glob
from typing import Union
import termcolor
import os
import torch
import torchvision
from groundingdino.util.inference import Model
from segment_anything import sam_model_registry, SamPredictor
from utils.recognize_characters import recognize_char
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# GroundingDINO config and checkpoint
GROUNDING_DINO_CONFIG_PATH = "utils/GroundingDINO_SwinB_cfg.py"
GROUNDING_DINO_CHECKPOINT_PATH = "checkpoints/groundingdino_swinb_cogcoor.pth"
# Segment-Anything checkpoint
SAM_ENCODER_VERSION = "vit_h"
SAM_CHECKPOINT_PATH = "checkpoints/sam_vit_h_4b8939.pth"
# Building GroundingDINO inference model
grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, device=DEVICE)
print(f"Using device: {termcolor.colored(DEVICE, 'green')}, model: {termcolor.colored('GroundingDINO', 'green')}, model path: {termcolor.colored(GROUNDING_DINO_CHECKPOINT_PATH, 'green')}")
# Building SAM Model and SAM Predictor
sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH)
sam.to(DEVICE)
sam_predictor = SamPredictor(sam)
print(f"Using device: {termcolor.colored(DEVICE, 'green')}, model: {termcolor.colored('Segment-Anything', 'green')}, model path: {termcolor.colored(SAM_CHECKPOINT_PATH, 'green')}")
# Predict classes and hyper-param for GroundingDINO
BOX_THRESHOLD = 0.25
TEXT_THRESHOLD = 0.25
NMS_THRESHOLD = 0.8
RECTIFIED_W, RECTIFIED_H = 600, 200
# Prompting SAM with detected boxes
def segment(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray:
sam_predictor.set_image(image)
result_masks = []
for box in xyxy:
masks, scores, logits = sam_predictor.predict(
box=box,
multimask_output=True
)
index = np.argmax(scores)
result_masks.append(masks[index])
return np.array(result_masks)
def recognize_plate(image_path: Union[np.ndarray, str], cut_ratio=0.15, save_image=False, print_probs=False):
if isinstance(image_path, str):
image = cv2.imread(image_path)
else:
image = image_path
CLASSES = ['license plate', 'sky', 'person']
# detect objects
detections = grounding_dino_model.predict_with_classes(
image=image,
classes=CLASSES,
box_threshold=BOX_THRESHOLD,
text_threshold=BOX_THRESHOLD
)
# NMS post process
nms_idx = torchvision.ops.nms(
torch.from_numpy(detections.xyxy),
torch.from_numpy(detections.confidence),
NMS_THRESHOLD
).numpy().tolist()
detections.xyxy = detections.xyxy[nms_idx]
detections.confidence = detections.confidence[nms_idx]
detections.class_id = detections.class_id[nms_idx]
# convert detections to masks
detections.mask = segment(
sam_predictor=sam_predictor,
image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
xyxy=detections.xyxy
)
# filter class_id==0 results
result_masks = detections.mask[(detections.class_id==0),:,:]
result_masks = result_masks.astype(np.uint8)
# findout the min mask
min_area, min_mask = np.inf, np.zeros_like(result_masks[0])
for mask in result_masks:
area = np.sum(mask)
if area < min_area:
min_area = area
min_mask = mask
# findout minrect of min mask
minrect = cv2.minAreaRect(np.argwhere(min_mask))
box = cv2.boxPoints(minrect)
box = np.int0(box)
box[:,[0, 1]] = box[:,[1, 0]]
# draw box
cv2.drawContours(image, [box], 0, (0, 0, 255), 2)
if save_image:
os.makedirs("contours", exist_ok=True)
cv2.imwrite(f"contours/{os.path.basename(image_path)}", image)
# sort the box points by clockwise
box = box[np.argsort(box[:, 0])]
if box[0, 1] > box[1, 1]:
box[[0, 1], :] = box[[1, 0], :]
if box[2, 1] < box[3, 1]:
box[[2, 3], :] = box[[3, 2], :]
# sort the box points by side length (short-long-short-long)
if np.linalg.norm(box[0] - box[1]) > np.linalg.norm(box[1] - box[2]):
box[[1, 3], :] = box[[3, 1], :]
# cut out the license plate and rectify it
rectified_plate = cv2.warpPerspective(image, cv2.getPerspectiveTransform(box.astype(np.float32), np.array([[0, 0], [0, RECTIFIED_H], [RECTIFIED_W, RECTIFIED_H], [RECTIFIED_W, 0]], dtype=np.float32)), (RECTIFIED_W, RECTIFIED_H))
rectified_plate_flip = cv2.flip(rectified_plate, 0)
if save_image:
os.makedirs("rectified_plate", exist_ok=True)
cv2.imwrite(f"rectified_plate/{os.path.basename(image_path)}", rectified_plate)
cv2.imwrite(f"rectified_plate/{os.path.basename(image_path)}_flip.jpg", rectified_plate_flip)
# recognize characters
result = recognize_char(Image.fromarray(rectified_plate), cut_ratio=cut_ratio, print_probs=print_probs)
result['rectified_plate'] = rectified_plate
result_flip = recognize_char(Image.fromarray(rectified_plate_flip), cut_ratio=cut_ratio, print_probs=print_probs)
result_flip['rectified_plate'] = rectified_plate_flip
if len(result_flip['plate']) == 7 and result_flip["confidence"] > result["confidence"]:
result = result_flip
result['detection'] = image
return result
if __name__ == "__main__":
image_dir = "images"
image_list = glob(f"{image_dir}/*.jpg") + glob(f"{image_dir}/*.png") + glob(f"{image_dir}/*.jpeg")
for image_path in image_list:
result = recognize_plate(image_path, save_image=True, print_probs=True)
print(f"Image path: {termcolor.colored(os.path.basename(image_path), 'green')} Recognized: {termcolor.colored(result, 'blue')}")
|