Spaces:
Sleeping
Sleeping
import redis | |
import os | |
import json | |
from fastapi.responses import JSONResponse | |
from typing import Optional, List | |
from llama_index.storage.chat_store.redis import RedisChatStore | |
from pymongo.mongo_client import MongoClient | |
from llama_index.core.memory import ChatMemoryBuffer | |
from service.dto import ChatMessage | |
class ChatStore: | |
def __init__(self): | |
self.redis_client = redis.Redis( | |
host="redis-10365.c244.us-east-1-2.ec2.redns.redis-cloud.com", | |
port=10365, | |
password=os.environ.get("REDIS_PASSWORD"), | |
) | |
uri = os.getenv("MONGO_URI") | |
self.client = MongoClient(uri) | |
def initialize_memory_bot(self, session_id): | |
chat_store = RedisChatStore( | |
redis_client=self.redis_client, ttl=86400 # Time-to-live set for 1 hour | |
) | |
db = self.client["bot_database"] | |
if ( | |
self.redis_client.exists(session_id) | |
or session_id in db.list_collection_names() | |
): | |
if session_id not in self.redis_client.keys(): | |
self.add_chat_history_to_redis( | |
session_id | |
) # Add chat history to Redis if not found | |
# Create memory buffer with chat store and session key | |
memory = ChatMemoryBuffer.from_defaults( | |
token_limit=3000, chat_store=chat_store, chat_store_key=session_id | |
) | |
else: | |
# Handle the case where the session doesn't exist | |
memory = ChatMemoryBuffer.from_defaults( | |
token_limit=3000, chat_store=chat_store, chat_store_key=session_id | |
) | |
return memory | |
def get_messages(self, session_id: str) -> List[dict]: | |
"""Get messages for a session_id.""" | |
items = self.redis_client.lrange(session_id, 0, -1) | |
if len(items) == 0: | |
return [] | |
# Decode and parse each item into a dictionary | |
return [json.loads(m.decode("utf-8")) for m in items] | |
def delete_last_message(self, session_id: str) -> Optional[ChatMessage]: | |
"""Delete last message for a session_id.""" | |
return self.redis_client.rpop(session_id) | |
def delete_messages(self, session_id: str) -> Optional[List[ChatMessage]]: | |
"""Delete messages for a session_id.""" | |
self.redis_client.delete(session_id) | |
db = self.client["bot_database"] | |
db.session_id.drop() | |
return None | |
def clean_message(self, session_id: str) -> Optional[ChatMessage]: | |
"""Delete specific message for a session_id.""" | |
current_list = self.redis_client.lrange(session_id, 0, -1) | |
indices_to_delete = [] | |
for index, item in enumerate(current_list): | |
data = json.loads(item) # Parse JSON string to dict | |
# Logic to determine if item should be removed | |
if (data.get("role") == "assistant" and data.get("content") is None) or ( | |
data.get("role") == "tool" | |
): | |
indices_to_delete.append(index) | |
# Remove elements by their indices in reverse order | |
for index in reversed(indices_to_delete): | |
self.redis_client.lrem( | |
session_id, 1, current_list[index] | |
) # Remove the element from the list in Redis | |
def get_keys(self) -> List[str]: | |
"""Get all keys.""" | |
try: | |
print(self.redis_client.keys("*")) | |
return [key.decode("utf-8") for key in self.redis_client.keys("*")] | |
except Exception as e: | |
# Log the error and return JSONResponse for FastAPI | |
print(f"An error occurred in update data.: {e}") | |
return JSONResponse(status_code=400, content="the error when get keys") | |
def add_message(self, session_id: str, message: ChatMessage) -> None: | |
"""Add a message for a session_id.""" | |
item = json.dumps(self._message_to_dict(message)) | |
self.redis_client.rpush(session_id, item) | |
def _message_to_dict(self, message: ChatMessage) -> dict: | |
return message.model_dump() | |
def add_chat_history_to_redis(self, session_id: str) -> None: | |
"""Fetch chat history from MongoDB and add it to Redis.""" | |
db = self.client["bot_database"] | |
collection = db[session_id] | |
try: | |
chat_history = collection.find() | |
chat_history_list = [ | |
{ | |
key: message[key] | |
for key in message | |
if key not in ["_id", "timestamp"] and message[key] is not None | |
} | |
for message in chat_history | |
if message is not None | |
] | |
for message in chat_history_list: | |
# Convert MongoDB document to the format you need | |
item = json.dumps( | |
self._message_to_dict(ChatMessage(**message)) | |
) # Convert message to dict | |
# Push to Redis | |
self.redis_client.rpush(session_id, item) | |
self.redis_client.expire(session_id, time=86400) | |
except Exception as e: | |
return JSONResponse(status_code=500, content="Add Database Error") | |
def get_all_messages_mongodb(self, session_id): | |
"""Get all messages for a session_id from MongoDB.""" | |
try: | |
db = self.client["bot_database"] | |
collection = db[session_id] | |
# Retrieve all documents from the collection | |
documents = collection.find() | |
# Convert the cursor to a list and exclude the _id field | |
documents_list = [ | |
{key: doc[key] for key in doc if key !="_id" and doc[key] is not None} | |
for doc in documents | |
] | |
# Print the list of documents without the _id field | |
print(documents_list) # Optional: If you want to see the output | |
return documents_list | |
except Exception as e: | |
print(f"An error occurred while retrieving messages: {e}") | |
return JSONResponse(status_code=500, content=f"An error occurred while retrieving messages: {e}") |