Spaces:
Runtime error
Runtime error
import numpy as np | |
import cv2 | |
import torch | |
from PIL import Image | |
import termcolor | |
from glob import glob | |
template_dir = "character_template" | |
char_info = { | |
"character_template/e.png": "鄂", "character_template/gui.png": "桂", | |
"character_template/hei.png": "黑", "character_template/ji.png": "冀", | |
"character_template/gui1.png": "贵", "character_template/jing.png": "京", | |
"character_template/lu.png": "鲁", "character_template/min.png": "闽", | |
"character_template/su.png": "苏", "character_template/wan.png": "皖", | |
"character_template/yu.png": "豫", "character_template/yue.png": "粤", | |
"character_template/xin.png": "新", "character_template/chuan.jpg": "川", | |
"character_template/ji1.jpg": "吉", "character_template/jin.jpg": "津", | |
"character_template/liao.jpg": "辽", "character_template/shan.jpg": "陕", | |
"character_template/zhe.jpg": "浙", "character_template/meng.jpg": "蒙", | |
} | |
char_list = list(char_info.values()) | |
character_image_list = [] | |
for template_path in char_info.keys(): | |
character_image = Image.open(template_path).convert('RGB') | |
character_image_list.append(character_image) | |
print(f"Support Chinese characters: {termcolor.colored(char_list, 'blue')}") | |
def calculate_correlation(image1: Image.Image, image2: Image.Image): | |
image1 = np.array(image1) | |
image2 = np.array(image2) | |
image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2GRAY) | |
image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2GRAY) | |
image2 = cv2.resize(image2, (image1.shape[1], image1.shape[0])) | |
image1_flat = image1.flatten() | |
image2_flat = image2.flatten() | |
correlation = np.corrcoef(image1_flat, image2_flat)[0, 1] | |
return correlation | |
def calculate_sift(image1: Image.Image, image2: Image.Image): | |
image1 = np.array(image1) | |
image2 = np.array(image2) | |
image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2GRAY) | |
image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2GRAY) | |
image2 = cv2.resize(image2, (image1.shape[1], image1.shape[0])) | |
sift = cv2.SIFT_create() | |
kp1, des1 = sift.detectAndCompute(image1, None) | |
kp2, des2 = sift.detectAndCompute(image2, None) | |
bf = cv2.BFMatcher() | |
matches = bf.knnMatch(des1, des2, k=2) | |
good = [] | |
for m, n in matches: | |
if m.distance < 0.75 * n.distance: | |
good.append(m) | |
src_pts = np.float32([kp1[m.queryIdx].pt for m in good]).reshape(-1, 1, 2) | |
dst_pts = np.float32([kp2[m.trainIdx].pt for m in good]).reshape(-1, 1, 2) | |
if len(good) < 4: | |
return len(good) | |
homography, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0) | |
inlier_matches = [m for i, m in enumerate(good) if mask[i] == 1] | |
return len(inlier_matches) | |
def recognize_chinese_char(image: Image.Image, image_path: str=None, print_probs=False): | |
if image_path is not None: | |
image = Image.open(image_path).convert('RGB') | |
score_list = [] | |
for character_image in character_image_list: | |
score_list.append(calculate_sift(image, character_image)) | |
char_index = np.array(score_list).argmax() | |
if print_probs: | |
prob_dict = dict(zip(char_list, score_list)) | |
print(f"Label probs: {termcolor.colored(prob_dict, 'red')}") | |
return char_list[char_index] | |
if __name__ == "__main__": | |
img_paths = glob(f"cut_plate/*.jpg") + glob(f"cut_plate/*.png") + glob(f"cut_plate/*.jpeg") | |
for image_path in img_paths: | |
print(image_path, recognize_chinese_char(None, image_path)) |