import gradio as gr from PIL import Image import torch from transformers import AutoFeatureExtractor, AutoTokenizer, TrOCRProcessor, VisionEncoderDecoderModel import os # def sauvola_thresholding(grayImage_, window_size=15): # """ # Sauvola thresholds are local thresholding techniques that are # useful for images where the background is not uniform, especially for text recognition. # grayImage_ --- Input image should be in 2-Dimension Gray Scale format. # window_size --- It represents the filter window size. # """ # # Assert the input conditions # assert len(grayImage_.shape) == 2, "Input image must be a 2-dimensional gray scale image." # assert isinstance(window_size, int) and window_size > 0, "Window size must be a positive integer." # thresh_sauvolavalue = threshold_sauvola(grayImage_, window_size=window_size) # thresholdImage_ = (grayImage_ > thresh_sauvolavalue) # return thresholdImage_ class OCRModel: def __init__(self, encoder_model, decoder_model, trained_model_path): # Load processor and model self.feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_model) self.decoder_tokenizer = AutoTokenizer.from_pretrained(decoder_model) self.processor = TrOCRProcessor(feature_extractor=self.feature_extractor, tokenizer=self.decoder_tokenizer) self.model = VisionEncoderDecoderModel.from_pretrained(trained_model_path) # Configure model settings self.model.config.decoder_start_token_id = self.processor.tokenizer.cls_token_id self.model.config.pad_token_id = self.processor.tokenizer.pad_token_id self.model.config.vocab_size = self.model.config.decoder.vocab_size self.model.config.eos_token_id = self.processor.tokenizer.sep_token_id self.model.config.max_length = 64 self.model.config.early_stopping = True self.model.config.no_repeat_ngram_size = 3 self.model.config.length_penalty = 2.0 self.model.config.num_beams = 4 def read_and_show(self, image_path): """ Reads an image from the provided path and converts it to RGB. :param image_path: String, path to the input image. :return: PIL Image object """ image = Image.open(image_path).convert('RGB') return image def ocr(self, image): """ Performs OCR on the given image. :param image: PIL Image object. :return: Extracted text from the image. """ # Preprocess the image device = torch.device("cuda" if torch.cuda.is_available() else "cpu") pixel_values = self.processor(image, return_tensors='pt').pixel_values.to(device) # Generate text generated_ids = self.model.generate(pixel_values) generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return generated_text # Initialize the OCR model ocr_model = OCRModel(encoder_model="google/vit-base-patch16-224-in21k", decoder_model="surajp/RoBERTa-hindi-guj-san", trained_model_path="./model/") #'sabaridsnfuji/Tamil_Offline_Handwritten_OCR')#"./model/") def main(image_path): # Process the image and extract text image = ocr_model.read_and_show(image_path) text = ocr_model.ocr(image) return image, text # Gradio Interface function def gradio_interface(image): # Save the uploaded image locally image_path = "uploaded_image.png" image.save(image_path) # Call the main function to process the image and get the result processed_image, result_text = main(image_path) return processed_image, result_text # Sample images for demonstration (make sure these image paths exist) sample_images = [ "./sample/16.jpg", # replace with actual image paths "./sample/20.jpg", # replace with actual image paths "./sample/21.jpg", # replace with actual image paths "./sample/31.jpg", # replace with actual image paths "./sample/35.jpg", # replace with actual image paths ] # Ensure sample images directory exists os.makedirs("samples", exist_ok=True) # Save some dummy sample images if they don't exist (you should replace these with actual images) for i, sample in enumerate(sample_images): if not os.path.exists(sample): img = Image.new("RGB", (224, 224), color=(i * 50, i * 50, i * 50)) img.save(sample) # Gradio UI setup with examples gr_interface = gr.Interface( fn=gradio_interface, inputs=gr.Image(type="pil"), # Updated to gr.Image outputs=[gr.Image(type="pil"), gr.Textbox()], # Updated to gr.Image and gr.Textbox title="Hindi Handwritten OCR Recognition", description="Upload a cropped image containing a word, or use the sample images below to recognize the text. This is a word recognition model. Currently, text detection is not supported.", examples=sample_images # Add the examples here ) # Launch the Gradio interface if __name__ == "__main__": gr_interface.launch()