web-researcher / improve_content.py
anirudhs's picture
added researcher files
8fd59af
from typing import List, TypedDict
from llm_config import get_llm_instructor, call_llm
from pydantic import BaseModel, Field
import ui
import prompts
from search import fetch_search_results, format_search_results
import random
import time
from dotenv import load_dotenv
import re
load_dotenv()
class RoundtableMessage(BaseModel):
response: str = Field(..., title="Your response")
follow_up: str = Field(..., title="Your follow-up question")
next_persona: str = Field(..., title="Who you are asking the question to")
class ContentState(TypedDict):
previous_messages: List[dict]
content: str
expert_question: str
iteration: int
full_messages: List[str]
refernces : str
class Queries(BaseModel):
queries : List[str] = Field(..., title="List of queries to search for")
class PersonaQuestion(BaseModel):
question: str = Field(..., title="Your question for the expert")
class StrucutredAnswer(BaseModel):
answer_response: str = Field(..., title="The response to the question with citations")
references_used: List[int] = Field(..., title="The references used to answer the question")
class ImproveContent:
def __init__(self, section_topic, section_description, section_key_questions, personas):
self.section_topic = section_topic
self.section_description = section_description
self.section_key_questions = section_key_questions
self.client = get_llm_instructor()
self.num_search_result = 1
self.num_interview_rounds = 3
self.personas = personas
self.warm_start_rounds = 10
# Define the initial state
def create_initial_state(self) -> ContentState:
return {
"expert_question": "",
"iteration": 0,
'previous_messages': [],
'full_messages': [],
'references' : ''
}
def expert_question_generator(self, persona, state: ContentState) -> ContentState:
response = call_llm(
instructions=prompts.QUALITY_CHECKER_INSTRUCTIONS,
additional_messages= state['previous_messages'],
context={
"title_description": self.section_description + ":" + self.section_topic,
"key_questions": self.section_key_questions,
'persona': persona.persona
},
response_model=PersonaQuestion,
logging_fn="quality_checker"
)
ui.system_sub_update("-------------------")
ui.system_sub_update(f'{persona.name} ({persona.role},{persona.affiliation}):')
ui.system_sub_update(response.question)
ui.system_sub_update("-------------------")
state["expert_question"] = response.question
state['previous_messages'].append({'role' : 'assistant', 'content': response.question})
state['full_messages'].append(response.question)
return state
def replace_references(self, text: str, references_list: List[int]) -> str:
"""Helper method to replace bracketed references with unique numbering."""
for idx in references_list:
text = text.replace(f"[{idx}]", f"[{self.num_search_result}]")
self.num_search_result += 1
return text
def answer_question(self, persona, state: ContentState):
queries = call_llm(
instructions=prompts.IMPROVE_CONTENT_CREATE_QUERY_INSTRUCTIONS,
model_type='fast',
context={
"section_topic": self.section_topic,
"expert_question": state["expert_question"],
'persona': persona.persona
},
response_model=Queries,
logging_fn="improve_content_create_query"
)
search_results, search_results_list = yield from fetch_search_results(queries.queries, self.task_status, self.section_topic, self.update_ui_fn)
# Hit the search engine to fetch relevant documents
if search_results_list == []:
queries = call_llm(
instructions=prompts.IMPROVE_CONTENT_CREATE_QUERY_INSTRUCTIONS,
model_type='fast',
context={
"section_topic": self.section_topic,
"expert_question": state["expert_question"],
'persona': persona.persona
},
response_model=Queries,
logging_fn="improve_content_create_query_fallback"
)
search_results, search_results_list = yield from fetch_search_results(queries.queries, self.task_status,self.section_topic, self.update_ui_fn)
response = call_llm(
instructions=prompts.IMPORVE_CONTENT_ANSWER_QUERY_INSTRUCTION,
model_type='rag',
context={
"section_topic": self.section_topic,
"expert_question": state["expert_question"],
"search_results": search_results,
'persona' : persona.persona
},
response_model=StrucutredAnswer,
logging_fn="improve_content_answer_query"
)
state["content"] =response.answer_response
references_used = format_search_results([search_results_list[i-1] for i in response.references_used])
# Find all unique bracketed references in the search results
bracketed_refs = re.findall(r'\[(\d+)\](?=\s*Title:)', search_results)
#Replace citations[2,3,4] with [2][3][4]
cited_references_raw = re.findall(r'\[(\d+(?:,\s*\d+)*)\]', response.answer_response)
for group in cited_references_raw:
nums_list = group.split(',')
new_string = ''.join(f'[{n.strip()}]' for n in nums_list)
old_string = f'[{group}]'
response.answer_response = response.answer_response.replace(old_string, new_string)
# Replace each reference number with its a unique search number
for ref in bracketed_refs:
search_results = search_results.replace(f'[{ref}]', f"[{self.num_search_result}]")
response.answer_response = response.answer_response.replace(f'[{ref}]', f"[{self.num_search_result}]")
self.num_search_result += 1
ui.system_sub_update("-------------------")
ui.system_sub_update('Content:')
ui.system_sub_update(response.answer_response)
ui.system_sub_update("-------------------")
state['previous_messages'].append({'role' : 'user', 'content' : response.answer_response})
state['full_messages'].append(response.answer_response)
state['references'] = state['references'] + '\n\n' + search_results
state["iteration"] += 1
return state
def create_and_run_interview(self, task_status, update_ui_fn):
"""Runs an iterative process of generating questions and answers
until the iteration limit is reached."""
self.task_status = task_status
self.update_ui_fn = update_ui_fn
discussion_messages = []
for persona in self.personas:
ui.system_update(f"Starting discussion with : {persona.name}: {persona.role}, {persona.affiliation}")
state = self.create_initial_state()
while state["iteration"] <= self.num_interview_rounds:
state = self.expert_question_generator(persona, state)
state = yield from self.answer_question(persona, state)
discussion_messages.extend(state['previous_messages'])
self.final_state = state
return discussion_messages
def generate_final_section(self, synopsis):
return '\n\n'.join(self.final_state['full_messages']), self.final_state['references']
def warm_start_discussion(self):
"""Warm start the discussion with existing personas"""
messages = [f"{self.personas[0].name}: Hi! Let's get started!"]
selected_persona = random.choice(self.personas)
for _ in range(self.warm_start_rounds):
# Get the last 5 messages if there are more than 5
recent_messages = messages[-5:] if len(messages) > 5 else messages
message = call_llm(
instructions=prompts.ROUNDTABLE_DISCUSSION_INSTRUCTIONS,
model_type='fast',
context={
"persona_name" : selected_persona.name,
"persona_role" : selected_persona.role,
"persona_affiliation" : selected_persona.affiliation,
"persona_focus" : selected_persona.focus,
"personas" :
"\n\n".join([p.name + '\n' + p.persona for p in self.personas if p != selected_persona]),
"discussion" : "\n\n".join(recent_messages)
},
response_model=RoundtableMessage,
logging_fn="roundtable_discussion"
)
ui.system_sub_update("\n\n" + selected_persona.name + ": " + message.response + '\n' + message.follow_up)
messages.append(selected_persona.name + ": " + message.response + '\n' + message.follow_up)
selected_persona = [p for p in self.personas if p.name == message.next_persona][0]
time.sleep(3)
return messages
if __name__ == "__main__":
section_name = 'Glean Search in the Enterprise Search Market'
section_description = 'Positioning and Competition'
section_key_questions = ['how is glean positioned in the enterprise search market?', "who are the main competitors in this space?"]
personas = ['\nRole: Business Analyst\nAffiliation: Enterprise Software Consultant\nDescription: Specializes in helping organizations implement and optimize AI-powered tools for improved productivity and knowledge management. Will analyze Glean and Copilot from a business user perspective.\n']
improve_content = ImproveContent(section_name, section_description, section_key_questions, personas)
improved_content = improve_content.create_and_run_interview()
improve_content.generate_final_section()
print(improved_content)