File size: 3,383 Bytes
d7e12cd
250467d
9947575
fa9edbf
 
f07599d
c1eebf8
3d71449
f604f09
3d71449
f604f09
 
3d71449
 
f604f09
 
3d71449
f604f09
f07599d
3d71449
c1eebf8
06672c1
fa9edbf
3d71449
fa9edbf
6b236d4
f07599d
fa9edbf
3d71449
06672c1
f604f09
fa9edbf
f604f09
fa9edbf
3d71449
fa9edbf
 
 
3d71449
 
 
f7e10cd
3d71449
 
fa9edbf
 
3d71449
fa9edbf
3d71449
 
fa9edbf
3d71449
 
77415fc
3d71449
 
 
f604f09
fa9edbf
 
77415fc
3d71449
77415fc
3d71449
 
 
77415fc
3d71449
f604f09
c1eebf8
 
6b236d4
3d71449
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
import streamlit as st
from transformers import AutoModel, AutoTokenizer
from PIL import Image
import uuid

# Cache the model loading function using @st.cache_resource
@st.cache_resource
def load_model(model_name):
    if model_name == "OCR for english or hindi (runs on CPU)":
        tokenizer = AutoTokenizer.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True)
        model = AutoModel.from_pretrained('srimanth-d/GOT_CPU', trust_remote_code=True, use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
        model.eval()  # Load model on CPU
    elif model_name == "OCR for english (runs on GPU)":
        tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
        model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda', use_safetensors=True, pad_token_id=tokenizer.eos_token_id)
        model.eval().cuda()  # Load model on GPU
    return tokenizer, model

# Function to run the GOT model for multilingual OCR
@st.cache_data
def run_GOT(_image, _tokenizer, _model):
    unique_id = str(uuid.uuid4())
    image_path = f"{unique_id}.png"
    
    _image.save(image_path)  # Save the image using the underscore variable

    try:
        # Use the model to extract text
        res = _model.chat(_tokenizer, image_path, ocr_type='ocr')  # Extract plain text
        return res
    except Exception as e:
        return f"Error: {str(e)}"
    finally:
        # Clean up the saved image
        if os.path.exists(image_path):
            os.remove(image_path)

# Function to highlight keyword in text
def highlight_keyword(text, keyword):
    if keyword:
        highlighted_text = text.replace(keyword, f"<mark>{keyword}</mark>")
        return highlighted_text
    return text

# Streamlit App
st.set_page_config(page_title="GOT-OCR Multilingual Demo", layout="wide")

# Creating two columns
left_col, right_col = st.columns(2)

with left_col:
    uploaded_image = st.file_uploader("Upload your image", type=["png", "jpg", "jpeg"])

with right_col:
    # Model selection in the right column
    model_option = st.selectbox("Select Model", ["OCR for english or hindi (runs on CPU)", "OCR for english (runs on GPU)"])

if uploaded_image:
    image = Image.open(uploaded_image)
    
    with left_col:
        st.image(image, caption='Uploaded Image', use_column_width=True)
    
    with right_col:
        if st.button("Run OCR"):
            with st.spinner("Processing..."):
                # Load the selected model (cached using @st.cache_resource)
                tokenizer, model = load_model(model_option)
                
                # Run OCR and cache the result using @st.cache_data
                result_text = run_GOT(image, tokenizer, model)  # Pass the original image here
                
                if "Error" not in result_text:
                    # Keyword input for search
                    keyword = st.text_input("Enter a keyword to highlight")
                    
                    # Highlight keyword in the extracted text
                    highlighted_text = highlight_keyword(result_text, keyword)
                    
                    # Display the extracted text
                    st.markdown(highlighted_text, unsafe_allow_html=True)
                else:
                    st.error(result_text)