File size: 3,209 Bytes
042a946
 
 
 
 
04f0bde
042a946
 
 
 
0e573d0
042a946
 
 
 
 
 
 
 
 
 
 
0e573d0
042a946
04f0bde
 
042a946
 
0e573d0
 
042a946
0e573d0
 
042a946
 
04f0bde
042a946
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e573d0
042a946
0e573d0
042a946
 
0e573d0
 
042a946
0e573d0
042a946
 
 
 
 
 
 
 
0e573d0
042a946
 
0e573d0
 
04f0bde
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import json
try:
    from sqlalchemy.orm import declarative_base
except ImportError:
    from sqlalchemy.ext.declarative import declarative_base
from langchain.schema import BaseChatMessageHistory
from datetime import datetime
from sqlalchemy import Column, Text, orm, create_engine
from .schemas import create_message_model, create_session_table


def get_sessions(engine, model_class, user_id):
    with orm.sessionmaker(engine)() as session:
        result = (
            session.query(model_class)
            .where(
                model_class.session_id == user_id
            )
            .order_by(model_class.create_by.desc())
        )
    return json.loads(result)


class SessionManager:
    def __init__(self, session_state, host, port, username, password,
                 db='chat', sess_table='sessions', msg_table='chat_memory') -> None:
        conn_str = f'clickhouse://{username}:{password}@{host}:{port}/{db}?protocol=https'
        self.engine = create_engine(conn_str, echo=False)
        self.sess_model_class = create_session_table(
            sess_table, declarative_base())
        self.sess_model_class.metadata.create_all(self.engine)
        self.msg_model_class = create_message_model(
            msg_table, declarative_base())
        self.msg_model_class.metadata.create_all(self.engine)
        self.Session = orm.sessionmaker(self.engine)
        self.session_state = session_state

    def list_sessions(self, user_id):
        with self.Session() as session:
            result = (
                session.query(self.sess_model_class)
                .where(
                    self.sess_model_class.user_id == user_id
                )
                .order_by(self.sess_model_class.create_by.desc())
            )
            sessions = []
            for r in result:
                sessions.append({
                    "session_id": r.session_id.split("?")[-1],
                    "system_prompt": r.system_prompt,
                })
            return sessions

    def modify_system_prompt(self, session_id, sys_prompt):
        with self.Session() as session:
            session.update(self.sess_model_class).where(
                self.sess_model_class == session_id).value(system_prompt=sys_prompt)
            session.commit()

    def add_session(self, user_id, session_id, system_prompt, **kwargs):
        with self.Session() as session:
            elem = self.sess_model_class(
                user_id=user_id, session_id=session_id, system_prompt=system_prompt,
                create_by=datetime.now(), additionals=json.dumps(kwargs)
            )
            session.add(elem)
            session.commit()

    def remove_session(self, session_id):
        with self.Session() as session:
            session.query(self.sess_model_class).where(
                self.sess_model_class.session_id == session_id).delete()
            # session.query(self.msg_model_class).where(self.msg_model_class.session_id==session_id).delete()
        if "agent" in self.session_state:
            self.session_state.agent.memory.chat_memory.clear()
        if "file_analyzer" in self.session_state:
            self.session_state.file_analyzer.clear_files()