|
import gradio as gr |
|
from PIL import Image |
|
import torch |
|
from transformers import AutoFeatureExtractor, AutoTokenizer, TrOCRProcessor, VisionEncoderDecoderModel |
|
import os |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OCRModel: |
|
def __init__(self, encoder_model, decoder_model, trained_model_path): |
|
|
|
self.feature_extractor = AutoFeatureExtractor.from_pretrained(encoder_model) |
|
self.decoder_tokenizer = AutoTokenizer.from_pretrained(decoder_model) |
|
self.processor = TrOCRProcessor(feature_extractor=self.feature_extractor, tokenizer=self.decoder_tokenizer) |
|
self.model = VisionEncoderDecoderModel.from_pretrained(trained_model_path) |
|
|
|
|
|
self.model.config.decoder_start_token_id = self.processor.tokenizer.cls_token_id |
|
self.model.config.pad_token_id = self.processor.tokenizer.pad_token_id |
|
self.model.config.vocab_size = self.model.config.decoder.vocab_size |
|
self.model.config.eos_token_id = self.processor.tokenizer.sep_token_id |
|
self.model.config.max_length = 64 |
|
self.model.config.early_stopping = True |
|
self.model.config.no_repeat_ngram_size = 3 |
|
self.model.config.length_penalty = 2.0 |
|
self.model.config.num_beams = 4 |
|
|
|
def read_and_show(self, image_path): |
|
""" |
|
Reads an image from the provided path and converts it to RGB. |
|
:param image_path: String, path to the input image. |
|
:return: PIL Image object |
|
""" |
|
image = Image.open(image_path).convert('RGB') |
|
return image |
|
|
|
def ocr(self, image): |
|
""" |
|
Performs OCR on the given image. |
|
:param image: PIL Image object. |
|
:return: Extracted text from the image. |
|
""" |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
pixel_values = self.processor(image, return_tensors='pt').pixel_values.to(device) |
|
|
|
|
|
generated_ids = self.model.generate(pixel_values) |
|
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
return generated_text |
|
|
|
|
|
ocr_model = OCRModel(encoder_model="google/vit-base-patch16-224-in21k", |
|
decoder_model="surajp/RoBERTa-hindi-guj-san", |
|
trained_model_path="./model/") |
|
|
|
|
|
|
|
def main(image_path): |
|
|
|
image = ocr_model.read_and_show(image_path) |
|
text = ocr_model.ocr(image) |
|
|
|
return image, text |
|
|
|
|
|
def gradio_interface(image): |
|
|
|
image_path = "uploaded_image.png" |
|
image.save(image_path) |
|
|
|
|
|
processed_image, result_text = main(image_path) |
|
|
|
return processed_image, result_text |
|
|
|
|
|
sample_images = [ |
|
"./sample/16.jpg", |
|
"./sample/20.jpg", |
|
"./sample/21.jpg", |
|
"./sample/31.jpg", |
|
"./sample/35.jpg", |
|
] |
|
|
|
|
|
os.makedirs("samples", exist_ok=True) |
|
|
|
for i, sample in enumerate(sample_images): |
|
if not os.path.exists(sample): |
|
img = Image.new("RGB", (224, 224), color=(i * 50, i * 50, i * 50)) |
|
img.save(sample) |
|
|
|
|
|
gr_interface = gr.Interface( |
|
fn=gradio_interface, |
|
inputs=gr.Image(type="pil"), |
|
outputs=[gr.Image(type="pil"), gr.Textbox()], |
|
title="Hindi Handwritten OCR Recognition", |
|
description="Upload a cropped image containing a word, or use the sample images below to recognize the text. This is a word recognition model. Currently, text detection is not supported.", |
|
examples=sample_images |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
gr_interface.launch() |
|
|