ai / app.py
neoguojing
finish ocr
ac510cd
raw
history blame
9.38 kB
import gradio as gr
import numpy as np
from gradio_image_prompter import ImagePrompter
from inference import ModelFactory
from face import FaceAlgo
from sam_everything import SamAnything
from ocr import do_ocr
components = {}
params = {
"algo_type": None,
"input_image":None
}
def gradio(*keys):
if len(keys) == 1 and type(keys[0]) in [list, tuple]:
keys = keys[0]
return [components[k] for k in keys]
algo_map = {
"目标检测":"detect",
"单阶段目标检测":"onestep_detect",
"分类":"classification",
"特征提取":"feature",
"语义分割":"semantic",
"实例分割":"instance",
"关键点检测":"keypoint",
"全景分割":"panoptic",
"YOLO":"yolo",
}
face_algo_map = {
"人脸检测":"detect",
"人脸识别":"recognize",
"人脸比对":"compare",
"特征提取":"feature",
"属性分析":"attr",
}
def create_ui():
with gr.Blocks() as demo:
with gr.Tab("基础算法"):
with gr.Row():
with gr.Column(scale=2):
components["algo_type"] = gr.Dropdown(
["目标检测","单阶段目标检测", "分类", "特征提取","语义分割","实例分割","关键点检测","全景分割","YOLO"],value="全景分割",
label="算法类别",interactive=True
)
with gr.Column(scale=2):
components["submit_btn"] = gr.Button(value="解析")
with gr.Row():
with gr.Column(scale=2):
with gr.Row(elem_id='audio-container'):
with gr.Group():
components["image_input"] = gr.Image(type="pil",elem_id='image-input',label='输入')
with gr.Column(scale=2):
with gr.Row():
with gr.Group():
components["image_output"] = gr.Image(type="pil",elem_id='image-output',label='输出',interactive=False)
with gr.Row():
with gr.Group():
components["result_output"] = gr.JSON(label="推理结果")
with gr.Tab("人脸算法"):
with gr.Row():
with gr.Column(scale=2):
components["face_type"] = gr.Dropdown(
["人脸检测","人脸识别", "人脸比对", "特征提取","属性分析"],value="人脸检测",
label="算法类别",interactive=True
)
with gr.Column(scale=2):
components["face_submit_btn"] = gr.Button(value="解析")
with gr.Row():
with gr.Column(scale=2):
with gr.Row(elem_id=''):
with gr.Group():
components["face_input"] = gr.Gallery(elem_id='face-input',label='输入',columns=2,type="pil")
with gr.Column(scale=2):
with gr.Row():
with gr.Group():
components["face_image_output"] = gr.Gallery(elem_id='face_image_output',label='输出',columns=2,interactive=False)
with gr.Row():
with gr.Group():
components["face_output"] = gr.JSON(label="推理结果")
with gr.Tab("SAM everything"):
with gr.Row():
with gr.Column(scale=2):
components["sam_submit_btn"] = gr.Button(value="解析")
with gr.Row():
with gr.Column(scale=2):
with gr.Group():
# components["sam_input"] = gr.ImageEditor(elem_id='sam-input',label='输入',type="pil")
components["sam_input"] = ImagePrompter(elem_id='sam-input',label='输入',type="pil")
with gr.Column(scale=2):
with gr.Group():
components["sam_output"] = gr.Gallery(elem_id='sam_output',label='输出',columns=1,interactive=False)
with gr.Tab("OCR"):
with gr.Row():
with gr.Column(scale=2):
components["ocr_type"] = gr.Dropdown(
["OCR","Easy"],value="Easy",
label="算法类别",interactive=True
)
with gr.Column(scale=2):
components["submit_ocr_btn"] = gr.Button(value="解析")
with gr.Row():
with gr.Column(scale=2):
with gr.Row(elem_id=''):
with gr.Group():
components["ocr_input"] = gr.Image(elem_id='ocr-input',label='输入',type="pil")
with gr.Column(scale=2):
with gr.Row():
with gr.Group():
components["ocr_output"] = gr.Image(elem_id='ocr_output',label='输出',interactive=False,type="pil")
with gr.Row():
with gr.Group():
components["ocr_json_output"] = gr.JSON(label="推理结果")
create_event_handlers()
return demo
def create_event_handlers():
params["algo_type"] = gr.State("全景分割")
params["input_image"] = gr.State()
params["face_type"] = gr.State("人脸检测")
components["image_input"].upload(
lambda x: x, gradio('image_input'), params["input_image"]
)
components["algo_type"].change(
lambda x: x, gradio('algo_type'), params["algo_type"]
)
components["submit_btn"].click(
do_refernce,gradio('algo_type','image_input'),gradio("result_output",'image_output')
)
components["face_type"].change(
ui_by_facetype, gradio('face_type'), params["face_type"]
)
components["face_submit_btn"].click(
do_face_refernce,gradio('face_type','face_input'),gradio("face_output",'face_image_output')
)
# components["sam_input"].upload(
# do_sam_everything,gradio('sam_input'),gradio("sam_output")
# )
# components["sam_input"].change(
# do_sam_everything,gradio('sam_input'),gradio("sam_output")
# )
components["sam_submit_btn"].click(
do_sam_everything,gradio('sam_input'),gradio("sam_output")
)
components["submit_ocr_btn"].click(
do_ocr,gradio('ocr_type','ocr_input'),gradio("ocr_output","ocr_json_output")
)
def do_refernce(algo_type,input_image):
# def do_refernce():
print("input image",input_image)
print(algo_type)
if input_image is None:
gr.Warning('请上传图片')
return None
algo_type = algo_map[algo_type]
factory = ModelFactory()
output,output_image = factory.predict(pil_image=input_image,task_type=algo_type)
if output_image is None or len(output_image) == 0:
return output,None
print("output image",output_image[0])
return output,output_image[0]
def ui_by_facetype(face_type):
print("ui_by_facetype",face_type)
def do_face_refernce(algo_type,input_images):
print("input image",input_images)
print(algo_type)
if input_images is None:
gr.Warning('请上传图片')
return None,None
input1 = input_images[0][0]
input2 = None
algo_type = face_algo_map[algo_type]
if algo_type == "compare" and len(input_images) >=2:
input2 = input_images[1][0]
elif algo_type == "compare" and len(input_images) < 2:
gr.Warning('请上传两张图片')
return None,None
m = FaceAlgo() # pragma: no cover
out,faces = m.predict(pil_image=input1,pil_image1=input2,algo_type=algo_type)
return out,faces
def do_sam_everything(im):
sam_anything = SamAnything()
print(im)
image_pil = im['image']
points = im['points']
images = None
if points is None or len(points) == 0:
_, images = sam_anything.seg_all(image_pil)
else:
point_coords = []
box = None
for item in points:
if item[2] == 1:
# 点类型
point_coords.append([item[0],item[1]])
else:
# box类型,只使用最后一个box
box = [item[0],item[1],item[3],item[4]]
box = np.array(box)
if box is not None:
_, images = sam_anything.seg_with_promp(image_pil,box=box)
else:
coords = np.array(point_coords)
print("point_coords:",coords.shape)
_, images = sam_anything.seg_with_promp(image_pil,point_coords=coords)
return images
def point_to_mask(pil_image):
# 遍历每个像素
width, height = pil_image.size
print(width, height)
points_list = []
for x in range(width):
for y in range(height):
# 获取像素的RGB值
pix_val = pil_image.getpixel((x, y))
if pix_val[0] != 0 and pix_val[1] != 0 and pix_val[2] != 0:
points_list.append((x, y))
points_array = np.array(points_list)
points_array_reshaped = points_array.reshape(-1, 2)
return points_array_reshaped
if __name__ == "__main__":
demo = create_ui()
# demo.launch(server_name="10.151.124.137")
demo.launch()