Jac-Zac commited on
Commit
78ed16f
Β·
1 Parent(s): ffe91b0

Create app.py

Browse files

Commit the app file

Files changed (1) hide show
  1. app.py +157 -0
app.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import streamlit as st
3
+ import torch
4
+ from PIL import ExifTags
5
+ from PIL import Image
6
+ from transformers import DonutProcessor
7
+ from transformers import VisionEncoderDecoderConfig
8
+ from transformers import VisionEncoderDecoderModel
9
+
10
+
11
+ def run_prediction(sample):
12
+ global pretrained_model, processor, task_prompt
13
+ if isinstance(sample, dict):
14
+ # prepare inputs
15
+ pixel_values = torch.tensor(sample["pixel_values"]).unsqueeze(0)
16
+ else: # sample is an image
17
+ # prepare encoder inputs
18
+ pixel_values = processor(image, return_tensors="pt").pixel_values
19
+
20
+ decoder_input_ids = processor.tokenizer(
21
+ task_prompt, add_special_tokens=False, return_tensors="pt"
22
+ ).input_ids
23
+
24
+ # run inference
25
+ outputs = pretrained_model.generate(
26
+ pixel_values.to(device),
27
+ decoder_input_ids=decoder_input_ids.to(device),
28
+ max_length=pretrained_model.decoder.config.max_position_embeddings,
29
+ pad_token_id=processor.tokenizer.pad_token_id,
30
+ eos_token_id=processor.tokenizer.eos_token_id,
31
+ use_cache=True,
32
+ num_beams=1,
33
+ bad_words_ids=[[processor.tokenizer.unk_token_id]],
34
+ return_dict_in_generate=True,
35
+ )
36
+
37
+ # process output
38
+ prediction = processor.batch_decode(outputs.sequences)[0]
39
+
40
+ # post-processing
41
+ if "cord" in task_prompt:
42
+ prediction = prediction.replace(processor.tokenizer.eos_token, "").replace(
43
+ processor.tokenizer.pad_token, ""
44
+ )
45
+ # prediction = re.sub(r"<.*?>", "", prediction, count=1).strip() # remove first task start token
46
+ prediction = processor.token2json(prediction)
47
+
48
+ # load reference target
49
+ if isinstance(sample, dict):
50
+ target = processor.token2json(sample["target_sequence"])
51
+ else:
52
+ target = "<not_provided>"
53
+
54
+ return prediction, target
55
+
56
+
57
+ # Image processing change the orientation if needed and the size accordingly to the model we use
58
+ def preprocess_image(image, size):
59
+ # Resize the image to a specific size
60
+ image = image.resize(size)
61
+
62
+ # Check if the image has orientation metadata and rotate it if necessary
63
+ for orientation in ExifTags.TAGS.keys():
64
+ if ExifTags.TAGS[orientation] == "Orientation":
65
+ if hasattr(image, "_getexif"):
66
+ exif = dict(image._getexif().items())
67
+ if exif[orientation] == 3:
68
+ image = image.rotate(180, expand=True)
69
+ elif exif[orientation] == 6:
70
+ image = image.rotate(270, expand=True)
71
+ elif exif[orientation] == 8:
72
+ image = image.rotate(90, expand=True)
73
+ break
74
+
75
+ return image
76
+
77
+
78
+ # What does this model do
79
+ task_prompt = "<s_herbarium>>"
80
+ st.markdown(
81
+ """
82
+ ### Donut Herbarium Testing
83
+ Experimental OCR-free Document Understanding Vision Transformer, fine-tuned with an herbarium dataset of around 1400 images.
84
+ """
85
+ )
86
+
87
+ with st.sidebar:
88
+ information = st.radio(
89
+ "Choose one predictor:",
90
+ ("Low Res (1200 * 900) 5 epochs", "Mid res (1600 ^ 1200) 10 epochs"),
91
+ )
92
+ image_choice = st.selectbox("Pick one πŸ“‘", ["1", "2", "3"], index=1)
93
+ uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
94
+
95
+ st.text(
96
+ f"{information} mode is ON!\nTarget πŸ“‘: {image_choice}"
97
+ ) # \n(opening image @:./img/receipt-{receipt}.png)')
98
+
99
+ col1, col2 = st.columns(2)
100
+
101
+ # Chose image
102
+ if uploaded_file is not None:
103
+ image = Image.open(uploaded_file)
104
+ else:
105
+ pass
106
+ # image_choice_map = {
107
+ # '1': '../donut_example/copy/img_resized/test/00021.jpg',
108
+ # '2': '../donut_example/copy/img_resized/test/00031.jpg',
109
+ # '3': '../donut_example/copy/img_resized/test/00050.jpg',
110
+ # }
111
+ # image = Image.open(image_choice_map[image_choice])
112
+
113
+
114
+ if information == "Low Res (1200 * 900) 5 epochs":
115
+ image = preprocess_image(image, (1200, 900))
116
+ else:
117
+ image = preprocess_image(image, (1600, 1200))
118
+
119
+ with col1:
120
+ st.image(image, caption="Your target sample")
121
+
122
+ # Run the model
123
+ if st.button("Parse sample! 🐍"):
124
+ image = image.convert("RGB")
125
+
126
+ # Choose which version to run base on the selected box
127
+ with st.spinner(f"Running the model on the target..."):
128
+ if information == "Low Res (1200 * 900) 5 epochs":
129
+ processor = DonutProcessor.from_pretrained(
130
+ "Jac-Zac/thesis_test_donut",
131
+ revision="12900abc6fb551a0ea339950462a6a0462820b75",
132
+ )
133
+ pretrained_model = VisionEncoderDecoderModel.from_pretrained(
134
+ "Jac-Zac/thesis_test_donut",
135
+ revision="12900abc6fb551a0ea339950462a6a0462820b75",
136
+ )
137
+
138
+ elif information == "Mid res (1600 ^ 1200) 10 epochs":
139
+ processor = DonutProcessor.from_pretrained(
140
+ "Jac-Zac/thesis_test_donut",
141
+ revision="8c5467cb66685e801ec6ff8de7e7fdd247274ed0",
142
+ )
143
+ pretrained_model = VisionEncoderDecoderModel.from_pretrained(
144
+ "Jac-Zac/thesis_test_donut",
145
+ revision="8c5467cb66685e801ec6ff8de7e7fdd247274ed0",
146
+ )
147
+
148
+ # this is the same for both models
149
+ task_prompt = f"<s_herbarium>"
150
+ device = "cuda" if torch.cuda.is_available() else "cpu"
151
+ pretrained_model.to(device)
152
+
153
+ with col2:
154
+ st.info(f"Parsing πŸ“‘...")
155
+ parsed_info, _ = run_prediction(image)
156
+ st.text(f"\n{information}")
157
+ st.json(parsed_info)