web-researcher / research_manager.py
anirudhs's picture
added researcher files
8fd59af
import ui
from typing import List
from pydantic import BaseModel, Field
import time
import gradio as gr
from llm_config import call_llm, get_llm_usage
import prompts
from colorama import Fore, Style
# Add these imports at the top
from search import fetch_search_results
from improve_content import ImproveContent
import re
# logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# logger = logging.getLogger(__name__)
class Section(BaseModel):
name: str = Field(
description="Name for this section of the report.",
)
description: str = Field(
description="Brief overview of the main topics and concepts to be covered in this section.",
)
questions: List[str] = Field(
description="Key Questions to answer in this section."
)
content: str = Field(
description="The content of the section."
)
class Sections(BaseModel):
sections: List[Section] = Field(
description="Sections of the report.",
)
@property
def as_str(self) -> str:
subsections = "\n\n".join(
f"## {section.name}\n\n-{section.description}\n\n- Questions: {'\n\n'.join(section.questions)}\n\n- Content: {section.content}\n"
for section in self.sections or []
)
return subsections
def print_sections(self) -> str:
return '\n\n'.join([s.content for s in self.sections])
class ResearchArea(BaseModel):
area : str = Field(..., title="Research Area")
search_terms : str = Field(..., title = "Search Term", description = "Search query that will help you find information")
class ResearchFocus(BaseModel):
areas : List[ResearchArea] = Field(..., title="Research Areas")
class RelevantSearchResults(BaseModel):
relevant_search_results : List[int] = Field(..., title="Relevant Search Results", description="The position of the search result in the search results list")
reasoning : List[str] = Field(..., title="Reasoning", description="Reasoning for selecting the search results")
class SearchTerm(BaseModel):
query : str = Field(..., title="Search Query")
#time_range : str = Field(..., title="Time Range", description="d/w/m/y/none")
class SearchTermsList(BaseModel):
queries : List[str] = Field(..., title="Search Terms as a list")
class Editor(BaseModel):
name: str = Field(
description="Name of the editor.",
)
affiliation: str = Field(
description="Primary affiliation of the editor.",
)
role: str = Field(
description="Role of the editor in the context of the topic.",
)
focus: str = Field(
description="Description of the editor's focus area, concerns and how they will help.",
)
@property
def persona(self) -> str:
return f"\nRole: {self.role}\nAffiliation: {self.affiliation}\nDescription: {self.focus}\n"
class Perspectives(BaseModel):
editors: List[Editor] = Field(
description="Comprehensive list of editors with their roles and affiliations.",
)
class ReportSynopsis(BaseModel):
synopsis: str= Field(..., title="Report Synopsis", description="A synopsis talking about what the reader can expect")
class SectionContent(BaseModel):
content: str = Field(..., title="Section Content", description="The content of the section")
class ResearchManager:
"""Manages the research process including analysis, search, and documentation"""
def __init__(self, research_task):
self.use_existing_outline = True
self.research_task = research_task
self.report_synopsis = ''
self.personas = ''
self.gradio_report_outline = ''
self.task_status = {
'synopsis_draft' : {"name": "Creating synopsis of the report...", "status": "pending"},
'gathering_info' : {"name": "Gathering Info on the topic...", "status": "pending"},
'running_searches' : {"name": "Run search...", "status": "pending"},
'mock_discussion' : {"name": "Conducting mock discussions...", "status": "pending"},
'generating_outline': {"name": "Generating a draft outline...", "status": "pending"},
}
def extract_citation_info(self,text):
"""
Extract citation number and URL from citation text
"""
references = {}
for ref in text:
# Find citation number
citation_match = re.search(r'\[(\d+)\]', ref)
citation_number = citation_match.group(1) if citation_match else None
# Find URL
url_match = re.search(r'URL: (https?://\S+)', ref)
url = url_match.group(1) if url_match else None
references[citation_number] = {
'url': url,
'reference_text': ref
}
return references
def section_writer(self, section: Section):
"""Given an outline of a section, generate search queries,
perform searches and generate the section content"""
improve_content = ImproveContent(section.name,
section.description,
section.questions,
self.personas.editors
)
improved_content = yield from improve_content.create_and_run_interview(self.task_status, self.update_gradio)
content, references = improve_content.generate_final_section(self.report_synopsis)
self.task_status[section.name]["name"] = "Writing Section: " + section.name
yield from self.update_gradio()
ui.system_update(f"Writing Section: {section.name}")
section_content = call_llm(
instructions=prompts.WRITE_SECTION_INSTRUCTIONS,
model_type='slow',
context={"section_description": section.description,
"gathered_info" : '\n\n'.join(content),
"topic": self.research_task['topic'],
"section_title" : section.name,
"synopsis" : self.report_synopsis,
"section_questions" : '\n'.join(section.questions),
'report_type': self.research_task['report_type'],
'section_length': self.research_task['section_length']},
response_model=SectionContent,
logging_fn='write_section_instructions'
)
#references = '\n\n'.join(references)
references_dict = self.extract_citation_info(references.split('\n\n'))
#Replacing citations with [2,3,4] format with [2][3][4]
cited_references_raw = re.findall(r'\[(\d+(?:,\s*\d+)*)\]', section_content.content)
for group in cited_references_raw:
nums_list = group.split(',')
new_string = ''.join(f'[{n.strip()}]' for n in nums_list)
old_string = f'[{group}]'
section_content.content = section_content.content.replace(old_string, new_string)
parsed_cited_references = []
for ref_group in cited_references_raw:
for ref_no in ref_group.split(','):
parsed_cited_references.append(ref_no.strip())
used_references = {}
uncited_sources= []
for reference_no in parsed_cited_references:
reference = references_dict.get(reference_no)
if reference:
used_references[reference_no] = reference
else:
print(f"Reference {reference_no} not found")
uncited_sources.append(reference_no)
section_content.content = section_content.content.replace(f"[{reference_no}]", "[!]")
for ref_no, data in used_references.items():
if data["url"]:
section_content.content = section_content.content.replace(f"[{ref_no}]", f"[[{ref_no}]]({data['url']})")
section.content = section_content.content
print(section_content.content)
self.task_status[section.name]["status"] = "done"
yield from self.update_gradio(report_outline_str=self.report_outline.print_sections(), button_disable=False)
ui.system_update("Waiting for 5 seconds before next section")
time.sleep(5)
return section
def _generate_report_outline(self):
"""Use LLM to generate focus areas for research based on the original query"""
ui.system_update(f"\nGathering Context..")
self.task_status['gathering_info']["status"] = "running"
yield from self.update_gradio()
queries = call_llm(
instructions=prompts.FIND_SEARCH_TERMS_INSTRUCTIONS,
model_type='fast',
context={
"report_type": self.research_task['report_type'],
"original_query": self.research_task['topic'],
"report_synopsis": self.report_synopsis,
},
response_model=SearchTermsList,
logging_fn='find_search_terms_instructions'
)
self.task_status['running_searches']["status"] = "running"
yield from self.update_gradio()
formatted_results, results = yield from fetch_search_results(query=queries.queries,
task_status=self.task_status,
task_name = 'running_searches',
fn = self.update_gradio)
self.context = formatted_results
self.task_status['running_searches']["status"] = "done"
self.task_status['gathering_info']["status"] = "done"
self.task_status['mock_discussion']["status"] = "running"
yield from self.update_gradio()
personas = call_llm(
instructions=prompts.GENERATE_ROUNDTABLE_PERSONAS_INSTRUCTIONS,
model_type='slow',
context={"context": self.context,
"topic": self.research_task['topic'],
"report_synopsis": self.report_synopsis,
'type_of_report': self.research_task['report_type'],
'num_personas': 5},
response_model=Perspectives,
logging_fn='generate_roundtable_personas_instructions'
)
self.task_status['mock_discussion']["name"] = "Started discussions..."
print(personas)
yield from self.update_gradio()
improve_content = ImproveContent(self.research_task['topic'],
"This section will focus on a comprehensive overview of glean",
self.research_task['key_questions'],
personas.editors)
warm_start_discussion = improve_content.warm_start_discussion()
self.task_status['mock_discussion']["name"] = "Mock discussions complete"
self.task_status['mock_discussion']["status"] = "done"
self.task_status['generating_outline']["status"] = "running"
yield from self.update_gradio()
ui.system_update("\nGenerating Report Outline..")
report_outline = call_llm(
instructions=prompts.GENERATE_REPORT_OUTLINE_INSTRUCTIONS,
model_type='slow',
context={
"report_type": self.research_task['report_type'],
"topic": self.research_task['topic'],
"context": self.context,
"discussion": '\n'.join(warm_start_discussion),
'num_sections': 3
},
response_model=Sections,
logging_fn='generate_report_outline_instructions'
)
self.task_status['generating_outline']["status"] = "done"
yield from self.update_gradio(report_outline_str=report_outline.as_str)
print(report_outline.as_str)
return report_outline
def validate_outline_with_human(self, report_outline: Sections) -> Sections:
"""Ask the human feedback and improve the report outline till they say 'OK' """
while True:
ui.system_update("\nPlease provide feedback on the generated report outline")
feedback = ui.get_multiline_input()
if feedback.lower() == 'ok':
return report_outline
ui.system_update("\nImproving the report outline based on your feedback")
extract_sections_chain = prompts.IMPROVE_REPORT_OUTLINE_PROMPT | self.llm.with_structured_output(Sections)
report_outline = extract_sections_chain.invoke({"topic": self.research_task['topic'], "feedback": feedback, "report_outline": report_outline.as_str})
ui.system_output(report_outline.as_str)
def create_report_synopsis(self):
return call_llm(
instructions=prompts.CREATE_SYNOPSIS_INSTRUCTIONS,
model_type='fast',
context={
"report_type": self.research_task['report_type'],
"topic": self.research_task['topic'],
"key_questions": self.research_task['key_questions'],
},
response_model=ReportSynopsis,
logging_fn='create_synopsis_instructions'
)
def update_gradio(self, report_outline_str = '', button_disable = False):
if report_outline_str != '':
self.gradio_report_outline = report_outline_str
yield [gr.update(interactive=button_disable), self.update_ui(), self.gradio_report_outline]
def start_research(self):
"""Main research loop with comprehensive functionality"""
self.task_status['synopsis_draft']["status"] = "running"
yield from self.update_gradio()
ui.system_update(f"Starting research on: {self.research_task['topic']}")
ui.system_update("\nGenerating report outline")
self.report_synopsis = self.create_report_synopsis()
self.task_status['synopsis_draft']["status"] = "done"
yield from self.update_gradio()
self.report_outline = yield from self._generate_report_outline()
#self.report_outline = self.validate_outline_with_human(self.report_outline)
for section in self.report_outline.sections:
self.task_status[section.name] = {"name": f"Starting Section: {section.name}", "status": "pending"}
yield from self.update_gradio()
ui.system_update("\nGenerating personas for writing sections")
self.personas = call_llm(
instructions=prompts.GENERATE_PERSONAS_INSTRUCTIONS,
model_type='slow',
context={
"topic": self.research_task['topic'],
"report_synopsis": self.report_synopsis,
'type_of_report': self.research_task['report_type'],
'num_personas': 2},
response_model=Perspectives,
logging_fn='generate_personas_instructions'
)
ui.system_update("\nWriting Sections....")
for section in self.report_outline.sections:
self.task_status[section.name]["status"] = "running"
yield from self.update_gradio()
ui.system_sub_update(f"\nWriting Section: {section.name}")
section = yield from self.section_writer(section)
for section in self.report_outline.sections:
print(section.content)
def update_ui(self):
completed_tasks = sum(1 for _, task in self.task_status.items() if task["status"] == "done")
total_tasks = len(self.task_status)
progress_percentage = int((completed_tasks / total_tasks) * 100)
html_output = f"""
<style>
.progress-bar-container {{
width: 100%;
background-color: #f3f3f3;
border-radius: 5px;
overflow: hidden;
margin-bottom: 20px;
}}
.progress-bar {{
height: 20px;
width: {progress_percentage}%;
background-color: #3498db;
transition: width 0.3s;
display: flex;
align-items: center;
justify-content: center;
color: white;
font-weight: bold;
font-size: 12px;
}}
.progress-task {{
display: flex;
align-items: center;
gap: 10px;
font-family: 'Helvetica Neue', Arial, sans-serif;
margin: 5px 0;
font-size: 14px;
font-weight: 500;
color: #333;
}}
.progress-task .task-name {{
flex-grow: 1;
}}
.progress-task .icon {{
width: 20px;
height: 20px;
}}
.loading-circle {{
width: 15px;
height: 15px;
border: 3px solid #ccc;
border-top: 3px solid #3498db;
border-radius: 50%;
animation: spin 1s linear infinite;
}}
@keyframes spin {{
0% {{ transform: rotate(0deg); }}
100% {{ transform: rotate(360deg); }}
}}
.done-icon {{
color: #2ecc71;
font-size: 16px;
}}
.checkbox {{
width: 15px;
height: 15px;
border: 1px solid #ccc;
display: inline-block;
margin-right: 10px;
}}
.milestone {{
display: inline-block;
width: 10px;
height: 10px;
background-color: #ccc;
border-radius: 50%;
margin: 0 5px;
}}
.milestone.completed {{
background-color: #2ecc71;
}}
</style>
<div class='progress-bar-container'>
<div class='progress-bar'>{progress_percentage}%</div>
</div>
<div style='display: flex; justify-content: center; margin-bottom: 20px;'>
{''.join([f"<div class='milestone {'completed' if i < completed_tasks else ''}'></div>" for i in range(total_tasks)])}
</div>
"""
for _, task in self.task_status.items():
if task["status"] == "running":
icon = "<div class='loading-circle'></div>"
elif task["status"] == "done":
icon = "<span class='done-icon'>&#10003;</span>"
else:
icon = "<div class='checkbox'></div>"
html_output += f"<div class='progress-task'><span class='icon'>{icon}</span><span class='task-name'>{task['name']}</span></div>"
return html_output