OCR_Application / app.py
Divyansh12's picture
Update app.py
f07599d verified
raw
history blame
2.43 kB
import streamlit as st
from PIL import Image
import re
from transformers import AutoModel, AutoTokenizer
st.set_page_config(page_title="OCR Application", page_icon="🖼️", layout="wide")
device = "cpu"
@st.cache_resource
#def load_model():
#processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
#model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten', device_map='cpu')
#@st.cache_resource
def load_model():
tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, device_map='cpu')
model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cpu', use_safetensors=True)
processor=tokenizer
return processor, model
def extract_text(image, processor, model):
# Preprocess the image and extract text
pixel_values = processor(images=image, return_tensors="pt").pixel_values
generated_ids = model.generate(pixel_values)
extracted_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return extracted_text
def highlight_matches(text, keywords):
# Highlight keywords in the extracted text
pattern = re.compile(f"({re.escape(keywords)})", re.IGNORECASE)
highlighted_text = pattern.sub(r"<mark>\1</mark>", text)
return highlighted_text
def main():
st.title("OCR Text Extractor using Hugging Face Model")
# Load model and processor
processor, model = load_model()
# Upload Image
uploaded_file = st.file_uploader("Upload an image for OCR", type=["png", "jpg", "jpeg"])
if uploaded_file:
image = Image.open(uploaded_file)
st.image(image, caption="Uploaded Image", use_column_width=True)
# Extract text from the image
with st.spinner("Extracting text from the image..."):
extracted_text = extract_text(image, processor, model)
st.subheader("Extracted Text")
st.text_area("Text from Image", extracted_text, height=300)
# Keyword search
st.subheader("Keyword Search")
keywords = st.text_input("Enter keywords to search:")
if st.button("Search"):
highlighted_text = highlight_matches(extracted_text, keywords)
st.subheader("Search Results")
st.markdown(highlighted_text, unsafe_allow_html=True)
if __name__ == "__main__":
main()