File size: 7,173 Bytes
1049376
 
1376891
1049376
 
 
1376891
1049376
 
 
 
 
 
 
0438285
1049376
0438285
 
 
 
 
 
 
 
1049376
 
 
0438285
ac11d7b
1049376
0438285
 
 
 
1049376
 
 
 
 
 
 
 
 
 
 
 
 
 
a1d1670
1049376
 
 
 
a1d1670
 
 
1049376
 
 
 
 
 
 
 
 
 
 
ac11d7b
1049376
 
 
 
 
 
 
 
 
0438285
1049376
0438285
1049376
 
 
 
 
 
 
 
 
 
a1d1670
1049376
 
a1d1670
1049376
 
 
 
 
 
 
 
 
 
a3741a5
1049376
 
 
 
 
 
 
 
93b3f24
 
 
 
 
 
 
 
1049376
93b3f24
1049376
 
93b3f24
1049376
 
ecbab86
1049376
ecbab86
 
1049376
 
ecbab86
1049376
 
ecbab86
0438285
1049376
 
 
 
 
 
 
a3741a5
 
 
93b3f24
ecbab86
1049376
ecbab86
1049376
a1d1670
1049376
ecbab86
 
 
1049376
 
2e1d4d8
ac11d7b
2e1d4d8
 
 
 
1049376
f5daac6
 
2e1d4d8
f5daac6
 
 
 
f749b78
f4bc94d
f749b78
f5daac6
03ec30e
 
 
 
 
 
a3741a5
1049376
1376891
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
import streamlit as st
import os
import tempfile
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory

api_token = os.getenv("HF_TOKEN")
list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]
list_llm_simple = [os.path.basename(llm) for llm in list_llm]

def load_doc(uploaded_files):
    try:
        temp_files = []
        for uploaded_file in uploaded_files:
            temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
            temp_file.write(uploaded_file.read())
            temp_file.close()
            temp_files.append(temp_file.name)
        
        loaders = [PyPDFLoader(x) for x in temp_files]
        pages = []
        for loader in loaders:
            pages.extend(loader.load())
        
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
        doc_splits = text_splitter.split_documents(pages)

        for temp_file in temp_files:
            os.remove(temp_file)  # Clean up temporary files

        return doc_splits
    except Exception as e:
        st.error(f"Error loading document: {e}")
        return []

def create_db(splits):
    try:
        embeddings = HuggingFaceEmbeddings()
        vectordb = FAISS.from_documents(splits, embeddings)
        return vectordb
    except Exception as e:
        st.error(f"Error creating vector database: {e}")
        return None

def initialize_llmchain(llm_model, vector_db):
    try:
        llm = HuggingFaceEndpoint(
            repo_id=llm_model,
            huggingfacehub_api_token=api_token,
            temperature=0.5,
            max_new_tokens=4096,
            top_k=3,
        )
        memory = ConversationBufferMemory(
            memory_key="chat_history",
            output_key='answer',
            return_messages=True
        )

        retriever = vector_db.as_retriever()
        qa_chain = ConversationalRetrievalChain.from_llm(
            llm,
            retriever=retriever,
            chain_type="stuff",
            memory=memory,
            return_source_documents=True,
            verbose=False,
        )
        return qa_chain
    except Exception as e:
        st.error(f"Error initializing LLM chain: {e}")
        return None

def initialize_database(uploaded_files):
    try:
        doc_splits = load_doc(uploaded_files)
        if not doc_splits:
            return None, "Failed to load documents."
        vector_db = create_db(doc_splits)
        if vector_db is None:
            return None, "Failed to create vector database."
        return vector_db, "Database created!"
    except Exception as e:
        st.error(f"Error initializing database: {e}")
        return None, "Failed to initialize database."

def initialize_LLM(llm_option, vector_db):
    try:
        llm_name = list_llm[llm_option]
        qa_chain = initialize_llmchain(llm_name, vector_db)
        if qa_chain is None:
            return None, "Failed to initialize QA chain."
        return qa_chain, "QA chain initialized. Chatbot is ready!"
    except Exception as e:
        st.error(f"Error initializing LLM: {e}")
        return None, "Failed to initialize LLM."

def format_chat_history(chat_history):
    formatted_chat_history = []
    for user_message, bot_message in chat_history:
        formatted_chat_history.append(f"User: {user_message}\nAssistant: {bot_message}\n")
    return formatted_chat_history

def conversation(qa_chain, message, history):
    try:
        formatted_chat_history = format_chat_history(history)
        response = qa_chain.invoke({"question": message, "chat_history": formatted_chat_history})
        response_answer = response["answer"]
        response_sources = response["source_documents"]

        sources = []
        for doc in response_sources:
            sources.append({
                "content": doc.page_content.strip(),
                "page": doc.metadata["page"] + 1
            })

        new_history = history + [(message, response_answer)]
        return qa_chain, new_history, response_answer, sources
    except Exception as e:
        st.error(f"Error in conversation: {e}")
        return qa_chain, history, "", []

def main():
    st.sidebar.title("PDF Chatbot")

    st.sidebar.markdown("### Step 1 - Upload PDF documents and Initialize RAG pipeline")
    uploaded_files = st.sidebar.file_uploader("Upload PDF documents", type="pdf", accept_multiple_files=True)

    if uploaded_files:
        if st.sidebar.button("Create vector database"):
            with st.spinner("Creating vector database..."):
                vector_db, db_message = initialize_database(uploaded_files)
                st.sidebar.success(db_message)
                st.session_state['vector_db'] = vector_db

    if 'vector_db' not in st.session_state:
        st.session_state['vector_db'] = None

    if 'qa_chain' not in st.session_state:
        st.session_state['qa_chain'] = None

    if 'chat_history' not in st.session_state:
        st.session_state['chat_history'] = []

    st.sidebar.markdown("### Select Large Language Model (LLM)")
    llm_option = st.sidebar.radio("Available LLMs", list_llm_simple)

    if st.sidebar.button("Initialize Question Answering Chatbot"):
        with st.spinner("Initializing QA chatbot..."):
            qa_chain, llm_message = initialize_LLM(list_llm_simple.index(llm_option), st.session_state['vector_db'])
            st.session_state['qa_chain'] = qa_chain
            st.sidebar.success(llm_message)

    st.title("Chat with your Document")

    if st.session_state['qa_chain']:
        st.markdown("### Chatbot Response")

        # Display the chat history in a chat-like interface
        for i, (user_msg, bot_msg) in enumerate(st.session_state['chat_history']):
            st.markdown(f"**User:** {user_msg}")
            st.markdown(f"**Assistant:** {bot_msg}")

        st.markdown("### Relevant context from the source document")

        with st.expander("Relevant context from the source document"):
            if 'sources' in st.session_state:
                for i, source in enumerate(st.session_state['sources']):
                    st.text_area(f"Source {i + 1} - Page {source['page']}", value=source["content"], height=100)

        with st.form(key="question_form"):
            message = st.text_input("Ask a question", key="message")
            submit_button = st.form_submit_button(label="Submit")

        if submit_button:
            with st.spinner("Generating response..."):
                qa_chain, chat_history, response_answer, sources = conversation(st.session_state['qa_chain'], message, st.session_state['chat_history'])
                st.session_state['qa_chain'] = qa_chain
                st.session_state['chat_history'] = chat_history
                st.session_state['sources'] = sources

if __name__ == "__main__":
    main()