Spaces:
Running
Running
import streamlit as st | |
from transformers import AutoProcessor, AutoModelForImageTextToText | |
from PIL import Image | |
import torch | |
import time # To simulate progress bar updates | |
# Load model and processor | |
def load_model(): | |
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct") | |
model = AutoModelForImageTextToText.from_pretrained("HuggingFaceTB/SmolVLM-Instruct") | |
return processor, model | |
# Function to preprocess image and handle model execution | |
def extract_text(image, processor, model): | |
# Initialize progress bar | |
progress_bar = st.progress(0) | |
time.sleep(0.5) | |
# Resize the image to fixed dimensions | |
try: | |
required_size = (224, 224) # Explicit resizing for model input | |
image_resized = image.resize(required_size) | |
progress_bar.progress(20) # Step 1: Image resized | |
time.sleep(0.5) | |
# Preprocess image (extract pixel values) | |
inputs = processor(images=image_resized, return_tensors="pt", do_resize=False).to("cpu") | |
pixel_values = inputs.get("pixel_values") | |
# Debugging: Check the pixel_values tensor shape | |
st.write(f"Pixel Values Shape: {pixel_values.shape}") | |
# Check if pixel values are valid | |
if pixel_values is None or pixel_values.shape[0] == 0: | |
raise ValueError("Preprocessing failed: Empty tensor generated for image.") | |
# Additional check to ensure it has expected shape | |
if pixel_values.shape[0] != 81 or pixel_values.shape[1] != 2048: | |
raise ValueError(f"Unexpected tensor shape: {pixel_values.shape}. Expected shape: [batch_size, 2048].") | |
progress_bar.progress(50) # Step 2: Image preprocessed | |
time.sleep(0.5) | |
# Perform inference | |
with torch.no_grad(): | |
outputs = model.generate(pixel_values=pixel_values) | |
progress_bar.progress(80) # Step 3: Model processing | |
time.sleep(0.5) | |
# Decode outputs to text | |
result = processor.batch_decode(outputs, skip_special_tokens=True)[0] | |
progress_bar.progress(100) # Step 4: Completed | |
time.sleep(0.5) | |
return result | |
except Exception as e: | |
raise RuntimeError(f"Error during text extraction: {str(e)}") | |
# Streamlit UI | |
def main(): | |
st.title("๐ผ๏ธ OCR App using SmolVLM-Instruct") | |
st.write("Upload an image, and I will extract the text for you!") | |
# Load the model and processor | |
processor, model = load_model() | |
# File uploader | |
uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
try: | |
# Open and display image | |
image = Image.open(uploaded_file).convert("RGB") | |
st.image(image, caption="Uploaded Image", use_container_width=True) | |
# Extract text with progress bar | |
with st.spinner("Extracting text... Please wait!"): | |
extracted_text = extract_text(image, processor, model) | |
st.subheader("๐ Extracted Text:") | |
st.write(extracted_text) | |
except Exception as e: | |
st.error(f"An error occurred: {str(e)}") | |
if __name__ == "__main__": | |
main() | |