File size: 5,087 Bytes
a42c0ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gradio as gr
from PIL import Image
import torch
from transformers import AutoFeatureExtractor, AutoTokenizer, TrOCRProcessor, VisionEncoderDecoderModel
import os

# def sauvola_thresholding(grayImage_, window_size=15):
#     """
#     Sauvola thresholds are local thresholding techniques that are 
#     useful for images where the background is not uniform, especially for text recognition.
    
#     grayImage_ --- Input image should be in 2-Dimension Gray Scale format.
#     window_size --- It represents the filter window size.
#     """
#     # Assert the input conditions
#     assert len(grayImage_.shape) == 2, "Input image must be a 2-dimensional gray scale image."
#     assert isinstance(window_size, int) and window_size > 0, "Window size must be a positive integer."
    
#     thresh_sauvolavalue = threshold_sauvola(grayImage_, window_size=window_size)
#     thresholdImage_ = (grayImage_ > thresh_sauvolavalue)
    
#     return thresholdImage_

class OCRModel:
    def __init__(self, encoder_model, decoder_model, trained_model_path):
        # Load processor and model
        self.feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_model)
        self.decoder_tokenizer = AutoTokenizer.from_pretrained(decoder_model)
        self.processor = TrOCRProcessor(feature_extractor=self.feature_extractor, tokenizer=self.decoder_tokenizer)
        self.model = VisionEncoderDecoderModel.from_pretrained(trained_model_path)
        
        # Configure model settings
        self.model.config.decoder_start_token_id = self.processor.tokenizer.cls_token_id
        self.model.config.pad_token_id = self.processor.tokenizer.pad_token_id
        self.model.config.vocab_size = self.model.config.decoder.vocab_size
        self.model.config.eos_token_id = self.processor.tokenizer.sep_token_id
        self.model.config.max_length = 64
        self.model.config.early_stopping = True
        self.model.config.no_repeat_ngram_size = 3
        self.model.config.length_penalty = 2.0
        self.model.config.num_beams = 4
        
    def read_and_show(self, image_path):
        """
        Reads an image from the provided path and converts it to RGB.
        :param image_path: String, path to the input image.
        :return: PIL Image object
        """
        image = Image.open(image_path).convert('RGB')
        return image

    def ocr(self, image):
        """
        Performs OCR on the given image.
        :param image: PIL Image object.
        :return: Extracted text from the image.
        """
        # Preprocess the image
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        pixel_values = self.processor(image, return_tensors='pt').pixel_values.to(device)

        # Generate text
        generated_ids = self.model.generate(pixel_values)
        generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        return generated_text

# Initialize the OCR model
ocr_model = OCRModel(encoder_model="google/vit-base-patch16-224-in21k",
                     decoder_model="surajp/RoBERTa-hindi-guj-san",
                     trained_model_path="./model/") #'sabaridsnfuji/Tamil_Offline_Handwritten_OCR')#"./model/")


    
def main(image_path):
# Process the image and extract text
    image = ocr_model.read_and_show(image_path)
    text = ocr_model.ocr(image)
    
    return image, text

# Gradio Interface function
def gradio_interface(image):
    # Save the uploaded image locally
    image_path = "uploaded_image.png"
    image.save(image_path)
    
    # Call the main function to process the image and get the result
    processed_image, result_text = main(image_path)
    
    return processed_image, result_text

# Sample images for demonstration (make sure these image paths exist)
sample_images = [
    "./sample/16.jpg",  # replace with actual image paths
    "./sample/20.jpg",  # replace with actual image paths
    "./sample/21.jpg",  # replace with actual image paths
    "./sample/31.jpg",  # replace with actual image paths
    "./sample/35.jpg",  # replace with actual image paths
]

# Ensure sample images directory exists
os.makedirs("samples", exist_ok=True)
# Save some dummy sample images if they don't exist (you should replace these with actual images)
for i, sample in enumerate(sample_images):
    if not os.path.exists(sample):
        img = Image.new("RGB", (224, 224), color=(i * 50, i * 50, i * 50))
        img.save(sample)

# Gradio UI setup with examples
gr_interface = gr.Interface(
    fn=gradio_interface, 
    inputs=gr.Image(type="pil"),  # Updated to gr.Image
    outputs=[gr.Image(type="pil"), gr.Textbox()],  # Updated to gr.Image and gr.Textbox
    title="Hindi Handwritten OCR Recognition", 
    description="Upload a cropped image containing a word, or use the sample images below to recognize the text. This is a word recognition model. Currently, text detection is not supported.",
    examples=sample_images  # Add the examples here
)

# Launch the Gradio interface
if __name__ == "__main__":
    gr_interface.launch()