demo / app.py
ashwani21's picture
Update app.py
b4890b7
import base64
import os
import re
from io import BytesIO
from pathlib import Path
import gradio as gr
import pandas as pd
import json
from langchain.schema.output_parser import OutputParserException
from PIL import Image
from openpyxl import load_workbook
from openpyxl.utils import get_column_letter
import categories
from categories import Category
from main import process_image, process_pdf
from forex_python.converter import CurrencyRates
HF_TOKEN = os.getenv("HF_TOKEN")
PDF_IFRAME = """
<div style="border-radius: 10px; width: 100%; overflow: hidden;">
<iframe
src="data:application/pdf;base64,{0}"
width="100%"
height="400"
type="application/pdf">
</iframe>
</div>"""
hf_writer_normal = gr.HuggingFaceDatasetSaver(
HF_TOKEN, "automatic-reimbursement-tool-demo", separate_dirs=False
)
hf_writer_incorrect = gr.HuggingFaceDatasetSaver(
HF_TOKEN, "automatic-reimbursement-tool-demo-incorrect", separate_dirs=False
)
# with open("examples/example1.pdf", "rb") as pdf_file:
# base64_pdf = base64.b64encode(pdf_file.read())
# example_paths = []
# current_file_path = None
# def ignore_examples(function):
# def new_function(*args, **kwargs):
# global example_paths, current_file_path
# if current_file_path not in example_paths:
# return function(*args, **kwargs)
def display_file(input_files):
global current_file_paths
# Initialize the list of current file paths
current_file_paths = [file.name for file in input_files]
if not input_files:
return gr.HTML.update(visible=False), gr.Image.update(visible=False)
# Check if there's any PDF file among the uploaded files
pdf_base64 = None
for input_file in input_files:
if input_file.name.endswith(".pdf"):
with open(input_file.name, "rb") as pdf_file:
pdf_base64 = base64.b64encode(pdf_file.read()).decode()
break # Assuming only one PDF is present
if pdf_base64:
return gr.HTML.update(PDF_IFRAME.format(pdf_base64), visible=True), gr.Image.update(visible=False)
else:
# You can choose to display the first image in the list or handle multiple images differently
image = Image.open(input_files[0].name)
return gr.HTML.update(visible=False), gr.Image.update(image, visible=True)
def show_intermediate_outputs(show_intermediate):
if show_intermediate:
return gr.Accordion.update(visible=True)
else:
return gr.Accordion.update(visible=False)
def show_share_contact(share_result):
return gr.Textbox.update(visible=share_result)
def clear_inputs():
return gr.File.update(value=None)
def clear_outputs(input_file):
if input_file:
return None, None, None, None
def extract_text(input_file):
"""Takes the input file and updates the extracted text"""
if not input_file:
gr.Error("Please upload a file to continue!")
return gr.Textbox.update()
# Send change to preprocessed image or to extracted text
if input_file.name.endswith(".pdf"):
text = process_pdf(Path(input_file.name), extract_only=True)
else:
text = process_image(Path(input_file.name), extract_only=True)
return text
def find_currency_symbol(text):
currency_symbols = {
'USD': ['$', 'US$', 'US Dollar', 'United States Dollar'],
'EUR': ['€', 'Euro'],
'GBP': ['£', 'British Pound', 'Pound Sterling'],
'JPY': ['¥', 'Japanese Yen'],
'AUD': ['A$', 'AU$', 'Australian Dollar'],
'CAD': ['C$', 'CA$', 'Canadian Dollar'],
'CHF': ['Swiss Franc'],
'CNY': ['CN¥', 'Chinese Yuan', 'Renminbi'],
'HKD': ['HK$', 'Hong Kong Dollar'],
'NZD': ['NZ$', 'New Zealand Dollar'],
'SEK': ['Swedish Krona'],
'KRW': ['₩', 'South Korean Won'],
'SGD': ['S$', 'Singapore Dollar'],
'NOK': ['Norwegian Krone'],
'MXN': ['Mexican Peso'],
'INR': ['₹', 'Indian Rupee'],
'RUB': ['₽', 'Russian Ruble'],
'ZAR': ['South African Rand'],
'BRL': ['R$', 'Brazilian Real'],
}
detected_currency = None
for currency, symbols in currency_symbols.items():
for symbol in symbols:
if symbol in text:
detected_currency = currency
break
if detected_currency:
break
return detected_currency
def get_exchange_rate_to_inr(currency):
c = CurrencyRates()
if currency == 'INR' or currency == None:
return 1
else:
try:
exchange_rate = c.get_rate(currency, 'INR')
return exchange_rate
except:
return None
def categorize_text(text):
"""Takes the extracted text and updates the category"""
category = categories.categorize_text(text)
return category
def query(category, text):
"""Takes the extracted text and category and updates the chatbot in two steps:
1. Construct a prompt
2. Generate a response
"""
#category = Category[category]
chain = categories.category_modules[category].chain
formatted_prompt = chain.prompt.format_prompt(
text=text,
format_instructions=chain.output_parser.get_format_instructions(),
)
question = f""
if len(formatted_prompt.messages) > 1:
question += f"**System:**\n{formatted_prompt.messages[0].content}"
question += f"\n\n**Human:**\n{formatted_prompt.messages[-1].content}"
yield gr.Chatbot.update([[question, "Generating..."]])
result = chain.generate(
input_list=[
{
"text": text,
"format_instructions": chain.output_parser.get_format_instructions(),
}
]
)
answer = result.generations[0][0].text
yield gr.Chatbot.update([[question, answer]])
PARSING_REGEXP = r"\*\*System:\*\*\n([\s\S]+)\n\n\*\*Human:\*\*\n([\s\S]+)"
def parse(category, chatbot):
"""Takes the chatbot prompt and response and updates the extracted information"""
global PARSING_REGEXP
chatbot_responses = []
for response in chatbot:
chatbot_responses.append(response[1])
if not chatbot_responses:
# Handle the case when there are no chatbot responses
return {"status": "No responses available"}
answer = chatbot_responses[-1]
# try:
# answer = next(chatbot)[1]
# except StopIteration:
# answer = ""
if category not in Category.__members__:
# Handle the case when an invalid category is provided
answer="test"
#category = Category[category]
chain = categories.category_modules[category].chain
yield {"status": "Parsing response..."}
try:
information = chain.output_parser.parse(answer)
information = information.json() if information else {}
except OutputParserException as e:
information = {
"details": str(e),
"output": e.llm_output,
}
yield information
def activate_flags():
return gr.Button.update(interactive=True), gr.Button.update(interactive=True)
def deactivate_flags():
return gr.Button.update(interactive=False), gr.Button.update(interactive=False)
def flag_if_shared(flag_method):
def proxy(share_result, request: gr.Request, *args, **kwargs):
if share_result:
return flag_method(request, *args, **kwargs)
return proxy
def save_df_to_excel_with_autowidth(df, filename):
# Save DataFrame to Excel without any formatting
df.to_excel(filename, index=False, engine='openpyxl')
# Open the Excel file with openpyxl to adjust column widths
book = load_workbook(filename)
sheet = book.active
# Loop through columns and adjust the width based on max length in each column
for column in sheet.columns:
max_length = 0
column = [cell for cell in column]
for cell in column:
try:
if len(str(cell.value)) > max_length:
max_length = len(cell.value)
except:
pass
adjusted_width = (max_length + 2) # adding a little extra space
sheet.column_dimensions[get_column_letter(column[0].column)].width = adjusted_width
# Save the changes back to the Excel file
book.save(filename)
def process_and_output_files(input_files):
data = []
total_amount = 0
item_no = 1
for file in input_files:
# Extract and categorize text for each file
text = extract_text(file)
currency = find_currency_symbol(text)
category = categorize_text(text)
chatbot_response = query(category, text) # Convert the generator to a list
#parsed_info = parse(category, chatbot_response)
chats=list(chatbot_response)
# Append the relevant data for this file to the output_data list
# data.append(
# #"File Name": file.name,
# #"Extracted Text": text,
# #"Category": category,
# #"Chatbot Response": chatbot_response, # Access the first element as a list
# #"trial" : chats,
# chats[1]["value"][0][1] ,
# )
exchange_rate = get_exchange_rate_to_inr(currency)
exchange_rate = float("{:.2f}".format(exchange_rate))
response_dict = json.loads(chats[1]["value"][0][1])
if category.name == "TRAVEL_CAB" :
# Extract the relevant data
extracted_data = {
"S.No.": item_no,
"Nature of Expenditure": response_dict.get("summary"),
"Billing Date": response_dict.get("issue_date"),
"Bill/Invoice No.": "NA",
"Amount(Rs.)": response_dict.get("total") * exchange_rate,
}
else:
extracted_data = {
"S.No.": item_no,
"Nature of Expenditure": response_dict.get("summary"),
"Billing Date": response_dict.get("issue_date"),
"Bill/Invoice No.": response_dict.get("uids"),
"Amount(Rs.)": response_dict.get("total") * exchange_rate
}
total_amount+=response_dict.get("total") * exchange_rate
# Append the relevant data for this file to the data list
data.append(extracted_data)
item_no=item_no+1
total_data = {
"S.No.": "",
"Nature of Expenditure": "Total Amount",
"Billing Date": "",
"Bill/Invoice No.": "",
"Amount(Rs.)": total_amount
}
data.append(total_data)
string_data = []
for item in data:
string_item = {key: str(value) for key, value in item.items()}
string_data.append(string_item)
df = pd.DataFrame(string_data)
filename = "output.xlsx"
save_df_to_excel_with_autowidth(df, filename)
table_html = df.to_html(classes="table table-bordered", index=True)
scrollable_table = f'<div style="overflow-x: auto;">{table_html}</div>'
return scrollable_table, filename
#return data
with gr.Blocks(title="Automatic Reimbursement Tool Demo") as page:
gr.Markdown("<center><h1>Automatic Reimbursement Tool Demo</h1></center>")
gr.Markdown("<h2>Description</h2>")
gr.Markdown(
"The reimbursement filing process can be time-consuming and cumbersome, causing "
"frustration for faculty members and finance departments. Our project aims to "
"automate the information extraction involved in the process by feeding "
"extracted text to language models such as ChatGPT. This demo showcases the "
"categorization and extraction parts of the pipeline. Categorization is done "
"to identify the relevant details associated with the text, after which "
"extraction is done for those details using a language model."
)
gr.Markdown("<h2>Try it out!</h2>")
with gr.Box() as demo:
with gr.Row():
with gr.Column(variant="panel"):
gr.HTML(
'<div><center style="color:rgb(200, 200, 200);">Input</center></div>'
)
pdf_preview = gr.HTML(label="Preview", show_label=True, visible=False)
image_preview = gr.Image(
label="Preview", show_label=True, visible=False, height=350
)
input_file = gr.File(
label="Input receipt",
show_label=True,
type="file",
file_count="multiple",
file_types=["image", ".pdf"],
)
input_file.change(
display_file, input_file, [pdf_preview, image_preview]
)
with gr.Row():
clear = gr.Button("Clear", variant="secondary")
submit_button = gr.Button("Submit", variant="primary")
show_intermediate = gr.Checkbox(
False,
label="Show intermediate outputs",
info="There are several intermediate steps in the process such as "
"preprocessing, OCR, chatbot interaction. You can choose to "
"show their results here.",
visible=False, # Shortcut for removal
)
share_result = gr.Checkbox(
True,
label="Share results",
info="Sharing your result with us will help us improve this tool.",
interactive=True,
)
contact = gr.Textbox(
type="email",
label="Contact",
interactive=True,
placeholder="Enter your email address",
info="Optionally, enter your email address to allow us to contact "
"you regarding your result.",
visible=True,
)
share_result.change(show_share_contact, share_result, [contact])
with gr.Column(variant="panel"):
gr.HTML(
'<div><center style="color:rgb(200, 200, 200);">Output</center></div>'
)
category = gr.Dropdown(
value=None,
choices=Category.__members__.keys(),
label=f"Recognized category ({', '.join(Category.__members__.keys())})",
show_label=True,
interactive=False,
)
intermediate_outputs = gr.Accordion(
"Intermediate outputs", open=True, visible=False
)
with intermediate_outputs:
extracted_text = gr.Textbox(
label="Extracted text",
show_label=True,
max_lines=5,
show_copy_button=True,
lines=5,
interactive=False,
)
chatbot = gr.Chatbot(
None,
label="Chatbot interaction",
show_label=True,
interactive=False,
height=240,
)
#information = gr.JSON(label="Extracted information")
table_display = gr.HTML(label="Table Display")
excel_download = gr.File(label="Download Excel", type="file")
with gr.Row():
flag_incorrect_button = gr.Button(
"Flag as incorrect", variant="stop", interactive=True
)
flag_irrelevant_button = gr.Button(
"Flag as irrelevant", variant="stop", interactive=True
)
show_intermediate.change(
show_intermediate_outputs, show_intermediate, [intermediate_outputs]
)
clear.click(clear_inputs, None, [input_file]).then(
deactivate_flags,
None,
[flag_incorrect_button, flag_irrelevant_button],
)
hf_writer_normal.setup(
[input_file, extracted_text, category, chatbot, table_display, contact],
flagging_dir="flagged",
)
flag_method = gr.flagging.FlagMethod(
hf_writer_normal, "", "", visual_feedback=False
)
submit_button.click(
clear_outputs,
[input_file],
[extracted_text, category, chatbot, table_display],
).then(
process_and_output_files,
[input_file],
[table_display, excel_download], # Adding excel_download here
).then(
flag_if_shared(flag_method),
[
share_result,
input_file,
extracted_text,
category,
chatbot,
table_display,
contact,
],
None,
preprocess=False,
)
hf_writer_incorrect.setup(
[input_file, extracted_text, category, chatbot, table_display, contact],
flagging_dir="flagged_incorrect",
)
flag_incorrect_method = gr.flagging.FlagMethod(
hf_writer_incorrect,
"Flag as incorrect",
"Incorrect",
visual_feedback=True,
)
flag_incorrect_button.click(
lambda: gr.Button.update(value="Saving...", interactive=False),
None,
flag_incorrect_button,
queue=False,
)
flag_incorrect_button.click(
flag_incorrect_method,
inputs=[
input_file,
extracted_text,
category,
chatbot,
table_display,
contact,
],
outputs=[flag_incorrect_button],
preprocess=False,
queue=False,
)
flag_irrelevant_method = gr.flagging.FlagMethod(
hf_writer_incorrect,
"Flag as irrelevant",
"Irrelevant",
visual_feedback=True,
)
flag_irrelevant_button.click(
lambda: gr.Button.update(value="Saving...", interactive=False),
None,
flag_irrelevant_button,
queue=False,
)
flag_irrelevant_button.click(
flag_irrelevant_method,
inputs=[
input_file,
extracted_text,
category,
chatbot,
table_display,
contact,
],
outputs=[flag_irrelevant_button],
preprocess=False,
queue=False,
)
page.queue(
concurrency_count=20,
max_size=1,
)
page.launch(show_api=True, show_error=True, debug=True)