QuickHawk commited on
Commit
9f3bed9
·
1 Parent(s): f5f7120

git status# This is a combination of 2 commits.

Browse files
Files changed (2) hide show
  1. app.py +63 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from indicnlp.transliterate.unicode_transliterate import UnicodeIndicTransliterator
3
+ from transformers import VisionEncoderDecoderModel, AutoProcessor, AutoTokenizer
4
+ from PIL import Image
5
+ import torch
6
+ from huggingface_hub import snapshot_download
7
+
8
+ snapshot_download(repo_id = "QuickHawk/trocr-indic")
9
+
10
+ ENCODER_MODEL_NAME = "facebook/deit-base-distilled-patch16-224"
11
+ DECODER_MODEL_NAME = "ai4bharat/IndicBART"
12
+
13
+ processor = AutoProcessor.from_pretrained(ENCODER_MODEL_NAME, use_fast=True)
14
+ tokenizer = AutoTokenizer.from_pretrained(DECODER_MODEL_NAME, use_fast=True)
15
+
16
+ model = VisionEncoderDecoderModel.from_pretrained(r"QuickHawk/trocr-indic")
17
+
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ model.to(device)
20
+
21
+ LANG_MAP = {
22
+ "as": "Assamese",
23
+ "bn": "Bengali",
24
+ "gu": "Gujarati",
25
+ "hi": "Hindi",
26
+ "kn": "Kannada",
27
+ "ml": "Malayalam",
28
+ "mr": "Marathi",
29
+ "or": "Odia",
30
+ "pa": "Punjabi",
31
+ "ta": "Tamil",
32
+ "te": "Telugu",
33
+ "ur": "Urdu"
34
+ }
35
+
36
+ bos_id = tokenizer._convert_token_to_id_with_added_voc("<s>")
37
+ eos_id = tokenizer._convert_token_to_id_with_added_voc("</s>")
38
+ pad_id = tokenizer._convert_token_to_id_with_added_voc("<pad>")
39
+
40
+ def predict(image):
41
+
42
+ with torch.no_grad():
43
+ pixel_values = processor(images=image, return_tensors="pt").pixel_values.to(device)
44
+ outputs_ids = model.generate(
45
+ pixel_values,
46
+ use_cache=True,
47
+ num_beams=4,
48
+ max_length=128,
49
+ min_length=1,
50
+ early_stopping=True,
51
+ pad_token_id=pad_id,
52
+ bos_token_id=bos_id,
53
+ eos_token_id=eos_id,
54
+ decoder_start_token_id=tokenizer._convert_token_to_id_with_added_voc("<2en>")
55
+ )
56
+
57
+ lang_token = tokenizer.decode(outputs_ids[0][1])
58
+ lang = lang_token[2:-1]
59
+
60
+ caption = tokenizer.decode(outputs_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
61
+ return UnicodeIndicTransliterator.transliterate(caption, "hi", lang), LANG_MAP[lang]
62
+
63
+ gr.Interface(fn=predict, inputs=gr.Image(type="pil"), outputs=[gr.Text(label = "Predicted Text"), gr.Text(label = "Predicted Language")]).launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ transformers
5
+ pillow
6
+ indicnlp
7
+ indic-nlp-library