Spaces:
Sleeping
Sleeping
import base64 | |
import os | |
import re | |
from io import BytesIO | |
from pathlib import Path | |
import gradio as gr | |
from langchain.schema.output_parser import OutputParserException | |
from PIL import Image | |
import categories | |
from categories import Category | |
from main import process_image, process_pdf | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
PDF_IFRAME = """ | |
<div style="border-radius: 10px; width: 100%; overflow: hidden;"> | |
<iframe | |
src="data:application/pdf;base64,{0}" | |
width="100%" | |
height="400" | |
type="application/pdf"> | |
</iframe> | |
</div>""" | |
hf_writer_normal = gr.HuggingFaceDatasetSaver( | |
HF_TOKEN, "automatic-reimbursement-tool-demo", separate_dirs=False | |
) | |
hf_writer_incorrect = gr.HuggingFaceDatasetSaver( | |
HF_TOKEN, "automatic-reimbursement-tool-demo-incorrect", separate_dirs=False | |
) | |
# with open("examples/example1.pdf", "rb") as pdf_file: | |
# base64_pdf = base64.b64encode(pdf_file.read()) | |
# example_paths = [] | |
# current_file_path = None | |
# def ignore_examples(function): | |
# def new_function(*args, **kwargs): | |
# global example_paths, current_file_path | |
# if current_file_path not in example_paths: | |
# return function(*args, **kwargs) | |
def display_file(input_file): | |
global current_file_path | |
current_file_path = input_file.name if input_file else None | |
if not input_file: | |
return gr.HTML.update(visible=False), gr.Image.update(visible=False) | |
if input_file.name.endswith(".pdf"): | |
with open(input_file.name, "rb") as input_file: | |
pdf_base64 = base64.b64encode(input_file.read()).decode() | |
return gr.HTML.update( | |
PDF_IFRAME.format(pdf_base64), visible=True | |
), gr.Image.update(visible=False) | |
else: | |
# image = Image.open(input_file.name) | |
return gr.HTML.update(visible=False), gr.Image.update( | |
input_file.name, visible=True | |
) | |
def show_intermediate_outputs(show_intermediate): | |
if show_intermediate: | |
return gr.Accordion.update(visible=True) | |
else: | |
return gr.Accordion.update(visible=False) | |
def show_share_contact(share_result): | |
return gr.Textbox.update(visible=share_result) | |
def clear_inputs(): | |
return gr.File.update(value=None) | |
def clear_outputs(input_file): | |
if input_file: | |
return None, None, None, None | |
def extract_text(input_file): | |
"""Takes the input file and updates the extracted text""" | |
if not input_file: | |
gr.Error("Please upload a file to continue!") | |
return gr.Textbox.update() | |
# Send change to preprocessed image or to extracted text | |
if input_file.name.endswith(".pdf"): | |
text = process_pdf(Path(input_file.name), extract_only=True) | |
else: | |
text = process_image(Path(input_file.name), extract_only=True) | |
return text | |
def categorize_text(text): | |
"""Takes the extracted text and updates the category""" | |
category = categories.categorize_text(text) | |
return category | |
def query(category, text): | |
"""Takes the extracted text and category and updates the chatbot in two steps: | |
1. Construct a prompt | |
2. Generate a response | |
""" | |
category = Category[category] | |
chain = categories.category_modules[category].chain | |
formatted_prompt = chain.prompt.format_prompt( | |
text=text, | |
format_instructions=chain.output_parser.get_format_instructions(), | |
) | |
question = f"" | |
if len(formatted_prompt.messages) > 1: | |
question += f"**System:**\n{formatted_prompt.messages[0].content}" | |
question += f"\n\n**Human:**\n{formatted_prompt.messages[1].content}" | |
yield gr.Chatbot.update([[question, "Generating..."]]) | |
result = chain.generate( | |
input_list=[ | |
{ | |
"text": text, | |
"format_instructions": chain.output_parser.get_format_instructions(), | |
} | |
] | |
) | |
answer = result.generations[0][0].text | |
yield gr.Chatbot.update([[question, answer]]) | |
PARSING_REGEXP = r"\*\*System:\*\*\n([\s\S]+)\n\n\*\*Human:\*\*\n([\s\S]+)" | |
def parse(category, chatbot): | |
"""Takes the chatbot prompt and response and updates the extracted information""" | |
global PARSING_REGEXP | |
answer = chatbot[0][1] | |
category = Category[category] | |
chain = categories.category_modules[category].chain | |
yield {"status": "Parsing response..."} | |
try: | |
information = chain.output_parser.parse(answer) | |
information = information.json() if information else {} | |
except OutputParserException as e: | |
information = { | |
"error": "Unable to parse chatbot output", | |
"details": str(e), | |
"output": e.llm_output, | |
} | |
yield information | |
def activate_flags(): | |
return gr.Button.update(interactive=True), gr.Button.update(interactive=True) | |
def deactivate_flags(): | |
return gr.Button.update(interactive=False), gr.Button.update(interactive=False) | |
def flag_if_shared(flag_method): | |
def proxy(share_result, request: gr.Request, *args, **kwargs): | |
if share_result: | |
return flag_method(request, *args, **kwargs) | |
return proxy | |
with gr.Blocks(title="Automatic Reimbursement Tool Demo") as page: | |
gr.Markdown("<center><h1>Automatic Reimbursement Tool Demo</h1></center>") | |
gr.Markdown("<h2>Description</h2>") | |
gr.Markdown( | |
"The reimbursement filing process can be time-consuming and cumbersome, causing " | |
"frustration for faculty members and finance departments. Our project aims to " | |
"automate the information extraction involved in the process by feeding " | |
"extracted text to language models such as ChatGPT. This demo showcases the " | |
"categorization and extraction parts of the pipeline. Categorization is done " | |
"to identify the relevant details associated with the text, after which " | |
"extraction is done for those details using a language model." | |
) | |
gr.Markdown("<h2>Try it out!</h2>") | |
with gr.Box() as demo: | |
with gr.Row(): | |
with gr.Column(variant="panel"): | |
gr.HTML( | |
'<div><center style="color:rgb(200, 200, 200);">Input</center></div>' | |
) | |
pdf_preview = gr.HTML(label="Preview", show_label=True, visible=False) | |
image_preview = gr.Image( | |
label="Preview", show_label=True, visible=False, height=350 | |
) | |
input_file = gr.File( | |
label="Input receipt", | |
show_label=True, | |
type="file", | |
file_count="single", | |
file_types=["image", ".pdf"], | |
) | |
input_file.change( | |
display_file, input_file, [pdf_preview, image_preview] | |
) | |
with gr.Row(): | |
clear = gr.Button("Clear", variant="secondary") | |
submit_button = gr.Button("Submit", variant="primary") | |
show_intermediate = gr.Checkbox( | |
False, | |
label="Show intermediate outputs", | |
info="There are several intermediate steps in the process such as " | |
"preprocessing, OCR, chatbot interaction. You can choose to " | |
"show their results here.", | |
) | |
share_result = gr.Checkbox( | |
True, | |
label="Share results", | |
info="Sharing your result with us will help us improve this tool.", | |
interactive=True, | |
) | |
contact = gr.Textbox( | |
type="email", | |
label="Contact", | |
interactive=True, | |
placeholder="Enter your email address", | |
info="Optionally, enter your email address to allow us to contact " | |
"you regarding your result.", | |
visible=True, | |
) | |
share_result.change(show_share_contact, share_result, [contact]) | |
with gr.Column(variant="panel"): | |
gr.HTML( | |
'<div><center style="color:rgb(200, 200, 200);">Output</center></div>' | |
) | |
category = gr.Dropdown( | |
value=None, | |
choices=Category.__members__.keys(), | |
label=f"Recognized category ({', '.join(Category.__members__.keys())})", | |
show_label=True, | |
interactive=False, | |
) | |
intermediate_outputs = gr.Accordion( | |
"Intermediate outputs", open=True, visible=False | |
) | |
with intermediate_outputs: | |
extracted_text = gr.Textbox( | |
label="Extracted text", | |
show_label=True, | |
max_lines=5, | |
show_copy_button=True, | |
lines=5, | |
interactive=False, | |
) | |
chatbot = gr.Chatbot( | |
None, | |
label="Chatbot interaction", | |
show_label=True, | |
interactive=False, | |
height=240, | |
) | |
information = gr.JSON(label="Extracted information") | |
with gr.Row(): | |
flag_incorrect_button = gr.Button( | |
"Flag as incorrect", variant="stop", interactive=True | |
) | |
flag_irrelevant_button = gr.Button( | |
"Flag as irrelevant", variant="stop", interactive=True | |
) | |
show_intermediate.change( | |
show_intermediate_outputs, show_intermediate, [intermediate_outputs] | |
) | |
clear.click(clear_inputs, None, [input_file]).then( | |
deactivate_flags, | |
None, | |
[flag_incorrect_button, flag_irrelevant_button], | |
) | |
hf_writer_normal.setup( | |
[input_file, extracted_text, category, chatbot, information, contact], | |
flagging_dir="flagged", | |
) | |
flag_method = gr.flagging.FlagMethod( | |
hf_writer_normal, "", "", visual_feedback=False | |
) | |
submit_button.click( | |
clear_outputs, | |
[input_file], | |
[extracted_text, category, chatbot, information], | |
).then( | |
extract_text, | |
[input_file], | |
[extracted_text], | |
).then( | |
categorize_text, | |
[extracted_text], | |
[category], | |
).then( | |
query, | |
[category, extracted_text], | |
[chatbot], | |
queue=True, | |
).then( | |
parse, | |
[category, chatbot], | |
[information], | |
).then( | |
activate_flags, | |
None, | |
[flag_incorrect_button, flag_irrelevant_button], | |
).then( | |
flag_if_shared(flag_method), | |
[ | |
share_result, | |
input_file, | |
extracted_text, | |
category, | |
chatbot, | |
information, | |
contact, | |
], | |
None, | |
preprocess=False, | |
) | |
hf_writer_incorrect.setup( | |
[input_file, extracted_text, category, chatbot, information, contact], | |
flagging_dir="flagged_incorrect", | |
) | |
flag_incorrect_method = gr.flagging.FlagMethod( | |
hf_writer_incorrect, | |
"Flag as incorrect", | |
"Incorrect", | |
visual_feedback=True, | |
) | |
flag_incorrect_button.click( | |
lambda: gr.Button.update(value="Saving...", interactive=False), | |
None, | |
flag_incorrect_button, | |
queue=False, | |
) | |
flag_incorrect_button.click( | |
flag_incorrect_method, | |
inputs=[ | |
input_file, | |
extracted_text, | |
category, | |
chatbot, | |
information, | |
contact, | |
], | |
outputs=[flag_incorrect_button], | |
preprocess=False, | |
queue=False, | |
) | |
flag_irrelevant_method = gr.flagging.FlagMethod( | |
hf_writer_incorrect, | |
"Flag as irrelevant", | |
"Irrelevant", | |
visual_feedback=True, | |
) | |
flag_irrelevant_button.click( | |
lambda: gr.Button.update(value="Saving...", interactive=False), | |
None, | |
flag_irrelevant_button, | |
queue=False, | |
) | |
flag_irrelevant_button.click( | |
flag_irrelevant_method, | |
inputs=[ | |
input_file, | |
extracted_text, | |
category, | |
chatbot, | |
information, | |
contact, | |
], | |
outputs=[flag_irrelevant_button], | |
preprocess=False, | |
queue=False, | |
) | |
page.queue( | |
concurrency_count=20, | |
max_size=1, | |
) | |
page.launch(show_api=True, show_error=True, debug=True) | |