import os import json import gradio as gr import pandas as pd from tempfile import NamedTemporaryFile from typing import List from langchain_core.prompts import ChatPromptTemplate from langchain_community.vectorstores import FAISS from langchain_community.document_loaders import PyPDFLoader from langchain_core.output_parsers import StrOutputParser from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_text_splitters import RecursiveCharacterTextSplitter from langchain_community.llms import HuggingFaceHub from langchain_core.runnables import RunnableParallel, RunnablePassthrough from langchain_core.document import Document huggingface_token = os.environ.get("HUGGINGFACE_TOKEN") def load_and_split_document_basic(file): """Loads and splits the document into pages.""" loader = PyPDFLoader(file.name) data = loader.load_and_split() return data def load_and_split_document_recursive(file: NamedTemporaryFile) -> List[Document]: """Loads and splits the document into chunks.""" loader = PyPDFLoader(file.name) pages = loader.load() text_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=200, length_function=len, ) chunks = text_splitter.split_documents(pages) return chunks def get_embeddings(): return HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") def create_or_update_database(data, embeddings): if os.path.exists("faiss_database"): db = FAISS.load_local("faiss_database", embeddings, allow_dangerous_deserialization=True) db.add_documents(data) else: db = FAISS.from_documents(data, embeddings) db.save_local("faiss_database") def clear_cache(): if os.path.exists("faiss_database"): os.remove("faiss_database") return "Cache cleared successfully." else: return "No cache to clear." prompt = """ Answer the question based only on the following context: {context} Question: {question} Provide a concise and direct answer to the question: """ def get_model(temperature, top_p, repetition_penalty): return HuggingFaceHub( repo_id="mistralai/Mistral-7B-Instruct-v0.3", model_kwargs={ "temperature": temperature, "top_p": top_p, "repetition_penalty": repetition_penalty, "max_length": 512 }, huggingfacehub_api_token=huggingface_token ) def generate_chunked_response(model, prompt, max_tokens=500, max_chunks=5): full_response = "" for i in range(max_chunks): chunk = model(prompt + full_response, max_new_tokens=max_tokens) full_response += chunk if chunk.strip().endswith((".", "!", "?")): break return full_response.strip() def response(database, model, question): prompt_val = ChatPromptTemplate.from_template(prompt) retriever = database.as_retriever() context = retriever.get_relevant_documents(question) context_str = "\n".join([doc.page_content for doc in context]) formatted_prompt = prompt_val.format(context=context_str, question=question) ans = generate_chunked_response(model, formatted_prompt) return ans def update_vectors(files, use_recursive_splitter): if not files: return "Please upload at least one PDF file." embed = get_embeddings() total_chunks = 0 for file in files: if use_recursive_splitter: data = load_and_split_document_recursive(file) else: data = load_and_split_document_basic(file) create_or_update_database(data, embed) total_chunks += len(data) return f"Vector store updated successfully. Processed {total_chunks} chunks from {len(files)} files." def ask_question(question, temperature, top_p, repetition_penalty): if not question: return "Please enter a question." embed = get_embeddings() database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True) model = get_model(temperature, top_p, repetition_penalty) return response(database, model, question) def extract_db_to_excel(): embed = get_embeddings() database = FAISS.load_local("faiss_database", embed, allow_dangerous_deserialization=True) documents = database.docstore._dict.values() data = [{"page_content": doc.page_content, "metadata": json.dumps(doc.metadata)} for doc in documents] df = pd.DataFrame(data) with NamedTemporaryFile(delete=False, suffix='.xlsx') as tmp: excel_path = tmp.name df.to_excel(excel_path, index=False) return excel_path # Gradio interface with gr.Blocks() as demo: gr.Markdown("# Chat with your PDF documents") with gr.Row(): file_input = gr.Files(label="Upload your PDF documents", file_types=[".pdf"]) update_button = gr.Button("Update Vector Store") use_recursive_splitter = gr.Checkbox(label="Use Recursive Text Splitter", value=False) update_output = gr.Textbox(label="Update Status") update_button.click(update_vectors, inputs=[file_input, use_recursive_splitter], outputs=update_output) with gr.Row(): question_input = gr.Textbox(label="Ask a question about your documents") temperature_slider = gr.Slider(label="Temperature", minimum=0.0, maximum=1.0, value=0.5, step=0.1) top_p_slider = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, value=0.9, step=0.1) repetition_penalty_slider = gr.Slider(label="Repetition Penalty", minimum=1.0, maximum=2.0, value=1.0, step=0.1) submit_button = gr.Button("Submit") answer_output = gr.Textbox(label="Answer") submit_button.click(ask_question, inputs=[question_input, temperature_slider, top_p_slider, repetition_penalty_slider], outputs=answer_output) extract_button = gr.Button("Extract Database to Excel") excel_output = gr.File(label="Download Excel File") extract_button.click(extract_db_to_excel, inputs=[], outputs=excel_output) clear_button = gr.Button("Clear Cache") clear_output = gr.Textbox(label="Cache Status") clear_button.click(clear_cache, inputs=[], outputs=clear_output) if __name__ == "__main__": demo.launch()