csxmli commited on
Commit
5307ec1
·
verified ·
1 Parent(s): 3a261c5

Update models/TextEnhancement.py

Browse files
Files changed (1) hide show
  1. models/TextEnhancement.py +367 -367
models/TextEnhancement.py CHANGED
@@ -1,367 +1,367 @@
1
- # -*- coding: utf-8 -*-
2
- import cv2
3
- import os.path as osp
4
- import torch
5
- import torchvision.transforms as transforms
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
- import numpy as np
9
- import logging
10
- logging.getLogger('modelscope').disabled = True
11
-
12
- from cnstd import CnStd
13
- from utils.utils_transocr import get_alphabet
14
- from utils.yolo_ocr_xloc import get_yolo_ocr_xloc
15
- from ultralytics import YOLO
16
-
17
- from modelscope.pipelines import pipeline
18
- from modelscope.utils.constant import Tasks
19
- from networks import *
20
- import warnings
21
- warnings.filterwarnings('ignore')
22
-
23
- from modelscope import snapshot_download
24
-
25
-
26
- ##########################################################################################
27
- ###############Text Restoration Model revised by xiaoming li
28
- ##########################################################################################
29
-
30
- alphabet_path = './models/benchmark_cvpr23.txt'
31
- CommonWordsForOCR = get_alphabet(alphabet_path)
32
- CommonWords = CommonWordsForOCR[2:-1]
33
-
34
-
35
-
36
- def str2idx(text):
37
- idx = []
38
- for t in text:
39
- idx.append(CommonWords.index(t) if t in CommonWords else 3484) #3955
40
- return idx
41
-
42
- def get_parameter_details(net):
43
- num_params = 0
44
- for param in net.parameters():
45
- num_params += param.numel()
46
- return num_params / 1e6
47
-
48
- def tensor2numpy(tensor):
49
- tensor = tensor * 0.5 + 0.5
50
- tensor = tensor.squeeze(0).permute(1, 2, 0).flip(2)
51
- return np.clip(tensor.float().cpu().numpy(), 0, 1) * 255.0
52
-
53
-
54
- class MARCONetPlus(object):
55
- def __init__(self, WEncoderPath=None, PriorModelPath=None, SRModelPath=None, YoloPath=None, device='cuda'):
56
- self.device = device
57
-
58
- modelscope_dir = snapshot_download('damo/cv_convnextTiny_ocr-recognition-general_damo', cache_dir='./checkpoints/modelscope_ocr')
59
- self.modelscope_ocr_recognition = pipeline(Tasks.ocr_recognition, model=modelscope_dir)
60
- self.yolo_character = YOLO(YoloPath)
61
-
62
- self.modelWEncoder = PSPEncoder() # WEncoder()
63
- self.modelWEncoder.load_state_dict(torch.load(WEncoderPath)['params'], strict=True)
64
- self.modelWEncoder.eval()
65
- self.modelWEncoder.to(device)
66
-
67
- self.modelPrior = TextPriorModel()
68
- self.modelPrior.load_state_dict(torch.load(PriorModelPath)['params'], strict=True)
69
- self.modelPrior.eval()
70
- self.modelPrior.to(device)
71
-
72
- self.modelSR = SRNet()
73
- self.modelSR.load_state_dict(torch.load(SRModelPath)['params'], strict=True)
74
- self.modelSR.eval()
75
- self.modelSR.to(device)
76
-
77
-
78
- print('='*128)
79
- print('{:>25s} : {:.2f} M Parameters'.format('modelWEncoder', get_parameter_details(self.modelWEncoder)))
80
- print('{:>25s} : {:.2f} M Parameters'.format('modelPrior', get_parameter_details(self.modelPrior)))
81
- print('{:>25s} : {:.2f} M Parameters'.format('modelSR', get_parameter_details(self.modelSR)))
82
- print('='*128)
83
-
84
- torch.cuda.empty_cache()
85
- self.cnstd = CnStd(model_name='db_resnet34',rotated_bbox=True, model_backend='pytorch', box_score_thresh=0.3, min_box_size=10, context=device)
86
- self.insize = 32
87
-
88
-
89
- def handle_texts(self, img, bg=None, sf=4, is_aligned=False, lq_label=None):
90
- '''
91
- Parameters:
92
- img: RGB 0~255.
93
- '''
94
-
95
- height, width = img.shape[:2]
96
- bg_height, bg_width = bg.shape[:2]
97
- print(' ' * 25 + f' ... The input->output image size is {bg_height//sf}*{bg_width//sf}->{bg_height}*{bg_width}')
98
-
99
- full_mask_blur = np.zeros(bg.shape, dtype=np.float32)
100
- full_mask_noblur = np.zeros(bg.shape, dtype=np.float32)
101
- full_text_img = np.zeros(bg.shape, dtype=np.float32) #+255
102
-
103
- orig_texts, enhanced_texts, debug_texts, pred_texts = [], [], [], []
104
- ocr_scores = []
105
-
106
- if not is_aligned:
107
- box_infos = self.cnstd.detect(img)
108
- for iix, box_info in enumerate(box_infos['detected_texts']):
109
- box = box_info['box'].astype(int)# left top, right top, right bottom, left bottom, [width, height]
110
- score = box_info['score']
111
- if score < 0.5:
112
- continue
113
-
114
- extend_box = box.copy()
115
- w = int(np.linalg.norm(box[0] - box[1]))
116
- h = int(np.linalg.norm(box[0] - box[3]))
117
-
118
- # extend the bounding box
119
- extend_lr = 0.15 * h
120
- extend_tb = 0.05 * h
121
- vec_w = (box[1] - box[0]) / w
122
- vec_h = (box[3] - box[0]) / h
123
-
124
- extend_box[0] = box[0] - vec_w * extend_lr - vec_h * extend_tb
125
- extend_box[1] = box[1] + vec_w * extend_lr - vec_h * extend_tb
126
- extend_box[2] = box[2] + vec_w * extend_lr + vec_h * extend_tb
127
- extend_box[3] = box[3] - vec_w * extend_lr + vec_h * extend_tb
128
- extend_box = extend_box.astype(int)
129
-
130
- w = int(np.linalg.norm(extend_box[0] - extend_box[1]))
131
- h = int(np.linalg.norm(extend_box[0] - extend_box[3]))
132
-
133
- if w > h:
134
- ref_h = self.insize
135
- ref_w = int(ref_h * w / h)
136
- else:
137
- print(' ' * 25 + ' ... Can not handle vertical text temporarily')
138
- continue
139
-
140
- ref_point = np.float32([[0,0], [ref_w, 0], [ref_w, ref_h], [0, ref_h]])
141
- det_point = np.float32(extend_box)
142
-
143
- matrix = cv2.getPerspectiveTransform(det_point, ref_point)
144
- inv_matrix = cv2.getPerspectiveTransform(ref_point*sf, det_point*sf)
145
-
146
- cropped_img = cv2.warpPerspective(img, matrix, (ref_w, ref_h), borderMode=cv2.BORDER_REPLICATE, flags=cv2.INTER_LINEAR)
147
-
148
-
149
- in_img, SQ, save_debug, pred_text, preds_locs_txt = self._process_text_line(cropped_img)
150
- if in_img is None:
151
- continue
152
- h_crop, w_crop = cropped_img.shape[:2]
153
- SQ = cv2.resize(SQ, (w_crop * sf, h_crop * sf), interpolation=cv2.INTER_CUBIC)
154
-
155
- debug_texts.append(save_debug)
156
- orig_texts.append(in_img)
157
- enhanced_texts.append(SQ)
158
- pred_texts.append(''.join(pred_text))
159
-
160
- tmp_mask = np.ones(SQ.shape).astype(float)
161
- warp_mask = cv2.warpPerspective(tmp_mask, inv_matrix, (bg_width, bg_height), flags=3)
162
- warp_img = cv2.warpPerspective(SQ, inv_matrix, (bg_width, bg_height), flags=3)
163
-
164
-
165
- # erode and blur based on the height of text region
166
- blur_pad = int(h // 6)
167
-
168
- if blur_pad % 2 == 0:
169
- blur_pad += 1
170
- blur_radius = (blur_pad - 1) // 2
171
- erode_radius = blur_radius + 1
172
- erode_pad = 2 * erode_radius + 1
173
-
174
- kernel_erode = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (erode_pad, erode_pad))
175
- warp_mask_erode = cv2.erode(warp_mask, kernel_erode, iterations=1)
176
-
177
- # warp_mask_blur = cv2.GaussianBlur(warp_mask_erode, (blur_pad, blur_pad), 0)
178
- warp_mask_blur = cv2.blur(warp_mask_erode, (blur_pad, blur_pad))
179
-
180
- full_text_img = full_text_img + warp_img
181
- full_mask_blur = full_mask_blur + warp_mask_blur
182
- full_mask_noblur = full_mask_noblur + warp_mask
183
-
184
- ocr_scores.append(score)
185
-
186
-
187
- index = full_mask_noblur > 0
188
- full_text_img[index] = full_text_img[index]/full_mask_noblur[index]
189
-
190
- full_mask_blur = np.clip(full_mask_blur, 0, 1)
191
- # fuse the text region back to the background
192
- final_img = full_text_img * full_mask_blur + bg * (1 - full_mask_blur)
193
-
194
-
195
- return final_img, orig_texts, enhanced_texts, debug_texts, pred_texts #, ocr_scores
196
-
197
- else: #aligned
198
-
199
- in_img, SQ, save_debug, pred_text, preds_locs_txt = self._process_text_line(img)
200
- if in_img is not None:
201
- debug_texts.append(save_debug)
202
- orig_texts.append(in_img)
203
- enhanced_texts.append(SQ)
204
- pred_texts.append(''.join(pred_text))
205
-
206
- return img, orig_texts, enhanced_texts, debug_texts, pred_texts #, preds_locs_txt
207
-
208
- def _process_text_line(self, img):
209
- """
210
- Process a single text line region for text enhancement.
211
-
212
- Args:
213
- img: Input text image
214
-
215
- """
216
-
217
-
218
- height, width = img.shape[:2]
219
- if height > width:
220
- print(' ' * 25 + ' ... Can not handle vertical text temporarily')
221
- return (None,) * 5
222
-
223
- w_norm = int(self.insize * width / height) // 4 * 4
224
- h_norm = self.insize
225
-
226
- img = cv2.resize(img, (w_norm*4, h_norm*4), interpolation=cv2.INTER_CUBIC)
227
- in_img = cv2.resize(img, (w_norm, h_norm), interpolation=cv2.INTER_CUBIC)
228
- ShowLQ = img[:,:,::-1]
229
-
230
- LQ_HeightNorm = transforms.ToTensor()(in_img)
231
- LQ_HeightNorm = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(LQ_HeightNorm).unsqueeze(0).to(self.device)
232
-
233
-
234
- '''
235
- Step 1: Predicting the character labels, bounding boxes.
236
- '''
237
-
238
- recognized_boxes, pred_text, char_x_centers = get_yolo_ocr_xloc(
239
- img, # input image, RGB 0~255
240
- yolo_model=self.yolo_character, # YOLO model instance for character detection
241
- ocr_pipeline=self.modelscope_ocr_recognition, # OCR pipeline/model for character recognition
242
- num_cropped_boxes=5, # Number of adjacent character boxes to include in each cropped segment (window size)
243
- expand_px=1, # Number of pixels to expand each crop region on all sides (except first/last)
244
- expand_px_for_first_last_cha=12, # Number of pixels to expand the crop region for the first and last character (left/right respectively)
245
- yolo_iou=0.1, # IOU threshold for YOLO non-max suppression (NMS)
246
- yolo_conf=0.07 # Confidence threshold for YOLO detection
247
- )
248
-
249
- print('{:>25s} ... Recognized chars: {}'.format(' ', ''.join(pred_text)))
250
- loc_sr = torch.tensor(char_x_centers, device=self.device).unsqueeze(0)
251
-
252
-
253
- # show character location
254
- pad = 1
255
- ShowPredLoc = ShowLQ.copy()
256
- for l in range(len(pred_text)):
257
- center_pred_w = int(loc_sr[0][l].item())
258
- if center_pred_w > 0:
259
- ShowPredLoc[:, max(0, center_pred_w-pad):min(center_pred_w+pad, ShowPredLoc.shape[1]), :] = 0
260
- ShowPredLoc[:, max(0, center_pred_w-pad):min(center_pred_w+pad, ShowPredLoc.shape[1]), 1] = 255
261
-
262
-
263
- '''
264
- Step 2: Character Prior Generation
265
- '''
266
-
267
- with torch.no_grad():
268
- w = self.modelWEncoder(LQ_HeightNorm, loc_sr)
269
-
270
- predict_characters128 = []
271
- predict_characters64 = []
272
- predict_characters32 = []
273
-
274
- for b in range(w.size(0)):
275
- w0 = w[b,...].clone() #16*512
276
- pred_label = str2idx(pred_text)
277
- pred_label = torch.Tensor(pred_label).type(torch.LongTensor).view(-1, 1)#.to(device)
278
-
279
- with torch.no_grad():
280
- prior_cha, prior_fea64, prior_fea32 = self.modelPrior(styles=w0[:len(pred_text),:], labels=pred_label, noise=None) #b *n * w * h
281
-
282
- predict_characters128.append(prior_cha)
283
- predict_characters64.append(prior_fea64)
284
- predict_characters32.append(prior_fea32)
285
-
286
-
287
- '''
288
- Step 3: Character SR
289
- '''
290
-
291
- with torch.no_grad():
292
- extend_right_width = extend_left_width = h_norm // 2
293
- LQ_HeightNorm_WidthExtend = F.pad(LQ_HeightNorm, (extend_left_width, extend_right_width, 0, 0), mode='replicate')
294
-
295
- preds_locs_txt = ''
296
- loc_for_extend_sr = loc_sr.clone()
297
- for i in range(len(pred_text)):
298
- preds_locs_txt += str(int(loc_for_extend_sr[0][i].cpu().item()))+'_'
299
- loc_for_extend_sr[0][i] = loc_for_extend_sr[0][i] + extend_left_width * 4
300
-
301
- SR = self.modelSR(LQ_HeightNorm_WidthExtend, predict_characters64, predict_characters32, loc_for_extend_sr)
302
-
303
- SR = tensor2numpy(SR)[:, extend_left_width * 4:extend_left_width * 4 + w_norm*4, ::-1]
304
-
305
-
306
- # reduce color inconsistency,use ab channel from in_img
307
- # sr_lab = cv2.cvtColor(SR.astype(np.uint8), cv2.COLOR_BGR2LAB)
308
- # target_size = (SR.shape[1], SR.shape[0])
309
- # in_img_resize = cv2.resize(in_img, target_size, interpolation=cv2.INTER_LINEAR)
310
- # in_img_lab = cv2.cvtColor(in_img_resize.astype(np.uint8), cv2.COLOR_BGR2LAB)
311
- # sr_lab[:,:,1:] = in_img_lab[:,:,1:]
312
- # SR = cv2.cvtColor(sr_lab, cv2.COLOR_LAB2BGR)
313
-
314
-
315
- prior128 = []
316
- pad = 2
317
- for prior in predict_characters128:
318
- for ii, p in enumerate(prior):
319
- prior128.append(p)
320
- prior128 = torch.cat(prior128, dim=2)
321
- prior128 = prior128 * 0.5 + 0.5
322
- prior128 = prior128.permute(1, 2, 0).flip(2)
323
- prior128 = np.clip(prior128.float().cpu().numpy(), 0, 1) * 255.0
324
- prior128 = np.repeat(prior128, 3, axis=2)
325
-
326
- ShowPrior = cv2.resize(prior128, (SR.shape[1], int(128 * SR.shape[1] / prior128.shape[1])), interpolation=cv2.INTER_CUBIC)
327
-
328
-
329
- #--------Fuse the structure prior to the LR input to show the details of alignment--------------
330
- fusion_bg = np.zeros_like(SR, dtype=np.float32)
331
- w4 = w_norm * 4
332
-
333
- for iii, c in enumerate(loc_sr[0].int()):
334
- current_prior = prior128[:, iii*128:(iii+1)*128, :]
335
- center_loc = c.item()
336
-
337
- x1 = max(center_loc - 64, 0)
338
- x2 = min(center_loc + 64, w4)
339
- y1 = max(64 - center_loc, 0)
340
- y2 = y1 + (x2 - x1)
341
- try:
342
- fusion_bg[:, x1:x2, :] += current_prior[:, y1:y2, :]
343
- except:
344
- return (None,) * 5
345
-
346
-
347
- mask = fusion_bg / 255.0
348
- fusion_bg[:,:,0] = 0
349
- fusion_bg[:,:,2] = 0
350
-
351
-
352
- ShowLQ = ShowLQ[:,:,::-1]
353
- fusion_bg = fusion_bg.astype(ShowLQ.dtype)
354
- fusion_bg = fusion_bg * 0.3 * mask + ShowLQ * 0.7 * mask + (1-mask) * ShowLQ
355
-
356
- ShowPrior = cv2.normalize(ShowPrior, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
357
-
358
- save_debug = np.vstack((ShowLQ, ShowPredLoc[:,:,::-1], SR, ShowPrior, fusion_bg))
359
-
360
- return in_img, SR, save_debug, pred_text, preds_locs_txt
361
-
362
-
363
-
364
- if __name__ == '__main__':
365
- print('Test')
366
-
367
-
 
1
+ # -*- coding: utf-8 -*-
2
+ import cv2
3
+ import os.path as osp
4
+ import torch
5
+ import torchvision.transforms as transforms
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import numpy as np
9
+ import logging
10
+ logging.getLogger('modelscope').disabled = True
11
+
12
+ from cnstd import CnStd
13
+ from utils.utils_transocr import get_alphabet
14
+ from utils.yolo_ocr_xloc import get_yolo_ocr_xloc
15
+ from ultralytics import YOLO
16
+
17
+ from modelscope.pipelines import pipeline
18
+ from modelscope.utils.constant import Tasks
19
+ from networks import *
20
+ import warnings
21
+ warnings.filterwarnings('ignore')
22
+
23
+ from modelscope import snapshot_download
24
+
25
+
26
+ ##########################################################################################
27
+ ###############Text Restoration Model revised by xiaoming li
28
+ ##########################################################################################
29
+
30
+ alphabet_path = './models/benchmark_cvpr23.txt'
31
+ CommonWordsForOCR = get_alphabet(alphabet_path)
32
+ CommonWords = CommonWordsForOCR[2:-1]
33
+
34
+
35
+
36
+ def str2idx(text):
37
+ idx = []
38
+ for t in text:
39
+ idx.append(CommonWords.index(t) if t in CommonWords else 3484) #3955
40
+ return idx
41
+
42
+ def get_parameter_details(net):
43
+ num_params = 0
44
+ for param in net.parameters():
45
+ num_params += param.numel()
46
+ return num_params / 1e6
47
+
48
+ def tensor2numpy(tensor):
49
+ tensor = tensor * 0.5 + 0.5
50
+ tensor = tensor.squeeze(0).permute(1, 2, 0).flip(2)
51
+ return np.clip(tensor.float().cpu().numpy(), 0, 1) * 255.0
52
+
53
+
54
+ class MARCONetPlus(object):
55
+ def __init__(self, WEncoderPath=None, PriorModelPath=None, SRModelPath=None, YoloPath=None, device='cuda'):
56
+ self.device = device
57
+
58
+ modelscope_dir = snapshot_download('damo/cv_convnextTiny_ocr-recognition-general_damo', cache_dir='./checkpoints/modelscope_ocr')
59
+ self.modelscope_ocr_recognition = pipeline(Tasks.ocr_recognition, model=modelscope_dir)
60
+ self.yolo_character = YOLO(YoloPath)
61
+
62
+ self.modelWEncoder = PSPEncoder() # WEncoder()
63
+ self.modelWEncoder.load_state_dict(torch.load(WEncoderPath)['params'], strict=True)
64
+ self.modelWEncoder.eval()
65
+ self.modelWEncoder.to(device)
66
+
67
+ self.modelPrior = TextPriorModel()
68
+ self.modelPrior.load_state_dict(torch.load(PriorModelPath)['params'], strict=True)
69
+ self.modelPrior.eval()
70
+ self.modelPrior.to(device)
71
+
72
+ self.modelSR = SRNet()
73
+ self.modelSR.load_state_dict(torch.load(SRModelPath)['params'], strict=True)
74
+ self.modelSR.eval()
75
+ self.modelSR.to(device)
76
+
77
+
78
+ print('='*128)
79
+ print('{:>25s} : {:.2f} M Parameters'.format('modelWEncoder', get_parameter_details(self.modelWEncoder)))
80
+ print('{:>25s} : {:.2f} M Parameters'.format('modelPrior', get_parameter_details(self.modelPrior)))
81
+ print('{:>25s} : {:.2f} M Parameters'.format('modelSR', get_parameter_details(self.modelSR)))
82
+ print('='*128)
83
+
84
+ torch.cuda.empty_cache()
85
+ self.cnstd = CnStd(model_name='db_resnet34',rotated_bbox=True, model_backend='pytorch', box_score_thresh=0.3, min_box_size=10, context=device)
86
+ self.insize = 32
87
+
88
+
89
+ def handle_texts(self, img, bg=None, sf=4, is_aligned=False, lq_label=None):
90
+ '''
91
+ Parameters:
92
+ img: RGB 0~255.
93
+ '''
94
+
95
+ height, width = img.shape[:2]
96
+ bg_height, bg_width = bg.shape[:2]
97
+ print(' ' * 25 + f' ... The input->output image size is {bg_height//sf}*{bg_width//sf}->{bg_height}*{bg_width}')
98
+
99
+ full_mask_blur = np.zeros(bg.shape, dtype=np.float32)
100
+ full_mask_noblur = np.zeros(bg.shape, dtype=np.float32)
101
+ full_text_img = np.zeros(bg.shape, dtype=np.float32) #+255
102
+
103
+ orig_texts, enhanced_texts, debug_texts, pred_texts = [], [], [], []
104
+ ocr_scores = []
105
+
106
+ if not is_aligned:
107
+ box_infos = self.cnstd.detect(img)
108
+ for iix, box_info in enumerate(box_infos['detected_texts']):
109
+ box = box_info['box'].astype(int)# left top, right top, right bottom, left bottom, [width, height]
110
+ score = box_info['score']
111
+ if score < 0.5:
112
+ continue
113
+
114
+ extend_box = box.copy()
115
+ w = int(np.linalg.norm(box[0] - box[1]))
116
+ h = int(np.linalg.norm(box[0] - box[3]))
117
+
118
+ # extend the bounding box
119
+ extend_lr = 0.15 * h
120
+ extend_tb = 0.05 * h
121
+ vec_w = (box[1] - box[0]) / w
122
+ vec_h = (box[3] - box[0]) / h
123
+
124
+ extend_box[0] = box[0] - vec_w * extend_lr - vec_h * extend_tb
125
+ extend_box[1] = box[1] + vec_w * extend_lr - vec_h * extend_tb
126
+ extend_box[2] = box[2] + vec_w * extend_lr + vec_h * extend_tb
127
+ extend_box[3] = box[3] - vec_w * extend_lr + vec_h * extend_tb
128
+ extend_box = extend_box.astype(int)
129
+
130
+ w = int(np.linalg.norm(extend_box[0] - extend_box[1]))
131
+ h = int(np.linalg.norm(extend_box[0] - extend_box[3]))
132
+
133
+ if w > h:
134
+ ref_h = self.insize
135
+ ref_w = int(ref_h * w / h)
136
+ else:
137
+ print(' ' * 25 + ' ... Can not handle vertical text temporarily')
138
+ continue
139
+
140
+ ref_point = np.float32([[0,0], [ref_w, 0], [ref_w, ref_h], [0, ref_h]])
141
+ det_point = np.float32(extend_box)
142
+
143
+ matrix = cv2.getPerspectiveTransform(det_point, ref_point)
144
+ inv_matrix = cv2.getPerspectiveTransform(ref_point*sf, det_point*sf)
145
+
146
+ cropped_img = cv2.warpPerspective(img, matrix, (ref_w, ref_h), borderMode=cv2.BORDER_REPLICATE, flags=cv2.INTER_LINEAR)
147
+
148
+
149
+ in_img, SQ, save_debug, pred_text, preds_locs_txt = self._process_text_line(cropped_img)
150
+ if in_img is None:
151
+ continue
152
+ h_crop, w_crop = cropped_img.shape[:2]
153
+ SQ = cv2.resize(SQ, (w_crop * sf, h_crop * sf), interpolation=cv2.INTER_LINEAR)
154
+
155
+ debug_texts.append(save_debug)
156
+ orig_texts.append(in_img)
157
+ enhanced_texts.append(SQ)
158
+ pred_texts.append(''.join(pred_text))
159
+
160
+ tmp_mask = np.ones(SQ.shape).astype(float)
161
+ warp_mask = cv2.warpPerspective(tmp_mask, inv_matrix, (bg_width, bg_height), flags=3)
162
+ warp_img = cv2.warpPerspective(SQ, inv_matrix, (bg_width, bg_height), flags=3)
163
+
164
+
165
+ # erode and blur based on the height of text region
166
+ blur_pad = int(h // 6)
167
+
168
+ if blur_pad % 2 == 0:
169
+ blur_pad += 1
170
+ blur_radius = (blur_pad - 1) // 2
171
+ erode_radius = blur_radius + 1
172
+ erode_pad = 2 * erode_radius + 1
173
+
174
+ kernel_erode = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (erode_pad, erode_pad))
175
+ warp_mask_erode = cv2.erode(warp_mask, kernel_erode, iterations=1)
176
+
177
+ # warp_mask_blur = cv2.GaussianBlur(warp_mask_erode, (blur_pad, blur_pad), 0)
178
+ warp_mask_blur = cv2.blur(warp_mask_erode, (blur_pad, blur_pad))
179
+
180
+ full_text_img = full_text_img + warp_img
181
+ full_mask_blur = full_mask_blur + warp_mask_blur
182
+ full_mask_noblur = full_mask_noblur + warp_mask
183
+
184
+ ocr_scores.append(score)
185
+
186
+
187
+ index = full_mask_noblur > 0
188
+ full_text_img[index] = full_text_img[index]/full_mask_noblur[index]
189
+
190
+ full_mask_blur = np.clip(full_mask_blur, 0, 1)
191
+ # fuse the text region back to the background
192
+ final_img = full_text_img * full_mask_blur + bg * (1 - full_mask_blur)
193
+
194
+
195
+ return final_img, orig_texts, enhanced_texts, debug_texts, pred_texts #, ocr_scores
196
+
197
+ else: #aligned
198
+
199
+ in_img, SQ, save_debug, pred_text, preds_locs_txt = self._process_text_line(img)
200
+ if in_img is not None:
201
+ debug_texts.append(save_debug)
202
+ orig_texts.append(in_img)
203
+ enhanced_texts.append(SQ)
204
+ pred_texts.append(''.join(pred_text))
205
+
206
+ return img, orig_texts, enhanced_texts, debug_texts, pred_texts #, preds_locs_txt
207
+
208
+ def _process_text_line(self, img):
209
+ """
210
+ Process a single text line region for text enhancement.
211
+
212
+ Args:
213
+ img: Input text image
214
+
215
+ """
216
+
217
+
218
+ height, width = img.shape[:2]
219
+ if height > width:
220
+ print(' ' * 25 + ' ... Can not handle vertical text temporarily')
221
+ return (None,) * 5
222
+
223
+ w_norm = int(self.insize * width / height) // 4 * 4
224
+ h_norm = self.insize
225
+
226
+ img = cv2.resize(img, (w_norm*4, h_norm*4), interpolation=cv2.INTER_LINEAR)
227
+ in_img = cv2.resize(img, (w_norm, h_norm), interpolation=cv2.INTER_LINEAR)
228
+ ShowLQ = img[:,:,::-1]
229
+
230
+ LQ_HeightNorm = transforms.ToTensor()(in_img)
231
+ LQ_HeightNorm = transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))(LQ_HeightNorm).unsqueeze(0).to(self.device)
232
+
233
+
234
+ '''
235
+ Step 1: Predicting the character labels, bounding boxes.
236
+ '''
237
+
238
+ recognized_boxes, pred_text, char_x_centers = get_yolo_ocr_xloc(
239
+ img, # input image, RGB 0~255
240
+ yolo_model=self.yolo_character, # YOLO model instance for character detection
241
+ ocr_pipeline=self.modelscope_ocr_recognition, # OCR pipeline/model for character recognition
242
+ num_cropped_boxes=5, # Number of adjacent character boxes to include in each cropped segment (window size)
243
+ expand_px=1, # Number of pixels to expand each crop region on all sides (except first/last)
244
+ expand_px_for_first_last_cha=12, # Number of pixels to expand the crop region for the first and last character (left/right respectively)
245
+ yolo_iou=0.1, # IOU threshold for YOLO non-max suppression (NMS)
246
+ yolo_conf=0.07 # Confidence threshold for YOLO detection
247
+ )
248
+
249
+ print('{:>25s} ... Recognized chars: {}'.format(' ', ''.join(pred_text)))
250
+ loc_sr = torch.tensor(char_x_centers, device=self.device).unsqueeze(0)
251
+
252
+
253
+ # show character location
254
+ pad = 1
255
+ ShowPredLoc = ShowLQ.copy()
256
+ for l in range(len(pred_text)):
257
+ center_pred_w = int(loc_sr[0][l].item())
258
+ if center_pred_w > 0:
259
+ ShowPredLoc[:, max(0, center_pred_w-pad):min(center_pred_w+pad, ShowPredLoc.shape[1]), :] = 0
260
+ ShowPredLoc[:, max(0, center_pred_w-pad):min(center_pred_w+pad, ShowPredLoc.shape[1]), 1] = 255
261
+
262
+
263
+ '''
264
+ Step 2: Character Prior Generation
265
+ '''
266
+
267
+ with torch.no_grad():
268
+ w = self.modelWEncoder(LQ_HeightNorm, loc_sr)
269
+
270
+ predict_characters128 = []
271
+ predict_characters64 = []
272
+ predict_characters32 = []
273
+
274
+ for b in range(w.size(0)):
275
+ w0 = w[b,...].clone() #16*512
276
+ pred_label = str2idx(pred_text)
277
+ pred_label = torch.Tensor(pred_label).type(torch.LongTensor).view(-1, 1)#.to(device)
278
+
279
+ with torch.no_grad():
280
+ prior_cha, prior_fea64, prior_fea32 = self.modelPrior(styles=w0[:len(pred_text),:], labels=pred_label, noise=None) #b *n * w * h
281
+
282
+ predict_characters128.append(prior_cha)
283
+ predict_characters64.append(prior_fea64)
284
+ predict_characters32.append(prior_fea32)
285
+
286
+
287
+ '''
288
+ Step 3: Character SR
289
+ '''
290
+
291
+ with torch.no_grad():
292
+ extend_right_width = extend_left_width = h_norm // 2
293
+ LQ_HeightNorm_WidthExtend = F.pad(LQ_HeightNorm, (extend_left_width, extend_right_width, 0, 0), mode='replicate')
294
+
295
+ preds_locs_txt = ''
296
+ loc_for_extend_sr = loc_sr.clone()
297
+ for i in range(len(pred_text)):
298
+ preds_locs_txt += str(int(loc_for_extend_sr[0][i].cpu().item()))+'_'
299
+ loc_for_extend_sr[0][i] = loc_for_extend_sr[0][i] + extend_left_width * 4
300
+
301
+ SR = self.modelSR(LQ_HeightNorm_WidthExtend, predict_characters64, predict_characters32, loc_for_extend_sr)
302
+
303
+ SR = tensor2numpy(SR)[:, extend_left_width * 4:extend_left_width * 4 + w_norm*4, ::-1]
304
+
305
+
306
+ # reduce color inconsistency,use ab channel from in_img
307
+ # sr_lab = cv2.cvtColor(SR.astype(np.uint8), cv2.COLOR_BGR2LAB)
308
+ # target_size = (SR.shape[1], SR.shape[0])
309
+ # in_img_resize = cv2.resize(in_img, target_size, interpolation=cv2.INTER_LINEAR)
310
+ # in_img_lab = cv2.cvtColor(in_img_resize.astype(np.uint8), cv2.COLOR_BGR2LAB)
311
+ # sr_lab[:,:,1:] = in_img_lab[:,:,1:]
312
+ # SR = cv2.cvtColor(sr_lab, cv2.COLOR_LAB2BGR)
313
+
314
+
315
+ prior128 = []
316
+ pad = 2
317
+ for prior in predict_characters128:
318
+ for ii, p in enumerate(prior):
319
+ prior128.append(p)
320
+ prior128 = torch.cat(prior128, dim=2)
321
+ prior128 = prior128 * 0.5 + 0.5
322
+ prior128 = prior128.permute(1, 2, 0).flip(2)
323
+ prior128 = np.clip(prior128.float().cpu().numpy(), 0, 1) * 255.0
324
+ prior128 = np.repeat(prior128, 3, axis=2)
325
+
326
+ ShowPrior = cv2.resize(prior128, (SR.shape[1], int(128 * SR.shape[1] / prior128.shape[1])), interpolation=cv2.INTER_LINEAR)
327
+
328
+
329
+ #--------Fuse the structure prior to the LR input to show the details of alignment--------------
330
+ fusion_bg = np.zeros_like(SR, dtype=np.float32)
331
+ w4 = w_norm * 4
332
+
333
+ for iii, c in enumerate(loc_sr[0].int()):
334
+ current_prior = prior128[:, iii*128:(iii+1)*128, :]
335
+ center_loc = c.item()
336
+
337
+ x1 = max(center_loc - 64, 0)
338
+ x2 = min(center_loc + 64, w4)
339
+ y1 = max(64 - center_loc, 0)
340
+ y2 = y1 + (x2 - x1)
341
+ try:
342
+ fusion_bg[:, x1:x2, :] += current_prior[:, y1:y2, :]
343
+ except:
344
+ return (None,) * 5
345
+
346
+
347
+ mask = fusion_bg / 255.0
348
+ fusion_bg[:,:,0] = 0
349
+ fusion_bg[:,:,2] = 0
350
+
351
+
352
+ ShowLQ = ShowLQ[:,:,::-1]
353
+ fusion_bg = fusion_bg.astype(ShowLQ.dtype)
354
+ fusion_bg = fusion_bg * 0.3 * mask + ShowLQ * 0.7 * mask + (1-mask) * ShowLQ
355
+
356
+ ShowPrior = cv2.normalize(ShowPrior, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8)
357
+
358
+ save_debug = np.vstack((ShowLQ, ShowPredLoc[:,:,::-1], SR, ShowPrior, fusion_bg))
359
+
360
+ return in_img, SR, save_debug, pred_text, preds_locs_txt
361
+
362
+
363
+
364
+ if __name__ == '__main__':
365
+ print('Test')
366
+
367
+