Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline | |
from langchain_community.vectorstores import FAISS | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.memory import ConversationBufferMemory | |
from transformers import pipeline | |
EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" | |
LLM_MODEL_NAME = "google/flan-t5-small" | |
# **Dokumente laden und aufteilen** | |
def load_and_split_docs(list_file_path): | |
if not list_file_path: | |
return [], "Fehler: Keine Dokumente gefunden!" | |
loaders = [PyPDFLoader(x) for x in list_file_path] | |
documents = [] | |
for loader in loaders: | |
documents.extend(loader.load()) | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=32) | |
return text_splitter.split_documents(documents) | |
# **Vektor-Datenbank mit FAISS erstellen** | |
def create_db(docs): | |
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME) | |
return FAISS.from_documents(docs, embeddings) | |
# **Datenbank initialisieren** | |
def initialize_database(list_file_obj): | |
if not list_file_obj or all(x is None for x in list_file_obj): | |
return None, "Fehler: Keine Dateien hochgeladen!" | |
list_file_path = [x.name for x in list_file_obj if x is not None] | |
doc_splits = load_and_split_docs(list_file_path) | |
vector_db = create_db(doc_splits) | |
print("Vektordatenbank erfolgreich erstellt!") | |
return vector_db, "Datenbank erfolgreich erstellt!" | |
# **QA-Kette initialisieren (Wrapper)** | |
def initialize_llm_chain_wrapper(temperature, max_tokens, vector_db): | |
if vector_db is None: | |
print("Fehler: Vektordatenbank nicht vorhanden!") | |
return None, "Fehler: Die Vektordatenbank wurde nicht erstellt! Bitte lade ein PDF hoch und klicke 'Erstelle Vektordatenbank'." | |
try: | |
print("Initialisiere QA-Chatbot...") | |
qa_chain = initialize_llm_chain(temperature, max_tokens, vector_db) | |
print("QA-Chatbot erfolgreich initialisiert!") | |
return qa_chain, "QA-Chatbot ist bereit!" | |
except Exception as e: | |
print(f"Fehler bei der Initialisierung: {str(e)}") | |
return None, f"Fehler bei der Initialisierung: {str(e)}" | |
# **LLM-Kette erstellen** | |
def initialize_llm_chain(temperature, max_tokens, vector_db): | |
print("Lade Modellpipeline...") | |
local_pipeline = pipeline( | |
"text2text-generation", | |
model=LLM_MODEL_NAME, | |
max_length=max_tokens, | |
temperature=temperature | |
) | |
print(f"Modell {LLM_MODEL_NAME} erfolgreich geladen.") | |
llm = HuggingFacePipeline(pipeline=local_pipeline) | |
memory = ConversationBufferMemory(memory_key="chat_history") | |
retriever = vector_db.as_retriever() | |
return ConversationalRetrievalChain.from_llm( | |
llm, | |
retriever=retriever, | |
memory=memory, | |
return_source_documents=True | |
) | |
# **Konversation mit QA-Kette führen** | |
def conversation(qa_chain, message, history): | |
if qa_chain is None: | |
return None, [{"role": "system", "content": "Der QA-Chain wurde nicht initialisiert!"}], history | |
if not message.strip(): | |
return qa_chain, [{"role": "system", "content": "Bitte eine Frage eingeben!"}], history | |
try: | |
print(f"Frage: {message}") | |
history = history[-5:] # Beschränke den Verlauf auf die letzten 5 Nachrichten | |
response = qa_chain.invoke({"question": message, "chat_history": history}) | |
response_text = response["answer"] | |
sources = [doc.metadata["source"] for doc in response["source_documents"]] | |
sources_text = "\n".join(sources) if sources else "Keine Quellen verfügbar" | |
# Strukturierte Rückgabe an `gr.Chatbot` | |
formatted_response = history + [ | |
{"role": "user", "content": message}, | |
{"role": "assistant", "content": f"{response_text}\n\n**Quellen:**\n{sources_text}"} | |
] | |
print("Antwort erfolgreich generiert.") | |
return qa_chain, formatted_response, formatted_response | |
except Exception as e: | |
print(f"Fehler während der Konversation: {str(e)}") | |
return qa_chain, [{"role": "system", "content": f"Fehler: {str(e)}"}], history | |
# **Gradio-Demo erstellen** | |
def demo(): | |
with gr.Blocks() as demo: | |
vector_db = gr.State() # Zustand für die Vektordatenbank | |
qa_chain = gr.State() # Zustand für den QA-Chain | |
chat_history = gr.State([]) # Chatverlauf speichern | |
gr.HTML("<center><h1>RAG Chatbot mit FAISS und lokalen Modellen</h1></center>") | |
with gr.Row(): | |
with gr.Column(): | |
document = gr.Files(file_types=[".pdf"], label="PDF hochladen") | |
db_btn = gr.Button("Erstelle Vektordatenbank") | |
db_status = gr.Textbox(value="Status: Nicht initialisiert", show_label=False) | |
slider_temperature = gr.Slider(0.01, 1.0, value=0.5, label="Temperature") | |
slider_max_tokens = gr.Slider(64, 512, value=256, label="Max Tokens") | |
qachain_btn = gr.Button("Initialisiere QA-Chatbot") | |
with gr.Column(): | |
chatbot = gr.Chatbot(label="Chatbot", type='messages', height=400) | |
msg = gr.Textbox(label="Deine Frage:", placeholder="Frage eingeben...") | |
submit_btn = gr.Button("Absenden") | |
db_btn.click(initialize_database, [document], [vector_db, db_status]) | |
qachain_btn.click(initialize_llm_chain_wrapper, [slider_temperature, slider_max_tokens, vector_db], [qa_chain, db_status]) | |
submit_btn.click(conversation, [qa_chain, msg, chat_history], [qa_chain, chatbot, chat_history]) | |
demo.launch(debug=True) | |
if __name__ == "__main__": | |
demo() | |