unimer_demo / app.py
wufan's picture
Upload 5 files
a47bf95 verified
raw
history blame
3.94 kB
import os
import sys
import shutil
import torch
import argparse
import gradio as gr
import numpy as np
from PIL import Image
from huggingface_hub import snapshot_download
sys.path.insert(0, os.path.join(os.getcwd(), ".."))
from unimernet.common.config import Config
import unimernet.tasks as tasks
from unimernet.processors import load_processor
def load_model_and_processor(cfg_path):
args = argparse.Namespace(cfg_path=cfg_path, options=None)
cfg = Config(args)
task = tasks.setup_task(cfg)
model = task.build_model(cfg)
vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
return model, vis_processor
def recognize_image(input_img, model_type):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if model_type == "base":
model = model_base.to(device)
elif model_type == "small":
model = model_small.to(device)
else:
model = model_tiny.to(device)
if len(input_img.shape) == 3:
input_img = input_img[:, :, ::-1].copy()
img = Image.fromarray(input_img)
image = vis_processor(img).unsqueeze(0).to(device)
output = model.generate({"image": image})
latex_code = output["pred_str"][0]
return latex_code
def gradio_reset():
return gr.update(value=None), gr.update(value=None)
if __name__ == "__main__":
root_path = os.path.abspath(os.getcwd())
# == download weights ==
tiny_model_dir = snapshot_download('wanderkid/unimernet_tiny')
small_model_dir = snapshot_download('wanderkid/unimernet_small')
base_model_dir = snapshot_download('wanderkid/unimernet_base')
os.makedirs(os.path.join(root_path, "models"), exist_ok=True)
shutil.move(tiny_model_dir, os.path.join(root_path, "models", "unimernet_tiny"))
shutil.move(small_model_dir, os.path.join(root_path, "models", "unimernet_small"))
shutil.move(base_model_dir, os.path.join(root_path, "models", "unimernet_base"))
# == download weights ==
# == load model ==
model_tiny, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_tiny.yaml"))
model_small, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_small.yaml"))
model_base, vis_processor = load_model_and_processor(os.path.join(root_path, "cfg_base.yaml"))
print("== load all models ==")
# == load model ==
with open("header.html", "r") as file:
header = file.read()
with gr.Blocks() as demo:
gr.HTML(header)
with gr.Row():
with gr.Column():
model_type = gr.Radio(
choices=["tiny", "small", "base"],
value="tiny",
label="Model Type",
interactive=True,
)
input_img = gr.Image(label=" ", interactive=True)
with gr.Row():
clear = gr.Button("Clear")
predict = gr.Button(value="Recognize", interactive=True, variant="primary")
with gr.Accordion("Examples:"):
example_root = os.path.join(os.path.dirname(__file__), "examples")
gr.Examples(
examples=[os.path.join(example_root, _) for _ in os.listdir(example_root) if
_.endswith("png")],
inputs=input_img,
)
with gr.Column():
gr.Button(value="Predict Latex:", interactive=False)
pred_latex = gr.Textbox(label='Latex', interactive=False)
clear.click(gradio_reset, inputs=None, outputs=[input_img, pred_latex])
predict.click(recognize_image, inputs=[input_img, model_type], outputs=[pred_latex])
demo.launch(server_name="0.0.0.0", server_port=7860, debug=True)