File size: 10,181 Bytes
8fd59af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
from typing import List, TypedDict
from llm_config import get_llm_instructor, call_llm
from pydantic import BaseModel, Field
import ui
import prompts
from search import fetch_search_results, format_search_results
import random
import time
from dotenv import load_dotenv
import re


load_dotenv()

class RoundtableMessage(BaseModel):
    response: str = Field(..., title="Your response")
    follow_up: str = Field(..., title="Your follow-up question")
    next_persona: str = Field(..., title="Who you are asking the question to")

class ContentState(TypedDict):
    previous_messages: List[dict]
    content: str
    expert_question: str
    iteration: int
    full_messages: List[str]
    refernces : str

class Queries(BaseModel):
    queries : List[str] = Field(..., title="List of queries to search for")

class PersonaQuestion(BaseModel):
    question: str = Field(..., title="Your question for the expert")

class StrucutredAnswer(BaseModel):
    answer_response: str = Field(..., title="The response to the question with citations")
    references_used: List[int] = Field(..., title="The references used to answer the question")

        



class ImproveContent:
    def __init__(self, section_topic, section_description, section_key_questions, personas):
        self.section_topic = section_topic
        self.section_description = section_description
        self.section_key_questions = section_key_questions

        self.client = get_llm_instructor()
        self.num_search_result = 1
        self.num_interview_rounds = 3
        self.personas = personas
        self.warm_start_rounds = 10


    
    # Define the initial state
    def create_initial_state(self) -> ContentState:
        return {
            "expert_question": "",
            "iteration": 0,
            'previous_messages': [],
            'full_messages': [],
            'references' : ''
        }


    def expert_question_generator(self, persona, state: ContentState) -> ContentState:
        
        response = call_llm(
            instructions=prompts.QUALITY_CHECKER_INSTRUCTIONS,
            additional_messages= state['previous_messages'],
            context={
                "title_description": self.section_description + ":" + self.section_topic,
                "key_questions": self.section_key_questions,
                'persona': persona.persona
            },
            response_model=PersonaQuestion,
            logging_fn="quality_checker"
        )
        ui.system_sub_update("-------------------")
        ui.system_sub_update(f'{persona.name} ({persona.role},{persona.affiliation}):')
        ui.system_sub_update(response.question)
        ui.system_sub_update("-------------------")
        state["expert_question"] = response.question
        state['previous_messages'].append({'role' : 'assistant', 'content': response.question})
        state['full_messages'].append(response.question)
        return state

    def replace_references(self, text: str, references_list: List[int]) -> str:
        """Helper method to replace bracketed references with unique numbering."""
        for idx in references_list:
            text = text.replace(f"[{idx}]", f"[{self.num_search_result}]")
            self.num_search_result += 1
        return text

    def answer_question(self, persona, state: ContentState):
        
        queries = call_llm(
            instructions=prompts.IMPROVE_CONTENT_CREATE_QUERY_INSTRUCTIONS,
            model_type='fast',
            context={
                "section_topic": self.section_topic,
                "expert_question": state["expert_question"],
                'persona': persona.persona
            },
            response_model=Queries,
            logging_fn="improve_content_create_query"
        )
        search_results, search_results_list = yield from fetch_search_results(queries.queries, self.task_status, self.section_topic, self.update_ui_fn)


        # Hit the search engine to fetch relevant documents
        if search_results_list == []:
            queries = call_llm(
                instructions=prompts.IMPROVE_CONTENT_CREATE_QUERY_INSTRUCTIONS,
                model_type='fast',
                context={
                    "section_topic": self.section_topic,
                    "expert_question": state["expert_question"],
                    'persona': persona.persona
                },
                response_model=Queries,
                logging_fn="improve_content_create_query_fallback"
                )
            search_results, search_results_list = yield from fetch_search_results(queries.queries, self.task_status,self.section_topic, self.update_ui_fn)


        response = call_llm(
            instructions=prompts.IMPORVE_CONTENT_ANSWER_QUERY_INSTRUCTION,
            model_type='rag',
            context={
                "section_topic": self.section_topic,
                "expert_question": state["expert_question"],
                "search_results": search_results,
                'persona' : persona.persona
            },
            response_model=StrucutredAnswer,
            logging_fn="improve_content_answer_query"
        )
          
        state["content"] =response.answer_response


        references_used = format_search_results([search_results_list[i-1] for i in response.references_used])
        # Find all unique bracketed references in the search results
        bracketed_refs = re.findall(r'\[(\d+)\](?=\s*Title:)', search_results)

        #Replace citations[2,3,4] with [2][3][4]
        cited_references_raw = re.findall(r'\[(\d+(?:,\s*\d+)*)\]', response.answer_response)
        for group in cited_references_raw:
            nums_list = group.split(',')
            new_string = ''.join(f'[{n.strip()}]' for n in nums_list)
            old_string = f'[{group}]'
            response.answer_response = response.answer_response.replace(old_string, new_string)
        # Replace each reference number with its a unique search number
        for ref in bracketed_refs:
            search_results = search_results.replace(f'[{ref}]', f"[{self.num_search_result}]")
            response.answer_response = response.answer_response.replace(f'[{ref}]', f"[{self.num_search_result}]")
            self.num_search_result += 1

        ui.system_sub_update("-------------------")
        ui.system_sub_update('Content:')
        ui.system_sub_update(response.answer_response)
        ui.system_sub_update("-------------------")
        state['previous_messages'].append({'role' : 'user', 'content' : response.answer_response})
        state['full_messages'].append(response.answer_response)
        state['references'] = state['references'] +  '\n\n' + search_results
        state["iteration"] += 1
        
        return state

    def create_and_run_interview(self, task_status, update_ui_fn):
        """Runs an iterative process of generating questions and answers 
           until the iteration limit is reached."""
        self.task_status = task_status
        self.update_ui_fn = update_ui_fn
        discussion_messages = []
        for persona in self.personas:
            ui.system_update(f"Starting discussion with : {persona.name}: {persona.role}, {persona.affiliation}")
            state = self.create_initial_state()
            while state["iteration"] <= self.num_interview_rounds:
                state = self.expert_question_generator(persona, state)
                state = yield from self.answer_question(persona, state)
            discussion_messages.extend(state['previous_messages'])
        self.final_state = state
        return discussion_messages

    def generate_final_section(self, synopsis):
        return '\n\n'.join(self.final_state['full_messages']), self.final_state['references']
    

    def warm_start_discussion(self):
        """Warm start the discussion with existing personas"""

        messages = [f"{self.personas[0].name}: Hi! Let's get started!"]
        selected_persona = random.choice(self.personas)
        for _ in range(self.warm_start_rounds):
    
            # Get the last 5 messages if there are more than 5
            recent_messages = messages[-5:] if len(messages) > 5 else messages
            
            message = call_llm(
                instructions=prompts.ROUNDTABLE_DISCUSSION_INSTRUCTIONS,
                model_type='fast',
                context={
                   "persona_name" : selected_persona.name,
                   "persona_role" : selected_persona.role,
                   "persona_affiliation" : selected_persona.affiliation,
                   "persona_focus" : selected_persona.focus,
                   "personas" : 
                        "\n\n".join([p.name + '\n' + p.persona for p in self.personas if p != selected_persona]),
                   "discussion" : "\n\n".join(recent_messages)
                        },
                response_model=RoundtableMessage,
                logging_fn="roundtable_discussion"
                  )  
            ui.system_sub_update("\n\n" + selected_persona.name + ": " + message.response + '\n' + message.follow_up)
            messages.append(selected_persona.name + ": " + message.response + '\n' + message.follow_up)
            selected_persona = [p for p in self.personas if p.name == message.next_persona][0]    
            time.sleep(3)
        return messages



if __name__ == "__main__":
    section_name = 'Glean Search in the Enterprise Search Market'
    section_description = 'Positioning and Competition'
    section_key_questions = ['how is glean positioned in the enterprise search market?', "who are the main competitors in this space?"]
    personas = ['\nRole: Business Analyst\nAffiliation: Enterprise Software Consultant\nDescription: Specializes in helping organizations implement and optimize AI-powered tools for improved productivity and knowledge management. Will analyze Glean and Copilot from a business user perspective.\n']
    improve_content = ImproveContent(section_name, section_description, section_key_questions, personas)
    improved_content = improve_content.create_and_run_interview()
    improve_content.generate_final_section()
    print(improved_content)