File size: 3,448 Bytes
0743bb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import redis
import os
import json
from fastapi import HTTPException
from uuid import uuid4
from typing import Optional, List
from llama_index.storage.chat_store.redis import RedisChatStore
from llama_index.core.memory import ChatMemoryBuffer
from service.dto import ChatMessage


class ChatStore:
    def __init__(self):
        self.redis_client = redis.Redis(
            host="redis-10365.c244.us-east-1-2.ec2.redns.redis-cloud.com",
            port=10365,
            password=os.environ.get("REDIS_PASSWORD"),
        )
        
    def generate_uuid(use_hex=False):
        if use_hex:
            return str(uuid4().hex)
        else:
            return str(uuid4())
        
    def initialize_memory_bot(self, session_id=None):
        if session_id is None:
            session_id = self.generate_uuid()
    # chat_store = SimpleChatStore()
        chat_store = RedisChatStore(
            redis_client=self.redis_client
        )  # Need to be configured

        memory = ChatMemoryBuffer.from_defaults(
            token_limit=3000, chat_store=chat_store, chat_store_key=session_id
        )

        return memory
    
    def get_messages(self, session_id: str) -> List[dict]:
        """Get messages for a session_id."""
        items = self.redis_client.lrange(session_id, 0, -1)
        if len(items) == 0:
            return []
        
        # Decode and parse each item into a dictionary
        return [json.loads(m.decode("utf-8")) for m in items]

    def delete_last_message(self, session_id: str) -> Optional[ChatMessage]:
        """Delete last message for a session_id."""
        return self.redis_client.rpop(session_id)
    
    def delete_messages(self, key: str) -> Optional[List[ChatMessage]]:
        """Delete messages for a key."""
        self.redis_client.delete(key)
        return None
    
    def clean_message(self, session_id: str) -> Optional[ChatMessage]:
        """Delete specific message for a session_id."""
        current_list = self.redis_client.lrange(session_id, 0, -1)
        
        indices_to_delete = []
        for index, item in enumerate(current_list):
            data = json.loads(item)  # Parse JSON string to dict

            # Logic to determine if item should be removed
            if (data.get("role") == "assistant" and data.get("content") is None) or (data.get("role") == "tool"):
                indices_to_delete.append(index)

        # Remove elements by their indices in reverse order
        for index in reversed(indices_to_delete):
            self.redis_client.lrem(session_id, 1, current_list[index])  # Remove the element from the list in Redis

    def get_keys(self) -> List[str]:
        """Get all keys."""
        try :
            print(self.redis_client.keys("*"))
            return [key.decode("utf-8") for key in self.redis_client.keys("*")]
        
        except Exception as e:
            # Log the error and raise HTTPException for FastAPI
            print(f"An error occurred in update data.: {e}")
            raise HTTPException(
                status_code=400, detail="the error when get keys"
            )

    def add_message(self, session_id: str, message: ChatMessage) -> None:
        """Add a message for a session_id."""
        item = json.dumps(self._message_to_dict(message))
        self.redis_client.rpush(session_id, item)

    def _message_to_dict(self, message: ChatMessage) -> dict:
        return message.model_dump()