Spaces:
Runtime error
Runtime error
File size: 5,644 Bytes
d657bf8 8551568 df3681e d657bf8 8551568 6372edc 8551568 6372edc 8551568 d657bf8 6372edc 01d6880 2b1ac83 6372edc d657bf8 45b88db 2b1ac83 d657bf8 45b88db 8976d30 d657bf8 45b88db d657bf8 45b88db 8976d30 d657bf8 45b88db 8976d30 d657bf8 01d6880 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import streamlit as st
from transformers import AutoModel, AutoTokenizer, MarianMTModel, MarianTokenizer
from PIL import Image
import tempfile
import os
import easyocr
import re
import torch
# Check if GPU is available, else default to CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
st.write(f"Using device: {device}")
# Load EasyOCR reader with English and Hindi language support
reader = easyocr.Reader(['en', 'hi']) # 'en' for English, 'hi' for Hindi
# Load the GOT-OCR2 model and tokenizer
tokenizer = AutoTokenizer.from_pretrained('stepfun-ai/GOT-OCR2_0', trust_remote_code=True)
# Load the model and move it to the correct device
model = AutoModel.from_pretrained(
'stepfun-ai/GOT-OCR2_0',
trust_remote_code=True,
low_cpu_mem_usage=True,
use_safetensors=True,
pad_token_id=tokenizer.eos_token_id
)
model = model.to(device) # Move model to appropriate device
model = model.eval()
# Override the chat function to remove hardcoded .cuda()
def modified_chat(tokenizer, temp_file_path, ocr_type='ocr', *args, **kwargs):
# Load the image data, perform OCR and get text
with open(temp_file_path, 'rb') as f:
image_data = f.read()
# Assuming OCR process to extract text from image
extracted_text = "some OCR processed text" # Placeholder, replace with actual OCR result
# Tokenize the extracted text
inputs = tokenizer(extracted_text, return_tensors="pt", truncation=True, padding=True)
# Move input_ids to the appropriate device
input_ids = inputs['input_ids'].to(device) # Use .to(device)
# Perform any necessary processing using the model
# Example: res = model(input_ids) # Uncomment and implement model processing
return f"Processed input: {input_ids}, OCR Type: {ocr_type}"
# Replace the model's chat method with the modified version
model.chat = modified_chat
# Load MarianMT translation model for Hindi to English translation
translation_tokenizer = MarianTokenizer.from_pretrained('Helsinki-NLP/opus-mt-hi-en')
translation_model = MarianMTModel.from_pretrained('Helsinki-NLP/opus-mt-hi-en')
# Define a function for keyword highlighting
def highlight_keywords(text, keyword):
pattern = re.compile(re.escape(keyword), re.IGNORECASE)
highlighted_text = pattern.sub(lambda match: f"**{match.group(0)}**", text)
return highlighted_text
# Streamlit App Title
st.title("OCR with GOT-OCR2 (English & Hindi Translation) and Keyword Search")
# File uploader for image input
image_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
if image_file is not None:
# Display the uploaded image
image = Image.open(image_file)
st.image(image, caption='Uploaded Image', use_column_width=True)
# Save the uploaded file to a temporary file
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as temp_file:
temp_file.write(image_file.getvalue())
temp_file_path = temp_file.name
# Button to run OCR
if st.button("Run OCR"):
# Use GOT-OCR2 model for plain text OCR (structured documents)
with torch.no_grad():
res_plain = model.chat(tokenizer, temp_file_path, ocr_type='ocr') # Ensure the correct parameters are passed
# Perform formatted text OCR
with torch.no_grad():
res_format = model.chat(tokenizer, temp_file_path, ocr_type='format')
# Use EasyOCR for both English and Hindi text recognition
result_easyocr = reader.readtext(temp_file_path, detail=0)
# Display the results
st.subheader("Plain Text OCR Results (English):")
st.write(res_plain)
st.subheader("Formatted Text OCR Results:")
st.write(res_format)
st.subheader("Detected Text using EasyOCR (English and Hindi):")
extracted_text = " ".join(result_easyocr) # Combine the list of text results
st.write(extracted_text)
# Translate Hindi text to English using MarianMT (optional step)
st.subheader("Translated Hindi Text to English:")
translated_text = []
for sentence in result_easyocr:
if sentence: # Assuming non-empty text is translated
tokenized_text = translation_tokenizer([sentence], return_tensors="pt", truncation=True)
tokenized_text = {key: val.to(device) for key, val in tokenized_text.items()} # Move tensors to device
translation = translation_model.generate(**tokenized_text)
translated_sentence = translation_tokenizer.decode(translation[0], skip_special_tokens=True)
translated_text.append(translated_sentence)
st.write(" ".join(translated_text))
# Additional OCR types using GOT-OCR2
with torch.no_grad():
res_fine_grained = model.chat(tokenizer, temp_file_path, ocr_type='ocr', ocr_box='')
st.subheader("Fine-Grained OCR Results:")
st.write(res_fine_grained)
# Render formatted OCR to HTML
with torch.no_grad():
res_render = model.chat(tokenizer, temp_file_path, ocr_type='format', render=True, save_render_file='./demo.html')
st.subheader("Rendered OCR Results (HTML):")
st.write(res_render)
# Search functionality
keyword = st.text_input("Enter keyword to search in extracted text:")
if keyword:
st.subheader("Search Results:")
highlighted_text = highlight_keywords(extracted_text, keyword)
st.markdown(highlighted_text)
# Clean up the temporary file after use
os.remove(temp_file_path) |