Hammad712 commited on
Commit
5fb4fa6
·
verified ·
1 Parent(s): 27bc155

Update chatbot.py

Browse files
Files changed (1) hide show
  1. chatbot.py +27 -12
chatbot.py CHANGED
@@ -6,6 +6,7 @@ from pymongo import MongoClient
6
  from langchain.prompts import ChatPromptTemplate
7
  from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory
8
  from langchain.chains import ConversationalRetrievalChain
 
9
 
10
  from llm_provider import llm
11
  from vectorstore_manager import get_user_retriever
@@ -49,6 +50,7 @@ db = client[DB_NAME]
49
  sessions_collection = db[SESSIONS_COLLECTION]
50
  chains_collection = db[CHAINS_COLLECTION]
51
 
 
52
  # === Core Functions ===
53
 
54
  def create_new_chat(user_id: str) -> str:
@@ -78,7 +80,7 @@ def create_new_chat(user_id: str) -> str:
78
  # If the user has no chain/vectorstore registered yet, register it
79
  if chains_collection.count_documents({"user_id": user_id}, limit=1) == 0:
80
  # This also creates the vectorstore on disk via vectorstore_manager.ingest_report
81
- # you should call ingest_report first elsewhere before chat
82
  chains_collection.insert_one({
83
  "user_id": user_id,
84
  "vectorstore_path": f"user_vectorstores/{user_id}_faiss"
@@ -86,38 +88,47 @@ def create_new_chat(user_id: str) -> str:
86
 
87
  return chat_id
88
 
 
89
  def get_chain_for_user(user_id: str, chat_id: str) -> ConversationalRetrievalChain:
90
  """
91
  Reconstructs (or creates) the user's ConversationalRetrievalChain
92
  using their vectorstore and the chat-specific memory object.
93
  """
94
- # Load chat history memory
95
- chat_history = MongoDBChatMessageHistory(
96
  session_id=chat_id,
97
  connection_string=MONGO_URI,
98
  database_name=DB_NAME,
99
  collection_name=HISTORY_COLLECTION,
100
  )
101
 
102
- # Look up vectorstore path
 
 
 
 
 
 
 
103
  chain_doc = chains_collection.find_one({"user_id": user_id})
104
  if not chain_doc:
105
  raise ValueError(f"No vectorstore registered for user {user_id}")
106
 
107
- # Initialize retriever from vectorstore
108
  retriever = get_user_retriever(user_id)
109
 
110
- # Create and return the chain
111
  return ConversationalRetrievalChain.from_llm(
112
  llm=llm,
113
  retriever=retriever,
114
  return_source_documents=True,
115
  chain_type="stuff",
116
  combine_docs_chain_kwargs={"prompt": user_prompt},
117
- memory=chat_history,
118
  verbose=False,
119
  )
120
 
 
121
  def summarize_messages(chat_history: MongoDBChatMessageHistory) -> bool:
122
  """
123
  If the chat history grows too long, summarize it to keep the memory concise.
@@ -138,6 +149,7 @@ def summarize_messages(chat_history: MongoDBChatMessageHistory) -> bool:
138
  chat_history.add_ai_message(summary.content)
139
  return True
140
 
 
141
  def stream_chat_response(user_id: str, chat_id: str, query: str):
142
  """
143
  Given a user_id, chat_id, and a query string, streams back the AI response
@@ -145,17 +157,20 @@ def stream_chat_response(user_id: str, chat_id: str, query: str):
145
  """
146
  # Ensure the chain and memory are set up
147
  chain = get_chain_for_user(user_id, chat_id)
148
- chat_history = chain.memory # the MongoDBChatMessageHistory instance
 
 
 
149
 
150
  # Optionally summarize if too many messages
151
- summarize_messages(chat_history)
152
 
153
  # Add the user message to history
154
- chat_history.add_user_message(query)
155
 
156
  # Stream the response
157
  response_accum = ""
158
- for chunk in chain.stream({"question": query, "chat_history": chat_history.messages}):
159
  if "answer" in chunk:
160
  print(chunk["answer"], end="", flush=True)
161
  response_accum += chunk["answer"]
@@ -165,4 +180,4 @@ def stream_chat_response(user_id: str, chat_id: str, query: str):
165
 
166
  # Persist the AI's final message
167
  if response_accum:
168
- chat_history.add_ai_message(response_accum)
 
6
  from langchain.prompts import ChatPromptTemplate
7
  from langchain_mongodb.chat_message_histories import MongoDBChatMessageHistory
8
  from langchain.chains import ConversationalRetrievalChain
9
+ from langchain.memory import ConversationBufferMemory
10
 
11
  from llm_provider import llm
12
  from vectorstore_manager import get_user_retriever
 
50
  sessions_collection = db[SESSIONS_COLLECTION]
51
  chains_collection = db[CHAINS_COLLECTION]
52
 
53
+
54
  # === Core Functions ===
55
 
56
  def create_new_chat(user_id: str) -> str:
 
80
  # If the user has no chain/vectorstore registered yet, register it
81
  if chains_collection.count_documents({"user_id": user_id}, limit=1) == 0:
82
  # This also creates the vectorstore on disk via vectorstore_manager.ingest_report
83
+ # You should call ingest_report first elsewhere before chat
84
  chains_collection.insert_one({
85
  "user_id": user_id,
86
  "vectorstore_path": f"user_vectorstores/{user_id}_faiss"
 
88
 
89
  return chat_id
90
 
91
+
92
  def get_chain_for_user(user_id: str, chat_id: str) -> ConversationalRetrievalChain:
93
  """
94
  Reconstructs (or creates) the user's ConversationalRetrievalChain
95
  using their vectorstore and the chat-specific memory object.
96
  """
97
+ # Step 1: Load raw MongoDB-backed chat history
98
+ mongo_history = MongoDBChatMessageHistory(
99
  session_id=chat_id,
100
  connection_string=MONGO_URI,
101
  database_name=DB_NAME,
102
  collection_name=HISTORY_COLLECTION,
103
  )
104
 
105
+ # Step 2: Wrap it in a ConversationBufferMemory so that LangChain accepts it
106
+ memory = ConversationBufferMemory(
107
+ memory_key="chat_history",
108
+ chat_memory=mongo_history,
109
+ return_messages=True
110
+ )
111
+
112
+ # Step 3: Look up vectorstore path for this user
113
  chain_doc = chains_collection.find_one({"user_id": user_id})
114
  if not chain_doc:
115
  raise ValueError(f"No vectorstore registered for user {user_id}")
116
 
117
+ # Step 4: Initialize retriever from vectorstore
118
  retriever = get_user_retriever(user_id)
119
 
120
+ # Step 5: Create and return the chain with a valid Memory instance
121
  return ConversationalRetrievalChain.from_llm(
122
  llm=llm,
123
  retriever=retriever,
124
  return_source_documents=True,
125
  chain_type="stuff",
126
  combine_docs_chain_kwargs={"prompt": user_prompt},
127
+ memory=memory,
128
  verbose=False,
129
  )
130
 
131
+
132
  def summarize_messages(chat_history: MongoDBChatMessageHistory) -> bool:
133
  """
134
  If the chat history grows too long, summarize it to keep the memory concise.
 
149
  chat_history.add_ai_message(summary.content)
150
  return True
151
 
152
+
153
  def stream_chat_response(user_id: str, chat_id: str, query: str):
154
  """
155
  Given a user_id, chat_id, and a query string, streams back the AI response
 
157
  """
158
  # Ensure the chain and memory are set up
159
  chain = get_chain_for_user(user_id, chat_id)
160
+
161
+ # Since we used ConversationBufferMemory, the underlying MongoDBChatMessageHistory is accessible at:
162
+ chat_memory_wrapper = chain.memory # type: ConversationBufferMemory
163
+ mongo_history = chat_memory_wrapper.chat_memory # type: MongoDBChatMessageHistory
164
 
165
  # Optionally summarize if too many messages
166
+ summarize_messages(mongo_history)
167
 
168
  # Add the user message to history
169
+ mongo_history.add_user_message(query)
170
 
171
  # Stream the response
172
  response_accum = ""
173
+ for chunk in chain.stream({"question": query, "chat_history": mongo_history.messages}):
174
  if "answer" in chunk:
175
  print(chunk["answer"], end="", flush=True)
176
  response_accum += chunk["answer"]
 
180
 
181
  # Persist the AI's final message
182
  if response_accum:
183
+ mongo_history.add_ai_message(response_accum)