Spaces:
Sleeping
Sleeping
File size: 2,030 Bytes
3a8de33 d1bb7e2 c6111b8 3a8de33 e6b9318 c6111b8 e6b9318 c6111b8 d1bb7e2 30abd6a d1bb7e2 e6b9318 c6111b8 e6b9318 d1bb7e2 e6b9318 d1bb7e2 e6b9318 d1bb7e2 e6b9318 d1bb7e2 fa36a00 e6b9318 3a8de33 e6b9318 c6111b8 9164d6d c6111b8 9164d6d c6111b8 9164d6d c6111b8 d1bb7e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import gradio as gr
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import numpy as np
import torch
# Load the primary model (DeepDiveDev/transformodocs-ocr)
processor1 = TrOCRProcessor.from_pretrained("DeepDiveDev/transformodocs-ocr")
model1 = VisionEncoderDecoderModel.from_pretrained("DeepDiveDev/transformodocs-ocr")
# Load the fallback model (microsoft/trocr-base-handwritten)
processor2 = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
model2 = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
# Function to extract text using both models
def extract_text(image):
try:
# Ensure the input is a PIL image
if isinstance(image, np.ndarray):
if len(image.shape) == 2: # Grayscale (H, W), convert to RGB
image = np.stack([image] * 3, axis=-1)
image = Image.fromarray(image)
else:
image = Image.open(image).convert("RGB") # Ensure RGB mode
# Resize for better accuracy
image = image.resize((640, 640))
# Process with the primary model
pixel_values = processor1(images=image, return_tensors="pt").pixel_values
generated_ids = model1.generate(pixel_values)
extracted_text = processor1.batch_decode(generated_ids, skip_special_tokens=True)[0]
# If output seems incorrect, use the fallback model
if len(extracted_text.strip()) < 2:
inputs = processor2(images=image, return_tensors="pt").pixel_values
generated_ids = model2.generate(inputs)
extracted_text = processor2.batch_decode(generated_ids, skip_special_tokens=True)[0]
return extracted_text
except Exception as e:
return f"Error: {str(e)}"
# Gradio Interface
iface = gr.Interface(
fn=extract_text,
inputs="image",
outputs="text",
title="TransformoDocs - AI OCR",
description="Upload a handwritten document and get the extracted text.",
)
iface.launch() |