doctr / app.py
osanseviero's picture
Update app.py
a044f10
raw
history blame
3.52 kB
import os
import streamlit as st
import tensorflow as tf
import matplotlib.pyplot as plt
from doctr.io import DocumentFile
from doctr.models import ocr_predictor
from doctr.utils.visualization import visualize_page
DET_ARCHS = ['db_resnet50', 'db_mobilenet_v3_large', 'linknet16']
RECO_ARCHS = ["crnn_vgg16_bn", "master", "sar_resnet31"]
def main():
# Wide mode
st.set_page_config(layout="wide")
# Designing the interface
st.title("DocTR: Document Text Recognition")
# For newline
st.write('\n')
#
st.write('Find more info at: https://github.com/mindee/doctr')
# For newline
st.write('\n')
# Instructions
st.markdown("*Hint: click on the top-right corner of an image to enlarge it!*")
# Set the columns
cols = st.beta_columns((1, 1, 1, 1))
cols[0].subheader("Input page")
cols[1].subheader("Segmentation heatmap")
cols[2].subheader("OCR output")
cols[3].subheader("Page reconstitution")
# Sidebar
# File selection
st.sidebar.title("Document selection")
# Disabling warning
st.set_option('deprecation.showfileUploaderEncoding', False)
# Choose your own image
uploaded_file = st.sidebar.file_uploader("Upload files", type=['pdf', 'png', 'jpeg', 'jpg'])
if uploaded_file is not None:
if uploaded_file.name.endswith('.pdf'):
doc = DocumentFile.from_pdf(uploaded_file.read()).as_images()
else:
doc = DocumentFile.from_images(uploaded_file.read())
page_idx = st.sidebar.selectbox("Page selection", [idx + 1 for idx in range(len(doc))]) - 1
cols[0].image(doc[page_idx])
# Model selection
st.sidebar.title("Model selection")
det_arch = st.sidebar.selectbox("Text detection model", DET_ARCHS)
reco_arch = st.sidebar.selectbox("Text recognition model", RECO_ARCHS)
# For newline
st.sidebar.write('\n')
if st.sidebar.button("Analyze page"):
if uploaded_file is None:
st.sidebar.write("Please upload a document")
else:
with st.spinner('Loading model...'):
predictor = ocr_predictor(det_arch, reco_arch, pretrained=True)
with st.spinner('Analyzing...'):
# Forward the image to the model
processed_batches = predictor.det_predictor.pre_processor([doc[page_idx]])
out = predictor.det_predictor.model(processed_batches[0], return_model_output=True)
seg_map = out["out_map"]
seg_map = tf.squeeze(seg_map[0, ...], axis=[2])
seg_map = cv2.resize(seg_map.numpy(), (doc[page_idx].shape[1], doc[page_idx].shape[0]),
interpolation=cv2.INTER_LINEAR)
# Plot the raw heatmap
fig, ax = plt.subplots()
ax.imshow(seg_map)
ax.axis('off')
cols[1].pyplot(fig)
# Plot OCR output
out = predictor([doc[page_idx]])
fig = visualize_page(out.pages[0].export(), doc[page_idx], interactive=False)
cols[2].pyplot(fig)
# Page reconsitution under input page
page_export = out.pages[0].export()
img = out.pages[0].synthesize()
cols[3].image(img, clamp=True)
# Display JSON
st.markdown("\nHere are your analysis results in JSON format:")
st.json(page_export)
if __name__ == '__main__':
main()