Bot_Development / core /chat /chatstore.py
dsmultimedika's picture
fix : update code
0767396
import redis
import os
import json
from datetime import datetime
from dotenv import load_dotenv
from fastapi.responses import JSONResponse
from typing import Optional, List, Dict
from llama_index.storage.chat_store.redis import RedisChatStore
from pymongo.mongo_client import MongoClient
from llama_index.core.memory import ChatMemoryBuffer
from service.dto import ChatMessage
load_dotenv()
class ChatStore:
def __init__(self):
self.redis_client = redis.Redis(
# host="redis-10365.c244.us-east-1-2.ec2.redns.redis-cloud.com",
host = os.getenv("REDIS_HOST"),
port=os.getenv("REDIS_PORT"),
username = os.getenv("REDIS_USERNAME"),
password=os.getenv("REDIS_PASSWORD"),
)
uri = os.getenv("MONGO_URI")
self.client = MongoClient(uri)
def initialize_memory_bot(self, session_id):
# Decode Redis keys to work with strings
redis_keys = [key.decode('utf-8') for key in self.redis_client.keys()]
chat_store = RedisChatStore(
redis_client=self.redis_client, ttl=86400 # Time-to-live set for 1 hour
)
db = self.client["bot_database"]
# Check if the session exists in Redis or MongoDB
if session_id in redis_keys:
# If the session already exists in Redis, create the memory buffer using Redis
memory = ChatMemoryBuffer.from_defaults(
token_limit=3000, chat_store=chat_store, chat_store_key=session_id
)
elif session_id in db.list_collection_names():
# If the session exists in MongoDB but not Redis, fetch messages from MongoDB
self.add_chat_history_to_redis(session_id) # Add chat history to Redis
# Then create the memory buffer using Redis
memory = ChatMemoryBuffer.from_defaults(
token_limit=3000, chat_store=chat_store, chat_store_key=session_id
)
else:
# If the session doesn't exist in either Redis or MongoDB, create an empty memory buffer
memory = ChatMemoryBuffer.from_defaults(
token_limit=3000, chat_store=chat_store, chat_store_key=session_id
)
return memory
def get_messages(self, session_id: str) -> List[dict]:
"""Get messages for a session_id."""
items = self.redis_client.lrange(session_id, 0, -1)
if len(items) == 0:
return []
# Decode and parse each item into a dictionary
return [json.loads(m.decode("utf-8")) for m in items]
def get_last_message(self, session_id: str) -> Optional[Dict]:
"""Get the last message for a session_id."""
last_message = self.redis_client.lindex(session_id, -1)
if last_message is None:
return None # Return None if there are no messages
# Decode and parse the last message into a dictionary
return json.loads(last_message.decode("utf-8"))
def get_last_message_mongodb(self, session_id: str):
db = self.client["bot_database"]
collection = db[session_id]
# Get the last document by sorting by _id in descending order
last_document = collection.find().sort("_id", -1).limit(1)
# Iterasi last_document dan kembalikan isi content jika ada
for doc in last_document:
return str(doc.get('content', "")) # kembalikan content atau string kosong jika tidak ada
# Jika tidak ada dokumen, kembalikan string kosong
return ""
def delete_last_message(self, session_id: str) -> Optional[ChatMessage]:
"""Delete last message for a session_id."""
return self.redis_client.rpop(session_id)
def delete_messages(self, session_id: str) -> Optional[List[ChatMessage]]:
"""Delete messages for a session_id."""
self.redis_client.delete(session_id)
db = self.client["bot_database"]
db.session_id.drop()
return None
def clean_message(self, session_id: str) -> Optional[ChatMessage]:
"""Delete specific message for a session_id."""
current_list = self.redis_client.lrange(session_id, 0, -1)
indices_to_delete = []
for index, item in enumerate(current_list):
data = json.loads(item) # Parse JSON string to dict
# Logic to determine if item should be removed
if (data.get("role") == "assistant" and data.get("content") is None) or (
data.get("role") == "tool"
):
indices_to_delete.append(index)
# Remove elements by their indices in reverse order
for index in reversed(indices_to_delete):
self.redis_client.lrem(
session_id, 1, current_list[index]
) # Remove the element from the list in Redis
def get_keys(self) -> List[str]:
"""Get all keys."""
try:
return [key.decode("utf-8") for key in self.redis_client.keys("*")]
except Exception as e:
return JSONResponse(status_code=400, content="the error when get keys")
def add_message(self, session_id: str, message: Optional[ChatMessage]) -> None:
"""Add a message for a session_id."""
item = json.dumps(self._message_to_dict(message))
self.redis_client.rpush(session_id, item)
def _message_to_dict(self, message: Optional[ChatMessage]) -> dict:
# Convert the ChatMessage instance into a dictionary with necessary adjustments
message_dict = message.model_dump()
# Convert any datetime fields to ISO format, if needed
if isinstance(message_dict.get('timestamp'), datetime):
message_dict['timestamp'] = message_dict['timestamp'].isoformat()
return message_dict
def add_chat_history_to_redis(self, session_id: str) -> None:
"""Fetch chat history from MongoDB and add it to Redis."""
db = self.client["bot_database"]
collection = db[session_id]
try:
chat_history = collection.find()
chat_history_list = [
{
key: message[key]
for key in message
if key not in ["_id", "timestamp"] and message[key] is not None
}
for message in chat_history
if message is not None
]
for message in chat_history_list:
# Convert MongoDB document to the format you need
item = json.dumps(
self._message_to_dict(ChatMessage(**message))
) # Convert message to dict
# Push to Redis
self.redis_client.rpush(session_id, item)
self.redis_client.expire(session_id, time=86400)
except Exception as e:
return JSONResponse(status_code=500, content="Add Database Error")
def get_all_messages_mongodb(self, session_id):
"""Get all messages for a session_id from MongoDB."""
try:
db = self.client["bot_database"]
collection = db[session_id]
# Retrieve all documents from the collection
documents = collection.find()
# Convert the cursor to a list and exclude the _id field
documents_list = [
{key: doc[key] for key in doc if key !="_id" and doc[key] is not None}
for doc in documents
]
return documents_list
except Exception as e:
return JSONResponse(status_code=500, content=f"An error occurred while retrieving messages: {e}")