Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -59,6 +59,9 @@ uploads_dir = os.path.join(app.root_path,'static', 'uploads')
|
|
59 |
|
60 |
os.makedirs(uploads_dir, exist_ok=True)
|
61 |
|
|
|
|
|
|
|
62 |
defaultEmbeddingModelID = 3
|
63 |
defaultLLMID=0
|
64 |
|
@@ -201,6 +204,26 @@ def loadKB(fileprovided, urlProvided, uploads_dir, request):
|
|
201 |
|
202 |
|
203 |
def getRAGChain(customerName, customerDistrict, custDetailsPresent, vectordb,llmID):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
204 |
chain = RetrievalQA.from_chain_type(
|
205 |
llm=getLLMModel(llmID),
|
206 |
chain_type='stuff',
|
@@ -210,10 +233,7 @@ def getRAGChain(customerName, customerDistrict, custDetailsPresent, vectordb,llm
|
|
210 |
chain_type_kwargs={
|
211 |
"verbose": False,
|
212 |
"prompt": createPrompt(customerName, customerDistrict, custDetailsPresent),
|
213 |
-
"memory":
|
214 |
-
k=3,
|
215 |
-
memory_key="history",
|
216 |
-
input_key="question"),
|
217 |
}
|
218 |
)
|
219 |
return chain
|
@@ -307,6 +327,10 @@ def aisearch():
|
|
307 |
def process_json():
|
308 |
print(f"\n{'*' * 100}\n")
|
309 |
print("Request Received >>>>>>>>>>>>>>>>>>", datetime.now().strftime("%H:%M:%S"))
|
|
|
|
|
|
|
|
|
310 |
content_type = request.headers.get('Content-Type')
|
311 |
if content_type == 'application/json':
|
312 |
requestQuery = request.get_json()
|
@@ -322,6 +346,14 @@ def process_json():
|
|
322 |
selectedLLMID=defaultLLMID
|
323 |
if "llmID" in requestQuery:
|
324 |
selectedLLMID=(int) (requestQuery['llmID'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
325 |
print("chain initiation")
|
326 |
chainRAG = getRAGChain(customerName, customerDistrict, custDetailsPresent, vectordb,selectedLLMID)
|
327 |
print("chain created")
|
@@ -332,6 +364,7 @@ def process_json():
|
|
332 |
# message = answering(query)
|
333 |
|
334 |
relevantDoc = vectordb.similarity_search_with_score(query, distance_metric="cos", k=3)
|
|
|
335 |
print("Printing Retriever Docs")
|
336 |
for doc in getRetriever(vectordb).get_relevant_documents(query):
|
337 |
searchResult = {}
|
|
|
59 |
|
60 |
os.makedirs(uploads_dir, exist_ok=True)
|
61 |
|
62 |
+
# Initialize global variables for conversation history
|
63 |
+
conversation_history = []
|
64 |
+
|
65 |
defaultEmbeddingModelID = 3
|
66 |
defaultLLMID=0
|
67 |
|
|
|
204 |
|
205 |
|
206 |
def getRAGChain(customerName, customerDistrict, custDetailsPresent, vectordb,llmID):
|
207 |
+
|
208 |
+
# Retrieve conversation history if available
|
209 |
+
memory = ConversationBufferWindowMemory(k=3, memory_key="history", input_key="question")
|
210 |
+
memory.load_history(conversation_history)
|
211 |
+
|
212 |
+
# chain = RetrievalQA.from_chain_type(
|
213 |
+
# llm=getLLMModel(llmID),
|
214 |
+
# chain_type='stuff',
|
215 |
+
# retriever=getRetriever(vectordb),
|
216 |
+
# #retriever=vectordb.as_retriever(),
|
217 |
+
# verbose=False,
|
218 |
+
# chain_type_kwargs={
|
219 |
+
# "verbose": False,
|
220 |
+
# "prompt": createPrompt(customerName, customerDistrict, custDetailsPresent),
|
221 |
+
# "memory": ConversationBufferWindowMemory(
|
222 |
+
# k=3,
|
223 |
+
# memory_key="history",
|
224 |
+
# input_key="question"),
|
225 |
+
# }
|
226 |
+
# )
|
227 |
chain = RetrievalQA.from_chain_type(
|
228 |
llm=getLLMModel(llmID),
|
229 |
chain_type='stuff',
|
|
|
233 |
chain_type_kwargs={
|
234 |
"verbose": False,
|
235 |
"prompt": createPrompt(customerName, customerDistrict, custDetailsPresent),
|
236 |
+
"memory": memory,
|
|
|
|
|
|
|
237 |
}
|
238 |
)
|
239 |
return chain
|
|
|
327 |
def process_json():
|
328 |
print(f"\n{'*' * 100}\n")
|
329 |
print("Request Received >>>>>>>>>>>>>>>>>>", datetime.now().strftime("%H:%M:%S"))
|
330 |
+
|
331 |
+
# Retrieve conversation ID from the request (use any suitable ID)
|
332 |
+
conversation_id = request.json.get('conversation_id', None)
|
333 |
+
|
334 |
content_type = request.headers.get('Content-Type')
|
335 |
if content_type == 'application/json':
|
336 |
requestQuery = request.get_json()
|
|
|
346 |
selectedLLMID=defaultLLMID
|
347 |
if "llmID" in requestQuery:
|
348 |
selectedLLMID=(int) (requestQuery['llmID'])
|
349 |
+
|
350 |
+
# Create a conversation ID-specific history list if not exists
|
351 |
+
conversation_history_id = f"{conversation_id}_history"
|
352 |
+
if conversation_history_id not in globals():
|
353 |
+
globals()[conversation_history_id] = []
|
354 |
+
conversation_history = globals()[conversation_history_id]
|
355 |
+
|
356 |
+
|
357 |
print("chain initiation")
|
358 |
chainRAG = getRAGChain(customerName, customerDistrict, custDetailsPresent, vectordb,selectedLLMID)
|
359 |
print("chain created")
|
|
|
364 |
# message = answering(query)
|
365 |
|
366 |
relevantDoc = vectordb.similarity_search_with_score(query, distance_metric="cos", k=3)
|
367 |
+
conversation_history.append(query)
|
368 |
print("Printing Retriever Docs")
|
369 |
for doc in getRetriever(vectordb).get_relevant_documents(query):
|
370 |
searchResult = {}
|