File size: 4,984 Bytes
d657bf8
 
 
 
 
 
 
8551568
 
 
df3681e
 
d657bf8
 
 
 
 
 
8551568
df3681e
8551568
 
 
 
 
 
 
45b88db
8551568
d657bf8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45b88db
 
17130f7
d657bf8
 
45b88db
17130f7
d657bf8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45b88db
d657bf8
 
 
 
 
 
 
45b88db
17130f7
d657bf8
 
 
 
45b88db
17130f7
d657bf8
 
 
 
 
 
 
 
 
 
 
 
 
 
45b88db
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import streamlit as st
from transformers import AutoModel, AutoTokenizer, MarianMTModel, MarianTokenizer
from PIL import Image
import tempfile
import os
import easyocr
import re
import torch

# Check if GPU is available, else default to CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
st.write(f"Using device: {device}")

# Load EasyOCR reader with English and Hindi language support
reader = easyocr.Reader(['en', 'hi'])  # 'en' for English, 'hi' for Hindi

# Load the GOT-OCR2 model and tokenizer
tokenizer = AutoTokenizer.from_pretrained('stepfun-ai/GOT-OCR2_0', trust_remote_code=True)

# Load the model and move it to the correct device (GPU if available, else CPU)
model = AutoModel.from_pretrained(
    'stepfun-ai/GOT-OCR2_0', 
    trust_remote_code=True, 
    low_cpu_mem_usage=True, 
    use_safetensors=True, 
    pad_token_id=tokenizer.eos_token_id
)
model = model.to(device)  # Move the model to the correct device
model = model.eval()

# Load MarianMT translation model for Hindi to English translation
translation_tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-hi-en')
translation_model = MarianMTModel.from_pretrained('Helsinki-NLP/opus-mt-hi-en')

# Define a function for keyword highlighting
def highlight_keywords(text, keyword):
    # Escape keyword for regex to avoid issues with special characters
    pattern = re.compile(re.escape(keyword), re.IGNORECASE)
    highlighted_text = pattern.sub(lambda match: f"**{match.group(0)}**", text)
    return highlighted_text

# Streamlit App Title
st.title("OCR with GOT-OCR2 (English & Hindi Translation) and Keyword Search")

# File uploader for image input
image_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])

if image_file is not None:
    # Display the uploaded image
    image = Image.open(image_file)
    st.image(image, caption='Uploaded Image', use_column_width=True)
    
    # Save the uploaded file to a temporary file
    with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as temp_file:
        temp_file.write(image_file.getvalue())
        temp_file_path = temp_file.name
    
    # Button to run OCR
    if st.button("Run OCR"):
        # Use GOT-OCR2 model for plain text OCR (structured documents)
        with torch.no_grad():
            res_plain = model.chat(tokenizer, temp_file_path, ocr_type='ocr')  # Removed device parameter

        # Perform formatted text OCR
        with torch.no_grad():
            res_format = model.chat(tokenizer, temp_file_path, ocr_type='format')  # Removed device parameter

        # Use EasyOCR for both English and Hindi text recognition
        result_easyocr = reader.readtext(temp_file_path, detail=0)

        # Display the results
        st.subheader("Plain Text OCR Results (English):")
        st.write(res_plain)

        st.subheader("Formatted Text OCR Results:")
        st.write(res_format)

        st.subheader("Detected Text using EasyOCR (English and Hindi):")
        extracted_text = " ".join(result_easyocr)  # Combine the list of text results
        st.write(extracted_text)

        # Translate Hindi text to English using MarianMT (optional step)
        st.subheader("Translated Hindi Text to English:")
        translated_text = []
        for sentence in result_easyocr:
            if sentence:  # Assuming non-empty text is translated
                tokenized_text = translation_tokenizer([sentence], return_tensors="pt", truncation=True)
                tokenized_text = {key: val.to(device) for key, val in tokenized_text.items()}  # Move tensors to device
                translation = translation_model.generate(**tokenized_text)
                translated_sentence = translation_tokenizer.decode(translation[0], skip_special_tokens=True)
                translated_text.append(translated_sentence)
        
        st.write(" ".join(translated_text))

        # Additional OCR types using GOT-OCR2
        with torch.no_grad():
            res_fine_grained = model.chat(tokenizer, temp_file_path, ocr_type='ocr', ocr_box='')  # Removed device parameter
        st.subheader("Fine-Grained OCR Results:")
        st.write(res_fine_grained)

        # Render formatted OCR to HTML
        with torch.no_grad():
            res_render = model.chat(tokenizer, temp_file_path, ocr_type='format', render=True, save_render_file='./demo.html')  # Removed device parameter
        st.subheader("Rendered OCR Results (HTML):")
        st.write(res_render)

        # Search functionality
        keyword = st.text_input("Enter keyword to search in extracted text:")

        if keyword:
            st.subheader("Search Results:")
            # Highlight the matching sections in the extracted text
            highlighted_text = highlight_keywords(extracted_text, keyword)
            st.markdown(highlighted_text)

        # Clean up the temporary file after use
        os.remove(temp_file_path)

# Note: No need for if __name__ == "__main__": st.run()