deepdoctection / app.py
JaMe76's picture
Update app.py
b210eb5
raw
history blame
13.2 kB
import os
import importlib.metadata
from os import getcwd, path, environ
from dotenv import load_dotenv
import json
def check_additional_requirements():
if importlib.util.find_spec("detectron2") is None:
os.system('pip install detectron2@git+https://github.com/facebookresearch/detectron2.git')
if importlib.util.find_spec("gradio") is not None:
if importlib.metadata.version("gradio")!="3.44.3":
os.system("pip uninstall -y gradio")
os.system("pip install gradio==3.44.3")
else:
os.system("pip install gradio==3.44.3")
os.system(os.environ["DD_ADDONS"])
return
load_dotenv()
check_additional_requirements()
import deepdoctection as dd
from deepdoctection.dataflow.serialize import DataFromList
import time
from dd_addons.extern import PdfTextDetector, PostProcessor, get_xsl_path
from dd_addons.pipe.conn import PostProcessorService
import gradio as gr
from botocore.config import Config
# work around: https://discuss.huggingface.co/t/how-to-install-a-specific-version-of-gradio-in-spaces/13552
_DD_ONE = "conf_dd_one.yaml"
_XSL_PATH = get_xsl_path()
dd.ModelCatalog.register("xrf_layout/model_final_inf_only.pt",dd.ModelProfile(
name="xrf_layout/model_final_inf_only.pt",
description="layout_detection/morning-dragon-114",
config="xrf_dd/layout/CASCADE_RCNN_R_50_FPN_GN.yaml",
size=[274632215],
tp_model=False,
hf_repo_id=environ.get("HF_REPO_LAYOUT"),
hf_model_name="model_final_inf_only.pt",
hf_config_file=["Base-RCNN-FPN.yaml", "CASCADE_RCNN_R_50_FPN_GN.yaml"],
categories={"1": dd.LayoutType.text,
"2": dd.LayoutType.title,
"3": dd.LayoutType.list,
"4": dd.LayoutType.table,
"5": dd.LayoutType.figure},
model_wrapper="D2FrcnnDetector",
))
dd.ModelCatalog.register("xrf_cell/model_final_inf_only.pt", dd.ModelProfile(
name="xrf_cell/model_final_inf_only.pt",
description="cell_detection/restful-eon-6",
config="xrf_dd/cell/CASCADE_RCNN_R_50_FPN_GN.yaml",
size=[274583063],
tp_model=False,
hf_repo_id=environ.get("HF_REPO_CELL"),
hf_model_name="model_final_inf_only.pt",
hf_config_file=["Base-RCNN-FPN.yaml", "CASCADE_RCNN_R_50_FPN_GN.yaml"],
categories={"1": dd.LayoutType.cell},
model_wrapper="D2FrcnnDetector",
))
dd.ModelCatalog.register("xrf_item/model_final_inf_only.pt", dd.ModelProfile(
name="xrf_item/model_final_inf_only.pt",
description="item_detection/firm_plasma_14",
config="xrf_dd/item/CASCADE_RCNN_R_50_FPN_GN.yaml",
size=[274595351],
tp_model=False,
hf_repo_id=environ.get("HF_REPO_ITEM"),
hf_model_name="model_final_inf_only.pt",
hf_config_file=["Base-RCNN-FPN.yaml", "CASCADE_RCNN_R_50_FPN_GN.yaml"],
categories={"1": dd.LayoutType.row, "2": dd.LayoutType.column},
model_wrapper="D2FrcnnDetector",
))
# Set up of the configuration and logging. Models are globally defined, so that they are not re-loaded once the input
# updates
cfg = dd.set_config_by_yaml(path.join(getcwd(),_DD_ONE))
cfg.freeze(freezed=False)
cfg.DEVICE = "cpu"
cfg.freeze()
# layout detector
layout_config_path = dd.ModelCatalog.get_full_path_configs(cfg.CONFIG.D2LAYOUT)
layout_weights_path = dd.ModelDownloadManager.maybe_download_weights_and_configs(cfg.WEIGHTS.D2LAYOUT)
categories_layout = dd.ModelCatalog.get_profile(cfg.WEIGHTS.D2LAYOUT).categories
assert categories_layout is not None
assert layout_weights_path is not None
d_layout = dd.D2FrcnnDetector(layout_config_path, layout_weights_path, categories_layout, device=cfg.DEVICE)
# cell detector
cell_config_path = dd.ModelCatalog.get_full_path_configs(cfg.CONFIG.D2CELL)
cell_weights_path = dd.ModelDownloadManager.maybe_download_weights_and_configs(cfg.WEIGHTS.D2CELL)
categories_cell = dd.ModelCatalog.get_profile(cfg.WEIGHTS.D2CELL).categories
assert categories_cell is not None
d_cell = dd.D2FrcnnDetector(cell_config_path, cell_weights_path, categories_cell, device=cfg.DEVICE)
# row/column detector
item_config_path = dd.ModelCatalog.get_full_path_configs(cfg.CONFIG.D2ITEM)
item_weights_path = dd.ModelDownloadManager.maybe_download_weights_and_configs(cfg.WEIGHTS.D2ITEM)
categories_item = dd.ModelCatalog.get_profile(cfg.WEIGHTS.D2ITEM).categories
assert categories_item is not None
d_item = dd.D2FrcnnDetector(item_config_path, item_weights_path, categories_item, device=cfg.DEVICE)
# pdf miner
pdf_text = PdfTextDetector(_XSL_PATH)
# text detector
credentials_kwargs={"aws_access_key_id": os.environ["ACCESS_KEY"],
"aws_secret_access_key": os.environ["SECRET_KEY"],
"config": Config(region_name=os.environ["REGION"])}
tex_text = dd.TextractOcrDetector(**credentials_kwargs)
def build_gradio_analyzer():
"""Building the Detectron2/DocTr analyzer based on the given config"""
cfg.freeze(freezed=False)
cfg.TAB = True
cfg.TAB_REF = True
cfg.OCR = True
cfg.freeze()
pipe_component_list = []
layout = dd.ImageLayoutService(d_layout, to_image=True, crop_image=True)
pipe_component_list.append(layout)
nms_service = dd.AnnotationNmsService(nms_pairs=cfg.LAYOUT_NMS_PAIRS.COMBINATIONS,
thresholds=cfg.LAYOUT_NMS_PAIRS.THRESHOLDS)
pipe_component_list.append(nms_service)
if cfg.TAB:
detect_result_generator = dd.DetectResultGenerator(categories_cell)
cell = dd.SubImageLayoutService(d_cell, dd.LayoutType.table, {1: 6}, detect_result_generator)
pipe_component_list.append(cell)
detect_result_generator = dd.DetectResultGenerator(categories_item)
item = dd.SubImageLayoutService(d_item, dd.LayoutType.table, {1: 7, 2: 8}, detect_result_generator)
pipe_component_list.append(item)
table_segmentation = dd.TableSegmentationService(
cfg.SEGMENTATION.ASSIGNMENT_RULE,
cfg.SEGMENTATION.THRESHOLD_ROWS,
cfg.SEGMENTATION.THRESHOLD_COLS,
cfg.SEGMENTATION.FULL_TABLE_TILING,
cfg.SEGMENTATION.REMOVE_IOU_THRESHOLD_ROWS,
cfg.SEGMENTATION.REMOVE_IOU_THRESHOLD_COLS,
cfg.SEGMENTATION.STRETCH_RULE
)
pipe_component_list.append(table_segmentation)
if cfg.TAB_REF:
table_segmentation_refinement = dd.TableSegmentationRefinementService()
pipe_component_list.append(table_segmentation_refinement)
if cfg.OCR:
d_text = dd.TextExtractionService(pdf_text)
pipe_component_list.append(d_text)
t_text = dd.TextExtractionService(tex_text,skip_if_text_extracted=True)
pipe_component_list.append(t_text)
match_words = dd.MatchingService(
parent_categories=cfg.WORD_MATCHING.PARENTAL_CATEGORIES,
child_categories=cfg.WORD_MATCHING.CHILD_CATEGORIES,
matching_rule=cfg.WORD_MATCHING.RULE,
threshold=cfg.WORD_MATCHING.THRESHOLD,
max_parent_only=cfg.WORD_MATCHING.MAX_PARENT_ONLY
)
pipe_component_list.append(match_words)
order = dd.TextOrderService(
text_container=cfg.TEXT_ORDERING.TEXT_CONTAINER,
floating_text_block_categories=cfg.TEXT_ORDERING.FLOATING_TEXT_BLOCK,
text_block_categories=cfg.TEXT_ORDERING.TEXT_BLOCK,
include_residual_text_container=cfg.TEXT_ORDERING.TEXT_CONTAINER_TO_TEXT_BLOCK)
pipe_component_list.append(order)
pipe = dd.DoctectionPipe(pipeline_component_list=pipe_component_list)
post_processor = PostProcessor("deepdoctection", **credentials_kwargs)
post_service = PostProcessorService(post_processor)
pipe_component_list.append(post_service)
return pipe
def analyze_image(img, pdf, max_datapoints):
# creating an image object and passing to the analyzer by using dataflows
analyzer = build_gradio_analyzer()
if img is not None:
image = dd.Image(file_name=str(time.time()).replace(".","") + ".png", location="")
image.image = img[:, :, ::-1]
df = DataFromList(lst=[image])
df = analyzer.analyze(dataset_dataflow=df)
elif pdf:
df = analyzer.analyze(path=pdf.name, max_datapoints=max_datapoints)
else:
raise ValueError
df.reset_state()
layout_items_str = ""
jsonl_out = []
dpts = []
html_list = []
for dp in df:
dpts.append(dp)
out = dp.as_dict()
jsonl_out.append(out)
out.pop("_image")
layout_items = [layout for layout in dp.layouts if layout.reading_order is not None]
layout_items.sort(key=lambda x: x.reading_order)
layout_items_str += f"\n\n -------- PAGE NUMBER: {dp.page_number+1} ------------- \n"
for item in layout_items:
layout_items_str += f"\n {item.category_name}: {item.text}"
html_list.extend([table.html for table in dp.tables])
if html_list:
html = ("<br /><br /><br />").join(html_list)
else:
html = None
json_object = json.dumps(jsonl_out, indent = 4)
return [dp.viz(show_cells=False) for dp in dpts], layout_items_str, html, json_object
demo = gr.Blocks(css="scrollbar.css")
with demo:
with gr.Box():
gr.Markdown("<h1><center>deepdoctection - A Document AI Package</center></h1>")
gr.Markdown("<strong>deep</strong>doctection is a Python library that orchestrates document extraction"
" and document layout analysis tasks using deep learning models. It does not implement models"
" but enables you to build pipelines using highly acknowledged libraries for object detection,"
" OCR and selected NLP tasks and provides an integrated frameworks for fine-tuning, evaluating"
" and running models.<br />"
"This pipeline consists of a stack of models powered by <strong>Detectron2"
"</strong> for layout analysis and table recognition. OCR will be provided as well. You can process"
"an image or even a PDF-document. Up to nine pages can be processed. <br />")
gr.Markdown("<center>Please note: The models for layout detection and table recognition are not open sourced.
When you start using deepdoctection you will get models that have been trained on less diversified data and that will perform worse.
OCR isn't open sourced either: It uses AWS Textract, which is a commercial service. Keep this in mind, before you get started with
your installation and observe dissapointing results. Thanks. </center>")
gr.Markdown("[https://github.com/deepdoctection/deepdoctection](https://github.com/deepdoctection/deepdoctection)")
with gr.Box():
gr.Markdown("<h2><center>Upload a document and choose setting</center></h2>")
with gr.Row():
with gr.Column():
with gr.Tab("Image upload"):
with gr.Column():
inputs = gr.Image(type='numpy', label="Original Image")
with gr.Tab("PDF upload (only first image will be processed) *"):
with gr.Column():
inputs_pdf = gr.File(label="PDF")
gr.Markdown("<sup>* If an image is cached in tab, remove it first</sup>")
with gr.Column():
gr.Examples(
examples=[path.join(getcwd(), "sample_1.jpg"), path.join(getcwd(), "sample_2.png")],
inputs = inputs)
gr.Examples(examples=[path.join(getcwd(), "sample_3.pdf")], inputs = inputs_pdf)
with gr.Row():
max_imgs = gr.Slider(1, 8, value=2, step=1, label="Number of pages in multi page PDF",
info="Will stop after 9 pages")
with gr.Row():
btn = gr.Button("Run model", variant="primary")
with gr.Box():
gr.Markdown("<h2><center>Outputs</center></h2>")
with gr.Row():
with gr.Column():
with gr.Box():
gr.Markdown("<center><strong>Contiguous text</strong></center>")
image_text = gr.Textbox()
with gr.Column():
with gr.Box():
gr.Markdown("<center><strong>Layout detection</strong></center>")
gallery = gr.Gallery(
label="Output images", show_label=False, elem_id="gallery"
).style(grid=2)
with gr.Row():
with gr.Box():
gr.Markdown("<center><strong>Table</strong></center>")
html = gr.HTML()
with gr.Row():
with gr.Box():
gr.Markdown("<center><strong>JSON</strong></center>")
json_output = gr.JSON()
btn.click(fn=analyze_image, inputs=[inputs, inputs_pdf, max_imgs],
outputs=[gallery, image_text, html, json_output])
demo.launch()