Spaces:
Sleeping
Sleeping
File size: 4,104 Bytes
2d8087a b5ef879 fa9edbf b5ef879 fa9edbf f07599d b5ef879 2d8087a b5ef879 f07599d 3d71449 b5ef879 fa9edbf 3d71449 fa9edbf b5ef879 f07599d fa9edbf b5ef879 fa9edbf f604f09 fa9edbf 3d71449 fa9edbf 3d71449 b5ef879 fa9edbf b5ef879 fa9edbf 2d8087a f604f09 2d8087a 77415fc b5ef879 2d8087a b5ef879 2d8087a b5ef879 2d8087a b5ef879 2d8087a b5ef879 2d8087a b5ef879 2d8087a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
ffrom transformers import AutoModel, AutoTokenizer
import streamlit as st
from PIL import Image
import re
import os
import uuid
# Load the model and tokenizer only once
if "model" not in st.session_state or "tokenizer" not in st.session_state:
@st.cache_resource
def load_model(model_name):
if model_name == "OCR for English or Hindi (CPU)":
tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True)
model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
model = model.eval()
elif model_name == "OCR for English (GPU)":
tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
model = model.eval().to('cuda')
return model, tokenizer
# Load and store in session state
model_option = "OCR for English or Hindi (CPU)" # Default value for loading purposes
model, tokenizer = load_model(model_option)
st.session_state["model"] = model
st.session_state["tokenizer"] = tokenizer
else:
model = st.session_state["model"]
tokenizer = st.session_state["tokenizer"]
# Function to run the GOT model for multilingual OCR
def run_ocr(image, model, tokenizer):
unique_id = str(uuid.uuid4())
image_path = f"{unique_id}.png"
# Save image to disk
image.save(image_path)
try:
# Use the model to extract text from the image
res = model.chat(tokenizer, image_path, ocr_type='ocr')
if isinstance(res, str):
return res
else:
return str(res)
except Exception as e:
return f"Error: {str(e)}"
finally:
# Clean up the saved image
if os.path.exists(image_path):
os.remove(image_path)
# Function to highlight keyword in text
def highlight_text(text, search_term):
if not search_term:
return text
pattern = re.compile(re.escape(search_term), re.IGNORECASE)
return pattern.sub(lambda m: f'<span style="background-color: yellow;">{m.group()}</span>', text)
# Streamlit App
st.title("GOT-OCR Multilingual Demo")
st.write("Upload an image for OCR")
# Create two columns
col1, col2 = st.columns(2)
# Left column - Display the uploaded image
with col1:
uploaded_image = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"])
if uploaded_image:
image = Image.open(uploaded_image)
st.image(image, caption='Uploaded Image', use_column_width=True)
# Right column - Model selection, options, and displaying extracted text
with col2:
model_option = st.selectbox("Select Model", ["OCR for English or Hindi (CPU)", "OCR for English (GPU)"])
if st.button("Run OCR"):
with st.spinner("Processing..."):
# Run OCR and store the result in session state
if uploaded_image:
result_text = run_ocr(image, model, tokenizer)
if "Error" not in result_text:
st.session_state["extracted_text"] = result_text # Store the result in session state
else:
st.error(result_text)
else:
st.error("Please upload an image before running OCR.")
# Display the extracted text if it exists in session state
if "extracted_text" in st.session_state:
extracted_text = st.session_state["extracted_text"]
# Keyword input for search
search_term = st.text_input("Enter a word or phrase to highlight:")
# Highlight keyword in the extracted text
highlighted_text = highlight_text(extracted_text, search_term)
# Display the highlighted text using markdown
st.subheader("Extracted Text:")
st.markdown(f'<div style="white-space: pre-wrap;">{highlighted_text}</div>', unsafe_allow_html=True)
|