Spaces:
Runtime error
Runtime error
Commit
·
25f8353
1
Parent(s):
9467a8a
revert 2
Browse files
app.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1 |
import torch
|
2 |
from torchvision import transforms as T
|
3 |
-
from transformers import AutoTokenizer
|
4 |
import gradio as gr
|
5 |
|
6 |
class App:
|
@@ -15,7 +14,6 @@ class App:
|
|
15 |
T.ToTensor(),
|
16 |
T.Normalize(0.5, 0.5)
|
17 |
])
|
18 |
-
self._tokenizer_cache = {}
|
19 |
|
20 |
def _get_model(self, name):
|
21 |
if name in self._model_cache:
|
@@ -24,33 +22,21 @@ class App:
|
|
24 |
self._model_cache[name] = model
|
25 |
return model
|
26 |
|
27 |
-
def _get_tokenizer(self, name):
|
28 |
-
if name in self._tokenizer_cache:
|
29 |
-
return self._tokenizer_cache[name]
|
30 |
-
tokenizer = AutoTokenizer.from_pretrained(name)
|
31 |
-
self._tokenizer_cache[name] = tokenizer
|
32 |
-
return tokenizer
|
33 |
-
|
34 |
@torch.inference_mode()
|
35 |
def __call__(self, model_name, image):
|
36 |
if image is None:
|
37 |
return '', []
|
38 |
model = self._get_model(model_name)
|
39 |
-
tokenizer = self._get_tokenizer(model_name)
|
40 |
-
|
41 |
image = self._preprocess(image.convert('RGB')).unsqueeze(0)
|
42 |
# Greedy decoding
|
43 |
pred = model(image).softmax(-1)
|
44 |
-
|
45 |
-
|
46 |
-
label = tokenizer.decode(pred.argmax(-1)[0].tolist(), skip_special_tokens=True)
|
47 |
-
raw_label, raw_confidence = tokenizer.decode(pred.argmax(-1)[0].tolist(), raw=True)
|
48 |
-
|
49 |
# Format confidence values
|
50 |
-
max_len = 25 if model_name == 'crnn' else len(label) + 1
|
51 |
-
conf = list(map('{:0.1f}'.format,
|
52 |
-
|
53 |
-
|
54 |
|
55 |
def main():
|
56 |
app = App()
|
@@ -63,12 +49,14 @@ def main():
|
|
63 |
read_upload = gr.Button('Read Text')
|
64 |
|
65 |
output = gr.Textbox(max_lines=1, label='Model output')
|
|
|
66 |
raw_output = gr.Dataframe(row_count=2, col_count=0, label='Raw output with confidence values ([0, 1] interval; [B] - BLANK token; [E] - EOS token)')
|
67 |
|
68 |
read_upload.click(app, inputs=[model_name, image_upload], outputs=[output, raw_output])
|
69 |
-
|
70 |
demo.queue(max_size=20)
|
71 |
demo.launch()
|
72 |
|
|
|
73 |
if __name__ == '__main__':
|
74 |
-
main()
|
|
|
1 |
import torch
|
2 |
from torchvision import transforms as T
|
|
|
3 |
import gradio as gr
|
4 |
|
5 |
class App:
|
|
|
14 |
T.ToTensor(),
|
15 |
T.Normalize(0.5, 0.5)
|
16 |
])
|
|
|
17 |
|
18 |
def _get_model(self, name):
|
19 |
if name in self._model_cache:
|
|
|
22 |
self._model_cache[name] = model
|
23 |
return model
|
24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
@torch.inference_mode()
|
26 |
def __call__(self, model_name, image):
|
27 |
if image is None:
|
28 |
return '', []
|
29 |
model = self._get_model(model_name)
|
|
|
|
|
30 |
image = self._preprocess(image.convert('RGB')).unsqueeze(0)
|
31 |
# Greedy decoding
|
32 |
pred = model(image).softmax(-1)
|
33 |
+
label, _ = model.tokenizer.decode(pred)
|
34 |
+
raw_label, raw_confidence = model.tokenizer.decode(pred, raw=True)
|
|
|
|
|
|
|
35 |
# Format confidence values
|
36 |
+
max_len = 25 if model_name == 'crnn' else len(label[0]) + 1
|
37 |
+
conf = list(map('{:0.1f}'.format, raw_confidence[0][:max_len].tolist()))
|
38 |
+
return label[0], [raw_label[0][:max_len], conf]
|
39 |
+
|
40 |
|
41 |
def main():
|
42 |
app = App()
|
|
|
49 |
read_upload = gr.Button('Read Text')
|
50 |
|
51 |
output = gr.Textbox(max_lines=1, label='Model output')
|
52 |
+
#adv_output = gr.Checkbox(label='Show detailed output')
|
53 |
raw_output = gr.Dataframe(row_count=2, col_count=0, label='Raw output with confidence values ([0, 1] interval; [B] - BLANK token; [E] - EOS token)')
|
54 |
|
55 |
read_upload.click(app, inputs=[model_name, image_upload], outputs=[output, raw_output])
|
56 |
+
#adv_output.change(lambda x: gr.update(visible=x), inputs=adv_output, outputs=raw_output)
|
57 |
demo.queue(max_size=20)
|
58 |
demo.launch()
|
59 |
|
60 |
+
|
61 |
if __name__ == '__main__':
|
62 |
+
main()
|