File size: 5,487 Bytes
6581de9
5279e45
6581de9
 
cd76fe5
43a5321
6581de9
 
43a5321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebb030c
43a5321
 
 
 
 
 
 
 
 
 
6581de9
 
 
1d18244
d063ab4
fde368a
6527a47
fde368a
5534f9c
1d18244
5534f9c
fde368a
4bf6412
 
a605f4c
4bf6412
 
 
cd76fe5
9ccb695
4bf6412
f9ce3f3
d6a46a2
 
 
 
cd76fe5
f9ce3f3
 
d6a46a2
1f662e3
f7fe7ff
1dc8f91
 
 
 
d6a46a2
 
 
 
df73b43
c91f43f
ef55841
c91f43f
 
 
 
 
 
ef55841
c91f43f
 
 
 
 
 
 
 
 
 
 
 
 
 
df73b43
1dc8f91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import torch
import streamlit as st

from PIL import Image
from io import BytesIO
from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig , DonutProcessor


def run_prediction(sample):
    global pretrained_model, processor, task_prompt
    if isinstance(sample, dict):
        # prepare inputs
        pixel_values = torch.tensor(sample["pixel_values"]).unsqueeze(0)
    else:  # sample is an image
        # prepare encoder inputs
        pixel_values = processor(image, return_tensors="pt").pixel_values
    
    decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids

    # run inference
    outputs = pretrained_model.generate(
        pixel_values.to(device),
        decoder_input_ids=decoder_input_ids.to(device),
        max_length=pretrained_model.decoder.config.max_position_embeddings,
        early_stopping=True,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=1,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )

    # process output
    prediction = processor.batch_decode(outputs.sequences)[0]
    
    # post-processing
    if "cord" in task_prompt:
        prediction = prediction.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
        # prediction = re.sub(r"<.*?>", "", prediction, count=1).strip()  # remove first task start token
    prediction = processor.token2json(prediction)
    
    # load reference target
    if isinstance(sample, dict):
        target = processor.token2json(sample["target_sequence"])
    else:
        target = "<not_provided>"
    
    return prediction, target
    

task_prompt = f"<s>"

logo = Image.open("./img/unstructured_logo_header.png")
st.image(logo)

st.markdown('''
This is an OCR-free Document Understanding Transformer nicknamed 🍩. It was fine-tuned with 1000 receipt images -> SROIE dataset.
The original 🍩 implementation can be found on [here](https://github.com/clovaai/donut).

At [Unstructured.io](https://github.com/Unstructured-IO/unstructured) we are on a mission to build custom preprocessing pipelines for labeling, training, or production ML-ready pipelines 🤩. 
Come and join us in our public repos and contribute! Each of your contributions and feedback holds great value and is very significant to the community 😊.
''')

image_upload = None
with st.sidebar:
    information = st.radio(
    "What information inside the are you interested in?",
    ('Receipt Summary', 'Receipt Menu Details', 'Extract all'))
    receipt = st.selectbox('Pick one 🧾', ['1', '2', '3', '4', '5', '6'], index=1)

    # file upload
    uploaded_file = st.file_uploader("Upload a 🧾")
    if uploaded_file is not None:
        # To read file as bytes:
        image_bytes_data = uploaded_file.getvalue()
        image_upload = Image.open(BytesIO(image_bytes_data))  #.frombytes('RGBA', (128,128), image_bytes_data, 'raw')
        # st.write(bytes_data)

st.text(f'{information} mode is ON!\nTarget 🧾: {receipt}')  # \n(opening image @:./img/receipt-{receipt}.png)')

image = Image.open(f"./img/receipt-{receipt}.jpg")

col1, col2 = st.columns(2)

with col1:
    if image_upload:
        st.image(image_upload, caption='Your target receipt')
    else:
        st.image(image, caption='Your target receipt')

st.text(f'baking the 🍩s...')

if information == 'Receipt Summary':
    processor = DonutProcessor.from_pretrained("unstructuredio/donut-base-sroie")
    pretrained_model = VisionEncoderDecoderModel.from_pretrained("unstructuredio/donut-base-sroie")
    task_prompt = f"<s>"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    pretrained_model.to(device)

elif information == 'Receipt Menu Details':
    processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
    pretrained_model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
    task_prompt = f"<s_cord-v2>"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    pretrained_model.to(device)
    
else:
    processor_a = DonutProcessor.from_pretrained("unstructuredio/donut-base-sroie")
    processor_b = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
    pretrained_model_a = VisionEncoderDecoderModel.from_pretrained("unstructuredio/donut-base-sroie")
    pretrained_model_b = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
    
    device = "cuda" if torch.cuda.is_available() else "cpu"

with col2:
    if information == 'Extract all!':
        st.text(f'parsing 🧾 (extracting all)...')
        pretrained_model, processor, task_prompt = pretrained_model_a, processor_a, f"<s>"
        pretrained_model.to(device)
        parsed_receipt_info_a, _ = run_prediction(image)
        pretrained_model, processor, task_prompt = pretrained_model_b, processor_b, f"<s_cord-v2>"
        pretrained_model.to(device)
        parsed_receipt_info_b, _ = run_prediction(image)
        st.text(f'\nReceipt Summary:')
        st.json(parsed_receipt_info_a)
        st.text(f'\nReceipt Menu Details:')
        st.json(parsed_receipt_info_b)
    else:
        st.text(f'parsing 🧾...')
        parsed_receipt_info, _ = run_prediction(image)
        st.text(f'\n{information}')
        st.json(parsed_receipt_info)