File size: 2,064 Bytes
d94f42d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import warnings
warnings.filterwarnings("ignore")

import torch
from PIL import Image
from torchvision import transforms as T
from glob import glob
import os
import re
import termcolor
from utils.iqa_recognize import recognize_chinese_char

# Load model and image transforms
device = "cuda" if torch.cuda.is_available() else "cpu"
parseq = torch.hub.load('baudm/parseq', 'parseq', pretrained=True, device=device, verbose=False).eval()
# use termcolor to print the model
print(f"Using device: {termcolor.colored(device, 'green')}, model: {termcolor.colored('parseq', 'green')}")

img_transform = T.Compose([
            T.Resize(parseq.hparams.img_size, T.InterpolationMode.BICUBIC),
            T.ToTensor(),
            T.Normalize(0.5, 0.5)
        ])

filtered_chars = ["-", "_", ",", ".", ":", "!", "(", ")", " "]

def recognize_char(img: Image.Image, image_path: str=None, cut_ratio=0.15, save_image=False, print_probs=False):
    if image_path is not None:
        img = Image.open(image_path).convert('RGB')

    left_part = img.crop((0, 0, img.size[0]*cut_ratio, img.size[1]))
    if image_path is not None and save_image:
        os.makedirs("cut_plate", exist_ok=True)
        left_part.save(f"cut_plate/{os.path.basename(image_path)}")
    left_char = recognize_chinese_char(left_part, print_probs=print_probs)
    img = img.crop((img.size[0]*cut_ratio, 0, img.size[0], img.size[1]))

    img = img_transform(img).unsqueeze(0)
    logits = parseq(img)

    pred = logits.softmax(-1)
    label, confidence = parseq.tokenizer.decode(pred)
    label = re.sub(f"[{''.join(filtered_chars)}]", "", label[0])

    return {
        "plate": left_char + label, 
        "confidence": float(confidence[0].data.mean()),
        "chinese": left_part,
    }
    

if __name__ == "__main__":
    img_paths = glob(f"rectified_plate/*.jpg") + glob(f"rectified_plate/*.png") + glob(f"rectified_plate/*.jpeg")
    for img_path in img_paths:
        result = recognize_char(None, img_path, save_image=True)
        print(f"Recognized: {termcolor.colored(result, 'blue')}")