File size: 4,800 Bytes
cef4f97
 
22c7b5b
cef4f97
 
 
d1d0907
1a87a19
02ff46f
 
 
ee4b3d0
02ff46f
ed456c1
22c7b5b
cef4f97
02ff46f
cef4f97
 
00ee90b
cef4f97
02ff46f
 
 
 
1a87a19
ee4b3d0
 
cef4f97
ee4b3d0
02ff46f
 
ee4b3d0
02ff46f
ee4b3d0
02ff46f
ee4b3d0
02ff46f
ee4b3d0
02ff46f
ee4b3d0
02ff46f
ee4b3d0
40d5755
ee4b3d0
 
 
02ff46f
cef4f97
ee4b3d0
 
cef4f97
 
 
 
 
 
 
 
 
 
 
 
71a766f
02ff46f
ee4b3d0
 
 
 
 
 
 
02ff46f
ee4b3d0
 
02ff46f
 
cef4f97
 
ee4b3d0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02ff46f
ee4b3d0
 
 
 
02ff46f
cef4f97
 
 
ee4b3d0
cef4f97
 
 
 
 
ee4b3d0
cef4f97
 
f34dca6
ee4b3d0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import gradio as gr
import torch
from transformers import AutoModel, AutoTokenizer, AutoConfig
import os
import base64
import spaces
import io
from PIL import Image
import numpy as np
import yaml
from pathlib import Path
from globe import title, description, modelinfor, joinus

model_name = 'ucaslcl/GOT-OCR2_0'

tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
model = model.eval().cuda()
model.config.pad_token_id = tokenizer.eos_token_id

def image_to_base64(image):
    buffered = io.BytesIO()
    image.save(buffered, format="PNG")
    return base64.b64encode(buffered.getvalue()).decode()

html_file = './demo.html'

@spaces.GPU
def process_image(image, task, ocr_type=None, ocr_box=None, ocr_color=None):
    if task == "Plain Text OCR":
        res = model.chat(tokenizer, image, ocr_type='ocr')
        return res, None
    elif task == "Format Text OCR":
        res = model.chat(tokenizer, image, ocr_type='format', render=True, save_render_file=html_file)
    elif task == "Fine-grained OCR (Box)":
        res = model.chat(tokenizer, image, ocr_type=ocr_type, ocr_box=ocr_box, render=True, save_render_file=html_file)
    elif task == "Fine-grained OCR (Color)":
        res = model.chat(tokenizer, image, ocr_type=ocr_type, ocr_color=ocr_color, render=True, save_render_file=html_file)
    elif task == "Multi-crop OCR":
        res = model.chat_crop(tokenizer, image, ocr_type='format', render=True, save_render_file=html_file)
    elif task == "Render Formatted OCR":
        res = model.chat(tokenizer, image, ocr_type='format', render=True, save_render_file=html_file)
    
    with open(html_file, 'r') as f:
        html_content = f.read()
    return res, html_content

def update_inputs(task):
    if task in ["Plain Text OCR", "Format Text OCR", "Multi-crop OCR", "Render Formatted OCR"]:
        return [gr.update(visible=False)] * 3
    elif task == "Fine-grained OCR (Box)":
        return [
            gr.update(visible=True, choices=["ocr", "format"]),
            gr.update(visible=True),
            gr.update(visible=False),
        ]
    elif task == "Fine-grained OCR (Color)":
        return [
            gr.update(visible=True, choices=["ocr", "format"]),
            gr.update(visible=False),
            gr.update(visible=True, choices=["red", "green", "blue"]),
        ]
def ocr_demo(image, task, ocr_type, ocr_box, ocr_color):
    res, html_content = process_image(image, task, ocr_type, ocr_box, ocr_color)
    
    res = f"${res}$"
    res = res.replace("$\\begin{tabular}", "\\begin{tabular}")
    res = res.replace("\\end{tabular}$", "\\end{tabular}")
    res = res.replace("\\(", "")
    res = res.replace("\\)", "")
    
    if html_content:
        html_string = f'<iframe srcdoc="{html_content}" width="100%" height="600px"></iframe>'
        return res, html_string
    return res, None
import gradio as gr

with gr.Blocks() as demo:
    gr.Markdown(title)
    gr.Markdown(description)
    gr.Markdown(joinus)
    
    with gr.Column():
        image_input = gr.Image(type="filepath", label="Input Image")
        task_dropdown = gr.Dropdown(
            choices=[
                "Plain Text OCR",
                "Format Text OCR",
                "Fine-grained OCR (Box)",
                "Fine-grained OCR (Color)",
                "Multi-crop OCR",
                "Render Formatted OCR"
            ],
            label="Select Task",
            value="Plain Text OCR"
        )
        ocr_type_dropdown = gr.Dropdown(
            choices=["ocr", "format"],
            label="OCR Type",
            visible=False
        )
        ocr_box_input = gr.Textbox(
            label="OCR Box (x1,y1,x2,y2)",
            placeholder="e.g., 100,100,200,200",
            visible=False
        )
        ocr_color_dropdown = gr.Dropdown(
            choices=["red", "green", "blue"],
            label="OCR Color",
            visible=False
        )
        submit_button = gr.Button("Process")

        output_markdown = gr.Markdown(label="🫴🏻📸GOT-OCR")
        output_html = gr.HTML(label="🫴🏻📸GOT-OCR")

    gr.Markdown(modelinfor)

    task_dropdown.change(
        update_inputs,
        inputs=[task_dropdown],
        outputs=[ocr_type_dropdown, ocr_box_input, ocr_color_dropdown]
    )
    
    submit_button.click(
        ocr_demo,
        inputs=[image_input, task_dropdown, ocr_type_dropdown, ocr_box_input, ocr_color_dropdown],
        outputs=[output_markdown, output_html]
    )

if __name__ == "__main__":
    demo.launch()