TextExtractor / app.py
deepsh2207's picture
Corrected download button
d840e96
raw
history blame
4.93 kB
# import cv2
# import matplotlib.pyplot as plt
import numpy as np
import streamlit as st
import torch
import json
from doctr.io import DocumentFile
from doctr.utils.visualization import visualize_page
from backend.pytorch import DET_ARCHS, RECO_ARCHS, load_predictor #forward_image
forward_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def main(det_archs, reco_archs):
"""Build a streamlit layout"""
# Wide mode
st.set_page_config(layout="wide")
# Designing the interface
st.title("Document Text Extraction")
# For newline
st.write("\n")
# Instructions
st.markdown("*Hint: click on the top-right corner of an image to enlarge it!*")
# Set the columns
# cols = st.columns((1, 1, 1, 1))
cols = st.columns((1, 1, 1))
cols[0].subheader("Input page")
# cols[1].subheader("Segmentation heatmap")
cols[1].subheader("OCR output")
cols[2].subheader("Page reconstitution")
# Sidebar
# File selection
st.sidebar.title("Document selection")
# Choose your own image
uploaded_file = st.sidebar.file_uploader("Upload files", type=["pdf", "png", "jpeg", "jpg"])
if uploaded_file is not None:
if uploaded_file.name.endswith(".pdf"):
doc = DocumentFile.from_pdf(uploaded_file.read())
else:
doc = DocumentFile.from_images(uploaded_file.read())
page_idx = st.sidebar.selectbox("Page selection", [idx + 1 for idx in range(len(doc))]) - 1
page = doc[page_idx]
cols[0].image(page)
# Model selection
st.sidebar.title("Model selection")
det_arch = st.sidebar.selectbox("Text detection model", det_archs)
reco_arch = st.sidebar.selectbox("Text recognition model", reco_archs)
# # For newline
# st.sidebar.write("\n")
# # Only straight pages or possible rotation
# st.sidebar.title("Parameters")
# assume_straight_pages = st.sidebar.checkbox("Assume straight pages", value=True)
# st.sidebar.write("\n")
# # Straighten pages
# straighten_pages = st.sidebar.checkbox("Straighten pages", value=False)
# st.sidebar.write("\n")
# # Binarization threshold
# bin_thresh = st.sidebar.slider("Binarization threshold", min_value=0.1, max_value=0.9, value=0.3, step=0.1)
# st.sidebar.write("\n")
if st.sidebar.button("Analyze page"):
if uploaded_file is None:
st.sidebar.write("Please upload a document")
else:
with st.spinner("Loading model..."):
# Default Values
assume_straight_pages, straighten_pages, bin_thresh = True, False, 0.3
predictor = load_predictor(
det_arch, reco_arch, assume_straight_pages, straighten_pages, bin_thresh, forward_device
)
with st.spinner("Analyzing..."):
# # Forward the image to the model
# seg_map = forward_image(predictor, page, forward_device)
# seg_map = np.squeeze(seg_map)
# seg_map = cv2.resize(seg_map, (page.shape[1], page.shape[0]), interpolation=cv2.INTER_LINEAR)
# # Plot the raw heatmap
# fig, ax = plt.subplots()
# ax.imshow(seg_map)
# ax.axis("off")
# cols[1].pyplot(fig)
# Plot OCR output
out = predictor([page])
fig = visualize_page(out.pages[0].export(), out.pages[0].page, interactive=False, add_labels=False)
cols[1].pyplot(fig)
# Page reconsitution under input page
page_export = out.pages[0].export()
if assume_straight_pages or (not assume_straight_pages and straighten_pages):
img = out.pages[0].synthesize()
cols[2].image(img, clamp=True)
print('out',out)
print('\n')
print('page_export',page_export)
print('\n')
all_text = ''
for i in page_export['blocks']:
for line in i['lines']:
for word in line['words']:
all_text+=word['value']
all_text+=' '
all_text+='\n'
print('all_text', all_text)
print('\n')
# Display Text
st.markdown("\n## **Here is your text:**")
st.write(all_text)
# Display JSON
json_string = json.dumps(page_export)
st.markdown("\n## **Here are your analysis results in JSON format:**")
st.download_button(label="Download JSON", data=json_string, file_name='data.json', mime='application/json')
st.json(page_export, expanded=False)
if __name__ == "__main__":
main(DET_ARCHS, RECO_ARCHS)