smol_vlm_ocr / app.py
Akshayram1's picture
Update app.py
ef9bfa1 verified
import streamlit as st
from transformers import AutoProcessor, AutoModelForImageTextToText
from PIL import Image
import torch
import time # To simulate progress bar updates
# Load model and processor
@st.cache_resource
def load_model():
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct")
model = AutoModelForImageTextToText.from_pretrained("HuggingFaceTB/SmolVLM-Instruct")
return processor, model
# Function to preprocess image and handle model execution
def extract_text(image, processor, model):
# Initialize progress bar
progress_bar = st.progress(0)
time.sleep(0.5)
# Resize the image to fixed dimensions
try:
required_size = (224, 224) # Explicit resizing for model input
image_resized = image.resize(required_size)
progress_bar.progress(20) # Step 1: Image resized
time.sleep(0.5)
# Preprocess image (extract pixel values)
inputs = processor(images=image_resized, return_tensors="pt", do_resize=False).to("cpu")
pixel_values = inputs.get("pixel_values")
# Debugging: Check the pixel_values tensor shape
st.write(f"Pixel Values Shape: {pixel_values.shape}")
# Check if pixel values are valid
if pixel_values is None or pixel_values.shape[0] == 0:
raise ValueError("Preprocessing failed: Empty tensor generated for image.")
# Additional check to ensure it has expected shape
if pixel_values.shape[0] != 81 or pixel_values.shape[1] != 2048:
raise ValueError(f"Unexpected tensor shape: {pixel_values.shape}. Expected shape: [batch_size, 2048].")
progress_bar.progress(50) # Step 2: Image preprocessed
time.sleep(0.5)
# Perform inference
with torch.no_grad():
outputs = model.generate(pixel_values=pixel_values)
progress_bar.progress(80) # Step 3: Model processing
time.sleep(0.5)
# Decode outputs to text
result = processor.batch_decode(outputs, skip_special_tokens=True)[0]
progress_bar.progress(100) # Step 4: Completed
time.sleep(0.5)
return result
except Exception as e:
raise RuntimeError(f"Error during text extraction: {str(e)}")
# Streamlit UI
def main():
st.title("๐Ÿ–ผ๏ธ OCR App using SmolVLM-Instruct")
st.write("Upload an image, and I will extract the text for you!")
# Load the model and processor
processor, model = load_model()
# File uploader
uploaded_file = st.file_uploader("Upload an Image", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
try:
# Open and display image
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Uploaded Image", use_container_width=True)
# Extract text with progress bar
with st.spinner("Extracting text... Please wait!"):
extracted_text = extract_text(image, processor, model)
st.subheader("๐Ÿ“ Extracted Text:")
st.write(extracted_text)
except Exception as e:
st.error(f"An error occurred: {str(e)}")
if __name__ == "__main__":
main()