Upload 7 files
Browse files- .gitattributes +35 -36
- README.md +13 -13
- TransUnet.py +320 -0
- TransUnet_Config.py +7 -0
- app.py +428 -0
- edit_func.py +311 -0
- requirements.txt +13 -0
.gitattributes
CHANGED
@@ -1,36 +1,35 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
-
MTO[[:space:]]Font/MTO[[:space:]]Getting[[:space:]]Angry.ttf filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
README.md
CHANGED
@@ -1,13 +1,13 @@
|
|
1 |
-
---
|
2 |
-
title: TheEditor
|
3 |
-
emoji: 📉
|
4 |
-
colorFrom: pink
|
5 |
-
colorTo: yellow
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 5.3.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
license: apache-2.0
|
11 |
-
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
+
---
|
2 |
+
title: TheEditor
|
3 |
+
emoji: 📉
|
4 |
+
colorFrom: pink
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.3.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: apache-2.0
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
TransUnet.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from positional_encodings.torch_encodings import PositionalEncoding2D
|
5 |
+
|
6 |
+
|
7 |
+
class LayerNorm2D(nn.Module):
|
8 |
+
def __init__(self, embed_dim):
|
9 |
+
super().__init__()
|
10 |
+
self.layer_norm = nn.LayerNorm(embed_dim)
|
11 |
+
|
12 |
+
def forward(self, x):
|
13 |
+
x = x.permute(0, 2, 3, 1)
|
14 |
+
x = self.layer_norm(x)
|
15 |
+
x = x.permute(0, 3, 1, 2)
|
16 |
+
return x
|
17 |
+
|
18 |
+
class Image_Adaptor(nn.Module):
|
19 |
+
def __init__(self, in_channels, adp_channels, dropout=0.1):
|
20 |
+
super().__init__()
|
21 |
+
|
22 |
+
self.adaptor = nn.Sequential(
|
23 |
+
nn.Conv2d(in_channels, adp_channels // 4, kernel_size=4, padding='same'),
|
24 |
+
LayerNorm2D(adp_channels // 4),
|
25 |
+
nn.GELU(),
|
26 |
+
nn.Conv2d(adp_channels // 4, adp_channels // 4, kernel_size=2, padding='same'),
|
27 |
+
LayerNorm2D(adp_channels // 4),
|
28 |
+
nn.GELU(),
|
29 |
+
nn.Conv2d(adp_channels // 4, adp_channels, kernel_size=2, padding='same')
|
30 |
+
)
|
31 |
+
self.dropout = nn.Dropout(dropout)
|
32 |
+
|
33 |
+
def forward(self, images):
|
34 |
+
"""
|
35 |
+
input: [N, in_channels, H, W]
|
36 |
+
output: [N, apd_channels, H, W]
|
37 |
+
"""
|
38 |
+
adapt_imgs = self.adaptor(images)
|
39 |
+
return self.dropout(adapt_imgs)
|
40 |
+
|
41 |
+
class Positional_Encoding(nn.Module):
|
42 |
+
def __init__(self, adp_channels):
|
43 |
+
super().__init__()
|
44 |
+
self.pe = PositionalEncoding2D(adp_channels)
|
45 |
+
|
46 |
+
def forward(self, adapt_imgs):
|
47 |
+
"""
|
48 |
+
input: [N, apd_channels, H, W]
|
49 |
+
output: [N, apd_channels, H, W]
|
50 |
+
"""
|
51 |
+
x = adapt_imgs.permute(0, -2, -1, -3)
|
52 |
+
encode = self.pe(x)
|
53 |
+
encode = encode.permute(0, -1, -3, -2)
|
54 |
+
return encode
|
55 |
+
|
56 |
+
class GeGLU(nn.Module):
|
57 |
+
def __init__(self, emb_channels, ffn_size):
|
58 |
+
super().__init__()
|
59 |
+
self.wi_0 = nn.Linear(emb_channels, ffn_size, bias=False)
|
60 |
+
self.wi_1 = nn.Linear(emb_channels, ffn_size, bias=False)
|
61 |
+
self.act = nn.GELU()
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
x_gelu = self.act(self.wi_0(x))
|
65 |
+
x_linear = self.wi_1(x)
|
66 |
+
x = x_gelu * x_linear
|
67 |
+
return x
|
68 |
+
|
69 |
+
class Feed_Forward(nn.Module):
|
70 |
+
def __init__(self, in_channels, ffw_channels, dropout=0.1):
|
71 |
+
super().__init__()
|
72 |
+
|
73 |
+
self.ln1 = GeGLU(in_channels, ffw_channels)
|
74 |
+
self.dropout = nn.Dropout(dropout)
|
75 |
+
self.ln2 = GeGLU(ffw_channels, in_channels)
|
76 |
+
|
77 |
+
def forward(self, x):
|
78 |
+
'''
|
79 |
+
input: [N, H, W, channels]
|
80 |
+
output: [N, H, W, channels]
|
81 |
+
'''
|
82 |
+
x = self.ln1(x)
|
83 |
+
x = self.dropout(x)
|
84 |
+
x = self.ln2(x)
|
85 |
+
return x
|
86 |
+
|
87 |
+
class MultiHeadAttention(nn.Module):
|
88 |
+
def __init__(self, channels, num_attn_heads, dropout=0.1):
|
89 |
+
super().__init__()
|
90 |
+
|
91 |
+
self.head_size = num_attn_heads
|
92 |
+
self.channels = channels
|
93 |
+
self.attn_size = channels // num_attn_heads
|
94 |
+
self.scale = self.attn_size ** -0.5
|
95 |
+
assert num_attn_heads * self.attn_size == channels, "Input channels of attention must divisible by number of attention head!"
|
96 |
+
|
97 |
+
self.lq = nn.Linear(channels, self.head_size*self.attn_size, bias=False)
|
98 |
+
self.lk = nn.Linear(channels, self.head_size*self.attn_size, bias=False)
|
99 |
+
self.lv = nn.Linear(channels, self.head_size*self.attn_size, bias=False)
|
100 |
+
self.lout = nn.Linear(self.head_size*self.attn_size, channels, bias=False)
|
101 |
+
self.dropout = nn.Dropout(dropout)
|
102 |
+
|
103 |
+
def forward(self, q, k, v):
|
104 |
+
'''
|
105 |
+
input: [N, H, W, channels] cho cả 3 cái q, k, v
|
106 |
+
output: [N, H, W, channels]
|
107 |
+
'''
|
108 |
+
bz, H, W, C = q.shape
|
109 |
+
|
110 |
+
# Duỗi ảnh ra trước
|
111 |
+
q = q.view(bz, -1, C) # [N, H*W, C]
|
112 |
+
k = k.view(bz, -1, C) # [N, H*W, C]
|
113 |
+
v = v.view(bz, -1, C) # [N, H*W, C]
|
114 |
+
|
115 |
+
q = self.lq(q).view(bz, -1, self.head_size, self.attn_size) # [N, H*W, hz, az]
|
116 |
+
k = self.lk(k).view(bz, -1, self.head_size, self.attn_size) # [N, H*W, hz, az]
|
117 |
+
v = self.lv(v).view(bz, -1, self.head_size, self.attn_size) # [N, H*W, hz, az]
|
118 |
+
|
119 |
+
q = q.transpose(1, 2) # [N, hz, H*W, az]
|
120 |
+
k = k.transpose(1, 2).transpose(-1, -2) # [N, hz, az, H*W]
|
121 |
+
v = v.transpose(1, 2) # [N, hz, H*W, az]
|
122 |
+
|
123 |
+
q *= self.scale
|
124 |
+
|
125 |
+
x = torch.matmul(q, k) # [N, hz, H*W, H*W]
|
126 |
+
x = torch.softmax(x, dim=-1)
|
127 |
+
x = self.dropout(x)
|
128 |
+
x = x.matmul(v) # [N, hz, H*W, az]
|
129 |
+
|
130 |
+
x = x.transpose(1, 2).contiguous() # [N, H*W, hz, az]
|
131 |
+
x = x.view(bz, -1, C) # [N, H*W, C]
|
132 |
+
x = x.view(bz, H, W, C) # [N, H, W, C]
|
133 |
+
|
134 |
+
x = self.lout(x) # [N, H, W, C]
|
135 |
+
|
136 |
+
return x
|
137 |
+
|
138 |
+
class Transformer_Encoder_Layer(nn.Module):
|
139 |
+
def __init__(self, channels, num_attn_heads, ffw_channels, dropout=0.1):
|
140 |
+
super().__init__()
|
141 |
+
|
142 |
+
self.attn_norm = nn.LayerNorm(channels)
|
143 |
+
self.attn_layer = MultiHeadAttention(channels, num_attn_heads, dropout)
|
144 |
+
self.attn_dropout = nn.Dropout(dropout)
|
145 |
+
|
146 |
+
self.ffw_norm = nn.LayerNorm(channels)
|
147 |
+
self.ffw_layer = Feed_Forward(channels, ffw_channels, dropout)
|
148 |
+
self.ffw_dropout = nn.Dropout(dropout)
|
149 |
+
|
150 |
+
def forward(self, adp_pos_imgs):
|
151 |
+
"""
|
152 |
+
input: [N, H, W, channels]
|
153 |
+
output: [N, H, W, channels]
|
154 |
+
"""
|
155 |
+
_x = adp_pos_imgs
|
156 |
+
x = self.attn_norm(adp_pos_imgs)
|
157 |
+
x = self.attn_layer(x, x, x)
|
158 |
+
x = self.attn_dropout(x)
|
159 |
+
x = x + _x
|
160 |
+
|
161 |
+
_x = x
|
162 |
+
x = self.ffw_norm(x)
|
163 |
+
x = self.ffw_layer(x)
|
164 |
+
x = self.ffw_dropout(x)
|
165 |
+
x = x + _x
|
166 |
+
return x
|
167 |
+
|
168 |
+
class Transformer_Encoder(nn.Module):
|
169 |
+
def __init__(self, in_channels, out_channels, num_layers, num_attn_heads, ffw_channels, dropout=0.1):
|
170 |
+
super().__init__()
|
171 |
+
|
172 |
+
self.encoder_layers = nn.ModuleList([
|
173 |
+
Transformer_Encoder_Layer(in_channels, num_attn_heads, ffw_channels, dropout) for _ in range(num_layers)
|
174 |
+
])
|
175 |
+
self.linear = nn.Linear(in_channels, out_channels)
|
176 |
+
self.last_norm = LayerNorm2D(out_channels)
|
177 |
+
self.dropout = nn.Dropout(dropout)
|
178 |
+
|
179 |
+
def forward(self, adp_pos_imgs):
|
180 |
+
"""
|
181 |
+
input: [N, in_channels, H, W]
|
182 |
+
output: [N, out_channels, H, W]
|
183 |
+
"""
|
184 |
+
x = adp_pos_imgs.permute(0, -2, -1, -3) # [N, H, W, in_channels]
|
185 |
+
|
186 |
+
for layer in self.encoder_layers:
|
187 |
+
x = layer(x)
|
188 |
+
|
189 |
+
x = self.linear(x) # [N, H, W, out_channels]
|
190 |
+
x = x.permute(0, -1, -3, -2)
|
191 |
+
x = self.last_norm(x)
|
192 |
+
out = self.dropout(x)
|
193 |
+
return out
|
194 |
+
|
195 |
+
class Double_Conv(nn.Module):
|
196 |
+
def __init__(self, in_channels, out_channels):
|
197 |
+
super().__init__()
|
198 |
+
|
199 |
+
self.double_conv = nn.Sequential(
|
200 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
201 |
+
nn.BatchNorm2d(out_channels),
|
202 |
+
nn.ReLU(inplace=True),
|
203 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
204 |
+
nn.BatchNorm2d(out_channels),
|
205 |
+
nn.ReLU(inplace=True)
|
206 |
+
)
|
207 |
+
|
208 |
+
def forward(self, X):
|
209 |
+
"""
|
210 |
+
input: [N, in_channels, H, W]
|
211 |
+
output: [N, out_channels, H//2, W//2]
|
212 |
+
"""
|
213 |
+
return self.double_conv(X)
|
214 |
+
|
215 |
+
class Down(nn.Module):
|
216 |
+
def __init__(self, in_channels, out_channels):
|
217 |
+
super().__init__()
|
218 |
+
|
219 |
+
self.down = nn.Sequential(
|
220 |
+
nn.MaxPool2d(2),
|
221 |
+
Double_Conv(in_channels, out_channels)
|
222 |
+
)
|
223 |
+
|
224 |
+
def forward(self, X):
|
225 |
+
"""
|
226 |
+
input: [N, in_channels, H, W]
|
227 |
+
output: [N, out_channels, H//2, W//2]
|
228 |
+
"""
|
229 |
+
return self.down(X)
|
230 |
+
|
231 |
+
class Up(nn.Module):
|
232 |
+
def __init__(self, in_channels, out_channels):
|
233 |
+
super().__init__()
|
234 |
+
|
235 |
+
self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2)
|
236 |
+
self.conv = Double_Conv(in_channels, out_channels)
|
237 |
+
|
238 |
+
def forward(self, X1, X2):
|
239 |
+
"""
|
240 |
+
input: X1 : [N, in_channels, H // 2, W // 2]
|
241 |
+
X2 : [N, in_channels // 2, H, W]
|
242 |
+
output: X : [N, out_channels, H, W]
|
243 |
+
"""
|
244 |
+
X1 = self.up(X1)
|
245 |
+
|
246 |
+
diffY = X2.shape[-2] - X1.shape[-2]
|
247 |
+
diffX = X2.shape[-1] - X1.shape[-1]
|
248 |
+
|
249 |
+
pad_top = diffY // 2
|
250 |
+
pad_bottom = diffY - pad_top
|
251 |
+
pad_left = diffX // 2
|
252 |
+
pad_right = diffX - pad_left
|
253 |
+
|
254 |
+
X1 = F.pad(X1, (pad_left, pad_right, pad_top, pad_bottom))
|
255 |
+
|
256 |
+
X = torch.cat((X2, X1), dim=-3)
|
257 |
+
return self.conv(X)
|
258 |
+
|
259 |
+
class Out_Conv(nn.Module):
|
260 |
+
def __init__(self, adp_channels, out_channels):
|
261 |
+
super().__init__()
|
262 |
+
|
263 |
+
self.out_conv = nn.Conv2d(adp_channels, out_channels, kernel_size=1)
|
264 |
+
|
265 |
+
def forward(self, X):
|
266 |
+
return self.out_conv(X)
|
267 |
+
|
268 |
+
class Trans_UNet(nn.Module):
|
269 |
+
def __init__(self,
|
270 |
+
in_channels,
|
271 |
+
adp_channels,
|
272 |
+
out_channels,
|
273 |
+
trans_num_layers=5,
|
274 |
+
trans_num_attn_heads=8,
|
275 |
+
trans_ffw_channels=1024,
|
276 |
+
dropout=0.1):
|
277 |
+
super().__init__()
|
278 |
+
|
279 |
+
self.img_adaptor = Image_Adaptor(in_channels, adp_channels, dropout)
|
280 |
+
self.pos_encoding = Positional_Encoding(adp_channels)
|
281 |
+
|
282 |
+
self.down1 = Down(adp_channels * 1, adp_channels * 2)
|
283 |
+
self.down2 = Down(adp_channels * 2, adp_channels * 4)
|
284 |
+
self.down3 = Down(adp_channels * 4, adp_channels * 8)
|
285 |
+
self.down4 = Down(adp_channels * 8, adp_channels * 16)
|
286 |
+
self.down5 = Down(adp_channels * 16, adp_channels * 32)
|
287 |
+
|
288 |
+
self.trans_encoder = Transformer_Encoder(adp_channels * 32, adp_channels * 32, trans_num_layers, trans_num_attn_heads, trans_ffw_channels, dropout)
|
289 |
+
|
290 |
+
self.up5 = Up(adp_channels * 32, adp_channels * 16)
|
291 |
+
self.up4 = Up(adp_channels * 16, adp_channels * 8)
|
292 |
+
self.up3 = Up(adp_channels * 8, adp_channels * 4)
|
293 |
+
self.up2 = Up(adp_channels * 4, adp_channels * 2)
|
294 |
+
self.up1 = Up(adp_channels * 2, adp_channels * 1)
|
295 |
+
|
296 |
+
self.out_conv = Out_Conv(adp_channels, out_channels)
|
297 |
+
self.sigmoid = nn.Sigmoid()
|
298 |
+
|
299 |
+
def forward(self, images):
|
300 |
+
adp_imgs = self.img_adaptor(images)
|
301 |
+
pos_enc = self.pos_encoding(adp_imgs)
|
302 |
+
adp_imgs += pos_enc
|
303 |
+
|
304 |
+
d1 = self.down1(adp_imgs)
|
305 |
+
d2 = self.down2(d1)
|
306 |
+
d3 = self.down3(d2)
|
307 |
+
d4 = self.down4(d3)
|
308 |
+
d5 = self.down5(d4)
|
309 |
+
|
310 |
+
x = self.trans_encoder(d5)
|
311 |
+
|
312 |
+
u5 = self.up5(x, d4)
|
313 |
+
u4 = self.up4(u5, d3)
|
314 |
+
u3 = self.up3(u4, d2)
|
315 |
+
u2 = self.up2(u3, d1)
|
316 |
+
u1 = self.up1(u2, adp_imgs)
|
317 |
+
|
318 |
+
x = self.out_conv(u1)
|
319 |
+
out = self.sigmoid(x)
|
320 |
+
return out
|
TransUnet_Config.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
in_channels = 3
|
2 |
+
adp_channels = 32
|
3 |
+
out_channels = 1
|
4 |
+
trans_num_layers = 5
|
5 |
+
trans_num_attn_heads = 8
|
6 |
+
trans_ffw_channels = 512
|
7 |
+
dropout = 0.1
|
app.py
ADDED
@@ -0,0 +1,428 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import gradio as gr
|
3 |
+
import os
|
4 |
+
from edit_func import *
|
5 |
+
from TransUnet import Trans_UNet
|
6 |
+
import TransUnet_Config as config2
|
7 |
+
from huggingface_hub import hf_hub_download
|
8 |
+
from googletrans import Translator
|
9 |
+
import random
|
10 |
+
import torch.nn as nn
|
11 |
+
import spaces
|
12 |
+
|
13 |
+
@spaces.GPU
|
14 |
+
class DTM(nn.Module):
|
15 |
+
def __init__(self):
|
16 |
+
super().__init__()
|
17 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
18 |
+
self.detect_text_model = Trans_UNet(
|
19 |
+
config2.in_channels, config2.adp_channels, config2.out_channels,
|
20 |
+
config2.trans_num_layers, config2.trans_num_attn_heads, config2.trans_ffw_channels,
|
21 |
+
config2.dropout
|
22 |
+
).to(self.device)
|
23 |
+
self.repo_name = 'SS3M/detect-text-model'
|
24 |
+
files = ['detect-text-v3-0.pt', 'detect-text-v3-1.pt',
|
25 |
+
'detect-text-v3-2.pt', 'detect-text-v3-3.pt',
|
26 |
+
'detect-text-v3-4.pt', 'detect-text-v3-5.pt',
|
27 |
+
'detect-text-v3-6.pt', 'detect-text-v3-7.pt']
|
28 |
+
self.files = []
|
29 |
+
for file in files:
|
30 |
+
self.files.append(hf_hub_download(repo_id=self.repo_name, filename=file))
|
31 |
+
|
32 |
+
def forward(self, X):
|
33 |
+
X = X.to(self.device)
|
34 |
+
N, C, H, W = X.shape
|
35 |
+
result = torch.zeros((N, 1, H, W))
|
36 |
+
for file in self.files:
|
37 |
+
model_path = file
|
38 |
+
best_model_state = torch.load(
|
39 |
+
model_path,
|
40 |
+
weights_only=True,
|
41 |
+
map_location=self.device
|
42 |
+
)
|
43 |
+
self.detect_text_model.load_state_dict(best_model_state)
|
44 |
+
result += self.detect_text_model(X)
|
45 |
+
result /= len(self.files)
|
46 |
+
return result
|
47 |
+
|
48 |
+
@spaces.GPU
|
49 |
+
class DWBM(nn.Module):
|
50 |
+
def __init__(self):
|
51 |
+
super().__init__()
|
52 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
53 |
+
self.detect_wordball_model = Trans_UNet(
|
54 |
+
config2.in_channels, config2.adp_channels, config2.out_channels,
|
55 |
+
config2.trans_num_layers, config2.trans_num_attn_heads, config2.trans_ffw_channels,
|
56 |
+
config2.dropout
|
57 |
+
).to(self.device)
|
58 |
+
self.repo_name = 'SS3M/detect-wordball-model'
|
59 |
+
files = ['detect-text-v3-0.pt', 'detect-text-v3-1.pt',
|
60 |
+
'detect-text-v3-2.pt', 'detect-text-v3-3.pt',
|
61 |
+
'detect-text-v3-4.pt', 'detect-text-v3-5.pt',
|
62 |
+
'detect-text-v3-6.pt', 'detect-text-v3-7.pt']
|
63 |
+
self.files = []
|
64 |
+
for file in files:
|
65 |
+
self.files.append(hf_hub_download(repo_id=self.repo_name, filename=file))
|
66 |
+
|
67 |
+
def forward(self, X):
|
68 |
+
X = X.to(self.device)
|
69 |
+
N, C, H, W = X.shape
|
70 |
+
result = torch.zeros((N, 1, H, W))
|
71 |
+
for file in self.files:
|
72 |
+
model_path = file
|
73 |
+
best_model_state = torch.load(
|
74 |
+
model_path,
|
75 |
+
weights_only=True,
|
76 |
+
map_location=self.device
|
77 |
+
)
|
78 |
+
self.detect_wordball_model.load_state_dict(best_model_state)
|
79 |
+
result += self.detect_wordball_model(X)
|
80 |
+
result /= len(self.files)
|
81 |
+
return result
|
82 |
+
|
83 |
+
detect_text_model = DTM()
|
84 |
+
detect_wordball_model = DWBM()
|
85 |
+
|
86 |
+
translator = Translator()
|
87 |
+
|
88 |
+
def down1(src_img):
|
89 |
+
src_img = cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR)
|
90 |
+
text_msk = create_text_mask(src_img, detect_text_model)
|
91 |
+
wordball_msk = create_wordball_mask(src_img, detect_wordball_model)
|
92 |
+
|
93 |
+
text_positions, areas = get_text_positions(text_msk, text_value=0)
|
94 |
+
rgbs = []
|
95 |
+
for _ in range(len(areas)):
|
96 |
+
rgbs.append((random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)))
|
97 |
+
|
98 |
+
idx = '; '.join(str(i) for i in range(len(areas)))
|
99 |
+
text_positions = '; '.join([', '.join(str(i) for i in pos) for pos in text_positions])
|
100 |
+
areas = '; '.join(str(i) for i in areas)
|
101 |
+
rgbs = '; '.join([', '.join(str(i) for i in rgb) for rgb in rgbs])
|
102 |
+
src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
|
103 |
+
return text_msk*255, wordball_msk*255, idx, text_positions, areas, rgbs, 'Xong'
|
104 |
+
|
105 |
+
def idx_txt_change(src_img, idx_txt, pos_txt, rgb_txt):
|
106 |
+
try:
|
107 |
+
src_img2 = cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR)
|
108 |
+
text_positions = pos_txt.split('; ')
|
109 |
+
for idx in range(len(text_positions)):
|
110 |
+
text_positions[idx] = (int(i) for i in text_positions[idx].split(', '))
|
111 |
+
rgbs = rgb_txt.split('; ')
|
112 |
+
for idx in range(len(rgbs)):
|
113 |
+
rgbs[idx] = (int(i) for i in rgbs[idx].split(', '))
|
114 |
+
idxes = [int(idx) for idx in idx_txt.split('; ')]
|
115 |
+
|
116 |
+
for idx, ((min_x, min_y, max_x, max_y), (r, g, b)) in enumerate(zip(text_positions, rgbs)):
|
117 |
+
if idx in idxes:
|
118 |
+
cv2.rectangle(src_img2, (min_x, min_y), (max_x, max_y), (b, g, r), thickness=4)
|
119 |
+
src_img2 = cv2.cvtColor(src_img2, cv2.COLOR_BGR2RGB)
|
120 |
+
return src_img2
|
121 |
+
except:
|
122 |
+
return src_img
|
123 |
+
|
124 |
+
def scale_area_change(min_area, max_area, area_txt):
|
125 |
+
areas = [int(area) for area in area_txt.split('; ')]
|
126 |
+
idxes = []
|
127 |
+
for idx, area in enumerate(areas):
|
128 |
+
if min_area <= area <= max_area:
|
129 |
+
idxes.append(idx)
|
130 |
+
idxes = '; '.join(str(i) for i in idxes)
|
131 |
+
return idxes
|
132 |
+
|
133 |
+
def position_block_change(X, Y, W, H, ID, pos_txt_value):
|
134 |
+
text_positions = pos_txt_value.split('; ')
|
135 |
+
for idx in range(len(text_positions)):
|
136 |
+
text_positions[idx] = (int(i) for i in text_positions[idx].split(', '))
|
137 |
+
|
138 |
+
text_positions2 = []
|
139 |
+
for idx, (min_x, min_y, max_x, max_y) in enumerate(text_positions):
|
140 |
+
if idx == ID:
|
141 |
+
text_positions2.append((X, Y, X+W, Y+H))
|
142 |
+
else:
|
143 |
+
text_positions2.append((min_x, min_y, max_x, max_y))
|
144 |
+
text_positions2 = '; '.join([', '.join(str(i) for i in pos) for pos in text_positions2])
|
145 |
+
return text_positions2
|
146 |
+
|
147 |
+
def ID_block_change(ID_value, checkbox_value, ID_txt_value):
|
148 |
+
ID_txt_value = [int(i) for i in ID_txt_value.split('; ')]
|
149 |
+
if checkbox_value and ID_value not in ID_txt_value:
|
150 |
+
ID_txt_value.append(ID_value)
|
151 |
+
if not checkbox_value and ID_value in ID_txt_value:
|
152 |
+
ID_txt_value.remove(ID_value)
|
153 |
+
ID_txt_value = sorted(ID_txt_value)
|
154 |
+
ID_txt_value = '; '.join([str(i) for i in ID_txt_value])
|
155 |
+
return ID_txt_value
|
156 |
+
|
157 |
+
def down2(src_img_value, txt_mask_value, wordball_mask_value, idx_txt_value, pos_txt_value):
|
158 |
+
src_img_value = cv2.cvtColor(src_img_value, cv2.COLOR_RGB2BGR)
|
159 |
+
text_positions = pos_txt_value.split('; ')
|
160 |
+
for idx in range(len(text_positions)):
|
161 |
+
text_positions[idx] = (int(i) for i in text_positions[idx].split(', '))
|
162 |
+
idxes = [int(i) for i in idx_txt_value.split('; ')]
|
163 |
+
|
164 |
+
for idx, (min_x, min_y, max_x, max_y) in enumerate(text_positions):
|
165 |
+
if idx not in idxes:
|
166 |
+
txt_mask_value[min_y:max_y+1, min_x:max_x+1] = 255
|
167 |
+
txt_mask_value = txt_mask_value[:, :, 0].astype(np.uint8)
|
168 |
+
non_text_src_img = clear_text(src_img_value, txt_mask_value, wordball_mask_value, text_value=0, non_text_value=255, r=5)
|
169 |
+
|
170 |
+
list_texts = get_list_texts(src_img_value, [tuple(map(int, pos.split(', '))) for idx, pos in enumerate(pos_txt_value.split('; ')) if idx in idxes])
|
171 |
+
list_translated_texts = translate(list_texts, translator)
|
172 |
+
list_fonts = '; '.join(['MTO Astro City.ttf' for _ in range(len(list_translated_texts))])
|
173 |
+
list_sizes = '; '.join(['20' for _ in range(len(list_translated_texts))])
|
174 |
+
list_strokes = '; '.join(['3' for _ in range(len(list_translated_texts))])
|
175 |
+
list_pads = '; '.join(['5' for _ in range(len(list_translated_texts))])
|
176 |
+
list_translated_texts = '; '.join(list_translated_texts)
|
177 |
+
switch = str(random.random())
|
178 |
+
|
179 |
+
return non_text_src_img, list_translated_texts, list_fonts, list_sizes, list_strokes, list_pads, switch, 'Xong'
|
180 |
+
|
181 |
+
def text_info_change(non_txt_img_value, translated_txt_value, pos_txt_value, idx_txt_value, font_txt_value, size_txt_value, stroke_txt_value, pad_txt_value):
|
182 |
+
non_txt_img_value = non_txt_img_value.copy()
|
183 |
+
idxes = [int(i) for i in idx_txt_value.split('; ')]
|
184 |
+
|
185 |
+
translated_text_src_img = insert_text(non_txt_img_value,
|
186 |
+
translated_txt_value.split('; '),
|
187 |
+
[tuple(map(int, pos.split(', '))) for idx, pos in enumerate(pos_txt_value.split('; ')) if idx in idxes],
|
188 |
+
font=font_txt_value.split('; '),
|
189 |
+
font_size=[int(i) for i in size_txt_value.split('; ')],
|
190 |
+
pad=[int(i) for i in pad_txt_value.split('; ')],
|
191 |
+
stroke=[int(i) for i in stroke_txt_value.split('; ')])
|
192 |
+
return translated_text_src_img
|
193 |
+
|
194 |
+
def value2_change(value, ID2_value, txt_value):
|
195 |
+
txt_value = txt_value.split('; ')
|
196 |
+
|
197 |
+
txt_value2 = []
|
198 |
+
for idx, text in enumerate(txt_value):
|
199 |
+
if idx == ID2_value:
|
200 |
+
txt_value2.append(str(value))
|
201 |
+
else:
|
202 |
+
txt_value2.append(str(text))
|
203 |
+
txt_value2 = '; '.join(txt_value2)
|
204 |
+
return txt_value2
|
205 |
+
|
206 |
+
# Tạo giao diện Gradio
|
207 |
+
with gr.Blocks() as demo:
|
208 |
+
# Cấu trúc
|
209 |
+
src_img = gr.Image(type="numpy", label="Upload Image")
|
210 |
+
|
211 |
+
down_bttn_1 = gr.Button("↓", elem_classes="arrow-button")
|
212 |
+
|
213 |
+
with gr.Row():
|
214 |
+
txt_mask = gr.Image(type="numpy", label="Upload Image", visible=True)
|
215 |
+
wordball_mask = gr.Image(type="numpy", label="Upload Image", visible=True)
|
216 |
+
complete = gr.Textbox()
|
217 |
+
with gr.Row():
|
218 |
+
idx_txt = gr.Textbox(label='ID', interactive=False, visible=False)
|
219 |
+
pos_txt = gr.Textbox(label='Pos', interactive=False, visible=False)
|
220 |
+
area_txt = gr.Textbox(label='Area', interactive=False, visible=False)
|
221 |
+
rgb_txt = gr.Textbox(label='rgb', interactive=False, visible=False)
|
222 |
+
with gr.Row():
|
223 |
+
boxed_txt_img = gr.Image(type="numpy", label="Upload Image")
|
224 |
+
with gr.Column() as down_1_column:
|
225 |
+
@gr.render(inputs=[pos_txt, rgb_txt], triggers=[rgb_txt.change])
|
226 |
+
def create_box(pos_txt_value, rgb_txt_value):
|
227 |
+
text_positions = pos_txt_value.split('; ')
|
228 |
+
for idx in range(len(text_positions)):
|
229 |
+
text_positions[idx] = (int(i) for i in text_positions[idx].split(', '))
|
230 |
+
rgbs = rgb_txt_value.split('; ')
|
231 |
+
for idx in range(len(rgbs)):
|
232 |
+
rgbs[idx] = (int(i) for i in rgbs[idx].split(', '))
|
233 |
+
|
234 |
+
elements = []
|
235 |
+
for idx, (min_x, min_y, max_x, max_y) in enumerate(text_positions):
|
236 |
+
with gr.Group() as box:
|
237 |
+
r, g, b = rgbs[idx]
|
238 |
+
with gr.Row():
|
239 |
+
gr.Markdown(
|
240 |
+
f"""
|
241 |
+
<div style="margin-left: 20px; display: flex; align-items: center;">
|
242 |
+
<div style="width: 10px; height: 10px; background-color: rgb({r}, {g}, {b}); margin-right: 5px;"></div>
|
243 |
+
<span style="font-size: 20px;">Textbox {idx+1}</span>
|
244 |
+
</div>
|
245 |
+
"""
|
246 |
+
)
|
247 |
+
checkbox = gr.Checkbox(value=True, label='', min_width=50, interactive=True)
|
248 |
+
with gr.Row():
|
249 |
+
X = gr.Number(label="X", value=min_x, interactive=True)
|
250 |
+
Y = gr.Number(label="Y", value=min_y, interactive=True)
|
251 |
+
W = gr.Number(label="W", value=max_x-min_x, interactive=True)
|
252 |
+
H = gr.Number(label="H", value=max_y-min_y, interactive=True)
|
253 |
+
ID = gr.Number(label="ID", value=idx, interactive=True, visible=False)
|
254 |
+
elements.append((X, Y, W, H, ID))
|
255 |
+
|
256 |
+
checkbox.change(
|
257 |
+
fn=ID_block_change,
|
258 |
+
inputs=[ID, checkbox, idx_txt],
|
259 |
+
outputs=idx_txt,
|
260 |
+
show_progress=False
|
261 |
+
).then(
|
262 |
+
fn=idx_txt_change,
|
263 |
+
inputs=[src_img, idx_txt, pos_txt, rgb_txt],
|
264 |
+
outputs=boxed_txt_img,
|
265 |
+
)
|
266 |
+
X.change(
|
267 |
+
fn=position_block_change,
|
268 |
+
inputs=[X, Y, W, H, ID, pos_txt],
|
269 |
+
outputs=pos_txt,
|
270 |
+
show_progress=False
|
271 |
+
).then(
|
272 |
+
fn=idx_txt_change,
|
273 |
+
inputs=[src_img, idx_txt, pos_txt, rgb_txt],
|
274 |
+
outputs=boxed_txt_img,
|
275 |
+
show_progress=False
|
276 |
+
)
|
277 |
+
Y.change(
|
278 |
+
fn=position_block_change,
|
279 |
+
inputs=[X, Y, W, H, ID, pos_txt],
|
280 |
+
outputs=pos_txt,
|
281 |
+
show_progress=False
|
282 |
+
).then(
|
283 |
+
fn=idx_txt_change,
|
284 |
+
inputs=[src_img, idx_txt, pos_txt, rgb_txt],
|
285 |
+
outputs=boxed_txt_img,
|
286 |
+
show_progress=False
|
287 |
+
)
|
288 |
+
W.change(
|
289 |
+
fn=position_block_change,
|
290 |
+
inputs=[X, Y, W, H, ID, pos_txt],
|
291 |
+
outputs=pos_txt,
|
292 |
+
show_progress=False
|
293 |
+
).then(
|
294 |
+
fn=idx_txt_change,
|
295 |
+
inputs=[src_img, idx_txt, pos_txt, rgb_txt],
|
296 |
+
outputs=boxed_txt_img,
|
297 |
+
show_progress=False
|
298 |
+
)
|
299 |
+
H.change(
|
300 |
+
fn=position_block_change,
|
301 |
+
inputs=[X, Y, W, H, ID, pos_txt],
|
302 |
+
outputs=pos_txt,
|
303 |
+
show_progress=False
|
304 |
+
).then(
|
305 |
+
fn=idx_txt_change,
|
306 |
+
inputs=[src_img, idx_txt, pos_txt, rgb_txt],
|
307 |
+
outputs=boxed_txt_img,
|
308 |
+
show_progress=False
|
309 |
+
)
|
310 |
+
down_bttn_2 = gr.Button("↓", elem_classes="arrow-button")
|
311 |
+
|
312 |
+
non_txt_img = gr.Image(type="numpy", label="Upload Image", visible=False)
|
313 |
+
complete2 = gr.Textbox()
|
314 |
+
with gr.Row():
|
315 |
+
translated_txt = gr.Textbox(label='translated', interactive=False, visible=False)
|
316 |
+
font_txt = gr.Textbox(label='font', interactive=False, visible=False)
|
317 |
+
size_txt = gr.Textbox(label='size', interactive=False, visible=False)
|
318 |
+
stroke_txt = gr.Textbox(label='stroke', interactive=False, visible=False)
|
319 |
+
pad_txt = gr.Textbox(label='pad', interactive=False, visible=False)
|
320 |
+
switch_txt = gr.Textbox(label='switch', value='1', interactive=False, visible=False)
|
321 |
+
with gr.Row():
|
322 |
+
boxed_inserted_non_txt_img = gr.Image(type="numpy", label="Upload Image")
|
323 |
+
with gr.Column():
|
324 |
+
@gr.render(inputs=[translated_txt, font_txt, size_txt, stroke_txt, pad_txt], triggers=[switch_txt.change])
|
325 |
+
def create_box2(translated_txt_value, font_txt_value, size_txt_value, stroke_txt_value, pad_txt_value):
|
326 |
+
translated_txt_value = translated_txt_value.split('; ')
|
327 |
+
font_txt_value = font_txt_value.split('; ')
|
328 |
+
size_txt_value = size_txt_value.split('; ')
|
329 |
+
stroke_txt_value = stroke_txt_value.split('; ')
|
330 |
+
pad_txt_value = pad_txt_value.split('; ')
|
331 |
+
|
332 |
+
elements = []
|
333 |
+
for idx in range(len(font_txt_value)):
|
334 |
+
with gr.Group():
|
335 |
+
gr.Markdown(
|
336 |
+
f"""
|
337 |
+
<div style="margin-left: 20px; display: flex; align-items: center;">
|
338 |
+
<div style="width: 10px; height: 10px; background-color: rgb(255, 255, 255); margin-right: 5px;"></div>
|
339 |
+
<span style="font-size: 20px;">Text box {idx}</span>
|
340 |
+
</div>
|
341 |
+
"""
|
342 |
+
)
|
343 |
+
translated_text_box = gr.Textbox(label="Translate", value=translated_txt_value[idx], interactive=True)
|
344 |
+
with gr.Row():
|
345 |
+
font = gr.Dropdown(choices=os.listdir('MTO Font'), label="Phông chữ", value=font_txt_value[idx], interactive=True, scale=7)
|
346 |
+
size = gr.Number(label="Size", value=int(size_txt_value[idx]), interactive=True, minimum=1)
|
347 |
+
stroke = gr.Number(label="Stroke", value=int(stroke_txt_value[idx]), interactive=True, minimum=0, maximum=5)
|
348 |
+
pad = gr.Number(label="Pad", value=int(pad_txt_value[idx]), interactive=True, minimum=1, maximum=10)
|
349 |
+
ID2 = gr.Number(label="ID", value=int(idx), interactive=True, visible=False)
|
350 |
+
|
351 |
+
translated_text_box.submit(
|
352 |
+
fn=value2_change,
|
353 |
+
inputs=[translated_text_box, ID2, translated_txt],
|
354 |
+
outputs=translated_txt,
|
355 |
+
show_progress=False
|
356 |
+
).then(
|
357 |
+
fn=text_info_change,
|
358 |
+
inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt],
|
359 |
+
outputs=boxed_inserted_non_txt_img,
|
360 |
+
)
|
361 |
+
font.change(
|
362 |
+
fn=value2_change,
|
363 |
+
inputs=[font, ID2, font_txt],
|
364 |
+
outputs=font_txt,
|
365 |
+
show_progress=False
|
366 |
+
).then(
|
367 |
+
fn=text_info_change,
|
368 |
+
inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt],
|
369 |
+
outputs=boxed_inserted_non_txt_img,
|
370 |
+
)
|
371 |
+
size.change(
|
372 |
+
fn=value2_change,
|
373 |
+
inputs=[size, ID2, size_txt],
|
374 |
+
outputs=size_txt,
|
375 |
+
show_progress=False
|
376 |
+
).then(
|
377 |
+
fn=text_info_change,
|
378 |
+
inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt],
|
379 |
+
outputs=boxed_inserted_non_txt_img,
|
380 |
+
)
|
381 |
+
stroke.change(
|
382 |
+
fn=value2_change,
|
383 |
+
inputs=[stroke, ID2, stroke_txt],
|
384 |
+
outputs=stroke_txt,
|
385 |
+
show_progress=False
|
386 |
+
).then(
|
387 |
+
fn=text_info_change,
|
388 |
+
inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt],
|
389 |
+
outputs=boxed_inserted_non_txt_img,
|
390 |
+
)
|
391 |
+
pad.change(
|
392 |
+
fn=value2_change,
|
393 |
+
inputs=[pad, ID2, pad_txt],
|
394 |
+
outputs=pad_txt,
|
395 |
+
show_progress=False
|
396 |
+
).then(
|
397 |
+
fn=text_info_change,
|
398 |
+
inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt],
|
399 |
+
outputs=boxed_inserted_non_txt_img,
|
400 |
+
)
|
401 |
+
|
402 |
+
# Css
|
403 |
+
demo.css = """
|
404 |
+
.arrow-button {
|
405 |
+
font-size: 40px; /* Kích thước font */
|
406 |
+
}
|
407 |
+
.group-elem {
|
408 |
+
height: 70px;
|
409 |
+
}
|
410 |
+
"""
|
411 |
+
|
412 |
+
# Điều khiển
|
413 |
+
down_bttn_1.click(
|
414 |
+
fn=down1,
|
415 |
+
inputs=src_img,
|
416 |
+
outputs=[txt_mask, wordball_mask, idx_txt, pos_txt, area_txt, rgb_txt, complete],
|
417 |
+
)
|
418 |
+
down_bttn_2.click(
|
419 |
+
fn=down2,
|
420 |
+
inputs=[src_img, txt_mask, wordball_mask, idx_txt, pos_txt],
|
421 |
+
outputs=[non_txt_img, translated_txt, font_txt, size_txt, stroke_txt, pad_txt, switch_txt, complete2],
|
422 |
+
).then(
|
423 |
+
fn=text_info_change,
|
424 |
+
inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt],
|
425 |
+
outputs=boxed_inserted_non_txt_img,
|
426 |
+
)
|
427 |
+
|
428 |
+
demo.launch()
|
edit_func.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import cv2
|
3 |
+
import pytesseract
|
4 |
+
from PIL import Image, ImageDraw, ImageFont
|
5 |
+
from collections import deque
|
6 |
+
import numpy as np
|
7 |
+
import os
|
8 |
+
|
9 |
+
# pytesseract.pytesseract.tesseract_cmd = 'Tesseract\\tesseract.exe'
|
10 |
+
|
11 |
+
def get_full_img_path(src_dir):
|
12 |
+
"""
|
13 |
+
input: Đường dẫn đền folder chứa ảnh
|
14 |
+
output: Danh sách tên của tất cả các ảnh
|
15 |
+
"""
|
16 |
+
list_img_names = []
|
17 |
+
for dirname, _, filenames in os.walk(src_dir):
|
18 |
+
for filename in filenames:
|
19 |
+
path = os.path.join(dirname, filename).replace(src_dir, '')
|
20 |
+
if path[0] == '/':
|
21 |
+
path = path[1:]
|
22 |
+
list_img_names.append(path)
|
23 |
+
return list_img_names
|
24 |
+
|
25 |
+
|
26 |
+
def create_text_mask(src_img, detect_text_model, kernel_size=5, iterations=3):
|
27 |
+
"""
|
28 |
+
input: Ảnh gốc, để dưới định dạng là np.array, shape: [H, W, C]
|
29 |
+
output: Mask đánh dấu text trong ảnh gốc, 0 là chữ, 1 là nền; shape: [H, W]
|
30 |
+
"""
|
31 |
+
img = torch.from_numpy(src_img).to(torch.uint8).to(detect_text_model.device)
|
32 |
+
imgT = (img / 255).unsqueeze(0).permute(0, -1, -3, -2)
|
33 |
+
|
34 |
+
detect_text_model.eval()
|
35 |
+
with torch.no_grad():
|
36 |
+
result = detect_text_model(imgT).squeeze()
|
37 |
+
result = (result >= 0.5).detach().cpu().numpy()
|
38 |
+
|
39 |
+
mask = ((1-result) * 255).astype(np.uint8)
|
40 |
+
|
41 |
+
kernel = np.ones((kernel_size, kernel_size), np.uint8)
|
42 |
+
mask = cv2.erode(mask, kernel, iterations=iterations)
|
43 |
+
mask = cv2.dilate(mask, kernel, iterations=2*iterations)
|
44 |
+
mask = cv2.erode(mask, kernel, iterations=iterations)
|
45 |
+
|
46 |
+
mask = (1 - mask // 255).astype(np.uint8)
|
47 |
+
return mask
|
48 |
+
|
49 |
+
|
50 |
+
def create_wordball_mask(src_img, detect_wordball_model, kernel_size=5, iterations=3):
|
51 |
+
"""
|
52 |
+
input: Ảnh gốc, để dưới định dạng là np.array, shape: [H, W, C]
|
53 |
+
output: Mask đánh dấu text trong ảnh gốc, 0 là chữ, 1 là nền; shape: [H, W]
|
54 |
+
"""
|
55 |
+
img = torch.from_numpy(src_img).to(torch.uint8).to(detect_wordball_model.device)
|
56 |
+
imgT = (img / 255).unsqueeze(0).permute(0, -1, -3, -2)
|
57 |
+
|
58 |
+
detect_wordball_model.eval()
|
59 |
+
with torch.no_grad():
|
60 |
+
result = detect_wordball_model(imgT).squeeze()
|
61 |
+
result = (result >= 0.5).detach().cpu().numpy()
|
62 |
+
|
63 |
+
mask = ((1-result) * 255).astype(np.uint8)
|
64 |
+
|
65 |
+
kernel = np.ones((kernel_size, kernel_size), np.uint8)
|
66 |
+
mask = cv2.erode(mask, kernel, iterations=iterations)
|
67 |
+
mask = cv2.dilate(mask, kernel, iterations=2*iterations)
|
68 |
+
mask = cv2.erode(mask, kernel, iterations=iterations)
|
69 |
+
|
70 |
+
mask = (1 - mask // 255).astype(np.uint8)
|
71 |
+
return mask
|
72 |
+
|
73 |
+
|
74 |
+
def clear_text(src_img, text_msk, wordball_msk, text_value=0, non_text_value=1, r=5):
|
75 |
+
"""
|
76 |
+
input: src_img: Ảnh gốc, để dưới định dạng là np.array, shape: [H, W, C]
|
77 |
+
text_msk: Mask đánh dấu text trong ảnh gốc; shape: [H, W]
|
78 |
+
text_value: Giá trị mà trong mặt nạ nó là text
|
79 |
+
non_text_value: Giá trị mà trong mặt nạ nó là nền
|
80 |
+
r: Bán kính để sử dụng cho việc xoá text và vẽ lại phần bị xoá
|
81 |
+
output: Ảnh sau khi xoá text, để dưới định dạng là np.array, shape: [H, W, C]
|
82 |
+
"""
|
83 |
+
MAX = max(text_value, non_text_value)
|
84 |
+
MIN = min(text_value, non_text_value)
|
85 |
+
|
86 |
+
scale_text_value = (text_value - MIN) / (MAX - MIN)
|
87 |
+
scale_non_text_value = (non_text_value - MIN) / (MAX - MIN)
|
88 |
+
|
89 |
+
text_msk[text_msk==text_value] = scale_text_value
|
90 |
+
text_msk[text_msk==non_text_value] = scale_non_text_value
|
91 |
+
|
92 |
+
wordball_msk[wordball_msk==text_value] = scale_text_value
|
93 |
+
wordball_msk[wordball_msk==non_text_value] = scale_non_text_value
|
94 |
+
|
95 |
+
if scale_text_value == 0:
|
96 |
+
text_msk = 1 - text_msk
|
97 |
+
wordball_msk = 1 - wordball_msk
|
98 |
+
text_msk = text_msk * 255
|
99 |
+
|
100 |
+
remove_txt = cv2.inpaint(src_img, text_msk, r, cv2.INPAINT_TELEA)
|
101 |
+
remove_wordball = remove_txt.copy()
|
102 |
+
remove_wordball[wordball_msk==1] = 255
|
103 |
+
|
104 |
+
return remove_wordball
|
105 |
+
|
106 |
+
|
107 |
+
def dfs(grid, y, x, visited, value):
|
108 |
+
"""
|
109 |
+
Thuật toán tìm miền liên thông, xem thêm về đồ thị nếu không biết nó là gì
|
110 |
+
Output: Một HCN bao phủ miền liên thông + Diện tích của miền liên thông
|
111 |
+
"""
|
112 |
+
max_y, max_x = y, x
|
113 |
+
min_y, min_x = y+1, x+1
|
114 |
+
area = 0
|
115 |
+
|
116 |
+
stack = deque([(y, x)])
|
117 |
+
while stack:
|
118 |
+
y, x = stack.pop()
|
119 |
+
|
120 |
+
max_x = max(max_x, x)
|
121 |
+
max_y = max(max_y, y)
|
122 |
+
min_x = min(min_x, x)
|
123 |
+
min_y = min(min_y, y)
|
124 |
+
|
125 |
+
if (y, x) not in visited:
|
126 |
+
visited.add((y, x))
|
127 |
+
area += 1
|
128 |
+
# Kiểm tra các ô liền kề
|
129 |
+
for dx, dy in [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]:
|
130 |
+
nx, ny = x + dx, y + dy
|
131 |
+
if 0 <= ny < grid.shape[0] and 0 <= nx < grid.shape[1] and grid[ny, nx] == value and (ny, nx) not in visited:
|
132 |
+
stack.append((ny, nx))
|
133 |
+
|
134 |
+
return (min_x, min_y, max_x, max_y), area
|
135 |
+
|
136 |
+
|
137 |
+
def find_clusters(grid, value):
|
138 |
+
"""
|
139 |
+
Thuật toán tìm danh sách các miền liên thông
|
140 |
+
"""
|
141 |
+
visited = set()
|
142 |
+
clusters = []
|
143 |
+
areas = []
|
144 |
+
|
145 |
+
for y in range(grid.shape[0]):
|
146 |
+
for x in range(grid.shape[1]):
|
147 |
+
if grid[y, x] == value and (y, x) not in visited:
|
148 |
+
cluster, area = dfs(grid, y, x, visited, value)
|
149 |
+
clusters.append(cluster)
|
150 |
+
areas.append(area)
|
151 |
+
|
152 |
+
return clusters, areas
|
153 |
+
|
154 |
+
def get_text_positions(text_msk, text_value=0):
|
155 |
+
"""
|
156 |
+
input: text_msk: Mask đánh dấu text trong ảnh gốc; shape: [H, W]
|
157 |
+
text_value: Giá trị mà trong mặt nạ nó là text
|
158 |
+
min_area: Giả trị tối thiểu của vùng có thể có text
|
159 |
+
output: Danh sách các cùng chứa text, định dạng (min_x, min_y, max_x, max_y)
|
160 |
+
"""
|
161 |
+
|
162 |
+
clusters, areas = find_clusters(text_msk, value=text_value)
|
163 |
+
return clusters, areas
|
164 |
+
|
165 |
+
def filter_text_positions(clusters, areas, min_area=1200, max_area=10000):
|
166 |
+
clusters = clusters[(areas >= min_area) & (areas <= max_area)]
|
167 |
+
return clusters
|
168 |
+
|
169 |
+
|
170 |
+
def get_list_texts(src_img, text_positions, lang='eng'):
|
171 |
+
"""
|
172 |
+
input: src_img: Ảnh gốc, để dưới định dạng là np.array, shape: [H, W, C]
|
173 |
+
text_positions: Danh sách các cùng chứa text, định dạng (min_x, min_y, max_x, max_y)
|
174 |
+
lang: Ngôn ngữ của text
|
175 |
+
output: Danh sách các câu text
|
176 |
+
"""
|
177 |
+
list_texts = []
|
178 |
+
for idx, (min_x, min_y, max_x, max_y) in enumerate(text_positions):
|
179 |
+
crop_img = src_img[min_y:max_y+1, min_x:max_x+1]
|
180 |
+
img_rgb = cv2.cvtColor(crop_img, cv2.COLOR_BGR2RGB)
|
181 |
+
img = Image.fromarray(img_rgb)
|
182 |
+
text = pytesseract.image_to_string(img, lang=lang).replace('\n', ' ').strip()
|
183 |
+
while ' ' in text:
|
184 |
+
text = text.replace(' ', ' ')
|
185 |
+
list_texts.append(text)
|
186 |
+
return list_texts
|
187 |
+
|
188 |
+
|
189 |
+
def translate(list_texts, translator):
|
190 |
+
translated_texts = []
|
191 |
+
for text in list_texts:
|
192 |
+
if not text:
|
193 |
+
text = 'a'
|
194 |
+
translated_text = translator.translate(text, src='en', dest='vi').text
|
195 |
+
translated_texts.append(translated_text)
|
196 |
+
return translated_texts
|
197 |
+
|
198 |
+
|
199 |
+
def add_centered_multiline_text(image, text, box, font_path="arial.ttf", font_size=36, pad=5, text_color=0):
|
200 |
+
# Mở ảnh
|
201 |
+
draw = ImageDraw.Draw(image)
|
202 |
+
|
203 |
+
# Giải nén box (min_x, min_y, max_x, max_y)
|
204 |
+
min_x, min_y, max_x, max_y = box
|
205 |
+
|
206 |
+
# Tạo font
|
207 |
+
font = ImageFont.truetype(font_path, font_size)
|
208 |
+
|
209 |
+
# Chia văn bản thành nhiều dòng nếu cần
|
210 |
+
wrapped_lines = wrap_text(text, font, draw, max_x - min_x)
|
211 |
+
|
212 |
+
# Tính chiều cao của tất cả các dòng cộng lại
|
213 |
+
total_text_height = sum(get_text_height(line, draw, font) for line in wrapped_lines)
|
214 |
+
|
215 |
+
# Tính toạ độ y bắt đầu để căn giữa theo chiều dọc
|
216 |
+
start_y = min_y + (max_y - min_y - total_text_height) // 2
|
217 |
+
|
218 |
+
# Vẽ từng dòng và căn giữa theo chiều ngang
|
219 |
+
current_y = start_y
|
220 |
+
for line in wrapped_lines:
|
221 |
+
text_width, text_height = get_text_dimensions(line, draw, font)
|
222 |
+
text_x = min_x + (max_x - min_x - text_width) // 2 # Căn giữa theo chiều ngang
|
223 |
+
draw.text((text_x, current_y), line, fill=text_color, font=font)
|
224 |
+
current_y += text_height + pad # Di chuyển y xuống để vẽ dòng tiếp theo
|
225 |
+
|
226 |
+
# Lưu ảnh mới
|
227 |
+
return image
|
228 |
+
|
229 |
+
def get_text_dimensions(text, draw, font):
|
230 |
+
"""Trả về (width, height) của văn bản."""
|
231 |
+
bbox = draw.textbbox((0, 0), text, font=font)
|
232 |
+
width = bbox[2] - bbox[0]
|
233 |
+
height = bbox[3] - bbox[1]
|
234 |
+
return width, height
|
235 |
+
|
236 |
+
def get_text_height(text, draw, font):
|
237 |
+
"""Trả về chiều cao của văn bản."""
|
238 |
+
_, _, _, height = draw.textbbox((0, 0), text, font=font)
|
239 |
+
return height
|
240 |
+
|
241 |
+
def wrap_text(text, font, draw, max_width):
|
242 |
+
"""Chia văn bản thành nhiều dòng dựa trên chiều rộng tối đa."""
|
243 |
+
words = text.split()
|
244 |
+
lines = []
|
245 |
+
current_line = ""
|
246 |
+
|
247 |
+
for word in words:
|
248 |
+
# Thử thêm từ vào dòng hiện tại
|
249 |
+
test_line = f"{current_line} {word}".strip()
|
250 |
+
test_width, _ = get_text_dimensions(test_line, draw, font)
|
251 |
+
|
252 |
+
if test_width <= max_width:
|
253 |
+
current_line = test_line
|
254 |
+
else:
|
255 |
+
# Nếu quá rộng, lưu dòng hiện tại và bắt đầu dòng mới
|
256 |
+
lines.append(current_line)
|
257 |
+
current_line = word
|
258 |
+
|
259 |
+
# Thêm dòng cuối cùng
|
260 |
+
if current_line:
|
261 |
+
lines.append(current_line)
|
262 |
+
|
263 |
+
return lines
|
264 |
+
|
265 |
+
def insert_text(non_text_src_img, list_translated_texts, text_positions, font=['MTO Astro City.ttf'], font_size=[20], pad=[5], text_color=0, stroke=[3]):
|
266 |
+
# Copy ảnh không chữ
|
267 |
+
img_bgr = non_text_src_img.copy()
|
268 |
+
|
269 |
+
# Thêm text vào măt nạ 1
|
270 |
+
for idx, text in enumerate(list_translated_texts):
|
271 |
+
# Tạo mặt nạ trắng
|
272 |
+
mask1 = Image.new("L", img_bgr.shape[:2][::-1], 255)
|
273 |
+
mask2 = Image.new("L", img_bgr.shape[:2][::-1], 255)
|
274 |
+
mask1 = add_centered_multiline_text(mask1, text, text_positions[idx], f'MTO Font/{font[idx]}', font_size[idx], pad=pad[idx], text_color=text_color)
|
275 |
+
|
276 |
+
# Chuyển ảnh từ PIL sang cv2
|
277 |
+
mask1 = (np.array(mask1) >= 127).astype(np.uint8) * 255
|
278 |
+
mask1 = cv2.cvtColor(mask1, cv2.COLOR_RGB2BGR)
|
279 |
+
|
280 |
+
if stroke[idx] > 0:
|
281 |
+
mask2 = np.array(mask2).astype(np.uint8)
|
282 |
+
mask2 = cv2.cvtColor(mask2, cv2.COLOR_RGB2BGR)
|
283 |
+
|
284 |
+
mask2 = mask2 - mask1
|
285 |
+
kernel = np.ones((stroke[idx]+1, stroke[idx]+1), np.uint8)
|
286 |
+
mask2 = cv2.dilate(mask2, kernel, iterations=1)
|
287 |
+
img_bgr[mask2==255] = 255
|
288 |
+
|
289 |
+
img_bgr[mask1==text_color] = text_color
|
290 |
+
return img_bgr
|
291 |
+
|
292 |
+
|
293 |
+
def save_img(path, translated_text_src_img):
|
294 |
+
"""
|
295 |
+
input: path: Đường dẫn đến ảnh gốc ban đầu
|
296 |
+
translated_text_src_img: Ảnh sau khi được dịch
|
297 |
+
output: Ảnh sau dịch được lưu lại, trong tên có thêm "translated-"
|
298 |
+
"""
|
299 |
+
dot = path.rfind('.')
|
300 |
+
last_slash = -1
|
301 |
+
if '/' in path:
|
302 |
+
last_slash = path.rfind('/')
|
303 |
+
|
304 |
+
ext = path[dot:]
|
305 |
+
parent_path = path[:last_slash+1]
|
306 |
+
name = path[last_slash+1:dot]
|
307 |
+
|
308 |
+
if parent_path and not os.path.exists(parent_path):
|
309 |
+
os.mkdir(parent_path)
|
310 |
+
cv2.imwrite(f'{parent_path}translated-{name}{ext}', translated_text_src_img)
|
311 |
+
print(f'Image saved at {parent_path}translated-{name}{ext}')
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
chardet==3.0.4
|
2 |
+
googletrans==4.0.0rc1
|
3 |
+
h2==3.2.0
|
4 |
+
hstspreload==2024.10.1
|
5 |
+
opencv-python==4.10.0.84
|
6 |
+
pip==24.2
|
7 |
+
positional-encodings==6.0.3
|
8 |
+
pytesseract==0.3.13
|
9 |
+
rfc3986==1.5.0
|
10 |
+
setuptools==75.1.0
|
11 |
+
spaces==0.30.4
|
12 |
+
torch==2.5.0
|
13 |
+
wheel==0.44.0
|