File size: 3,628 Bytes
67be838
 
 
 
 
 
01d0c35
67be838
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
01d0c35
67be838
 
 
 
 
 
 
 
 
 
 
 
01d0c35
 
 
 
 
67be838
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import streamlit as st
import os
import pickle
import time
import requests

st.set_page_config(page_title="Psychedelics GPT")

st.markdown(
    """
    <style>
        .title {
            text-align: center;
            font-size: 2em;
            font-weight: bold;
        }
    </style>
    <div class="title">  Psychedelics Chatbot </div>
    """,
    unsafe_allow_html=True
)
# Load and Save Conversations
conversations_file = "conversations.pkl"


@st.cache_data
def load_conversations():
    try:
        with open(conversations_file, "rb") as f:
            return pickle.load(f)
    except (FileNotFoundError, EOFError):
        return []


def save_conversations(conversations):
    temp_conversations_file = conversations_file
    with open(temp_conversations_file, "wb") as f:
        pickle.dump(conversations, f)
    os.replace(temp_conversations_file, conversations_file)


if 'conversations' not in st.session_state:
    st.session_state.conversations = load_conversations()

if 'current_conversation' not in st.session_state:
    st.session_state.current_conversation = [{"role": "assistant", "content": "How may I assist you today?"}]


def truncate_string(s, length=30):
    return s[:length].rstrip() + "..." if len(s) > length else s


def display_chats_sidebar():
    with st.sidebar.container():
        st.header('Settings')
        col1, col2 = st.columns([1, 1])

        with col1:
            if col1.button('Start New Chat', key="new_chat"):
                st.session_state.current_conversation = []
                st.session_state.conversations.append(st.session_state.current_conversation)

        with col2:
            if col2.button('Clear All Chats', key="clear_all"):
                st.session_state.conversations = []
                st.session_state.current_conversation = []

    with st.sidebar.container():
        st.header('Conversations')
        for idx, conversation in enumerate(st.session_state.conversations):
            if conversation:
                chat_title_raw = next((msg["content"] for msg in conversation if msg["role"] == "user"), "New Chat")
                chat_title = truncate_string(chat_title_raw)
                if st.sidebar.button(f"{chat_title}", key=f"chat_button_{idx}"):
                    st.session_state.current_conversation = st.session_state.conversations[idx]


def main_app():
    for message in st.session_state.current_conversation:
        with st.chat_message(message["role"]):
            st.write(message["content"])

    def generate_response(prompt_input):
        json = {

            "user_prompt": prompt_input,
            "chat_history": []

        }
        response = requests.post('http://3.223.163.181:8090/generate', json=json)

        return response.json()

    if prompt := st.chat_input('Send a Message'):
        st.session_state.current_conversation.append({"role": "user", "content": prompt})
        with st.chat_message("user"):
            st.write(prompt)

        with st.chat_message("assistant"):
            with st.spinner("Thinking..."):
                response = generate_response(prompt)
                st.markdown(response['response'])
                sources_str = 'References:\n' + '\n'.join(
                    [f'{idx + 1}. {source}' for idx, source in enumerate(response['sources'])])
                st.markdown(sources_str)
                st.session_state.current_conversation.append(
                    {"role": "assistant", "content": response['response'] + "\n" + sources_str})
                save_conversations(st.session_state.conversations)


display_chats_sidebar()

main_app()