Spaces:
Sleeping
Sleeping
File size: 3,383 Bytes
d7e12cd 250467d 9947575 fa9edbf f07599d c1eebf8 3d71449 f604f09 3d71449 f604f09 3d71449 f604f09 3d71449 f604f09 f07599d 3d71449 c1eebf8 06672c1 fa9edbf 3d71449 fa9edbf 6b236d4 f07599d fa9edbf 3d71449 06672c1 f604f09 fa9edbf f604f09 fa9edbf 3d71449 fa9edbf 3d71449 f7e10cd 3d71449 fa9edbf 3d71449 fa9edbf 3d71449 fa9edbf 3d71449 77415fc 3d71449 f604f09 fa9edbf 77415fc 3d71449 77415fc 3d71449 77415fc 3d71449 f604f09 c1eebf8 6b236d4 3d71449 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import os
import streamlit as st
from transformers import AutoModel, AutoTokenizer
from PIL import Image
import uuid
# Cache the model loading function using @st.cache_resource
@st.cache_resource
def load_model(model_name):
if model_name == "OCR for english or hindi (runs on CPU)":
tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True)
model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
model.eval() # Load model on CPU
elif model_name == "OCR for english (runs on GPU)":
tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
model.eval().cuda() # Load model on GPU
return tokenizer, model
# Function to run the GOT model for multilingual OCR
@st.cache_data
def run_GOT(_image, _tokenizer, _model):
unique_id = str(uuid.uuid4())
image_path = f"{unique_id}.png"
_image.save(image_path) # Save the image using the underscore variable
try:
# Use the model to extract text
res = _model.chat(_tokenizer, image_path, ocr_type='ocr') # Extract plain text
return res
except Exception as e:
return f"Error: {str(e)}"
finally:
# Clean up the saved image
if os.path.exists(image_path):
os.remove(image_path)
# Function to highlight keyword in text
def highlight_keyword(text, keyword):
if keyword:
highlighted_text = text.replace(keyword, f"<mark>{keyword}</mark>")
return highlighted_text
return text
# Streamlit App
st.set_page_config(page_title="GOT-OCR Multilingual Demo", layout="wide")
# Creating two columns
left_col, right_col = st.columns(2)
with left_col:
uploaded_image = st.file_uploader("Upload your image", type=["png", "jpg", "jpeg"])
with right_col:
# Model selection in the right column
model_option = st.selectbox("Select Model", ["OCR for english or hindi (runs on CPU)", "OCR for english (runs on GPU)"])
if uploaded_image:
image = Image.open(uploaded_image)
with left_col:
st.image(image, caption='Uploaded Image', use_column_width=True)
with right_col:
if st.button("Run OCR"):
with st.spinner("Processing..."):
# Load the selected model (cached using @st.cache_resource)
tokenizer, model = load_model(model_option)
# Run OCR and cache the result using @st.cache_data
result_text = run_GOT(image, tokenizer, model) # Pass the original image here
if "Error" not in result_text:
# Keyword input for search
keyword = st.text_input("Enter a keyword to highlight")
# Highlight keyword in the extracted text
highlighted_text = highlight_keyword(result_text, keyword)
# Display the extracted text
st.markdown(highlighted_text, unsafe_allow_html=True)
else:
st.error(result_text)
|