File size: 6,537 Bytes
94d6273
 
 
 
 
 
 
 
 
b6bc1e1
94d6273
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import streamlit as st
from ragchatbot import RAGChatBot
from pydantic_models import RequestModel, ChatHistoryItem


def validate_chat_history_item(chat_history_item: ChatHistoryItem):
    return ChatHistoryItem.model_validate(chat_history_item.model_dump())

st.set_page_config(page_title="RAG-Chatbot", page_icon=":mag:", layout="wide")
st.title("Test Contextual Retrieval - KCS10")
col1, col2, col3 = st.columns(3)

col1.title("Contextual Chunking")
col2.title("Current Model")
col3.title("Formatted Text")

if "context_ragchatbot" not in st.session_state:
    st.session_state.context_ragchatbot = RAGChatBot(vectorstore_path="context_vectorstore")

if "formatted_ragchatbot" not in st.session_state:
    st.session_state.formatted_ragchatbot = RAGChatBot(vectorstore_path="formatted_vectorstore")

if "just_ragchatbot" not in st.session_state:
    st.session_state.just_ragchatbot = RAGChatBot(vectorstore_path="just_vectorstore")

if "context_chat_history" not in st.session_state:
    st.session_state.context_chat_history = []

if "formatted_chat_history" not in st.session_state:
    st.session_state.formatted_chat_history = []

if "just_chat_history" not in st.session_state:
    st.session_state.just_chat_history = []
for chat_index in range(0,len(st.session_state.context_chat_history)):
    assert len(st.session_state.context_chat_history) == len(st.session_state.formatted_chat_history) == len(st.session_state.just_chat_history)
    for col, chat_history, sources_text in zip(st.columns(3, vertical_alignment="top"), [st.session_state.context_chat_history, st.session_state.just_chat_history, st.session_state.formatted_chat_history], ["Contextual Chunking", "Current Model", "Formatted Text"]):
        chat = chat_history[chat_index]
        with col.chat_message("user"):
            st.write(chat.get("user_message").replace("\n","\n\n"))
        with col.chat_message("assistant"):
            st.write(chat.get("assistant_message").replace("\n","\n\n"))
            st.write(chat.get("search_phrase"))
            for i, doc in enumerate(chat.get("sources_documents")):
                with st.expander(f"{sources_text} Sources - {i+1}"):
                    st.subheader(f"{doc.get('heading')} - {doc.get('relevance_score')}")
                    if sources_text == "Contextual Chunking":
                        st.write(doc.get("page_content").replace("\n","\n\n").split("<chunk_content>")[1].split("</chunk_content>")[0])
                    else:
                        st.write(doc.get("page_content").replace("\n","\n\n"))

# print_session_state_variables()
if user_query := st.chat_input("Enter your query"):
    for col in st.columns(3, vertical_alignment="top"):
        with col.chat_message("user"):
            st.write(user_query.replace("\n","\n\n"))
    with st.spinner("Generating response..."):
        context_response = st.session_state.context_ragchatbot.get_response(
            RequestModel(user_question=user_query, chat_history=[ChatHistoryItem(user_message=chat.get("user_message"), assistant_message=chat.get("assistant_message")) for chat in st.session_state.context_chat_history])
        )
        sources_documents = [{"heading":doc.heading, "page_content":doc.page_content, "relevance_score":doc.relevance_score} for doc in context_response.sources_documents]
        st.session_state.context_chat_history.append({
            "user_message": user_query,
            "assistant_message": context_response.answer,
            "search_phrase": context_response.search_phrase,
            "sources_documents": sources_documents
        })


        just_response = st.session_state.just_ragchatbot.get_response(
            RequestModel(user_question=user_query, chat_history=[ChatHistoryItem(user_message=chat.get("user_message"), assistant_message=chat.get("assistant_message")) for chat in st.session_state.just_chat_history])
        )
        sources_documents = [{"heading":doc.heading, "page_content":doc.page_content, "relevance_score":doc.relevance_score} for doc in just_response.sources_documents]
        st.session_state.just_chat_history.append({
            "user_message": user_query,
            "assistant_message": just_response.answer,
            "search_phrase": just_response.search_phrase,
            "sources_documents": sources_documents
        })


        formatted_response = st.session_state.formatted_ragchatbot.get_response(
            RequestModel(user_question=user_query, chat_history=[ChatHistoryItem(user_message=chat.get("user_message"), assistant_message=chat.get("assistant_message")) for chat in st.session_state.formatted_chat_history])
        )
        sources_documents = [{"heading":doc.heading, "page_content":doc.page_content, "relevance_score":doc.relevance_score} for doc in formatted_response.sources_documents]
        st.session_state.formatted_chat_history.append({
            "user_message": user_query,
            "assistant_message": formatted_response.answer,
            "search_phrase": formatted_response.search_phrase,
            "sources_documents": sources_documents
        })


        st.rerun()
    # with col1.chat_message("assistant"):
    #     st.write(context_response.answer.replace("\n","\n\n"))
    #     with col1.expander("Contextual Chunking Sources"):
    #         for doc in context_response.sources_documents:
    #             st.subheader(f"{doc.heading} - {doc.relevance_score}")
    #             st.write(doc.page_content.replace("\n","\n\n").split("<chunk_content>")[1].split("</chunk_content>")[0])
    #             st.divider()
    # with col2.chat_message("assistant"):
    #     st.write(just_response.answer.replace("\n","\n\n"))
    #     with st.expander("Without Contextual Chunking Sources"):
    #         st.write(just_response.chat_history[-1].search_phrase)
    #         for doc in just_response.sources_documents:
    #             st.subheader(f"{doc.heading} - {doc.relevance_score}")
    #             st.write(doc.page_content.replace("\n","\n\n"))
    #             st.divider()
    # with col3.chat_message("assistant"):
    #     st.write(formatted_response.answer.replace("\n","\n\n"))
    #     with st.expander("Formatted Contextual Chunking Sources"):
    #         st.write(formatted_response.chat_history[-1].search_phrase)
    #         for doc in formatted_response.sources_documents:
    #             st.subheader(f"{doc.heading} - {doc.relevance_score}")
    #             st.write(doc.page_content.replace("\n","\n\n"))
    #             st.divider()