RAG_test_1 / app.py
la04's picture
Update app.py
a344264 verified
raw
history blame
5.84 kB
import os
import gradio as gr
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
from langchain_community.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from transformers import pipeline
EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
LLM_MODEL_NAME = "google/flan-t5-small"
# **Dokumente laden und aufteilen**
def load_and_split_docs(list_file_path):
if not list_file_path:
return [], "Fehler: Keine Dokumente gefunden!"
loaders = [PyPDFLoader(x) for x in list_file_path]
documents = []
for loader in loaders:
documents.extend(loader.load())
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=32)
return text_splitter.split_documents(documents)
# **Vektor-Datenbank mit FAISS erstellen**
def create_db(docs):
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
return FAISS.from_documents(docs, embeddings)
# **Datenbank initialisieren**
def initialize_database(list_file_obj):
if not list_file_obj or all(x is None for x in list_file_obj):
return None, "Fehler: Keine Dateien hochgeladen!"
list_file_path = [x.name for x in list_file_obj if x is not None]
doc_splits = load_and_split_docs(list_file_path)
vector_db = create_db(doc_splits)
print("Vektordatenbank erfolgreich erstellt!")
return vector_db, "Datenbank erfolgreich erstellt!"
# **QA-Kette initialisieren (Wrapper)**
def initialize_llm_chain_wrapper(temperature, max_tokens, vector_db):
if vector_db is None:
print("Fehler: Vektordatenbank nicht vorhanden!")
return None, "Fehler: Die Vektordatenbank wurde nicht erstellt! Bitte lade ein PDF hoch und klicke 'Erstelle Vektordatenbank'."
try:
print("Initialisiere QA-Chatbot...")
qa_chain = initialize_llm_chain(temperature, max_tokens, vector_db)
print("QA-Chatbot erfolgreich initialisiert!")
return qa_chain, "QA-Chatbot ist bereit!"
except Exception as e:
print(f"Fehler bei der Initialisierung: {str(e)}")
return None, f"Fehler bei der Initialisierung: {str(e)}"
# **LLM-Kette erstellen**
def initialize_llm_chain(temperature, max_tokens, vector_db):
print("Lade Modellpipeline...")
local_pipeline = pipeline(
"text2text-generation",
model=LLM_MODEL_NAME,
max_length=max_tokens,
temperature=temperature
)
print(f"Modell {LLM_MODEL_NAME} erfolgreich geladen.")
llm = HuggingFacePipeline(pipeline=local_pipeline)
memory = ConversationBufferMemory(memory_key="chat_history")
retriever = vector_db.as_retriever()
return ConversationalRetrievalChain.from_llm(
llm,
retriever=retriever,
memory=memory,
return_source_documents=True
)
# **Konversation mit QA-Kette führen**
def conversation(qa_chain, message, history):
if qa_chain is None:
return None, [{"role": "system", "content": "Der QA-Chain wurde nicht initialisiert!"}], history
if not message.strip():
return qa_chain, [{"role": "system", "content": "Bitte eine Frage eingeben!"}], history
try:
print(f"Frage: {message}")
history = history[-5:] # Beschränke den Verlauf auf die letzten 5 Nachrichten
response = qa_chain.invoke({"question": message, "chat_history": history})
response_text = response["answer"]
sources = [doc.metadata["source"] for doc in response["source_documents"]]
sources_text = "\n".join(sources) if sources else "Keine Quellen verfügbar"
# Strukturierte Rückgabe an `gr.Chatbot`
formatted_response = history + [
{"role": "user", "content": message},
{"role": "assistant", "content": f"{response_text}\n\n**Quellen:**\n{sources_text}"}
]
print("Antwort erfolgreich generiert.")
return qa_chain, formatted_response, formatted_response
except Exception as e:
print(f"Fehler während der Konversation: {str(e)}")
return qa_chain, [{"role": "system", "content": f"Fehler: {str(e)}"}], history
# **Gradio-Demo erstellen**
def demo():
with gr.Blocks() as demo:
vector_db = gr.State() # Zustand für die Vektordatenbank
qa_chain = gr.State() # Zustand für den QA-Chain
chat_history = gr.State([]) # Chatverlauf speichern
gr.HTML("<center><h1>RAG Chatbot mit FAISS und lokalen Modellen</h1></center>")
with gr.Row():
with gr.Column():
document = gr.Files(file_types=[".pdf"], label="PDF hochladen")
db_btn = gr.Button("Erstelle Vektordatenbank")
db_status = gr.Textbox(value="Status: Nicht initialisiert", show_label=False)
slider_temperature = gr.Slider(0.01, 1.0, value=0.5, label="Temperature")
slider_max_tokens = gr.Slider(64, 512, value=256, label="Max Tokens")
qachain_btn = gr.Button("Initialisiere QA-Chatbot")
with gr.Column():
chatbot = gr.Chatbot(label="Chatbot", type='messages', height=400)
msg = gr.Textbox(label="Deine Frage:", placeholder="Frage eingeben...")
submit_btn = gr.Button("Absenden")
db_btn.click(initialize_database, [document], [vector_db, db_status])
qachain_btn.click(initialize_llm_chain_wrapper, [slider_temperature, slider_max_tokens, vector_db], [qa_chain, db_status])
submit_btn.click(conversation, [qa_chain, msg, chat_history], [qa_chain, chatbot, chat_history])
demo.launch(debug=True)
if __name__ == "__main__":
demo()