File size: 4,104 Bytes
2d8087a
b5ef879
fa9edbf
b5ef879
 
fa9edbf
f07599d
b5ef879
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d8087a
b5ef879
 
 
 
 
 
f07599d
3d71449
b5ef879
fa9edbf
3d71449
fa9edbf
b5ef879
 
f07599d
fa9edbf
b5ef879
 
 
 
 
 
fa9edbf
f604f09
fa9edbf
3d71449
fa9edbf
 
 
3d71449
b5ef879
 
 
 
 
fa9edbf
 
b5ef879
 
fa9edbf
2d8087a
 
 
 
 
 
 
 
 
f604f09
2d8087a
 
 
77415fc
b5ef879
 
 
2d8087a
 
 
 
 
 
b5ef879
2d8087a
b5ef879
2d8087a
 
 
b5ef879
2d8087a
 
 
 
b5ef879
2d8087a
b5ef879
2d8087a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
ffrom transformers import AutoModel, AutoTokenizer
import streamlit as st
from PIL import Image
import re
import os
import uuid

# Load the model and tokenizer only once
if "model" not in st.session_state or "tokenizer" not in st.session_state:
    @st.cache_resource
    def load_model(model_name):
        if model_name == "OCR for English or Hindi (CPU)":
            tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True)
            model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
            model = model.eval()
        elif model_name == "OCR for English (GPU)":
            tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
            model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
            model = model.eval().to('cuda')
        return model, tokenizer

    # Load and store in session state
    model_option = "OCR for English or Hindi (CPU)"  # Default value for loading purposes
    model, tokenizer = load_model(model_option)
    st.session_state["model"] = model
    st.session_state["tokenizer"] = tokenizer
else:
    model = st.session_state["model"]
    tokenizer = st.session_state["tokenizer"]

# Function to run the GOT model for multilingual OCR
def run_ocr(image, model, tokenizer):
    unique_id = str(uuid.uuid4())
    image_path = f"{unique_id}.png"
    
    # Save image to disk
    image.save(image_path)

    try:
        # Use the model to extract text from the image
        res = model.chat(tokenizer, image_path, ocr_type='ocr')
        if isinstance(res, str):
            return res
        else:
            return str(res)
    except Exception as e:
        return f"Error: {str(e)}"
    finally:
        # Clean up the saved image
        if os.path.exists(image_path):
            os.remove(image_path)

# Function to highlight keyword in text
def highlight_text(text, search_term):
    if not search_term:
        return text
    pattern = re.compile(re.escape(search_term), re.IGNORECASE)
    return pattern.sub(lambda m: f'<span style="background-color: yellow;">{m.group()}</span>', text)

# Streamlit App
st.title("GOT-OCR Multilingual Demo")
st.write("Upload an image for OCR")

# Create two columns
col1, col2 = st.columns(2)

# Left column - Display the uploaded image
with col1:
    uploaded_image = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg"])
    if uploaded_image:
        image = Image.open(uploaded_image)
        st.image(image, caption='Uploaded Image', use_column_width=True)

# Right column - Model selection, options, and displaying extracted text
with col2:
    model_option = st.selectbox("Select Model", ["OCR for English or Hindi (CPU)", "OCR for English (GPU)"])
    
    if st.button("Run OCR"):
        with st.spinner("Processing..."):
            # Run OCR and store the result in session state
            if uploaded_image:
                result_text = run_ocr(image, model, tokenizer)
                if "Error" not in result_text:
                    st.session_state["extracted_text"] = result_text  # Store the result in session state
                else:
                    st.error(result_text)
            else:
                st.error("Please upload an image before running OCR.")

    # Display the extracted text if it exists in session state
    if "extracted_text" in st.session_state:
        extracted_text = st.session_state["extracted_text"]

        # Keyword input for search
        search_term = st.text_input("Enter a word or phrase to highlight:")
        
        # Highlight keyword in the extracted text
        highlighted_text = highlight_text(extracted_text, search_term)
        
        # Display the highlighted text using markdown
        st.subheader("Extracted Text:")
        st.markdown(f'<div style="white-space: pre-wrap;">{highlighted_text}</div>', unsafe_allow_html=True)