Thomas Dehaene commited on
Commit
af1842e
·
1 Parent(s): e948df6

Bump steamlit

Browse files
Files changed (1) hide show
  1. app.py +16 -16
app.py CHANGED
@@ -5,30 +5,30 @@ import nlpaug.augmenter.char as nac
5
  import subprocess
6
  import sys
7
  import logging
 
8
 
9
  def install():
10
  subprocess.check_call([sys.executable, "-m", "pip", "install", "streamlit==0.89.0"])
11
 
12
  install()
13
 
14
- import streamlit as st
15
-
16
  logging.warning(st.__version__)
17
 
18
 
19
- st.markdown('# ByT5 Dutch OCR Corrector :pill:')
20
- st.write('This app corrects common dutch OCR mistakes, to showcase how this could be used in an OCR post-processing pipeline.')
21
 
22
- st.markdown("""
23
  To use this:
24
  - Enter a text with OCR mistakes and hit 'unscramble':point_down:
25
  - Or enter a normal text, scramble it :twisted_rightwards_arrows: and then hit 'unscramble' :point_down:""")
26
 
27
- @st.cache(allow_output_mutation=True,
28
  suppress_st_warning=True,
29
  show_spinner=False)
30
  def load_model():
31
- with st.spinner('Please wait for the model to load...'):
32
  ocr_pipeline=pipeline(
33
  'text2text-generation',
34
  model='ml6team/byt5-base-dutch-ocr-correction',
@@ -39,10 +39,10 @@ def load_model():
39
  ocr_pipeline = load_model()
40
 
41
 
42
- if 'text' not in st.session_state:
43
- st.session_state.text = ""
44
 
45
- left_area, right_area = st.beta_columns(2)
46
 
47
  # Format the left area
48
  left_area.header("Input")
@@ -58,16 +58,16 @@ right_area.header("Output")
58
 
59
  if scramble_button:
60
  aug = nac.OcrAug()
61
- st.session_state.text = st.session_state.input_text
62
- base_text = st.session_state.text
63
  augmented_data = aug.augment(base_text)
64
- st.session_state.text = augmented_data
65
- del st.session_state.input_text
66
  placeholder.empty()
67
- input_text = placeholder.text_area(value=st.session_state.text, label='Insert text:', key='input_text')
68
 
69
  if submit_button:
70
- base_text = st.session_state.input_text
71
  output_text = " ".join([x['generated_text'] for x in ocr_pipeline(wrap(base_text, 128))])
72
  right_area.markdown('#####')
73
  right_area.text_area(value=output_text, label="Corrected text:")
 
5
  import subprocess
6
  import sys
7
  import logging
8
+ import importlib
9
 
10
  def install():
11
  subprocess.check_call([sys.executable, "-m", "pip", "install", "streamlit==0.89.0"])
12
 
13
  install()
14
 
15
+ importlib.reload(streamlit)
 
16
  logging.warning(st.__version__)
17
 
18
 
19
+ streamlit.markdown('# ByT5 Dutch OCR Corrector :pill:')
20
+ streamlit.write('This app corrects common dutch OCR mistakes, to showcase how this could be used in an OCR post-processing pipeline.')
21
 
22
+ streamlit.markdown("""
23
  To use this:
24
  - Enter a text with OCR mistakes and hit 'unscramble':point_down:
25
  - Or enter a normal text, scramble it :twisted_rightwards_arrows: and then hit 'unscramble' :point_down:""")
26
 
27
+ @streamlit.cache(allow_output_mutation=True,
28
  suppress_st_warning=True,
29
  show_spinner=False)
30
  def load_model():
31
+ with streamlit.spinner('Please wait for the model to load...'):
32
  ocr_pipeline=pipeline(
33
  'text2text-generation',
34
  model='ml6team/byt5-base-dutch-ocr-correction',
 
39
  ocr_pipeline = load_model()
40
 
41
 
42
+ if 'text' not in streamlit.session_state:
43
+ streamlit.session_state.text = ""
44
 
45
+ left_area, right_area = streamlit.beta_columns(2)
46
 
47
  # Format the left area
48
  left_area.header("Input")
 
58
 
59
  if scramble_button:
60
  aug = nac.OcrAug()
61
+ streamlit.session_state.text = streamlit.session_state.input_text
62
+ base_text = streamlit.session_state.text
63
  augmented_data = aug.augment(base_text)
64
+ streamlit.session_state.text = augmented_data
65
+ del streamlit.session_state.input_text
66
  placeholder.empty()
67
+ input_text = placeholder.text_area(value=streamlit.session_state.text, label='Insert text:', key='input_text')
68
 
69
  if submit_button:
70
+ base_text = streamlit.session_state.input_text
71
  output_text = " ".join([x['generated_text'] for x in ocr_pipeline(wrap(base_text, 128))])
72
  right_area.markdown('#####')
73
  right_area.text_area(value=output_text, label="Corrected text:")