Spaces:
Sleeping
Sleeping
File size: 2,274 Bytes
d010bf6 d1bb7e2 c6111b8 3a8de33 e6b9318 c6111b8 a3df3f5 e6b9318 c6111b8 a3df3f5 30abd6a d1bb7e2 e6b9318 a3df3f5 c6111b8 e6b9318 d010bf6 d1bb7e2 a3df3f5 d1bb7e2 e6b9318 d010bf6 e6b9318 a3df3f5 e6b9318 a3df3f5 2653a83 e6b9318 a3df3f5 d1bb7e2 2653a83 fa36a00 e6b9318 a3df3f5 3a8de33 e6b9318 c6111b8 a3df3f5 c6111b8 9164d6d d010bf6 9164d6d a3df3f5 c6111b8 2653a83 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
import gradio as gr
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import numpy as np
import torch
# Load the primary OCR model (DeepDiveDev/transformodocs-ocr)
processor1 = TrOCRProcessor.from_pretrained("DeepDiveDev/transformodocs-ocr")
model1 = VisionEncoderDecoderModel.from_pretrained("DeepDiveDev/transformodocs-ocr")
# Load the fallback model (microsoft/trocr-base-handwritten) for handwritten text
processor2 = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
model2 = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
# Function to extract text from handwritten images
def extract_text(image):
try:
# Ensure input is a PIL Image
if isinstance(image, np.ndarray):
if len(image.shape) == 2: # Grayscale (H, W) -> Convert to RGB
image = np.stack([image] * 3, axis=-1)
image = Image.fromarray(image)
elif isinstance(image, str): # If file path is given, open the image
image = Image.open(image).convert("RGB")
# Maintain aspect ratio while resizing (better for OCR)
image.thumbnail((800, 800))
# Process image with the first model
pixel_values = processor1(images=image, return_tensors="pt").pixel_values.to(torch.float32)
generated_ids = model1.generate(pixel_values)
extracted_text = processor1.batch_decode(generated_ids, skip_special_tokens=True)[0]
# If output is short or incorrect, use the fallback model
if len(extracted_text.strip()) < 2:
inputs = processor2(images=image, return_tensors="pt").pixel_values.to(torch.float32)
generated_ids = model2.generate(inputs)
extracted_text = processor2.batch_decode(generated_ids, skip_special_tokens=True)[0]
return extracted_text if extracted_text else "No text detected."
except Exception as e:
return f"Error: {str(e)}"
# Gradio UI for OCR Extraction
iface = gr.Interface(
fn=extract_text,
inputs=gr.Image(type="pil"), # Ensures input is a PIL image
outputs="text",
title="Handwritten OCR Extraction",
description="Upload a handwritten image to extract text using AI OCR.",
)
iface.launch()
|