Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -12,42 +12,90 @@ import logging
|
|
12 |
import numpy as np
|
13 |
from collections import deque
|
14 |
from huggingface_hub import InferenceClient
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
-
# Set up logging
|
17 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
18 |
logger = logging.getLogger(__name__)
|
19 |
|
20 |
-
# Configuration
|
21 |
HF_API_KEY = os.environ.get("HF_API_KEY")
|
22 |
if not HF_API_KEY:
|
23 |
raise ValueError("Please set the HF_API_KEY environment variable.")
|
24 |
|
25 |
-
# Initialize Hugging Face Inference Client
|
26 |
client = InferenceClient(provider="hf-inference", api_key=HF_API_KEY)
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
|
33 |
-
MAX_ITERATIONS =
|
34 |
-
TIMEOUT =
|
35 |
RETRY_DELAY = 5
|
36 |
-
NUM_RESULTS =
|
37 |
SIMILARITY_THRESHOLD = 0.15
|
38 |
-
MAX_CONTEXT_ITEMS =
|
39 |
-
MAX_HISTORY_ITEMS =
|
|
|
|
|
|
|
40 |
|
41 |
-
# Load multiple embedding models for different purposes
|
42 |
try:
|
43 |
-
main_similarity_model = SentenceTransformer('all-mpnet-base-v2')
|
44 |
concept_similarity_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
except Exception as e:
|
46 |
-
logger.error(f"Failed to load
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
for attempt in range(retries):
|
52 |
try:
|
53 |
messages = [{"role": "user", "content": prompt}]
|
@@ -64,22 +112,185 @@ def hf_inference(model_name, prompt, max_tokens=500, retries=5):
|
|
64 |
time.sleep(RETRY_DELAY * (1 + attempt))
|
65 |
return {"error": "Request failed after multiple retries."}
|
66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
def tool_search_web(query: str, num_results: int = NUM_RESULTS, safesearch: str = "moderate",
|
68 |
-
|
69 |
try:
|
70 |
with DDGS() as ddgs:
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
if results:
|
74 |
-
return [{"title": r["title"], "snippet": r["body"], "url": r["href"]} for r in results]
|
75 |
else:
|
|
|
|
|
|
|
|
|
|
|
76 |
return []
|
77 |
except Exception as e:
|
78 |
logger.error(f"DuckDuckGo search error: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
return []
|
80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
def tool_reason(prompt: str, search_results: list, reasoning_context: list = [],
|
82 |
-
|
83 |
if not search_results:
|
84 |
return "No search results to reason about."
|
85 |
|
@@ -89,8 +300,22 @@ def tool_reason(prompt: str, search_results: list, reasoning_context: list = [],
|
|
89 |
if focus_areas:
|
90 |
reasoning_input += f"Focus particularly on these aspects: {', '.join(focus_areas)}\n\n"
|
91 |
|
|
|
92 |
for i, result in enumerate(search_results):
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
if reasoning_context:
|
96 |
recent_context = reasoning_context[-MAX_HISTORY_ITEMS:]
|
@@ -99,9 +324,9 @@ def tool_reason(prompt: str, search_results: list, reasoning_context: list = [],
|
|
99 |
if critique:
|
100 |
reasoning_input += f"\n\nRecent critique to address: {critique}\n"
|
101 |
|
102 |
-
reasoning_input += "\nProvide a thorough, nuanced analysis that builds upon previous reasoning if applicable. Consider multiple perspectives
|
103 |
|
104 |
-
reasoning_output =
|
105 |
|
106 |
if isinstance(reasoning_output, dict) and "generated_text" in reasoning_output:
|
107 |
return reasoning_output["generated_text"].strip()
|
@@ -114,14 +339,27 @@ def tool_summarize(insights: list, prompt: str, contradictions: list = []) -> st
|
|
114 |
return "No insights to summarize."
|
115 |
|
116 |
summarization_input = f"Synthesize the following insights into a cohesive and comprehensive summary regarding: '{prompt}'\n\n"
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
if contradictions:
|
120 |
summarization_input += "\n\nAddress these specific contradictions:\n" + "\n".join(contradictions)
|
121 |
|
122 |
-
summarization_input += "\n\nProvide a well-structured summary that:\n1. Presents the main findings\n2. Acknowledges limitations and uncertainties\n3. Highlights areas of consensus and disagreement\n4. Suggests potential directions for further inquiry"
|
123 |
|
124 |
-
summarization_output =
|
125 |
|
126 |
if isinstance(summarization_output, dict) and "generated_text" in summarization_output:
|
127 |
return summarization_output["generated_text"].strip()
|
@@ -130,7 +368,7 @@ def tool_summarize(insights: list, prompt: str, contradictions: list = []) -> st
|
|
130 |
return "Could not generate a summary due to an error."
|
131 |
|
132 |
def tool_generate_search_query(prompt: str, previous_queries: list = [],
|
133 |
-
|
134 |
query_gen_input = f"Generate an effective search query for the following prompt: {prompt}\n"
|
135 |
|
136 |
if previous_queries:
|
@@ -143,7 +381,7 @@ def tool_generate_search_query(prompt: str, previous_queries: list = [],
|
|
143 |
if focus_areas:
|
144 |
query_gen_input += f"Focus particularly on these aspects: {', '.join(focus_areas)}\n"
|
145 |
|
146 |
-
query_gen_input += "Refine the search query based on previous queries, aiming for more precise results.\n"
|
147 |
query_gen_input += "Search Query:"
|
148 |
|
149 |
query_gen_output = hf_inference(MAIN_LLM_MODEL, query_gen_input)
|
@@ -155,13 +393,13 @@ def tool_generate_search_query(prompt: str, previous_queries: list = [],
|
|
155 |
return ""
|
156 |
|
157 |
def tool_critique_reasoning(reasoning_output: str, prompt: str,
|
158 |
-
|
159 |
critique_input = f"Critically evaluate the following reasoning output in relation to the prompt:\n\nPrompt: {prompt}\n\nReasoning: {reasoning_output}\n\n"
|
160 |
|
161 |
if previous_critiques:
|
162 |
critique_input += "Previous critiques that should be addressed:\n" + "\n".join(previous_critiques[-MAX_HISTORY_ITEMS:]) + "\n\n"
|
163 |
|
164 |
-
critique_input += "Identify any flaws, biases, logical fallacies, unsupported claims, or areas for improvement. Be specific and constructive. Suggest concrete ways to enhance the reasoning."
|
165 |
|
166 |
critique_output = hf_inference(CRITIC_LLM_MODEL, critique_input)
|
167 |
|
@@ -175,8 +413,20 @@ def tool_identify_contradictions(insights: list) -> list:
|
|
175 |
if len(insights) < 2:
|
176 |
return []
|
177 |
|
178 |
-
|
179 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
180 |
|
181 |
contradiction_output = hf_inference(CRITIC_LLM_MODEL, contradiction_input)
|
182 |
|
@@ -192,27 +442,46 @@ def tool_identify_contradictions(insights: list) -> list:
|
|
192 |
return []
|
193 |
|
194 |
def tool_identify_focus_areas(prompt: str, insights: list = [],
|
195 |
-
|
196 |
focus_input = f"Based on this research prompt: '{prompt}'\n\n"
|
197 |
|
198 |
if insights:
|
199 |
-
|
|
|
200 |
|
201 |
if failed_areas:
|
202 |
focus_input += f"These focus areas didn't yield useful results: {', '.join(failed_areas)}\n\n"
|
203 |
|
204 |
-
focus_input += "Identify
|
205 |
|
206 |
focus_output = hf_inference(MAIN_LLM_MODEL, focus_input)
|
207 |
|
208 |
if isinstance(focus_output, dict) and "generated_text" in focus_output:
|
209 |
result = focus_output["generated_text"].strip()
|
210 |
areas = re.findall(r'(?:^|\n)(?:\d+\.|\*|\-)\s*(.*?)(?=(?:\n(?:\d+\.|\*|\-|$))|$)', result)
|
211 |
-
return [area.strip() for area in areas if area.strip()][:
|
212 |
|
213 |
logger.error(f"Failed to identify focus areas: {focus_output}")
|
214 |
return []
|
215 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
def filter_results(search_results, prompt, previous_snippets=None):
|
217 |
if not main_similarity_model or not search_results:
|
218 |
return search_results
|
@@ -233,11 +502,14 @@ def filter_results(search_results, prompt, previous_snippets=None):
|
|
233 |
|
234 |
result_embedding = main_similarity_model.encode(combined_text, convert_to_tensor=True)
|
235 |
cosine_score = util.pytorch_cos_sim(prompt_embedding, result_embedding)[0][0].item()
|
236 |
-
|
237 |
if cosine_score >= SIMILARITY_THRESHOLD:
|
238 |
result['relevance_score'] = cosine_score
|
239 |
filtered_results.append(result)
|
240 |
seen_snippets.add(result['snippet'])
|
|
|
|
|
|
|
241 |
|
242 |
filtered_results.sort(key=lambda x: x.get('relevance_score', 0), reverse=True)
|
243 |
return filtered_results
|
@@ -247,14 +519,14 @@ def filter_results(search_results, prompt, previous_snippets=None):
|
|
247 |
return search_results
|
248 |
|
249 |
def tool_extract_key_entities(prompt: str) -> list:
|
250 |
-
entity_input = f"Extract the key entities (people, organizations, concepts, technologies, etc.) from this research prompt that should be investigated individually:\n\n{prompt}\n\nList
|
251 |
|
252 |
entity_output = hf_inference(MAIN_LLM_MODEL, entity_input)
|
253 |
|
254 |
if isinstance(entity_output, dict) and "generated_text" in entity_output:
|
255 |
result = entity_output["generated_text"].strip()
|
256 |
entities = [e.strip() for e in result.split('\n') if e.strip()]
|
257 |
-
return entities[:
|
258 |
|
259 |
logger.error(f"Failed to extract key entities: {entity_output}")
|
260 |
return []
|
@@ -269,9 +541,9 @@ def tool_meta_analyze(entity_insights: Dict[str, list], prompt: str) -> str:
|
|
269 |
if insights:
|
270 |
meta_input += f"\n--- {entity} ---\n" + insights[-1] + "\n"
|
271 |
|
272 |
-
meta_input += "\nProvide a high-level synthesis that identifies:\n1. Common themes across entities\n2. Important differences\n3. How these entities interact or influence each other\n4. The broader implications for the original research question"
|
273 |
|
274 |
-
meta_output =
|
275 |
|
276 |
if isinstance(meta_output, dict) and "generated_text" in meta_output:
|
277 |
return meta_output["generated_text"].strip()
|
@@ -279,6 +551,41 @@ def tool_meta_analyze(entity_insights: Dict[str, list], prompt: str) -> str:
|
|
279 |
logger.error(f"Failed to perform meta-analysis: {meta_output}")
|
280 |
return "Could not generate a meta-analysis due to an error."
|
281 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
tools = {
|
283 |
"search_web": {
|
284 |
"function": tool_search_web,
|
@@ -291,6 +598,45 @@ tools = {
|
|
291 |
"language": {"type": "string", "description": "Optional language code."}
|
292 |
},
|
293 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
"reason": {
|
295 |
"function": tool_reason,
|
296 |
"description": "Analyzes and reasons about information.",
|
@@ -360,6 +706,15 @@ tools = {
|
|
360 |
"entity_insights": {"type": "object", "description": "Dictionary mapping entities to their insights."},
|
361 |
"prompt": {"type": "string", "description": "The original research prompt."}
|
362 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
363 |
}
|
364 |
}
|
365 |
|
@@ -385,6 +740,16 @@ Available Tools:
|
|
385 |
Instructions:
|
386 |
Select the BEST tool and parameters for the current research stage. Output valid JSON. If no tool is appropriate, respond with {}.
|
387 |
Only use provided tools. Be strategic about which tool to use next based on the research progress so far.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
388 |
Example:
|
389 |
{"tool": "search_web", "parameters": {"query": "Eiffel Tower location"}}
|
390 |
Output:
|
@@ -392,28 +757,52 @@ Output:
|
|
392 |
return prompt
|
393 |
|
394 |
def deep_research(prompt):
|
395 |
-
task_description = "You are an advanced research assistant
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
407 |
-
|
408 |
-
|
409 |
-
|
410 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
411 |
if key_entities:
|
412 |
context.append(f"Identified key entities: {key_entities}")
|
413 |
-
intermediate_output += f"Identified key entities for focused research: {
|
414 |
|
415 |
entity_progress = {entity: {'queries': [], 'insights': []} for entity in key_entities}
|
416 |
entity_progress['general'] = {'queries': [], 'insights': []}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
417 |
|
418 |
for i in range(MAX_ITERATIONS):
|
419 |
if key_entities and i > 0:
|
@@ -423,17 +812,39 @@ def deep_research(prompt):
|
|
423 |
current_entity = 'general'
|
424 |
|
425 |
context.append(f"Current focus: {current_entity}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
426 |
|
427 |
if i == 0:
|
428 |
initial_query = tool_generate_search_query(prompt=prompt)
|
429 |
if initial_query:
|
430 |
previous_queries.append(initial_query)
|
431 |
entity_progress['general']['queries'].append(initial_query)
|
432 |
-
search_results = tool_search_web(query=initial_query)
|
433 |
-
filtered_search_results = filter_results(search_results, prompt)
|
434 |
|
435 |
-
|
436 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
437 |
|
438 |
if filtered_search_results:
|
439 |
context.append(f"Initial Search Results: {len(filtered_search_results)} items found")
|
@@ -443,6 +854,7 @@ def deep_research(prompt):
|
|
443 |
entity_progress['general']['insights'].append(reasoning_output)
|
444 |
reasoning_context.append(reasoning_output)
|
445 |
context.append(f"Initial Reasoning: {reasoning_output[:200]}...")
|
|
|
446 |
else:
|
447 |
failed_queries.append(initial_query)
|
448 |
context.append(f"Initial query yielded no relevant results: {initial_query}")
|
@@ -458,14 +870,23 @@ def deep_research(prompt):
|
|
458 |
previous_queries.append(entity_query)
|
459 |
entity_progress[current_entity]['queries'].append(entity_query)
|
460 |
|
461 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
462 |
filtered_search_results = filter_results(search_results,
|
463 |
f"{prompt} {current_entity}",
|
464 |
previous_snippets=seen_snippets)
|
465 |
|
466 |
-
for result in filtered_search_results:
|
467 |
-
seen_snippets.add(result['snippet'])
|
468 |
-
|
469 |
if filtered_search_results:
|
470 |
context.append(f"Entity Search for {current_entity}: {len(filtered_search_results)} results")
|
471 |
|
@@ -485,6 +906,7 @@ def deep_research(prompt):
|
|
485 |
entity_specific_insights[current_entity].append(entity_reasoning)
|
486 |
|
487 |
context.append(f"Reasoning about {current_entity}: {entity_reasoning[:200]}...")
|
|
|
488 |
else:
|
489 |
failed_queries.append(entity_query)
|
490 |
context.append(f"Entity query for {current_entity} yielded no relevant results")
|
@@ -538,6 +960,19 @@ def deep_research(prompt):
|
|
538 |
|
539 |
previous_queries.append(result)
|
540 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
541 |
elif tool_name == "reason":
|
542 |
if current_entity != 'general' and 'reasoning_context' not in parameters:
|
543 |
parameters['reasoning_context'] = entity_progress[current_entity]['insights']
|
@@ -565,25 +1000,9 @@ def deep_research(prompt):
|
|
565 |
entity_specific_insights[current_entity].append(result)
|
566 |
else:
|
567 |
reasoning_context.append(result)
|
568 |
-
|
569 |
all_insights.append(result)
|
570 |
|
571 |
-
elif tool_name == "search_web":
|
572 |
-
result = tool_search_web(**parameters)
|
573 |
-
filtered_result = filter_results(result,
|
574 |
-
prompt if current_entity == 'general' else f"{prompt} {current_entity}",
|
575 |
-
previous_snippets=seen_snippets)
|
576 |
-
|
577 |
-
for r in filtered_result:
|
578 |
-
seen_snippets.add(r['snippet'])
|
579 |
-
|
580 |
-
result = filtered_result
|
581 |
-
|
582 |
-
if not result:
|
583 |
-
query = parameters.get('query', '')
|
584 |
-
if query:
|
585 |
-
failed_queries.append(query)
|
586 |
-
|
587 |
elif tool_name == "critique_reasoning":
|
588 |
if 'previous_critiques' not in parameters:
|
589 |
parameters['previous_critiques'] = previous_critiques
|
@@ -616,6 +1035,16 @@ def deep_research(prompt):
|
|
616 |
failed_areas.extend([area for area in old_focus if area not in result])
|
617 |
context.append(f"New focus areas: {result}")
|
618 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
619 |
elif tool_name == "meta_analyze":
|
620 |
if 'entity_insights' not in parameters:
|
621 |
parameters['entity_insights'] = entity_specific_insights
|
@@ -625,6 +1054,10 @@ def deep_research(prompt):
|
|
625 |
if result:
|
626 |
all_insights.append(result)
|
627 |
context.append(f"Meta-analysis across entities: {result[:200]}...")
|
|
|
|
|
|
|
|
|
628 |
|
629 |
else:
|
630 |
result = tool["function"](**parameters)
|
@@ -646,11 +1079,32 @@ def deep_research(prompt):
|
|
646 |
intermediate_output += f"Iteration {i+1} - Error: {str(e)}\n"
|
647 |
continue
|
648 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
649 |
if len(entity_specific_insights) > 1 and len(all_insights) > 2:
|
650 |
meta_analysis = tool_meta_analyze(entity_insights=entity_specific_insights, prompt=prompt)
|
651 |
if meta_analysis:
|
652 |
all_insights.append(meta_analysis)
|
653 |
intermediate_output += f"Final Meta-Analysis: {meta_analysis[:500]}...\n"
|
|
|
654 |
|
655 |
if all_insights:
|
656 |
final_result = tool_summarize(all_insights, prompt, contradictions)
|
@@ -659,8 +1113,11 @@ def deep_research(prompt):
|
|
659 |
|
660 |
full_output = f"**Research Prompt:** {prompt}\n\n"
|
661 |
|
662 |
-
if
|
663 |
-
|
|
|
|
|
|
|
664 |
|
665 |
full_output += "**Research Process:**\n" + intermediate_output + "\n"
|
666 |
|
@@ -679,7 +1136,8 @@ def deep_research(prompt):
|
|
679 |
|
680 |
return full_output
|
681 |
|
682 |
-
#
|
|
|
683 |
custom_css = """
|
684 |
.gradio-container {
|
685 |
background-color: #f7f9fc;
|
@@ -687,6 +1145,7 @@ custom_css = """
|
|
687 |
.output-box {
|
688 |
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
689 |
line-height: 1.5;
|
|
|
690 |
}
|
691 |
h3 {
|
692 |
color: #2c3e50;
|
@@ -700,7 +1159,6 @@ h3 {
|
|
700 |
}
|
701 |
"""
|
702 |
|
703 |
-
# Create the Gradio interface with enhanced UI
|
704 |
iface = gr.Interface(
|
705 |
fn=deep_research,
|
706 |
inputs=[
|
@@ -708,11 +1166,9 @@ iface = gr.Interface(
|
|
708 |
],
|
709 |
outputs=gr.Textbox(lines=30, placeholder="Research results will appear here...", label="Research Results", elem_classes=["output-box"]),
|
710 |
title="Advanced Multi-Stage Research Assistant",
|
711 |
-
description="""This tool performs deep, multi-faceted research
|
712 |
-
|
713 |
-
|
714 |
-
3. Exploring different perspectives and addressing contradictions
|
715 |
-
4. Synthesizing insights across multiple information sources""",
|
716 |
examples=[
|
717 |
["What are the key factors affecting urban tree survival and how do they vary between developing and developed countries?"],
|
718 |
["Compare and contrast the economic policies of China and the United States over the past two decades, analyzing their impacts on global trade."],
|
@@ -723,18 +1179,8 @@ iface = gr.Interface(
|
|
723 |
theme="default",
|
724 |
cache_examples=False,
|
725 |
css=custom_css,
|
726 |
-
|
727 |
-
analytics_enabled=False,
|
728 |
)
|
729 |
|
730 |
-
|
731 |
-
|
732 |
-
<div class="footer">
|
733 |
-
<p>This research assistant performs advanced multi-stage analysis using natural language processing and web search.</p>
|
734 |
-
<p>Results should be verified with additional sources. Not suitable for medical, legal, or emergency use.</p>
|
735 |
-
</div>
|
736 |
-
"""
|
737 |
-
|
738 |
-
|
739 |
-
# Launch the interface
|
740 |
-
iface.launch(share=False)
|
|
|
12 |
import numpy as np
|
13 |
from collections import deque
|
14 |
from huggingface_hub import InferenceClient
|
15 |
+
import requests
|
16 |
+
import arxiv
|
17 |
+
import scholarly
|
18 |
+
import pymed
|
19 |
+
import wikipedia
|
20 |
+
from newspaper import Article
|
21 |
+
import pickle
|
22 |
+
import faiss
|
23 |
+
import threading
|
24 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
25 |
+
import tiktoken
|
26 |
|
|
|
27 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
28 |
logger = logging.getLogger(__name__)
|
29 |
|
|
|
30 |
HF_API_KEY = os.environ.get("HF_API_KEY")
|
31 |
if not HF_API_KEY:
|
32 |
raise ValueError("Please set the HF_API_KEY environment variable.")
|
33 |
|
|
|
34 |
client = InferenceClient(provider="hf-inference", api_key=HF_API_KEY)
|
35 |
|
36 |
+
MAIN_LLM_MODEL = "meta-llama/Llama-3-70b-instruct"
|
37 |
+
REASONING_LLM_MODEL = "anthropic/claude-3-opus-20240229"
|
38 |
+
CRITIC_LLM_MODEL = "google/gemini-1.5-pro"
|
39 |
+
ENSEMBLE_MODELS = [MAIN_LLM_MODEL, REASONING_LLM_MODEL, CRITIC_LLM_MODEL]
|
40 |
|
41 |
+
MAX_ITERATIONS = 20
|
42 |
+
TIMEOUT = 120
|
43 |
RETRY_DELAY = 5
|
44 |
+
NUM_RESULTS = 15
|
45 |
SIMILARITY_THRESHOLD = 0.15
|
46 |
+
MAX_CONTEXT_ITEMS = 30
|
47 |
+
MAX_HISTORY_ITEMS = 8
|
48 |
+
MAX_FULL_TEXT_LENGTH = 10000
|
49 |
+
FAISS_INDEX_PATH = "research_index.faiss"
|
50 |
+
RESEARCH_DATA_PATH = "research_data.pkl"
|
51 |
|
|
|
52 |
try:
|
53 |
+
main_similarity_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
|
54 |
concept_similarity_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
55 |
+
document_similarity_model = SentenceTransformer('sentence-transformers/multi-qa-mpnet-base-dot-v1')
|
56 |
+
|
57 |
+
embedding_dim = document_similarity_model.get_sentence_embedding_dimension()
|
58 |
+
if os.path.exists(FAISS_INDEX_PATH):
|
59 |
+
index = faiss.read_index(FAISS_INDEX_PATH)
|
60 |
+
logger.info(f"Loaded FAISS index from {FAISS_INDEX_PATH}")
|
61 |
+
else:
|
62 |
+
index = faiss.IndexFlatIP(embedding_dim)
|
63 |
+
logger.info("Created a new FAISS index.")
|
64 |
except Exception as e:
|
65 |
+
logger.error(f"Failed to load models or initialize FAISS: {e}")
|
66 |
+
raise
|
67 |
+
|
68 |
+
def get_token_count(text):
|
69 |
+
try:
|
70 |
+
encoding = tiktoken.get_encoding("cl100k_base")
|
71 |
+
return len(encoding.encode(text))
|
72 |
+
except:
|
73 |
+
return len(text.split()) * 1.3
|
74 |
|
75 |
+
def save_research_data(data, index):
|
76 |
+
try:
|
77 |
+
with open(RESEARCH_DATA_PATH, "wb") as f:
|
78 |
+
pickle.dump(data, f)
|
79 |
+
faiss.write_index(index, FAISS_INDEX_PATH)
|
80 |
+
logger.info(f"Research data and index saved to {RESEARCH_DATA_PATH} and {FAISS_INDEX_PATH}")
|
81 |
+
except Exception as e:
|
82 |
+
logger.error(f"Error saving research data: {e}")
|
83 |
+
|
84 |
+
def load_research_data():
|
85 |
+
if os.path.exists(RESEARCH_DATA_PATH):
|
86 |
+
try:
|
87 |
+
with open(RESEARCH_DATA_PATH, "rb") as f:
|
88 |
+
data = pickle.load(f)
|
89 |
+
logger.info(f"Loaded research data from {RESEARCH_DATA_PATH}")
|
90 |
+
return data
|
91 |
+
except Exception as e:
|
92 |
+
logger.error(f"Error loading research data: {e}")
|
93 |
+
return {}
|
94 |
+
else:
|
95 |
+
logger.info("No existing research data found.")
|
96 |
+
return {}
|
97 |
+
|
98 |
+
def hf_inference(model_name, prompt, max_tokens=2000, retries=5):
|
99 |
for attempt in range(retries):
|
100 |
try:
|
101 |
messages = [{"role": "user", "content": prompt}]
|
|
|
112 |
time.sleep(RETRY_DELAY * (1 + attempt))
|
113 |
return {"error": "Request failed after multiple retries."}
|
114 |
|
115 |
+
def ensemble_inference(prompt, models=ENSEMBLE_MODELS, max_tokens=1500):
|
116 |
+
results = []
|
117 |
+
with ThreadPoolExecutor(max_workers=len(models)) as executor:
|
118 |
+
future_to_model = {executor.submit(hf_inference, model, prompt, max_tokens): model for model in models}
|
119 |
+
for future in as_completed(future_to_model):
|
120 |
+
model = future_to_model[future]
|
121 |
+
try:
|
122 |
+
result = future.result()
|
123 |
+
if "generated_text" in result:
|
124 |
+
results.append({"model": model, "text": result["generated_text"]})
|
125 |
+
except Exception as e:
|
126 |
+
logger.error(f"Error with model {model}: {e}")
|
127 |
+
|
128 |
+
if not results:
|
129 |
+
return {"error": "All models failed to generate responses"}
|
130 |
+
|
131 |
+
if len(results) == 1:
|
132 |
+
return {"generated_text": results[0]["text"]}
|
133 |
+
|
134 |
+
synthesis_prompt = "Synthesize these expert responses into a single coherent answer:\n\n"
|
135 |
+
for result in results:
|
136 |
+
synthesis_prompt += f"Expert {results.index(result) + 1} ({result['model'].split('/')[-1]}):\n{result['text']}\n\n"
|
137 |
+
|
138 |
+
synthesis = hf_inference(MAIN_LLM_MODEL, synthesis_prompt)
|
139 |
+
if "generated_text" in synthesis:
|
140 |
+
return synthesis
|
141 |
+
else:
|
142 |
+
return {"generated_text": max(results, key=lambda x: len(x["text"]))["text"]}
|
143 |
+
|
144 |
def tool_search_web(query: str, num_results: int = NUM_RESULTS, safesearch: str = "moderate",
|
145 |
+
time_filter: Optional[str] = None, region: str = "wt-wt", language: str = "en-us") -> list:
|
146 |
try:
|
147 |
with DDGS() as ddgs:
|
148 |
+
kwargs = {
|
149 |
+
"keywords": query,
|
150 |
+
"max_results": num_results,
|
151 |
+
"safesearch": safesearch,
|
152 |
+
"region": region,
|
153 |
+
"hreflang": language,
|
154 |
+
}
|
155 |
+
if time_filter:
|
156 |
+
if time_filter in ['d', 'w', 'm', 'y']:
|
157 |
+
kwargs["time"] = time_filter
|
158 |
+
|
159 |
+
results = [r for r in ddgs.text(**kwargs)]
|
160 |
if results:
|
161 |
+
return [{"title": r["title"], "snippet": r["body"], "url": r["href"]}] for r in results]
|
162 |
else:
|
163 |
+
if time_filter and "time" in kwargs:
|
164 |
+
del kwargs["time"]
|
165 |
+
results = [r for r in ddgs.text(**kwargs)]
|
166 |
+
if results:
|
167 |
+
return [{"title": r["title"], "snippet": r["body"], "url": r["href"]}] for r in results]
|
168 |
return []
|
169 |
except Exception as e:
|
170 |
logger.error(f"DuckDuckGo search error: {e}")
|
171 |
+
try:
|
172 |
+
with DDGS() as ddgs:
|
173 |
+
results = [r for r in ddgs.text(
|
174 |
+
keywords=query,
|
175 |
+
max_results=num_results,
|
176 |
+
safesearch=safesearch,
|
177 |
+
region=region,
|
178 |
+
hreflang=language
|
179 |
+
)]
|
180 |
+
if results:
|
181 |
+
return [{"title": r["title"], "snippet": r["body"], "url": r["href"]} for r in results]
|
182 |
+
except Exception as e2:
|
183 |
+
logger.error(f"Fallback DuckDuckGo search also failed: {e2}")
|
184 |
return []
|
185 |
|
186 |
+
def tool_search_arxiv(query: str, max_results: int = 5) -> list:
|
187 |
+
try:
|
188 |
+
client = arxiv.Client()
|
189 |
+
search = arxiv.Search(
|
190 |
+
query=query,
|
191 |
+
max_results=max_results,
|
192 |
+
sort_by=arxiv.SortCriterion.Relevance
|
193 |
+
)
|
194 |
+
results = []
|
195 |
+
for paper in client.results(search):
|
196 |
+
results.append({
|
197 |
+
"title": paper.title,
|
198 |
+
"snippet": paper.summary[:500] + "..." if len(paper.summary) > 500 else paper.summary,
|
199 |
+
"url": paper.pdf_url,
|
200 |
+
"authors": ", ".join(author.name for author in paper.authors),
|
201 |
+
"published": paper.published.strftime("%Y-%m-%d") if paper.published else "Unknown",
|
202 |
+
"source": "arXiv"
|
203 |
+
})
|
204 |
+
return results
|
205 |
+
except Exception as e:
|
206 |
+
logger.error(f"arXiv search error: {e}")
|
207 |
+
return []
|
208 |
+
|
209 |
+
def tool_search_pubmed(query: str, max_results: int = 5) -> list:
|
210 |
+
try:
|
211 |
+
pubmed = pymed.PubMed(tool="ResearchAssistant", email="[email protected]")
|
212 |
+
results = list(pubmed.query(query, max_results=max_results))
|
213 |
+
|
214 |
+
output = []
|
215 |
+
for article in results:
|
216 |
+
try:
|
217 |
+
data = article.toDict()
|
218 |
+
output.append({
|
219 |
+
"title": data.get("title", "No title"),
|
220 |
+
"snippet": data.get("abstract", "No abstract")[:500] + "..." if data.get("abstract", "") and len(data.get("abstract", "")) > 500 else data.get("abstract", "No abstract"),
|
221 |
+
"url": f"https://pubmed.ncbi.nlm.nih.gov/{data.get('pubmed_id')}/",
|
222 |
+
"authors": ", ".join(author.get("name", "") for author in data.get("authors", [])),
|
223 |
+
"published": data.get("publication_date", "Unknown"),
|
224 |
+
"source": "PubMed"
|
225 |
+
})
|
226 |
+
except:
|
227 |
+
continue
|
228 |
+
return output
|
229 |
+
except Exception as e:
|
230 |
+
logger.error(f"PubMed search error: {e}")
|
231 |
+
return []
|
232 |
+
|
233 |
+
def tool_search_wikipedia(query: str, max_results: int = 3) -> list:
|
234 |
+
try:
|
235 |
+
search_results = wikipedia.search(query, results=max_results)
|
236 |
+
results = []
|
237 |
+
|
238 |
+
for title in search_results:
|
239 |
+
try:
|
240 |
+
page = wikipedia.page(title)
|
241 |
+
summary = page.summary
|
242 |
+
snippet = summary[:500] + "..." if len(summary) > 500 else summary
|
243 |
+
results.append({
|
244 |
+
"title": page.title,
|
245 |
+
"snippet": snippet,
|
246 |
+
"url": page.url,
|
247 |
+
"source": "Wikipedia"
|
248 |
+
})
|
249 |
+
except (wikipedia.exceptions.DisambiguationError, wikipedia.exceptions.PageError):
|
250 |
+
continue
|
251 |
+
|
252 |
+
return results
|
253 |
+
except Exception as e:
|
254 |
+
logger.error(f"Wikipedia search error: {e}")
|
255 |
+
return []
|
256 |
+
|
257 |
+
def tool_search_scholar(query: str, max_results: int = 5) -> list:
|
258 |
+
try:
|
259 |
+
search_query = scholarly.search_pubs(query)
|
260 |
+
results = []
|
261 |
+
for _ in range(max_results):
|
262 |
+
try:
|
263 |
+
result = next(search_query)
|
264 |
+
results.append({
|
265 |
+
"title": result.get("bib", {}).get("title", "No title"),
|
266 |
+
"snippet": result.get("bib", {}).get("abstract", "No abstract")[:500] + "..." if result.get("bib", {}).get("abstract") else result.get("bib", {}).get("abstract", "No abstract"),
|
267 |
+
"url": result.get("pub_url", "#"),
|
268 |
+
"authors": ", ".join(result.get("bib", {}).get("author", [])),
|
269 |
+
"published": result.get("bib", {}).get("pub_year", "Unknown"),
|
270 |
+
"source": "Google Scholar"
|
271 |
+
})
|
272 |
+
except StopIteration:
|
273 |
+
break
|
274 |
+
except Exception as e:
|
275 |
+
logger.warning(f"Error processing Scholar result: {e}")
|
276 |
+
continue
|
277 |
+
return results
|
278 |
+
except Exception as e:
|
279 |
+
logger.error(f"Google Scholar search error: {e}")
|
280 |
+
return []
|
281 |
+
|
282 |
+
def extract_article_content(url: str) -> str:
|
283 |
+
try:
|
284 |
+
article = Article(url)
|
285 |
+
article.download()
|
286 |
+
article.parse()
|
287 |
+
return article.text
|
288 |
+
except Exception as e:
|
289 |
+
logger.error(f"Failed to extract article content from {url}: {e}")
|
290 |
+
return ""
|
291 |
+
|
292 |
def tool_reason(prompt: str, search_results: list, reasoning_context: list = [],
|
293 |
+
critique: str = "", focus_areas: list = []) -> str:
|
294 |
if not search_results:
|
295 |
return "No search results to reason about."
|
296 |
|
|
|
300 |
if focus_areas:
|
301 |
reasoning_input += f"Focus particularly on these aspects: {', '.join(focus_areas)}\n\n"
|
302 |
|
303 |
+
results_by_source = {}
|
304 |
for i, result in enumerate(search_results):
|
305 |
+
source = result.get('source', 'Web Search')
|
306 |
+
if source not in results_by_source:
|
307 |
+
results_by_source[source] = []
|
308 |
+
results_by_source[source].append((i, result))
|
309 |
+
|
310 |
+
for source, results in results_by_source.items():
|
311 |
+
reasoning_input += f"\n--- {source} Results ---\n"
|
312 |
+
for i, result in results:
|
313 |
+
reasoning_input += f"- Result {i + 1}: Title: {result['title']}\n Snippet: {result['snippet']}\n"
|
314 |
+
if 'authors' in result:
|
315 |
+
reasoning_input += f" Authors: {result['authors']}\n"
|
316 |
+
if 'published' in result:
|
317 |
+
reasoning_input += f" Published: {result['published']}\n"
|
318 |
+
reasoning_input += "\n"
|
319 |
|
320 |
if reasoning_context:
|
321 |
recent_context = reasoning_context[-MAX_HISTORY_ITEMS:]
|
|
|
324 |
if critique:
|
325 |
reasoning_input += f"\n\nRecent critique to address: {critique}\n"
|
326 |
|
327 |
+
reasoning_input += "\nProvide a thorough, nuanced analysis that builds upon previous reasoning if applicable. Consider multiple perspectives, potential contradictions in the search results, and the reliability of different sources."
|
328 |
|
329 |
+
reasoning_output = ensemble_inference(reasoning_input)
|
330 |
|
331 |
if isinstance(reasoning_output, dict) and "generated_text" in reasoning_output:
|
332 |
return reasoning_output["generated_text"].strip()
|
|
|
339 |
return "No insights to summarize."
|
340 |
|
341 |
summarization_input = f"Synthesize the following insights into a cohesive and comprehensive summary regarding: '{prompt}'\n\n"
|
342 |
+
|
343 |
+
max_tokens = 12000
|
344 |
+
selected_insights = []
|
345 |
+
token_count = get_token_count(summarization_input) + get_token_count("\n\n".join(contradictions))
|
346 |
+
|
347 |
+
for insight in reversed(insights):
|
348 |
+
insight_tokens = get_token_count(insight)
|
349 |
+
if token_count + insight_tokens < max_tokens:
|
350 |
+
selected_insights.insert(0, insight)
|
351 |
+
token_count += insight_tokens
|
352 |
+
else:
|
353 |
+
break
|
354 |
+
|
355 |
+
summarization_input += "\n\n".join(selected_insights)
|
356 |
|
357 |
if contradictions:
|
358 |
summarization_input += "\n\nAddress these specific contradictions:\n" + "\n".join(contradictions)
|
359 |
|
360 |
+
summarization_input += "\n\nProvide a well-structured summary that:\n1. Presents the main findings\n2. Acknowledges limitations and uncertainties\n3. Highlights areas of consensus and disagreement\n4. Suggests potential directions for further inquiry\n5. Evaluates the strength of evidence for key claims"
|
361 |
|
362 |
+
summarization_output = ensemble_inference(summarization_input)
|
363 |
|
364 |
if isinstance(summarization_output, dict) and "generated_text" in summarization_output:
|
365 |
return summarization_output["generated_text"].strip()
|
|
|
368 |
return "Could not generate a summary due to an error."
|
369 |
|
370 |
def tool_generate_search_query(prompt: str, previous_queries: list = [],
|
371 |
+
failed_queries: list = [], focus_areas: list = []) -> str:
|
372 |
query_gen_input = f"Generate an effective search query for the following prompt: {prompt}\n"
|
373 |
|
374 |
if previous_queries:
|
|
|
381 |
if focus_areas:
|
382 |
query_gen_input += f"Focus particularly on these aspects: {', '.join(focus_areas)}\n"
|
383 |
|
384 |
+
query_gen_input += "Refine the search query based on previous queries, aiming for more precise results. Consider using advanced search operators like site:, filetype:, intitle:, etc. when appropriate. Make sure the query is well-formed for academic and scientific search engines.\n"
|
385 |
query_gen_input += "Search Query:"
|
386 |
|
387 |
query_gen_output = hf_inference(MAIN_LLM_MODEL, query_gen_input)
|
|
|
393 |
return ""
|
394 |
|
395 |
def tool_critique_reasoning(reasoning_output: str, prompt: str,
|
396 |
+
previous_critiques: list = []) -> str:
|
397 |
critique_input = f"Critically evaluate the following reasoning output in relation to the prompt:\n\nPrompt: {prompt}\n\nReasoning: {reasoning_output}\n\n"
|
398 |
|
399 |
if previous_critiques:
|
400 |
critique_input += "Previous critiques that should be addressed:\n" + "\n".join(previous_critiques[-MAX_HISTORY_ITEMS:]) + "\n\n"
|
401 |
|
402 |
+
critique_input += "Identify any flaws, biases, logical fallacies, unsupported claims, or areas for improvement. Be specific and constructive. Suggest concrete ways to enhance the reasoning. Also evaluate the strength of evidence and whether conclusions are proportionate to the available information."
|
403 |
|
404 |
critique_output = hf_inference(CRITIC_LLM_MODEL, critique_input)
|
405 |
|
|
|
413 |
if len(insights) < 2:
|
414 |
return []
|
415 |
|
416 |
+
max_tokens = 12000
|
417 |
+
selected_insights = []
|
418 |
+
token_count = 0
|
419 |
+
|
420 |
+
for insight in reversed(insights):
|
421 |
+
insight_tokens = get_token_count(insight)
|
422 |
+
if token_count + insight_tokens < max_tokens:
|
423 |
+
selected_insights.insert(0, insight)
|
424 |
+
token_count += insight_tokens
|
425 |
+
else:
|
426 |
+
break
|
427 |
+
|
428 |
+
contradiction_input = "Identify specific contradictions in these insights:\n\n" + "\n\n".join(selected_insights)
|
429 |
+
contradiction_input += "\n\nList each contradiction as a separate numbered point. For each contradiction, cite the specific claims that are in tension and evaluate which claim is better supported. If no contradictions exist, respond with 'No contradictions found.'"
|
430 |
|
431 |
contradiction_output = hf_inference(CRITIC_LLM_MODEL, contradiction_input)
|
432 |
|
|
|
442 |
return []
|
443 |
|
444 |
def tool_identify_focus_areas(prompt: str, insights: list = [],
|
445 |
+
failed_areas: list = []) -> list:
|
446 |
focus_input = f"Based on this research prompt: '{prompt}'\n\n"
|
447 |
|
448 |
if insights:
|
449 |
+
recent_insights = insights[-5:] if len(insights) > 5 else insights
|
450 |
+
focus_input += "And these existing insights:\n" + "\n".join(recent_insights) + "\n\n"
|
451 |
|
452 |
if failed_areas:
|
453 |
focus_input += f"These focus areas didn't yield useful results: {', '.join(failed_areas)}\n\n"
|
454 |
|
455 |
+
focus_input += "Identify 3-5 specific aspects that should be investigated further to get a complete understanding. Be precise and prioritize underexplored areas. For each suggested area, briefly explain why it's important to investigate."
|
456 |
|
457 |
focus_output = hf_inference(MAIN_LLM_MODEL, focus_input)
|
458 |
|
459 |
if isinstance(focus_output, dict) and "generated_text" in focus_output:
|
460 |
result = focus_output["generated_text"].strip()
|
461 |
areas = re.findall(r'(?:^|\n)(?:\d+\.|\*|\-)\s*(.*?)(?=(?:\n(?:\d+\.|\*|\-|$))|$)', result)
|
462 |
+
return [area.strip() for area in areas if area.strip()][:5]
|
463 |
|
464 |
logger.error(f"Failed to identify focus areas: {focus_output}")
|
465 |
return []
|
466 |
|
467 |
+
def add_to_faiss_index(text: str):
|
468 |
+
"""Adds the embedding of the given text to the FAISS index."""
|
469 |
+
embedding = document_similarity_model.encode(text, convert_to_tensor=True)
|
470 |
+
embedding_np = embedding.cpu().numpy().reshape(1, -1) # Ensure 2D array
|
471 |
+
if embedding_np.shape[1] != embedding_dim:
|
472 |
+
logger.error(f"Embedding dimension mismatch: expected {embedding_dim}, got {embedding_np.shape[1]}")
|
473 |
+
return
|
474 |
+
faiss.normalize_L2(embedding_np)
|
475 |
+
index.add(embedding_np)
|
476 |
+
|
477 |
+
def search_faiss_index(query: str, top_k: int = 5) -> List[str]:
|
478 |
+
"""Searches the FAISS index for the most similar texts to the query."""
|
479 |
+
query_embedding = document_similarity_model.encode(query, convert_to_tensor=True)
|
480 |
+
query_embedding_np = query_embedding.cpu().numpy().reshape(1, -1)
|
481 |
+
faiss.normalize_L2(query_embedding_np)
|
482 |
+
distances, indices = index.search(query_embedding_np, top_k)
|
483 |
+
return indices[0].tolist() # Return indices
|
484 |
+
|
485 |
def filter_results(search_results, prompt, previous_snippets=None):
|
486 |
if not main_similarity_model or not search_results:
|
487 |
return search_results
|
|
|
502 |
|
503 |
result_embedding = main_similarity_model.encode(combined_text, convert_to_tensor=True)
|
504 |
cosine_score = util.pytorch_cos_sim(prompt_embedding, result_embedding)[0][0].item()
|
505 |
+
|
506 |
if cosine_score >= SIMILARITY_THRESHOLD:
|
507 |
result['relevance_score'] = cosine_score
|
508 |
filtered_results.append(result)
|
509 |
seen_snippets.add(result['snippet'])
|
510 |
+
# Add snippet to FAISS index
|
511 |
+
add_to_faiss_index(result['snippet'])
|
512 |
+
|
513 |
|
514 |
filtered_results.sort(key=lambda x: x.get('relevance_score', 0), reverse=True)
|
515 |
return filtered_results
|
|
|
519 |
return search_results
|
520 |
|
521 |
def tool_extract_key_entities(prompt: str) -> list:
|
522 |
+
entity_input = f"Extract the key entities (people, organizations, concepts, technologies, events, time periods, locations, etc.) from this research prompt that should be investigated individually:\n\n{prompt}\n\nList the 5-7 most important entities, one per line, with a brief explanation (2-3 sentences) of why each is central to the research question."
|
523 |
|
524 |
entity_output = hf_inference(MAIN_LLM_MODEL, entity_input)
|
525 |
|
526 |
if isinstance(entity_output, dict) and "generated_text" in entity_output:
|
527 |
result = entity_output["generated_text"].strip()
|
528 |
entities = [e.strip() for e in result.split('\n') if e.strip()]
|
529 |
+
return entities[:7]
|
530 |
|
531 |
logger.error(f"Failed to extract key entities: {entity_output}")
|
532 |
return []
|
|
|
541 |
if insights:
|
542 |
meta_input += f"\n--- {entity} ---\n" + insights[-1] + "\n"
|
543 |
|
544 |
+
meta_input += "\nProvide a high-level synthesis that identifies:\n1. Common themes across entities\n2. Important differences and contradictions\n3. How these entities interact or influence each other\n4. The broader implications for the original research question\n5. A systems-level understanding of how these elements fit together"
|
545 |
|
546 |
+
meta_output = ensemble_inference(meta_input)
|
547 |
|
548 |
if isinstance(meta_output, dict) and "generated_text" in meta_output:
|
549 |
return meta_output["generated_text"].strip()
|
|
|
551 |
logger.error(f"Failed to perform meta-analysis: {meta_output}")
|
552 |
return "Could not generate a meta-analysis due to an error."
|
553 |
|
554 |
+
def tool_draft_research_plan(prompt: str, entities: list, focus_areas: list = []) -> str:
|
555 |
+
plan_input = f"Create a detailed research plan for investigating this question: '{prompt}'\n\n"
|
556 |
+
|
557 |
+
if entities:
|
558 |
+
plan_input += "Key entities to investigate:\n" + "\n".join(entities) + "\n\n"
|
559 |
+
|
560 |
+
if focus_areas:
|
561 |
+
plan_input += "Additional focus areas:\n" + "\n".join(focus_areas) + "\n\n"
|
562 |
+
|
563 |
+
plan_input += "The research plan should include:\n"
|
564 |
+
plan_input += "1. Main research questions and sub-questions\n"
|
565 |
+
plan_input += "2. Methodology for investigating each aspect\n"
|
566 |
+
plan_input += "3. Potential sources and databases to consult\n"
|
567 |
+
plan_input += "4. Suggested sequence of investigation\n"
|
568 |
+
plan_input += "5. Potential challenges and how to address them\n"
|
569 |
+
plan_input += "6. Criteria for evaluating the quality of findings"
|
570 |
+
|
571 |
+
plan_output = hf_inference(REASONING_LLM_MODEL, plan_input)
|
572 |
+
|
573 |
+
if isinstance(plan_output, dict) and "generated_text" in plan_output:
|
574 |
+
return plan_output["generated_text"].strip()
|
575 |
+
|
576 |
+
logger.error(f"Failed to generate research plan: {plan_output}")
|
577 |
+
return "Could not generate a research plan due to an error."
|
578 |
+
|
579 |
+
def tool_extract_article(url: str) -> str:
|
580 |
+
content = extract_article_content(url)
|
581 |
+
if not content:
|
582 |
+
return f"Could not extract content from {url}"
|
583 |
+
|
584 |
+
if len(content) > MAX_FULL_TEXT_LENGTH:
|
585 |
+
content = content[:MAX_FULL_TEXT_LENGTH] + "... [content truncated]"
|
586 |
+
|
587 |
+
return content
|
588 |
+
|
589 |
tools = {
|
590 |
"search_web": {
|
591 |
"function": tool_search_web,
|
|
|
598 |
"language": {"type": "string", "description": "Optional language code."}
|
599 |
},
|
600 |
},
|
601 |
+
"search_arxiv": {
|
602 |
+
"function": tool_search_arxiv,
|
603 |
+
"description": "Searches arXiv for scientific papers.",
|
604 |
+
"parameters": {
|
605 |
+
"query": {"type": "string", "description": "The search query for scientific papers."},
|
606 |
+
"max_results": {"type": "integer", "description": "Maximum number of papers to return."}
|
607 |
+
},
|
608 |
+
},
|
609 |
+
"search_pubmed": {
|
610 |
+
"function": tool_search_pubmed,
|
611 |
+
"description": "Searches PubMed for medical and scientific literature.",
|
612 |
+
"parameters": {
|
613 |
+
"query": {"type": "string", "description": "The search query for medical literature."},
|
614 |
+
"max_results": {"type": "integer", "description": "Maximum number of articles to return."}
|
615 |
+
},
|
616 |
+
},
|
617 |
+
"search_wikipedia": {
|
618 |
+
"function": tool_search_wikipedia,
|
619 |
+
"description": "Searches Wikipedia for information.",
|
620 |
+
"parameters": {
|
621 |
+
"query": {"type": "string", "description": "The search query for Wikipedia."},
|
622 |
+
"max_results": {"type": "integer", "description": "Maximum number of articles to return."}
|
623 |
+
},
|
624 |
+
},
|
625 |
+
"search_scholar": {
|
626 |
+
"function": tool_search_scholar,
|
627 |
+
"description": "Searches Google Scholar for academic publications.",
|
628 |
+
"parameters": {
|
629 |
+
"query": {"type": "string", "description": "The search query for Google Scholar."},
|
630 |
+
"max_results": {"type": "integer", "description": "Maximum number of articles to return."}
|
631 |
+
}
|
632 |
+
},
|
633 |
+
"extract_article": {
|
634 |
+
"function": tool_extract_article,
|
635 |
+
"description": "Extracts the main content from a web article URL",
|
636 |
+
"parameters": {
|
637 |
+
"url": {"type": "string", "description": "The URL of the article to extract"}
|
638 |
+
},
|
639 |
+
},
|
640 |
"reason": {
|
641 |
"function": tool_reason,
|
642 |
"description": "Analyzes and reasons about information.",
|
|
|
706 |
"entity_insights": {"type": "object", "description": "Dictionary mapping entities to their insights."},
|
707 |
"prompt": {"type": "string", "description": "The original research prompt."}
|
708 |
},
|
709 |
+
},
|
710 |
+
"draft_research_plan": {
|
711 |
+
"function": tool_draft_research_plan,
|
712 |
+
"description": "Creates a detailed research plan.",
|
713 |
+
"parameters": {
|
714 |
+
"prompt": {"type": "string", "description": "The research question/prompt."},
|
715 |
+
"entities": {"type": "array", "description": "Key entities to investigate."},
|
716 |
+
"focus_areas": {"type": "array", "description": "Additional areas to focus on."}
|
717 |
+
}
|
718 |
}
|
719 |
}
|
720 |
|
|
|
740 |
Instructions:
|
741 |
Select the BEST tool and parameters for the current research stage. Output valid JSON. If no tool is appropriate, respond with {}.
|
742 |
Only use provided tools. Be strategic about which tool to use next based on the research progress so far.
|
743 |
+
|
744 |
+
You MUST be methodical. Think step-by-step:
|
745 |
+
|
746 |
+
1. **Plan:** If it's the very beginning, extract key entities, identify focus areas, and then draft a research plan.
|
747 |
+
2. **Search:** Use a variety of search tools. Start with broad searches, then narrow down. Use specific search tools (arXiv, PubMed, Scholar) for relevant topics.
|
748 |
+
3. **Analyze:** Reason deeply about search results, and critique your reasoning. Identify contradictions. Filter and use FAISS index for relevant information.
|
749 |
+
4. **Refine:** If results are poor, generate *better* search queries. Adjust focus areas.
|
750 |
+
5. **Iterate:** Repeat steps 2-4, focusing on different entities and aspects.
|
751 |
+
6. **Synthesize:** Finally, summarize the findings, addressing contradictions.
|
752 |
+
|
753 |
Example:
|
754 |
{"tool": "search_web", "parameters": {"query": "Eiffel Tower location"}}
|
755 |
Output:
|
|
|
757 |
return prompt
|
758 |
|
759 |
def deep_research(prompt):
|
760 |
+
task_description = "You are an advanced research assistant. Use available tools iteratively, focus on different aspects, follow promising leads, critically evaluate your findings, and build up a comprehensive understanding. Utilize the FAISS index to avoid redundant searches and build a persistent knowledge base."
|
761 |
+
research_data = load_research_data()
|
762 |
+
|
763 |
+
context = research_data.get('context', [])
|
764 |
+
all_insights = research_data.get('all_insights', [])
|
765 |
+
entity_specific_insights = research_data.get('entity_specific_insights', {})
|
766 |
+
intermediate_output = "" # For Gradio display
|
767 |
+
previous_queries = research_data.get('previous_queries', [])
|
768 |
+
failed_queries = research_data.get('failed_queries', [])
|
769 |
+
reasoning_context = research_data.get('reasoning_context', [])
|
770 |
+
previous_critiques = research_data.get('previous_critiques', [])
|
771 |
+
focus_areas = research_data.get('focus_areas', [])
|
772 |
+
failed_areas = research_data.get('failed_areas', [])
|
773 |
+
seen_snippets = set(research_data.get('seen_snippets', []))
|
774 |
+
contradictions = research_data.get('contradictions', [])
|
775 |
+
research_session_id = research_data.get('research_session_id', str(uuid4()))
|
776 |
+
|
777 |
+
# Restore or initialize FAISS index
|
778 |
+
global index
|
779 |
+
if research_data:
|
780 |
+
logger.info("Restoring FAISS Index from loaded data.")
|
781 |
+
else:
|
782 |
+
index.reset() #Start Fresh
|
783 |
+
logger.info("Initialized a fresh FAISS Index")
|
784 |
+
|
785 |
+
key_entities_with_descriptions = tool_extract_key_entities(prompt=prompt)
|
786 |
+
key_entities = [e.split(":")[0].strip() for e in key_entities_with_descriptions]
|
787 |
if key_entities:
|
788 |
context.append(f"Identified key entities: {key_entities}")
|
789 |
+
intermediate_output += f"Identified key entities for focused research: {key_entities_with_descriptions}\n"
|
790 |
|
791 |
entity_progress = {entity: {'queries': [], 'insights': []} for entity in key_entities}
|
792 |
entity_progress['general'] = {'queries': [], 'insights': []}
|
793 |
+
for entity in key_entities + ['general']:
|
794 |
+
if entity in research_data:
|
795 |
+
entity_progress[entity]['queries'] = research_data[entity]['queries']
|
796 |
+
entity_progress[entity]['insights'] = research_data[entity]['insights']
|
797 |
+
|
798 |
+
if i == 0:
|
799 |
+
initial_focus_areas = tool_identify_focus_areas(prompt=prompt)
|
800 |
+
research_plan = tool_draft_research_plan(prompt=prompt, entities=key_entities, focus_areas=initial_focus_areas)
|
801 |
+
context.append(f"Initial Research Plan: {research_plan[:200]}...")
|
802 |
+
intermediate_output += f"Initial Research Plan:\n{research_plan}\n\n"
|
803 |
+
focus_areas = initial_focus_areas
|
804 |
+
elif not focus_areas:
|
805 |
+
focus_areas = tool_identify_focus_areas(prompt=prompt, insights=all_insights, failed_areas=failed_areas)
|
806 |
|
807 |
for i in range(MAX_ITERATIONS):
|
808 |
if key_entities and i > 0:
|
|
|
812 |
current_entity = 'general'
|
813 |
|
814 |
context.append(f"Current focus: {current_entity}")
|
815 |
+
|
816 |
+
# FAISS similarity search before web/arxiv/pubmed searches
|
817 |
+
if i > 0: # Don't do it on first iteration
|
818 |
+
faiss_results_indices = search_faiss_index(prompt if current_entity == 'general' else f"{prompt} {current_entity}")
|
819 |
+
faiss_context = []
|
820 |
+
for idx in faiss_results_indices:
|
821 |
+
if idx < len(all_insights):
|
822 |
+
faiss_context.append(f"Previously found insight: {all_insights[idx]}")
|
823 |
+
if faiss_context:
|
824 |
+
context.extend(faiss_context)
|
825 |
+
intermediate_output += f"Iteration {i+1} - Retrieved {len(faiss_context)} relevant items from FAISS index.\n"
|
826 |
+
|
827 |
|
828 |
if i == 0:
|
829 |
initial_query = tool_generate_search_query(prompt=prompt)
|
830 |
if initial_query:
|
831 |
previous_queries.append(initial_query)
|
832 |
entity_progress['general']['queries'].append(initial_query)
|
|
|
|
|
833 |
|
834 |
+
with ThreadPoolExecutor(max_workers=5) as executor:
|
835 |
+
futures = [
|
836 |
+
executor.submit(tool_search_web, query=initial_query, num_results=NUM_RESULTS),
|
837 |
+
executor.submit(tool_search_arxiv, query=initial_query, max_results=5),
|
838 |
+
executor.submit(tool_search_pubmed, query=initial_query, max_results=5),
|
839 |
+
executor.submit(tool_search_wikipedia, query=initial_query, max_results=3),
|
840 |
+
executor.submit(tool_search_scholar, query=initial_query, max_results=5)
|
841 |
+
]
|
842 |
+
|
843 |
+
search_results = []
|
844 |
+
for future in as_completed(futures):
|
845 |
+
search_results.extend(future.result())
|
846 |
+
|
847 |
+
filtered_search_results = filter_results(search_results, prompt)
|
848 |
|
849 |
if filtered_search_results:
|
850 |
context.append(f"Initial Search Results: {len(filtered_search_results)} items found")
|
|
|
854 |
entity_progress['general']['insights'].append(reasoning_output)
|
855 |
reasoning_context.append(reasoning_output)
|
856 |
context.append(f"Initial Reasoning: {reasoning_output[:200]}...")
|
857 |
+
add_to_faiss_index(reasoning_output) # Add reasoning to FAISS
|
858 |
else:
|
859 |
failed_queries.append(initial_query)
|
860 |
context.append(f"Initial query yielded no relevant results: {initial_query}")
|
|
|
870 |
previous_queries.append(entity_query)
|
871 |
entity_progress[current_entity]['queries'].append(entity_query)
|
872 |
|
873 |
+
with ThreadPoolExecutor(max_workers=5) as executor:
|
874 |
+
futures = [
|
875 |
+
executor.submit(tool_search_web, query=entity_query, num_results=NUM_RESULTS//2),
|
876 |
+
executor.submit(tool_search_arxiv, query=entity_query, max_results=3),
|
877 |
+
executor.submit(tool_search_pubmed, query=entity_query, max_results=3),
|
878 |
+
executor.submit(tool_search_wikipedia, query=entity_query, max_results=2),
|
879 |
+
executor.submit(tool_search_scholar, query=entity_query, max_results=3)
|
880 |
+
]
|
881 |
+
|
882 |
+
search_results = []
|
883 |
+
for future in as_completed(futures):
|
884 |
+
search_results.extend(future.result())
|
885 |
+
|
886 |
filtered_search_results = filter_results(search_results,
|
887 |
f"{prompt} {current_entity}",
|
888 |
previous_snippets=seen_snippets)
|
889 |
|
|
|
|
|
|
|
890 |
if filtered_search_results:
|
891 |
context.append(f"Entity Search for {current_entity}: {len(filtered_search_results)} results")
|
892 |
|
|
|
906 |
entity_specific_insights[current_entity].append(entity_reasoning)
|
907 |
|
908 |
context.append(f"Reasoning about {current_entity}: {entity_reasoning[:200]}...")
|
909 |
+
add_to_faiss_index(entity_reasoning) # Add to FAISS
|
910 |
else:
|
911 |
failed_queries.append(entity_query)
|
912 |
context.append(f"Entity query for {current_entity} yielded no relevant results")
|
|
|
960 |
|
961 |
previous_queries.append(result)
|
962 |
|
963 |
+
elif tool_name in ["search_web", "search_arxiv", "search_pubmed", "search_wikipedia", "search_scholar"]:
|
964 |
+
result = tool["function"](**parameters)
|
965 |
+
search_prompt = prompt
|
966 |
+
if current_entity != 'general':
|
967 |
+
search_prompt = f"{prompt} focusing on {current_entity}"
|
968 |
+
|
969 |
+
filtered_result = filter_results(result, search_prompt, previous_snippets=seen_snippets)
|
970 |
+
|
971 |
+
result = filtered_result
|
972 |
+
|
973 |
+
if not result and 'query' in parameters:
|
974 |
+
failed_queries.append(parameters['query'])
|
975 |
+
|
976 |
elif tool_name == "reason":
|
977 |
if current_entity != 'general' and 'reasoning_context' not in parameters:
|
978 |
parameters['reasoning_context'] = entity_progress[current_entity]['insights']
|
|
|
1000 |
entity_specific_insights[current_entity].append(result)
|
1001 |
else:
|
1002 |
reasoning_context.append(result)
|
1003 |
+
add_to_faiss_index(result) # Add reasoning to FAISS
|
1004 |
all_insights.append(result)
|
1005 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1006 |
elif tool_name == "critique_reasoning":
|
1007 |
if 'previous_critiques' not in parameters:
|
1008 |
parameters['previous_critiques'] = previous_critiques
|
|
|
1035 |
failed_areas.extend([area for area in old_focus if area not in result])
|
1036 |
context.append(f"New focus areas: {result}")
|
1037 |
|
1038 |
+
elif tool_name == "extract_article":
|
1039 |
+
result = tool["function"](**parameters)
|
1040 |
+
if result:
|
1041 |
+
context.append(f"Extracted article content from {parameters['url']}: {result[:200]}...")
|
1042 |
+
reasoning_about_article = tool_reason(prompt=prompt, search_results=[{"title": "Extracted Article", "snippet": result, "url": parameters['url']}])
|
1043 |
+
if reasoning_about_article:
|
1044 |
+
all_insights.append(reasoning_about_article)
|
1045 |
+
add_to_faiss_index(reasoning_about_article) # Add to FAISS
|
1046 |
+
|
1047 |
+
|
1048 |
elif tool_name == "meta_analyze":
|
1049 |
if 'entity_insights' not in parameters:
|
1050 |
parameters['entity_insights'] = entity_specific_insights
|
|
|
1054 |
if result:
|
1055 |
all_insights.append(result)
|
1056 |
context.append(f"Meta-analysis across entities: {result[:200]}...")
|
1057 |
+
add_to_faiss_index(result) # Add to FAISS
|
1058 |
+
|
1059 |
+
elif tool_name == "draft_research_plan":
|
1060 |
+
result = "Research plan already generated."
|
1061 |
|
1062 |
else:
|
1063 |
result = tool["function"](**parameters)
|
|
|
1079 |
intermediate_output += f"Iteration {i+1} - Error: {str(e)}\n"
|
1080 |
continue
|
1081 |
|
1082 |
+
# Save research data after each iteration
|
1083 |
+
research_data = {
|
1084 |
+
'context': context,
|
1085 |
+
'all_insights': all_insights,
|
1086 |
+
'entity_specific_insights': entity_specific_insights,
|
1087 |
+
'previous_queries': previous_queries,
|
1088 |
+
'failed_queries': failed_queries,
|
1089 |
+
'reasoning_context': reasoning_context,
|
1090 |
+
'previous_critiques': previous_critiques,
|
1091 |
+
'focus_areas': focus_areas,
|
1092 |
+
'failed_areas': failed_areas,
|
1093 |
+
'seen_snippets': list(seen_snippets), # Convert set to list for pickling
|
1094 |
+
'contradictions': contradictions,
|
1095 |
+
'research_session_id': research_session_id
|
1096 |
+
}
|
1097 |
+
for entity in entity_progress:
|
1098 |
+
research_data[entity] = entity_progress[entity]
|
1099 |
+
save_research_data(research_data, index)
|
1100 |
+
|
1101 |
+
|
1102 |
if len(entity_specific_insights) > 1 and len(all_insights) > 2:
|
1103 |
meta_analysis = tool_meta_analyze(entity_insights=entity_specific_insights, prompt=prompt)
|
1104 |
if meta_analysis:
|
1105 |
all_insights.append(meta_analysis)
|
1106 |
intermediate_output += f"Final Meta-Analysis: {meta_analysis[:500]}...\n"
|
1107 |
+
add_to_faiss_index(meta_analysis)
|
1108 |
|
1109 |
if all_insights:
|
1110 |
final_result = tool_summarize(all_insights, prompt, contradictions)
|
|
|
1113 |
|
1114 |
full_output = f"**Research Prompt:** {prompt}\n\n"
|
1115 |
|
1116 |
+
if key_entities_with_descriptions:
|
1117 |
+
full_output += f"**Key Entities Identified:**\n"
|
1118 |
+
for entity in key_entities_with_descriptions:
|
1119 |
+
full_output += f"- {entity}\n"
|
1120 |
+
full_output += "\n"
|
1121 |
|
1122 |
full_output += "**Research Process:**\n" + intermediate_output + "\n"
|
1123 |
|
|
|
1136 |
|
1137 |
return full_output
|
1138 |
|
1139 |
+
# Gradio Interface
|
1140 |
+
|
1141 |
custom_css = """
|
1142 |
.gradio-container {
|
1143 |
background-color: #f7f9fc;
|
|
|
1145 |
.output-box {
|
1146 |
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
1147 |
line-height: 1.5;
|
1148 |
+
font-size: 14px; /* Increased font size */
|
1149 |
}
|
1150 |
h3 {
|
1151 |
color: #2c3e50;
|
|
|
1159 |
}
|
1160 |
"""
|
1161 |
|
|
|
1162 |
iface = gr.Interface(
|
1163 |
fn=deep_research,
|
1164 |
inputs=[
|
|
|
1166 |
],
|
1167 |
outputs=gr.Textbox(lines=30, placeholder="Research results will appear here...", label="Research Results", elem_classes=["output-box"]),
|
1168 |
title="Advanced Multi-Stage Research Assistant",
|
1169 |
+
description="""This tool performs deep, multi-faceted research, leveraging multiple search engines,
|
1170 |
+
specialized academic databases, and advanced AI models. It incorporates a persistent knowledge
|
1171 |
+
base using FAISS indexing to avoid redundant searches and build upon previous findings.""",
|
|
|
|
|
1172 |
examples=[
|
1173 |
["What are the key factors affecting urban tree survival and how do they vary between developing and developed countries?"],
|
1174 |
["Compare and contrast the economic policies of China and the United States over the past two decades, analyzing their impacts on global trade."],
|
|
|
1179 |
theme="default",
|
1180 |
cache_examples=False,
|
1181 |
css=custom_css,
|
1182 |
+
allow_flagging="never", # Disable flagging
|
|
|
1183 |
)
|
1184 |
|
1185 |
+
if __name__ == "__main__":
|
1186 |
+
iface.launch(share=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|