File size: 5,418 Bytes
9947575
250467d
9947575
fa9edbf
 
 
 
 
 
 
f07599d
fa9edbf
 
 
 
f07599d
fa9edbf
 
 
9947575
fa9edbf
 
 
f07599d
fa9edbf
 
 
 
 
 
 
f07599d
fa9edbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f07599d
fa9edbf
f07599d
fa9edbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130

import streamlit as st
from transformers import AutoModel, AutoTokenizer
from PIL import Image
import os
import base64
import uuid
import time
import shutil
from pathlib import Path

# Load tokenizer and model on CPU
tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True)
model = model.eval()  # Use CPU

# Define folders for uploads and results
UPLOAD_FOLDER = "./uploads"
RESULTS_FOLDER = "./results"

for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]:
    if not os.path.exists(folder):
        os.makedirs(folder)

# Function to run the GOT model
def run_GOT(image, got_mode, fine_grained_mode="", ocr_color="", ocr_box=""):
    unique_id = str(uuid.uuid4())
    image_path = os.path.join(UPLOAD_FOLDER, f"{unique_id}.png")
    result_path = os.path.join(RESULTS_FOLDER, f"{unique_id}.html")
    
    image.save(image_path)

    try:
        if got_mode == "plain texts OCR":
            res = model.chat(tokenizer, image_path, ocr_type='ocr')
            return res, None
        elif got_mode == "format texts OCR":
            res = model.chat(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
        elif got_mode == "plain multi-crop OCR":
            res = model.chat_crop(tokenizer, image_path, ocr_type='ocr')
            return res, None
        elif got_mode == "format multi-crop OCR":
            res = model.chat_crop(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
        elif got_mode == "plain fine-grained OCR":
            res = model.chat(tokenizer, image_path, ocr_type='ocr', ocr_box=ocr_box, ocr_color=ocr_color)
            return res, None
        elif got_mode == "format fine-grained OCR":
            res = model.chat(tokenizer, image_path, ocr_type='format', ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path)

        res_markdown = res

        if "format" in got_mode and os.path.exists(result_path):
            with open(result_path, 'r') as f:
                html_content = f.read()
            encoded_html = base64.b64encode(html_content.encode('utf-8')).decode('utf-8')
            iframe_src = f"data:text/html;base64,{encoded_html}"
            iframe = f'<iframe src="{iframe_src}" width="100%" height="600px"></iframe>'
            return res_markdown, iframe
        else:
            return res_markdown, None
    except Exception as e:
        return f"Error: {str(e)}", None
    finally:
        if os.path.exists(image_path):
            os.remove(image_path)

# Function to clean up old files
def cleanup_old_files():
    current_time = time.time()
    for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]:
        for file_path in Path(folder).glob('*'):
            if current_time - file_path.stat().st_mtime > 3600:  # 1 hour
                file_path.unlink()

# Streamlit App
st.set_page_config(page_title="GOT-OCR-2.0 Demo", layout="wide")

st.markdown("""
<h2> <span style="color: #ff6600">General OCR Theory</span>: Towards OCR-2.0 via a Unified End-to-end Model</h2>
<a href="https://huggingface.co/ucaslcl/GOT-OCR2_0">[😊 Hugging Face]</a> 
<a href="https://arxiv.org/abs/2409.01704">[πŸ“œ Paper]</a> 
<a href="https://github.com/Ucas-HaoranWei/GOT-OCR2.0/">[🌟 GitHub]</a> 
""", unsafe_allow_html=True)

st.markdown("""
"πŸ”₯πŸ”₯πŸ”₯This is the official online demo of the GOT-OCR-2.0 model!!!"
### Demo Guidelines
- You need to upload your image below and choose one mode of GOT, then click "Submit" to run the GOT model. More characters will result in longer wait times.
- **plain texts OCR & format texts OCR**: The two modes are for the image-level OCR.
- **plain multi-crop OCR & format multi-crop OCR**: For images with more complex content, you can achieve higher-quality results with these modes.
- **plain fine-grained OCR & format fine-grained OCR**: In these modes, you can specify fine-grained regions on the input image for more flexible OCR. Fine-grained regions can be coordinates of the box, red color, blue color, or green color.
""")

uploaded_image = st.file_uploader("Upload your image", type=["png", "jpg", "jpeg"])

if uploaded_image:
    image = Image.open(uploaded_image)
    st.image(image, caption='Uploaded Image', use_column_width=True)

    got_mode = st.selectbox("Choose one mode of GOT", [
        "plain texts OCR",
        "format texts OCR",
        "plain multi-crop OCR",
        "format multi-crop OCR",
        "plain fine-grained OCR",
        "format fine-grained OCR",
    ])

    fine_grained_mode = None
    ocr_color = ""
    ocr_box = ""

    if "fine-grained" in got_mode:
        fine_grained_mode = st.selectbox("Fine-grained type", ["box", "color"])
        if fine_grained_mode == "box":
            ocr_box = st.text_input("Input box: [x1,y1,x2,y2]", value="[0,0,100,100]")
        elif fine_grained_mode == "color":
            ocr_color = st.selectbox("Color list", ["red", "green", "blue"])

    if st.button("Submit"):
        with st.spinner("Processing..."):
            result_text, html_result = run_GOT(image, got_mode, fine_grained_mode, ocr_color, ocr_box)
            st.text_area("GOT Output", result_text, height=200)

            if html_result:
                st.markdown(html_result, unsafe_allow_html=True)

# Cleanup old files
cleanup_old_files()