Spaces:
Sleeping
Sleeping
File size: 4,555 Bytes
d7e12cd 250467d 9947575 fa9edbf f07599d d7e12cd fa9edbf ba1bd28 77415fc f07599d fa9edbf 9947575 fa9edbf f07599d fa9edbf f07599d fa9edbf f07599d fa9edbf f07599d fa9edbf 77415fc fa9edbf 77415fc fa9edbf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import os
import streamlit as st
from transformers import AutoModel, AutoTokenizer
from PIL import Image
import base64
import uuid
import time
from pathlib import Path
# Force the use of CPU
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
# Load tokenizer and model on CPU
tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True)
model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
model.eval()
# Define folders for uploads and results
UPLOAD_FOLDER = "./uploads"
RESULTS_FOLDER = "./results"
for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]:
if not os.path.exists(folder):
os.makedirs(folder)
# Function to run the GOT model
def run_GOT(image, got_mode, fine_grained_mode="", ocr_color="", ocr_box=""):
unique_id = str(uuid.uuid4())
image_path = os.path.join(UPLOAD_FOLDER, f"{unique_id}.png")
result_path = os.path.join(RESULTS_FOLDER, f"{unique_id}.html")
image.save(image_path)
try:
if got_mode == "plain texts OCR":
res = model.chat(tokenizer, image_path, ocr_type='ocr')
return res, None
elif got_mode == "format texts OCR":
res = model.chat(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
elif got_mode == "plain multi-crop OCR":
res = model.chat_crop(tokenizer, image_path, ocr_type='ocr')
return res, None
elif got_mode == "format multi-crop OCR":
res = model.chat_crop(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path)
elif got_mode == "plain fine-grained OCR":
res = model.chat(tokenizer, image_path, ocr_type='ocr', ocr_box=ocr_box, ocr_color=ocr_color)
return res, None
elif got_mode == "format fine-grained OCR":
res = model.chat(tokenizer, image_path, ocr_type='format', ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path)
res_markdown = res
if "format" in got_mode and os.path.exists(result_path):
with open(result_path, 'r') as f:
html_content = f.read()
encoded_html = base64.b64encode(html_content.encode('utf-8')).decode('utf-8')
iframe_src = f"data:text/html;base64,{encoded_html}"
iframe = f'<iframe src="{iframe_src}" width="100%" height="600px"></iframe>'
return res_markdown, iframe
else:
return res_markdown, None
except Exception as e:
return f"Error: {str(e)}", None
finally:
if os.path.exists(image_path):
os.remove(image_path)
# Function to clean up old files
def cleanup_old_files():
current_time = time.time()
for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]:
for file_path in Path(folder).glob('*'):
if current_time - file_path.stat().st_mtime > 3600: # 1 hour
file_path.unlink()
# Streamlit App
st.set_page_config(page_title="GOT-OCR-2.0 Demo", layout="wide")
uploaded_image = st.file_uploader("Upload your image", type=["png", "jpg", "jpeg"])
# Create two columns for layout
col1, col2 = st.columns(2)
if uploaded_image:
image = Image.open(uploaded_image)
with col1:
st.image(image, caption='Uploaded Image', use_column_width=True)
with col2:
got_mode = st.selectbox("Choose one mode of GOT", [
"plain texts OCR",
"format texts OCR",
"plain multi-crop OCR",
"format multi-crop OCR",
"plain fine-grained OCR",
"format fine-grained OCR",
])
fine_grained_mode = None
ocr_color = ""
ocr_box = ""
if "fine-grained" in got_mode:
fine_grained_mode = st.selectbox("Fine-grained type", ["box", "color"])
if fine_grained_mode == "box":
ocr_box = st.text_input("Input box: [x1,y1,x2,y2]", value="[0,0,100,100]")
elif fine_grained_mode == "color":
ocr_color = st.selectbox("Color list", ["red", "green", "blue"])
if st.button("Submit"):
with st.spinner("Processing..."):
result_text, html_result = run_GOT(image, got_mode, fine_grained_mode, ocr_color, ocr_box)
st.text_area("GOT Output", result_text, height=200)
if html_result:
st.markdown(html_result, unsafe_allow_html=True)
# Cleanup old files
cleanup_old_files()
|