File size: 6,942 Bytes
3b0feea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f23f69
3b0feea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4913f59
3b0feea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import streamlit as st
import os
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_huggingface import HuggingFaceEndpoint  # Updated import
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
import tempfile

api_token = os.getenv("HF_TOKEN")
list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]
list_llm_simple = [os.path.basename(llm) for llm in list_llm]

def load_doc(uploaded_files):
    try:
        temp_files = []
        for uploaded_file in uploaded_files:
            temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
            temp_file.write(uploaded_file.read())
            temp_file.close()
            temp_files.append(temp_file.name)
        
        loaders = [PyPDFLoader(x) for x in temp_files]
        pages = []
        for loader in loaders:
            pages.extend(loader.load())
        
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
        doc_splits = text_splitter.split_documents(pages)

        for temp_file in temp_files:
            os.remove(temp_file)  # Clean up temporary files

        return doc_splits
    except Exception as e:
        st.error(f"Error loading document: {e}")
        return []

def create_db(splits):
    try:
        embeddings = HuggingFaceEmbeddings()
        vectordb = FAISS.from_documents(splits, embeddings)
        return vectordb
    except Exception as e:
        st.error(f"Error creating vector database: {e}")
        return None

def initialize_llmchain(llm_model, vector_db):
    try:
        llm = HuggingFaceEndpoint(
            repo_id=llm_model,
            huggingfacehub_api_token=api_token,
            temperature=0.5,
            max_new_tokens=4096,
            top_k=3,
        )
        memory = ConversationBufferMemory(
            memory_key="chat_history",
            output_key='answer',
            return_messages=True
        )

        retriever = vector_db.as_retriever()
        qa_chain = ConversationalRetrievalChain.from_llm(
            llm,
            retriever=retriever,
            chain_type="stuff",
            memory=memory,
            return_source_documents=True,
            verbose=False,
        )
        return qa_chain
    except Exception as e:
        st.error(f"Error initializing LLM chain: {e}")
        return None

def initialize_database(uploaded_files):
    try:
        doc_splits = load_doc(uploaded_files)
        if not doc_splits:
            return None, "Failed to load documents."
        vector_db = create_db(doc_splits)
        if vector_db is None:
            return None, "Failed to create vector database."
        return vector_db, "Database created!"
    except Exception as e:
        st.error(f"Error initializing database: {e}")
        return None, "Failed to initialize database."

def initialize_LLM(llm_option, vector_db):
    try:
        llm_name = list_llm[llm_option]
        qa_chain = initialize_llmchain(llm_name, vector_db)
        if qa_chain is None:
            return None, "Failed to initialize QA chain."
        return qa_chain, "QA chain initialized. Chatbot is ready!"
    except Exception as e:
        st.error(f"Error initializing LLM: {e}")
        return None, "Failed to initialize LLM."

def format_chat_history(chat_history):
    formatted_chat_history = []
    for user_message, bot_message in chat_history:
        formatted_chat_history.append(f"User: {user_message}\nAssistant: {bot_message}\n")
    return formatted_chat_history

def conversation(qa_chain, message, history):
    try:
        formatted_chat_history = format_chat_history(history)
        response = qa_chain.invoke({"question": message, "chat_history": formatted_chat_history})
        response_answer = response["answer"]
        response_sources = response["source_documents"]

        sources = []
        for doc in response_sources:
            sources.append({
                "content": doc.page_content.strip(),
                "page": doc.metadata["page"] + 1
            })

        new_history = history + [(message, response_answer)]
        return qa_chain, new_history, response_answer, sources
    except Exception as e:
        st.error(f"Error in conversation: {e}")
        return qa_chain, history, "", []

def main():
    st.sidebar.title("PDF Chatbot")

    st.sidebar.markdown("### Step 1 - Upload PDF documents and create the vector database")
    uploaded_files = st.sidebar.file_uploader("Upload PDF documents", type="pdf", accept_multiple_files=True)

    if uploaded_files:
        if st.sidebar.button("Create vector database"):
            with st.spinner("Creating vector database..."):
                vector_db, db_message = initialize_database(uploaded_files)
                st.sidebar.success(db_message)
                st.session_state['vector_db'] = vector_db

    if 'vector_db' not in st.session_state:
        st.session_state['vector_db'] = None

    if 'qa_chain' not in st.session_state:
        st.session_state['qa_chain'] = None

    if 'chat_history' not in st.session_state:
        st.session_state['chat_history'] = []

    st.sidebar.markdown("### Select Large Language Model (LLM)")
    llm_option = st.sidebar.radio("Available LLMs", list_llm_simple)

    if st.sidebar.button("Initialize Question Answering Chatbot"):
        with st.spinner("Initializing QA chatbot..."):
            qa_chain, llm_message = initialize_LLM(list_llm_simple.index(llm_option), st.session_state['vector_db'])
            st.session_state['qa_chain'] = qa_chain
            st.sidebar.success(llm_message)

    st.title("Chat with your Document")

    sources = []  # Initialize sources variable
    if st.session_state['qa_chain']:
        message = st.text_input("Ask a question")

        if st.button("Submit"):
            with st.spinner("Generating response..."):
                qa_chain, chat_history, response_answer, sources = conversation(st.session_state['qa_chain'], message, st.session_state['chat_history'])
                st.session_state['qa_chain'] = qa_chain
                st.session_state['chat_history'] = chat_history

        st.markdown("### Chatbot Response")

        # Display the chat history in a chat-like interface
        for i, (user_msg, bot_msg) in enumerate(st.session_state['chat_history']):
            st.markdown(f"**User:** {user_msg}")
            st.markdown(f"**Assistant:** {bot_msg}")

        with st.expander("Relevant context from the source document"):
            for source in sources:
                st.text_area(f"Source - Page {source['page']}", value=source["content"], height=100)

if __name__ == "__main__":
    main()