Spaces:
Sleeping
Sleeping
File size: 2,108 Bytes
d5101c4 032c174 d5101c4 032c174 d5101c4 032c174 d5101c4 34eb802 d5101c4 34eb802 d5101c4 34eb802 032c174 d5101c4 032c174 d5101c4 34eb802 d5101c4 34eb802 d5101c4 34eb802 d5101c4 34eb802 d5101c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
from textwrap import wrap
from transformers import pipeline
import nlpaug.augmenter.char as nac
import streamlit as st
st.markdown('# ByT5 Dutch OCR Corrector :pill:')
st.write('This app corrects common dutch OCR mistakes, to showcase how this could be used in an OCR post-processing pipeline.')
st.markdown("""
To use this:
- Enter a text with OCR mistakes and hit 'unscramble':point_down:
- Or enter a normal text, scramble it :twisted_rightwards_arrows: and then hit 'unscramble' :point_down:""")
@st.cache(allow_output_mutation=True,
suppress_st_warning=True,
show_spinner=False)
def load_model():
with st.spinner('Please wait for the model to load...'):
ocr_pipeline=pipeline(
'text2text-generation',
model='ml6team/byt5-base-dutch-ocr-correction',
tokenizer='ml6team/byt5-base-dutch-ocr-correction'
)
return ocr_pipeline
ocr_pipeline = load_model()
if 'text' not in st.session_state:
st.session_state['text'] = ""
left_area, right_area = st.columns(2)
# Format the left area
left_area.header("Input")
form = left_area.form(key='ocrcorrector')
placeholder = form.empty()
placeholder.empty()
input_text = placeholder.text_area(value=st.session_state.text, label='Insert text:', key='input_text')
scramble_button = form.form_submit_button(label='Scramble')
submit_button = form.form_submit_button(label='Unscramble')
# Right area
right_area.header("Output")
if scramble_button:
aug = nac.OcrAug()
st.session_state.text = st.session_state.input_text
base_text = st.session_state.text
augmented_data = aug.augment(base_text)
st.session_state.text = augmented_data
del st.session_state.input_text
placeholder.empty()
input_text = placeholder.text_area(value=st.session_state.text, label='Insert text:', key='input_text')
if submit_button:
base_text = st.session_state.input_text
output_text = " ".join([x['generated_text'] for x in ocr_pipeline(wrap(base_text, 128))])
right_area.markdown('#####')
right_area.text_area(value=output_text, label="Corrected text:")
|