File size: 5,468 Bytes
d2f149b
 
 
 
919648a
 
 
 
 
 
d2f149b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
919648a
 
 
d2f149b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
919648a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
import sys

if "APP_PATH" in os.environ:
    app_path = os.path.abspath(os.environ["APP_PATH"])
    if os.getcwd() != app_path:
        # fix sys.path for import
        os.chdir(app_path)
    if app_path not in sys.path:
        sys.path.append(app_path)

import gradio as gr

from tabled.assignment import assign_rows_columns
from tabled.fileinput import load_pdfs_images
from tabled.formats.markdown import markdown_format
from tabled.inference.detection import detect_tables
from tabled.inference.recognition import get_cells, recognize_tables

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
os.environ["IN_STREAMLIT"] = "true"
import pypdfium2

from PIL import Image

from tabled.inference.models import load_detection_models, load_recognition_models, load_layout_models

def load_models():
    return load_detection_models(), load_recognition_models(), load_layout_models()

def run_table_rec(image, highres_image, text_line, models, skip_detection=False, detect_boxes=False):
    if not skip_detection:
        table_imgs, table_bboxes, _ = detect_tables([image], [highres_image], models[2])
    else:
        table_imgs = [highres_image]
        table_bboxes = [[0, 0, highres_image.size[0], highres_image.size[1]]]

    table_text_lines = [text_line] * len(table_imgs)
    highres_image_sizes = [highres_image.size] * len(table_imgs)
    cells, needs_ocr = get_cells(table_imgs, table_bboxes, highres_image_sizes, table_text_lines, models[0], detect_boxes=detect_boxes)

    table_rec = recognize_tables(table_imgs, cells, needs_ocr, models[1])
    cells = [assign_rows_columns(tr, im_size) for tr, im_size in zip(table_rec, highres_image_sizes)]

    out_data = []
    for idx, (cell, pred, table_img) in enumerate(zip(cells, table_rec, table_imgs)):
        md = markdown_format(cell)
        out_data.append((md, table_img))
    return out_data

def open_pdf(pdf_file):
    return pypdfium2.PdfDocument(pdf_file)

def count_pdf(pdf_file):
    doc = open_pdf(pdf_file)
    return len(doc)

def get_page_image(pdf_file, page_num, dpi=96):
    doc = open_pdf(pdf_file)
    renderer = doc.render(
        pypdfium2.PdfBitmap.to_pil,
        page_indices=[page_num - 1],
        scale=dpi / 72,
    )
    png = list(renderer)[0]
    png_image = png.convert("RGB")
    return png_image

def get_uploaded_image(in_file):
    return Image.open(in_file).convert("RGB")

# Load models if not already loaded in reload mode
if 'models' not in globals():
    models = load_models()

with gr.Blocks(title="Tabled") as demo:
    gr.Markdown("""
    # Tabled Demo

    This app will let you try tabled, a table detection and recognition model.  It will detect and recognize the tables.

    Find the project [here](https://github.com/VikParuchuri/tabled).
    """)

    with gr.Row():
        with gr.Column():
            in_file = gr.File(label="PDF file or image:", file_types=[".pdf", ".png", ".jpg", ".jpeg", ".gif", ".webp"])
            in_num = gr.Slider(label="Page number", minimum=1, maximum=100, value=1, step=1)
            in_img = gr.Image(label="Page of Image", type="pil", sources=None)

            skip_detection_ckb = gr.Checkbox(label="Skip table detection", info="Use this if tables are already cropped (the whole PDF page or image is a table)")
            detect_boxes_ckb = gr.Checkbox(label="Detect cell boxes", info="Detect table cell boxes vs extract from PDF.  Will also run OCR.")

            tabled_btn = gr.Button("Run Tabled")
        with gr.Column():
            result_img = gr.Gallery(label="Result images", show_label=False, columns=[1], rows=[10], object_fit="contain", height="auto")
            result_md = gr.Markdown(label="Result markdown")

        def show_image(file, num=1):
            if file.endswith('.pdf'):
                count = count_pdf(file)
                img = get_page_image(file, num)
                return [
                    gr.update(visible=True, maximum=count),
                    gr.update(value=img)]
            else:
                img = get_uploaded_image(file)
                return [
                    gr.update(visible=False),
                    gr.update(value=img)]

        in_file.upload(
            fn=show_image,
            inputs=[in_file],
            outputs=[in_num, in_img],
        )
        in_num.change(
            fn=show_image,
            inputs=[in_file, in_num],
            outputs=[in_num, in_img],
        )

        # Run OCR
        def text_rec_img(in_file, page_number, skip_detection, detect_boxes):
            images, highres_images, names, text_lines = load_pdfs_images(in_file, max_pages=1, start_page=page_number - 1)

            out_data = run_table_rec(images[0], highres_images[0], text_lines[0], models, skip_detection=skip_detection, detect_boxes=detect_boxes)

            if len(out_data) == 0:
                gr.Warning(f"No tables found on page {page_number}", duration=5)

            table_md = ""
            table_imgs = []
            for idx, (md, table_img) in enumerate(out_data):
                table_md += f"\n## Table {idx}\n"
                table_md += md
                table_imgs.append(table_img)
            return gr.update(rows=[len(table_imgs)], value=table_imgs), table_md
        tabled_btn.click(
            fn=text_rec_img,
            inputs=[in_file, in_num, skip_detection_ckb, detect_boxes_ckb],
            outputs=[result_img, result_md]
        )

if __name__ == "__main__":
    demo.launch()