neoguojing commited on
Commit
ac510cd
·
1 Parent(s): fdd016a

finish ocr

Browse files
Files changed (4) hide show
  1. app.py +15 -12
  2. ocr.py +61 -0
  3. requirements.txt +4 -1
  4. sam_everything.py +2 -3
app.py CHANGED
@@ -1,18 +1,11 @@
1
- import json
2
- from functools import partial
3
- from pathlib import Path
4
 
5
  import gradio as gr
6
- from PIL import Image
7
- import torch
8
  import numpy as np
9
  from gradio_image_prompter import ImagePrompter
10
- import sys
11
- sys.path.append("..")
12
-
13
  from inference import ModelFactory
14
  from face import FaceAlgo
15
  from sam_everything import SamAnything
 
16
 
17
 
18
  components = {}
@@ -113,8 +106,8 @@ def create_ui():
113
  with gr.Tab("OCR"):
114
  with gr.Row():
115
  with gr.Column(scale=2):
116
- components["algo_type"] = gr.Dropdown(
117
- ["OCR","DoNut"],value="DoNut",
118
  label="算法类别",interactive=True
119
  )
120
  with gr.Column(scale=2):
@@ -124,11 +117,14 @@ def create_ui():
124
  with gr.Column(scale=2):
125
  with gr.Row(elem_id=''):
126
  with gr.Group():
127
- components["ocr_input"] = gr.Gallery(elem_id='ocr-input',label='输入',columns=2,type="pil")
128
  with gr.Column(scale=2):
129
  with gr.Row():
130
  with gr.Group():
131
- components["ocr_output"] = gr.Gallery(elem_id='ocr_output',label='输出',columns=2,interactive=False)
 
 
 
132
 
133
  create_event_handlers()
134
  return demo
@@ -172,6 +168,10 @@ def create_event_handlers():
172
  do_sam_everything,gradio('sam_input'),gradio("sam_output")
173
  )
174
 
 
 
 
 
175
  def do_refernce(algo_type,input_image):
176
  # def do_refernce():
177
  print("input image",input_image)
@@ -243,6 +243,9 @@ def do_sam_everything(im):
243
 
244
  return images
245
 
 
 
 
246
  def point_to_mask(pil_image):
247
  # 遍历每个像素
248
  width, height = pil_image.size
 
 
 
 
1
 
2
  import gradio as gr
 
 
3
  import numpy as np
4
  from gradio_image_prompter import ImagePrompter
 
 
 
5
  from inference import ModelFactory
6
  from face import FaceAlgo
7
  from sam_everything import SamAnything
8
+ from ocr import do_ocr
9
 
10
 
11
  components = {}
 
106
  with gr.Tab("OCR"):
107
  with gr.Row():
108
  with gr.Column(scale=2):
109
+ components["ocr_type"] = gr.Dropdown(
110
+ ["OCR","Easy"],value="Easy",
111
  label="算法类别",interactive=True
112
  )
113
  with gr.Column(scale=2):
 
117
  with gr.Column(scale=2):
118
  with gr.Row(elem_id=''):
119
  with gr.Group():
120
+ components["ocr_input"] = gr.Image(elem_id='ocr-input',label='输入',type="pil")
121
  with gr.Column(scale=2):
122
  with gr.Row():
123
  with gr.Group():
124
+ components["ocr_output"] = gr.Image(elem_id='ocr_output',label='输出',interactive=False,type="pil")
125
+ with gr.Row():
126
+ with gr.Group():
127
+ components["ocr_json_output"] = gr.JSON(label="推理结果")
128
 
129
  create_event_handlers()
130
  return demo
 
168
  do_sam_everything,gradio('sam_input'),gradio("sam_output")
169
  )
170
 
171
+ components["submit_ocr_btn"].click(
172
+ do_ocr,gradio('ocr_type','ocr_input'),gradio("ocr_output","ocr_json_output")
173
+ )
174
+
175
  def do_refernce(algo_type,input_image):
176
  # def do_refernce():
177
  print("input image",input_image)
 
243
 
244
  return images
245
 
246
+
247
+
248
+
249
  def point_to_mask(pil_image):
250
  # 遍历每个像素
251
  width, height = pil_image.size
ocr.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from detectron2.data.detection_utils import read_image,pil_image_to_numpy
3
+ from detectron2.utils.visualizer import Visualizer
4
+ from sam_everything import visimage_to_pil
5
+ import numpy as np
6
+ def do_ocr(ocr_type,input):
7
+ print(ocr_type)
8
+ result = None
9
+ np_image = pil_image_to_numpy(input)
10
+ if ocr_type == "OCR":
11
+ from paddleocr import PaddleOCR
12
+ ocr = PaddleOCR(lang='ch', use_angle_cls=True)
13
+ # img_path = 'exp.jpeg'
14
+ result = ocr.ocr(np_image)
15
+ print(result)
16
+ result = parse_paddle_result(result)
17
+
18
+ elif ocr_type == "Easy":
19
+ import easyocr
20
+ reader = easyocr.Reader(['en','ch_sim']) # 初始化 EasyOCR,选择需要支持的语言(例如英文)
21
+ result = reader.readtext(np_image)
22
+ result = parse_esay_result(result)
23
+
24
+ view = Visualizer(np_image)
25
+ for item in result:
26
+ polygon = np.array(item['box'])
27
+ view.draw_polygon(polygon, "k")
28
+
29
+ vis_image = view.get_output()
30
+ pil_images = visimage_to_pil([vis_image])
31
+ return pil_images[0],result
32
+
33
+ def parse_esay_result(data):
34
+ results = []
35
+ for entry in data:
36
+ box = entry[0]
37
+ text = entry[1]
38
+ confidence = entry[2]
39
+ result = {
40
+ 'box': box,
41
+ 'text': text,
42
+ 'confidence': confidence
43
+ }
44
+ results.append(result)
45
+ return results
46
+
47
+ def parse_paddle_result(data):
48
+ results = []
49
+ for entry in data[0]:
50
+ box = entry[0]
51
+ text = entry[1][0]
52
+ confidence = entry[1][1]
53
+ result = {
54
+ 'box': box,
55
+ 'text': text,
56
+ 'confidence': confidence
57
+ }
58
+ results.append(result)
59
+ return results
60
+
61
+
requirements.txt CHANGED
@@ -11,4 +11,7 @@ omegaconf==2.3.0
11
  pycocotools==2.0.7
12
  gradio_image_prompter==0.1.0
13
  cloudpickle==2.2.1
14
- segment_anything @ git+https://github.com/facebookresearch/segment-anything.git
 
 
 
 
11
  pycocotools==2.0.7
12
  gradio_image_prompter==0.1.0
13
  cloudpickle==2.2.1
14
+ segment_anything @ git+https://github.com/facebookresearch/segment-anything.git
15
+ paddlepaddle==2.6.1
16
+ paddleocr==2.7.3
17
+ easyocr==1.7.1
sam_everything.py CHANGED
@@ -89,11 +89,10 @@ def bitmask_to_polygon(mask):
89
  return contour
90
 
91
  # VIS图片转换为pil
92
- def visimage_to_pil(visimages,need_save=True,idx=0):
93
  pil_images = []
94
  for i,visimage in enumerate(visimages):
95
- visualized_image = visimage.get_image()
96
- # [:, :, ::-1]
97
  pil_image = Image.fromarray(visualized_image)
98
  if need_save:
99
  pil_image.save(f"{idx}_{i}.jpg")
 
89
  return contour
90
 
91
  # VIS图片转换为pil
92
+ def visimage_to_pil(visimages,need_save=False,idx=0):
93
  pil_images = []
94
  for i,visimage in enumerate(visimages):
95
+ visualized_image = visimage.get_image()[:, :, ::-1]
 
96
  pil_image = Image.fromarray(visualized_image)
97
  if need_save:
98
  pil_image.save(f"{idx}_{i}.jpg")