File size: 2,449 Bytes
d5101c4
 
 
 
630dd76
 
7b11265
af1842e
4230a5d
630dd76
 
 
 
 
 
af1842e
e948df6
7b11265
d5101c4
af1842e
 
d5101c4
af1842e
d5101c4
 
 
 
af1842e
d5101c4
 
 
af1842e
d5101c4
 
 
 
 
 
 
 
 
630dd76
af1842e
 
d5101c4
af1842e
d5101c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af1842e
 
d5101c4
af1842e
 
d5101c4
af1842e
d5101c4
 
af1842e
d5101c4
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
from textwrap import wrap 

from transformers import pipeline
import nlpaug.augmenter.char as nac
import subprocess
import sys
import logging
import importlib
import streamlit

def install():
    subprocess.check_call([sys.executable, "-m", "pip", "install", "streamlit==0.89.0"])

install()

importlib.reload(streamlit)
logging.warning(st.__version__)


streamlit.markdown('# ByT5 Dutch OCR Corrector :pill:')
streamlit.write('This app corrects common dutch OCR mistakes, to showcase how this could be used in an OCR post-processing pipeline.')

streamlit.markdown("""
To use this:
- Enter a text with OCR mistakes and hit 'unscramble':point_down:
- Or enter a normal text, scramble it :twisted_rightwards_arrows: and then hit 'unscramble' :point_down:""")

@streamlit.cache(allow_output_mutation=True,
          suppress_st_warning=True,
          show_spinner=False)
def load_model():
    with streamlit.spinner('Please wait for the model to load...'):
        ocr_pipeline=pipeline(
            'text2text-generation',
            model='ml6team/byt5-base-dutch-ocr-correction',
            tokenizer='ml6team/byt5-base-dutch-ocr-correction'
        )
    return ocr_pipeline

ocr_pipeline = load_model()


if 'text' not in streamlit.session_state:
    streamlit.session_state.text = ""

left_area, right_area = streamlit.beta_columns(2)

# Format the left area
left_area.header("Input")
form = left_area.form(key='ocrcorrector')
placeholder = form.empty()
placeholder.empty()
input_text = placeholder.text_area(value=st.session_state.text, label='Insert text:', key='input_text')
scramble_button = form.form_submit_button(label='Scramble')
submit_button = form.form_submit_button(label='Unscramble')

# Right area
right_area.header("Output")

if scramble_button:
    aug = nac.OcrAug()
    streamlit.session_state.text = streamlit.session_state.input_text
    base_text = streamlit.session_state.text
    augmented_data = aug.augment(base_text)
    streamlit.session_state.text = augmented_data
    del streamlit.session_state.input_text
    placeholder.empty()
    input_text = placeholder.text_area(value=streamlit.session_state.text, label='Insert text:', key='input_text')

if submit_button:
    base_text = streamlit.session_state.input_text
    output_text = " ".join([x['generated_text'] for x in ocr_pipeline(wrap(base_text, 128))])
    right_area.markdown('#####')
    right_area.text_area(value=output_text, label="Corrected text:")