AutoEditor / edit_func.py
SS3M's picture
Upload 29 files
9b3edf8 verified
raw
history blame
11.8 kB
import torch
import cv2
import pytesseract
from PIL import Image, ImageDraw, ImageFont
from collections import deque
import numpy as np
import os
# pytesseract.pytesseract.tesseract_cmd = 'Tesseract\\tesseract.exe'
def get_full_img_path(src_dir):
"""
input: Đường dẫn đền folder chứa ảnh
output: Danh sách tên của tất cả các ảnh
"""
list_img_names = []
for dirname, _, filenames in os.walk(src_dir):
for filename in filenames:
path = os.path.join(dirname, filename).replace(src_dir, '')
if path[0] == '/':
path = path[1:]
list_img_names.append(path)
return list_img_names
def create_text_mask(src_img, detect_text_model, kernel_size=5, iterations=3):
"""
input: Ảnh gốc, để dưới định dạng là np.array, shape: [H, W, C]
output: Mask đánh dấu text trong ảnh gốc, 0 là chữ, 1 là nền; shape: [H, W]
"""
img = torch.from_numpy(src_img).to(torch.uint8).to(detect_text_model.device)
imgT = (img / 255).unsqueeze(0).permute(0, -1, -3, -2)
detect_text_model.eval()
with torch.no_grad():
result = detect_text_model(imgT).squeeze()
result = (result >= 0.5).detach().cpu().numpy()
mask = ((1-result) * 255).astype(np.uint8)
kernel = np.ones((kernel_size, kernel_size), np.uint8)
mask = cv2.erode(mask, kernel, iterations=iterations)
mask = cv2.dilate(mask, kernel, iterations=2*iterations)
mask = cv2.erode(mask, kernel, iterations=iterations)
mask = (1 - mask // 255).astype(np.uint8)
return mask
def create_wordball_mask(src_img, detect_wordball_model, kernel_size=5, iterations=3):
"""
input: Ảnh gốc, để dưới định dạng là np.array, shape: [H, W, C]
output: Mask đánh dấu text trong ảnh gốc, 0 là chữ, 1 là nền; shape: [H, W]
"""
img = torch.from_numpy(src_img).to(torch.uint8).to(detect_wordball_model.device)
imgT = (img / 255).unsqueeze(0).permute(0, -1, -3, -2)
detect_wordball_model.eval()
with torch.no_grad():
result = detect_wordball_model(imgT).squeeze()
result = (result >= 0.5).detach().cpu().numpy()
mask = ((1-result) * 255).astype(np.uint8)
kernel = np.ones((kernel_size, kernel_size), np.uint8)
mask = cv2.erode(mask, kernel, iterations=iterations)
mask = cv2.dilate(mask, kernel, iterations=2*iterations)
mask = cv2.erode(mask, kernel, iterations=iterations)
mask = (1 - mask // 255).astype(np.uint8)
return mask
def clear_text(src_img, text_msk, wordball_msk, text_value=0, non_text_value=1, r=5):
"""
input: src_img: Ảnh gốc, để dưới định dạng là np.array, shape: [H, W, C]
text_msk: Mask đánh dấu text trong ảnh gốc; shape: [H, W]
text_value: Giá trị mà trong mặt nạ nó là text
non_text_value: Giá trị mà trong mặt nạ nó là nền
r: Bán kính để sử dụng cho việc xoá text và vẽ lại phần bị xoá
output: Ảnh sau khi xoá text, để dưới định dạng là np.array, shape: [H, W, C]
"""
MAX = max(text_value, non_text_value)
MIN = min(text_value, non_text_value)
scale_text_value = (text_value - MIN) / (MAX - MIN)
scale_non_text_value = (non_text_value - MIN) / (MAX - MIN)
text_msk[text_msk==text_value] = scale_text_value
text_msk[text_msk==non_text_value] = scale_non_text_value
wordball_msk[wordball_msk==text_value] = scale_text_value
wordball_msk[wordball_msk==non_text_value] = scale_non_text_value
if scale_text_value == 0:
text_msk = 1 - text_msk
wordball_msk = 1 - wordball_msk
text_msk = text_msk * 255
remove_txt = cv2.inpaint(src_img, text_msk, r, cv2.INPAINT_TELEA)
remove_wordball = remove_txt.copy()
remove_wordball[wordball_msk==1] = 255
return remove_wordball
def dfs(grid, y, x, visited, value):
"""
Thuật toán tìm miền liên thông, xem thêm về đồ thị nếu không biết nó là gì
Output: Một HCN bao phủ miền liên thông + Diện tích của miền liên thông
"""
max_y, max_x = y, x
min_y, min_x = y+1, x+1
area = 0
stack = deque([(y, x)])
while stack:
y, x = stack.pop()
max_x = max(max_x, x)
max_y = max(max_y, y)
min_x = min(min_x, x)
min_y = min(min_y, y)
if (y, x) not in visited:
visited.add((y, x))
area += 1
# Kiểm tra các ô liền kề
for dx, dy in [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]:
nx, ny = x + dx, y + dy
if 0 <= ny < grid.shape[0] and 0 <= nx < grid.shape[1] and grid[ny, nx] == value and (ny, nx) not in visited:
stack.append((ny, nx))
return (min_x, min_y, max_x, max_y), area
def find_clusters(grid, value):
"""
Thuật toán tìm danh sách các miền liên thông
"""
visited = set()
clusters = []
areas = []
for y in range(grid.shape[0]):
for x in range(grid.shape[1]):
if grid[y, x] == value and (y, x) not in visited:
cluster, area = dfs(grid, y, x, visited, value)
clusters.append(cluster)
areas.append(area)
return clusters, areas
def get_text_positions(text_msk, text_value=0):
"""
input: text_msk: Mask đánh dấu text trong ảnh gốc; shape: [H, W]
text_value: Giá trị mà trong mặt nạ nó là text
min_area: Giả trị tối thiểu của vùng có thể có text
output: Danh sách các cùng chứa text, định dạng (min_x, min_y, max_x, max_y)
"""
clusters, areas = find_clusters(text_msk, value=text_value)
return clusters, areas
def filter_text_positions(clusters, areas, min_area=1200, max_area=10000):
clusters = clusters[(areas >= min_area) & (areas <= max_area)]
return clusters
def get_list_texts(src_img, text_positions, lang='eng'):
"""
input: src_img: Ảnh gốc, để dưới định dạng là np.array, shape: [H, W, C]
text_positions: Danh sách các cùng chứa text, định dạng (min_x, min_y, max_x, max_y)
lang: Ngôn ngữ của text
output: Danh sách các câu text
"""
list_texts = []
for idx, (min_x, min_y, max_x, max_y) in enumerate(text_positions):
crop_img = src_img[min_y:max_y+1, min_x:max_x+1]
img_rgb = cv2.cvtColor(crop_img, cv2.COLOR_BGR2RGB)
img = Image.fromarray(img_rgb)
text = pytesseract.image_to_string(img, lang=lang).replace('\n', ' ').strip()
while ' ' in text:
text = text.replace(' ', ' ')
list_texts.append(text)
return list_texts
def translate(list_texts, translator):
translated_texts = []
for text in list_texts:
if not text:
text = 'a'
translated_text = translator.translate(text, src='en', dest='vi').text
translated_texts.append(translated_text)
return translated_texts
def add_centered_multiline_text(image, text, box, font_path="arial.ttf", font_size=36, pad=5, text_color=0):
# Mở ảnh
draw = ImageDraw.Draw(image)
# Giải nén box (min_x, min_y, max_x, max_y)
min_x, min_y, max_x, max_y = box
# Tạo font
font = ImageFont.truetype(font_path, font_size)
# Chia văn bản thành nhiều dòng nếu cần
wrapped_lines = wrap_text(text, font, draw, max_x - min_x)
# Tính chiều cao của tất cả các dòng cộng lại
total_text_height = sum(get_text_height(line, draw, font) for line in wrapped_lines)
# Tính toạ độ y bắt đầu để căn giữa theo chiều dọc
start_y = min_y + (max_y - min_y - total_text_height) // 2
# Vẽ từng dòng và căn giữa theo chiều ngang
current_y = start_y
for line in wrapped_lines:
text_width, text_height = get_text_dimensions(line, draw, font)
text_x = min_x + (max_x - min_x - text_width) // 2 # Căn giữa theo chiều ngang
draw.text((text_x, current_y), line, fill=text_color, font=font)
current_y += text_height + pad # Di chuyển y xuống để vẽ dòng tiếp theo
# Lưu ảnh mới
return image
def get_text_dimensions(text, draw, font):
"""Trả về (width, height) của văn bản."""
bbox = draw.textbbox((0, 0), text, font=font)
width = bbox[2] - bbox[0]
height = bbox[3] - bbox[1]
return width, height
def get_text_height(text, draw, font):
"""Trả về chiều cao của văn bản."""
_, _, _, height = draw.textbbox((0, 0), text, font=font)
return height
def wrap_text(text, font, draw, max_width):
"""Chia văn bản thành nhiều dòng dựa trên chiều rộng tối đa."""
words = text.split()
lines = []
current_line = ""
for word in words:
# Thử thêm từ vào dòng hiện tại
test_line = f"{current_line} {word}".strip()
test_width, _ = get_text_dimensions(test_line, draw, font)
if test_width <= max_width:
current_line = test_line
else:
# Nếu quá rộng, lưu dòng hiện tại và bắt đầu dòng mới
lines.append(current_line)
current_line = word
# Thêm dòng cuối cùng
if current_line:
lines.append(current_line)
return lines
def insert_text(non_text_src_img, list_translated_texts, text_positions, font=['MTO Astro City.ttf'], font_size=[20], pad=[5], text_color=0, stroke=[3]):
# Copy ảnh không chữ
img_bgr = non_text_src_img.copy()
# Thêm text vào măt nạ 1
for idx, text in enumerate(list_translated_texts):
# Tạo mặt nạ trắng
mask1 = Image.new("L", img_bgr.shape[:2][::-1], 255)
mask2 = Image.new("L", img_bgr.shape[:2][::-1], 255)
mask1 = add_centered_multiline_text(mask1, text, text_positions[idx], f'MTO Font/{font[idx]}', font_size[idx], pad=pad[idx], text_color=text_color)
# Chuyển ảnh từ PIL sang cv2
mask1 = (np.array(mask1) >= 127).astype(np.uint8) * 255
mask1 = cv2.cvtColor(mask1, cv2.COLOR_RGB2BGR)
if stroke[idx] > 0:
mask2 = np.array(mask2).astype(np.uint8)
mask2 = cv2.cvtColor(mask2, cv2.COLOR_RGB2BGR)
mask2 = mask2 - mask1
kernel = np.ones((stroke[idx]+1, stroke[idx]+1), np.uint8)
mask2 = cv2.dilate(mask2, kernel, iterations=1)
img_bgr[mask2==255] = 255
img_bgr[mask1==text_color] = text_color
return img_bgr
def save_img(path, translated_text_src_img):
"""
input: path: Đường dẫn đến ảnh gốc ban đầu
translated_text_src_img: Ảnh sau khi được dịch
output: Ảnh sau dịch được lưu lại, trong tên có thêm "translated-"
"""
dot = path.rfind('.')
last_slash = -1
if '/' in path:
last_slash = path.rfind('/')
ext = path[dot:]
parent_path = path[:last_slash+1]
name = path[last_slash+1:dot]
if parent_path and not os.path.exists(parent_path):
os.mkdir(parent_path)
cv2.imwrite(f'{parent_path}translated-{name}{ext}', translated_text_src_img)
print(f'Image saved at {parent_path}translated-{name}{ext}')