Spaces:
Sleeping
Sleeping
File size: 7,524 Bytes
eb2a41f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import os
import streamlit as st
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import PromptTemplate
from langchain_huggingface import HuggingFaceEndpoint
from langchain_community.document_loaders import PyPDFLoader, DirectoryLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
# Load environment variables
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv())
# Constants
DATA_PATH = "data/"
DB_FAISS_PATH = "vectorstore/db_faiss"
HUGGINGFACE_REPO_ID = "mistralai/Mistral-7B-Instruct-v0.3"
HF_TOKEN = os.environ.get("HF_TOKEN")
# Custom prompt template
CUSTOM_PROMPT_TEMPLATE = """
Use the pieces of information provided in the context to answer user's question.
If you dont know the answer, just say that you dont know, dont try to make up an answer.
Dont provide anything out of the given context
Context: {context}
Question: {question}
Start the answer directly. No small talk please.
"""
def load_pdf_files(data_path):
try:
loader = DirectoryLoader(data_path,
glob='*.pdf',
loader_cls=PyPDFLoader)
documents = loader.load()
return documents
except Exception as e:
st.error(f"Error loading PDF files: {e}")
return []
def create_chunks(extracted_data):
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500,
chunk_overlap=50)
text_chunks = text_splitter.split_documents(extracted_data)
return text_chunks
def get_embedding_model():
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
return embedding_model
def create_vectorstore():
if not os.path.exists(DATA_PATH):
os.makedirs(DATA_PATH)
st.warning(f"Created empty data directory at {DATA_PATH}. Please upload PDF files.")
return None
documents = load_pdf_files(data=DATA_PATH)
if not documents:
st.warning("No PDF files found in data directory. Please upload some PDFs.")
return None
st.info(f"Loaded {len(documents)} PDF pages")
text_chunks = create_chunks(extracted_data=documents)
st.info(f"Created {len(text_chunks)} text chunks")
embedding_model = get_embedding_model()
if not os.path.exists(os.path.dirname(DB_FAISS_PATH)):
os.makedirs(os.path.dirname(DB_FAISS_PATH))
db = FAISS.from_documents(text_chunks, embedding_model)
db.save_local(DB_FAISS_PATH)
st.success(f"Created vector store at {DB_FAISS_PATH}")
return db
@st.cache_resource
def get_vectorstore():
if os.path.exists(DB_FAISS_PATH):
embedding_model = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
try:
db = FAISS.load_local(DB_FAISS_PATH, embedding_model, allow_dangerous_deserialization=True)
return db
except Exception as e:
st.error(f"Error loading vector store: {e}")
return None
else:
st.warning("Vector store not found. Please create it first.")
return None
def set_custom_prompt():
prompt = PromptTemplate(template=CUSTOM_PROMPT_TEMPLATE, input_variables=["context", "question"])
return prompt
def load_llm():
if not HF_TOKEN:
st.error("HF_TOKEN not found. Please set it in your environment variables.")
return None
try:
llm = HuggingFaceEndpoint(
repo_id=HUGGINGFACE_REPO_ID,
task="text-generation",
temperature=0.5,
model_kwargs={
"token": HF_TOKEN,
"max_length": 512
}
)
return llm
except Exception as e:
st.error(f"Error loading LLM: {e}")
return None
def upload_pdf():
uploaded_files = st.file_uploader("Upload PDF files", type="pdf", accept_multiple_files=True)
if uploaded_files:
for uploaded_file in uploaded_files:
with open(os.path.join(DATA_PATH, uploaded_file.name), "wb") as f:
f.write(uploaded_file.getbuffer())
st.success(f"Uploaded {len(uploaded_files)} files to {DATA_PATH}")
return True
return False
def main():
st.title("PDF Question Answering System")
# Sidebar
st.sidebar.title("Settings")
page = st.sidebar.radio("Choose an action", ["Upload PDFs", "Create Vector Store", "Chat with Documents"])
if page == "Upload PDFs":
st.header("Upload PDF Files")
st.info("Upload PDF files that will be used for question answering")
if upload_pdf():
st.info("Now go to 'Create Vector Store' to process your documents")
elif page == "Create Vector Store":
st.header("Create Vector Store")
st.info("This will process your PDF files and create embeddings")
if st.button("Create Vector Store"):
with st.spinner("Processing documents..."):
create_vectorstore()
elif page == "Chat with Documents":
st.header("Ask Questions About Your Documents")
if 'messages' not in st.session_state:
st.session_state.messages = []
for message in st.session_state.messages:
st.chat_message(message['role']).markdown(message['content'])
prompt = st.chat_input("Ask a question about your documents")
if prompt:
st.chat_message('user').markdown(prompt)
st.session_state.messages.append({'role': 'user', 'content': prompt})
vectorstore = get_vectorstore()
if vectorstore is None:
st.error("Vector store not available. Please create it first.")
return
llm = load_llm()
if llm is None:
return
try:
with st.spinner("Thinking..."):
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=vectorstore.as_retriever(search_kwargs={'k': 3}),
return_source_documents=True,
chain_type_kwargs={'prompt': set_custom_prompt()}
)
response = qa_chain.invoke({'query': prompt})
result = response["result"]
source_documents = response["source_documents"]
# Format source documents more cleanly
source_docs_text = "\n\n**Source Documents:**\n"
for i, doc in enumerate(source_documents, 1):
source_docs_text += f"{i}. Page {doc.metadata.get('page', 'N/A')}: {doc.page_content[:200]}...\n\n"
result_to_show = f"{result}\n{source_docs_text}"
st.chat_message('assistant').markdown(result_to_show)
st.session_state.messages.append({'role': 'assistant', 'content': result_to_show})
except Exception as e:
st.error(f"Error: {str(e)}")
st.error("Please check your HuggingFace token and model access permissions")
if __name__ == "__main__":
main() |