File size: 3,714 Bytes
e931b70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import json

from backend.chat_bot.tools import create_session_table, create_message_history_table
from backend.constants.variables import GLOBAL_CONFIG

try:
    from sqlalchemy.orm import declarative_base
except ImportError:
    from sqlalchemy.ext.declarative import declarative_base
from datetime import datetime
from sqlalchemy import orm, create_engine
from logger import logger


def get_sessions(engine, model_class, user_id):
    with orm.sessionmaker(engine)() as session:
        result = (
            session.query(model_class)
            .where(
                model_class.session_id == user_id
            )
            .order_by(model_class.create_by.desc())
        )
    return json.loads(result)


class SessionManager:
    def __init__(
            self,
            session_state,
            host,
            port,
            username,
            password,
            db='chat',
            session_table='sessions',
            msg_table='chat_memory'
    ) -> None:
        if GLOBAL_CONFIG.myscale_enable_https == False:
            conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=http'
        else:
            conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=https'
        self.engine = create_engine(conn_str, echo=False)
        self.session_model_class = create_session_table(
            session_table, declarative_base())
        self.session_model_class.metadata.create_all(self.engine)
        self.msg_model_class = create_message_history_table(msg_table, declarative_base())
        self.msg_model_class.metadata.create_all(self.engine)
        self.session_orm = orm.sessionmaker(self.engine)
        self.session_state = session_state

    def list_sessions(self, user_id: str):
        with self.session_orm() as session:
            result = (
                session.query(self.session_model_class)
                .where(
                    self.session_model_class.user_id == user_id
                )
                .order_by(self.session_model_class.create_by.desc())
            )
            sessions = []
            for r in result:
                sessions.append({
                    "session_id": r.session_id.split("?")[-1],
                    "system_prompt": r.system_prompt,
                })
            return sessions

    # Update sys_prompt with given session_id
    def modify_system_prompt(self, session_id, sys_prompt):
        with self.session_orm() as session:
            obj = session.query(self.session_model_class).where(
                self.session_model_class.session_id == session_id).first()
            if obj:
                obj.system_prompt = sys_prompt
                session.commit()
            else:
                logger.warning(f"Session {session_id} not found")

    # Add a session(session_id, sys_prompt)
    def add_session(self, user_id: str, session_id: str, system_prompt: str, **kwargs):
        with self.session_orm() as session:
            elem = self.session_model_class(
                user_id=user_id, session_id=session_id, system_prompt=system_prompt,
                create_by=datetime.now(), additionals=json.dumps(kwargs)
            )
            session.add(elem)
            session.commit()

    # Remove a session and related chat history.
    def remove_session(self, session_id: str):
        with self.session_orm() as session:
            # remove session
            session.query(self.session_model_class).where(self.session_model_class.session_id == session_id).delete()
            # remove related chat history.
            session.query(self.msg_model_class).where(self.msg_model_class.session_id == session_id).delete()