import os import streamlit as st from transformers import AutoModel, AutoTokenizer from PIL import Image import base64 import uuid import time from pathlib import Path # Force the use of CPU os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # Load tokenizer and model on CPU tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True) model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id) model.eval() # Define folders for uploads and results UPLOAD_FOLDER = "./uploads" RESULTS_FOLDER = "./results" for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]: if not os.path.exists(folder): os.makedirs(folder) # Function to run the GOT model def run_GOT(image, got_mode, fine_grained_mode="", ocr_color="", ocr_box=""): unique_id = str(uuid.uuid4()) image_path = os.path.join(UPLOAD_FOLDER, f"{unique_id}.png") result_path = os.path.join(RESULTS_FOLDER, f"{unique_id}.html") image.save(image_path) try: if got_mode == "plain texts OCR": res = model.chat(tokenizer, image_path, ocr_type='ocr') return res, None elif got_mode == "format texts OCR": res = model.chat(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path) elif got_mode == "plain multi-crop OCR": res = model.chat_crop(tokenizer, image_path, ocr_type='ocr') return res, None elif got_mode == "format multi-crop OCR": res = model.chat_crop(tokenizer, image_path, ocr_type='format', render=True, save_render_file=result_path) elif got_mode == "plain fine-grained OCR": res = model.chat(tokenizer, image_path, ocr_type='ocr', ocr_box=ocr_box, ocr_color=ocr_color) return res, None elif got_mode == "format fine-grained OCR": res = model.chat(tokenizer, image_path, ocr_type='format', ocr_box=ocr_box, ocr_color=ocr_color, render=True, save_render_file=result_path) res_markdown = res if "format" in got_mode and os.path.exists(result_path): with open(result_path, 'r') as f: html_content = f.read() encoded_html = base64.b64encode(html_content.encode('utf-8')).decode('utf-8') iframe_src = f"data:text/html;base64,{encoded_html}" iframe = f'' return res_markdown, iframe else: return res_markdown, None except Exception as e: return f"Error: {str(e)}", None finally: if os.path.exists(image_path): os.remove(image_path) # Function to clean up old files def cleanup_old_files(): current_time = time.time() for folder in [UPLOAD_FOLDER, RESULTS_FOLDER]: for file_path in Path(folder).glob('*'): if current_time - file_path.stat().st_mtime > 3600: # 1 hour file_path.unlink() # Streamlit App st.set_page_config(page_title="GOT-OCR-2.0 Demo", layout="wide") uploaded_image = st.file_uploader("Upload your image", type=["png", "jpg", "jpeg"]) # Create two columns for layout col1, col2 = st.columns(2) if uploaded_image: image = Image.open(uploaded_image) with col1: st.image(image, caption='Uploaded Image', use_column_width=True) with col2: got_mode = st.selectbox("Choose one mode of GOT", [ "plain texts OCR", "format texts OCR", "plain multi-crop OCR", "format multi-crop OCR", "plain fine-grained OCR", "format fine-grained OCR", ]) fine_grained_mode = None ocr_color = "" ocr_box = "" if "fine-grained" in got_mode: fine_grained_mode = st.selectbox("Fine-grained type", ["box", "color"]) if fine_grained_mode == "box": ocr_box = st.text_input("Input box: [x1,y1,x2,y2]", value="[0,0,100,100]") elif fine_grained_mode == "color": ocr_color = st.selectbox("Color list", ["red", "green", "blue"]) if st.button("Submit"): with st.spinner("Processing..."): result_text, html_result = run_GOT(image, got_mode, fine_grained_mode, ocr_color, ocr_box) st.text_area("GOT Output", result_text, height=200) if html_result: st.markdown(html_result, unsafe_allow_html=True) # Cleanup old files cleanup_old_files()