Spaces:
Paused
Paused
Shreyas094
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -7,7 +7,7 @@ import requests
|
|
7 |
import random
|
8 |
import urllib.parse
|
9 |
from tempfile import NamedTemporaryFile
|
10 |
-
from typing import List
|
11 |
from bs4 import BeautifulSoup
|
12 |
from langchain.prompts import PromptTemplate
|
13 |
from langchain.chains import LLMChain
|
@@ -17,10 +17,72 @@ from langchain_community.document_loaders import PyPDFLoader
|
|
17 |
from langchain_core.output_parsers import StrOutputParser
|
18 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
19 |
from langchain_community.llms import HuggingFaceHub
|
20 |
-
from langchain_core.documents import Document
|
21 |
|
22 |
huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
def load_document(file: NamedTemporaryFile) -> List[Document]:
|
25 |
"""Loads and splits the document into pages."""
|
26 |
loader = PyPDFLoader(file.name)
|
@@ -207,6 +269,8 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
|
|
207 |
|
208 |
model = get_model(temperature, top_p, repetition_penalty)
|
209 |
embed = get_embeddings()
|
|
|
|
|
210 |
|
211 |
if os.path.exists("faiss_database"):
|
212 |
database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
|
@@ -219,16 +283,10 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
|
|
219 |
for attempt in range(max_attempts):
|
220 |
try:
|
221 |
if web_search:
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
if rephrased_query == original_query:
|
228 |
-
print("Warning: Query was not rephrased. Using original query for search.")
|
229 |
-
|
230 |
-
search_results = google_search(rephrased_query)
|
231 |
-
web_docs = [Document(page_content=result["text"], metadata={"source": result["link"]}) for result in search_results if result["text"]]
|
232 |
|
233 |
if database is None:
|
234 |
database = FAISS.from_documents(web_docs, embed)
|
@@ -237,20 +295,17 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
|
|
237 |
|
238 |
database.save_local("faiss_database")
|
239 |
|
240 |
-
context_str = "\n".join([f"
|
241 |
|
242 |
prompt_template = """
|
243 |
Answer the question based on the following web search results:
|
244 |
Web Search Results:
|
245 |
{context}
|
246 |
-
Original Question: {
|
247 |
-
Rephrased Search Query: {rephrased_query}
|
248 |
If the web search results don't contain relevant information, state that the information is not available in the search results.
|
249 |
Provide a concise and direct answer to the original question without mentioning the web search or these instructions.
|
250 |
Do not include any source information in your answer.
|
251 |
"""
|
252 |
-
prompt_val = ChatPromptTemplate.from_template(prompt_template)
|
253 |
-
formatted_prompt = prompt_val.format(context=context_str, original_question=question, rephrased_query=rephrased_query)
|
254 |
else:
|
255 |
if database is None:
|
256 |
return "No documents available. Please upload documents or enable web search to answer questions."
|
@@ -259,7 +314,6 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
|
|
259 |
relevant_docs = retriever.get_relevant_documents(question)
|
260 |
context_str = "\n".join([doc.page_content for doc in relevant_docs])
|
261 |
|
262 |
-
# Reduce context if we're not on the first attempt
|
263 |
if attempt > 0:
|
264 |
words = context_str.split()
|
265 |
context_str = " ".join(words[:int(len(words) * context_reduction_factor)])
|
@@ -273,8 +327,9 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
|
|
273 |
Provide a concise and direct answer to the question.
|
274 |
Do not include any source information in your answer.
|
275 |
"""
|
276 |
-
|
277 |
-
|
|
|
278 |
|
279 |
full_response = generate_chunked_response(model, formatted_prompt)
|
280 |
|
@@ -294,7 +349,16 @@ def ask_question(question, temperature, top_p, repetition_penalty, web_search):
|
|
294 |
else:
|
295 |
answer = full_response.strip()
|
296 |
|
297 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
if web_search:
|
299 |
sources = set(doc.metadata['source'] for doc in web_docs)
|
300 |
sources_section = "\n\nSources:\n" + "\n".join(f"- {source}" for source in sources)
|
|
|
7 |
import random
|
8 |
import urllib.parse
|
9 |
from tempfile import NamedTemporaryFile
|
10 |
+
from typing import List, Dict
|
11 |
from bs4 import BeautifulSoup
|
12 |
from langchain.prompts import PromptTemplate
|
13 |
from langchain.chains import LLMChain
|
|
|
17 |
from langchain_core.output_parsers import StrOutputParser
|
18 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
19 |
from langchain_community.llms import HuggingFaceHub
|
20 |
+
from langchain_core.documents import Document
|
21 |
|
22 |
huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
|
23 |
|
24 |
+
class Agent1:
|
25 |
+
def __init__(self, model):
|
26 |
+
self.model = model
|
27 |
+
|
28 |
+
def rephrase_and_split(self, user_input: str) -> List[str]:
|
29 |
+
rephrase_prompt = PromptTemplate(
|
30 |
+
input_variables=["query"],
|
31 |
+
template="""
|
32 |
+
Your task is to rephrase the given query into one or more concise, search-engine-friendly formats.
|
33 |
+
If the query contains multiple distinct questions, split them.
|
34 |
+
Provide ONLY the rephrased queries without any additional text or explanations, one per line.
|
35 |
+
|
36 |
+
Query: {query}
|
37 |
+
|
38 |
+
Rephrased queries:"""
|
39 |
+
)
|
40 |
+
|
41 |
+
chain = LLMChain(llm=self.model, prompt=rephrase_prompt)
|
42 |
+
response = chain.run(query=user_input).strip()
|
43 |
+
|
44 |
+
return [q.strip() for q in response.split('\n') if q.strip()]
|
45 |
+
|
46 |
+
def process(self, user_input: str) -> Dict[str, List[Dict[str, str]]]:
|
47 |
+
queries = self.rephrase_and_split(user_input)
|
48 |
+
results = {}
|
49 |
+
for query in queries:
|
50 |
+
results[query] = google_search(query)
|
51 |
+
return results
|
52 |
+
|
53 |
+
class Agent2:
|
54 |
+
def __init__(self, model):
|
55 |
+
self.model = model
|
56 |
+
|
57 |
+
def validate_response(self, user_query: str, response: str) -> bool:
|
58 |
+
validation_prompt = PromptTemplate(
|
59 |
+
input_variables=["query", "response"],
|
60 |
+
template="""
|
61 |
+
Evaluate if the following response fully answers the user's query.
|
62 |
+
User query: {query}
|
63 |
+
Response: {response}
|
64 |
+
|
65 |
+
Does the response fully answer the query? Answer with Yes or No:"""
|
66 |
+
)
|
67 |
+
|
68 |
+
chain = LLMChain(llm=self.model, prompt=validation_prompt)
|
69 |
+
result = chain.run(query=user_query, response=response).strip().lower()
|
70 |
+
return result == 'yes'
|
71 |
+
|
72 |
+
def generate_follow_up_query(self, user_query: str, response: str) -> str:
|
73 |
+
follow_up_prompt = PromptTemplate(
|
74 |
+
input_variables=["query", "response"],
|
75 |
+
template="""
|
76 |
+
The following response did not fully answer the user's query.
|
77 |
+
User query: {query}
|
78 |
+
Response: {response}
|
79 |
+
|
80 |
+
Generate a follow-up query to get more relevant information:"""
|
81 |
+
)
|
82 |
+
|
83 |
+
chain = LLMChain(llm=self.model, prompt=follow_up_prompt)
|
84 |
+
return chain.run(query=user_query, response=response).strip()
|
85 |
+
|
86 |
def load_document(file: NamedTemporaryFile) -> List[Document]:
|
87 |
"""Loads and splits the document into pages."""
|
88 |
loader = PyPDFLoader(file.name)
|
|
|
269 |
|
270 |
model = get_model(temperature, top_p, repetition_penalty)
|
271 |
embed = get_embeddings()
|
272 |
+
agent1 = Agent1(model)
|
273 |
+
agent2 = Agent2(model)
|
274 |
|
275 |
if os.path.exists("faiss_database"):
|
276 |
database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True)
|
|
|
283 |
for attempt in range(max_attempts):
|
284 |
try:
|
285 |
if web_search:
|
286 |
+
search_results = agent1.process(question)
|
287 |
+
web_docs = []
|
288 |
+
for query, results in search_results.items():
|
289 |
+
web_docs.extend([Document(page_content=result["text"], metadata={"source": result["link"], "query": query}) for result in results if result["text"]])
|
|
|
|
|
|
|
|
|
|
|
|
|
290 |
|
291 |
if database is None:
|
292 |
database = FAISS.from_documents(web_docs, embed)
|
|
|
295 |
|
296 |
database.save_local("faiss_database")
|
297 |
|
298 |
+
context_str = "\n".join([f"Query: {doc.metadata['query']}\nSource: {doc.metadata['source']}\nContent: {doc.page_content}" for doc in web_docs])
|
299 |
|
300 |
prompt_template = """
|
301 |
Answer the question based on the following web search results:
|
302 |
Web Search Results:
|
303 |
{context}
|
304 |
+
Original Question: {question}
|
|
|
305 |
If the web search results don't contain relevant information, state that the information is not available in the search results.
|
306 |
Provide a concise and direct answer to the original question without mentioning the web search or these instructions.
|
307 |
Do not include any source information in your answer.
|
308 |
"""
|
|
|
|
|
309 |
else:
|
310 |
if database is None:
|
311 |
return "No documents available. Please upload documents or enable web search to answer questions."
|
|
|
314 |
relevant_docs = retriever.get_relevant_documents(question)
|
315 |
context_str = "\n".join([doc.page_content for doc in relevant_docs])
|
316 |
|
|
|
317 |
if attempt > 0:
|
318 |
words = context_str.split()
|
319 |
context_str = " ".join(words[:int(len(words) * context_reduction_factor)])
|
|
|
327 |
Provide a concise and direct answer to the question.
|
328 |
Do not include any source information in your answer.
|
329 |
"""
|
330 |
+
|
331 |
+
prompt_val = ChatPromptTemplate.from_template(prompt_template)
|
332 |
+
formatted_prompt = prompt_val.format(context=context_str, question=question)
|
333 |
|
334 |
full_response = generate_chunked_response(model, formatted_prompt)
|
335 |
|
|
|
349 |
else:
|
350 |
answer = full_response.strip()
|
351 |
|
352 |
+
if not agent2.validate_response(question, answer):
|
353 |
+
follow_up_query = agent2.generate_follow_up_query(question, answer)
|
354 |
+
follow_up_results = agent1.process(follow_up_query)
|
355 |
+
follow_up_docs = [Document(page_content=result["text"], metadata={"source": result["link"], "query": follow_up_query}) for results in follow_up_results.values() for result in results if result["text"]]
|
356 |
+
database.add_documents(follow_up_docs)
|
357 |
+
context_str += "\n" + "\n".join([f"Follow-up Query: {doc.metadata['query']}\nSource: {doc.metadata['source']}\nContent: {doc.page_content}" for doc in follow_up_docs])
|
358 |
+
formatted_prompt = prompt_val.format(context=context_str, question=question)
|
359 |
+
full_response = generate_chunked_response(model, formatted_prompt)
|
360 |
+
answer = full_response.strip()
|
361 |
+
|
362 |
if web_search:
|
363 |
sources = set(doc.metadata['source'] for doc in web_docs)
|
364 |
sources_section = "\n\nSources:\n" + "\n".join(f"- {source}" for source in sources)
|