import base64
import os
import re
from io import BytesIO
from pathlib import Path
import gradio as gr
from langchain.schema.output_parser import OutputParserException
from PIL import Image
import categories
from categories import Category
from main import process_image, process_pdf
HF_TOKEN = os.getenv("HF_TOKEN")
PDF_IFRAME = """
"""
hf_writer_normal = gr.HuggingFaceDatasetSaver(
HF_TOKEN, "automatic-reimbursement-tool-demo", separate_dirs=False
)
hf_writer_incorrect = gr.HuggingFaceDatasetSaver(
HF_TOKEN, "automatic-reimbursement-tool-demo-incorrect", separate_dirs=False
)
# with open("examples/example1.pdf", "rb") as pdf_file:
# base64_pdf = base64.b64encode(pdf_file.read())
# example_paths = []
# current_file_path = None
# def ignore_examples(function):
# def new_function(*args, **kwargs):
# global example_paths, current_file_path
# if current_file_path not in example_paths:
# return function(*args, **kwargs)
def display_file(input_file):
global current_file_path
current_file_path = input_file.name if input_file else None
if not input_file:
return gr.HTML.update(visible=False), gr.Image.update(visible=False)
if input_file.name.endswith(".pdf"):
with open(input_file.name, "rb") as input_file:
pdf_base64 = base64.b64encode(input_file.read()).decode()
return gr.HTML.update(
PDF_IFRAME.format(pdf_base64), visible=True
), gr.Image.update(visible=False)
else:
# image = Image.open(input_file.name)
return gr.HTML.update(visible=False), gr.Image.update(
input_file.name, visible=True
)
def show_intermediate_outputs(show_intermediate):
if show_intermediate:
return gr.Accordion.update(visible=True)
else:
return gr.Accordion.update(visible=False)
def show_share_contact(share_result):
return gr.Textbox.update(visible=share_result)
def clear_inputs():
return gr.File.update(value=None)
def clear_outputs(input_file):
if input_file:
return None, None, None, None
def extract_text(input_file):
"""Takes the input file and updates the extracted text"""
if not input_file:
gr.Error("Please upload a file to continue!")
return gr.Textbox.update()
# Send change to preprocessed image or to extracted text
if input_file.name.endswith(".pdf"):
text = process_pdf(Path(input_file.name), extract_only=True)
else:
text = process_image(Path(input_file.name), extract_only=True)
return text
def categorize_text(text):
"""Takes the extracted text and updates the category"""
category = categories.categorize_text(text)
return category
def query(category, text):
"""Takes the extracted text and category and updates the chatbot in two steps:
1. Construct a prompt
2. Generate a response
"""
category = Category[category]
chain = categories.category_modules[category].chain
formatted_prompt = chain.prompt.format_prompt(
text=text,
format_instructions=chain.output_parser.get_format_instructions(),
)
question = f""
if len(formatted_prompt.messages) > 1:
question += f"**System:**\n{formatted_prompt.messages[0].content}"
question += f"\n\n**Human:**\n{formatted_prompt.messages[1].content}"
yield gr.Chatbot.update([[question, "Generating..."]])
result = chain.generate(
input_list=[
{
"text": text,
"format_instructions": chain.output_parser.get_format_instructions(),
}
]
)
answer = result.generations[0][0].text
yield gr.Chatbot.update([[question, answer]])
PARSING_REGEXP = r"\*\*System:\*\*\n([\s\S]+)\n\n\*\*Human:\*\*\n([\s\S]+)"
def parse(category, chatbot):
"""Takes the chatbot prompt and response and updates the extracted information"""
global PARSING_REGEXP
answer = chatbot[0][1]
category = Category[category]
chain = categories.category_modules[category].chain
yield {"status": "Parsing response..."}
try:
information = chain.output_parser.parse(answer)
information = information.json() if information else {}
except OutputParserException as e:
information = {
"error": "Unable to parse chatbot output",
"details": str(e),
"output": e.llm_output,
}
yield information
def activate_flags():
return gr.Button.update(interactive=True), gr.Button.update(interactive=True)
def deactivate_flags():
return gr.Button.update(interactive=False), gr.Button.update(interactive=False)
def flag_if_shared(flag_method):
def proxy(share_result, request: gr.Request, *args, **kwargs):
if share_result:
return flag_method(request, *args, **kwargs)
return proxy
with gr.Blocks(title="Automatic Reimbursement Tool Demo") as page:
gr.Markdown("Automatic Reimbursement Tool Demo
")
gr.Markdown("Description
")
gr.Markdown(
"The reimbursement filing process can be time-consuming and cumbersome, causing "
"frustration for faculty members and finance departments. Our project aims to "
"automate the information extraction involved in the process by feeding "
"extracted text to language models such as ChatGPT. This demo showcases the "
"categorization and extraction parts of the pipeline. Categorization is done "
"to identify the relevant details associated with the text, after which "
"extraction is done for those details using a language model."
)
gr.Markdown("Try it out!
")
with gr.Box() as demo:
with gr.Row():
with gr.Column(variant="panel"):
gr.HTML(
'Input'
)
pdf_preview = gr.HTML(label="Preview", show_label=True, visible=False)
image_preview = gr.Image(
label="Preview", show_label=True, visible=False, height=350
)
input_file = gr.File(
label="Input receipt",
show_label=True,
type="file",
file_count="single",
file_types=["image", ".pdf"],
)
input_file.change(
display_file, input_file, [pdf_preview, image_preview]
)
with gr.Row():
clear = gr.Button("Clear", variant="secondary")
submit_button = gr.Button("Submit", variant="primary")
show_intermediate = gr.Checkbox(
False,
label="Show intermediate outputs",
info="There are several intermediate steps in the process such as "
"preprocessing, OCR, chatbot interaction. You can choose to "
"show their results here.",
)
share_result = gr.Checkbox(
True,
label="Share results",
info="Sharing your result with us will help us improve this tool.",
interactive=True,
)
contact = gr.Textbox(
type="email",
label="Contact",
interactive=True,
placeholder="Enter your email address",
info="Optionally, enter your email address to allow us to contact "
"you regarding your result.",
visible=True,
)
share_result.change(show_share_contact, share_result, [contact])
with gr.Column(variant="panel"):
gr.HTML(
'Output'
)
category = gr.Dropdown(
value=None,
choices=Category.__members__.keys(),
label=f"Recognized category ({', '.join(Category.__members__.keys())})",
show_label=True,
interactive=False,
)
intermediate_outputs = gr.Accordion(
"Intermediate outputs", open=True, visible=False
)
with intermediate_outputs:
extracted_text = gr.Textbox(
label="Extracted text",
show_label=True,
max_lines=5,
show_copy_button=True,
lines=5,
interactive=False,
)
chatbot = gr.Chatbot(
None,
label="Chatbot interaction",
show_label=True,
interactive=False,
height=240,
)
information = gr.JSON(label="Extracted information")
with gr.Row():
flag_incorrect_button = gr.Button(
"Flag as incorrect", variant="stop", interactive=True
)
flag_irrelevant_button = gr.Button(
"Flag as irrelevant", variant="stop", interactive=True
)
show_intermediate.change(
show_intermediate_outputs, show_intermediate, [intermediate_outputs]
)
clear.click(clear_inputs, None, [input_file]).then(
deactivate_flags,
None,
[flag_incorrect_button, flag_irrelevant_button],
)
hf_writer_normal.setup(
[input_file, extracted_text, category, chatbot, information, contact],
flagging_dir="flagged",
)
flag_method = gr.flagging.FlagMethod(
hf_writer_normal, "", "", visual_feedback=False
)
submit_button.click(
clear_outputs,
[input_file],
[extracted_text, category, chatbot, information],
).then(
extract_text,
[input_file],
[extracted_text],
).then(
categorize_text,
[extracted_text],
[category],
).then(
query,
[category, extracted_text],
[chatbot],
queue=True,
).then(
parse,
[category, chatbot],
[information],
).then(
activate_flags,
None,
[flag_incorrect_button, flag_irrelevant_button],
).then(
flag_if_shared(flag_method),
[
share_result,
input_file,
extracted_text,
category,
chatbot,
information,
contact,
],
None,
preprocess=False,
)
hf_writer_incorrect.setup(
[input_file, extracted_text, category, chatbot, information, contact],
flagging_dir="flagged_incorrect",
)
flag_incorrect_method = gr.flagging.FlagMethod(
hf_writer_incorrect,
"Flag as incorrect",
"Incorrect",
visual_feedback=True,
)
flag_incorrect_button.click(
lambda: gr.Button.update(value="Saving...", interactive=False),
None,
flag_incorrect_button,
queue=False,
)
flag_incorrect_button.click(
flag_incorrect_method,
inputs=[
input_file,
extracted_text,
category,
chatbot,
information,
contact,
],
outputs=[flag_incorrect_button],
preprocess=False,
queue=False,
)
flag_irrelevant_method = gr.flagging.FlagMethod(
hf_writer_incorrect,
"Flag as irrelevant",
"Irrelevant",
visual_feedback=True,
)
flag_irrelevant_button.click(
lambda: gr.Button.update(value="Saving...", interactive=False),
None,
flag_irrelevant_button,
queue=False,
)
flag_irrelevant_button.click(
flag_irrelevant_method,
inputs=[
input_file,
extracted_text,
category,
chatbot,
information,
contact,
],
outputs=[flag_irrelevant_button],
preprocess=False,
queue=False,
)
page.queue(
concurrency_count=20,
max_size=1,
)
page.launch(show_api=True, show_error=True, debug=True)