Spaces:
Sleeping
Sleeping
yash bhaskar
commited on
Commit
·
d2dbe42
1
Parent(s):
cab49b2
Add application file
Browse files
app.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import json
|
3 |
+
|
4 |
+
# Import your modules here
|
5 |
+
from Agents.togetherAIAgent import generate_article_from_query
|
6 |
+
from Agents.wikiAgent import get_wiki_data
|
7 |
+
from Agents.rankerAgent import rankerAgent
|
8 |
+
from Query_Modification.QueryModification import query_Modifier, getKeywords
|
9 |
+
from Ranking.RRF.RRF_implementation import reciprocal_rank_fusion_three, reciprocal_rank_fusion_six
|
10 |
+
from Retrieval.tf_idf import tf_idf_pipeline
|
11 |
+
from Retrieval.bm25 import bm25_pipeline
|
12 |
+
from Retrieval.vision import vision_pipeline
|
13 |
+
from Retrieval.openSource import open_source_pipeline
|
14 |
+
from Baseline.boolean import boolean_pipeline
|
15 |
+
from AnswerGeneration.getAnswer import generate_answer_withContext, generate_answer_zeroShot
|
16 |
+
|
17 |
+
# Load miniWikiCollection
|
18 |
+
miniWikiCollection = json.load(open('Datasets/mini_wiki_collection.json', 'r'))
|
19 |
+
miniWikiCollectionDict = {wiki['wikipedia_id']: " ".join(wiki['text']) for wiki in miniWikiCollection}
|
20 |
+
|
21 |
+
def process_query(query):
|
22 |
+
# Query modification
|
23 |
+
modified_query = query_Modifier(query)
|
24 |
+
|
25 |
+
# Context Generation
|
26 |
+
article = generate_article_from_query(query)
|
27 |
+
|
28 |
+
# Keyword Extraction and getting context from Wiki
|
29 |
+
keywords = getKeywords(query)
|
30 |
+
wiki_data = get_wiki_data(keywords)
|
31 |
+
|
32 |
+
# Retrieve rankings
|
33 |
+
boolean_ranking = boolean_pipeline(query)
|
34 |
+
tf_idf_ranking = tf_idf_pipeline(query)
|
35 |
+
bm25_ranking = bm25_pipeline(query)
|
36 |
+
vision_ranking = vision_pipeline(query)
|
37 |
+
open_source_ranking = open_source_pipeline(query)
|
38 |
+
|
39 |
+
# Modified queries
|
40 |
+
boolean_ranking_modified = boolean_pipeline(modified_query)
|
41 |
+
tf_idf_ranking_modified = tf_idf_pipeline(modified_query)
|
42 |
+
bm25_ranking_modified = bm25_pipeline(modified_query)
|
43 |
+
vision_ranking_modified = vision_pipeline(modified_query)
|
44 |
+
open_source_ranking_modified = open_source_pipeline(modified_query)
|
45 |
+
|
46 |
+
# RRF rankings
|
47 |
+
tf_idf_bm25_open_RRF_Ranking = reciprocal_rank_fusion_three(tf_idf_ranking, bm25_ranking, open_source_ranking)
|
48 |
+
tf_idf_bm25_open_RRF_Ranking_modified = reciprocal_rank_fusion_three(tf_idf_ranking_modified, bm25_ranking_modified, open_source_ranking_modified)
|
49 |
+
tf_idf_bm25_open_RRF_Ranking_combined = reciprocal_rank_fusion_six(
|
50 |
+
tf_idf_ranking, bm25_ranking, open_source_ranking,
|
51 |
+
tf_idf_ranking_modified, bm25_ranking_modified, open_source_ranking_modified
|
52 |
+
)
|
53 |
+
|
54 |
+
agent1_context = wiki_data[0]
|
55 |
+
agent2_context = article
|
56 |
+
|
57 |
+
boolean_context = miniWikiCollectionDict[boolean_ranking[0]]
|
58 |
+
tf_idf_context = miniWikiCollectionDict[tf_idf_ranking[0]]
|
59 |
+
bm25_context = miniWikiCollectionDict[str(bm25_ranking[0])]
|
60 |
+
vision_context = miniWikiCollectionDict[vision_ranking[0]]
|
61 |
+
open_source_context = miniWikiCollectionDict[open_source_ranking[0]]
|
62 |
+
|
63 |
+
boolean_context_modified = miniWikiCollectionDict[boolean_ranking_modified[0]]
|
64 |
+
tf_idf_context_modified = miniWikiCollectionDict[tf_idf_ranking_modified[0]]
|
65 |
+
bm25_context_modified = miniWikiCollectionDict[str(bm25_ranking_modified[0])]
|
66 |
+
vision_context_modified = miniWikiCollectionDict[vision_ranking_modified[0]]
|
67 |
+
open_source_context_modified = miniWikiCollectionDict[open_source_ranking_modified[0]]
|
68 |
+
|
69 |
+
tf_idf_bm25_open_RRF_Ranking_context = miniWikiCollectionDict[tf_idf_bm25_open_RRF_Ranking[0]]
|
70 |
+
tf_idf_bm25_open_RRF_Ranking_modified_context = miniWikiCollectionDict[tf_idf_bm25_open_RRF_Ranking_modified[0]]
|
71 |
+
tf_idf_bm25_open_RRF_Ranking_combined_context = miniWikiCollectionDict[tf_idf_bm25_open_RRF_Ranking_combined[0]]
|
72 |
+
|
73 |
+
# Generating answers
|
74 |
+
agent1_answer = generate_answer_withContext(query, agent1_context)
|
75 |
+
agent2_answer = generate_answer_withContext(query, agent2_context)
|
76 |
+
|
77 |
+
boolean_answer = generate_answer_withContext(query, boolean_context)
|
78 |
+
tf_idf_answer = generate_answer_withContext(query, tf_idf_context)
|
79 |
+
bm25_answer = generate_answer_withContext(query, bm25_context)
|
80 |
+
vision_answer = generate_answer_withContext(query, vision_context)
|
81 |
+
open_source_answer = generate_answer_withContext(query, open_source_context)
|
82 |
+
|
83 |
+
boolean_answer_modified = generate_answer_withContext(modified_query, boolean_context_modified)
|
84 |
+
tf_idf_answer_modified = generate_answer_withContext(modified_query, tf_idf_context_modified)
|
85 |
+
bm25_answer_modified = generate_answer_withContext(modified_query, bm25_context_modified)
|
86 |
+
vision_answer_modified = generate_answer_withContext(modified_query, vision_context_modified)
|
87 |
+
open_source_answer_modified = generate_answer_withContext(modified_query, open_source_context_modified)
|
88 |
+
|
89 |
+
tf_idf_bm25_open_RRF_Ranking_answer = generate_answer_withContext(query, tf_idf_bm25_open_RRF_Ranking_context)
|
90 |
+
tf_idf_bm25_open_RRF_Ranking_modified_answer = generate_answer_withContext(modified_query, tf_idf_bm25_open_RRF_Ranking_modified_context)
|
91 |
+
tf_idf_bm25_open_RRF_Ranking_combined_answer = generate_answer_withContext(query, tf_idf_bm25_open_RRF_Ranking_combined_context)
|
92 |
+
|
93 |
+
zeroShot = generate_answer_zeroShot(query)
|
94 |
+
|
95 |
+
# Ranking the best answer
|
96 |
+
rankerAgentInput = {
|
97 |
+
"query": query,
|
98 |
+
"agent1": agent1_answer,
|
99 |
+
"agent2": agent2_answer,
|
100 |
+
"boolean": boolean_answer,
|
101 |
+
"tf_idf": tf_idf_answer,
|
102 |
+
"bm25": bm25_answer,
|
103 |
+
"vision": vision_answer,
|
104 |
+
"open_source": open_source_answer,
|
105 |
+
"boolean_modified": boolean_answer_modified,
|
106 |
+
"tf_idf_modified": tf_idf_answer_modified,
|
107 |
+
"bm25_modified": bm25_answer_modified,
|
108 |
+
"vision_modified": vision_answer_modified,
|
109 |
+
"open_source_modified": open_source_answer_modified,
|
110 |
+
"tf_idf_bm25_open_RRF_Ranking": tf_idf_bm25_open_RRF_Ranking_answer,
|
111 |
+
"tf_idf_bm25_open_RRF_Ranking_modified": tf_idf_bm25_open_RRF_Ranking_modified_answer,
|
112 |
+
"tf_idf_bm25_open_RRF_Ranking_combined": tf_idf_bm25_open_RRF_Ranking_combined_answer,
|
113 |
+
"zeroShot": zeroShot
|
114 |
+
}
|
115 |
+
|
116 |
+
best_model, best_answer = rankerAgent(rankerAgentInput)
|
117 |
+
|
118 |
+
return (
|
119 |
+
best_model,
|
120 |
+
best_answer,
|
121 |
+
agent1_answer, agent1_context,
|
122 |
+
agent2_answer, agent2_context,
|
123 |
+
boolean_answer, boolean_context,
|
124 |
+
tf_idf_answer, tf_idf_context,
|
125 |
+
bm25_answer, bm25_context,
|
126 |
+
vision_answer, vision_context,
|
127 |
+
open_source_answer, open_source_context,
|
128 |
+
boolean_answer_modified, boolean_context_modified,
|
129 |
+
tf_idf_answer_modified, tf_idf_context_modified,
|
130 |
+
bm25_answer_modified, bm25_context_modified,
|
131 |
+
vision_answer_modified, vision_context_modified,
|
132 |
+
open_source_answer_modified, open_source_context_modified,
|
133 |
+
tf_idf_bm25_open_RRF_Ranking_answer, tf_idf_bm25_open_RRF_Ranking_context,
|
134 |
+
tf_idf_bm25_open_RRF_Ranking_modified_answer, tf_idf_bm25_open_RRF_Ranking_modified_context,
|
135 |
+
tf_idf_bm25_open_RRF_Ranking_combined_answer, tf_idf_bm25_open_RRF_Ranking_combined_context,
|
136 |
+
zeroShot, "Zero-shot doesn't have a context."
|
137 |
+
)
|
138 |
+
|
139 |
+
# Interface creation
|
140 |
+
def create_interface():
|
141 |
+
with gr.Blocks() as interface:
|
142 |
+
query_input = gr.Textbox(label="Enter your query")
|
143 |
+
best_model_output = gr.Textbox(label="Best Model", interactive=False)
|
144 |
+
best_answer_output = gr.Textbox(label="Best Answer", interactive=False)
|
145 |
+
|
146 |
+
def create_answer_row(label):
|
147 |
+
with gr.Row():
|
148 |
+
answer_textbox = gr.Textbox(label=f"{label} Answer", interactive=False)
|
149 |
+
context_button = gr.Button(f"Show {label} Context")
|
150 |
+
context_textbox = gr.Textbox(label=f"{label} Context", visible=False)
|
151 |
+
|
152 |
+
# Event to show the context
|
153 |
+
context_button.click(
|
154 |
+
fn=lambda x: gr.update(visible=True, value=x),
|
155 |
+
inputs=None,
|
156 |
+
outputs=context_textbox
|
157 |
+
)
|
158 |
+
return answer_textbox, context_textbox
|
159 |
+
|
160 |
+
agent1_output, agent1_context_output = create_answer_row("Agent 1")
|
161 |
+
|
162 |
+
agent2_output, agent2_context_output = create_answer_row("Agent 2")
|
163 |
+
boolean_output, boolean_context_output = create_answer_row("Boolean")
|
164 |
+
tf_idf_output, tf_idf_context_output = create_answer_row("TF-IDF")
|
165 |
+
bm25_output, bm25_context_output = create_answer_row("BM25")
|
166 |
+
vision_output, vision_context_output = create_answer_row("Vision")
|
167 |
+
open_source_output, open_source_context_output = create_answer_row("Open Source")
|
168 |
+
|
169 |
+
boolean_mod_output, boolean_mod_context_output = create_answer_row("Boolean (Modified)")
|
170 |
+
tf_idf_mod_output, tf_idf_mod_context_output = create_answer_row("TF-IDF (Modified)")
|
171 |
+
bm25_mod_output, bm25_mod_context_output = create_answer_row("BM25 (Modified)")
|
172 |
+
vision_mod_output, vision_mod_context_output = create_answer_row("Vision (Modified)")
|
173 |
+
open_source_mod_output, open_source_context_output = create_answer_row("Open Source (Modified)")
|
174 |
+
|
175 |
+
tf_idf_rrf_output, tf_idf_rrf_context_output = create_answer_row("TF-IDF + BM25 + Open RRF")
|
176 |
+
tf_idf_rrf_mod_output, tf_idf_rrf_mod_context_output = create_answer_row("TF-IDF + BM25 + Open RRF (Modified)")
|
177 |
+
tf_idf_rrf_combined_output, tf_idf_rrf_combined_context_output = create_answer_row("TF-IDF + BM25 + Open RRF (Combined)")
|
178 |
+
|
179 |
+
zero_shot_output, zero_shot_context_output = create_answer_row("Zero Shot")
|
180 |
+
|
181 |
+
gr.Button("Submit").click(
|
182 |
+
fn=process_query,
|
183 |
+
inputs=query_input,
|
184 |
+
outputs=[
|
185 |
+
best_model_output,
|
186 |
+
best_answer_output,
|
187 |
+
agent1_output, agent1_context_output,
|
188 |
+
agent2_output, agent2_context_output,
|
189 |
+
boolean_output, boolean_context_output,
|
190 |
+
tf_idf_output, tf_idf_context_output,
|
191 |
+
bm25_output, bm25_context_output,
|
192 |
+
vision_output, vision_context_output,
|
193 |
+
open_source_output, open_source_context_output,
|
194 |
+
boolean_mod_output, boolean_mod_context_output,
|
195 |
+
tf_idf_mod_output, tf_idf_mod_context_output,
|
196 |
+
bm25_mod_output, bm25_mod_context_output,
|
197 |
+
vision_mod_output, vision_mod_context_output,
|
198 |
+
open_source_mod_output, open_source_context_output,
|
199 |
+
tf_idf_rrf_output, tf_idf_rrf_context_output,
|
200 |
+
tf_idf_rrf_mod_output, tf_idf_rrf_mod_context_output,
|
201 |
+
tf_idf_rrf_combined_output, tf_idf_rrf_combined_context_output,
|
202 |
+
zero_shot_output, zero_shot_context_output
|
203 |
+
]
|
204 |
+
)
|
205 |
+
|
206 |
+
return interface
|
207 |
+
|
208 |
+
# Launch the interface
|
209 |
+
if __name__ == "__main__":
|
210 |
+
interface = create_interface()
|
211 |
+
interface.launch()
|