SS3M commited on
Commit
e2cc14b
·
verified ·
1 Parent(s): 9ef1b02

Upload 7 files

Browse files
Files changed (7) hide show
  1. .gitattributes +35 -36
  2. README.md +13 -13
  3. TransUnet.py +320 -0
  4. TransUnet_Config.py +7 -0
  5. app.py +428 -0
  6. edit_func.py +311 -0
  7. requirements.txt +13 -0
.gitattributes CHANGED
@@ -1,36 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- MTO[[:space:]]Font/MTO[[:space:]]Getting[[:space:]]Angry.ttf filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
 
README.md CHANGED
@@ -1,13 +1,13 @@
1
- ---
2
- title: TheEditor
3
- emoji: 📉
4
- colorFrom: pink
5
- colorTo: yellow
6
- sdk: gradio
7
- sdk_version: 5.3.0
8
- app_file: app.py
9
- pinned: false
10
- license: apache-2.0
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: TheEditor
3
+ emoji: 📉
4
+ colorFrom: pink
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 5.3.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
TransUnet.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from positional_encodings.torch_encodings import PositionalEncoding2D
5
+
6
+
7
+ class LayerNorm2D(nn.Module):
8
+ def __init__(self, embed_dim):
9
+ super().__init__()
10
+ self.layer_norm = nn.LayerNorm(embed_dim)
11
+
12
+ def forward(self, x):
13
+ x = x.permute(0, 2, 3, 1)
14
+ x = self.layer_norm(x)
15
+ x = x.permute(0, 3, 1, 2)
16
+ return x
17
+
18
+ class Image_Adaptor(nn.Module):
19
+ def __init__(self, in_channels, adp_channels, dropout=0.1):
20
+ super().__init__()
21
+
22
+ self.adaptor = nn.Sequential(
23
+ nn.Conv2d(in_channels, adp_channels // 4, kernel_size=4, padding='same'),
24
+ LayerNorm2D(adp_channels // 4),
25
+ nn.GELU(),
26
+ nn.Conv2d(adp_channels // 4, adp_channels // 4, kernel_size=2, padding='same'),
27
+ LayerNorm2D(adp_channels // 4),
28
+ nn.GELU(),
29
+ nn.Conv2d(adp_channels // 4, adp_channels, kernel_size=2, padding='same')
30
+ )
31
+ self.dropout = nn.Dropout(dropout)
32
+
33
+ def forward(self, images):
34
+ """
35
+ input: [N, in_channels, H, W]
36
+ output: [N, apd_channels, H, W]
37
+ """
38
+ adapt_imgs = self.adaptor(images)
39
+ return self.dropout(adapt_imgs)
40
+
41
+ class Positional_Encoding(nn.Module):
42
+ def __init__(self, adp_channels):
43
+ super().__init__()
44
+ self.pe = PositionalEncoding2D(adp_channels)
45
+
46
+ def forward(self, adapt_imgs):
47
+ """
48
+ input: [N, apd_channels, H, W]
49
+ output: [N, apd_channels, H, W]
50
+ """
51
+ x = adapt_imgs.permute(0, -2, -1, -3)
52
+ encode = self.pe(x)
53
+ encode = encode.permute(0, -1, -3, -2)
54
+ return encode
55
+
56
+ class GeGLU(nn.Module):
57
+ def __init__(self, emb_channels, ffn_size):
58
+ super().__init__()
59
+ self.wi_0 = nn.Linear(emb_channels, ffn_size, bias=False)
60
+ self.wi_1 = nn.Linear(emb_channels, ffn_size, bias=False)
61
+ self.act = nn.GELU()
62
+
63
+ def forward(self, x):
64
+ x_gelu = self.act(self.wi_0(x))
65
+ x_linear = self.wi_1(x)
66
+ x = x_gelu * x_linear
67
+ return x
68
+
69
+ class Feed_Forward(nn.Module):
70
+ def __init__(self, in_channels, ffw_channels, dropout=0.1):
71
+ super().__init__()
72
+
73
+ self.ln1 = GeGLU(in_channels, ffw_channels)
74
+ self.dropout = nn.Dropout(dropout)
75
+ self.ln2 = GeGLU(ffw_channels, in_channels)
76
+
77
+ def forward(self, x):
78
+ '''
79
+ input: [N, H, W, channels]
80
+ output: [N, H, W, channels]
81
+ '''
82
+ x = self.ln1(x)
83
+ x = self.dropout(x)
84
+ x = self.ln2(x)
85
+ return x
86
+
87
+ class MultiHeadAttention(nn.Module):
88
+ def __init__(self, channels, num_attn_heads, dropout=0.1):
89
+ super().__init__()
90
+
91
+ self.head_size = num_attn_heads
92
+ self.channels = channels
93
+ self.attn_size = channels // num_attn_heads
94
+ self.scale = self.attn_size ** -0.5
95
+ assert num_attn_heads * self.attn_size == channels, "Input channels of attention must divisible by number of attention head!"
96
+
97
+ self.lq = nn.Linear(channels, self.head_size*self.attn_size, bias=False)
98
+ self.lk = nn.Linear(channels, self.head_size*self.attn_size, bias=False)
99
+ self.lv = nn.Linear(channels, self.head_size*self.attn_size, bias=False)
100
+ self.lout = nn.Linear(self.head_size*self.attn_size, channels, bias=False)
101
+ self.dropout = nn.Dropout(dropout)
102
+
103
+ def forward(self, q, k, v):
104
+ '''
105
+ input: [N, H, W, channels] cho cả 3 cái q, k, v
106
+ output: [N, H, W, channels]
107
+ '''
108
+ bz, H, W, C = q.shape
109
+
110
+ # Duỗi ảnh ra trước
111
+ q = q.view(bz, -1, C) # [N, H*W, C]
112
+ k = k.view(bz, -1, C) # [N, H*W, C]
113
+ v = v.view(bz, -1, C) # [N, H*W, C]
114
+
115
+ q = self.lq(q).view(bz, -1, self.head_size, self.attn_size) # [N, H*W, hz, az]
116
+ k = self.lk(k).view(bz, -1, self.head_size, self.attn_size) # [N, H*W, hz, az]
117
+ v = self.lv(v).view(bz, -1, self.head_size, self.attn_size) # [N, H*W, hz, az]
118
+
119
+ q = q.transpose(1, 2) # [N, hz, H*W, az]
120
+ k = k.transpose(1, 2).transpose(-1, -2) # [N, hz, az, H*W]
121
+ v = v.transpose(1, 2) # [N, hz, H*W, az]
122
+
123
+ q *= self.scale
124
+
125
+ x = torch.matmul(q, k) # [N, hz, H*W, H*W]
126
+ x = torch.softmax(x, dim=-1)
127
+ x = self.dropout(x)
128
+ x = x.matmul(v) # [N, hz, H*W, az]
129
+
130
+ x = x.transpose(1, 2).contiguous() # [N, H*W, hz, az]
131
+ x = x.view(bz, -1, C) # [N, H*W, C]
132
+ x = x.view(bz, H, W, C) # [N, H, W, C]
133
+
134
+ x = self.lout(x) # [N, H, W, C]
135
+
136
+ return x
137
+
138
+ class Transformer_Encoder_Layer(nn.Module):
139
+ def __init__(self, channels, num_attn_heads, ffw_channels, dropout=0.1):
140
+ super().__init__()
141
+
142
+ self.attn_norm = nn.LayerNorm(channels)
143
+ self.attn_layer = MultiHeadAttention(channels, num_attn_heads, dropout)
144
+ self.attn_dropout = nn.Dropout(dropout)
145
+
146
+ self.ffw_norm = nn.LayerNorm(channels)
147
+ self.ffw_layer = Feed_Forward(channels, ffw_channels, dropout)
148
+ self.ffw_dropout = nn.Dropout(dropout)
149
+
150
+ def forward(self, adp_pos_imgs):
151
+ """
152
+ input: [N, H, W, channels]
153
+ output: [N, H, W, channels]
154
+ """
155
+ _x = adp_pos_imgs
156
+ x = self.attn_norm(adp_pos_imgs)
157
+ x = self.attn_layer(x, x, x)
158
+ x = self.attn_dropout(x)
159
+ x = x + _x
160
+
161
+ _x = x
162
+ x = self.ffw_norm(x)
163
+ x = self.ffw_layer(x)
164
+ x = self.ffw_dropout(x)
165
+ x = x + _x
166
+ return x
167
+
168
+ class Transformer_Encoder(nn.Module):
169
+ def __init__(self, in_channels, out_channels, num_layers, num_attn_heads, ffw_channels, dropout=0.1):
170
+ super().__init__()
171
+
172
+ self.encoder_layers = nn.ModuleList([
173
+ Transformer_Encoder_Layer(in_channels, num_attn_heads, ffw_channels, dropout) for _ in range(num_layers)
174
+ ])
175
+ self.linear = nn.Linear(in_channels, out_channels)
176
+ self.last_norm = LayerNorm2D(out_channels)
177
+ self.dropout = nn.Dropout(dropout)
178
+
179
+ def forward(self, adp_pos_imgs):
180
+ """
181
+ input: [N, in_channels, H, W]
182
+ output: [N, out_channels, H, W]
183
+ """
184
+ x = adp_pos_imgs.permute(0, -2, -1, -3) # [N, H, W, in_channels]
185
+
186
+ for layer in self.encoder_layers:
187
+ x = layer(x)
188
+
189
+ x = self.linear(x) # [N, H, W, out_channels]
190
+ x = x.permute(0, -1, -3, -2)
191
+ x = self.last_norm(x)
192
+ out = self.dropout(x)
193
+ return out
194
+
195
+ class Double_Conv(nn.Module):
196
+ def __init__(self, in_channels, out_channels):
197
+ super().__init__()
198
+
199
+ self.double_conv = nn.Sequential(
200
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
201
+ nn.BatchNorm2d(out_channels),
202
+ nn.ReLU(inplace=True),
203
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
204
+ nn.BatchNorm2d(out_channels),
205
+ nn.ReLU(inplace=True)
206
+ )
207
+
208
+ def forward(self, X):
209
+ """
210
+ input: [N, in_channels, H, W]
211
+ output: [N, out_channels, H//2, W//2]
212
+ """
213
+ return self.double_conv(X)
214
+
215
+ class Down(nn.Module):
216
+ def __init__(self, in_channels, out_channels):
217
+ super().__init__()
218
+
219
+ self.down = nn.Sequential(
220
+ nn.MaxPool2d(2),
221
+ Double_Conv(in_channels, out_channels)
222
+ )
223
+
224
+ def forward(self, X):
225
+ """
226
+ input: [N, in_channels, H, W]
227
+ output: [N, out_channels, H//2, W//2]
228
+ """
229
+ return self.down(X)
230
+
231
+ class Up(nn.Module):
232
+ def __init__(self, in_channels, out_channels):
233
+ super().__init__()
234
+
235
+ self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2)
236
+ self.conv = Double_Conv(in_channels, out_channels)
237
+
238
+ def forward(self, X1, X2):
239
+ """
240
+ input: X1 : [N, in_channels, H // 2, W // 2]
241
+ X2 : [N, in_channels // 2, H, W]
242
+ output: X : [N, out_channels, H, W]
243
+ """
244
+ X1 = self.up(X1)
245
+
246
+ diffY = X2.shape[-2] - X1.shape[-2]
247
+ diffX = X2.shape[-1] - X1.shape[-1]
248
+
249
+ pad_top = diffY // 2
250
+ pad_bottom = diffY - pad_top
251
+ pad_left = diffX // 2
252
+ pad_right = diffX - pad_left
253
+
254
+ X1 = F.pad(X1, (pad_left, pad_right, pad_top, pad_bottom))
255
+
256
+ X = torch.cat((X2, X1), dim=-3)
257
+ return self.conv(X)
258
+
259
+ class Out_Conv(nn.Module):
260
+ def __init__(self, adp_channels, out_channels):
261
+ super().__init__()
262
+
263
+ self.out_conv = nn.Conv2d(adp_channels, out_channels, kernel_size=1)
264
+
265
+ def forward(self, X):
266
+ return self.out_conv(X)
267
+
268
+ class Trans_UNet(nn.Module):
269
+ def __init__(self,
270
+ in_channels,
271
+ adp_channels,
272
+ out_channels,
273
+ trans_num_layers=5,
274
+ trans_num_attn_heads=8,
275
+ trans_ffw_channels=1024,
276
+ dropout=0.1):
277
+ super().__init__()
278
+
279
+ self.img_adaptor = Image_Adaptor(in_channels, adp_channels, dropout)
280
+ self.pos_encoding = Positional_Encoding(adp_channels)
281
+
282
+ self.down1 = Down(adp_channels * 1, adp_channels * 2)
283
+ self.down2 = Down(adp_channels * 2, adp_channels * 4)
284
+ self.down3 = Down(adp_channels * 4, adp_channels * 8)
285
+ self.down4 = Down(adp_channels * 8, adp_channels * 16)
286
+ self.down5 = Down(adp_channels * 16, adp_channels * 32)
287
+
288
+ self.trans_encoder = Transformer_Encoder(adp_channels * 32, adp_channels * 32, trans_num_layers, trans_num_attn_heads, trans_ffw_channels, dropout)
289
+
290
+ self.up5 = Up(adp_channels * 32, adp_channels * 16)
291
+ self.up4 = Up(adp_channels * 16, adp_channels * 8)
292
+ self.up3 = Up(adp_channels * 8, adp_channels * 4)
293
+ self.up2 = Up(adp_channels * 4, adp_channels * 2)
294
+ self.up1 = Up(adp_channels * 2, adp_channels * 1)
295
+
296
+ self.out_conv = Out_Conv(adp_channels, out_channels)
297
+ self.sigmoid = nn.Sigmoid()
298
+
299
+ def forward(self, images):
300
+ adp_imgs = self.img_adaptor(images)
301
+ pos_enc = self.pos_encoding(adp_imgs)
302
+ adp_imgs += pos_enc
303
+
304
+ d1 = self.down1(adp_imgs)
305
+ d2 = self.down2(d1)
306
+ d3 = self.down3(d2)
307
+ d4 = self.down4(d3)
308
+ d5 = self.down5(d4)
309
+
310
+ x = self.trans_encoder(d5)
311
+
312
+ u5 = self.up5(x, d4)
313
+ u4 = self.up4(u5, d3)
314
+ u3 = self.up3(u4, d2)
315
+ u2 = self.up2(u3, d1)
316
+ u1 = self.up1(u2, adp_imgs)
317
+
318
+ x = self.out_conv(u1)
319
+ out = self.sigmoid(x)
320
+ return out
TransUnet_Config.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ in_channels = 3
2
+ adp_channels = 32
3
+ out_channels = 1
4
+ trans_num_layers = 5
5
+ trans_num_attn_heads = 8
6
+ trans_ffw_channels = 512
7
+ dropout = 0.1
app.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import gradio as gr
3
+ import os
4
+ from edit_func import *
5
+ from TransUnet import Trans_UNet
6
+ import TransUnet_Config as config2
7
+ from huggingface_hub import hf_hub_download
8
+ from googletrans import Translator
9
+ import random
10
+ import torch.nn as nn
11
+ import spaces
12
+
13
+ @spaces.GPU
14
+ class DTM(nn.Module):
15
+ def __init__(self):
16
+ super().__init__()
17
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
18
+ self.detect_text_model = Trans_UNet(
19
+ config2.in_channels, config2.adp_channels, config2.out_channels,
20
+ config2.trans_num_layers, config2.trans_num_attn_heads, config2.trans_ffw_channels,
21
+ config2.dropout
22
+ ).to(self.device)
23
+ self.repo_name = 'SS3M/detect-text-model'
24
+ files = ['detect-text-v3-0.pt', 'detect-text-v3-1.pt',
25
+ 'detect-text-v3-2.pt', 'detect-text-v3-3.pt',
26
+ 'detect-text-v3-4.pt', 'detect-text-v3-5.pt',
27
+ 'detect-text-v3-6.pt', 'detect-text-v3-7.pt']
28
+ self.files = []
29
+ for file in files:
30
+ self.files.append(hf_hub_download(repo_id=self.repo_name, filename=file))
31
+
32
+ def forward(self, X):
33
+ X = X.to(self.device)
34
+ N, C, H, W = X.shape
35
+ result = torch.zeros((N, 1, H, W))
36
+ for file in self.files:
37
+ model_path = file
38
+ best_model_state = torch.load(
39
+ model_path,
40
+ weights_only=True,
41
+ map_location=self.device
42
+ )
43
+ self.detect_text_model.load_state_dict(best_model_state)
44
+ result += self.detect_text_model(X)
45
+ result /= len(self.files)
46
+ return result
47
+
48
+ @spaces.GPU
49
+ class DWBM(nn.Module):
50
+ def __init__(self):
51
+ super().__init__()
52
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
53
+ self.detect_wordball_model = Trans_UNet(
54
+ config2.in_channels, config2.adp_channels, config2.out_channels,
55
+ config2.trans_num_layers, config2.trans_num_attn_heads, config2.trans_ffw_channels,
56
+ config2.dropout
57
+ ).to(self.device)
58
+ self.repo_name = 'SS3M/detect-wordball-model'
59
+ files = ['detect-text-v3-0.pt', 'detect-text-v3-1.pt',
60
+ 'detect-text-v3-2.pt', 'detect-text-v3-3.pt',
61
+ 'detect-text-v3-4.pt', 'detect-text-v3-5.pt',
62
+ 'detect-text-v3-6.pt', 'detect-text-v3-7.pt']
63
+ self.files = []
64
+ for file in files:
65
+ self.files.append(hf_hub_download(repo_id=self.repo_name, filename=file))
66
+
67
+ def forward(self, X):
68
+ X = X.to(self.device)
69
+ N, C, H, W = X.shape
70
+ result = torch.zeros((N, 1, H, W))
71
+ for file in self.files:
72
+ model_path = file
73
+ best_model_state = torch.load(
74
+ model_path,
75
+ weights_only=True,
76
+ map_location=self.device
77
+ )
78
+ self.detect_wordball_model.load_state_dict(best_model_state)
79
+ result += self.detect_wordball_model(X)
80
+ result /= len(self.files)
81
+ return result
82
+
83
+ detect_text_model = DTM()
84
+ detect_wordball_model = DWBM()
85
+
86
+ translator = Translator()
87
+
88
+ def down1(src_img):
89
+ src_img = cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR)
90
+ text_msk = create_text_mask(src_img, detect_text_model)
91
+ wordball_msk = create_wordball_mask(src_img, detect_wordball_model)
92
+
93
+ text_positions, areas = get_text_positions(text_msk, text_value=0)
94
+ rgbs = []
95
+ for _ in range(len(areas)):
96
+ rgbs.append((random.randint(0, 255), random.randint(0, 255), random.randint(0, 255)))
97
+
98
+ idx = '; '.join(str(i) for i in range(len(areas)))
99
+ text_positions = '; '.join([', '.join(str(i) for i in pos) for pos in text_positions])
100
+ areas = '; '.join(str(i) for i in areas)
101
+ rgbs = '; '.join([', '.join(str(i) for i in rgb) for rgb in rgbs])
102
+ src_img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)
103
+ return text_msk*255, wordball_msk*255, idx, text_positions, areas, rgbs, 'Xong'
104
+
105
+ def idx_txt_change(src_img, idx_txt, pos_txt, rgb_txt):
106
+ try:
107
+ src_img2 = cv2.cvtColor(src_img, cv2.COLOR_RGB2BGR)
108
+ text_positions = pos_txt.split('; ')
109
+ for idx in range(len(text_positions)):
110
+ text_positions[idx] = (int(i) for i in text_positions[idx].split(', '))
111
+ rgbs = rgb_txt.split('; ')
112
+ for idx in range(len(rgbs)):
113
+ rgbs[idx] = (int(i) for i in rgbs[idx].split(', '))
114
+ idxes = [int(idx) for idx in idx_txt.split('; ')]
115
+
116
+ for idx, ((min_x, min_y, max_x, max_y), (r, g, b)) in enumerate(zip(text_positions, rgbs)):
117
+ if idx in idxes:
118
+ cv2.rectangle(src_img2, (min_x, min_y), (max_x, max_y), (b, g, r), thickness=4)
119
+ src_img2 = cv2.cvtColor(src_img2, cv2.COLOR_BGR2RGB)
120
+ return src_img2
121
+ except:
122
+ return src_img
123
+
124
+ def scale_area_change(min_area, max_area, area_txt):
125
+ areas = [int(area) for area in area_txt.split('; ')]
126
+ idxes = []
127
+ for idx, area in enumerate(areas):
128
+ if min_area <= area <= max_area:
129
+ idxes.append(idx)
130
+ idxes = '; '.join(str(i) for i in idxes)
131
+ return idxes
132
+
133
+ def position_block_change(X, Y, W, H, ID, pos_txt_value):
134
+ text_positions = pos_txt_value.split('; ')
135
+ for idx in range(len(text_positions)):
136
+ text_positions[idx] = (int(i) for i in text_positions[idx].split(', '))
137
+
138
+ text_positions2 = []
139
+ for idx, (min_x, min_y, max_x, max_y) in enumerate(text_positions):
140
+ if idx == ID:
141
+ text_positions2.append((X, Y, X+W, Y+H))
142
+ else:
143
+ text_positions2.append((min_x, min_y, max_x, max_y))
144
+ text_positions2 = '; '.join([', '.join(str(i) for i in pos) for pos in text_positions2])
145
+ return text_positions2
146
+
147
+ def ID_block_change(ID_value, checkbox_value, ID_txt_value):
148
+ ID_txt_value = [int(i) for i in ID_txt_value.split('; ')]
149
+ if checkbox_value and ID_value not in ID_txt_value:
150
+ ID_txt_value.append(ID_value)
151
+ if not checkbox_value and ID_value in ID_txt_value:
152
+ ID_txt_value.remove(ID_value)
153
+ ID_txt_value = sorted(ID_txt_value)
154
+ ID_txt_value = '; '.join([str(i) for i in ID_txt_value])
155
+ return ID_txt_value
156
+
157
+ def down2(src_img_value, txt_mask_value, wordball_mask_value, idx_txt_value, pos_txt_value):
158
+ src_img_value = cv2.cvtColor(src_img_value, cv2.COLOR_RGB2BGR)
159
+ text_positions = pos_txt_value.split('; ')
160
+ for idx in range(len(text_positions)):
161
+ text_positions[idx] = (int(i) for i in text_positions[idx].split(', '))
162
+ idxes = [int(i) for i in idx_txt_value.split('; ')]
163
+
164
+ for idx, (min_x, min_y, max_x, max_y) in enumerate(text_positions):
165
+ if idx not in idxes:
166
+ txt_mask_value[min_y:max_y+1, min_x:max_x+1] = 255
167
+ txt_mask_value = txt_mask_value[:, :, 0].astype(np.uint8)
168
+ non_text_src_img = clear_text(src_img_value, txt_mask_value, wordball_mask_value, text_value=0, non_text_value=255, r=5)
169
+
170
+ list_texts = get_list_texts(src_img_value, [tuple(map(int, pos.split(', '))) for idx, pos in enumerate(pos_txt_value.split('; ')) if idx in idxes])
171
+ list_translated_texts = translate(list_texts, translator)
172
+ list_fonts = '; '.join(['MTO Astro City.ttf' for _ in range(len(list_translated_texts))])
173
+ list_sizes = '; '.join(['20' for _ in range(len(list_translated_texts))])
174
+ list_strokes = '; '.join(['3' for _ in range(len(list_translated_texts))])
175
+ list_pads = '; '.join(['5' for _ in range(len(list_translated_texts))])
176
+ list_translated_texts = '; '.join(list_translated_texts)
177
+ switch = str(random.random())
178
+
179
+ return non_text_src_img, list_translated_texts, list_fonts, list_sizes, list_strokes, list_pads, switch, 'Xong'
180
+
181
+ def text_info_change(non_txt_img_value, translated_txt_value, pos_txt_value, idx_txt_value, font_txt_value, size_txt_value, stroke_txt_value, pad_txt_value):
182
+ non_txt_img_value = non_txt_img_value.copy()
183
+ idxes = [int(i) for i in idx_txt_value.split('; ')]
184
+
185
+ translated_text_src_img = insert_text(non_txt_img_value,
186
+ translated_txt_value.split('; '),
187
+ [tuple(map(int, pos.split(', '))) for idx, pos in enumerate(pos_txt_value.split('; ')) if idx in idxes],
188
+ font=font_txt_value.split('; '),
189
+ font_size=[int(i) for i in size_txt_value.split('; ')],
190
+ pad=[int(i) for i in pad_txt_value.split('; ')],
191
+ stroke=[int(i) for i in stroke_txt_value.split('; ')])
192
+ return translated_text_src_img
193
+
194
+ def value2_change(value, ID2_value, txt_value):
195
+ txt_value = txt_value.split('; ')
196
+
197
+ txt_value2 = []
198
+ for idx, text in enumerate(txt_value):
199
+ if idx == ID2_value:
200
+ txt_value2.append(str(value))
201
+ else:
202
+ txt_value2.append(str(text))
203
+ txt_value2 = '; '.join(txt_value2)
204
+ return txt_value2
205
+
206
+ # Tạo giao diện Gradio
207
+ with gr.Blocks() as demo:
208
+ # Cấu trúc
209
+ src_img = gr.Image(type="numpy", label="Upload Image")
210
+
211
+ down_bttn_1 = gr.Button("↓", elem_classes="arrow-button")
212
+
213
+ with gr.Row():
214
+ txt_mask = gr.Image(type="numpy", label="Upload Image", visible=True)
215
+ wordball_mask = gr.Image(type="numpy", label="Upload Image", visible=True)
216
+ complete = gr.Textbox()
217
+ with gr.Row():
218
+ idx_txt = gr.Textbox(label='ID', interactive=False, visible=False)
219
+ pos_txt = gr.Textbox(label='Pos', interactive=False, visible=False)
220
+ area_txt = gr.Textbox(label='Area', interactive=False, visible=False)
221
+ rgb_txt = gr.Textbox(label='rgb', interactive=False, visible=False)
222
+ with gr.Row():
223
+ boxed_txt_img = gr.Image(type="numpy", label="Upload Image")
224
+ with gr.Column() as down_1_column:
225
+ @gr.render(inputs=[pos_txt, rgb_txt], triggers=[rgb_txt.change])
226
+ def create_box(pos_txt_value, rgb_txt_value):
227
+ text_positions = pos_txt_value.split('; ')
228
+ for idx in range(len(text_positions)):
229
+ text_positions[idx] = (int(i) for i in text_positions[idx].split(', '))
230
+ rgbs = rgb_txt_value.split('; ')
231
+ for idx in range(len(rgbs)):
232
+ rgbs[idx] = (int(i) for i in rgbs[idx].split(', '))
233
+
234
+ elements = []
235
+ for idx, (min_x, min_y, max_x, max_y) in enumerate(text_positions):
236
+ with gr.Group() as box:
237
+ r, g, b = rgbs[idx]
238
+ with gr.Row():
239
+ gr.Markdown(
240
+ f"""
241
+ <div style="margin-left: 20px; display: flex; align-items: center;">
242
+ <div style="width: 10px; height: 10px; background-color: rgb({r}, {g}, {b}); margin-right: 5px;"></div>
243
+ <span style="font-size: 20px;">Textbox {idx+1}</span>
244
+ </div>
245
+ """
246
+ )
247
+ checkbox = gr.Checkbox(value=True, label='', min_width=50, interactive=True)
248
+ with gr.Row():
249
+ X = gr.Number(label="X", value=min_x, interactive=True)
250
+ Y = gr.Number(label="Y", value=min_y, interactive=True)
251
+ W = gr.Number(label="W", value=max_x-min_x, interactive=True)
252
+ H = gr.Number(label="H", value=max_y-min_y, interactive=True)
253
+ ID = gr.Number(label="ID", value=idx, interactive=True, visible=False)
254
+ elements.append((X, Y, W, H, ID))
255
+
256
+ checkbox.change(
257
+ fn=ID_block_change,
258
+ inputs=[ID, checkbox, idx_txt],
259
+ outputs=idx_txt,
260
+ show_progress=False
261
+ ).then(
262
+ fn=idx_txt_change,
263
+ inputs=[src_img, idx_txt, pos_txt, rgb_txt],
264
+ outputs=boxed_txt_img,
265
+ )
266
+ X.change(
267
+ fn=position_block_change,
268
+ inputs=[X, Y, W, H, ID, pos_txt],
269
+ outputs=pos_txt,
270
+ show_progress=False
271
+ ).then(
272
+ fn=idx_txt_change,
273
+ inputs=[src_img, idx_txt, pos_txt, rgb_txt],
274
+ outputs=boxed_txt_img,
275
+ show_progress=False
276
+ )
277
+ Y.change(
278
+ fn=position_block_change,
279
+ inputs=[X, Y, W, H, ID, pos_txt],
280
+ outputs=pos_txt,
281
+ show_progress=False
282
+ ).then(
283
+ fn=idx_txt_change,
284
+ inputs=[src_img, idx_txt, pos_txt, rgb_txt],
285
+ outputs=boxed_txt_img,
286
+ show_progress=False
287
+ )
288
+ W.change(
289
+ fn=position_block_change,
290
+ inputs=[X, Y, W, H, ID, pos_txt],
291
+ outputs=pos_txt,
292
+ show_progress=False
293
+ ).then(
294
+ fn=idx_txt_change,
295
+ inputs=[src_img, idx_txt, pos_txt, rgb_txt],
296
+ outputs=boxed_txt_img,
297
+ show_progress=False
298
+ )
299
+ H.change(
300
+ fn=position_block_change,
301
+ inputs=[X, Y, W, H, ID, pos_txt],
302
+ outputs=pos_txt,
303
+ show_progress=False
304
+ ).then(
305
+ fn=idx_txt_change,
306
+ inputs=[src_img, idx_txt, pos_txt, rgb_txt],
307
+ outputs=boxed_txt_img,
308
+ show_progress=False
309
+ )
310
+ down_bttn_2 = gr.Button("↓", elem_classes="arrow-button")
311
+
312
+ non_txt_img = gr.Image(type="numpy", label="Upload Image", visible=False)
313
+ complete2 = gr.Textbox()
314
+ with gr.Row():
315
+ translated_txt = gr.Textbox(label='translated', interactive=False, visible=False)
316
+ font_txt = gr.Textbox(label='font', interactive=False, visible=False)
317
+ size_txt = gr.Textbox(label='size', interactive=False, visible=False)
318
+ stroke_txt = gr.Textbox(label='stroke', interactive=False, visible=False)
319
+ pad_txt = gr.Textbox(label='pad', interactive=False, visible=False)
320
+ switch_txt = gr.Textbox(label='switch', value='1', interactive=False, visible=False)
321
+ with gr.Row():
322
+ boxed_inserted_non_txt_img = gr.Image(type="numpy", label="Upload Image")
323
+ with gr.Column():
324
+ @gr.render(inputs=[translated_txt, font_txt, size_txt, stroke_txt, pad_txt], triggers=[switch_txt.change])
325
+ def create_box2(translated_txt_value, font_txt_value, size_txt_value, stroke_txt_value, pad_txt_value):
326
+ translated_txt_value = translated_txt_value.split('; ')
327
+ font_txt_value = font_txt_value.split('; ')
328
+ size_txt_value = size_txt_value.split('; ')
329
+ stroke_txt_value = stroke_txt_value.split('; ')
330
+ pad_txt_value = pad_txt_value.split('; ')
331
+
332
+ elements = []
333
+ for idx in range(len(font_txt_value)):
334
+ with gr.Group():
335
+ gr.Markdown(
336
+ f"""
337
+ <div style="margin-left: 20px; display: flex; align-items: center;">
338
+ <div style="width: 10px; height: 10px; background-color: rgb(255, 255, 255); margin-right: 5px;"></div>
339
+ <span style="font-size: 20px;">Text box {idx}</span>
340
+ </div>
341
+ """
342
+ )
343
+ translated_text_box = gr.Textbox(label="Translate", value=translated_txt_value[idx], interactive=True)
344
+ with gr.Row():
345
+ font = gr.Dropdown(choices=os.listdir('MTO Font'), label="Phông chữ", value=font_txt_value[idx], interactive=True, scale=7)
346
+ size = gr.Number(label="Size", value=int(size_txt_value[idx]), interactive=True, minimum=1)
347
+ stroke = gr.Number(label="Stroke", value=int(stroke_txt_value[idx]), interactive=True, minimum=0, maximum=5)
348
+ pad = gr.Number(label="Pad", value=int(pad_txt_value[idx]), interactive=True, minimum=1, maximum=10)
349
+ ID2 = gr.Number(label="ID", value=int(idx), interactive=True, visible=False)
350
+
351
+ translated_text_box.submit(
352
+ fn=value2_change,
353
+ inputs=[translated_text_box, ID2, translated_txt],
354
+ outputs=translated_txt,
355
+ show_progress=False
356
+ ).then(
357
+ fn=text_info_change,
358
+ inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt],
359
+ outputs=boxed_inserted_non_txt_img,
360
+ )
361
+ font.change(
362
+ fn=value2_change,
363
+ inputs=[font, ID2, font_txt],
364
+ outputs=font_txt,
365
+ show_progress=False
366
+ ).then(
367
+ fn=text_info_change,
368
+ inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt],
369
+ outputs=boxed_inserted_non_txt_img,
370
+ )
371
+ size.change(
372
+ fn=value2_change,
373
+ inputs=[size, ID2, size_txt],
374
+ outputs=size_txt,
375
+ show_progress=False
376
+ ).then(
377
+ fn=text_info_change,
378
+ inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt],
379
+ outputs=boxed_inserted_non_txt_img,
380
+ )
381
+ stroke.change(
382
+ fn=value2_change,
383
+ inputs=[stroke, ID2, stroke_txt],
384
+ outputs=stroke_txt,
385
+ show_progress=False
386
+ ).then(
387
+ fn=text_info_change,
388
+ inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt],
389
+ outputs=boxed_inserted_non_txt_img,
390
+ )
391
+ pad.change(
392
+ fn=value2_change,
393
+ inputs=[pad, ID2, pad_txt],
394
+ outputs=pad_txt,
395
+ show_progress=False
396
+ ).then(
397
+ fn=text_info_change,
398
+ inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt],
399
+ outputs=boxed_inserted_non_txt_img,
400
+ )
401
+
402
+ # Css
403
+ demo.css = """
404
+ .arrow-button {
405
+ font-size: 40px; /* Kích thước font */
406
+ }
407
+ .group-elem {
408
+ height: 70px;
409
+ }
410
+ """
411
+
412
+ # Điều khiển
413
+ down_bttn_1.click(
414
+ fn=down1,
415
+ inputs=src_img,
416
+ outputs=[txt_mask, wordball_mask, idx_txt, pos_txt, area_txt, rgb_txt, complete],
417
+ )
418
+ down_bttn_2.click(
419
+ fn=down2,
420
+ inputs=[src_img, txt_mask, wordball_mask, idx_txt, pos_txt],
421
+ outputs=[non_txt_img, translated_txt, font_txt, size_txt, stroke_txt, pad_txt, switch_txt, complete2],
422
+ ).then(
423
+ fn=text_info_change,
424
+ inputs=[non_txt_img, translated_txt, pos_txt, idx_txt, font_txt, size_txt, stroke_txt, pad_txt],
425
+ outputs=boxed_inserted_non_txt_img,
426
+ )
427
+
428
+ demo.launch()
edit_func.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import cv2
3
+ import pytesseract
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ from collections import deque
6
+ import numpy as np
7
+ import os
8
+
9
+ # pytesseract.pytesseract.tesseract_cmd = 'Tesseract\\tesseract.exe'
10
+
11
+ def get_full_img_path(src_dir):
12
+ """
13
+ input: Đường dẫn đền folder chứa ảnh
14
+ output: Danh sách tên của tất cả các ảnh
15
+ """
16
+ list_img_names = []
17
+ for dirname, _, filenames in os.walk(src_dir):
18
+ for filename in filenames:
19
+ path = os.path.join(dirname, filename).replace(src_dir, '')
20
+ if path[0] == '/':
21
+ path = path[1:]
22
+ list_img_names.append(path)
23
+ return list_img_names
24
+
25
+
26
+ def create_text_mask(src_img, detect_text_model, kernel_size=5, iterations=3):
27
+ """
28
+ input: Ảnh gốc, để dưới định dạng là np.array, shape: [H, W, C]
29
+ output: Mask đánh dấu text trong ảnh gốc, 0 là chữ, 1 là nền; shape: [H, W]
30
+ """
31
+ img = torch.from_numpy(src_img).to(torch.uint8).to(detect_text_model.device)
32
+ imgT = (img / 255).unsqueeze(0).permute(0, -1, -3, -2)
33
+
34
+ detect_text_model.eval()
35
+ with torch.no_grad():
36
+ result = detect_text_model(imgT).squeeze()
37
+ result = (result >= 0.5).detach().cpu().numpy()
38
+
39
+ mask = ((1-result) * 255).astype(np.uint8)
40
+
41
+ kernel = np.ones((kernel_size, kernel_size), np.uint8)
42
+ mask = cv2.erode(mask, kernel, iterations=iterations)
43
+ mask = cv2.dilate(mask, kernel, iterations=2*iterations)
44
+ mask = cv2.erode(mask, kernel, iterations=iterations)
45
+
46
+ mask = (1 - mask // 255).astype(np.uint8)
47
+ return mask
48
+
49
+
50
+ def create_wordball_mask(src_img, detect_wordball_model, kernel_size=5, iterations=3):
51
+ """
52
+ input: Ảnh gốc, để dưới định dạng là np.array, shape: [H, W, C]
53
+ output: Mask đánh dấu text trong ảnh gốc, 0 là chữ, 1 là nền; shape: [H, W]
54
+ """
55
+ img = torch.from_numpy(src_img).to(torch.uint8).to(detect_wordball_model.device)
56
+ imgT = (img / 255).unsqueeze(0).permute(0, -1, -3, -2)
57
+
58
+ detect_wordball_model.eval()
59
+ with torch.no_grad():
60
+ result = detect_wordball_model(imgT).squeeze()
61
+ result = (result >= 0.5).detach().cpu().numpy()
62
+
63
+ mask = ((1-result) * 255).astype(np.uint8)
64
+
65
+ kernel = np.ones((kernel_size, kernel_size), np.uint8)
66
+ mask = cv2.erode(mask, kernel, iterations=iterations)
67
+ mask = cv2.dilate(mask, kernel, iterations=2*iterations)
68
+ mask = cv2.erode(mask, kernel, iterations=iterations)
69
+
70
+ mask = (1 - mask // 255).astype(np.uint8)
71
+ return mask
72
+
73
+
74
+ def clear_text(src_img, text_msk, wordball_msk, text_value=0, non_text_value=1, r=5):
75
+ """
76
+ input: src_img: Ảnh gốc, để dưới định dạng là np.array, shape: [H, W, C]
77
+ text_msk: Mask đánh dấu text trong ảnh gốc; shape: [H, W]
78
+ text_value: Giá trị mà trong mặt nạ nó là text
79
+ non_text_value: Giá trị mà trong mặt nạ nó là nền
80
+ r: Bán kính để sử dụng cho việc xoá text và vẽ lại phần bị xoá
81
+ output: Ảnh sau khi xoá text, để dưới định dạng là np.array, shape: [H, W, C]
82
+ """
83
+ MAX = max(text_value, non_text_value)
84
+ MIN = min(text_value, non_text_value)
85
+
86
+ scale_text_value = (text_value - MIN) / (MAX - MIN)
87
+ scale_non_text_value = (non_text_value - MIN) / (MAX - MIN)
88
+
89
+ text_msk[text_msk==text_value] = scale_text_value
90
+ text_msk[text_msk==non_text_value] = scale_non_text_value
91
+
92
+ wordball_msk[wordball_msk==text_value] = scale_text_value
93
+ wordball_msk[wordball_msk==non_text_value] = scale_non_text_value
94
+
95
+ if scale_text_value == 0:
96
+ text_msk = 1 - text_msk
97
+ wordball_msk = 1 - wordball_msk
98
+ text_msk = text_msk * 255
99
+
100
+ remove_txt = cv2.inpaint(src_img, text_msk, r, cv2.INPAINT_TELEA)
101
+ remove_wordball = remove_txt.copy()
102
+ remove_wordball[wordball_msk==1] = 255
103
+
104
+ return remove_wordball
105
+
106
+
107
+ def dfs(grid, y, x, visited, value):
108
+ """
109
+ Thuật toán tìm miền liên thông, xem thêm về đồ thị nếu không biết nó là gì
110
+ Output: Một HCN bao phủ miền liên thông + Diện tích của miền liên thông
111
+ """
112
+ max_y, max_x = y, x
113
+ min_y, min_x = y+1, x+1
114
+ area = 0
115
+
116
+ stack = deque([(y, x)])
117
+ while stack:
118
+ y, x = stack.pop()
119
+
120
+ max_x = max(max_x, x)
121
+ max_y = max(max_y, y)
122
+ min_x = min(min_x, x)
123
+ min_y = min(min_y, y)
124
+
125
+ if (y, x) not in visited:
126
+ visited.add((y, x))
127
+ area += 1
128
+ # Kiểm tra các ô liền kề
129
+ for dx, dy in [(-1, 0), (1, 0), (0, -1), (0, 1), (-1, -1), (-1, 1), (1, -1), (1, 1)]:
130
+ nx, ny = x + dx, y + dy
131
+ if 0 <= ny < grid.shape[0] and 0 <= nx < grid.shape[1] and grid[ny, nx] == value and (ny, nx) not in visited:
132
+ stack.append((ny, nx))
133
+
134
+ return (min_x, min_y, max_x, max_y), area
135
+
136
+
137
+ def find_clusters(grid, value):
138
+ """
139
+ Thuật toán tìm danh sách các miền liên thông
140
+ """
141
+ visited = set()
142
+ clusters = []
143
+ areas = []
144
+
145
+ for y in range(grid.shape[0]):
146
+ for x in range(grid.shape[1]):
147
+ if grid[y, x] == value and (y, x) not in visited:
148
+ cluster, area = dfs(grid, y, x, visited, value)
149
+ clusters.append(cluster)
150
+ areas.append(area)
151
+
152
+ return clusters, areas
153
+
154
+ def get_text_positions(text_msk, text_value=0):
155
+ """
156
+ input: text_msk: Mask đánh dấu text trong ảnh gốc; shape: [H, W]
157
+ text_value: Giá trị mà trong mặt nạ nó là text
158
+ min_area: Giả trị tối thiểu của vùng có thể có text
159
+ output: Danh sách các cùng chứa text, định dạng (min_x, min_y, max_x, max_y)
160
+ """
161
+
162
+ clusters, areas = find_clusters(text_msk, value=text_value)
163
+ return clusters, areas
164
+
165
+ def filter_text_positions(clusters, areas, min_area=1200, max_area=10000):
166
+ clusters = clusters[(areas >= min_area) & (areas <= max_area)]
167
+ return clusters
168
+
169
+
170
+ def get_list_texts(src_img, text_positions, lang='eng'):
171
+ """
172
+ input: src_img: Ảnh gốc, để dưới định dạng là np.array, shape: [H, W, C]
173
+ text_positions: Danh sách các cùng chứa text, định dạng (min_x, min_y, max_x, max_y)
174
+ lang: Ngôn ngữ của text
175
+ output: Danh sách các câu text
176
+ """
177
+ list_texts = []
178
+ for idx, (min_x, min_y, max_x, max_y) in enumerate(text_positions):
179
+ crop_img = src_img[min_y:max_y+1, min_x:max_x+1]
180
+ img_rgb = cv2.cvtColor(crop_img, cv2.COLOR_BGR2RGB)
181
+ img = Image.fromarray(img_rgb)
182
+ text = pytesseract.image_to_string(img, lang=lang).replace('\n', ' ').strip()
183
+ while ' ' in text:
184
+ text = text.replace(' ', ' ')
185
+ list_texts.append(text)
186
+ return list_texts
187
+
188
+
189
+ def translate(list_texts, translator):
190
+ translated_texts = []
191
+ for text in list_texts:
192
+ if not text:
193
+ text = 'a'
194
+ translated_text = translator.translate(text, src='en', dest='vi').text
195
+ translated_texts.append(translated_text)
196
+ return translated_texts
197
+
198
+
199
+ def add_centered_multiline_text(image, text, box, font_path="arial.ttf", font_size=36, pad=5, text_color=0):
200
+ # Mở ảnh
201
+ draw = ImageDraw.Draw(image)
202
+
203
+ # Giải nén box (min_x, min_y, max_x, max_y)
204
+ min_x, min_y, max_x, max_y = box
205
+
206
+ # Tạo font
207
+ font = ImageFont.truetype(font_path, font_size)
208
+
209
+ # Chia văn bản thành nhiều dòng nếu cần
210
+ wrapped_lines = wrap_text(text, font, draw, max_x - min_x)
211
+
212
+ # Tính chiều cao của tất cả các dòng cộng lại
213
+ total_text_height = sum(get_text_height(line, draw, font) for line in wrapped_lines)
214
+
215
+ # Tính toạ độ y bắt đầu để căn giữa theo chiều dọc
216
+ start_y = min_y + (max_y - min_y - total_text_height) // 2
217
+
218
+ # Vẽ từng dòng và căn giữa theo chiều ngang
219
+ current_y = start_y
220
+ for line in wrapped_lines:
221
+ text_width, text_height = get_text_dimensions(line, draw, font)
222
+ text_x = min_x + (max_x - min_x - text_width) // 2 # Căn giữa theo chiều ngang
223
+ draw.text((text_x, current_y), line, fill=text_color, font=font)
224
+ current_y += text_height + pad # Di chuyển y xuống để vẽ dòng tiếp theo
225
+
226
+ # Lưu ảnh mới
227
+ return image
228
+
229
+ def get_text_dimensions(text, draw, font):
230
+ """Trả về (width, height) của văn bản."""
231
+ bbox = draw.textbbox((0, 0), text, font=font)
232
+ width = bbox[2] - bbox[0]
233
+ height = bbox[3] - bbox[1]
234
+ return width, height
235
+
236
+ def get_text_height(text, draw, font):
237
+ """Trả về chiều cao của văn bản."""
238
+ _, _, _, height = draw.textbbox((0, 0), text, font=font)
239
+ return height
240
+
241
+ def wrap_text(text, font, draw, max_width):
242
+ """Chia văn bản thành nhiều dòng dựa trên chiều rộng tối đa."""
243
+ words = text.split()
244
+ lines = []
245
+ current_line = ""
246
+
247
+ for word in words:
248
+ # Thử thêm từ vào dòng hiện tại
249
+ test_line = f"{current_line} {word}".strip()
250
+ test_width, _ = get_text_dimensions(test_line, draw, font)
251
+
252
+ if test_width <= max_width:
253
+ current_line = test_line
254
+ else:
255
+ # Nếu quá rộng, lưu dòng hiện tại và bắt đầu dòng mới
256
+ lines.append(current_line)
257
+ current_line = word
258
+
259
+ # Thêm dòng cuối cùng
260
+ if current_line:
261
+ lines.append(current_line)
262
+
263
+ return lines
264
+
265
+ def insert_text(non_text_src_img, list_translated_texts, text_positions, font=['MTO Astro City.ttf'], font_size=[20], pad=[5], text_color=0, stroke=[3]):
266
+ # Copy ảnh không chữ
267
+ img_bgr = non_text_src_img.copy()
268
+
269
+ # Thêm text vào măt nạ 1
270
+ for idx, text in enumerate(list_translated_texts):
271
+ # Tạo mặt nạ trắng
272
+ mask1 = Image.new("L", img_bgr.shape[:2][::-1], 255)
273
+ mask2 = Image.new("L", img_bgr.shape[:2][::-1], 255)
274
+ mask1 = add_centered_multiline_text(mask1, text, text_positions[idx], f'MTO Font/{font[idx]}', font_size[idx], pad=pad[idx], text_color=text_color)
275
+
276
+ # Chuyển ảnh từ PIL sang cv2
277
+ mask1 = (np.array(mask1) >= 127).astype(np.uint8) * 255
278
+ mask1 = cv2.cvtColor(mask1, cv2.COLOR_RGB2BGR)
279
+
280
+ if stroke[idx] > 0:
281
+ mask2 = np.array(mask2).astype(np.uint8)
282
+ mask2 = cv2.cvtColor(mask2, cv2.COLOR_RGB2BGR)
283
+
284
+ mask2 = mask2 - mask1
285
+ kernel = np.ones((stroke[idx]+1, stroke[idx]+1), np.uint8)
286
+ mask2 = cv2.dilate(mask2, kernel, iterations=1)
287
+ img_bgr[mask2==255] = 255
288
+
289
+ img_bgr[mask1==text_color] = text_color
290
+ return img_bgr
291
+
292
+
293
+ def save_img(path, translated_text_src_img):
294
+ """
295
+ input: path: Đường dẫn đến ảnh gốc ban đầu
296
+ translated_text_src_img: Ảnh sau khi được dịch
297
+ output: Ảnh sau dịch được lưu lại, trong tên có thêm "translated-"
298
+ """
299
+ dot = path.rfind('.')
300
+ last_slash = -1
301
+ if '/' in path:
302
+ last_slash = path.rfind('/')
303
+
304
+ ext = path[dot:]
305
+ parent_path = path[:last_slash+1]
306
+ name = path[last_slash+1:dot]
307
+
308
+ if parent_path and not os.path.exists(parent_path):
309
+ os.mkdir(parent_path)
310
+ cv2.imwrite(f'{parent_path}translated-{name}{ext}', translated_text_src_img)
311
+ print(f'Image saved at {parent_path}translated-{name}{ext}')
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ chardet==3.0.4
2
+ googletrans==4.0.0rc1
3
+ h2==3.2.0
4
+ hstspreload==2024.10.1
5
+ opencv-python==4.10.0.84
6
+ pip==24.2
7
+ positional-encodings==6.0.3
8
+ pytesseract==0.3.13
9
+ rfc3986==1.5.0
10
+ setuptools==75.1.0
11
+ spaces==0.30.4
12
+ torch==2.5.0
13
+ wheel==0.44.0