sabaridsnfuji's picture
updated the code and model
a42c0ed verified
import gradio as gr
from PIL import Image
import torch
from transformers import AutoFeatureExtractor, AutoTokenizer, TrOCRProcessor, VisionEncoderDecoderModel
import os
# def sauvola_thresholding(grayImage_, window_size=15):
# """
# Sauvola thresholds are local thresholding techniques that are
# useful for images where the background is not uniform, especially for text recognition.
# grayImage_ --- Input image should be in 2-Dimension Gray Scale format.
# window_size --- It represents the filter window size.
# """
# # Assert the input conditions
# assert len(grayImage_.shape) == 2, "Input image must be a 2-dimensional gray scale image."
# assert isinstance(window_size, int) and window_size > 0, "Window size must be a positive integer."
# thresh_sauvolavalue = threshold_sauvola(grayImage_, window_size=window_size)
# thresholdImage_ = (grayImage_ > thresh_sauvolavalue)
# return thresholdImage_
class OCRModel:
def __init__(self, encoder_model, decoder_model, trained_model_path):
# Load processor and model
self.feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_model)
self.decoder_tokenizer = AutoTokenizer.from_pretrained(decoder_model)
self.processor = TrOCRProcessor(feature_extractor=self.feature_extractor, tokenizer=self.decoder_tokenizer)
self.model = VisionEncoderDecoderModel.from_pretrained(trained_model_path)
# Configure model settings
self.model.config.decoder_start_token_id = self.processor.tokenizer.cls_token_id
self.model.config.pad_token_id = self.processor.tokenizer.pad_token_id
self.model.config.vocab_size = self.model.config.decoder.vocab_size
self.model.config.eos_token_id = self.processor.tokenizer.sep_token_id
self.model.config.max_length = 64
self.model.config.early_stopping = True
self.model.config.no_repeat_ngram_size = 3
self.model.config.length_penalty = 2.0
self.model.config.num_beams = 4
def read_and_show(self, image_path):
"""
Reads an image from the provided path and converts it to RGB.
:param image_path: String, path to the input image.
:return: PIL Image object
"""
image = Image.open(image_path).convert('RGB')
return image
def ocr(self, image):
"""
Performs OCR on the given image.
:param image: PIL Image object.
:return: Extracted text from the image.
"""
# Preprocess the image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pixel_values = self.processor(image, return_tensors='pt').pixel_values.to(device)
# Generate text
generated_ids = self.model.generate(pixel_values)
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text
# Initialize the OCR model
ocr_model = OCRModel(encoder_model="google/vit-base-patch16-224-in21k",
decoder_model="surajp/RoBERTa-hindi-guj-san",
trained_model_path="./model/") #'sabaridsnfuji/Tamil_Offline_Handwritten_OCR')#"./model/")
def main(image_path):
# Process the image and extract text
image = ocr_model.read_and_show(image_path)
text = ocr_model.ocr(image)
return image, text
# Gradio Interface function
def gradio_interface(image):
# Save the uploaded image locally
image_path = "uploaded_image.png"
image.save(image_path)
# Call the main function to process the image and get the result
processed_image, result_text = main(image_path)
return processed_image, result_text
# Sample images for demonstration (make sure these image paths exist)
sample_images = [
"./sample/16.jpg", # replace with actual image paths
"./sample/20.jpg", # replace with actual image paths
"./sample/21.jpg", # replace with actual image paths
"./sample/31.jpg", # replace with actual image paths
"./sample/35.jpg", # replace with actual image paths
]
# Ensure sample images directory exists
os.makedirs("samples", exist_ok=True)
# Save some dummy sample images if they don't exist (you should replace these with actual images)
for i, sample in enumerate(sample_images):
if not os.path.exists(sample):
img = Image.new("RGB", (224, 224), color=(i * 50, i * 50, i * 50))
img.save(sample)
# Gradio UI setup with examples
gr_interface = gr.Interface(
fn=gradio_interface,
inputs=gr.Image(type="pil"), # Updated to gr.Image
outputs=[gr.Image(type="pil"), gr.Textbox()], # Updated to gr.Image and gr.Textbox
title="Hindi Handwritten OCR Recognition",
description="Upload a cropped image containing a word, or use the sample images below to recognize the text. This is a word recognition model. Currently, text detection is not supported.",
examples=sample_images # Add the examples here
)
# Launch the Gradio interface
if __name__ == "__main__":
gr_interface.launch()