Update app.py
Browse files
app.py
CHANGED
@@ -41,7 +41,11 @@ class Agent1:
|
|
41 |
chain = LLMChain(llm=self.model, prompt=rephrase_prompt)
|
42 |
response = chain.run(query=user_input).strip()
|
43 |
|
44 |
-
|
|
|
|
|
|
|
|
|
45 |
|
46 |
def process(self, user_input: str) -> Dict[str, List[Dict[str, str]]]:
|
47 |
queries = self.rephrase_and_split(user_input)
|
@@ -50,39 +54,6 @@ class Agent1:
|
|
50 |
results[query] = google_search(query)
|
51 |
return results
|
52 |
|
53 |
-
class Agent2:
|
54 |
-
def __init__(self, model):
|
55 |
-
self.model = model
|
56 |
-
|
57 |
-
def validate_response(self, user_query: str, response: str) -> bool:
|
58 |
-
validation_prompt = PromptTemplate(
|
59 |
-
input_variables=["query", "response"],
|
60 |
-
template="""
|
61 |
-
Evaluate if the following response fully answers the user's query.
|
62 |
-
User query: {query}
|
63 |
-
Response: {response}
|
64 |
-
|
65 |
-
Does the response fully answer the query? Answer with Yes or No:"""
|
66 |
-
)
|
67 |
-
|
68 |
-
chain = LLMChain(llm=self.model, prompt=validation_prompt)
|
69 |
-
result = chain.run(query=user_query, response=response).strip().lower()
|
70 |
-
return result == 'yes'
|
71 |
-
|
72 |
-
def generate_follow_up_query(self, user_query: str, response: str) -> str:
|
73 |
-
follow_up_prompt = PromptTemplate(
|
74 |
-
input_variables=["query", "response"],
|
75 |
-
template="""
|
76 |
-
The following response did not fully answer the user's query.
|
77 |
-
User query: {query}
|
78 |
-
Response: {response}
|
79 |
-
|
80 |
-
Generate a follow-up query to get more relevant information:"""
|
81 |
-
)
|
82 |
-
|
83 |
-
chain = LLMChain(llm=self.model, prompt=follow_up_prompt)
|
84 |
-
return chain.run(query=user_query, response=response).strip()
|
85 |
-
|
86 |
def load_document(file: NamedTemporaryFile) -> List[Document]:
|
87 |
"""Loads and splits the document into pages."""
|
88 |
loader = PyPDFLoader(file.name)
|
@@ -270,7 +241,6 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
|
|
270 |
model = get_model(temperature, top_p, repetition_penalty)
|
271 |
embed = get_embeddings()
|
272 |
agent1 = Agent1(model)
|
273 |
-
agent2 = Agent2(model)
|
274 |
|
275 |
if os.path.exists("faiss_database"):
|
276 |
database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
|
@@ -279,7 +249,6 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
|
|
279 |
|
280 |
max_attempts = 3
|
281 |
context_reduction_factor = 0.7
|
282 |
-
agent2_max_attempts = 2
|
283 |
|
284 |
for attempt in range(max_attempts):
|
285 |
try:
|
@@ -350,20 +319,6 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
|
|
350 |
else:
|
351 |
answer = full_response.strip()
|
352 |
|
353 |
-
for agent2_attempt in range(agent2_max_attempts):
|
354 |
-
if agent2.validate_response(question, answer):
|
355 |
-
break
|
356 |
-
|
357 |
-
if agent2_attempt < agent2_max_attempts - 1:
|
358 |
-
follow_up_query = agent2.generate_follow_up_query(question, answer)
|
359 |
-
follow_up_results = agent1.process(follow_up_query)
|
360 |
-
follow_up_docs = [Document(page_content=result["text"], metadata={"source": result["link"], "query": follow_up_query}) for results in follow_up_results.values() for result in results if result["text"]]
|
361 |
-
database.add_documents(follow_up_docs)
|
362 |
-
context_str += "\n" + "\n".join([f"Follow-up Query: {doc.metadata['query']}\nSource: {doc.metadata['source']}\nContent: {doc.page_content}" for doc in follow_up_docs])
|
363 |
-
formatted_prompt = prompt_val.format(context=context_str, question=question)
|
364 |
-
full_response = generate_chunked_response(model, formatted_prompt)
|
365 |
-
answer = full_response.strip()
|
366 |
-
|
367 |
if web_search:
|
368 |
sources = set(doc.metadata['source'] for doc in web_docs)
|
369 |
sources_section = "\n\nSources:\n" + "\n".join(f"- {source}" for source in sources)
|
|
|
41 |
chain = LLMChain(llm=self.model, prompt=rephrase_prompt)
|
42 |
response = chain.run(query=user_input).strip()
|
43 |
|
44 |
+
# Remove any lines that contain instructions or explanations
|
45 |
+
rephrased_queries = [q.strip() for q in response.split('\n') if q.strip() and not q.startswith("Rephrase") and "query" not in q.lower()]
|
46 |
+
|
47 |
+
# If no valid rephrased queries, return the original input
|
48 |
+
return rephrased_queries if rephrased_queries else [user_input]
|
49 |
|
50 |
def process(self, user_input: str) -> Dict[str, List[Dict[str, str]]]:
|
51 |
queries = self.rephrase_and_split(user_input)
|
|
|
54 |
results[query] = google_search(query)
|
55 |
return results
|
56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
def load_document(file: NamedTemporaryFile) -> List[Document]:
|
58 |
"""Loads and splits the document into pages."""
|
59 |
loader = PyPDFLoader(file.name)
|
|
|
241 |
model = get_model(temperature, top_p, repetition_penalty)
|
242 |
embed = get_embeddings()
|
243 |
agent1 = Agent1(model)
|
|
|
244 |
|
245 |
if os.path.exists("faiss_database"):
|
246 |
database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
|
|
|
249 |
|
250 |
max_attempts = 3
|
251 |
context_reduction_factor = 0.7
|
|
|
252 |
|
253 |
for attempt in range(max_attempts):
|
254 |
try:
|
|
|
319 |
else:
|
320 |
answer = full_response.strip()
|
321 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
322 |
if web_search:
|
323 |
sources = set(doc.metadata['source'] for doc in web_docs)
|
324 |
sources_section = "\n\nSources:\n" + "\n".join(f"- {source}" for source in sources)
|