File size: 11,842 Bytes
e2cc14b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
import torch
import cv2
import pytesseract
from PIL import Image, ImageDraw, ImageFont
from collections import deque
import numpy as np
import os

# pytesseract.pytesseract.tesseract_cmd = 'Tesseract\\tesseract.exe'

def get_full_img_path(src_dir):
    """

    input:  Đường dẫn đền folder chứa ảnh

    output: Danh sách tên của tất cả các ảnh

    """
    list_img_names = []
    for dirname, _, filenames in os.walk(src_dir):
        for filename in filenames:
            path = os.path.join(dirname, filename).replace(src_dir, '')
            if path[0] == '/':
                path = path[1:]
            list_img_names.append(path)
    return list_img_names


def create_text_mask(src_img, detect_text_model, kernel_size=5, iterations=3):
    """

    input:  Ảnh gốc, để dưới định dạng là np.array, shape: [H, W, C]

    output: Mask đánh dấu text trong ảnh gốc, 0 là chữ, 1 là nền; shape: [H, W]

    """
    img = torch.from_numpy(src_img).to(torch.uint8).to(detect_text_model.device)
    imgT = (img / 255).unsqueeze(0).permute(0, -1, -3, -2)
    
    detect_text_model.eval()
    with torch.no_grad():
        result = detect_text_model(imgT).squeeze()
    result = (result >= 0.5).detach().cpu().numpy()
    
    mask = ((1-result) * 255).astype(np.uint8)

    kernel = np.ones((kernel_size, kernel_size), np.uint8)
    mask = cv2.erode(mask, kernel, iterations=iterations)
    mask = cv2.dilate(mask, kernel, iterations=2*iterations)
    mask = cv2.erode(mask, kernel, iterations=iterations)

    mask = (1 - mask // 255).astype(np.uint8)
    return mask


def create_wordball_mask(src_img, detect_wordball_model, kernel_size=5, iterations=3):
    """

    input:  Ảnh gốc, để dưới định dạng là np.array, shape: [H, W, C]

    output: Mask đánh dấu text trong ảnh gốc, 0 là chữ, 1 là nền; shape: [H, W]

    """
    img = torch.from_numpy(src_img).to(torch.uint8).to(detect_wordball_model.device)
    imgT = (img / 255).unsqueeze(0).permute(0, -1, -3, -2)
    
    detect_wordball_model.eval()
    with torch.no_grad():
        result = detect_wordball_model(imgT).squeeze()
    result = (result >= 0.5).detach().cpu().numpy()
    
    mask = ((1-result) * 255).astype(np.uint8)

    kernel = np.ones((kernel_size, kernel_size), np.uint8)
    mask = cv2.erode(mask, kernel, iterations=iterations)
    mask = cv2.dilate(mask, kernel, iterations=2*iterations)
    mask = cv2.erode(mask, kernel, iterations=iterations)

    mask = (1 - mask // 255).astype(np.uint8)
    return mask


def clear_text(src_img, text_msk, wordball_msk, text_value=0, non_text_value=1, r=5):
    """

    input:  src_img: Ảnh gốc, để dưới định dạng là np.array, shape: [H, W, C]

            text_msk: Mask đánh dấu text trong ảnh gốc; shape: [H, W]

            text_value: Giá trị mà trong mặt nạ nó là text

            non_text_value: Giá trị mà trong mặt nạ nó là nền

            r: Bán kính để sử dụng cho việc xoá text và vẽ lại phần bị xoá

    output: Ảnh sau khi xoá text, để dưới định dạng là np.array, shape: [H, W, C]

    """
    MAX = max(text_value, non_text_value)
    MIN = min(text_value, non_text_value)
    
    scale_text_value = (text_value - MIN) / (MAX - MIN)
    scale_non_text_value = (non_text_value - MIN) / (MAX - MIN)
    
    text_msk[text_msk==text_value] = scale_text_value
    text_msk[text_msk==non_text_value] = scale_non_text_value
    
    wordball_msk[wordball_msk==text_value] = scale_text_value
    wordball_msk[wordball_msk==non_text_value] = scale_non_text_value
    
    if scale_text_value == 0:
        text_msk = 1 - text_msk
        wordball_msk = 1 - wordball_msk
    text_msk = text_msk * 255
    
    remove_txt = cv2.inpaint(src_img, text_msk, r, cv2.INPAINT_TELEA)
    remove_wordball = remove_txt.copy()
    remove_wordball[wordball_msk==1] = 255
    
    return remove_wordball


def dfs(grid, y, x, visited, value):
    """

    Thuật toán tìm miền liên thông, xem thêm về đồ thị nếu không biết nó là gì

    Output: Một HCN bao phủ miền liên thông + Diện tích của miền liên thông

    """
    max_y, max_x = y, x
    min_y, min_x = y+1, x+1
    area = 0
    
    stack = deque([(y, x)])
    while stack:
        y, x = stack.pop()

        max_x = max(max_x, x)
        max_y = max(max_y, y)
        min_x = min(min_x, x)
        min_y = min(min_y, y)

        if (y, x) not in visited:
            visited.add((y, x))
            area += 1
            # Kiểm tra các ô liền kề
            for dx, dy in [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]:
                nx, ny = x + dx, y + dy
                if 0 <= ny < grid.shape[0] and 0 <= nx < grid.shape[1] and grid[ny, nx] == value and (ny, nx) not in visited:
                    stack.append((ny, nx))

    return (min_x, min_y, max_x, max_y), area


def find_clusters(grid, value):
    """

    Thuật toán tìm danh sách các miền liên thông

    """
    visited = set()
    clusters = []
    areas = []
    
    for y in range(grid.shape[0]):
        for x in range(grid.shape[1]):
            if grid[y, x] == value and (y, x) not in visited:
                cluster, area = dfs(grid, y, x, visited, value)
                clusters.append(cluster)
                areas.append(area)
    
    return clusters, areas

def get_text_positions(text_msk, text_value=0):
    """

    input:  text_msk: Mask đánh dấu text trong ảnh gốc; shape: [H, W]

            text_value: Giá trị mà trong mặt nạ nó là text

            min_area: Giả trị tối thiểu của vùng có thể có text

    output: Danh sách các cùng chứa text, định dạng (min_x, min_y, max_x, max_y)

    """
    
    clusters, areas = find_clusters(text_msk, value=text_value)
    return clusters, areas

def filter_text_positions(clusters, areas, min_area=1200, max_area=10000):
    clusters = clusters[(areas >= min_area) & (areas <= max_area)]
    return clusters


def get_list_texts(src_img, text_positions, lang='eng'):
    """

    input:  src_img: Ảnh gốc, để dưới định dạng là np.array, shape: [H, W, C]

            text_positions: Danh sách các cùng chứa text, định dạng (min_x, min_y, max_x, max_y)

            lang: Ngôn ngữ của text

    output: Danh sách các câu text

    """
    list_texts = []
    for idx, (min_x, min_y, max_x, max_y) in enumerate(text_positions):
        crop_img = src_img[min_y:max_y+1, min_x:max_x+1]
        img_rgb = cv2.cvtColor(crop_img, cv2.COLOR_BGR2RGB)
        img = Image.fromarray(img_rgb)
        text = pytesseract.image_to_string(img, lang=lang).replace('\n', ' ').strip()
        while '  ' in text:
            text = text.replace('  ', ' ')
        list_texts.append(text)
    return list_texts


def translate(list_texts, translator):
    translated_texts = []
    for text in list_texts:
        if not text:
            text = 'a'
        translated_text = translator.translate(text, src='en', dest='vi').text
        translated_texts.append(translated_text)
    return translated_texts


def add_centered_multiline_text(image, text, box, font_path="arial.ttf", font_size=36, pad=5, text_color=0):
    # Mở ảnh
    draw = ImageDraw.Draw(image)

    # Giải nén box (min_x, min_y, max_x, max_y)
    min_x, min_y, max_x, max_y = box

    # Tạo font
    font = ImageFont.truetype(font_path, font_size)

    # Chia văn bản thành nhiều dòng nếu cần
    wrapped_lines = wrap_text(text, font, draw, max_x - min_x)

    # Tính chiều cao của tất cả các dòng cộng lại
    total_text_height = sum(get_text_height(line, draw, font) for line in wrapped_lines)
    
    # Tính toạ độ y bắt đầu để căn giữa theo chiều dọc
    start_y = min_y + (max_y - min_y - total_text_height) // 2

    # Vẽ từng dòng và căn giữa theo chiều ngang
    current_y = start_y
    for line in wrapped_lines:
        text_width, text_height = get_text_dimensions(line, draw, font)
        text_x = min_x + (max_x - min_x - text_width) // 2  # Căn giữa theo chiều ngang
        draw.text((text_x, current_y), line, fill=text_color, font=font)
        current_y += text_height + pad # Di chuyển y xuống để vẽ dòng tiếp theo

    # Lưu ảnh mới
    return image

def get_text_dimensions(text, draw, font):
    """Trả về (width, height) của văn bản."""
    bbox = draw.textbbox((0, 0), text, font=font)
    width = bbox[2] - bbox[0]
    height = bbox[3] - bbox[1]
    return width, height

def get_text_height(text, draw, font):
    """Trả về chiều cao của văn bản."""
    _, _, _, height = draw.textbbox((0, 0), text, font=font)
    return height

def wrap_text(text, font, draw, max_width):
    """Chia văn bản thành nhiều dòng dựa trên chiều rộng tối đa."""
    words = text.split()
    lines = []
    current_line = ""

    for word in words:
        # Thử thêm từ vào dòng hiện tại
        test_line = f"{current_line} {word}".strip()
        test_width, _ = get_text_dimensions(test_line, draw, font)

        if test_width <= max_width:
            current_line = test_line
        else:
            # Nếu quá rộng, lưu dòng hiện tại và bắt đầu dòng mới
            lines.append(current_line)
            current_line = word

    # Thêm dòng cuối cùng
    if current_line:
        lines.append(current_line)

    return lines

def insert_text(non_text_src_img, list_translated_texts, text_positions, font=['MTO Astro City.ttf'], font_size=[20], pad=[5], text_color=0, stroke=[3]):
    # Copy ảnh không chữ
    img_bgr = non_text_src_img.copy()
    
    # Thêm text vào măt nạ 1
    for idx, text in enumerate(list_translated_texts):
        # Tạo mặt nạ trắng
        mask1 = Image.new("L", img_bgr.shape[:2][::-1], 255)
        mask2 = Image.new("L", img_bgr.shape[:2][::-1], 255)
        mask1 = add_centered_multiline_text(mask1, text, text_positions[idx], f'MTO Font/{font[idx]}', font_size[idx], pad=pad[idx], text_color=text_color)
    
        # Chuyển ảnh từ PIL sang cv2
        mask1 = (np.array(mask1) >= 127).astype(np.uint8) * 255
        mask1 = cv2.cvtColor(mask1, cv2.COLOR_RGB2BGR)
        
        if stroke[idx] > 0:
            mask2 = np.array(mask2).astype(np.uint8)
            mask2 = cv2.cvtColor(mask2, cv2.COLOR_RGB2BGR)
            
            mask2 = mask2 - mask1
            kernel = np.ones((stroke[idx]+1, stroke[idx]+1), np.uint8)
            mask2 = cv2.dilate(mask2, kernel, iterations=1)
            img_bgr[mask2==255] = 255
        
        img_bgr[mask1==text_color] = text_color
    return img_bgr


def save_img(path, translated_text_src_img):
    """

    input:  path: Đường dẫn đến ảnh gốc ban đầu

            translated_text_src_img: Ảnh sau khi được dịch

    output: Ảnh sau dịch được lưu lại, trong tên có thêm "translated-"

    """
    dot = path.rfind('.')
    last_slash = -1
    if '/' in path:
        last_slash = path.rfind('/')
        
    ext = path[dot:]
    parent_path = path[:last_slash+1]
    name = path[last_slash+1:dot]
    
    if parent_path and not os.path.exists(parent_path):
        os.mkdir(parent_path)    
    cv2.imwrite(f'{parent_path}translated-{name}{ext}', translated_text_src_img)
    print(f'Image saved at {parent_path}translated-{name}{ext}')