import gradio as gr import os from langchain_community.document_loaders import PyPDFLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint from langchain_community.vectorstores import FAISS from langchain.chains import ConversationalRetrievalChain from langchain.memory import ConversationBufferMemory # API-Token aus Umgebungsvariable laden api_token = os.getenv("HF_Token") # Modelle für Auswahl list_llm = [ "google/flan-t5-base", # Leichtes Instruktionsmodell "sentence-transformers/all-MiniLM-L6-v2", # Embeddings-optimiertes Modell "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5", # Pythia 12B "bigscience/bloom-3b", # Multilingualer BLOOM "bigscience/bloom-1b7" # Leichtes BLOOM-Modell ] # Dokumentenverarbeitung def load_doc(list_file_path): if not list_file_path: return [], "Fehler: Keine Dokumente gefunden!" loaders = [PyPDFLoader(x) for x in list_file_path] documents = [] for loader in loaders: documents.extend(loader.load()) text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=32) return text_splitter.split_documents(documents) # Erstelle Vektordatenbank def create_db(splits): embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") return FAISS.from_documents(splits, embeddings) # Initialisiere Datenbank def initialize_database(list_file_obj): if not list_file_obj: return None, "Fehler: Keine Dateien hochgeladen!" list_file_path = list_file_obj # Dateipfade von den hochgeladenen Dateien doc_splits = load_doc(list_file_path) vector_db = create_db(doc_splits) return vector_db, "Datenbank erfolgreich erstellt!" # Initialisiere LLM-Kette def initialize_llmchain(llm_model, temperature, max_tokens, vector_db): if vector_db is None: return None, "Fehler: Keine Vektordatenbank verfügbar." if "pythia" in llm_model or "bloom" in llm_model: max_tokens = min(max_tokens, 2048) else: max_tokens = min(max_tokens, 1024) llm = HuggingFaceEndpoint( repo_id=llm_model, huggingfacehub_api_token=api_token, temperature=temperature, max_new_tokens=max_tokens ) memory = ConversationBufferMemory(memory_key="chat_history", output_key="answer", return_messages=True) retriever = vector_db.as_retriever() qa_chain = ConversationalRetrievalChain.from_llm( llm, retriever=retriever, chain_type="stuff", memory=memory, return_source_documents=True ) return qa_chain # Initialisiere LLM def initialize_LLM(llm_option, llm_temperature, max_tokens, vector_db): if vector_db is None: return None, "Fehler: Datenbank wurde nicht erstellt!" llm_name = list_llm[llm_option] qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, vector_db) return qa_chain, "QA-Kette initialisiert. Chatbot ist bereit!" # Konversation def conversation(qa_chain, message, history): if qa_chain is None: return None, [{"role": "system", "content": "Die QA-Kette wurde nicht initialisiert."}], history if not message.strip(): return qa_chain, [{"role": "system", "content": "Bitte eine Frage eingeben!"}], history response = qa_chain.invoke({"question": message, "chat_history": history}) response_text = response.get("answer", "Keine Antwort verfügbar.") formatted_response = history + [{"role": "user", "content": message}, {"role": "assistant", "content": response_text}] return qa_chain, formatted_response, formatted_response # Gradio UI def demo(): with gr.Blocks() as demo: vector_db = gr.State() qa_chain = gr.State() gr.Markdown("

RAG-Chatbot mit Pythia und BLOOM (CPU-kompatibel)

") with gr.Row(): with gr.Column(): document = gr.Files(label="PDF-Dokument hochladen", type="filepath", file_types=[".pdf"], file_count="multiple") db_btn = gr.Button("Erstelle Vektordatenbank") db_status = gr.Textbox(label="Datenbankstatus", value="Nicht erstellt", interactive=False) llm_btn = gr.Radio( ["Flan-T5 Base", "MiniLM", "Pythia 12B", "BLOOM 3B", "BLOOM 1.7B"], label="Verfügbare LLMs", value="Flan-T5 Base", type="index" ) slider_temperature = gr.Slider(0.01, 1.0, 0.5, label="Temperature") slider_maxtokens = gr.Slider(1, 2048, 512, label="Max Tokens") qachain_btn = gr.Button("Initialisiere QA-Chatbot") llm_status = gr.Textbox(label="Chatbot-Status", value="Nicht initialisiert", interactive=False) with gr.Column(): chatbot = gr.Chatbot(label="Chatbot", height=400, type="messages") msg = gr.Textbox(label="Frage stellen") submit_btn = gr.Button("Absenden") # Events verknüpfen db_btn.click(initialize_database, inputs=[document], outputs=[vector_db, db_status]) qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, vector_db], outputs=[qa_chain, llm_status]) submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, chatbot, chatbot]) demo.launch(debug=True) if __name__ == "__main__": demo()