File size: 11,842 Bytes
e2cc14b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 |
import torch
import cv2
import pytesseract
from PIL import Image, ImageDraw, ImageFont
from collections import deque
import numpy as np
import os
# pytesseract.pytesseract.tesseract_cmd = 'Tesseract\\tesseract.exe'
def get_full_img_path(src_dir):
"""
input: Đường dẫn đền folder chứa ảnh
output: Danh sách tên của tất cả các ảnh
"""
list_img_names = []
for dirname, _, filenames in os.walk(src_dir):
for filename in filenames:
path = os.path.join(dirname, filename).replace(src_dir, '')
if path[0] == '/':
path = path[1:]
list_img_names.append(path)
return list_img_names
def create_text_mask(src_img, detect_text_model, kernel_size=5, iterations=3):
"""
input: Ảnh gốc, để dưới định dạng là np.array, shape: [H, W, C]
output: Mask đánh dấu text trong ảnh gốc, 0 là chữ, 1 là nền; shape: [H, W]
"""
img = torch.from_numpy(src_img).to(torch.uint8).to(detect_text_model.device)
imgT = (img / 255).unsqueeze(0).permute(0, -1, -3, -2)
detect_text_model.eval()
with torch.no_grad():
result = detect_text_model(imgT).squeeze()
result = (result >= 0.5).detach().cpu().numpy()
mask = ((1-result) * 255).astype(np.uint8)
kernel = np.ones((kernel_size, kernel_size), np.uint8)
mask = cv2.erode(mask, kernel, iterations=iterations)
mask = cv2.dilate(mask, kernel, iterations=2*iterations)
mask = cv2.erode(mask, kernel, iterations=iterations)
mask = (1 - mask // 255).astype(np.uint8)
return mask
def create_wordball_mask(src_img, detect_wordball_model, kernel_size=5, iterations=3):
"""
input: Ảnh gốc, để dưới định dạng là np.array, shape: [H, W, C]
output: Mask đánh dấu text trong ảnh gốc, 0 là chữ, 1 là nền; shape: [H, W]
"""
img = torch.from_numpy(src_img).to(torch.uint8).to(detect_wordball_model.device)
imgT = (img / 255).unsqueeze(0).permute(0, -1, -3, -2)
detect_wordball_model.eval()
with torch.no_grad():
result = detect_wordball_model(imgT).squeeze()
result = (result >= 0.5).detach().cpu().numpy()
mask = ((1-result) * 255).astype(np.uint8)
kernel = np.ones((kernel_size, kernel_size), np.uint8)
mask = cv2.erode(mask, kernel, iterations=iterations)
mask = cv2.dilate(mask, kernel, iterations=2*iterations)
mask = cv2.erode(mask, kernel, iterations=iterations)
mask = (1 - mask // 255).astype(np.uint8)
return mask
def clear_text(src_img, text_msk, wordball_msk, text_value=0, non_text_value=1, r=5):
"""
input: src_img: Ảnh gốc, để dưới định dạng là np.array, shape: [H, W, C]
text_msk: Mask đánh dấu text trong ảnh gốc; shape: [H, W]
text_value: Giá trị mà trong mặt nạ nó là text
non_text_value: Giá trị mà trong mặt nạ nó là nền
r: Bán kính để sử dụng cho việc xoá text và vẽ lại phần bị xoá
output: Ảnh sau khi xoá text, để dưới định dạng là np.array, shape: [H, W, C]
"""
MAX = max(text_value, non_text_value)
MIN = min(text_value, non_text_value)
scale_text_value = (text_value - MIN) / (MAX - MIN)
scale_non_text_value = (non_text_value - MIN) / (MAX - MIN)
text_msk[text_msk==text_value] = scale_text_value
text_msk[text_msk==non_text_value] = scale_non_text_value
wordball_msk[wordball_msk==text_value] = scale_text_value
wordball_msk[wordball_msk==non_text_value] = scale_non_text_value
if scale_text_value == 0:
text_msk = 1 - text_msk
wordball_msk = 1 - wordball_msk
text_msk = text_msk * 255
remove_txt = cv2.inpaint(src_img, text_msk, r, cv2.INPAINT_TELEA)
remove_wordball = remove_txt.copy()
remove_wordball[wordball_msk==1] = 255
return remove_wordball
def dfs(grid, y, x, visited, value):
"""
Thuật toán tìm miền liên thông, xem thêm về đồ thị nếu không biết nó là gì
Output: Một HCN bao phủ miền liên thông + Diện tích của miền liên thông
"""
max_y, max_x = y, x
min_y, min_x = y+1, x+1
area = 0
stack = deque([(y, x)])
while stack:
y, x = stack.pop()
max_x = max(max_x, x)
max_y = max(max_y, y)
min_x = min(min_x, x)
min_y = min(min_y, y)
if (y, x) not in visited:
visited.add((y, x))
area += 1
# Kiểm tra các ô liền kề
for dx, dy in [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]:
nx, ny = x + dx, y + dy
if 0 <= ny < grid.shape[0] and 0 <= nx < grid.shape[1] and grid[ny, nx] == value and (ny, nx) not in visited:
stack.append((ny, nx))
return (min_x, min_y, max_x, max_y), area
def find_clusters(grid, value):
"""
Thuật toán tìm danh sách các miền liên thông
"""
visited = set()
clusters = []
areas = []
for y in range(grid.shape[0]):
for x in range(grid.shape[1]):
if grid[y, x] == value and (y, x) not in visited:
cluster, area = dfs(grid, y, x, visited, value)
clusters.append(cluster)
areas.append(area)
return clusters, areas
def get_text_positions(text_msk, text_value=0):
"""
input: text_msk: Mask đánh dấu text trong ảnh gốc; shape: [H, W]
text_value: Giá trị mà trong mặt nạ nó là text
min_area: Giả trị tối thiểu của vùng có thể có text
output: Danh sách các cùng chứa text, định dạng (min_x, min_y, max_x, max_y)
"""
clusters, areas = find_clusters(text_msk, value=text_value)
return clusters, areas
def filter_text_positions(clusters, areas, min_area=1200, max_area=10000):
clusters = clusters[(areas >= min_area) & (areas <= max_area)]
return clusters
def get_list_texts(src_img, text_positions, lang='eng'):
"""
input: src_img: Ảnh gốc, để dưới định dạng là np.array, shape: [H, W, C]
text_positions: Danh sách các cùng chứa text, định dạng (min_x, min_y, max_x, max_y)
lang: Ngôn ngữ của text
output: Danh sách các câu text
"""
list_texts = []
for idx, (min_x, min_y, max_x, max_y) in enumerate(text_positions):
crop_img = src_img[min_y:max_y+1, min_x:max_x+1]
img_rgb = cv2.cvtColor(crop_img, cv2.COLOR_BGR2RGB)
img = Image.fromarray(img_rgb)
text = pytesseract.image_to_string(img, lang=lang).replace('\n', ' ').strip()
while ' ' in text:
text = text.replace(' ', ' ')
list_texts.append(text)
return list_texts
def translate(list_texts, translator):
translated_texts = []
for text in list_texts:
if not text:
text = 'a'
translated_text = translator.translate(text, src='en', dest='vi').text
translated_texts.append(translated_text)
return translated_texts
def add_centered_multiline_text(image, text, box, font_path="arial.ttf", font_size=36, pad=5, text_color=0):
# Mở ảnh
draw = ImageDraw.Draw(image)
# Giải nén box (min_x, min_y, max_x, max_y)
min_x, min_y, max_x, max_y = box
# Tạo font
font = ImageFont.truetype(font_path, font_size)
# Chia văn bản thành nhiều dòng nếu cần
wrapped_lines = wrap_text(text, font, draw, max_x - min_x)
# Tính chiều cao của tất cả các dòng cộng lại
total_text_height = sum(get_text_height(line, draw, font) for line in wrapped_lines)
# Tính toạ độ y bắt đầu để căn giữa theo chiều dọc
start_y = min_y + (max_y - min_y - total_text_height) // 2
# Vẽ từng dòng và căn giữa theo chiều ngang
current_y = start_y
for line in wrapped_lines:
text_width, text_height = get_text_dimensions(line, draw, font)
text_x = min_x + (max_x - min_x - text_width) // 2 # Căn giữa theo chiều ngang
draw.text((text_x, current_y), line, fill=text_color, font=font)
current_y += text_height + pad # Di chuyển y xuống để vẽ dòng tiếp theo
# Lưu ảnh mới
return image
def get_text_dimensions(text, draw, font):
"""Trả về (width, height) của văn bản."""
bbox = draw.textbbox((0, 0), text, font=font)
width = bbox[2] - bbox[0]
height = bbox[3] - bbox[1]
return width, height
def get_text_height(text, draw, font):
"""Trả về chiều cao của văn bản."""
_, _, _, height = draw.textbbox((0, 0), text, font=font)
return height
def wrap_text(text, font, draw, max_width):
"""Chia văn bản thành nhiều dòng dựa trên chiều rộng tối đa."""
words = text.split()
lines = []
current_line = ""
for word in words:
# Thử thêm từ vào dòng hiện tại
test_line = f"{current_line} {word}".strip()
test_width, _ = get_text_dimensions(test_line, draw, font)
if test_width <= max_width:
current_line = test_line
else:
# Nếu quá rộng, lưu dòng hiện tại và bắt đầu dòng mới
lines.append(current_line)
current_line = word
# Thêm dòng cuối cùng
if current_line:
lines.append(current_line)
return lines
def insert_text(non_text_src_img, list_translated_texts, text_positions, font=['MTO Astro City.ttf'], font_size=[20], pad=[5], text_color=0, stroke=[3]):
# Copy ảnh không chữ
img_bgr = non_text_src_img.copy()
# Thêm text vào măt nạ 1
for idx, text in enumerate(list_translated_texts):
# Tạo mặt nạ trắng
mask1 = Image.new("L", img_bgr.shape[:2][::-1], 255)
mask2 = Image.new("L", img_bgr.shape[:2][::-1], 255)
mask1 = add_centered_multiline_text(mask1, text, text_positions[idx], f'MTO Font/{font[idx]}', font_size[idx], pad=pad[idx], text_color=text_color)
# Chuyển ảnh từ PIL sang cv2
mask1 = (np.array(mask1) >= 127).astype(np.uint8) * 255
mask1 = cv2.cvtColor(mask1, cv2.COLOR_RGB2BGR)
if stroke[idx] > 0:
mask2 = np.array(mask2).astype(np.uint8)
mask2 = cv2.cvtColor(mask2, cv2.COLOR_RGB2BGR)
mask2 = mask2 - mask1
kernel = np.ones((stroke[idx]+1, stroke[idx]+1), np.uint8)
mask2 = cv2.dilate(mask2, kernel, iterations=1)
img_bgr[mask2==255] = 255
img_bgr[mask1==text_color] = text_color
return img_bgr
def save_img(path, translated_text_src_img):
"""
input: path: Đường dẫn đến ảnh gốc ban đầu
translated_text_src_img: Ảnh sau khi được dịch
output: Ảnh sau dịch được lưu lại, trong tên có thêm "translated-"
"""
dot = path.rfind('.')
last_slash = -1
if '/' in path:
last_slash = path.rfind('/')
ext = path[dot:]
parent_path = path[:last_slash+1]
name = path[last_slash+1:dot]
if parent_path and not os.path.exists(parent_path):
os.mkdir(parent_path)
cv2.imwrite(f'{parent_path}translated-{name}{ext}', translated_text_src_img)
print(f'Image saved at {parent_path}translated-{name}{ext}') |