Spaces:
Sleeping
Sleeping
import gradio as gr | |
from indicnlp.transliterate.unicode_transliterate import UnicodeIndicTransliterator | |
from transformers import VisionEncoderDecoderModel, AutoProcessor, AutoTokenizer | |
from PIL import Image | |
import torch | |
from huggingface_hub import snapshot_download | |
snapshot_download(repo_id = "QuickHawk/trocr-indic") | |
ENCODER_MODEL_NAME = "facebook/deit-base-distilled-patch16-224" | |
DECODER_MODEL_NAME = "ai4bharat/IndicBART" | |
processor = AutoProcessor.from_pretrained(ENCODER_MODEL_NAME, use_fast=True) | |
tokenizer = AutoTokenizer.from_pretrained(DECODER_MODEL_NAME, use_fast=True) | |
model = VisionEncoderDecoderModel.from_pretrained(r"QuickHawk/trocr-indic") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
LANG_MAP = { | |
"as": "Assamese", | |
"bn": "Bengali", | |
"gu": "Gujarati", | |
"hi": "Hindi", | |
"kn": "Kannada", | |
"ml": "Malayalam", | |
"mr": "Marathi", | |
"or": "Odia", | |
"pa": "Punjabi", | |
"ta": "Tamil", | |
"te": "Telugu", | |
"ur": "Urdu" | |
} | |
bos_id = tokenizer._convert_token_to_id_with_added_voc("<s>") | |
eos_id = tokenizer._convert_token_to_id_with_added_voc("</s>") | |
pad_id = tokenizer._convert_token_to_id_with_added_voc("<pad>") | |
def predict(image): | |
with torch.no_grad(): | |
pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(device) | |
outputs_ids = model.generate( | |
pixel_values, | |
use_cache=True, | |
num_beams=4, | |
max_length=128, | |
min_length=1, | |
early_stopping=True, | |
pad_token_id=pad_id, | |
bos_token_id=bos_id, | |
eos_token_id=eos_id, | |
decoder_start_token_id=tokenizer._convert_token_to_id_with_added_voc("<2en>") | |
) | |
lang_token = tokenizer.decode(outputs_ids[0][1]) | |
lang = lang_token[2:-1] | |
caption = tokenizer.decode(outputs_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
return UnicodeIndicTransliterator.transliterate(caption, "hi", lang), LANG_MAP[lang] | |
gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=[gr.Text(label = "Predicted Text"), gr.Text(label = "Predicted Language")]).launch() | |