File size: 2,108 Bytes
d5101c4
 
 
 
032c174
d5101c4
032c174
 
d5101c4
032c174
d5101c4
 
 
 
34eb802
d5101c4
 
 
34eb802
d5101c4
 
 
 
 
 
 
 
 
34eb802
032c174
d5101c4
032c174
d5101c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34eb802
 
d5101c4
34eb802
 
d5101c4
34eb802
d5101c4
 
34eb802
d5101c4
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from textwrap import wrap 

from transformers import pipeline
import nlpaug.augmenter.char as nac
import streamlit as st

st.markdown('# ByT5 Dutch OCR Corrector :pill:')
st.write('This app corrects common dutch OCR mistakes, to showcase how this could be used in an OCR post-processing pipeline.')

st.markdown("""
To use this:
- Enter a text with OCR mistakes and hit 'unscramble':point_down:
- Or enter a normal text, scramble it :twisted_rightwards_arrows: and then hit 'unscramble' :point_down:""")

@st.cache(allow_output_mutation=True,
          suppress_st_warning=True,
          show_spinner=False)
def load_model():
    with st.spinner('Please wait for the model to load...'):
        ocr_pipeline=pipeline(
            'text2text-generation',
            model='ml6team/byt5-base-dutch-ocr-correction',
            tokenizer='ml6team/byt5-base-dutch-ocr-correction'
        )
    return ocr_pipeline

ocr_pipeline = load_model()

if 'text' not in st.session_state:
    st.session_state['text'] = ""

left_area, right_area = st.columns(2)

# Format the left area
left_area.header("Input")
form = left_area.form(key='ocrcorrector')
placeholder = form.empty()
placeholder.empty()
input_text = placeholder.text_area(value=st.session_state.text, label='Insert text:', key='input_text')
scramble_button = form.form_submit_button(label='Scramble')
submit_button = form.form_submit_button(label='Unscramble')

# Right area
right_area.header("Output")

if scramble_button:
    aug = nac.OcrAug()
    st.session_state.text = st.session_state.input_text
    base_text = st.session_state.text
    augmented_data = aug.augment(base_text)
    st.session_state.text = augmented_data
    del st.session_state.input_text
    placeholder.empty()
    input_text = placeholder.text_area(value=st.session_state.text, label='Insert text:', key='input_text')

if submit_button:
    base_text = st.session_state.input_text
    output_text = " ".join([x['generated_text'] for x in ocr_pipeline(wrap(base_text, 128))])
    right_area.markdown('#####')
    right_area.text_area(value=output_text, label="Corrected text:")