OCR_Application / app.py
Divyansh12's picture
Update app.py
fa9edbf verified
raw
history blame
5.42 kB
import streamlit as st
from transformers import AutoModel, AutoTokenizer
from PIL import Image
import os
import base64
import uuid
import time
import shutil
from pathlib import Path
# Load tokenizer and model on CPU
tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True)
model = model.eval() # Use CPU
# Define folders for uploads and results
UPLOAD_FOLDER = "./uploads"
RESULTS_FOLDER = "./results"
for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]:
if not os.path.exists(folder):
os.makedirs(folder)
# Function to run the GOT model
def run_GOT(image, got_mode, fine_grained_mode="", ocr_color="", ocr_box=""):
unique_id = str(uuid.uuid4())
image_path = os.path.join(UPLOAD_FOLDER, f"{unique_id}.png")
result_path = os.path.join(RESULTS_FOLDER, f"{unique_id}.html")
image.save(image_path)
try:
if got_mode == "plain texts OCR":
res = model.chat(tokenizer, image_path, ocr_type='ocr')
return res, None
elif got_mode == "format texts OCR":
res = model.chat(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
elif got_mode == "plain multi-crop OCR":
res = model.chat_crop(tokenizer, image_path, ocr_type='ocr')
return res, None
elif got_mode == "format multi-crop OCR":
res = model.chat_crop(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
elif got_mode == "plain fine-grained OCR":
res = model.chat(tokenizer, image_path, ocr_type='ocr', ocr_box=ocr_box, ocr_color=ocr_color)
return res, None
elif got_mode == "format fine-grained OCR":
res = model.chat(tokenizer, image_path, ocr_type='format', ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path)
res_markdown = res
if "format" in got_mode and os.path.exists(result_path):
with open(result_path, 'r') as f:
html_content = f.read()
encoded_html = base64.b64encode(html_content.encode('utf-8')).decode('utf-8')
iframe_src = f"data:text/html;base64,{encoded_html}"
iframe = f'<iframe src="{iframe_src}" width="100%" height="600px"></iframe>'
return res_markdown, iframe
else:
return res_markdown, None
except Exception as e:
return f"Error: {str(e)}", None
finally:
if os.path.exists(image_path):
os.remove(image_path)
# Function to clean up old files
def cleanup_old_files():
current_time = time.time()
for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]:
for file_path in Path(folder).glob('*'):
if current_time - file_path.stat().st_mtime > 3600: # 1 hour
file_path.unlink()
# Streamlit App
st.set_page_config(page_title="GOT-OCR-2.0 Demo", layout="wide")
st.markdown("""
<h2> <span style="color: #ff6600">General OCR Theory</span>: Towards OCR-2.0 via a Unified End-to-end Model</h2>
<a href="https://huggingface.co/ucaslcl/GOT-OCR2_0">[😊 Hugging Face]</a>
<a href="https://arxiv.org/abs/2409.01704">[πŸ“œ Paper]</a>
<a href="https://github.com/Ucas-HaoranWei/GOT-OCR2.0/">[🌟 GitHub]</a>
""", unsafe_allow_html=True)
st.markdown("""
"πŸ”₯πŸ”₯πŸ”₯This is the official online demo of the GOT-OCR-2.0 model!!!"
### Demo Guidelines
- You need to upload your image below and choose one mode of GOT, then click "Submit" to run the GOT model. More characters will result in longer wait times.
- **plain texts OCR & format texts OCR**: The two modes are for the image-level OCR.
- **plain multi-crop OCR & format multi-crop OCR**: For images with more complex content, you can achieve higher-quality results with these modes.
- **plain fine-grained OCR & format fine-grained OCR**: In these modes, you can specify fine-grained regions on the input image for more flexible OCR. Fine-grained regions can be coordinates of the box, red color, blue color, or green color.
""")
uploaded_image = st.file_uploader("Upload your image", type=["png", "jpg", "jpeg"])
if uploaded_image:
image = Image.open(uploaded_image)
st.image(image, caption='Uploaded Image', use_column_width=True)
got_mode = st.selectbox("Choose one mode of GOT", [
"plain texts OCR",
"format texts OCR",
"plain multi-crop OCR",
"format multi-crop OCR",
"plain fine-grained OCR",
"format fine-grained OCR",
])
fine_grained_mode = None
ocr_color = ""
ocr_box = ""
if "fine-grained" in got_mode:
fine_grained_mode = st.selectbox("Fine-grained type", ["box", "color"])
if fine_grained_mode == "box":
ocr_box = st.text_input("Input box: [x1,y1,x2,y2]", value="[0,0,100,100]")
elif fine_grained_mode == "color":
ocr_color = st.selectbox("Color list", ["red", "green", "blue"])
if st.button("Submit"):
with st.spinner("Processing..."):
result_text, html_result = run_GOT(image, got_mode, fine_grained_mode, ocr_color, ocr_box)
st.text_area("GOT Output", result_text, height=200)
if html_result:
st.markdown(html_result, unsafe_allow_html=True)
# Cleanup old files
cleanup_old_files()