demoPOC commited on
Commit
4c329c5
·
verified ·
1 Parent(s): 3adb0e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -4
app.py CHANGED
@@ -59,6 +59,9 @@ uploads_dir = os.path.join(app.root_path,'static', 'uploads')
59
 
60
  os.makedirs(uploads_dir, exist_ok=True)
61
 
 
 
 
62
  defaultEmbeddingModelID = 3
63
  defaultLLMID=0
64
 
@@ -201,6 +204,26 @@ def loadKB(fileprovided, urlProvided, uploads_dir, request):
201
 
202
 
203
  def getRAGChain(customerName, customerDistrict, custDetailsPresent, vectordb,llmID):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  chain = RetrievalQA.from_chain_type(
205
  llm=getLLMModel(llmID),
206
  chain_type='stuff',
@@ -210,10 +233,7 @@ def getRAGChain(customerName, customerDistrict, custDetailsPresent, vectordb,llm
210
  chain_type_kwargs={
211
  "verbose": False,
212
  "prompt": createPrompt(customerName, customerDistrict, custDetailsPresent),
213
- "memory": ConversationBufferWindowMemory(
214
- k=3,
215
- memory_key="history",
216
- input_key="question"),
217
  }
218
  )
219
  return chain
@@ -307,6 +327,10 @@ def aisearch():
307
  def process_json():
308
  print(f"\n{'*' * 100}\n")
309
  print("Request Received >>>>>>>>>>>>>>>>>>", datetime.now().strftime("%H:%M:%S"))
 
 
 
 
310
  content_type = request.headers.get('Content-Type')
311
  if content_type == 'application/json':
312
  requestQuery = request.get_json()
@@ -322,6 +346,14 @@ def process_json():
322
  selectedLLMID=defaultLLMID
323
  if "llmID" in requestQuery:
324
  selectedLLMID=(int) (requestQuery['llmID'])
 
 
 
 
 
 
 
 
325
  print("chain initiation")
326
  chainRAG = getRAGChain(customerName, customerDistrict, custDetailsPresent, vectordb,selectedLLMID)
327
  print("chain created")
@@ -332,6 +364,7 @@ def process_json():
332
  # message = answering(query)
333
 
334
  relevantDoc = vectordb.similarity_search_with_score(query, distance_metric="cos", k=3)
 
335
  print("Printing Retriever Docs")
336
  for doc in getRetriever(vectordb).get_relevant_documents(query):
337
  searchResult = {}
 
59
 
60
  os.makedirs(uploads_dir, exist_ok=True)
61
 
62
+ # Initialize global variables for conversation history
63
+ conversation_history = []
64
+
65
  defaultEmbeddingModelID = 3
66
  defaultLLMID=0
67
 
 
204
 
205
 
206
  def getRAGChain(customerName, customerDistrict, custDetailsPresent, vectordb,llmID):
207
+
208
+ # Retrieve conversation history if available
209
+ memory = ConversationBufferWindowMemory(k=3, memory_key="history", input_key="question")
210
+ memory.load_history(conversation_history)
211
+
212
+ # chain = RetrievalQA.from_chain_type(
213
+ # llm=getLLMModel(llmID),
214
+ # chain_type='stuff',
215
+ # retriever=getRetriever(vectordb),
216
+ # #retriever=vectordb.as_retriever(),
217
+ # verbose=False,
218
+ # chain_type_kwargs={
219
+ # "verbose": False,
220
+ # "prompt": createPrompt(customerName, customerDistrict, custDetailsPresent),
221
+ # "memory": ConversationBufferWindowMemory(
222
+ # k=3,
223
+ # memory_key="history",
224
+ # input_key="question"),
225
+ # }
226
+ # )
227
  chain = RetrievalQA.from_chain_type(
228
  llm=getLLMModel(llmID),
229
  chain_type='stuff',
 
233
  chain_type_kwargs={
234
  "verbose": False,
235
  "prompt": createPrompt(customerName, customerDistrict, custDetailsPresent),
236
+ "memory": memory,
 
 
 
237
  }
238
  )
239
  return chain
 
327
  def process_json():
328
  print(f"\n{'*' * 100}\n")
329
  print("Request Received >>>>>>>>>>>>>>>>>>", datetime.now().strftime("%H:%M:%S"))
330
+
331
+ # Retrieve conversation ID from the request (use any suitable ID)
332
+ conversation_id = request.json.get('conversation_id', None)
333
+
334
  content_type = request.headers.get('Content-Type')
335
  if content_type == 'application/json':
336
  requestQuery = request.get_json()
 
346
  selectedLLMID=defaultLLMID
347
  if "llmID" in requestQuery:
348
  selectedLLMID=(int) (requestQuery['llmID'])
349
+
350
+ # Create a conversation ID-specific history list if not exists
351
+ conversation_history_id = f"{conversation_id}_history"
352
+ if conversation_history_id not in globals():
353
+ globals()[conversation_history_id] = []
354
+ conversation_history = globals()[conversation_history_id]
355
+
356
+
357
  print("chain initiation")
358
  chainRAG = getRAGChain(customerName, customerDistrict, custDetailsPresent, vectordb,selectedLLMID)
359
  print("chain created")
 
364
  # message = answering(query)
365
 
366
  relevantDoc = vectordb.similarity_search_with_score(query, distance_metric="cos", k=3)
367
+ conversation_history.append(query)
368
  print("Printing Retriever Docs")
369
  for doc in getRetriever(vectordb).get_relevant_documents(query):
370
  searchResult = {}