Document-AI-GPT / app.py
JaMe76's picture
add files
bc391b2
raw
history blame
6.1 kB
import os
os.system('pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu102/torch1.9/index.html')
credentials_kwargs={"aws_access_key_id": os.environ["ACCESS_KEY"],"aws_secret_access_key": os.environ["SECRET_KEY"]}
# work around: https://discuss.huggingface.co/t/how-to-install-a-specific-version-of-gradio-in-spaces/13552
os.system("pip uninstall -y gradio")
os.system("pip install gradio==3.4.1")
os.system(os.environ["DD_ADDONS"])
import time
from os import getcwd, path
import deepdoctection as dd
from deepdoctection.dataflow.serialize import DataFromList
from deepdoctection.utils.settings import get_type
from dd_addons.analyzer.loader import get_loader
from dd_addons.extern.guidance import TOKEN_DEFAULT_INSTRUCTION
from dd_addons.utils.settings import register_llm_token_tag, register_string_categories_from_list
from dd_addons.extern.openai import OpenAiLmmTokenClassifier
import gradio as gr
analyzer = get_loader(reset_config_file=True)
demo = gr.Blocks(css="scrollbar.css")
def process_analyzer(openai_api_key, categories_str, instruction_str, img, pdf, max_datapoints):
categories_list = categories_str.split(",")
register_string_categories_from_list(categories_list, "custom_token_classes")
custom_token_class = dd.object_types_registry.get("custom_token_classes")
print([token_class for token_class in custom_token_class])
register_llm_token_tag([token_class for token_class in custom_token_class])
categories = {
str(idx + 1): get_type(val) for idx, val in enumerate(categories_list)
}
gpt_token_classifier = OpenAiLmmTokenClassifier(
model_name="gpt-3.5-turbo",
categories=categories,
api_key=openai_api_key,
instruction= instruction_str if instruction_str else None,
)
analyzer.pipe_component_list[8].language_model = gpt_token_classifier
if img is not None:
image = dd.Image(file_name=str(time.time()).replace(".","") + ".png", location="")
image.image = img[:, :, ::-1]
df = DataFromList(lst=[image])
df = analyzer.analyze(dataset_dataflow=df)
elif pdf:
df = analyzer.analyze(path=pdf.name, max_datapoints=max_datapoints)
else:
raise ValueError
df.reset_state()
json_out = {}
dpts = []
for idx, dp in enumerate(df):
dpts.append(dp)
json_out[f"page_{idx}"] = dp.get_token()
return [dp.viz(show_cells=False, show_layouts=False, show_tables=False, show_words=True, show_token_class=True, ignore_default_token_class=True)
for dp in dpts], json_out
with demo:
with gr.Box():
gr.Markdown("<h1><center>Document AI GPT</center></h1>")
gr.Markdown("<h2 ><center>Zero or few-shot Entity Extraction powered by ChatGPT and <strong>deep</strong>doctection </center></h2>"
"<center>This pipeline consists of a stack of models powered for layout analysis and table recognition "
"to prepare a prompt for ChatGPT. </center>"
"<center>Be aware! The Space is still very fragile.</center><br />")
with gr.Box():
gr.Markdown("<h2><center>Upload a document and choose setting</center></h2>")
with gr.Row():
with gr.Column():
with gr.Tab("Image upload"):
with gr.Column():
inputs = gr.Image(type='numpy', label="Original Image")
with gr.Tab("PDF upload *"):
with gr.Column():
inputs_pdf = gr.File(label="PDF")
gr.Markdown("<sup>* If an image is cached in tab, remove it first</sup>")
with gr.Box():
gr.Examples(
examples=[path.join(getcwd(), "sample_2.png")],
inputs = inputs)
with gr.Box():
gr.Markdown("Enter your OpenAI API Key* ")
user_token = gr.Textbox(value='', placeholder="OpenAI API Key", type="password", show_label=False)
gr.Markdown("<sup>* Your API key will not be saved. However, it is always recommended to deactivate the"
"API key once it is entered into an unknown source</sup>")
with gr.Column():
with gr.Box():
gr.Markdown(
"Enter a list of comma seperated entities. Use a snake case style. Avoid special characters. "
"Best way is to only use `a-z` and `_`")
categories = gr.Textbox(value='', placeholder="mitarbeiter_anzahl", show_label=False)
with gr.Box():
gr.Markdown("Optional: Enter a prompt for additional guidance. Will use the placeholder as fallback")
instruction = gr.Textbox(value='', placeholder=TOKEN_DEFAULT_INSTRUCTION, show_label=False)
with gr.Row():
max_imgs = gr.Slider(1, 3, value=1, step=1, label="Number of pages in multi page PDF",
info="Will stop after 3 pages")
with gr.Row():
btn = gr.Button("Run model", variant="primary")
with gr.Box():
gr.Markdown("<h2><center>Outputs</center></h2>")
with gr.Row():
with gr.Column():
with gr.Box():
gr.Markdown("<center><strong>JSON</strong></center>")
json = gr.JSON()
with gr.Column():
with gr.Box():
gr.Markdown("<center><strong>Layout detection</strong></center>")
gallery = gr.Gallery(
label="Output images", show_label=False, elem_id="gallery"
).style(grid=2)
with gr.Row():
with gr.Box():
gr.Markdown("<center><strong>Table</strong></center>")
html = gr.HTML()
btn.click(fn=process_analyzer, inputs=[user_token, categories, instruction, inputs, inputs_pdf, max_imgs],
outputs=[gallery, json])
demo.launch()