Spaces:
Sleeping
Sleeping
import gradio as gr | |
import pandas as pd | |
from extract_answer import extract_endpoint_llama | |
from generate_answers_matrix import generate_answers | |
example_queries = [ | |
"What is the size, shape, and energy (watt hour) or capacity (Amp hour) of battery discussed in the paper?", | |
"What specific mechanical testing methods were used to quantify strength?", | |
"What parameters they used to quantify the benefit of their individual design (mass saving, increased run time, etc.)?", | |
"What material chemistry combination (on the anode, cathode, separator, and electrolyte) was used in these papers?", | |
"What kind of end use application they targeted?", | |
] | |
MAX_CATEGORIES = 10 | |
def change_button(text): | |
if len(text) > 0: | |
return gr.Button(interactive=True) | |
else: | |
return gr.Button(interactive=False) | |
def generate_category_btn(cluster_output): | |
unique_categories = set() | |
for item in cluster_output: | |
unique_categories.update(item["categories"]) | |
update_show = [gr.Button(visible=True, value=w) for w in unique_categories] | |
update_hide = [ | |
gr.Button(visible=False, value="") | |
for _ in range(MAX_CATEGORIES - len(unique_categories)) | |
] | |
return update_show + update_hide | |
def add_query(this_query, query_list): | |
if not query_list: | |
query_list = [this_query] | |
elif this_query not in query_list: | |
query_list.append(this_query) | |
df = pd.DataFrame(query_list, columns=["Queries"]) | |
return query_list, df | |
def reset_queries(): | |
return [], pd.DataFrame(columns=["Queries"]) | |
btn_list = [] | |
with gr.Blocks() as app: | |
gr.Markdown( | |
""" | |
# Paper Query Matrix | |
This app extracts text from papers and then searches for relevant excerpts based on user queries. | |
### Input | |
1. A group of research papers that you want to run the queries on. | |
1. Queries that you would like to know about these papers. | |
### Output | |
Table containing the relevant excerpts from the papers for each of the queries. | |
# 1. Upload + Extract | |
First, upload the papers you want to analyze. Currently, we only support PDFs. Once they're uploaded, you can extract the text data from the papers. | |
""" | |
) | |
file_upload = gr.Files() | |
extract_btn = gr.Button("Extract", interactive=False) | |
with gr.Tab(label="Table"): | |
extract_df = gr.Dataframe( | |
datatype="markdown", column_widths=[100, 400], wrap=True | |
) | |
with gr.Tab(label="JSON"): | |
extract_output = gr.JSON(label="Extract Output") | |
gr.Markdown( | |
""" | |
---------------- | |
# 2. Create Queries | |
Enter a the queries that you would like to know about these papers. This will search the papers to find the most relevant excerpts. | |
""" | |
) | |
gr.Markdown( | |
""" | |
### Input | |
""" | |
) | |
query = gr.Textbox( | |
label="Query", value=example_queries[1], lines=3, placeholder="Enter a query" | |
) | |
add_query_btn = gr.Button("Add Query", interactive=False) | |
gr.Markdown( | |
""" | |
You can also select some example queries below. | |
""" | |
) | |
with gr.Row(): | |
q0_btn = gr.Button(example_queries[0], interactive=False) | |
q1_btn = gr.Button(example_queries[1], interactive=False) | |
q2_btn = gr.Button(example_queries[2], interactive=False) | |
q3_btn = gr.Button(example_queries[3], interactive=False) | |
q4_btn = gr.Button(example_queries[4], interactive=False) | |
gr.Markdown( | |
""" | |
### Output | |
""" | |
) | |
with gr.Tab(label="Queries Table"): | |
query_df = gr.Dataframe( | |
datatype="markdown", column_widths=[100, 100, 300], wrap=True | |
) | |
with gr.Tab(label="JSON"): | |
query_output = gr.JSON(label="Queries") | |
reset_query_btn = gr.Button("Clear Queries", interactive=False) | |
gr.Markdown( | |
""" | |
---------------- | |
# 3. Extract Answers | |
Gather the relevant excerpts from each of the papers | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown( | |
""" | |
### Input | |
""" | |
) | |
generate_answers_btn = gr.Button("Extract Answers", interactive=False) | |
gr.Markdown( | |
""" | |
### Answer Matrix | |
""" | |
) | |
with gr.Tab(label="Output Table"): | |
answers_df = gr.Dataframe( | |
datatype="markdown", column_widths=[100, 100, 300], wrap=True | |
) | |
with gr.Tab(label="JSON"): | |
answers_output = gr.JSON(label="Answer Output") | |
# Event handlers | |
file_upload.change(fn=change_button, inputs=[file_upload], outputs=[extract_btn]) | |
extract_output.change( | |
fn=change_button, inputs=[extract_output], outputs=[add_query_btn] | |
) | |
extract_output.change(fn=change_button, inputs=[extract_output], outputs=[q0_btn]) | |
extract_output.change(fn=change_button, inputs=[extract_output], outputs=[q1_btn]) | |
extract_output.change(fn=change_button, inputs=[extract_output], outputs=[q2_btn]) | |
extract_output.change(fn=change_button, inputs=[extract_output], outputs=[q3_btn]) | |
extract_output.change(fn=change_button, inputs=[extract_output], outputs=[q4_btn]) | |
extract_output.change( | |
fn=change_button, inputs=[extract_output], outputs=[reset_query_btn] | |
) | |
extract_btn.click( | |
fn=extract_endpoint_llama, | |
inputs=[file_upload], | |
outputs=[extract_output, extract_df], | |
) | |
q0_btn.click( | |
fn=add_query, | |
inputs=[q0_btn, query_output], | |
outputs=[query_output, query_df], | |
) | |
q1_btn.click( | |
fn=add_query, | |
inputs=[q1_btn, query_output], | |
outputs=[query_output, query_df], | |
) | |
q2_btn.click( | |
fn=add_query, | |
inputs=[q2_btn, query_output], | |
outputs=[query_output, query_df], | |
) | |
q3_btn.click( | |
fn=add_query, | |
inputs=[q3_btn, query_output], | |
outputs=[query_output, query_df], | |
) | |
q4_btn.click( | |
fn=add_query, | |
inputs=[q4_btn, query_output], | |
outputs=[query_output, query_df], | |
) | |
add_query_btn.click( | |
fn=add_query, | |
inputs=[query, query_output], | |
outputs=[query_output, query_df], | |
) | |
reset_query_btn.click( | |
fn=reset_queries, | |
inputs=[], | |
outputs=[query_output, query_df], | |
) | |
query_output.change( | |
fn=change_button, inputs=[query_output], outputs=[generate_answers_btn] | |
) | |
generate_answers_btn.click( | |
fn=generate_answers, | |
inputs=[extract_output, query_output], | |
outputs=[answers_output, answers_df], | |
# api_name="cluster", | |
) | |
if __name__ == "__main__": | |
app.launch() | |