File size: 5,475 Bytes
1049376
 
1376891
1049376
 
 
1376891
1049376
 
 
 
 
 
 
0438285
1049376
0438285
 
 
 
 
 
 
 
1049376
 
 
0438285
ac11d7b
1049376
0438285
 
 
 
1049376
 
 
 
 
 
 
 
 
 
 
 
 
 
a1d1670
1049376
 
 
 
a1d1670
 
 
1049376
 
 
 
 
 
 
 
 
 
 
ac11d7b
1049376
 
 
 
 
 
 
 
 
0438285
1049376
0438285
1049376
 
 
 
 
 
 
 
 
 
a1d1670
1049376
 
a1d1670
1049376
 
 
 
 
 
 
 
 
 
a3741a5
1049376
 
 
 
 
 
 
 
93b3f24
 
 
 
 
 
 
 
1049376
93b3f24
1049376
 
93b3f24
1049376
 
ecbab86
1049376
ecbab86
 
1049376
 
ecbab86
1049376
 
ecbab86
0438285
1049376
 
 
 
 
 
 
a3741a5
 
 
93b3f24
ecbab86
1049376
4ba3627
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import streamlit as st
import os
import tempfile
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory

api_token = os.getenv("HF_TOKEN")
list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]
list_llm_simple = [os.path.basename(llm) for llm in list_llm]

def load_doc(uploaded_files):
    try:
        temp_files = []
        for uploaded_file in uploaded_files:
            temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
            temp_file.write(uploaded_file.read())
            temp_file.close()
            temp_files.append(temp_file.name)
        
        loaders = [PyPDFLoader(x) for x in temp_files]
        pages = []
        for loader in loaders:
            pages.extend(loader.load())
        
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
        doc_splits = text_splitter.split_documents(pages)

        for temp_file in temp_files:
            os.remove(temp_file)  # Clean up temporary files

        return doc_splits
    except Exception as e:
        st.error(f"Error loading document: {e}")
        return []

def create_db(splits):
    try:
        embeddings = HuggingFaceEmbeddings()
        vectordb = FAISS.from_documents(splits, embeddings)
        return vectordb
    except Exception as e:
        st.error(f"Error creating vector database: {e}")
        return None

def initialize_llmchain(llm_model, vector_db):
    try:
        llm = HuggingFaceEndpoint(
            repo_id=llm_model,
            huggingfacehub_api_token=api_token,
            temperature=0.5,
            max_new_tokens=4096,
            top_k=3,
        )
        memory = ConversationBufferMemory(
            memory_key="chat_history",
            output_key='answer',
            return_messages=True
        )

        retriever = vector_db.as_retriever()
        qa_chain = ConversationalRetrievalChain.from_llm(
            llm,
            retriever=retriever,
            chain_type="stuff",
            memory=memory,
            return_source_documents=True,
            verbose=False,
        )
        return qa_chain
    except Exception as e:
        st.error(f"Error initializing LLM chain: {e}")
        return None

def initialize_database(uploaded_files):
    try:
        doc_splits = load_doc(uploaded_files)
        if not doc_splits:
            return None, "Failed to load documents."
        vector_db = create_db(doc_splits)
        if vector_db is None:
            return None, "Failed to create vector database."
        return vector_db, "Database created!"
    except Exception as e:
        st.error(f"Error initializing database: {e}")
        return None, "Failed to initialize database."

def initialize_LLM(llm_option, vector_db):
    try:
        llm_name = list_llm[llm_option]
        qa_chain = initialize_llmchain(llm_name, vector_db)
        if qa_chain is None:
            return None, "Failed to initialize QA chain."
        return qa_chain, "QA chain initialized. Chatbot is ready!"
    except Exception as e:
        st.error(f"Error initializing LLM: {e}")
        return None, "Failed to initialize LLM."

def format_chat_history(chat_history):
    formatted_chat_history = []
    for user_message, bot_message in chat_history:
        formatted_chat_history.append(f"User: {user_message}\nAssistant: {bot_message}\n")
    return formatted_chat_history

def conversation(qa_chain, message, history):
    try:
        formatted_chat_history = format_chat_history(history)
        response = qa_chain.invoke({"question": message, "chat_history": formatted_chat_history})
        response_answer = response["answer"]
        response_sources = response["source_documents"]

        sources = []
        for doc in response_sources:
            sources.append({
                "content": doc.page_content.strip(),
                "page": doc.metadata["page"] + 1
            })

        new_history = history + [(message, response_answer)]
        return qa_chain, new_history, response_answer, sources
    except Exception as e:
        st.error(f"Error in conversation: {e}")
        return qa_chain, history, "", []

def main():
    st.sidebar.title("PDF Chatbot")

    st.sidebar.markdown("### Step 1 - Upload PDF documents and Initialize RAG pipeline")
    uploaded_files = st.sidebar.file_uploader("Upload PDF documents", type="pdf", accept_multiple_files=True)

    if uploaded_files:
        if st.sidebar.button("Create vector database"):
            with st.spinner("Creating vector database..."):
                vector_db, db_message = initialize_database(uploaded_files)
                st.sidebar.success(db_message)
                st.session_state['vector_db'] = vector_db

    if 'vector_db' not in st.session_state:
        st.session_state['vector_db'] = None

    if 'qa_chain' not in st.session_state:
        st.session_state['qa_chain'] = None

    if 'chat_history' not in st.session_state:
        st.session_state['chat_history'] = []

    st.sidebar.markdown("### Select Large Language Model (LLM)")
    llm_option = st.sidebar.radio("Available LLMs", list_llm_simple)

    if