|
import json |
|
import logging |
|
from datetime import datetime |
|
from typing import List |
|
|
|
from langchain.memory import MongoDBChatMessageHistory |
|
from langchain.schema import AIMessage, BaseMessage, HumanMessage, messages_from_dict, _message_to_dict |
|
from pymongo import errors |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class CustomMongoDBChatMessageHistory(MongoDBChatMessageHistory): |
|
|
|
@property |
|
def messages(self) -> List[BaseMessage]: |
|
"""Retrieve the messages from MongoDB""" |
|
from pymongo import errors |
|
cursor = None |
|
try: |
|
cursor = self.collection.find({"SessionId": self.session_id}) |
|
except errors.OperationFailure as error: |
|
logger.error(error) |
|
|
|
document_count = self.collection.count_documents({"SessionId": self.session_id}) |
|
|
|
if cursor and document_count > 0: |
|
document = cursor[0] |
|
items = document["messages"] |
|
else: |
|
items = [] |
|
|
|
messages = messages_from_dict([json.loads(item) for item in items]) |
|
return messages |
|
|
|
def add_user_message(self, message: str) -> None: |
|
self.append(HumanMessage(content=message)) |
|
|
|
def add_ai_message(self, message: str) -> None: |
|
self.append(AIMessage(content=message)) |
|
|
|
def append(self, message: BaseMessage) -> None: |
|
"""Append the message to the record in MongoDB with the desired format""" |
|
|
|
|
|
sender = "ai" if isinstance(message, AIMessage) else "human" |
|
|
|
|
|
message_obj = { |
|
"type": sender, |
|
"content": message.content, |
|
"timestamp": datetime.utcnow() |
|
} |
|
|
|
try: |
|
|
|
self.collection.update_one( |
|
{"SessionId": self.session_id}, |
|
{"$push": {"messages": json.dumps(_message_to_dict(message))}}, |
|
upsert=True |
|
) |
|
except errors.WriteError as err: |
|
logger.error(err) |
|
|