Spaces:
Sleeping
Sleeping
import streamlit as st | |
from PIL import Image | |
import re | |
from transformers import AutoModel, AutoTokenizer | |
st.set_page_config(page_title="OCR Application", page_icon="🖼️", layout="wide") | |
device = "cpu" | |
#def load_model(): | |
#processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten') | |
#model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten', device_map='cpu') | |
#@st.cache_resource | |
def load_model(): | |
tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, device_map='cpu') | |
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cpu', use_safetensors=True) | |
processor=tokenizer | |
return processor, model | |
def extract_text(image, processor, model): | |
# Preprocess the image and extract text | |
pixel_values = processor(images=image, return_tensors="pt").pixel_values | |
generated_ids = model.generate(pixel_values) | |
extracted_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
return extracted_text | |
def highlight_matches(text, keywords): | |
# Highlight keywords in the extracted text | |
pattern = re.compile(f"({re.escape(keywords)})", re.IGNORECASE) | |
highlighted_text = pattern.sub(r"<mark>\1</mark>", text) | |
return highlighted_text | |
def main(): | |
st.title("OCR Text Extractor using Hugging Face Model") | |
# Load model and processor | |
processor, model = load_model() | |
# Upload Image | |
uploaded_file = st.file_uploader("Upload an image for OCR", type=["png", "jpg", "jpeg"]) | |
if uploaded_file: | |
image = Image.open(uploaded_file) | |
st.image(image, caption="Uploaded Image", use_column_width=True) | |
# Extract text from the image | |
with st.spinner("Extracting text from the image..."): | |
extracted_text = extract_text(image, processor, model) | |
st.subheader("Extracted Text") | |
st.text_area("Text from Image", extracted_text, height=300) | |
# Keyword search | |
st.subheader("Keyword Search") | |
keywords = st.text_input("Enter keywords to search:") | |
if st.button("Search"): | |
highlighted_text = highlight_matches(extracted_text, keywords) | |
st.subheader("Search Results") | |
st.markdown(highlighted_text, unsafe_allow_html=True) | |
if __name__ == "__main__": | |
main() | |