Spaces:
Sleeping
Sleeping
import ui | |
from typing import List | |
from pydantic import BaseModel, Field | |
import time | |
import gradio as gr | |
from llm_config import call_llm, get_llm_usage | |
import prompts | |
from colorama import Fore, Style | |
# Add these imports at the top | |
from search import fetch_search_results | |
from improve_content import ImproveContent | |
import re | |
# logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
# logger = logging.getLogger(__name__) | |
class Section(BaseModel): | |
name: str = Field( | |
description="Name for this section of the report.", | |
) | |
description: str = Field( | |
description="Brief overview of the main topics and concepts to be covered in this section.", | |
) | |
questions: List[str] = Field( | |
description="Key Questions to answer in this section." | |
) | |
content: str = Field( | |
description="The content of the section." | |
) | |
class Sections(BaseModel): | |
sections: List[Section] = Field( | |
description="Sections of the report.", | |
) | |
def as_str(self) -> str: | |
subsections = "\n\n".join( | |
f"## {section.name}\n\n-{section.description}\n\n- Questions: {'\n\n'.join(section.questions)}\n\n- Content: {section.content}\n" | |
for section in self.sections or [] | |
) | |
return subsections | |
def print_sections(self) -> str: | |
return '\n\n'.join([s.content for s in self.sections]) | |
class ResearchArea(BaseModel): | |
area : str = Field(..., title="Research Area") | |
search_terms : str = Field(..., title = "Search Term", description = "Search query that will help you find information") | |
class ResearchFocus(BaseModel): | |
areas : List[ResearchArea] = Field(..., title="Research Areas") | |
class RelevantSearchResults(BaseModel): | |
relevant_search_results : List[int] = Field(..., title="Relevant Search Results", description="The position of the search result in the search results list") | |
reasoning : List[str] = Field(..., title="Reasoning", description="Reasoning for selecting the search results") | |
class SearchTerm(BaseModel): | |
query : str = Field(..., title="Search Query") | |
#time_range : str = Field(..., title="Time Range", description="d/w/m/y/none") | |
class SearchTermsList(BaseModel): | |
queries : List[str] = Field(..., title="Search Terms as a list") | |
class Editor(BaseModel): | |
name: str = Field( | |
description="Name of the editor.", | |
) | |
affiliation: str = Field( | |
description="Primary affiliation of the editor.", | |
) | |
role: str = Field( | |
description="Role of the editor in the context of the topic.", | |
) | |
focus: str = Field( | |
description="Description of the editor's focus area, concerns and how they will help.", | |
) | |
def persona(self) -> str: | |
return f"\nRole: {self.role}\nAffiliation: {self.affiliation}\nDescription: {self.focus}\n" | |
class Perspectives(BaseModel): | |
editors: List[Editor] = Field( | |
description="Comprehensive list of editors with their roles and affiliations.", | |
) | |
class ReportSynopsis(BaseModel): | |
synopsis: str= Field(..., title="Report Synopsis", description="A synopsis talking about what the reader can expect") | |
class SectionContent(BaseModel): | |
content: str = Field(..., title="Section Content", description="The content of the section") | |
class ResearchManager: | |
"""Manages the research process including analysis, search, and documentation""" | |
def __init__(self, research_task): | |
self.use_existing_outline = True | |
self.research_task = research_task | |
self.report_synopsis = '' | |
self.personas = '' | |
self.gradio_report_outline = '' | |
self.task_status = { | |
'synopsis_draft' : {"name": "Creating synopsis of the report...", "status": "pending"}, | |
'gathering_info' : {"name": "Gathering Info on the topic...", "status": "pending"}, | |
'running_searches' : {"name": "Run search...", "status": "pending"}, | |
'mock_discussion' : {"name": "Conducting mock discussions...", "status": "pending"}, | |
'generating_outline': {"name": "Generating a draft outline...", "status": "pending"}, | |
} | |
def extract_citation_info(self,text): | |
""" | |
Extract citation number and URL from citation text | |
""" | |
references = {} | |
for ref in text: | |
# Find citation number | |
citation_match = re.search(r'\[(\d+)\]', ref) | |
citation_number = citation_match.group(1) if citation_match else None | |
# Find URL | |
url_match = re.search(r'URL: (https?://\S+)', ref) | |
url = url_match.group(1) if url_match else None | |
references[citation_number] = { | |
'url': url, | |
'reference_text': ref | |
} | |
return references | |
def section_writer(self, section: Section): | |
"""Given an outline of a section, generate search queries, | |
perform searches and generate the section content""" | |
improve_content = ImproveContent(section.name, | |
section.description, | |
section.questions, | |
self.personas.editors | |
) | |
improved_content = yield from improve_content.create_and_run_interview(self.task_status, self.update_gradio) | |
content, references = improve_content.generate_final_section(self.report_synopsis) | |
self.task_status[section.name]["name"] = "Writing Section: " + section.name | |
yield from self.update_gradio() | |
ui.system_update(f"Writing Section: {section.name}") | |
section_content = call_llm( | |
instructions=prompts.WRITE_SECTION_INSTRUCTIONS, | |
model_type='slow', | |
context={"section_description": section.description, | |
"gathered_info" : '\n\n'.join(content), | |
"topic": self.research_task['topic'], | |
"section_title" : section.name, | |
"synopsis" : self.report_synopsis, | |
"section_questions" : '\n'.join(section.questions), | |
'report_type': self.research_task['report_type'], | |
'section_length': self.research_task['section_length']}, | |
response_model=SectionContent, | |
logging_fn='write_section_instructions' | |
) | |
#references = '\n\n'.join(references) | |
references_dict = self.extract_citation_info(references.split('\n\n')) | |
#Replacing citations with [2,3,4] format with [2][3][4] | |
cited_references_raw = re.findall(r'\[(\d+(?:,\s*\d+)*)\]', section_content.content) | |
for group in cited_references_raw: | |
nums_list = group.split(',') | |
new_string = ''.join(f'[{n.strip()}]' for n in nums_list) | |
old_string = f'[{group}]' | |
section_content.content = section_content.content.replace(old_string, new_string) | |
parsed_cited_references = [] | |
for ref_group in cited_references_raw: | |
for ref_no in ref_group.split(','): | |
parsed_cited_references.append(ref_no.strip()) | |
used_references = {} | |
uncited_sources= [] | |
for reference_no in parsed_cited_references: | |
reference = references_dict.get(reference_no) | |
if reference: | |
used_references[reference_no] = reference | |
else: | |
print(f"Reference {reference_no} not found") | |
uncited_sources.append(reference_no) | |
section_content.content = section_content.content.replace(f"[{reference_no}]", "[!]") | |
for ref_no, data in used_references.items(): | |
if data["url"]: | |
section_content.content = section_content.content.replace(f"[{ref_no}]", f"[[{ref_no}]]({data['url']})") | |
section.content = section_content.content | |
print(section_content.content) | |
self.task_status[section.name]["status"] = "done" | |
yield from self.update_gradio(report_outline_str=self.report_outline.print_sections(), button_disable=False) | |
ui.system_update("Waiting for 5 seconds before next section") | |
time.sleep(5) | |
return section | |
def _generate_report_outline(self): | |
"""Use LLM to generate focus areas for research based on the original query""" | |
ui.system_update(f"\nGathering Context..") | |
self.task_status['gathering_info']["status"] = "running" | |
yield from self.update_gradio() | |
queries = call_llm( | |
instructions=prompts.FIND_SEARCH_TERMS_INSTRUCTIONS, | |
model_type='fast', | |
context={ | |
"report_type": self.research_task['report_type'], | |
"original_query": self.research_task['topic'], | |
"report_synopsis": self.report_synopsis, | |
}, | |
response_model=SearchTermsList, | |
logging_fn='find_search_terms_instructions' | |
) | |
self.task_status['running_searches']["status"] = "running" | |
yield from self.update_gradio() | |
formatted_results, results = yield from fetch_search_results(query=queries.queries, | |
task_status=self.task_status, | |
task_name = 'running_searches', | |
fn = self.update_gradio) | |
self.context = formatted_results | |
self.task_status['running_searches']["status"] = "done" | |
self.task_status['gathering_info']["status"] = "done" | |
self.task_status['mock_discussion']["status"] = "running" | |
yield from self.update_gradio() | |
personas = call_llm( | |
instructions=prompts.GENERATE_ROUNDTABLE_PERSONAS_INSTRUCTIONS, | |
model_type='slow', | |
context={"context": self.context, | |
"topic": self.research_task['topic'], | |
"report_synopsis": self.report_synopsis, | |
'type_of_report': self.research_task['report_type'], | |
'num_personas': 5}, | |
response_model=Perspectives, | |
logging_fn='generate_roundtable_personas_instructions' | |
) | |
self.task_status['mock_discussion']["name"] = "Started discussions..." | |
print(personas) | |
yield from self.update_gradio() | |
improve_content = ImproveContent(self.research_task['topic'], | |
"This section will focus on a comprehensive overview of glean", | |
self.research_task['key_questions'], | |
personas.editors) | |
warm_start_discussion = improve_content.warm_start_discussion() | |
self.task_status['mock_discussion']["name"] = "Mock discussions complete" | |
self.task_status['mock_discussion']["status"] = "done" | |
self.task_status['generating_outline']["status"] = "running" | |
yield from self.update_gradio() | |
ui.system_update("\nGenerating Report Outline..") | |
report_outline = call_llm( | |
instructions=prompts.GENERATE_REPORT_OUTLINE_INSTRUCTIONS, | |
model_type='slow', | |
context={ | |
"report_type": self.research_task['report_type'], | |
"topic": self.research_task['topic'], | |
"context": self.context, | |
"discussion": '\n'.join(warm_start_discussion), | |
'num_sections': 3 | |
}, | |
response_model=Sections, | |
logging_fn='generate_report_outline_instructions' | |
) | |
self.task_status['generating_outline']["status"] = "done" | |
yield from self.update_gradio(report_outline_str=report_outline.as_str) | |
print(report_outline.as_str) | |
return report_outline | |
def validate_outline_with_human(self, report_outline: Sections) -> Sections: | |
"""Ask the human feedback and improve the report outline till they say 'OK' """ | |
while True: | |
ui.system_update("\nPlease provide feedback on the generated report outline") | |
feedback = ui.get_multiline_input() | |
if feedback.lower() == 'ok': | |
return report_outline | |
ui.system_update("\nImproving the report outline based on your feedback") | |
extract_sections_chain = prompts.IMPROVE_REPORT_OUTLINE_PROMPT | self.llm.with_structured_output(Sections) | |
report_outline = extract_sections_chain.invoke({"topic": self.research_task['topic'], "feedback": feedback, "report_outline": report_outline.as_str}) | |
ui.system_output(report_outline.as_str) | |
def create_report_synopsis(self): | |
return call_llm( | |
instructions=prompts.CREATE_SYNOPSIS_INSTRUCTIONS, | |
model_type='fast', | |
context={ | |
"report_type": self.research_task['report_type'], | |
"topic": self.research_task['topic'], | |
"key_questions": self.research_task['key_questions'], | |
}, | |
response_model=ReportSynopsis, | |
logging_fn='create_synopsis_instructions' | |
) | |
def update_gradio(self, report_outline_str = '', button_disable = False): | |
if report_outline_str != '': | |
self.gradio_report_outline = report_outline_str | |
yield [gr.update(interactive=button_disable), self.update_ui(), self.gradio_report_outline] | |
def start_research(self): | |
"""Main research loop with comprehensive functionality""" | |
self.task_status['synopsis_draft']["status"] = "running" | |
yield from self.update_gradio() | |
ui.system_update(f"Starting research on: {self.research_task['topic']}") | |
ui.system_update("\nGenerating report outline") | |
self.report_synopsis = self.create_report_synopsis() | |
self.task_status['synopsis_draft']["status"] = "done" | |
yield from self.update_gradio() | |
self.report_outline = yield from self._generate_report_outline() | |
#self.report_outline = self.validate_outline_with_human(self.report_outline) | |
for section in self.report_outline.sections: | |
self.task_status[section.name] = {"name": f"Starting Section: {section.name}", "status": "pending"} | |
yield from self.update_gradio() | |
ui.system_update("\nGenerating personas for writing sections") | |
self.personas = call_llm( | |
instructions=prompts.GENERATE_PERSONAS_INSTRUCTIONS, | |
model_type='slow', | |
context={ | |
"topic": self.research_task['topic'], | |
"report_synopsis": self.report_synopsis, | |
'type_of_report': self.research_task['report_type'], | |
'num_personas': 2}, | |
response_model=Perspectives, | |
logging_fn='generate_personas_instructions' | |
) | |
ui.system_update("\nWriting Sections....") | |
for section in self.report_outline.sections: | |
self.task_status[section.name]["status"] = "running" | |
yield from self.update_gradio() | |
ui.system_sub_update(f"\nWriting Section: {section.name}") | |
section = yield from self.section_writer(section) | |
for section in self.report_outline.sections: | |
print(section.content) | |
def update_ui(self): | |
completed_tasks = sum(1 for _, task in self.task_status.items() if task["status"] == "done") | |
total_tasks = len(self.task_status) | |
progress_percentage = int((completed_tasks / total_tasks) * 100) | |
html_output = f""" | |
<style> | |
.progress-bar-container {{ | |
width: 100%; | |
background-color: #f3f3f3; | |
border-radius: 5px; | |
overflow: hidden; | |
margin-bottom: 20px; | |
}} | |
.progress-bar {{ | |
height: 20px; | |
width: {progress_percentage}%; | |
background-color: #3498db; | |
transition: width 0.3s; | |
display: flex; | |
align-items: center; | |
justify-content: center; | |
color: white; | |
font-weight: bold; | |
font-size: 12px; | |
}} | |
.progress-task {{ | |
display: flex; | |
align-items: center; | |
gap: 10px; | |
font-family: 'Helvetica Neue', Arial, sans-serif; | |
margin: 5px 0; | |
font-size: 14px; | |
font-weight: 500; | |
color: #333; | |
}} | |
.progress-task .task-name {{ | |
flex-grow: 1; | |
}} | |
.progress-task .icon {{ | |
width: 20px; | |
height: 20px; | |
}} | |
.loading-circle {{ | |
width: 15px; | |
height: 15px; | |
border: 3px solid #ccc; | |
border-top: 3px solid #3498db; | |
border-radius: 50%; | |
animation: spin 1s linear infinite; | |
}} | |
@keyframes spin {{ | |
0% {{ transform: rotate(0deg); }} | |
100% {{ transform: rotate(360deg); }} | |
}} | |
.done-icon {{ | |
color: #2ecc71; | |
font-size: 16px; | |
}} | |
.checkbox {{ | |
width: 15px; | |
height: 15px; | |
border: 1px solid #ccc; | |
display: inline-block; | |
margin-right: 10px; | |
}} | |
.milestone {{ | |
display: inline-block; | |
width: 10px; | |
height: 10px; | |
background-color: #ccc; | |
border-radius: 50%; | |
margin: 0 5px; | |
}} | |
.milestone.completed {{ | |
background-color: #2ecc71; | |
}} | |
</style> | |
<div class='progress-bar-container'> | |
<div class='progress-bar'>{progress_percentage}%</div> | |
</div> | |
<div style='display: flex; justify-content: center; margin-bottom: 20px;'> | |
{''.join([f"<div class='milestone {'completed' if i < completed_tasks else ''}'></div>" for i in range(total_tasks)])} | |
</div> | |
""" | |
for _, task in self.task_status.items(): | |
if task["status"] == "running": | |
icon = "<div class='loading-circle'></div>" | |
elif task["status"] == "done": | |
icon = "<span class='done-icon'>✓</span>" | |
else: | |
icon = "<div class='checkbox'></div>" | |
html_output += f"<div class='progress-task'><span class='icon'>{icon}</span><span class='task-name'>{task['name']}</span></div>" | |
return html_output | |