File size: 3,457 Bytes
d7e12cd
250467d
9947575
fa9edbf
 
f07599d
c1eebf8
3d71449
f604f09
3d71449
f604f09
 
3d71449
 
f604f09
 
3d71449
f604f09
f07599d
3d71449
c1eebf8
f604f09
fa9edbf
3d71449
fa9edbf
 
f7e10cd
f07599d
fa9edbf
3d71449
 
f7e10cd
f604f09
fa9edbf
f7e10cd
f604f09
fa9edbf
3d71449
fa9edbf
 
 
3d71449
 
 
f7e10cd
3d71449
 
fa9edbf
 
3d71449
fa9edbf
3d71449
 
fa9edbf
3d71449
 
77415fc
3d71449
 
 
f604f09
fa9edbf
 
77415fc
3d71449
77415fc
3d71449
 
 
77415fc
3d71449
f604f09
c1eebf8
 
f604f09
3d71449
 
 
 
 
 
 
 
 
 
 
 
f7e10cd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import streamlit as st
from transformers import AutoModel, AutoTokenizer
from PIL import Image
import uuid

# Cache the model loading function using @st.cache_resource
@st.cache_resource
def load_model(model_name):
    if model_name == "OCR for english or hindi (runs on CPU)":
        tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True)
        model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
        model.eval()  # Load model on CPU
    elif model_name == "OCR for english (runs on GPU)":
        tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
        model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
        model.eval().cuda()  # Load model on GPU
    return tokenizer, model

# Function to run the GOT model for multilingual OCR
@st.cache_data
def run_GOT(image, tokenizer, model):
    unique_id = str(uuid.uuid4())
    image_path = f"{unique_id}.png"
    
    image.save(image_path)
    st.write(f"Saved image to {image_path}")

    try:
        # Use the model to extract text
        res = model.chat(tokenizer, image_path, ocr_type='ocr')  # Extract plain text
        st.write(f"Raw result: {res}")  # Debug output
        return res
    except Exception as e:
        st.error(f"Error: {str(e)}")  # Display any errors
        return f"Error: {str(e)}"
    finally:
        # Clean up the saved image
        if os.path.exists(image_path):
            os.remove(image_path)

# Function to highlight keyword in text
def highlight_keyword(text, keyword):
    if keyword:
        highlighted_text = text.replace(keyword, f"<mark>{keyword}</mark>")
        return highlighted_text
    return text

# Streamlit App
st.set_page_config(page_title="GOT-OCR Multilingual Demo", layout="wide")

# Creating two columns
left_col, right_col = st.columns(2)

with left_col:
    uploaded_image = st.file_uploader("Upload your image", type=["png", "jpg", "jpeg"])

with right_col:
    # Model selection in the right column
    model_option = st.selectbox("Select Model", ["OCR for english or hindi (runs on CPU)", "OCR for english (runs on GPU)"])

if uploaded_image:
    image = Image.open(uploaded_image)
    
    with left_col:
        st.image(image, caption='Uploaded Image', use_column_width=True)
    
    with right_col:
        if st.button("Run OCR"):
            with st.spinner("Processing..."):
                # Load the selected model (cached using @st.cache_resource)
                tokenizer, model = load_model(model_option)
                
                # Run OCR and cache the result using @st.cache_data
                result_text = run_GOT(image, tokenizer, model)
                
                if "Error" not in result_text:
                    # Keyword input for search
                    keyword = st.text_input("Enter a keyword to highlight")
                    
                    # Highlight keyword in the extracted text
                    highlighted_text = highlight_keyword(result_text, keyword)
                    
                    # Display the extracted text
                    st.markdown(highlighted_text, unsafe_allow_html=True)
                else:
                    st.error(result_text)