viklofg's picture
Update app.py
2a8aa8e
raw
history blame
4.22 kB
import streamlit as st
from transformers import AutoTokenizer, T5ForConditionalGeneration
import post_ocr
# Sidebar information
info = '''Welcome to the demo of the [swedish-ocr-correction](https://huggingface.co/viklofg/swedish-ocr-correction) model.
Enter or upload OCR output and the model will attempt to correct it.
:clock2: Slow generation? Try a shorter input.
'''
# Example inputs
examples = {
'Examples': None,
'Example 1': 'En Gosse fur plats nu genast ! inetallyrkc, JU 83 Drottninggatan.',
'Example 2': '— Storartad gåfva till Göteborgs Museum. Den i HandelstidniDgens g&rdagsnnmmer omtalade hvalfisken, sorn fångats i Frölnndaviken, har i dag af hr brukspatronen James Dickson blifvit inköpt för 1,500 rdr och skänkt till härvarande Museum.',
'Example 3': 'Sn underlig race att ſtudera, desfa uppſinnare! utropar en Londontidnings fronifôr. Wet ni hur ſtort antalet är af patenter, ſom ſiſtlidet är utfärdades i British Patent Office? Jo, 14,000 ſty>en !! Det kan man ju fkalla en rif rd! Fjorton tuſen uppfinninnar! Herre Gud, hwilfet märkrwoärdigt tidehrvarf wi lefroa i!'
}
# Load model
@st.cache_resource
def load_model():
return T5ForConditionalGeneration.from_pretrained('viklofg/swedish-ocr-correction')
model = load_model()
# Load tokenizer
@st.cache_resource
def load_tokenizer():
return AutoTokenizer.from_pretrained('google/byt5-small')
tokenizer = load_tokenizer()
# Set model and tokenizer
post_ocr.set_model(model, tokenizer)
# Title
st.title(':memo: Swedish OCR correction')
# Input and output areas
tab1, tab2 = st.tabs(["Text input", "From file"])
# Initialize session state
def clean_inputs():
st.session_state.inputs = {'tab1': None, 'tab2': None}
if 'inputs' not in st.session_state:
clean_inputs()
def clean_outputs():
st.session_state.outputs = {'tab1': None, 'tab2': None}
if 'outputs' not in st.session_state:
clean_outputs()
# Sidebar (settings and stuff)
with st.sidebar:
st.header('Welcome')
st.markdown(info)
st.header('Settings')
overlap2candidates = {'None': 1, 'Little': 3, 'Much': 5}
overlap_help = '''Long texts are processed in chunks using a sliding window technique.
Here you can choose how much overlap the sliding window should have with the previous
processed chunk. No overlap is the fastest, but some overlap may increase accuracy.'''
overlap = st.selectbox(
'Overlap',
options=overlap2candidates,
help=overlap_help,
on_change=clean_inputs)
n_candidates = overlap2candidates[overlap]
st.subheader('Output')
show_changes = st.toggle('Show changes')
def handle_input(input_, id_):
"""Generate and display output"""
with st.container(border=True):
st.caption('Output')
# Only update the output if the input has been updated
if input_ and st.session_state.inputs[id_] != input_:
st.session_state.inputs[id_] = input_
with st.spinner('Generating...'):
output = post_ocr.process(input_, n_candidates)
st.session_state.outputs[id_] = output
# Display output
output = st.session_state.outputs[id_]
if output is not None:
st.write(post_ocr.diff(input_, output) if show_changes else output)
# Manual entry tab
with tab1:
col1, col2 = st.columns([4, 1])
with col2:
example_title = st.selectbox('Examples', options=examples,
label_visibility='collapsed')
with col1:
text = st.text_area(
label='Input text',
value=examples[example_titlessssssssssssssssss],
height=200,
label_visibility='collapsed',
placeholder='Enter OCR generated text or choose an example')
if text is not None:
handle_input(text, 'tab1')
# File upload tab
with tab2:
uploaded_file = st.file_uploader('Choose a file', type='.txt')
# Display file content
if uploaded_file is not None:
file_content = uploaded_file.getvalue().decode('utf-8')
text = st.text_area('File content', value=file_content, height=300)
handle_input(text, 'tab2')