DeepDiveDev's picture
Update app.py
30abd6a verified
raw
history blame
2.06 kB
import gradio as gr
from transformers import TrOCRProcessor, VisionEncoderDecoderModel, AutoProcessor, AutoModelForVision2Seq
from PIL import Image
import numpy as np
import torch
# Load the primary model (DeepDiveDev/transformodocs-ocr)
processor1 = TrOCRProcessor.from_pretrained("DeepDiveDev/transformodocs-ocr")
model1 = VisionEncoderDecoderModel.from_pretrained("DeepDiveDev/transformodocs-ocr")
# Load the fallback model (allenai/olmOCR-7B-0225-preview)
model2 = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
processor2 = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
# Function to extract text using both models
def extract_text(image):
try:
# Convert input to PIL Image
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
else:
image = Image.open(image).convert("RGB")
# Preprocessing
image = image.convert("L") # Convert to grayscale for better OCR
image = image.resize((640, 640)) # Resize to improve accuracy
# Process with the primary model
pixel_values = processor1(images=image, return_tensors="pt").pixel_values
generated_ids = model1.generate(pixel_values)
extracted_text = processor1.batch_decode(generated_ids, skip_special_tokens=True)[0]
# If output seems incorrect, use the fallback model
if len(extracted_text.strip()) < 2: # If output is too short, retry with second model
inputs = processor2(images=image, return_tensors="pt").pixel_values
generated_ids = model2.generate(inputs)
extracted_text = processor2.batch_decode(generated_ids, skip_special_tokens=True)[0]
return extracted_text
except Exception as e:
return f"Error: {str(e)}"
# Gradio Interface
iface = gr.Interface(
fn=extract_text,
inputs="image",
outputs="text",
title="TransformoDocs - AI OCR",
description="Upload a handwritten document and get the extracted text.",
)
iface.launch()