File size: 7,524 Bytes
eb2a41f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import os
import streamlit as st
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import PromptTemplate
from langchain_huggingface import HuggingFaceEndpoint
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter

# Load environment variables
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv())

# Constants
DATA_PATH = "data/"
DB_FAISS_PATH = "vectorstore/db_faiss"
HUGGINGFACE_REPO_ID = "mistralai/Mistral-7B-Instruct-v0.3"
HF_TOKEN = os.environ.get("HF_TOKEN")

# Custom prompt template
CUSTOM_PROMPT_TEMPLATE = """
Use the pieces of information provided in the context to answer user's question.
If you dont know the answer, just say that you dont know, dont try to make up an answer.

Dont provide anything out of the given context

Context: {context}
Question: {question}

Start the answer directly. No small talk please.
"""

def load_pdf_files(data_path):
    try:
        loader = DirectoryLoader(data_path,
                                glob='*.pdf',
                                loader_cls=PyPDFLoader)
        documents = loader.load()
        return documents
    except Exception as e:
        st.error(f"Error loading PDF files: {e}")
        return []

def create_chunks(extracted_data):
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=500,
                                                chunk_overlap=50)
    text_chunks = text_splitter.split_documents(extracted_data)
    return text_chunks

def get_embedding_model():
    embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    return embedding_model

def create_vectorstore():
    if not os.path.exists(DATA_PATH):
        os.makedirs(DATA_PATH)
        st.warning(f"Created empty data directory at {DATA_PATH}. Please upload PDF files.")
        return None
    
    documents = load_pdf_files(data=DATA_PATH)
    if not documents:
        st.warning("No PDF files found in data directory. Please upload some PDFs.")
        return None
    
    st.info(f"Loaded {len(documents)} PDF pages")
    text_chunks = create_chunks(extracted_data=documents)
    st.info(f"Created {len(text_chunks)} text chunks")
    
    embedding_model = get_embedding_model()
    
    if not os.path.exists(os.path.dirname(DB_FAISS_PATH)):
        os.makedirs(os.path.dirname(DB_FAISS_PATH))
    
    db = FAISS.from_documents(text_chunks, embedding_model)
    db.save_local(DB_FAISS_PATH)
    st.success(f"Created vector store at {DB_FAISS_PATH}")
    return db

@st.cache_resource
def get_vectorstore():
    if os.path.exists(DB_FAISS_PATH):
        embedding_model = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
        try:
            db = FAISS.load_local(DB_FAISS_PATH, embedding_model, allow_dangerous_deserialization=True)
            return db
        except Exception as e:
            st.error(f"Error loading vector store: {e}")
            return None
    else:
        st.warning("Vector store not found. Please create it first.")
        return None

def set_custom_prompt():
    prompt = PromptTemplate(template=CUSTOM_PROMPT_TEMPLATE, input_variables=["context", "question"])
    return prompt

def load_llm():
    if not HF_TOKEN:
        st.error("HF_TOKEN not found. Please set it in your environment variables.")
        return None
    
    try:
        llm = HuggingFaceEndpoint(
            repo_id=HUGGINGFACE_REPO_ID,
            task="text-generation",
            temperature=0.5,
            model_kwargs={
                "token": HF_TOKEN,
                "max_length": 512
            }
        )
        return llm
    except Exception as e:
        st.error(f"Error loading LLM: {e}")
        return None

def upload_pdf():
    uploaded_files = st.file_uploader("Upload PDF files", type="pdf", accept_multiple_files=True)
    if uploaded_files:
        for uploaded_file in uploaded_files:
            with open(os.path.join(DATA_PATH, uploaded_file.name), "wb") as f:
                f.write(uploaded_file.getbuffer())
        st.success(f"Uploaded {len(uploaded_files)} files to {DATA_PATH}")
        return True
    return False

def main():
    st.title("PDF Question Answering System")
    
    # Sidebar
    st.sidebar.title("Settings")
    page = st.sidebar.radio("Choose an action", ["Upload PDFs", "Create Vector Store", "Chat with Documents"])
    
    if page == "Upload PDFs":
        st.header("Upload PDF Files")
        st.info("Upload PDF files that will be used for question answering")
        if upload_pdf():
            st.info("Now go to 'Create Vector Store' to process your documents")
    
    elif page == "Create Vector Store":
        st.header("Create Vector Store")
        st.info("This will process your PDF files and create embeddings")
        if st.button("Create Vector Store"):
            with st.spinner("Processing documents..."):
                create_vectorstore()
    
    elif page == "Chat with Documents":
        st.header("Ask Questions About Your Documents")
        
        if 'messages' not in st.session_state:
            st.session_state.messages = []
        
        for message in st.session_state.messages:
            st.chat_message(message['role']).markdown(message['content'])
        
        prompt = st.chat_input("Ask a question about your documents")
        
        if prompt:
            st.chat_message('user').markdown(prompt)
            st.session_state.messages.append({'role': 'user', 'content': prompt})
            
            vectorstore = get_vectorstore()
            if vectorstore is None:
                st.error("Vector store not available. Please create it first.")
                return
            
            llm = load_llm()
            if llm is None:
                return
            
            try:
                with st.spinner("Thinking..."):
                    qa_chain = RetrievalQA.from_chain_type(
                        llm=llm,
                        chain_type="stuff",
                        retriever=vectorstore.as_retriever(search_kwargs={'k': 3}),
                        return_source_documents=True,
                        chain_type_kwargs={'prompt': set_custom_prompt()}
                    )
                    
                    response = qa_chain.invoke({'query': prompt})
                    
                    result = response["result"]
                    source_documents = response["source_documents"]
                    
                    # Format source documents more cleanly
                    source_docs_text = "\n\n**Source Documents:**\n"
                    for i, doc in enumerate(source_documents, 1):
                        source_docs_text += f"{i}. Page {doc.metadata.get('page', 'N/A')}: {doc.page_content[:200]}...\n\n"
                    
                    result_to_show = f"{result}\n{source_docs_text}"
                    
                    st.chat_message('assistant').markdown(result_to_show)
                    st.session_state.messages.append({'role': 'assistant', 'content': result_to_show})
                    
            except Exception as e:
                st.error(f"Error: {str(e)}")
                st.error("Please check your HuggingFace token and model access permissions")

if __name__ == "__main__":
    main()