File size: 2,274 Bytes
d010bf6
d1bb7e2
c6111b8
3a8de33
e6b9318
c6111b8
a3df3f5
e6b9318
 
c6111b8
a3df3f5
30abd6a
d1bb7e2
e6b9318
a3df3f5
c6111b8
e6b9318
d010bf6
d1bb7e2
a3df3f5
d1bb7e2
e6b9318
d010bf6
 
e6b9318
a3df3f5
 
e6b9318
a3df3f5
2653a83
e6b9318
 
 
a3df3f5
d1bb7e2
2653a83
fa36a00
e6b9318
 
a3df3f5
3a8de33
e6b9318
 
c6111b8
a3df3f5
c6111b8
9164d6d
d010bf6
9164d6d
a3df3f5
 
c6111b8
 
2653a83
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import gradio as gr
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import numpy as np
import torch

# Load the primary OCR model (DeepDiveDev/transformodocs-ocr)
processor1 = TrOCRProcessor.from_pretrained("DeepDiveDev/transformodocs-ocr")
model1 = VisionEncoderDecoderModel.from_pretrained("DeepDiveDev/transformodocs-ocr")

# Load the fallback model (microsoft/trocr-base-handwritten) for handwritten text
processor2 = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
model2 = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")

# Function to extract text from handwritten images
def extract_text(image):
    try:
        # Ensure input is a PIL Image
        if isinstance(image, np.ndarray):
            if len(image.shape) == 2:  # Grayscale (H, W) -> Convert to RGB
                image = np.stack([image] * 3, axis=-1)
            image = Image.fromarray(image)
        elif isinstance(image, str):  # If file path is given, open the image
            image = Image.open(image).convert("RGB")

        # Maintain aspect ratio while resizing (better for OCR)
        image.thumbnail((800, 800))  

        # Process image with the first model
        pixel_values = processor1(images=image, return_tensors="pt").pixel_values.to(torch.float32)
        generated_ids = model1.generate(pixel_values)
        extracted_text = processor1.batch_decode(generated_ids, skip_special_tokens=True)[0]

        # If output is short or incorrect, use the fallback model
        if len(extracted_text.strip()) < 2:  
            inputs = processor2(images=image, return_tensors="pt").pixel_values.to(torch.float32)
            generated_ids = model2.generate(inputs)
            extracted_text = processor2.batch_decode(generated_ids, skip_special_tokens=True)[0]

        return extracted_text if extracted_text else "No text detected."

    except Exception as e:
        return f"Error: {str(e)}"

# Gradio UI for OCR Extraction
iface = gr.Interface(
    fn=extract_text,
    inputs=gr.Image(type="pil"),  # Ensures input is a PIL image
    outputs="text",
    title="Handwritten OCR Extraction",
    description="Upload a handwritten image to extract text using AI OCR.",
)

iface.launch()