File size: 3,480 Bytes
e931b70
 
 
a796108
 
e931b70
 
 
 
 
 
 
 
 
 
fab8405
 
e931b70
fab8405
e931b70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4853cf
a796108
e931b70
 
 
 
 
 
 
 
45180a0
0e573d0
e931b70
 
 
 
 
 
 
 
 
 
 
 
9061790
 
e931b70
 
 
 
 
 
 
 
 
 
9061790
e931b70
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import os
import time

import streamlit as st

from backend.constants.streamlit_keys import DATA_INITIALIZE_NOT_STATED, DATA_INITIALIZE_COMPLETED, \
    DATA_INITIALIZE_STARTED
from backend.constants.variables import DATA_INITIALIZE_STATUS, JUMP_QUERY_ASK, CHAINS_RETRIEVERS_MAPPING, \
    TABLE_EMBEDDINGS_MAPPING, RETRIEVER_TOOLS, USER_NAME, GLOBAL_CONFIG, update_global_config
from backend.construct.build_all import build_chains_and_retrievers, load_embedding_models, update_retriever_tools
from backend.types.global_config import GlobalConfig
from logger import logger
from ui.chat_page import chat_page
from ui.home import render_home
from ui.retrievers import render_retrievers


# warnings.filterwarnings("ignore", category=UserWarning)

def prepare_environment():
    os.environ['TOKENIZERS_PARALLELISM'] = 'true'
    os.environ["LANGCHAIN_TRACING_V2"] = "false"
    # os.environ["LANGCHAIN_API_KEY"] = ""
    os.environ["OPENAI_API_BASE"] = st.secrets['OPENAI_API_BASE']
    os.environ["OPENAI_API_KEY"] = st.secrets['OPENAI_API_KEY']
    os.environ["AUTH0_CLIENT_ID"] = st.secrets['AUTH0_CLIENT_ID']
    os.environ["AUTH0_DOMAIN"] = st.secrets['AUTH0_DOMAIN']

    update_global_config(GlobalConfig(
        openai_api_base=st.secrets['OPENAI_API_BASE'],
        openai_api_key=st.secrets['OPENAI_API_KEY'],
        auth0_client_id=st.secrets['AUTH0_CLIENT_ID'],
        auth0_domain=st.secrets['AUTH0_DOMAIN'],
        myscale_user=st.secrets['MYSCALE_USER'],
        myscale_password=st.secrets['MYSCALE_PASSWORD'],
        myscale_host=st.secrets['MYSCALE_HOST'],
        myscale_port=st.secrets['MYSCALE_PORT'],
        query_model="gpt-3.5-turbo-0125",
        chat_model="gpt-3.5-turbo-0125",
        untrusted_api=st.secrets['UNSTRUCTURED_API'],
        myscale_enable_https=st.secrets.get('MYSCALE_ENABLE_HTTPS', True),
    ))


# when refresh browser, all session keys will be cleaned.
def initialize_session_state():
    if DATA_INITIALIZE_STATUS not in st.session_state:
        st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_NOT_STATED
        logger.info(f"Initialize session state key: {DATA_INITIALIZE_STATUS}")
    if JUMP_QUERY_ASK not in st.session_state:
        st.session_state[JUMP_QUERY_ASK] = False
        logger.info(f"Initialize session state key: {JUMP_QUERY_ASK}")


def initialize_chat_data():
    if st.session_state[DATA_INITIALIZE_STATUS] != DATA_INITIALIZE_COMPLETED:
        start_time = time.time()
        st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_STARTED
        st.session_state[TABLE_EMBEDDINGS_MAPPING] = load_embedding_models()
        st.session_state[CHAINS_RETRIEVERS_MAPPING] = build_chains_and_retrievers()
        st.session_state[RETRIEVER_TOOLS] = update_retriever_tools()
        # mark data initialization finished.
        st.session_state[DATA_INITIALIZE_STATUS] = DATA_INITIALIZE_COMPLETED
        end_time = time.time()
        logger.info(f"ChatData initialized finished in {round(end_time - start_time, 3)} seconds, "
                    f"session state keys: {list(st.session_state.keys())}")


st.set_page_config(
    page_title="ChatData",
    page_icon="https://myscale.com/favicon.ico",
    initial_sidebar_state="expanded",
    layout="wide",
)

prepare_environment()
initialize_session_state()
initialize_chat_data()

if USER_NAME in st.session_state:
    chat_page()
else:
    if st.session_state[JUMP_QUERY_ASK]:
        render_retrievers()
    else:
        render_home()