import os import gradio as gr from langchain_community.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.embeddings import HuggingFaceEmbeddings from langchain_community.vectorstores import FAISS from langchain.chains import ConversationalRetrievalChain from langchain.memory import ConversationBufferMemory from langchain_community.llms import HuggingFacePipeline from transformers import pipeline # Embeddings- und LLM-Modelle EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" LLM_MODEL_NAME = "google/flan-t5-small" # **Dokumente laden und aufteilen** def load_and_split_docs(list_file_path): if not list_file_path: return [], "Fehler: Keine Dokumente gefunden!" loaders = [PyPDFLoader(x) for x in list_file_path] documents = [] for loader in loaders: documents.extend(loader.load()) text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=32) return text_splitter.split_documents(documents) # **Vektor-Datenbank mit FAISS erstellen** def create_db(docs): embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME) return FAISS.from_documents(docs, embeddings) # **Datenbank initialisieren** def initialize_database(list_file_obj): if not list_file_obj or all(x is None for x in list_file_obj): return None, "Fehler: Keine Dateien hochgeladen!" list_file_path = [x.name for x in list_file_obj if x is not None] doc_splits = load_and_split_docs(list_file_path) vector_db = create_db(doc_splits) return vector_db, "Datenbank erfolgreich erstellt!" # **LLM-Kette initialisieren (Wrapper)** def initialize_llm_chain_wrapper(temperature, max_tokens, vector_db): if vector_db is None: return None, "Fehler: Vektordatenbank nicht initialisiert!" qa_chain = initialize_llm_chain(temperature, max_tokens, vector_db) return qa_chain, "QA-Chatbot ist bereit!" # **LLM-Kette erstellen** def initialize_llm_chain(temperature, max_tokens, vector_db): local_pipeline = pipeline( "text2text-generation", model=LLM_MODEL_NAME, max_length=max_tokens, temperature=temperature ) llm = HuggingFacePipeline(pipeline=local_pipeline) memory = ConversationBufferMemory(memory_key="chat_history") retriever = vector_db.as_retriever() return ConversationalRetrievalChain.from_llm( llm, retriever=retriever, memory=memory, return_source_documents=True ) # **Konversation mit QA-Kette führen** def conversation(qa_chain, message, history): if qa_chain is None: return None, "Der QA-Chain wurde nicht initialisiert!", history if not message.strip(): return qa_chain, "Bitte eine Frage eingeben!", history try: response = qa_chain({"question": message, "chat_history": history}) response_text = response["answer"] sources = [doc.metadata["source"] for doc in response["source_documents"]] sources_text = "\n".join(sources) if sources else "Keine Quellen verfügbar" return qa_chain, f"{response_text}\n\n**Quellen:**\n{sources_text}", history + [(message, response_text)] except Exception as e: return qa_chain, f"Fehler: {str(e)}", history # **Gradio-Demo erstellen** def demo(): with gr.Blocks() as demo: vector_db = gr.State() # Zustand für die Vektordatenbank qa_chain = gr.State() # Zustand für den QA-Chain chat_history = gr.State([]) # Chatverlauf speichern gr.HTML("

RAG Chatbot mit FAISS und lokalen Modellen

") with gr.Row(): with gr.Column(): document = gr.Files(file_types=[".pdf"], label="PDF hochladen") db_btn = gr.Button("Erstelle Vektordatenbank") db_status = gr.Textbox(value="Status: Nicht initialisiert", show_label=False) slider_temperature = gr.Slider(0.01, 1.0, value=0.5, label="Temperature") slider_max_tokens = gr.Slider(64, 512, value=256, label="Max Tokens") qachain_btn = gr.Button("Initialisiere QA-Chatbot") with gr.Column(): chatbot = gr.Chatbot(label="Chatbot", type='messages', height=400) msg = gr.Textbox(label="Deine Frage:", placeholder="Frage eingeben...") submit_btn = gr.Button("Absenden") # **Button-Events definieren** db_btn.click( initialize_database, inputs=[document], # Eingabe der hochgeladenen Dokumente outputs=[vector_db, db_status] # Ausgabe: Vektor-Datenbank und Status ) qachain_btn.click( initialize_llm_chain_wrapper, inputs=[slider_temperature, slider_max_tokens, vector_db], outputs=[qa_chain, db_status] ) submit_btn.click( conversation, inputs=[qa_chain, msg, chat_history], # Chatkette, Nutzerfrage, Chatverlauf outputs=[qa_chain, chatbot, chat_history] # Antwort der Kette, Chatbot-Ausgabe, neuer Verlauf ) demo.launch(debug=True, queue=True) # Verwendung von queue=True statt enable_queue=True if __name__ == "__main__": demo()