Spaces:
Sleeping
Sleeping
File size: 4,746 Bytes
2c20468 244a9ba d82bfa1 244a9ba c3bf080 d82bfa1 23cbcf8 244a9ba 2c20468 244a9ba d82bfa1 2c20468 244a9ba d82bfa1 2c20468 244a9ba 2c20468 244a9ba 23cbcf8 d82bfa1 2c20468 244a9ba d82bfa1 2c20468 8ca77ad 0a1e2b9 244a9ba 0a1e2b9 244a9ba 812f60c 244a9ba d82bfa1 812f60c d82bfa1 2c20468 244a9ba 2c20468 8ca77ad c3bf080 a135ba9 c3bf080 8ca77ad c3bf080 8ca77ad a135ba9 c3bf080 8ca77ad c3bf080 8ca77ad 2c20468 812f60c c3bf080 244a9ba 2c20468 812f60c 244a9ba 812f60c 244a9ba 23cbcf8 244a9ba 812f60c a135ba9 812f60c c3bf080 a135ba9 c3bf080 2c20468 812f60c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
import os
import gradio as gr
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
from langchain_community.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from transformers import pipeline
EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
LLM_MODEL_NAME = "google/flan-t5-small"
def load_and_split_docs(list_file_path):
if not list_file_path:
return [], "Fehler: Keine Dokumente gefunden!"
loaders = [PyPDFLoader(x) for x in list_file_path]
documents = []
for loader in loaders:
documents.extend(loader.load())
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=32)
return text_splitter.split_documents(documents)
def create_db(docs):
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
return FAISS.from_documents(docs, embeddings)
def initialize_database(list_file_obj):
if not list_file_obj or all(x is None for x in list_file_obj):
return None, "Fehler: Keine Dateien hochgeladen!"
list_file_path = [x.name for x in list_file_obj if x is not None]
doc_splits = load_and_split_docs(list_file_path)
vector_db = create_db(doc_splits)
return vector_db, "Datenbank erfolgreich erstellt!"
def initialize_llm_chain_wrapper(temperature, max_tokens, vector_db):
if vector_db is None:
return None, "Fehler: Vektordatenbank nicht initialisiert!"
qa_chain = initialize_llm_chain(temperature, max_tokens, vector_db)
return qa_chain, "QA-Chatbot ist bereit!"
def initialize_llm_chain(temperature, max_tokens, vector_db):
local_pipeline = pipeline(
"text2text-generation",
model=LLM_MODEL_NAME,
max_length=max_tokens,
temperature=temperature
)
llm = HuggingFacePipeline(pipeline=local_pipeline)
memory = ConversationBufferMemory(memory_key="chat_history")
retriever = vector_db.as_retriever()
return ConversationalRetrievalChain.from_llm(
llm,
retriever=retriever,
memory=memory,
return_source_documents=True
)
def conversation(qa_chain, message, history):
if qa_chain is None:
return None, [{"role": "system", "content": "Der QA-Chain wurde nicht initialisiert!"}], history
if not message.strip():
return qa_chain, [{"role": "system", "content": "Bitte eine Frage eingeben!"}], history
try:
history = history[-5:] # Nur die letzten 5 Nachrichten übergeben
response = qa_chain.invoke({"question": message, "chat_history": history})
response_text = response["answer"]
sources = [doc.metadata["source"] for doc in response["source_documents"]]
sources_text = "\n".join(sources) if sources else "Keine Quellen verfügbar"
formatted_response = [
{"role": "user", "content": message},
{"role": "assistant", "content": f"{response_text}\n\n**Quellen:**\n{sources_text}"}
]
return qa_chain, formatted_response, history + [(message, response_text)]
except Exception as e:
return qa_chain, [{"role": "system", "content": f"Fehler: {str(e)}"}], history
def demo():
with gr.Blocks() as demo:
vector_db = gr.State()
qa_chain = gr.State()
chat_history = gr.State([])
gr.HTML("<center><h1>RAG Chatbot mit FAISS und lokalen Modellen</h1></center>")
with gr.Row():
with gr.Column():
document = gr.Files(file_types=[".pdf"], label="PDF hochladen")
db_btn = gr.Button("Erstelle Vektordatenbank")
db_status = gr.Textbox(value="Status: Nicht initialisiert", show_label=False)
slider_temperature = gr.Slider(0.01, 1.0, value=0.5, label="Temperature")
slider_max_tokens = gr.Slider(64, 512, value=256, label="Max Tokens")
qachain_btn = gr.Button("Initialisiere QA-Chatbot")
with gr.Column():
chatbot = gr.Chatbot(label="Chatbot", type='messages', height=400)
msg = gr.Textbox(label="Deine Frage:", placeholder="Frage eingeben...")
submit_btn = gr.Button("Absenden")
db_btn.click(initialize_database, [document], [vector_db, db_status])
qachain_btn.click(initialize_llm_chain_wrapper, [slider_temperature, slider_max_tokens, vector_db], [qa_chain])
submit_btn.click(conversation, [qa_chain, msg, chat_history], [qa_chain, chatbot, chat_history])
demo.launch(debug=True)
if __name__ == "__main__":
demo()
|