File size: 2,959 Bytes
7d9087b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import streamlit as st
# from services.llm import process_answer
import time
import re

# Custom CSS for chat styling
CHAT_CSS = """
    <style>
        .user-message {
            text-align: right;
            background-color: #3c8ce7;
            color: white;
            padding: 10px;
            border-radius: 10px;
            margin-bottom: 10px;
            display: inline-block;
            width: fit-content;
            max-width: 70%;
            margin-left: auto;
            box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.1);
        }
        .assistant-message {
            text-align: left;
            background-color: #d16ba5;
            color: white;
            padding: 10px;
            border-radius: 10px;
            margin-bottom: 10px;
            display: inline-block;
            width: fit-content;
            max-width: 70%;
            margin-right: auto;
            box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.1);
        }
    </style>
"""
def extract_thoughts(response_text):
    """Extracts <think>...</think> content and the main answer."""
    match = re.search(r"<think>(.*?)</think>", response_text, re.DOTALL)
    if match:
        thinking_part = match.group(1).strip()
        main_answer = re.sub(r"<think>.*?</think>", "", response_text, flags=re.DOTALL).strip()
    else:
        thinking_part = None
        main_answer = response_text.strip()
    
    return thinking_part, main_answer

# Streamed response emulator
def response_generator(response):
    for word in response.split():
        yield word + " "
        time.sleep(0.05)

def display_chat(qa_chain, mode):
    st.markdown(CHAT_CSS, unsafe_allow_html=True)

    if "messages" not in st.session_state:
        st.session_state.messages = []

    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.markdown(message["content"])

    if prompt := st.chat_input("Ask something..."):
        st.session_state.messages.append({"role": "user", "content": prompt})
        with st.chat_message("user"):
            st.markdown(prompt)
        
        # Get chat response
        response = qa_chain.invoke({"input": prompt}) if mode else qa_chain.invoke({'context': prompt})
        if not response:  # Handle empty responses
            response = {'answer': "I don't know."}

        if mode is False:
            response = {'answer': response}

        # Extract <think> part and main answer
        thinking_part, main_answer = extract_thoughts(response['answer'])

        # Display assistant response
        with st.chat_message("assistant"):
            if thinking_part:
                with st.expander("💭 Thought Process"):
                    st.markdown(thinking_part)  # Hidden by default, expandable
                
            response = st.write_stream(response_generator(main_answer))
        
        st.session_state.messages.append({"role": "assistant", "content": response})