OCR_Application / app.py
Divyansh12's picture
Update app.py
77415fc verified
raw
history blame
4.56 kB
import os
import streamlit as st
from transformers import AutoModel, AutoTokenizer
from PIL import Image
import base64
import uuid
import time
from pathlib import Path
# Force the use of CPU
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
# Load tokenizer and model on CPU
tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True)
model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
model.eval()
# Define folders for uploads and results
UPLOAD_FOLDER = "./uploads"
RESULTS_FOLDER = "./results"
for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]:
if not os.path.exists(folder):
os.makedirs(folder)
# Function to run the GOT model
def run_GOT(image, got_mode, fine_grained_mode="", ocr_color="", ocr_box=""):
unique_id = str(uuid.uuid4())
image_path = os.path.join(UPLOAD_FOLDER, f"{unique_id}.png")
result_path = os.path.join(RESULTS_FOLDER, f"{unique_id}.html")
image.save(image_path)
try:
if got_mode == "plain texts OCR":
res = model.chat(tokenizer, image_path, ocr_type='ocr')
return res, None
elif got_mode == "format texts OCR":
res = model.chat(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
elif got_mode == "plain multi-crop OCR":
res = model.chat_crop(tokenizer, image_path, ocr_type='ocr')
return res, None
elif got_mode == "format multi-crop OCR":
res = model.chat_crop(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
elif got_mode == "plain fine-grained OCR":
res = model.chat(tokenizer, image_path, ocr_type='ocr', ocr_box=ocr_box, ocr_color=ocr_color)
return res, None
elif got_mode == "format fine-grained OCR":
res = model.chat(tokenizer, image_path, ocr_type='format', ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path)
res_markdown = res
if "format" in got_mode and os.path.exists(result_path):
with open(result_path, 'r') as f:
html_content = f.read()
encoded_html = base64.b64encode(html_content.encode('utf-8')).decode('utf-8')
iframe_src = f"data:text/html;base64,{encoded_html}"
iframe = f'<iframe src="{iframe_src}" width="100%" height="600px"></iframe>'
return res_markdown, iframe
else:
return res_markdown, None
except Exception as e:
return f"Error: {str(e)}", None
finally:
if os.path.exists(image_path):
os.remove(image_path)
# Function to clean up old files
def cleanup_old_files():
current_time = time.time()
for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]:
for file_path in Path(folder).glob('*'):
if current_time - file_path.stat().st_mtime > 3600: # 1 hour
file_path.unlink()
# Streamlit App
st.set_page_config(page_title="GOT-OCR-2.0 Demo", layout="wide")
uploaded_image = st.file_uploader("Upload your image", type=["png", "jpg", "jpeg"])
# Create two columns for layout
col1, col2 = st.columns(2)
if uploaded_image:
image = Image.open(uploaded_image)
with col1:
st.image(image, caption='Uploaded Image', use_column_width=True)
with col2:
got_mode = st.selectbox("Choose one mode of GOT", [
"plain texts OCR",
"format texts OCR",
"plain multi-crop OCR",
"format multi-crop OCR",
"plain fine-grained OCR",
"format fine-grained OCR",
])
fine_grained_mode = None
ocr_color = ""
ocr_box = ""
if "fine-grained" in got_mode:
fine_grained_mode = st.selectbox("Fine-grained type", ["box", "color"])
if fine_grained_mode == "box":
ocr_box = st.text_input("Input box: [x1,y1,x2,y2]", value="[0,0,100,100]")
elif fine_grained_mode == "color":
ocr_color = st.selectbox("Color list", ["red", "green", "blue"])
if st.button("Submit"):
with st.spinner("Processing..."):
result_text, html_result = run_GOT(image, got_mode, fine_grained_mode, ocr_color, ocr_box)
st.text_area("GOT Output", result_text, height=200)
if html_result:
st.markdown(html_result, unsafe_allow_html=True)
# Cleanup old files
cleanup_old_files()