Spaces:
Runtime error
Runtime error
Create app.py
Browse filesCommit the app file
app.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
import streamlit as st
|
3 |
+
import torch
|
4 |
+
from PIL import ExifTags
|
5 |
+
from PIL import Image
|
6 |
+
from transformers import DonutProcessor
|
7 |
+
from transformers import VisionEncoderDecoderConfig
|
8 |
+
from transformers import VisionEncoderDecoderModel
|
9 |
+
|
10 |
+
|
11 |
+
def run_prediction(sample):
|
12 |
+
global pretrained_model, processor, task_prompt
|
13 |
+
if isinstance(sample, dict):
|
14 |
+
# prepare inputs
|
15 |
+
pixel_values = torch.tensor(sample["pixel_values"]).unsqueeze(0)
|
16 |
+
else: # sample is an image
|
17 |
+
# prepare encoder inputs
|
18 |
+
pixel_values = processor(image, return_tensors="pt").pixel_values
|
19 |
+
|
20 |
+
decoder_input_ids = processor.tokenizer(
|
21 |
+
task_prompt, add_special_tokens=False, return_tensors="pt"
|
22 |
+
).input_ids
|
23 |
+
|
24 |
+
# run inference
|
25 |
+
outputs = pretrained_model.generate(
|
26 |
+
pixel_values.to(device),
|
27 |
+
decoder_input_ids=decoder_input_ids.to(device),
|
28 |
+
max_length=pretrained_model.decoder.config.max_position_embeddings,
|
29 |
+
pad_token_id=processor.tokenizer.pad_token_id,
|
30 |
+
eos_token_id=processor.tokenizer.eos_token_id,
|
31 |
+
use_cache=True,
|
32 |
+
num_beams=1,
|
33 |
+
bad_words_ids=[[processor.tokenizer.unk_token_id]],
|
34 |
+
return_dict_in_generate=True,
|
35 |
+
)
|
36 |
+
|
37 |
+
# process output
|
38 |
+
prediction = processor.batch_decode(outputs.sequences)[0]
|
39 |
+
|
40 |
+
# post-processing
|
41 |
+
if "cord" in task_prompt:
|
42 |
+
prediction = prediction.replace(processor.tokenizer.eos_token, "").replace(
|
43 |
+
processor.tokenizer.pad_token, ""
|
44 |
+
)
|
45 |
+
# prediction = re.sub(r"<.*?>", "", prediction, count=1).strip() # remove first task start token
|
46 |
+
prediction = processor.token2json(prediction)
|
47 |
+
|
48 |
+
# load reference target
|
49 |
+
if isinstance(sample, dict):
|
50 |
+
target = processor.token2json(sample["target_sequence"])
|
51 |
+
else:
|
52 |
+
target = "<not_provided>"
|
53 |
+
|
54 |
+
return prediction, target
|
55 |
+
|
56 |
+
|
57 |
+
# Image processing change the orientation if needed and the size accordingly to the model we use
|
58 |
+
def preprocess_image(image, size):
|
59 |
+
# Resize the image to a specific size
|
60 |
+
image = image.resize(size)
|
61 |
+
|
62 |
+
# Check if the image has orientation metadata and rotate it if necessary
|
63 |
+
for orientation in ExifTags.TAGS.keys():
|
64 |
+
if ExifTags.TAGS[orientation] == "Orientation":
|
65 |
+
if hasattr(image, "_getexif"):
|
66 |
+
exif = dict(image._getexif().items())
|
67 |
+
if exif[orientation] == 3:
|
68 |
+
image = image.rotate(180, expand=True)
|
69 |
+
elif exif[orientation] == 6:
|
70 |
+
image = image.rotate(270, expand=True)
|
71 |
+
elif exif[orientation] == 8:
|
72 |
+
image = image.rotate(90, expand=True)
|
73 |
+
break
|
74 |
+
|
75 |
+
return image
|
76 |
+
|
77 |
+
|
78 |
+
# What does this model do
|
79 |
+
task_prompt = "<s_herbarium>>"
|
80 |
+
st.markdown(
|
81 |
+
"""
|
82 |
+
### Donut Herbarium Testing
|
83 |
+
Experimental OCR-free Document Understanding Vision Transformer, fine-tuned with an herbarium dataset of around 1400 images.
|
84 |
+
"""
|
85 |
+
)
|
86 |
+
|
87 |
+
with st.sidebar:
|
88 |
+
information = st.radio(
|
89 |
+
"Choose one predictor:",
|
90 |
+
("Low Res (1200 * 900) 5 epochs", "Mid res (1600 ^ 1200) 10 epochs"),
|
91 |
+
)
|
92 |
+
image_choice = st.selectbox("Pick one π", ["1", "2", "3"], index=1)
|
93 |
+
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
|
94 |
+
|
95 |
+
st.text(
|
96 |
+
f"{information} mode is ON!\nTarget π: {image_choice}"
|
97 |
+
) # \n(opening image @:./img/receipt-{receipt}.png)')
|
98 |
+
|
99 |
+
col1, col2 = st.columns(2)
|
100 |
+
|
101 |
+
# Chose image
|
102 |
+
if uploaded_file is not None:
|
103 |
+
image = Image.open(uploaded_file)
|
104 |
+
else:
|
105 |
+
pass
|
106 |
+
# image_choice_map = {
|
107 |
+
# '1': '../donut_example/copy/img_resized/test/00021.jpg',
|
108 |
+
# '2': '../donut_example/copy/img_resized/test/00031.jpg',
|
109 |
+
# '3': '../donut_example/copy/img_resized/test/00050.jpg',
|
110 |
+
# }
|
111 |
+
# image = Image.open(image_choice_map[image_choice])
|
112 |
+
|
113 |
+
|
114 |
+
if information == "Low Res (1200 * 900) 5 epochs":
|
115 |
+
image = preprocess_image(image, (1200, 900))
|
116 |
+
else:
|
117 |
+
image = preprocess_image(image, (1600, 1200))
|
118 |
+
|
119 |
+
with col1:
|
120 |
+
st.image(image, caption="Your target sample")
|
121 |
+
|
122 |
+
# Run the model
|
123 |
+
if st.button("Parse sample! π"):
|
124 |
+
image = image.convert("RGB")
|
125 |
+
|
126 |
+
# Choose which version to run base on the selected box
|
127 |
+
with st.spinner(f"Running the model on the target..."):
|
128 |
+
if information == "Low Res (1200 * 900) 5 epochs":
|
129 |
+
processor = DonutProcessor.from_pretrained(
|
130 |
+
"Jac-Zac/thesis_test_donut",
|
131 |
+
revision="12900abc6fb551a0ea339950462a6a0462820b75",
|
132 |
+
)
|
133 |
+
pretrained_model = VisionEncoderDecoderModel.from_pretrained(
|
134 |
+
"Jac-Zac/thesis_test_donut",
|
135 |
+
revision="12900abc6fb551a0ea339950462a6a0462820b75",
|
136 |
+
)
|
137 |
+
|
138 |
+
elif information == "Mid res (1600 ^ 1200) 10 epochs":
|
139 |
+
processor = DonutProcessor.from_pretrained(
|
140 |
+
"Jac-Zac/thesis_test_donut",
|
141 |
+
revision="8c5467cb66685e801ec6ff8de7e7fdd247274ed0",
|
142 |
+
)
|
143 |
+
pretrained_model = VisionEncoderDecoderModel.from_pretrained(
|
144 |
+
"Jac-Zac/thesis_test_donut",
|
145 |
+
revision="8c5467cb66685e801ec6ff8de7e7fdd247274ed0",
|
146 |
+
)
|
147 |
+
|
148 |
+
# this is the same for both models
|
149 |
+
task_prompt = f"<s_herbarium>"
|
150 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
151 |
+
pretrained_model.to(device)
|
152 |
+
|
153 |
+
with col2:
|
154 |
+
st.info(f"Parsing π...")
|
155 |
+
parsed_info, _ = run_prediction(image)
|
156 |
+
st.text(f"\n{information}")
|
157 |
+
st.json(parsed_info)
|