Spaces:
Sleeping
Sleeping
from typing import List, TypedDict | |
from llm_config import get_llm_instructor, call_llm | |
from pydantic import BaseModel, Field | |
import ui | |
import prompts | |
from search import fetch_search_results, format_search_results | |
import random | |
import time | |
from dotenv import load_dotenv | |
import re | |
load_dotenv() | |
class RoundtableMessage(BaseModel): | |
response: str = Field(..., title="Your response") | |
follow_up: str = Field(..., title="Your follow-up question") | |
next_persona: str = Field(..., title="Who you are asking the question to") | |
class ContentState(TypedDict): | |
previous_messages: List[dict] | |
content: str | |
expert_question: str | |
iteration: int | |
full_messages: List[str] | |
refernces : str | |
class Queries(BaseModel): | |
queries : List[str] = Field(..., title="List of queries to search for") | |
class PersonaQuestion(BaseModel): | |
question: str = Field(..., title="Your question for the expert") | |
class StrucutredAnswer(BaseModel): | |
answer_response: str = Field(..., title="The response to the question with citations") | |
references_used: List[int] = Field(..., title="The references used to answer the question") | |
class ImproveContent: | |
def __init__(self, section_topic, section_description, section_key_questions, personas): | |
self.section_topic = section_topic | |
self.section_description = section_description | |
self.section_key_questions = section_key_questions | |
self.client = get_llm_instructor() | |
self.num_search_result = 1 | |
self.num_interview_rounds = 3 | |
self.personas = personas | |
self.warm_start_rounds = 10 | |
# Define the initial state | |
def create_initial_state(self) -> ContentState: | |
return { | |
"expert_question": "", | |
"iteration": 0, | |
'previous_messages': [], | |
'full_messages': [], | |
'references' : '' | |
} | |
def expert_question_generator(self, persona, state: ContentState) -> ContentState: | |
response = call_llm( | |
instructions=prompts.QUALITY_CHECKER_INSTRUCTIONS, | |
additional_messages= state['previous_messages'], | |
context={ | |
"title_description": self.section_description + ":" + self.section_topic, | |
"key_questions": self.section_key_questions, | |
'persona': persona.persona | |
}, | |
response_model=PersonaQuestion, | |
logging_fn="quality_checker" | |
) | |
ui.system_sub_update("-------------------") | |
ui.system_sub_update(f'{persona.name} ({persona.role},{persona.affiliation}):') | |
ui.system_sub_update(response.question) | |
ui.system_sub_update("-------------------") | |
state["expert_question"] = response.question | |
state['previous_messages'].append({'role' : 'assistant', 'content': response.question}) | |
state['full_messages'].append(response.question) | |
return state | |
def replace_references(self, text: str, references_list: List[int]) -> str: | |
"""Helper method to replace bracketed references with unique numbering.""" | |
for idx in references_list: | |
text = text.replace(f"[{idx}]", f"[{self.num_search_result}]") | |
self.num_search_result += 1 | |
return text | |
def answer_question(self, persona, state: ContentState): | |
queries = call_llm( | |
instructions=prompts.IMPROVE_CONTENT_CREATE_QUERY_INSTRUCTIONS, | |
model_type='fast', | |
context={ | |
"section_topic": self.section_topic, | |
"expert_question": state["expert_question"], | |
'persona': persona.persona | |
}, | |
response_model=Queries, | |
logging_fn="improve_content_create_query" | |
) | |
search_results, search_results_list = yield from fetch_search_results(queries.queries, self.task_status, self.section_topic, self.update_ui_fn) | |
# Hit the search engine to fetch relevant documents | |
if search_results_list == []: | |
queries = call_llm( | |
instructions=prompts.IMPROVE_CONTENT_CREATE_QUERY_INSTRUCTIONS, | |
model_type='fast', | |
context={ | |
"section_topic": self.section_topic, | |
"expert_question": state["expert_question"], | |
'persona': persona.persona | |
}, | |
response_model=Queries, | |
logging_fn="improve_content_create_query_fallback" | |
) | |
search_results, search_results_list = yield from fetch_search_results(queries.queries, self.task_status,self.section_topic, self.update_ui_fn) | |
response = call_llm( | |
instructions=prompts.IMPORVE_CONTENT_ANSWER_QUERY_INSTRUCTION, | |
model_type='rag', | |
context={ | |
"section_topic": self.section_topic, | |
"expert_question": state["expert_question"], | |
"search_results": search_results, | |
'persona' : persona.persona | |
}, | |
response_model=StrucutredAnswer, | |
logging_fn="improve_content_answer_query" | |
) | |
state["content"] =response.answer_response | |
references_used = format_search_results([search_results_list[i-1] for i in response.references_used]) | |
# Find all unique bracketed references in the search results | |
bracketed_refs = re.findall(r'\[(\d+)\](?=\s*Title:)', search_results) | |
#Replace citations[2,3,4] with [2][3][4] | |
cited_references_raw = re.findall(r'\[(\d+(?:,\s*\d+)*)\]', response.answer_response) | |
for group in cited_references_raw: | |
nums_list = group.split(',') | |
new_string = ''.join(f'[{n.strip()}]' for n in nums_list) | |
old_string = f'[{group}]' | |
response.answer_response = response.answer_response.replace(old_string, new_string) | |
# Replace each reference number with its a unique search number | |
for ref in bracketed_refs: | |
search_results = search_results.replace(f'[{ref}]', f"[{self.num_search_result}]") | |
response.answer_response = response.answer_response.replace(f'[{ref}]', f"[{self.num_search_result}]") | |
self.num_search_result += 1 | |
ui.system_sub_update("-------------------") | |
ui.system_sub_update('Content:') | |
ui.system_sub_update(response.answer_response) | |
ui.system_sub_update("-------------------") | |
state['previous_messages'].append({'role' : 'user', 'content' : response.answer_response}) | |
state['full_messages'].append(response.answer_response) | |
state['references'] = state['references'] + '\n\n' + search_results | |
state["iteration"] += 1 | |
return state | |
def create_and_run_interview(self, task_status, update_ui_fn): | |
"""Runs an iterative process of generating questions and answers | |
until the iteration limit is reached.""" | |
self.task_status = task_status | |
self.update_ui_fn = update_ui_fn | |
discussion_messages = [] | |
for persona in self.personas: | |
ui.system_update(f"Starting discussion with : {persona.name}: {persona.role}, {persona.affiliation}") | |
state = self.create_initial_state() | |
while state["iteration"] <= self.num_interview_rounds: | |
state = self.expert_question_generator(persona, state) | |
state = yield from self.answer_question(persona, state) | |
discussion_messages.extend(state['previous_messages']) | |
self.final_state = state | |
return discussion_messages | |
def generate_final_section(self, synopsis): | |
return '\n\n'.join(self.final_state['full_messages']), self.final_state['references'] | |
def warm_start_discussion(self): | |
"""Warm start the discussion with existing personas""" | |
messages = [f"{self.personas[0].name}: Hi! Let's get started!"] | |
selected_persona = random.choice(self.personas) | |
for _ in range(self.warm_start_rounds): | |
# Get the last 5 messages if there are more than 5 | |
recent_messages = messages[-5:] if len(messages) > 5 else messages | |
message = call_llm( | |
instructions=prompts.ROUNDTABLE_DISCUSSION_INSTRUCTIONS, | |
model_type='fast', | |
context={ | |
"persona_name" : selected_persona.name, | |
"persona_role" : selected_persona.role, | |
"persona_affiliation" : selected_persona.affiliation, | |
"persona_focus" : selected_persona.focus, | |
"personas" : | |
"\n\n".join([p.name + '\n' + p.persona for p in self.personas if p != selected_persona]), | |
"discussion" : "\n\n".join(recent_messages) | |
}, | |
response_model=RoundtableMessage, | |
logging_fn="roundtable_discussion" | |
) | |
ui.system_sub_update("\n\n" + selected_persona.name + ": " + message.response + '\n' + message.follow_up) | |
messages.append(selected_persona.name + ": " + message.response + '\n' + message.follow_up) | |
selected_persona = [p for p in self.personas if p.name == message.next_persona][0] | |
time.sleep(3) | |
return messages | |
if __name__ == "__main__": | |
section_name = 'Glean Search in the Enterprise Search Market' | |
section_description = 'Positioning and Competition' | |
section_key_questions = ['how is glean positioned in the enterprise search market?', "who are the main competitors in this space?"] | |
personas = ['\nRole: Business Analyst\nAffiliation: Enterprise Software Consultant\nDescription: Specializes in helping organizations implement and optimize AI-powered tools for improved productivity and knowledge management. Will analyze Glean and Copilot from a business user perspective.\n'] | |
improve_content = ImproveContent(section_name, section_description, section_key_questions, personas) | |
improved_content = improve_content.create_and_run_interview() | |
improve_content.generate_final_section() | |
print(improved_content) |