File size: 6,135 Bytes
0743bb0
 
 
d57efd6
0743bb0
 
d57efd6
0743bb0
 
 
 
 
 
 
 
 
 
 
d57efd6
 
 
 
 
0743bb0
 
d57efd6
0743bb0
 
d57efd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0743bb0
d57efd6
0743bb0
 
 
 
 
d57efd6
0743bb0
 
 
 
 
 
d57efd6
 
 
 
 
 
0743bb0
d57efd6
0743bb0
 
 
d57efd6
0743bb0
 
 
 
 
d57efd6
 
 
0743bb0
 
 
 
d57efd6
 
 
0743bb0
 
 
d57efd6
0743bb0
 
d57efd6
0743bb0
d57efd6
0743bb0
d57efd6
0743bb0
 
 
 
 
 
 
d57efd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import redis
import os
import json
from fastapi.responses import JSONResponse
from typing import Optional, List
from llama_index.storage.chat_store.redis import RedisChatStore
from pymongo.mongo_client import MongoClient
from llama_index.core.memory import ChatMemoryBuffer
from service.dto import ChatMessage


class ChatStore:
    def __init__(self):
        self.redis_client = redis.Redis(
            host="redis-10365.c244.us-east-1-2.ec2.redns.redis-cloud.com",
            port=10365,
            password=os.environ.get("REDIS_PASSWORD"),
        )

        uri = os.getenv("MONGO_URI")
        self.client = MongoClient(uri)

    def initialize_memory_bot(self, session_id):
        
        chat_store = RedisChatStore(
            redis_client=self.redis_client, ttl=86400  # Time-to-live set for 1 hour
        )

        db = self.client["bot_database"]

        if (
            self.redis_client.exists(session_id)
            or session_id in db.list_collection_names()
        ):
            if session_id not in self.redis_client.keys():
                self.add_chat_history_to_redis(
                    session_id
                )  # Add chat history to Redis if not found
            # Create memory buffer with chat store and session key
            memory = ChatMemoryBuffer.from_defaults(
                token_limit=3000, chat_store=chat_store, chat_store_key=session_id
            )
        else:
            # Handle the case where the session doesn't exist
            memory = ChatMemoryBuffer.from_defaults(
                token_limit=3000, chat_store=chat_store, chat_store_key=session_id
            )

        return memory

    def get_messages(self, session_id: str) -> List[dict]:
        """Get messages for a session_id."""
        items = self.redis_client.lrange(session_id, 0, -1)
        if len(items) == 0:
            return []

        # Decode and parse each item into a dictionary
        return [json.loads(m.decode("utf-8")) for m in items]

    def delete_last_message(self, session_id: str) -> Optional[ChatMessage]:
        """Delete last message for a session_id."""
        return self.redis_client.rpop(session_id)

    def delete_messages(self, session_id: str) -> Optional[List[ChatMessage]]:
        """Delete messages for a session_id."""
        self.redis_client.delete(session_id)
        db = self.client["bot_database"]
        db.session_id.drop()
        return None

    def clean_message(self, session_id: str) -> Optional[ChatMessage]:
        """Delete specific message for a session_id."""
        current_list = self.redis_client.lrange(session_id, 0, -1)

        indices_to_delete = []
        for index, item in enumerate(current_list):
            data = json.loads(item)  # Parse JSON string to dict

            # Logic to determine if item should be removed
            if (data.get("role") == "assistant" and data.get("content") is None) or (
                data.get("role") == "tool"
            ):
                indices_to_delete.append(index)

        # Remove elements by their indices in reverse order
        for index in reversed(indices_to_delete):
            self.redis_client.lrem(
                session_id, 1, current_list[index]
            )  # Remove the element from the list in Redis

    def get_keys(self) -> List[str]:
        """Get all keys."""
        try:
            print(self.redis_client.keys("*"))
            return [key.decode("utf-8") for key in self.redis_client.keys("*")]

        except Exception as e:
            # Log the error and return JSONResponse for FastAPI
            print(f"An error occurred in update data.: {e}")
            return JSONResponse(status_code=400, content="the error when get keys")

    def add_message(self, session_id: str, message: ChatMessage) -> None:
        """Add a message for a session_id."""
        item = json.dumps(self._message_to_dict(message))
        self.redis_client.rpush(session_id, item)

    def _message_to_dict(self, message: ChatMessage) -> dict:
        return message.model_dump()

    def add_chat_history_to_redis(self, session_id: str) -> None:
        """Fetch chat history from MongoDB and add it to Redis."""
        db = self.client["bot_database"]
        collection = db[session_id]

        try:
            chat_history = collection.find()
            chat_history_list = [
                {
                    key: message[key]
                    for key in message
                    if key not in ["_id", "timestamp"] and message[key] is not None
                }
                for message in chat_history
                if message is not None
            ]

            for message in chat_history_list:
                # Convert MongoDB document to the format you need
                item = json.dumps(
                    self._message_to_dict(ChatMessage(**message))
                )  # Convert message to dict
                # Push to Redis
                self.redis_client.rpush(session_id, item)
            self.redis_client.expire(session_id, time=86400)

        except Exception as e:
            return JSONResponse(status_code=500, content="Add Database Error")

    def get_all_messages_mongodb(self, session_id):
        """Get all messages for a session_id from MongoDB."""
        try:
            db = self.client["bot_database"]
            collection = db[session_id]

            # Retrieve all documents from the collection
            documents = collection.find()

            # Convert the cursor to a list and exclude the _id field
            documents_list = [
                {key: doc[key] for key in doc if key !="_id" and doc[key] is not None}
                for doc in documents
            ]

            # Print the list of documents without the _id field
            print(documents_list)  # Optional: If you want to see the output

            return documents_list
        
        except Exception as e:
            print(f"An error occurred while retrieving messages: {e}")
            return JSONResponse(status_code=500, content=f"An error occurred while retrieving messages: {e}")