Spaces:
Runtime error
Runtime error
Commit
·
0c80503
1
Parent(s):
204969e
add files
Browse files
app.py
CHANGED
@@ -111,12 +111,33 @@ def patch_resize_transform(patch_image_size=480, is_document=False):
|
|
111 |
return _patch_resize_transform
|
112 |
|
113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
# Construct input for caption task
|
115 |
def construct_sample(task, image: Image, patch_image_size=480):
|
116 |
-
bos_item = torch.LongTensor([task.src_dict.bos()])
|
117 |
-
eos_item = torch.LongTensor([task.src_dict.eos()])
|
118 |
-
pad_idx = task.src_dict.pad()
|
119 |
-
|
120 |
patch_image = patch_resize_transform(patch_image_size)(image).unsqueeze(0)
|
121 |
patch_mask = torch.tensor([True])
|
122 |
src_text = encode_text(task, "图片上的文å—是什么?", append_bos=True, append_eos=True).unsqueeze(0)
|
@@ -141,35 +162,11 @@ def apply_half(t):
|
|
141 |
return t
|
142 |
|
143 |
|
144 |
-
def ocr(
|
145 |
-
reader = ReaderLite()
|
146 |
-
overrides={"eval_cider":False, "beam":8, "max_len_b":128, "patch_image_size":480, "orig_patch_image_size":224, "no_repeat_ngram_size":0, "seed":7}
|
147 |
-
models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
148 |
-
utils.split_paths(ckpt),
|
149 |
-
arg_overrides=overrides
|
150 |
-
)
|
151 |
-
|
152 |
-
# Move models to GPU
|
153 |
-
for model in models:
|
154 |
-
model.eval()
|
155 |
-
if use_fp16:
|
156 |
-
model.half()
|
157 |
-
if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
|
158 |
-
model.cuda()
|
159 |
-
model.prepare_for_inference_(cfg)
|
160 |
-
|
161 |
-
# Initialize generator
|
162 |
-
generator = task.build_generator(models, cfg.generation)
|
163 |
-
|
164 |
-
bos_item = torch.LongTensor([task.src_dict.bos()])
|
165 |
-
eos_item = torch.LongTensor([task.src_dict.eos()])
|
166 |
-
pad_idx = task.src_dict.pad()
|
167 |
-
|
168 |
orig_image = Image.open(img)
|
169 |
results = get_images(img, reader)
|
170 |
box_list, image_list = zip(*results)
|
171 |
draw_boxes(orig_image, box_list)
|
172 |
-
orig_image.save(out_img)
|
173 |
|
174 |
ocr_result = []
|
175 |
for box, image in zip(box_list, image_list):
|
@@ -183,7 +180,8 @@ def ocr(ckpt, img, out_img):
|
|
183 |
ocr_result.append(result[0]['ocr'].replace(' ', ''))
|
184 |
|
185 |
result = '\n'.join(ocr_result)
|
186 |
-
|
|
|
187 |
|
188 |
|
189 |
title = "OFA-OCR"
|
@@ -192,7 +190,8 @@ description = "Gradio Demo for OFA-OCR. Upload your own image or click any one o
|
|
192 |
article = "<p style='text-align: center'><a href='https://github.com/OFA-Sys/OFA' target='_blank'>OFA Github " \
|
193 |
"Repo</a></p> "
|
194 |
examples = [['lihe.png'], ['chinese.jpg'], ['paibian.jpeg'], ['shupai.png'], ['zuowen.jpg']]
|
195 |
-
io = gr.Interface(fn=ocr, inputs=gr.inputs.Image(type='pil'),
|
|
|
196 |
title=title, description=description, article=article, examples=examples,
|
197 |
allow_flagging=False, allow_screenshot=False)
|
198 |
io.launch(cache_examples=True)
|
|
|
111 |
return _patch_resize_transform
|
112 |
|
113 |
|
114 |
+
reader = ReaderLite()
|
115 |
+
overrides={"eval_cider":False, "beam":8, "max_len_b":128, "patch_image_size":480,
|
116 |
+
"orig_patch_image_size":224, "no_repeat_ngram_size":0, "seed":7}
|
117 |
+
models, cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
118 |
+
utils.split_paths('checkpoints/ocr.pt'),
|
119 |
+
arg_overrides=overrides
|
120 |
+
)
|
121 |
+
|
122 |
+
# Move models to GPU
|
123 |
+
for model in models:
|
124 |
+
model.eval()
|
125 |
+
if use_fp16:
|
126 |
+
model.half()
|
127 |
+
if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
|
128 |
+
model.cuda()
|
129 |
+
model.prepare_for_inference_(cfg)
|
130 |
+
|
131 |
+
# Initialize generator
|
132 |
+
generator = task.build_generator(models, cfg.generation)
|
133 |
+
|
134 |
+
bos_item = torch.LongTensor([task.src_dict.bos()])
|
135 |
+
eos_item = torch.LongTensor([task.src_dict.eos()])
|
136 |
+
pad_idx = task.src_dict.pad()
|
137 |
+
|
138 |
+
|
139 |
# Construct input for caption task
|
140 |
def construct_sample(task, image: Image, patch_image_size=480):
|
|
|
|
|
|
|
|
|
141 |
patch_image = patch_resize_transform(patch_image_size)(image).unsqueeze(0)
|
142 |
patch_mask = torch.tensor([True])
|
143 |
src_text = encode_text(task, "图片上的文å—是什么?", append_bos=True, append_eos=True).unsqueeze(0)
|
|
|
162 |
return t
|
163 |
|
164 |
|
165 |
+
def ocr(img):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
orig_image = Image.open(img)
|
167 |
results = get_images(img, reader)
|
168 |
box_list, image_list = zip(*results)
|
169 |
draw_boxes(orig_image, box_list)
|
|
|
170 |
|
171 |
ocr_result = []
|
172 |
for box, image in zip(box_list, image_list):
|
|
|
180 |
ocr_result.append(result[0]['ocr'].replace(' ', ''))
|
181 |
|
182 |
result = '\n'.join(ocr_result)
|
183 |
+
|
184 |
+
return orig_image, result
|
185 |
|
186 |
|
187 |
title = "OFA-OCR"
|
|
|
190 |
article = "<p style='text-align: center'><a href='https://github.com/OFA-Sys/OFA' target='_blank'>OFA Github " \
|
191 |
"Repo</a></p> "
|
192 |
examples = [['lihe.png'], ['chinese.jpg'], ['paibian.jpeg'], ['shupai.png'], ['zuowen.jpg']]
|
193 |
+
io = gr.Interface(fn=ocr, inputs=gr.inputs.Image(type='pil'),
|
194 |
+
outputs=[gr.outputs.Image(type='pil'), gr.outputs.Textbox(label="OCR result")],
|
195 |
title=title, description=description, article=article, examples=examples,
|
196 |
allow_flagging=False, allow_screenshot=False)
|
197 |
io.launch(cache_examples=True)
|