File size: 5,644 Bytes
d657bf8
 
 
 
 
 
 
8551568
 
 
df3681e
 
d657bf8
 
 
 
 
 
8551568
6372edc
8551568
 
 
 
 
 
 
6372edc
8551568
d657bf8
6372edc
01d6880
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b1ac83
6372edc
 
 
 
d657bf8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45b88db
 
2b1ac83
d657bf8
 
45b88db
8976d30
d657bf8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45b88db
d657bf8
 
 
 
 
 
 
45b88db
8976d30
d657bf8
 
 
 
45b88db
8976d30
d657bf8
 
 
 
 
 
 
 
 
 
 
 
01d6880
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import streamlit as st
from transformers import AutoModel, AutoTokenizer, MarianMTModel, MarianTokenizer
from PIL import Image
import tempfile
import os
import easyocr
import re
import torch

# Check if GPU is available, else default to CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
st.write(f"Using device: {device}")

# Load EasyOCR reader with English and Hindi language support
reader = easyocr.Reader(['en', 'hi'])  # 'en' for English, 'hi' for Hindi

# Load the GOT-OCR2 model and tokenizer
tokenizer = AutoTokenizer.from_pretrained('stepfun-ai/GOT-OCR2_0', trust_remote_code=True)

# Load the model and move it to the correct device
model = AutoModel.from_pretrained(
    'stepfun-ai/GOT-OCR2_0', 
    trust_remote_code=True, 
    low_cpu_mem_usage=True, 
    use_safetensors=True, 
    pad_token_id=tokenizer.eos_token_id
)
model = model.to(device)  # Move model to appropriate device
model = model.eval()

# Override the chat function to remove hardcoded .cuda()
def modified_chat(tokenizer, temp_file_path, ocr_type='ocr', *args, **kwargs):
    # Load the image data, perform OCR and get text
    with open(temp_file_path, 'rb') as f:
        image_data = f.read()
    
    # Assuming OCR process to extract text from image
    extracted_text = "some OCR processed text"  # Placeholder, replace with actual OCR result

    # Tokenize the extracted text
    inputs = tokenizer(extracted_text, return_tensors="pt", truncation=True, padding=True)
    
    # Move input_ids to the appropriate device
    input_ids = inputs['input_ids'].to(device)  # Use .to(device)
    
    # Perform any necessary processing using the model
    # Example: res = model(input_ids)  # Uncomment and implement model processing
    
    return f"Processed input: {input_ids}, OCR Type: {ocr_type}"

# Replace the model's chat method with the modified version
model.chat = modified_chat

# Load MarianMT translation model for Hindi to English translation
translation_tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-hi-en')
translation_model = MarianMTModel.from_pretrained('Helsinki-NLP/opus-mt-hi-en')

# Define a function for keyword highlighting
def highlight_keywords(text, keyword):
    pattern = re.compile(re.escape(keyword), re.IGNORECASE)
    highlighted_text = pattern.sub(lambda match: f"**{match.group(0)}**", text)
    return highlighted_text

# Streamlit App Title
st.title("OCR with GOT-OCR2 (English & Hindi Translation) and Keyword Search")

# File uploader for image input
image_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])

if image_file is not None:
    # Display the uploaded image
    image = Image.open(image_file)
    st.image(image, caption='Uploaded Image', use_column_width=True)
    
    # Save the uploaded file to a temporary file
    with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as temp_file:
        temp_file.write(image_file.getvalue())
        temp_file_path = temp_file.name
    
    # Button to run OCR
    if st.button("Run OCR"):
        # Use GOT-OCR2 model for plain text OCR (structured documents)
        with torch.no_grad():
            res_plain = model.chat(tokenizer, temp_file_path, ocr_type='ocr')  # Ensure the correct parameters are passed

        # Perform formatted text OCR
        with torch.no_grad():
            res_format = model.chat(tokenizer, temp_file_path, ocr_type='format')

        # Use EasyOCR for both English and Hindi text recognition
        result_easyocr = reader.readtext(temp_file_path, detail=0)

        # Display the results
        st.subheader("Plain Text OCR Results (English):")
        st.write(res_plain)

        st.subheader("Formatted Text OCR Results:")
        st.write(res_format)

        st.subheader("Detected Text using EasyOCR (English and Hindi):")
        extracted_text = " ".join(result_easyocr)  # Combine the list of text results
        st.write(extracted_text)

        # Translate Hindi text to English using MarianMT (optional step)
        st.subheader("Translated Hindi Text to English:")
        translated_text = []
        for sentence in result_easyocr:
            if sentence:  # Assuming non-empty text is translated
                tokenized_text = translation_tokenizer([sentence], return_tensors="pt", truncation=True)
                tokenized_text = {key: val.to(device) for key, val in tokenized_text.items()}  # Move tensors to device
                translation = translation_model.generate(**tokenized_text)
                translated_sentence = translation_tokenizer.decode(translation[0], skip_special_tokens=True)
                translated_text.append(translated_sentence)
        
        st.write(" ".join(translated_text))

        # Additional OCR types using GOT-OCR2
        with torch.no_grad():
            res_fine_grained = model.chat(tokenizer, temp_file_path, ocr_type='ocr', ocr_box='')
        st.subheader("Fine-Grained OCR Results:")
        st.write(res_fine_grained)

        # Render formatted OCR to HTML
        with torch.no_grad():
            res_render = model.chat(tokenizer, temp_file_path, ocr_type='format', render=True, save_render_file='./demo.html')
        st.subheader("Rendered OCR Results (HTML):")
        st.write(res_render)

        # Search functionality
        keyword = st.text_input("Enter keyword to search in extracted text:")

        if keyword:
            st.subheader("Search Results:")
            highlighted_text = highlight_keywords(extracted_text, keyword)
            st.markdown(highlighted_text)

        # Clean up the temporary file after use
        os.remove(temp_file_path)