Thomas Dehaene
Update streamlit
4230a5d
raw
history blame
2.45 kB
from textwrap import wrap
from transformers import pipeline
import nlpaug.augmenter.char as nac
import subprocess
import sys
import logging
import importlib
import streamlit
def install():
subprocess.check_call([sys.executable, "-m", "pip", "install", "streamlit==0.89.0"])
install()
importlib.reload(streamlit)
logging.warning(st.__version__)
streamlit.markdown('# ByT5 Dutch OCR Corrector :pill:')
streamlit.write('This app corrects common dutch OCR mistakes, to showcase how this could be used in an OCR post-processing pipeline.')
streamlit.markdown("""
To use this:
- Enter a text with OCR mistakes and hit 'unscramble':point_down:
- Or enter a normal text, scramble it :twisted_rightwards_arrows: and then hit 'unscramble' :point_down:""")
@streamlit.cache(allow_output_mutation=True,
suppress_st_warning=True,
show_spinner=False)
def load_model():
with streamlit.spinner('Please wait for the model to load...'):
ocr_pipeline=pipeline(
'text2text-generation',
model='ml6team/byt5-base-dutch-ocr-correction',
tokenizer='ml6team/byt5-base-dutch-ocr-correction'
)
return ocr_pipeline
ocr_pipeline = load_model()
if 'text' not in streamlit.session_state:
streamlit.session_state.text = ""
left_area, right_area = streamlit.beta_columns(2)
# Format the left area
left_area.header("Input")
form = left_area.form(key='ocrcorrector')
placeholder = form.empty()
placeholder.empty()
input_text = placeholder.text_area(value=st.session_state.text, label='Insert text:', key='input_text')
scramble_button = form.form_submit_button(label='Scramble')
submit_button = form.form_submit_button(label='Unscramble')
# Right area
right_area.header("Output")
if scramble_button:
aug = nac.OcrAug()
streamlit.session_state.text = streamlit.session_state.input_text
base_text = streamlit.session_state.text
augmented_data = aug.augment(base_text)
streamlit.session_state.text = augmented_data
del streamlit.session_state.input_text
placeholder.empty()
input_text = placeholder.text_area(value=streamlit.session_state.text, label='Insert text:', key='input_text')
if submit_button:
base_text = streamlit.session_state.input_text
output_text = " ".join([x['generated_text'] for x in ocr_pipeline(wrap(base_text, 128))])
right_area.markdown('#####')
right_area.text_area(value=output_text, label="Corrected text:")