File size: 5,870 Bytes
d94f42d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import warnings
warnings.filterwarnings("ignore")
from transformers import logging
logging.set_verbosity_error()

import cv2
import numpy as np
from PIL import Image
from glob import glob
from typing import Union
import termcolor
import os

import torch
import torchvision

from groundingdino.util.inference import Model
from segment_anything import sam_model_registry, SamPredictor
from utils.recognize_characters import recognize_char

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# GroundingDINO config and checkpoint
GROUNDING_DINO_CONFIG_PATH = "utils/GroundingDINO_SwinB_cfg.py"
GROUNDING_DINO_CHECKPOINT_PATH = "checkpoints/groundingdino_swinb_cogcoor.pth"

# Segment-Anything checkpoint
SAM_ENCODER_VERSION = "vit_h"
SAM_CHECKPOINT_PATH = "checkpoints/sam_vit_h_4b8939.pth"

# Building GroundingDINO inference model
grounding_dino_model = Model(model_config_path=GROUNDING_DINO_CONFIG_PATH, model_checkpoint_path=GROUNDING_DINO_CHECKPOINT_PATH, device=DEVICE)
print(f"Using device: {termcolor.colored(DEVICE, 'green')}, model: {termcolor.colored('GroundingDINO', 'green')}, model path: {termcolor.colored(GROUNDING_DINO_CHECKPOINT_PATH, 'green')}")

# Building SAM Model and SAM Predictor
sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH)
sam.to(DEVICE)
sam_predictor = SamPredictor(sam)
print(f"Using device: {termcolor.colored(DEVICE, 'green')}, model: {termcolor.colored('Segment-Anything', 'green')}, model path: {termcolor.colored(SAM_CHECKPOINT_PATH, 'green')}")


# Predict classes and hyper-param for GroundingDINO
BOX_THRESHOLD = 0.25
TEXT_THRESHOLD = 0.25
NMS_THRESHOLD = 0.8
RECTIFIED_W, RECTIFIED_H = 600, 200


# Prompting SAM with detected boxes
def segment(sam_predictor: SamPredictor, image: np.ndarray, xyxy: np.ndarray) -> np.ndarray:
    sam_predictor.set_image(image)
    result_masks = []
    for box in xyxy:
        masks, scores, logits = sam_predictor.predict(
            box=box,
            multimask_output=True
        )
        index = np.argmax(scores)
        result_masks.append(masks[index])
    return np.array(result_masks)


def recognize_plate(image_path: Union[np.ndarray, str], cut_ratio=0.15, save_image=False, print_probs=False):
    if isinstance(image_path, str):
        image = cv2.imread(image_path)
    else:
        image = image_path
    CLASSES = ['license plate', 'sky', 'person']

    # detect objects
    detections = grounding_dino_model.predict_with_classes(
        image=image,
        classes=CLASSES,
        box_threshold=BOX_THRESHOLD,
        text_threshold=BOX_THRESHOLD
    )

    # NMS post process
    nms_idx = torchvision.ops.nms(
        torch.from_numpy(detections.xyxy), 
        torch.from_numpy(detections.confidence), 
        NMS_THRESHOLD
    ).numpy().tolist()

    detections.xyxy = detections.xyxy[nms_idx]
    detections.confidence = detections.confidence[nms_idx]
    detections.class_id = detections.class_id[nms_idx]

    # convert detections to masks
    detections.mask = segment(
        sam_predictor=sam_predictor,
        image=cv2.cvtColor(image, cv2.COLOR_BGR2RGB),
        xyxy=detections.xyxy
    )

    # filter class_id==0 results
    result_masks = detections.mask[(detections.class_id==0),:,:]
    result_masks = result_masks.astype(np.uint8)

    # findout the min mask
    min_area, min_mask = np.inf, np.zeros_like(result_masks[0])
    for mask in result_masks:
        area = np.sum(mask)
        if area < min_area:
            min_area = area
            min_mask = mask

    # findout minrect of min mask
    minrect = cv2.minAreaRect(np.argwhere(min_mask))
    box = cv2.boxPoints(minrect)
    box = np.int0(box)
    box[:,[0, 1]] = box[:,[1, 0]]
    # draw box
    cv2.drawContours(image, [box], 0, (0, 0, 255), 2)
    
    if save_image:
        os.makedirs("contours", exist_ok=True)
        cv2.imwrite(f"contours/{os.path.basename(image_path)}", image)

    # sort the box points by clockwise
    box = box[np.argsort(box[:, 0])]
    if box[0, 1] > box[1, 1]:
        box[[0, 1], :] = box[[1, 0], :]
    if box[2, 1] < box[3, 1]:
        box[[2, 3], :] = box[[3, 2], :]

    # sort the box points by side length (short-long-short-long)
    if np.linalg.norm(box[0] - box[1]) > np.linalg.norm(box[1] - box[2]):
        box[[1, 3], :] = box[[3, 1], :]

    
    # cut out the license plate and rectify it
    rectified_plate = cv2.warpPerspective(image, cv2.getPerspectiveTransform(box.astype(np.float32), np.array([[0, 0], [0, RECTIFIED_H], [RECTIFIED_W, RECTIFIED_H], [RECTIFIED_W, 0]], dtype=np.float32)), (RECTIFIED_W, RECTIFIED_H))
    rectified_plate_flip = cv2.flip(rectified_plate, 0)

    if save_image:
        os.makedirs("rectified_plate", exist_ok=True)
        cv2.imwrite(f"rectified_plate/{os.path.basename(image_path)}", rectified_plate)
        cv2.imwrite(f"rectified_plate/{os.path.basename(image_path)}_flip.jpg", rectified_plate_flip)

    # recognize characters
    result = recognize_char(Image.fromarray(rectified_plate), cut_ratio=cut_ratio, print_probs=print_probs)
    result['rectified_plate'] = rectified_plate
    result_flip = recognize_char(Image.fromarray(rectified_plate_flip), cut_ratio=cut_ratio, print_probs=print_probs)
    result_flip['rectified_plate'] = rectified_plate_flip

    if len(result_flip['plate']) == 7 and result_flip["confidence"] > result["confidence"]:
        result = result_flip
    result['detection'] = image
    return result


if __name__ == "__main__":
    image_dir = "images"
    image_list = glob(f"{image_dir}/*.jpg") + glob(f"{image_dir}/*.png") + glob(f"{image_dir}/*.jpeg")

    for image_path in image_list:
        result = recognize_plate(image_path, save_image=True, print_probs=True)
        print(f"Image path: {termcolor.colored(os.path.basename(image_path), 'green')} Recognized: {termcolor.colored(result, 'blue')}")