jojortz commited on
Commit
e874a08
·
1 Parent(s): a50a9fe

add initial llm4research app

Browse files
.gitattributes copy ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .env
README copy.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Llm4research
3
+ emoji: 🏆
4
+ colorFrom: gray
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 4.21.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+
4
+ from extract_answer import extract_endpoint_llama
5
+ from generate_answers_matrix import generate_answers
6
+
7
+ example_queries = [
8
+ "What is the size, shape, and energy (watt hour) or capacity (Amp hour) of battery discussed in the paper?",
9
+ "What specific mechanical testing methods were used to quantify strength?",
10
+ "What parameters they used to quantify the benefit of their individual design (mass saving, increased run time, etc.)?",
11
+ "What material chemistry combination (on the anode, cathode, separator, and electrolyte) was used in these papers?",
12
+ "What kind of end use application they targeted?",
13
+ ]
14
+ MAX_CATEGORIES = 10
15
+
16
+
17
+ def change_button(text):
18
+ if len(text) > 0:
19
+ return gr.Button(interactive=True)
20
+ else:
21
+ return gr.Button(interactive=False)
22
+
23
+
24
+ def generate_category_btn(cluster_output):
25
+ unique_categories = set()
26
+ for item in cluster_output:
27
+ unique_categories.update(item["categories"])
28
+
29
+ update_show = [gr.Button(visible=True, value=w) for w in unique_categories]
30
+ update_hide = [
31
+ gr.Button(visible=False, value="")
32
+ for _ in range(MAX_CATEGORIES - len(unique_categories))
33
+ ]
34
+ return update_show + update_hide
35
+
36
+
37
+ def add_query(this_query, query_list):
38
+ if not query_list:
39
+ query_list = [this_query]
40
+ elif this_query not in query_list:
41
+ query_list.append(this_query)
42
+
43
+ df = pd.DataFrame(query_list, columns=["Queries"])
44
+ return query_list, df
45
+
46
+
47
+ def reset_queries():
48
+ return [], pd.DataFrame(columns=["Queries"])
49
+
50
+
51
+ btn_list = []
52
+
53
+
54
+ with gr.Blocks() as app:
55
+ gr.Markdown(
56
+ """
57
+ # Paper Query Matrix
58
+ This app extracts text from papers and then searches for relevant excerpts based on user queries.
59
+
60
+ ### Input
61
+ 1. A group of research papers that you want to run the queries on.
62
+ 1. Queries that you would like to know about these papers.
63
+
64
+ ### Output
65
+ Table containing the relevant excerpts from the papers for each of the queries.
66
+
67
+ # 1. Upload + Extract
68
+ First, upload the papers you want to analyze. Currently, we only support PDFs. Once they're uploaded, you can extract the text data from the papers.
69
+ """
70
+ )
71
+ file_upload = gr.Files()
72
+ extract_btn = gr.Button("Extract", interactive=False)
73
+ with gr.Tab(label="Table"):
74
+ extract_df = gr.Dataframe(
75
+ datatype="markdown", column_widths=[100, 400], wrap=True
76
+ )
77
+ with gr.Tab(label="JSON"):
78
+ extract_output = gr.JSON(label="Extract Output")
79
+
80
+ gr.Markdown(
81
+ """
82
+ ----------------
83
+ # 2. Create Queries
84
+ Enter a the queries that you would like to know about these papers. This will search the papers to find the most relevant excerpts.
85
+ """
86
+ )
87
+
88
+ gr.Markdown(
89
+ """
90
+ ### Input
91
+ """
92
+ )
93
+ query = gr.Textbox(
94
+ label="Query", value=example_queries[1], lines=3, placeholder="Enter a query"
95
+ )
96
+ add_query_btn = gr.Button("Add Query", interactive=False)
97
+ gr.Markdown(
98
+ """
99
+ You can also select some example queries below.
100
+ """
101
+ )
102
+ with gr.Row():
103
+ q0_btn = gr.Button(example_queries[0], interactive=False)
104
+ q1_btn = gr.Button(example_queries[1], interactive=False)
105
+ q2_btn = gr.Button(example_queries[2], interactive=False)
106
+ q3_btn = gr.Button(example_queries[3], interactive=False)
107
+ q4_btn = gr.Button(example_queries[4], interactive=False)
108
+
109
+ gr.Markdown(
110
+ """
111
+ ### Output
112
+ """
113
+ )
114
+ with gr.Tab(label="Queries Table"):
115
+ query_df = gr.Dataframe(
116
+ datatype="markdown", column_widths=[100, 100, 300], wrap=True
117
+ )
118
+ with gr.Tab(label="JSON"):
119
+ query_output = gr.JSON(label="Queries")
120
+
121
+ reset_query_btn = gr.Button("Clear Queries", interactive=False)
122
+
123
+ gr.Markdown(
124
+ """
125
+ ----------------
126
+ # 3. Extract Answers
127
+ Gather the relevant excerpts from each of the papers
128
+ """
129
+ )
130
+ with gr.Row():
131
+ with gr.Column():
132
+ gr.Markdown(
133
+ """
134
+ ### Input
135
+ """
136
+ )
137
+ generate_answers_btn = gr.Button("Extract Answers", interactive=False)
138
+
139
+ gr.Markdown(
140
+ """
141
+ ### Answer Matrix
142
+ """
143
+ )
144
+ with gr.Tab(label="Output Table"):
145
+ answers_df = gr.Dataframe(
146
+ datatype="markdown", column_widths=[100, 100, 300], wrap=True
147
+ )
148
+ with gr.Tab(label="JSON"):
149
+ answers_output = gr.JSON(label="Answer Output")
150
+
151
+ # Event handlers
152
+ file_upload.change(fn=change_button, inputs=[file_upload], outputs=[extract_btn])
153
+
154
+ extract_output.change(
155
+ fn=change_button, inputs=[extract_output], outputs=[add_query_btn]
156
+ )
157
+ extract_output.change(fn=change_button, inputs=[extract_output], outputs=[q0_btn])
158
+ extract_output.change(fn=change_button, inputs=[extract_output], outputs=[q1_btn])
159
+ extract_output.change(fn=change_button, inputs=[extract_output], outputs=[q2_btn])
160
+ extract_output.change(fn=change_button, inputs=[extract_output], outputs=[q3_btn])
161
+ extract_output.change(fn=change_button, inputs=[extract_output], outputs=[q4_btn])
162
+ extract_output.change(
163
+ fn=change_button, inputs=[extract_output], outputs=[reset_query_btn]
164
+ )
165
+
166
+ extract_btn.click(
167
+ fn=extract_endpoint_llama,
168
+ inputs=[file_upload],
169
+ outputs=[extract_output, extract_df],
170
+ )
171
+
172
+ q0_btn.click(
173
+ fn=add_query,
174
+ inputs=[q0_btn, query_output],
175
+ outputs=[query_output, query_df],
176
+ )
177
+
178
+ q1_btn.click(
179
+ fn=add_query,
180
+ inputs=[q1_btn, query_output],
181
+ outputs=[query_output, query_df],
182
+ )
183
+
184
+ q2_btn.click(
185
+ fn=add_query,
186
+ inputs=[q2_btn, query_output],
187
+ outputs=[query_output, query_df],
188
+ )
189
+
190
+ q3_btn.click(
191
+ fn=add_query,
192
+ inputs=[q3_btn, query_output],
193
+ outputs=[query_output, query_df],
194
+ )
195
+
196
+ q4_btn.click(
197
+ fn=add_query,
198
+ inputs=[q4_btn, query_output],
199
+ outputs=[query_output, query_df],
200
+ )
201
+
202
+ add_query_btn.click(
203
+ fn=add_query,
204
+ inputs=[query, query_output],
205
+ outputs=[query_output, query_df],
206
+ )
207
+
208
+ reset_query_btn.click(
209
+ fn=reset_queries,
210
+ inputs=[],
211
+ outputs=[query_output, query_df],
212
+ )
213
+
214
+ query_output.change(
215
+ fn=change_button, inputs=[query_output], outputs=[generate_answers_btn]
216
+ )
217
+
218
+ generate_answers_btn.click(
219
+ fn=generate_answers,
220
+ inputs=[extract_output, query_output],
221
+ outputs=[answers_output, answers_df],
222
+ # api_name="cluster",
223
+ )
224
+
225
+ if __name__ == "__main__":
226
+ app.launch()
extract_answer.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import pandas as pd
4
+ from dotenv import load_dotenv
5
+ from llama_index.core import SimpleDirectoryReader
6
+ from llama_parse import LlamaParse
7
+
8
+ load_dotenv()
9
+ MIN_PARAGRAPH_LENGTH = 50
10
+
11
+
12
+ def extract_paragraphs(markdown_text):
13
+ """
14
+ Extract paragraphs from a markdown text.
15
+ """
16
+ # Split the text into paragraphs using regex
17
+ paragraphs = re.split(r"\n\n+", markdown_text)
18
+ # Remove leading and trailing whitespaces from each paragraph
19
+ paragraphs = [p.strip() for p in paragraphs if p.strip()]
20
+ paragraphs = [
21
+ p
22
+ for p in paragraphs
23
+ if len(p) >= MIN_PARAGRAPH_LENGTH and not p.startswith("#")
24
+ ]
25
+ print(f"created {len(paragraphs)} paragraphs\n", paragraphs)
26
+
27
+ return paragraphs
28
+
29
+
30
+ def extract_endpoint_llama(file_paths):
31
+ """
32
+ Extract PDFs using LlamaParse.
33
+ """
34
+
35
+ # set up parser
36
+ parser = LlamaParse(result_type="markdown") # "markdown" and "text" are available
37
+
38
+ # use SimpleDirectoryReader to parse our file
39
+ file_extractor = {".pdf": parser}
40
+ documents = SimpleDirectoryReader(
41
+ input_files=file_paths, file_extractor=file_extractor
42
+ ).load_data()
43
+
44
+ extracted_data = []
45
+
46
+ for doc in documents:
47
+ print(doc.text[:500])
48
+ paragraphs = extract_paragraphs(doc.text)
49
+ data = {
50
+ "paper": doc.metadata["file_name"],
51
+ "chunks": paragraphs,
52
+ }
53
+ extracted_data.append(data)
54
+
55
+ df = pd.DataFrame(extracted_data)
56
+
57
+ return [extracted_data, df]
generate_answers_matrix.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import pandas as pd
3
+ from uniflow.flow.client import TransformClient
4
+ from uniflow.flow.config import TransformOpenAIConfig
5
+ from uniflow.op.prompt import Context
6
+
7
+ DEBUG = False
8
+
9
+
10
+ def generate_relevant_chunks(query, input_data, progress=gr.Progress()):
11
+ """
12
+ Generate relevant chunks of text from a list of papers based on a query.
13
+ """
14
+ data_list = []
15
+ for paper in input_data: # progress.tqdm(input_data, desc="Papers"):
16
+ data = [Context(context=query, paragraph=p) for p in paper["chunks"]]
17
+ data_list.append({"paper": paper["paper"], "data": data})
18
+
19
+ instruction = """
20
+ # Task: I am a researcher trying to understand information across several research papers. You are to determine which of the chunks most directly contains information related to the query.
21
+ ## Input:
22
+ 1. context: A brief query or description of the information I am looking for.
23
+ 2. paragraph: An paragraph from a research paper.
24
+ ## Evaluation Criteria: You should pick which sentence(s) contains directly relevant information to the context. The best answer is the sentences that most directly answer or contain the information specific to the context. If there are no such sentences, you should answer with ["None"].
25
+ ## Response Format: Your response should only include two fields below:
26
+ 1. explanation: Reasoning behind your judgment, explaining why the answer is appropriate or not.
27
+ 2. answer: The best sentence(s) that meet the Evaluation Criteria as a list of strings. This should be ["None"] if no sentence answers the query. At most, include 3 sentences.
28
+ """
29
+
30
+ few_shot_examples = []
31
+
32
+ num_thread_batch_size = 16
33
+
34
+ config = TransformOpenAIConfig()
35
+ config.prompt_template.instruction = instruction
36
+ config.prompt_template.few_shot_prompt = few_shot_examples
37
+ config.model_config.model_name = "gpt-4-1106-preview"
38
+ config.model_config.response_format = {"type": "json_object"}
39
+ config.model_config.num_call = 1
40
+ config.model_config.temperature = 0.0
41
+ config.model_config.num_thread = num_thread_batch_size
42
+ config.model_config.batch_size = num_thread_batch_size
43
+
44
+ client = TransformClient(config)
45
+
46
+ output = []
47
+
48
+ for paper in data_list:
49
+ init_output = client.run(paper["data"])
50
+ combined_output = init_output[0]
51
+ combined_output["output"][0]["response"][0]["explanation"] = [
52
+ combined_output["output"][0]["response"][0]["explanation"]
53
+ ]
54
+ if DEBUG:
55
+ print(combined_output)
56
+ for item in init_output[1:]:
57
+ combined_output["output"][0]["response"][0]["answer"].extend(
58
+ item["output"][0]["response"][0]["answer"]
59
+ )
60
+ combined_output["output"][0]["response"][0]["explanation"].append(
61
+ item["output"][0]["response"][0]["explanation"]
62
+ )
63
+ output.append(combined_output)
64
+
65
+ output_answers = []
66
+
67
+ for idx, o in enumerate(output):
68
+ filtered_answers = [
69
+ item for item in o["output"][0]["response"][0]["answer"] if item != "None"
70
+ ]
71
+ if len(filtered_answers) == 0:
72
+ filtered_answers = ["None"]
73
+ output_answers.append(
74
+ {"paper": input_data[idx]["paper"], "answer": filtered_answers}
75
+ )
76
+
77
+ df = pd.DataFrame(output_answers)
78
+
79
+ return [output_answers, df]
80
+
81
+
82
+ def generate_answers(papers, queries, progress=gr.Progress()):
83
+ """
84
+ Generate relevant chunks of text from a list of papers based on a list of queries.
85
+ """
86
+ print(len(papers), len(queries))
87
+ output_data = []
88
+ for query in progress.tqdm(queries, desc="Queries"):
89
+ [data, df] = generate_relevant_chunks(query, papers)
90
+ # print("data", data)
91
+ for d in data:
92
+ d["query"] = query
93
+ # data["query"] = query
94
+ output_data.extend(data)
95
+ df = create_df(output_data)
96
+ return output_data, df
97
+
98
+
99
+ def create_df(data):
100
+ query_data = {item["query"]: {} for item in data}
101
+
102
+ # Fill in query data
103
+ for item in data:
104
+ query = item["query"]
105
+ paper = item["paper"]
106
+ answer = item["answer"][0] if item["answer"] else None
107
+ query_data[query][paper] = answer
108
+
109
+ # Create DataFrame from the dictionary
110
+ df = pd.DataFrame.from_dict(query_data, orient="index")
111
+
112
+ # Reset index to include 'Queries' as a column
113
+ df = df.rename_axis("Queries").reset_index()
114
+
115
+ # Reorder columns so that 'Queries' is the first column
116
+ cols = ["Queries"] + [col for col in df.columns if col != "Queries"]
117
+ df = df[cols]
118
+ return df
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ uniflow==0.0.25
2
+ python-dotenv==1.0.1
3
+ gradio==4.19.2
4
+ llama-index==0.10.19
5
+ llama-parse==0.3.9
6
+ rapidfuzz==3.6.2
7
+ dataclasses-json==0.6.4