File size: 7,698 Bytes
0743bb0
 
 
0767396
 
 
 
d57efd6
b39c0ba
0743bb0
d57efd6
0743bb0
 
 
0767396
 
0743bb0
 
 
 
0767396
 
 
 
 
0743bb0
d57efd6
 
 
 
 
0767396
 
 
0743bb0
d57efd6
0743bb0
d57efd6
0767396
 
 
 
 
 
 
 
 
 
 
d57efd6
 
 
 
0767396
d57efd6
 
 
0767396
0743bb0
d57efd6
0743bb0
 
 
 
 
d57efd6
0743bb0
 
0767396
b39c0ba
 
 
0767396
b39c0ba
 
0767396
b39c0ba
 
 
9f7b904
 
 
 
 
 
0767396
 
9f7b904
0767396
 
 
 
0743bb0
 
 
 
d57efd6
 
 
 
 
 
0743bb0
d57efd6
0743bb0
 
 
9f7b904
0743bb0
 
 
 
 
9f7b904
 
 
0743bb0
 
 
 
9f7b904
 
 
0743bb0
 
 
d57efd6
0743bb0
d57efd6
0743bb0
d57efd6
0743bb0
0767396
0743bb0
 
 
 
0767396
 
 
 
 
 
 
d57efd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0767396
d57efd6
 
 
 
0767396
d57efd6
0767396
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import redis
import os
import json

from datetime import datetime
from dotenv import load_dotenv

from fastapi.responses import JSONResponse
from typing import Optional, List, Dict
from llama_index.storage.chat_store.redis import RedisChatStore
from pymongo.mongo_client import MongoClient
from llama_index.core.memory import ChatMemoryBuffer
from service.dto import ChatMessage

load_dotenv()


class ChatStore:
    def __init__(self):
        self.redis_client = redis.Redis(
            # host="redis-10365.c244.us-east-1-2.ec2.redns.redis-cloud.com",
            host = os.getenv("REDIS_HOST"),
            port=os.getenv("REDIS_PORT"),
            username = os.getenv("REDIS_USERNAME"),
            password=os.getenv("REDIS_PASSWORD"),
        )

        uri = os.getenv("MONGO_URI")
        self.client = MongoClient(uri)

    def initialize_memory_bot(self, session_id):
        # Decode Redis keys to work with strings
        redis_keys = [key.decode('utf-8') for key in self.redis_client.keys()]
        
        chat_store = RedisChatStore(
            redis_client=self.redis_client, ttl=86400  # Time-to-live set for 1 hour
        )
        db = self.client["bot_database"]
        
        # Check if the session exists in Redis or MongoDB
        if session_id in redis_keys:
            # If the session already exists in Redis, create the memory buffer using Redis
            memory = ChatMemoryBuffer.from_defaults(
                token_limit=3000, chat_store=chat_store, chat_store_key=session_id
            )
        elif session_id in db.list_collection_names():
            # If the session exists in MongoDB but not Redis, fetch messages from MongoDB
            self.add_chat_history_to_redis(session_id)  # Add chat history to Redis
            # Then create the memory buffer using Redis
            memory = ChatMemoryBuffer.from_defaults(
                token_limit=3000, chat_store=chat_store, chat_store_key=session_id
            )
        else:
            # If the session doesn't exist in either Redis or MongoDB, create an empty memory buffer
            memory = ChatMemoryBuffer.from_defaults(
                token_limit=3000, chat_store=chat_store, chat_store_key=session_id
            )
        
        return memory

    def get_messages(self, session_id: str) -> List[dict]:
        """Get messages for a session_id."""
        items = self.redis_client.lrange(session_id, 0, -1)
        if len(items) == 0:
            return []

        # Decode and parse each item into a dictionary
        return [json.loads(m.decode("utf-8")) for m in items]
    
    def get_last_message(self, session_id: str) -> Optional[Dict]:
        """Get the last message for a session_id."""
        last_message = self.redis_client.lindex(session_id, -1)
        
        if last_message is None:
            return None  # Return None if there are no messages
        
        # Decode and parse the last message into a dictionary
        return json.loads(last_message.decode("utf-8"))

    def get_last_message_mongodb(self, session_id: str):
        db = self.client["bot_database"]
        collection = db[session_id]

        # Get the last document by sorting by _id in descending order
        last_document = collection.find().sort("_id", -1).limit(1)
        
        # Iterasi last_document dan kembalikan isi content jika ada
        for doc in last_document:
            return str(doc.get('content', ""))  # kembalikan content atau string kosong jika tidak ada
        
        # Jika tidak ada dokumen, kembalikan string kosong
        return ""

    def delete_last_message(self, session_id: str) -> Optional[ChatMessage]:
        """Delete last message for a session_id."""
        return self.redis_client.rpop(session_id)

    def delete_messages(self, session_id: str) -> Optional[List[ChatMessage]]:
        """Delete messages for a session_id."""
        self.redis_client.delete(session_id)
        db = self.client["bot_database"]
        db.session_id.drop()
        return None

    def clean_message(self, session_id: str) -> Optional[ChatMessage]:
        """Delete specific message for a session_id."""
        current_list = self.redis_client.lrange(session_id, 0, -1)

        indices_to_delete = []
        for index, item in enumerate(current_list):
            data = json.loads(item)  # Parse JSON string to dict

            # Logic to determine if item should be removed
            if (data.get("role") == "assistant" and data.get("content") is None) or (
                data.get("role") == "tool"
            ):
                indices_to_delete.append(index)

        # Remove elements by their indices in reverse order
        for index in reversed(indices_to_delete):
            self.redis_client.lrem(
                session_id, 1, current_list[index]
            )  # Remove the element from the list in Redis

    def get_keys(self) -> List[str]:
        """Get all keys."""
        try:
            return [key.decode("utf-8") for key in self.redis_client.keys("*")]

        except Exception as e:
            return JSONResponse(status_code=400, content="the error when get keys")

    def add_message(self, session_id: str, message: Optional[ChatMessage]) -> None:
        """Add a message for a session_id."""
        item = json.dumps(self._message_to_dict(message))
        self.redis_client.rpush(session_id, item)

    def _message_to_dict(self, message: Optional[ChatMessage]) -> dict:
        # Convert the ChatMessage instance into a dictionary with necessary adjustments
        message_dict = message.model_dump()
        # Convert any datetime fields to ISO format, if needed
        if isinstance(message_dict.get('timestamp'), datetime):
            message_dict['timestamp'] = message_dict['timestamp'].isoformat()
        return message_dict

    def add_chat_history_to_redis(self, session_id: str) -> None:
        """Fetch chat history from MongoDB and add it to Redis."""
        db = self.client["bot_database"]
        collection = db[session_id]

        try:
            chat_history = collection.find()
            chat_history_list = [
                {
                    key: message[key]
                    for key in message
                    if key not in ["_id", "timestamp"] and message[key] is not None
                }
                for message in chat_history
                if message is not None
            ]

            for message in chat_history_list:
                # Convert MongoDB document to the format you need
                item = json.dumps(
                    self._message_to_dict(ChatMessage(**message))
                )  # Convert message to dict
                # Push to Redis
                self.redis_client.rpush(session_id, item)
            self.redis_client.expire(session_id, time=86400)

        except Exception as e:
            return JSONResponse(status_code=500, content="Add Database Error")

    def get_all_messages_mongodb(self, session_id):
        """Get all messages for a session_id from MongoDB."""
        try:
            db = self.client["bot_database"]
            collection = db[session_id]

            # Retrieve all documents from the collection
            documents = collection.find()

            # Convert the cursor to a list and exclude the _id field
            documents_list = [
                {key: doc[key] for key in doc if key !="_id" and doc[key] is not None}
                for doc in documents
            ]

            return documents_list
        
        except Exception as e:
            return JSONResponse(status_code=500, content=f"An error occurred while retrieving messages: {e}")