Spaces:
Sleeping
Sleeping
File size: 6,564 Bytes
0743bb0 d57efd6 b39c0ba 0743bb0 d57efd6 0743bb0 d57efd6 0743bb0 d57efd6 0743bb0 d57efd6 0743bb0 d57efd6 0743bb0 d57efd6 0743bb0 b39c0ba 0743bb0 d57efd6 0743bb0 d57efd6 0743bb0 d57efd6 0743bb0 d57efd6 0743bb0 d57efd6 0743bb0 d57efd6 0743bb0 d57efd6 0743bb0 d57efd6 0743bb0 d57efd6 0743bb0 d57efd6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import redis
import os
import json
from fastapi.responses import JSONResponse
from typing import Optional, List, Dict
from llama_index.storage.chat_store.redis import RedisChatStore
from pymongo.mongo_client import MongoClient
from llama_index.core.memory import ChatMemoryBuffer
from service.dto import ChatMessage
class ChatStore:
def __init__(self):
self.redis_client = redis.Redis(
host="redis-10365.c244.us-east-1-2.ec2.redns.redis-cloud.com",
port=10365,
password=os.environ.get("REDIS_PASSWORD"),
)
uri = os.getenv("MONGO_URI")
self.client = MongoClient(uri)
def initialize_memory_bot(self, session_id):
chat_store = RedisChatStore(
redis_client=self.redis_client, ttl=86400 # Time-to-live set for 1 hour
)
db = self.client["bot_database"]
if (
self.redis_client.exists(session_id)
or session_id in db.list_collection_names()
):
if session_id not in self.redis_client.keys():
self.add_chat_history_to_redis(
session_id
) # Add chat history to Redis if not found
# Create memory buffer with chat store and session key
memory = ChatMemoryBuffer.from_defaults(
token_limit=3000, chat_store=chat_store, chat_store_key=session_id
)
else:
# Handle the case where the session doesn't exist
memory = ChatMemoryBuffer.from_defaults(
token_limit=3000, chat_store=chat_store, chat_store_key=session_id
)
return memory
def get_messages(self, session_id: str) -> List[dict]:
"""Get messages for a session_id."""
items = self.redis_client.lrange(session_id, 0, -1)
if len(items) == 0:
return []
# Decode and parse each item into a dictionary
return [json.loads(m.decode("utf-8")) for m in items]
def get_last_message(self, session_id: str) -> Optional[Dict]:
"""Get the last message for a session_id."""
last_message = self.redis_client.lindex(session_id, -1)
if last_message is None:
return None # Return None if there are no messages
# Decode and parse the last message into a dictionary
return json.loads(last_message.decode("utf-8"))
def delete_last_message(self, session_id: str) -> Optional[ChatMessage]:
"""Delete last message for a session_id."""
return self.redis_client.rpop(session_id)
def delete_messages(self, session_id: str) -> Optional[List[ChatMessage]]:
"""Delete messages for a session_id."""
self.redis_client.delete(session_id)
db = self.client["bot_database"]
db.session_id.drop()
return None
def clean_message(self, session_id: str) -> Optional[ChatMessage]:
"""Delete specific message for a session_id."""
current_list = self.redis_client.lrange(session_id, 0, -1)
indices_to_delete = []
for index, item in enumerate(current_list):
data = json.loads(item) # Parse JSON string to dict
# Logic to determine if item should be removed
if (data.get("role") == "assistant" and data.get("content") is None) or (
data.get("role") == "tool"
):
indices_to_delete.append(index)
# Remove elements by their indices in reverse order
for index in reversed(indices_to_delete):
self.redis_client.lrem(
session_id, 1, current_list[index]
) # Remove the element from the list in Redis
def get_keys(self) -> List[str]:
"""Get all keys."""
try:
print(self.redis_client.keys("*"))
return [key.decode("utf-8") for key in self.redis_client.keys("*")]
except Exception as e:
# Log the error and return JSONResponse for FastAPI
print(f"An error occurred in update data.: {e}")
return JSONResponse(status_code=400, content="the error when get keys")
def add_message(self, session_id: str, message: ChatMessage) -> None:
"""Add a message for a session_id."""
item = json.dumps(self._message_to_dict(message))
self.redis_client.rpush(session_id, item)
def _message_to_dict(self, message: ChatMessage) -> dict:
return message.model_dump()
def add_chat_history_to_redis(self, session_id: str) -> None:
"""Fetch chat history from MongoDB and add it to Redis."""
db = self.client["bot_database"]
collection = db[session_id]
try:
chat_history = collection.find()
chat_history_list = [
{
key: message[key]
for key in message
if key not in ["_id", "timestamp"] and message[key] is not None
}
for message in chat_history
if message is not None
]
for message in chat_history_list:
# Convert MongoDB document to the format you need
item = json.dumps(
self._message_to_dict(ChatMessage(**message))
) # Convert message to dict
# Push to Redis
self.redis_client.rpush(session_id, item)
self.redis_client.expire(session_id, time=86400)
except Exception as e:
return JSONResponse(status_code=500, content="Add Database Error")
def get_all_messages_mongodb(self, session_id):
"""Get all messages for a session_id from MongoDB."""
try:
db = self.client["bot_database"]
collection = db[session_id]
# Retrieve all documents from the collection
documents = collection.find()
# Convert the cursor to a list and exclude the _id field
documents_list = [
{key: doc[key] for key in doc if key !="_id" and doc[key] is not None}
for doc in documents
]
# Print the list of documents without the _id field
print(documents_list) # Optional: If you want to see the output
return documents_list
except Exception as e:
print(f"An error occurred while retrieving messages: {e}")
return JSONResponse(status_code=500, content=f"An error occurred while retrieving messages: {e}") |