murtazadahmardeh commited on
Commit
25f8353
·
1 Parent(s): 9467a8a
Files changed (1) hide show
  1. app.py +10 -22
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import torch
2
  from torchvision import transforms as T
3
- from transformers import AutoTokenizer
4
  import gradio as gr
5
 
6
  class App:
@@ -15,7 +14,6 @@ class App:
15
  T.ToTensor(),
16
  T.Normalize(0.5, 0.5)
17
  ])
18
- self._tokenizer_cache = {}
19
 
20
  def _get_model(self, name):
21
  if name in self._model_cache:
@@ -24,33 +22,21 @@ class App:
24
  self._model_cache[name] = model
25
  return model
26
 
27
- def _get_tokenizer(self, name):
28
- if name in self._tokenizer_cache:
29
- return self._tokenizer_cache[name]
30
- tokenizer = AutoTokenizer.from_pretrained(name)
31
- self._tokenizer_cache[name] = tokenizer
32
- return tokenizer
33
-
34
  @torch.inference_mode()
35
  def __call__(self, model_name, image):
36
  if image is None:
37
  return '', []
38
  model = self._get_model(model_name)
39
- tokenizer = self._get_tokenizer(model_name)
40
-
41
  image = self._preprocess(image.convert('RGB')).unsqueeze(0)
42
  # Greedy decoding
43
  pred = model(image).softmax(-1)
44
-
45
- # Tokenize input data
46
- label = tokenizer.decode(pred.argmax(-1)[0].tolist(), skip_special_tokens=True)
47
- raw_label, raw_confidence = tokenizer.decode(pred.argmax(-1)[0].tolist(), raw=True)
48
-
49
  # Format confidence values
50
- max_len = 25 if model_name == 'crnn' else len(label) + 1
51
- conf = list(map('{:0.1f}'.format, pred[0, :, :max_len].tolist()))
52
-
53
- return label, [raw_label[:max_len], conf]
54
 
55
  def main():
56
  app = App()
@@ -63,12 +49,14 @@ def main():
63
  read_upload = gr.Button('Read Text')
64
 
65
  output = gr.Textbox(max_lines=1, label='Model output')
 
66
  raw_output = gr.Dataframe(row_count=2, col_count=0, label='Raw output with confidence values ([0, 1] interval; [B] - BLANK token; [E] - EOS token)')
67
 
68
  read_upload.click(app, inputs=[model_name, image_upload], outputs=[output, raw_output])
69
-
70
  demo.queue(max_size=20)
71
  demo.launch()
72
 
 
73
  if __name__ == '__main__':
74
- main()
 
1
  import torch
2
  from torchvision import transforms as T
 
3
  import gradio as gr
4
 
5
  class App:
 
14
  T.ToTensor(),
15
  T.Normalize(0.5, 0.5)
16
  ])
 
17
 
18
  def _get_model(self, name):
19
  if name in self._model_cache:
 
22
  self._model_cache[name] = model
23
  return model
24
 
 
 
 
 
 
 
 
25
  @torch.inference_mode()
26
  def __call__(self, model_name, image):
27
  if image is None:
28
  return '', []
29
  model = self._get_model(model_name)
 
 
30
  image = self._preprocess(image.convert('RGB')).unsqueeze(0)
31
  # Greedy decoding
32
  pred = model(image).softmax(-1)
33
+ label, _ = model.tokenizer.decode(pred)
34
+ raw_label, raw_confidence = model.tokenizer.decode(pred, raw=True)
 
 
 
35
  # Format confidence values
36
+ max_len = 25 if model_name == 'crnn' else len(label[0]) + 1
37
+ conf = list(map('{:0.1f}'.format, raw_confidence[0][:max_len].tolist()))
38
+ return label[0], [raw_label[0][:max_len], conf]
39
+
40
 
41
  def main():
42
  app = App()
 
49
  read_upload = gr.Button('Read Text')
50
 
51
  output = gr.Textbox(max_lines=1, label='Model output')
52
+ #adv_output = gr.Checkbox(label='Show detailed output')
53
  raw_output = gr.Dataframe(row_count=2, col_count=0, label='Raw output with confidence values ([0, 1] interval; [B] - BLANK token; [E] - EOS token)')
54
 
55
  read_upload.click(app, inputs=[model_name, image_upload], outputs=[output, raw_output])
56
+ #adv_output.change(lambda x: gr.update(visible=x), inputs=adv_output, outputs=raw_output)
57
  demo.queue(max_size=20)
58
  demo.launch()
59
 
60
+
61
  if __name__ == '__main__':
62
+ main()