import gradio as gr from indicnlp.transliterate.unicode_transliterate import UnicodeIndicTransliterator from transformers import VisionEncoderDecoderModel, AutoProcessor, AutoTokenizer from PIL import Image import torch from huggingface_hub import snapshot_download snapshot_download(repo_id = "QuickHawk/trocr-indic") ENCODER_MODEL_NAME = "facebook/deit-base-distilled-patch16-224" DECODER_MODEL_NAME = "ai4bharat/IndicBART" processor = AutoProcessor.from_pretrained(ENCODER_MODEL_NAME, use_fast=True) tokenizer = AutoTokenizer.from_pretrained(DECODER_MODEL_NAME, use_fast=True) model = VisionEncoderDecoderModel.from_pretrained(r"QuickHawk/trocr-indic") device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) LANG_MAP = { "as": "Assamese", "bn": "Bengali", "gu": "Gujarati", "hi": "Hindi", "kn": "Kannada", "ml": "Malayalam", "mr": "Marathi", "or": "Odia", "pa": "Punjabi", "ta": "Tamil", "te": "Telugu", "ur": "Urdu" } bos_id = tokenizer._convert_token_to_id_with_added_voc("") eos_id = tokenizer._convert_token_to_id_with_added_voc("") pad_id = tokenizer._convert_token_to_id_with_added_voc("") def predict(image): with torch.no_grad(): pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(device) outputs_ids = model.generate( pixel_values, use_cache=True, num_beams=4, max_length=128, min_length=1, early_stopping=True, pad_token_id=pad_id, bos_token_id=bos_id, eos_token_id=eos_id, decoder_start_token_id=tokenizer._convert_token_to_id_with_added_voc("<2en>") ) lang_token = tokenizer.decode(outputs_ids[0][1]) lang = lang_token[2:-1] caption = tokenizer.decode(outputs_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) return UnicodeIndicTransliterator.transliterate(caption, "hi", lang), LANG_MAP[lang] gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=[gr.Text(label = "Predicted Text"), gr.Text(label = "Predicted Language")]).launch()