Spaces:
Runtime error
Runtime error
File size: 6,279 Bytes
78ed16f 3f6e42c e7ffe81 78ed16f e7ffe81 78ed16f 38be16b 78ed16f 7ccd814 9b5236e 78ed16f e7ffe81 d2f1f4f 78ed16f 3f6e42c 016f785 3ced1e5 3f6e42c ecee539 78ed16f 3f6e42c 78ed16f 3f6e42c 78ed16f cd48314 78ed16f 3f6e42c 78ed16f 3f6e42c 78ed16f d2ffc38 21fbeff d2ffc38 21fbeff d2ffc38 7ccd814 aa01645 6fdbc56 7ccd814 aa01645 6fdbc56 7ccd814 78ed16f 016f785 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
#!/usr/bin/env python3
import streamlit as st
import torch
import os
from PIL import Image, ImageOps
from transformers import DonutProcessor
from transformers import VisionEncoderDecoderConfig
from transformers import VisionEncoderDecoderModel
def run_prediction(sample):
global pretrained_model, processor, task_prompt
if isinstance(sample, dict):
# prepare inputs
pixel_values = torch.tensor(sample["pixel_values"]).unsqueeze(0)
else: # sample is an image
# prepare encoder inputs
pixel_values = processor(image, return_tensors="pt").pixel_values
decoder_input_ids = processor.tokenizer(
task_prompt, add_special_tokens=False, return_tensors="pt"
).input_ids
# run inference
outputs = pretrained_model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=pretrained_model.decoder.config.max_position_embeddings,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# process output
prediction = processor.batch_decode(outputs.sequences)[0]
# post-processing
if "cord" in task_prompt:
prediction = prediction.replace(processor.tokenizer.eos_token, "").replace(
processor.tokenizer.pad_token, ""
)
# prediction = re.sub(r"<.*?>", "", prediction, count=1).strip() # remove first task start token
prediction = processor.token2json(prediction)
# load reference target
if isinstance(sample, dict):
target = processor.token2json(sample["target_sequence"])
else:
target = "<not_provided>"
return prediction, target
# Image processing change the orientation if needed and the size accordingly to the model we use
def preprocess_image(image, size):
# Resize the image to a specific size
image = image.resize(size)
# Automatically rotate the image based on its EXIF orientation metadata
image = ImageOps.exif_transpose(image)
return image
# What does this model do
task_prompt = "<s_herbarium>>"
st.markdown(
"""
### Donut Herbarium Testing
Experimental OCR-free Document Understanding Vision Transformer, fine-tuned with an herbarium dataset of around 1400 images.
"""
)
with st.sidebar:
information = st.radio(
"Choose one predictor:",
("Low Res (1200 * 900) 5 epochs", "Mid res (1600 * 1200) 10 epochs", "Mid res (1600 * 1200) 14 epochs", "Mid res new 0 epoch")
)
image_choice = st.selectbox("Pick one π", ["1", "2", "3","4"], index=0)
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
st.text(
f"{information} mode is ON!\nTarget π: {image_choice}"
) # \n(opening image @:./img/receipt-{receipt}.png)')
col1, col2 = st.columns(2)
# Chose image
if uploaded_file is not None:
image = Image.open(uploaded_file)
if information == "Low Res (1200 * 900) 5 epochs":
image = preprocess_image(image, (1200, 900))
else:
image = preprocess_image(image, (1200, 1600))
else:
image_choice_map = {
'1': 'examples/00021.jpg',
'2': 'examples/00031.jpg',
'3': 'examples/00050.jpg',
'4': 'examples/zero_name.jpg',
}
image = Image.open(image_choice_map[image_choice])
with col1:
st.image(image, caption="Your target sample")
# Run the model
if st.button("Parse sample! π"):
image = image.convert("RGB")
# Choose which version to run base on the selected box
with st.spinner(f"Running the model on the target..."):
if information == "Low Res (1200 * 900) 5 epochs":
processor = DonutProcessor.from_pretrained(
"Jac-Zac/thesis_test_donut",
revision="12900abc6fb551a0ea339950462a6a0462820b75",
use_auth_token=os.environ["TOKEN"],
)
pretrained_model = VisionEncoderDecoderModel.from_pretrained(
"Jac-Zac/thesis_test_donut",
revision="12900abc6fb551a0ea339950462a6a0462820b75",
use_auth_token=os.environ["TOKEN"],
)
elif information == "Mid res (1600 * 1200) 10 epochs":
processor = DonutProcessor.from_pretrained(
"Jac-Zac/thesis_test_donut",
revision="8c5467cb66685e801ec6ff8de7e7fdd247274ed0",
use_auth_token=os.environ["TOKEN"],
)
pretrained_model = VisionEncoderDecoderModel.from_pretrained(
"Jac-Zac/thesis_test_donut",
revision="8c5467cb66685e801ec6ff8de7e7fdd247274ed0",
use_auth_token=os.environ["TOKEN"],
)
elif information == "Mid res (1600 * 1200) 14 epochs":
processor = DonutProcessor.from_pretrained(
"Jac-Zac/thesis_test_donut",
revision="ba396d4b3d39a4eaf7c8d4919b384ebcf6f0360f",
use_auth_token=os.environ["TOKEN"],
)
pretrained_model = VisionEncoderDecoderModel.from_pretrained(
"Jac-Zac/thesis_test_donut",
revision="ba396d4b3d39a4eaf7c8d4919b384ebcf6f0360f",
use_auth_token=os.environ["TOKEN"],
)
elif information == "Mid res new 0 epoch":
processor = DonutProcessor.from_pretrained(
"Jac-Zac/thesis_donut",
#revision="4d64fa9a156908aa3df0e0e39463d401528a15c9",
use_auth_token=os.environ["TOKEN"],
)
pretrained_model = VisionEncoderDecoderModel.from_pretrained(
"Jac-Zac/thesis_donut",
#revision="4d64fa9a156908aa3df0e0e39463d401528a15c9",
use_auth_token=os.environ["TOKEN"],
)
# this is the same for both models
task_prompt = f"<s_herbarium>"
device = "cuda" if torch.cuda.is_available() else "cpu"
pretrained_model.to(device)
with col2:
st.info(f"Parsing π...")
parsed_info, _ = run_prediction(image)
st.text(f"\n{information}")
st.json(parsed_info)
|