receipt-parser / app.py
laverdes's picture
feat: new flow and new Unstructured receipt parser
2294783
raw
history blame
6.65 kB
import torch
import streamlit as st
from PIL import Image
from io import BytesIO
from transformers import VisionEncoderDecoderModel, VisionEncoderDecoderConfig , DonutProcessor
def run_prediction(sample):
global pretrained_model, processor, task_prompt
if isinstance(sample, dict):
# prepare inputs
pixel_values = torch.tensor(sample["pixel_values"]).unsqueeze(0)
else: # sample is an image
# prepare encoder inputs
pixel_values = processor(image, return_tensors="pt").pixel_values
decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids
# run inference
outputs = pretrained_model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=pretrained_model.decoder.config.max_position_embeddings,
early_stopping=True,
pad_token_id=processor.tokenizer.pad_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
use_cache=True,
num_beams=1,
bad_words_ids=[[processor.tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
# process output
prediction = processor.batch_decode(outputs.sequences)[0]
# post-processing
if "cord" in task_prompt:
prediction = prediction.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
# prediction = re.sub(r"<.*?>", "", prediction, count=1).strip() # remove first task start token
prediction = processor.token2json(prediction)
# load reference target
if isinstance(sample, dict):
target = processor.token2json(sample["target_sequence"])
else:
target = "<not_provided>"
return prediction, target
task_prompt = f"<s>"
logo = Image.open("./img/rsz_unstructured_logo.png")
st.image(logo)
st.markdown('''
### Receipt Parser
This is an OCR-free Document Understanding Transformer nicknamed 🍩. It was fine-tuned with 1000 receipt images -> SROIE dataset.
The original 🍩 implementation can be found on [here](https://github.com/clovaai/donut).
At [Unstructured.io](https://github.com/Unstructured-IO/unstructured) we are on a mission to build custom preprocessing pipelines for labeling, training, or production ML-ready pipelines 🤩.
Come and join us in our public repos and contribute! Each of your contributions and feedback holds great value and is very significant to the community 😊.
''')
image_upload = None
photo = None
with st.sidebar:
information = st.radio(
"What information inside the 🧾s are you interested in extracting?",
('Receipt Summary', 'Receipt Menu Details', 'Extract all', 'Unstructured.io Parser'))
receipt = st.selectbox('Pick one 🧾', ['1', '2', '3', '4', '5', '6'], index=1)
# file upload
uploaded_file = st.file_uploader("Upload a 🧾")
if uploaded_file is not None:
# To read file as bytes:
image_bytes_data = uploaded_file.getvalue()
image_upload = Image.open(BytesIO(image_bytes_data)) #.frombytes('RGBA', (128,128), image_bytes_data, 'raw')
# st.write(bytes_data)
camera_click = st.button('Use my camera')
img_file_buffer = None
if camera_click:
img_file_buffer = st.camera_input("Take a picture of your receipt!")
if img_file_buffer:
# To read image file buffer as a PIL Image:
photo = Image.open(img_file_buffer)
st.info("picture taken!")
st.text(f'{information} mode is ON!\nTarget 🧾: {receipt}') # \n(opening image @:./img/receipt-{receipt}.png)')
col1, col2 = st.columns(2)
if photo:
image = photo
st.info("photo loaded to image")
elif image_upload:
image = image_upload
else:
image = Image.open(f"./img/receipt-{receipt}.jpg")
with col1:
st.image(image, caption='Your target receipt')
if st.button('Parse receipt! 🐍'):
with st.spinner(f'baking the 🍩s...'):
if information == 'Receipt Summary':
processor = DonutProcessor.from_pretrained("unstructuredio/donut-base-sroie")
pretrained_model = VisionEncoderDecoderModel.from_pretrained("unstructuredio/donut-base-sroie")
task_prompt = f"<s>"
device = "cuda" if torch.cuda.is_available() else "cpu"
pretrained_model.to(device)
elif information == 'Receipt Menu Details':
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
pretrained_model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
task_prompt = f"<s_cord-v2>"
device = "cuda" if torch.cuda.is_available() else "cpu"
pretrained_model.to(device)
elif information == 'Unstructured.io Parser':
processor = DonutProcessor.from_pretrained("unstructuredio/donut-base-labelstudio-A1.0")
pretrained_model = VisionEncoderDecoderModel.from_pretrained("unstructuredio/donut-base-labelstudio-A1.0")
task_prompt = f"<s>"
device = "cuda" if torch.cuda.is_available() else "cpu"
pretrained_model.to(device)
else: # Extract all
processor_a = DonutProcessor.from_pretrained("unstructuredio/donut-base-sroie")
processor_b = DonutProcessor.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
pretrained_model_a = VisionEncoderDecoderModel.from_pretrained("unstructuredio/donut-base-sroie")
pretrained_model_b = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base-finetuned-cord-v2")
device = "cuda" if torch.cuda.is_available() else "cpu"
with col2:
if information == 'Extract all':
st.info(f'parsing 🧾 (extracting all)...')
pretrained_model, processor, task_prompt = pretrained_model_a, processor_a, f"<s>"
pretrained_model.to(device)
parsed_receipt_info_a, _ = run_prediction(image)
pretrained_model, processor, task_prompt = pretrained_model_b, processor_b, f"<s_cord-v2>"
pretrained_model.to(device)
parsed_receipt_info_b, _ = run_prediction(image)
st.text(f'\nReceipt Summary:')
st.json(parsed_receipt_info_a)
st.text(f'\nReceipt Menu Details:')
st.json(parsed_receipt_info_b)
else:
st.info(f'parsing 🧾...')
parsed_receipt_info, _ = run_prediction(image)
st.text(f'\n{information}')
st.json(parsed_receipt_info)