File size: 2,025 Bytes
3a8de33
9164d6d
c6111b8
3a8de33
e6b9318
c6111b8
e6b9318
 
 
c6111b8
e6b9318
 
 
 
 
c6111b8
e6b9318
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a8de33
e6b9318
 
c6111b8
9164d6d
c6111b8
9164d6d
c6111b8
9164d6d
 
 
c6111b8
 
e6b9318
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import gradio as gr 
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import numpy as np
import torch

# Load the primary model (DeepDiveDev/transformodocs-ocr)
processor1 = TrOCRProcessor.from_pretrained("DeepDiveDev/transformodocs-ocr")
model1 = VisionEncoderDecoderModel.from_pretrained("DeepDiveDev/transformodocs-ocr")

# Load the fallback model (allenai/olmOCR-7B-0225-preview)
processor2 = TrOCRProcessor.from_pretrained("allenai/olmOCR-7B-0225-preview")
model2 = VisionEncoderDecoderModel.from_pretrained("allenai/olmOCR-7B-0225-preview")

# Function to extract text using both models
def extract_text(image):
    try:
        # Convert input to PIL Image
        if isinstance(image, np.ndarray):  
            image = Image.fromarray(image)
        else:
            image = Image.open(image).convert("RGB")  

        # Preprocessing
        image = image.convert("L")  # Convert to grayscale for better OCR
        image = image.resize((640, 640))  # Resize to improve accuracy

        # Process with the primary model
        pixel_values = processor1(images=image, return_tensors="pt").pixel_values
        generated_ids = model1.generate(pixel_values)
        extracted_text = processor1.batch_decode(generated_ids, skip_special_tokens=True)[0]

        # If output seems incorrect, use the fallback model
        if len(extracted_text.strip()) < 2:  # If output is too short, retry with second model
            pixel_values = processor2(images=image, return_tensors="pt").pixel_values
            generated_ids = model2.generate(pixel_values)
            extracted_text = processor2.batch_decode(generated_ids, skip_special_tokens=True)[0]

        return extracted_text

    except Exception as e:
        return f"Error: {str(e)}"

# Gradio Interface
iface = gr.Interface(
    fn=extract_text,
    inputs="image",
    outputs="text",
    title="TransformoDocs - AI OCR",
    description="Upload a handwritten document and get the extracted text.",
)

iface.launch()