File size: 4,719 Bytes
d812fea
 
640c986
 
 
 
 
 
 
d812fea
640c986
 
 
 
 
 
 
 
 
 
383e64a
640c986
 
 
 
 
383e64a
 
640c986
383e64a
cdd9cb4
 
640c986
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d812fea
 
 
 
 
 
 
 
 
 
 
 
640c986
 
 
 
 
 
 
d812fea
 
 
640c986
 
 
 
 
cdd9cb4
 
 
 
 
 
 
 
 
 
640c986
 
 
 
cdd9cb4
640c986
 
 
 
 
cdd9cb4
640c986
d812fea
 
 
 
84f3cde
 
 
 
 
 
 
 
 
 
 
 
 
 
d812fea
640c986
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# import cv2
# import matplotlib.pyplot as plt
import numpy as np
import streamlit as st
import torch

from doctr.io import DocumentFile
from doctr.utils.visualization import visualize_page

from backend.pytorch import DET_ARCHS, RECO_ARCHS, load_predictor #forward_image

forward_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


def main(det_archs, reco_archs):
    """Build a streamlit layout"""
    # Wide mode
    st.set_page_config(layout="wide")

    # Designing the interface
    st.title("Document Text Extraction")
    # For newline
    st.write("\n")
    # Instructions
    st.markdown("*Hint: click on the top-right corner of an image to enlarge it!*")
    # Set the columns
    # cols = st.columns((1, 1, 1, 1))
    cols = st.columns((1, 1, 1))
    cols[0].subheader("Input page")
    # cols[1].subheader("Segmentation heatmap")
    cols[1].subheader("OCR output")
    cols[2].subheader("Page reconstitution")

    # Sidebar
    # File selection
    st.sidebar.title("Document selection")
    # Choose your own image
    uploaded_file = st.sidebar.file_uploader("Upload files", type=["pdf", "png", "jpeg", "jpg"])
    if uploaded_file is not None:
        if uploaded_file.name.endswith(".pdf"):
            doc = DocumentFile.from_pdf(uploaded_file.read())
        else:
            doc = DocumentFile.from_images(uploaded_file.read())
        page_idx = st.sidebar.selectbox("Page selection", [idx + 1 for idx in range(len(doc))]) - 1
        page = doc[page_idx]
        cols[0].image(page)

    # Model selection
    st.sidebar.title("Model selection")
    det_arch = st.sidebar.selectbox("Text detection model", det_archs)
    reco_arch = st.sidebar.selectbox("Text recognition model", reco_archs)

    # # For newline
    # st.sidebar.write("\n")
    # # Only straight pages or possible rotation
    # st.sidebar.title("Parameters")
    # assume_straight_pages = st.sidebar.checkbox("Assume straight pages", value=True)
    # st.sidebar.write("\n")
    # # Straighten pages
    # straighten_pages = st.sidebar.checkbox("Straighten pages", value=False)
    # st.sidebar.write("\n")
    # # Binarization threshold
    # bin_thresh = st.sidebar.slider("Binarization threshold", min_value=0.1, max_value=0.9, value=0.3, step=0.1)
    # st.sidebar.write("\n")

    if st.sidebar.button("Analyze page"):
        if uploaded_file is None:
            st.sidebar.write("Please upload a document")

        else:
            with st.spinner("Loading model..."):
                # Default Values
                assume_straight_pages, straighten_pages, bin_thresh = True, False, 0.3

                predictor = load_predictor(
                    det_arch, reco_arch, assume_straight_pages, straighten_pages, bin_thresh, forward_device
                )

            with st.spinner("Analyzing..."):
                # # Forward the image to the model
                # seg_map = forward_image(predictor, page, forward_device)
                # seg_map = np.squeeze(seg_map)
                # seg_map = cv2.resize(seg_map, (page.shape[1], page.shape[0]), interpolation=cv2.INTER_LINEAR)

                # # Plot the raw heatmap
                # fig, ax = plt.subplots()
                # ax.imshow(seg_map)
                # ax.axis("off")
                # cols[1].pyplot(fig)

                # Plot OCR output
                out = predictor([page])
                fig = visualize_page(out.pages[0].export(), out.pages[0].page, interactive=False, add_labels=False)
                cols[1].pyplot(fig)

                # Page reconsitution under input page
                page_export = out.pages[0].export()
                if assume_straight_pages or (not assume_straight_pages and straighten_pages):
                    img = out.pages[0].synthesize()
                    cols[2].image(img, clamp=True)

                print('out',out)
                print('\n')
                print('page_export',page_export)
                print('\n')
                all_text = ''
                for i in page_export['blocks']:
                    for line in i['lines']:
                        for word in line['words']:
                            all_text+=word['value']
                            all_text+=' '
                    all_text+='\n'
                
                print('all_text', all_text)
                print('\n')

                # Display Text
                st.markdown("\nHere is your text:")
                st.write(all_text)

                # Display JSON
                st.markdown("\nHere are your analysis results in JSON format:")
                st.json(page_export, expanded=False)


if __name__ == "__main__":
    main(DET_ARCHS, RECO_ARCHS)