|
from textwrap import wrap |
|
|
|
from transformers import pipeline |
|
import nlpaug.augmenter.char as nac |
|
import streamlit as st |
|
|
|
st.markdown('# ByT5 Dutch OCR Corrector :pill:') |
|
st.write('This app corrects common dutch OCR mistakes, to showcase how this could be used in an OCR post-processing pipeline.') |
|
|
|
st.markdown(""" |
|
To use this: |
|
- Enter a text with OCR mistakes and hit 'unscramble':point_down: |
|
- Or enter a normal text, scramble it :twisted_rightwards_arrows: and then hit 'unscramble' :point_down:""") |
|
|
|
@st.cache(allow_output_mutation=True, |
|
suppress_st_warning=True, |
|
show_spinner=False) |
|
def load_model(): |
|
with st.spinner('Please wait for the model to load...'): |
|
ocr_pipeline=pipeline( |
|
'text2text-generation', |
|
model='ml6team/byt5-base-dutch-ocr-correction', |
|
tokenizer='ml6team/byt5-base-dutch-ocr-correction' |
|
) |
|
return ocr_pipeline |
|
|
|
ocr_pipeline = load_model() |
|
|
|
if 'text' not in st.session_state: |
|
st.session_state['text'] = "" |
|
|
|
left_area, right_area = st.columns(2) |
|
|
|
|
|
left_area.header("Input") |
|
form = left_area.form(key='ocrcorrector') |
|
placeholder = form.empty() |
|
placeholder.empty() |
|
input_text = placeholder.text_area(value=st.session_state.text, label='Insert text:', key='input_text') |
|
scramble_button = form.form_submit_button(label='Scramble') |
|
submit_button = form.form_submit_button(label='Unscramble') |
|
|
|
|
|
right_area.header("Output") |
|
|
|
if scramble_button: |
|
aug = nac.OcrAug() |
|
st.session_state.text = st.session_state.input_text |
|
base_text = st.session_state.text |
|
augmented_data = aug.augment(base_text) |
|
st.session_state.text = augmented_data |
|
del st.session_state.input_text |
|
placeholder.empty() |
|
input_text = placeholder.text_area(value=st.session_state.text, label='Insert text:', key='input_text') |
|
|
|
if submit_button: |
|
base_text = st.session_state.input_text |
|
output_text = " ".join([x['generated_text'] for x in ocr_pipeline(wrap(base_text, 128))]) |
|
right_area.markdown('#####') |
|
right_area.text_area(value=output_text, label="Corrected text:") |
|
|
|
|