File size: 5,252 Bytes
2c20468
244a9ba
d82bfa1
244a9ba
d82bfa1
 
23cbcf8
 
d82bfa1
244a9ba
2c20468
0a1e2b9
244a9ba
d82bfa1
2c20468
0a1e2b9
244a9ba
d82bfa1
 
2c20468
244a9ba
2c20468
244a9ba
23cbcf8
d82bfa1
2c20468
0a1e2b9
244a9ba
 
d82bfa1
2c20468
0a1e2b9
8ca77ad
 
 
 
 
 
 
 
0a1e2b9
 
 
 
 
 
 
 
 
244a9ba
 
0a1e2b9
244a9ba
 
812f60c
244a9ba
d82bfa1
812f60c
d82bfa1
2c20468
 
 
244a9ba
2c20468
 
0a1e2b9
8ca77ad
 
 
a135ba9
 
8ca77ad
 
 
 
a135ba9
8ca77ad
 
 
 
0a1e2b9
2c20468
812f60c
a135ba9
 
 
244a9ba
 
2c20468
812f60c
244a9ba
812f60c
244a9ba
23cbcf8
244a9ba
812f60c
 
 
a135ba9
 
812f60c
 
a135ba9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244a9ba
d82bfa1
2c20468
 
812f60c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import os
import gradio as gr
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain_community.llms import HuggingFacePipeline
from transformers import pipeline

# Embeddings- und LLM-Modelle
EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
LLM_MODEL_NAME = "google/flan-t5-small"

# **Dokumente laden und aufteilen**
def load_and_split_docs(list_file_path):
    if not list_file_path:
        return [], "Fehler: Keine Dokumente gefunden!"
    loaders = [PyPDFLoader(x) for x in list_file_path]
    documents = []
    for loader in loaders:
        documents.extend(loader.load())
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=32)
    return text_splitter.split_documents(documents)

# **Vektor-Datenbank mit FAISS erstellen**
def create_db(docs):
    embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
    return FAISS.from_documents(docs, embeddings)

# **Datenbank initialisieren**
def initialize_database(list_file_obj):
    if not list_file_obj or all(x is None for x in list_file_obj):
        return None, "Fehler: Keine Dateien hochgeladen!"
    list_file_path = [x.name for x in list_file_obj if x is not None]
    doc_splits = load_and_split_docs(list_file_path)
    vector_db = create_db(doc_splits)
    return vector_db, "Datenbank erfolgreich erstellt!"

# **LLM-Kette initialisieren (Wrapper)**
def initialize_llm_chain_wrapper(temperature, max_tokens, vector_db):
    if vector_db is None:
        return None, "Fehler: Vektordatenbank nicht initialisiert!"
    qa_chain = initialize_llm_chain(temperature, max_tokens, vector_db)
    return qa_chain, "QA-Chatbot ist bereit!"

# **LLM-Kette erstellen**
def initialize_llm_chain(temperature, max_tokens, vector_db):
    local_pipeline = pipeline(
        "text2text-generation",
        model=LLM_MODEL_NAME,
        max_length=max_tokens,
        temperature=temperature
    )
    llm = HuggingFacePipeline(pipeline=local_pipeline)
    memory = ConversationBufferMemory(memory_key="chat_history")
    retriever = vector_db.as_retriever()
    return ConversationalRetrievalChain.from_llm(
        llm,
        retriever=retriever,
        memory=memory,
        return_source_documents=True
    )

# **Konversation mit QA-Kette führen**
def conversation(qa_chain, message, history):
    if qa_chain is None:
        return None, "Der QA-Chain wurde nicht initialisiert!", history
    if not message.strip():
        return qa_chain, "Bitte eine Frage eingeben!", history
    try:
        response = qa_chain({"question": message, "chat_history": history})
        response_text = response["answer"]
        sources = [doc.metadata["source"] for doc in response["source_documents"]]
        sources_text = "\n".join(sources) if sources else "Keine Quellen verfügbar"
        return qa_chain, f"{response_text}\n\n**Quellen:**\n{sources_text}", history + [(message, response_text)]
    except Exception as e:
        return qa_chain, f"Fehler: {str(e)}", history

# **Gradio-Demo erstellen**
def demo():
    with gr.Blocks() as demo:
        vector_db = gr.State()  # Zustand für die Vektordatenbank
        qa_chain = gr.State()  # Zustand für den QA-Chain
        chat_history = gr.State([])  # Chatverlauf speichern

        gr.HTML("<center><h1>RAG Chatbot mit FAISS und lokalen Modellen</h1></center>")
        with gr.Row():
            with gr.Column():
                document = gr.Files(file_types=[".pdf"], label="PDF hochladen")
                db_btn = gr.Button("Erstelle Vektordatenbank")
                db_status = gr.Textbox(value="Status: Nicht initialisiert", show_label=False)
                slider_temperature = gr.Slider(0.01, 1.0, value=0.5, label="Temperature")
                slider_max_tokens = gr.Slider(64, 512, value=256, label="Max Tokens")
                qachain_btn = gr.Button("Initialisiere QA-Chatbot")

            with gr.Column():
                chatbot = gr.Chatbot(label="Chatbot", type='messages', height=400)
                msg = gr.Textbox(label="Deine Frage:", placeholder="Frage eingeben...")
                submit_btn = gr.Button("Absenden")

        # **Button-Events definieren**
        db_btn.click(
            initialize_database,
            inputs=[document],  # Eingabe der hochgeladenen Dokumente
            outputs=[vector_db, db_status]  # Ausgabe: Vektor-Datenbank und Status
        )
        
        qachain_btn.click(
            initialize_llm_chain_wrapper,
            inputs=[slider_temperature, slider_max_tokens, vector_db],
            outputs=[qa_chain, db_status]
        )

        submit_btn.click(
            conversation,
            inputs=[qa_chain, msg, chat_history],  # Chatkette, Nutzerfrage, Chatverlauf
            outputs=[qa_chain, chatbot, chat_history]  # Antwort der Kette, Chatbot-Ausgabe, neuer Verlauf
        )

    demo.launch(debug=True, enable_queue=True)

if __name__ == "__main__":
    demo()