File size: 2,030 Bytes
3a8de33
d1bb7e2
c6111b8
3a8de33
e6b9318
c6111b8
e6b9318
 
 
c6111b8
d1bb7e2
30abd6a
d1bb7e2
e6b9318
 
c6111b8
e6b9318
d1bb7e2
 
 
 
e6b9318
 
d1bb7e2
e6b9318
d1bb7e2
 
e6b9318
 
 
 
 
 
 
d1bb7e2
fa36a00
 
e6b9318
 
 
3a8de33
e6b9318
 
c6111b8
9164d6d
c6111b8
9164d6d
c6111b8
9164d6d
 
 
c6111b8
 
d1bb7e2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import gradio as gr 
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import numpy as np
import torch

# Load the primary model (DeepDiveDev/transformodocs-ocr)
processor1 = TrOCRProcessor.from_pretrained("DeepDiveDev/transformodocs-ocr")
model1 = VisionEncoderDecoderModel.from_pretrained("DeepDiveDev/transformodocs-ocr")

# Load the fallback model (microsoft/trocr-base-handwritten)
processor2 = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
model2 = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")

# Function to extract text using both models
def extract_text(image):
    try:
        # Ensure the input is a PIL image
        if isinstance(image, np.ndarray):
            if len(image.shape) == 2:  # Grayscale (H, W), convert to RGB
                image = np.stack([image] * 3, axis=-1)
            image = Image.fromarray(image)
        else:
            image = Image.open(image).convert("RGB")  # Ensure RGB mode

        # Resize for better accuracy
        image = image.resize((640, 640))  

        # Process with the primary model
        pixel_values = processor1(images=image, return_tensors="pt").pixel_values
        generated_ids = model1.generate(pixel_values)
        extracted_text = processor1.batch_decode(generated_ids, skip_special_tokens=True)[0]

        # If output seems incorrect, use the fallback model
        if len(extracted_text.strip()) < 2:  
            inputs = processor2(images=image, return_tensors="pt").pixel_values
            generated_ids = model2.generate(inputs)
            extracted_text = processor2.batch_decode(generated_ids, skip_special_tokens=True)[0]

        return extracted_text

    except Exception as e:
        return f"Error: {str(e)}"

# Gradio Interface
iface = gr.Interface(
    fn=extract_text,
    inputs="image",
    outputs="text",
    title="TransformoDocs - AI OCR",
    description="Upload a handwritten document and get the extracted text.",
)

iface.launch()