File size: 4,555 Bytes
d7e12cd
250467d
9947575
fa9edbf
 
 
 
 
f07599d
d7e12cd
 
 
fa9edbf
ba1bd28
 
77415fc
f07599d
fa9edbf
 
 
9947575
fa9edbf
 
 
f07599d
fa9edbf
 
 
 
 
 
 
f07599d
fa9edbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f07599d
fa9edbf
f07599d
fa9edbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77415fc
 
 
fa9edbf
 
77415fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa9edbf
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
import streamlit as st
from transformers import AutoModel, AutoTokenizer
from PIL import Image
import base64
import uuid
import time
from pathlib import Path

# Force the use of CPU
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

# Load tokenizer and model on CPU
tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True)
model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
model.eval()

# Define folders for uploads and results
UPLOAD_FOLDER = "./uploads"
RESULTS_FOLDER = "./results"

for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]:
    if not os.path.exists(folder):
        os.makedirs(folder)

# Function to run the GOT model
def run_GOT(image, got_mode, fine_grained_mode="", ocr_color="", ocr_box=""):
    unique_id = str(uuid.uuid4())
    image_path = os.path.join(UPLOAD_FOLDER, f"{unique_id}.png")
    result_path = os.path.join(RESULTS_FOLDER, f"{unique_id}.html")
    
    image.save(image_path)

    try:
        if got_mode == "plain texts OCR":
            res = model.chat(tokenizer, image_path, ocr_type='ocr')
            return res, None
        elif got_mode == "format texts OCR":
            res = model.chat(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
        elif got_mode == "plain multi-crop OCR":
            res = model.chat_crop(tokenizer, image_path, ocr_type='ocr')
            return res, None
        elif got_mode == "format multi-crop OCR":
            res = model.chat_crop(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
        elif got_mode == "plain fine-grained OCR":
            res = model.chat(tokenizer, image_path, ocr_type='ocr', ocr_box=ocr_box, ocr_color=ocr_color)
            return res, None
        elif got_mode == "format fine-grained OCR":
            res = model.chat(tokenizer, image_path, ocr_type='format', ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path)

        res_markdown = res

        if "format" in got_mode and os.path.exists(result_path):
            with open(result_path, 'r') as f:
                html_content = f.read()
            encoded_html = base64.b64encode(html_content.encode('utf-8')).decode('utf-8')
            iframe_src = f"data:text/html;base64,{encoded_html}"
            iframe = f'<iframe src="{iframe_src}" width="100%" height="600px"></iframe>'
            return res_markdown, iframe
        else:
            return res_markdown, None
    except Exception as e:
        return f"Error: {str(e)}", None
    finally:
        if os.path.exists(image_path):
            os.remove(image_path)

# Function to clean up old files
def cleanup_old_files():
    current_time = time.time()
    for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]:
        for file_path in Path(folder).glob('*'):
            if current_time - file_path.stat().st_mtime > 3600:  # 1 hour
                file_path.unlink()

# Streamlit App
st.set_page_config(page_title="GOT-OCR-2.0 Demo", layout="wide")

uploaded_image = st.file_uploader("Upload your image", type=["png", "jpg", "jpeg"])

# Create two columns for layout
col1, col2 = st.columns(2)

if uploaded_image:
    image = Image.open(uploaded_image)
    
    with col1:
        st.image(image, caption='Uploaded Image', use_column_width=True)

    with col2:
        got_mode = st.selectbox("Choose one mode of GOT", [
            "plain texts OCR",
            "format texts OCR",
            "plain multi-crop OCR",
            "format multi-crop OCR",
            "plain fine-grained OCR",
            "format fine-grained OCR",
        ])

        fine_grained_mode = None
        ocr_color = ""
        ocr_box = ""

        if "fine-grained" in got_mode:
            fine_grained_mode = st.selectbox("Fine-grained type", ["box", "color"])
            if fine_grained_mode == "box":
                ocr_box = st.text_input("Input box: [x1,y1,x2,y2]", value="[0,0,100,100]")
            elif fine_grained_mode == "color":
                ocr_color = st.selectbox("Color list", ["red", "green", "blue"])

        if st.button("Submit"):
            with st.spinner("Processing..."):
                result_text, html_result = run_GOT(image, got_mode, fine_grained_mode, ocr_color, ocr_box)
                st.text_area("GOT Output", result_text, height=200)

                if html_result:
                    st.markdown(html_result, unsafe_allow_html=True)

# Cleanup old files
cleanup_old_files()