File size: 4,137 Bytes
2c20468
244a9ba
23cbcf8
244a9ba
23cbcf8
244a9ba
23cbcf8
 
244a9ba
 
2c20468
244a9ba
 
 
2c20468
244a9ba
 
2c20468
244a9ba
2c20468
244a9ba
23cbcf8
244a9ba
2c20468
 
244a9ba
 
 
 
 
2c20468
244a9ba
 
 
 
 
 
 
 
812f60c
244a9ba
0dda7f4
812f60c
244a9ba
 
2c20468
 
 
 
244a9ba
2c20468
 
 
244a9ba
812f60c
2c20468
244a9ba
2c20468
812f60c
2c20468
244a9ba
 
 
2c20468
244a9ba
2c20468
244a9ba
 
 
 
2c20468
244a9ba
2c20468
812f60c
2c20468
 
244a9ba
 
2c20468
812f60c
244a9ba
812f60c
244a9ba
23cbcf8
244a9ba
812f60c
 
 
231e3ba
244a9ba
812f60c
 
244a9ba
 
812f60c
244a9ba
812f60c
2c20468
 
812f60c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import os
import gradio as gr
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.llms import HuggingFacePipeline
from transformers import pipeline

# **Embeddings-Modell (kein API-Key nötig, lokal geladen)**
EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
LLM_MODEL_NAME = "google/flan-t5-small"  # Alternativ: "google/flan-t5-base", etc.

# **Dokumente laden und aufteilen**
def load_and_split_docs(list_file_path):
    loaders = [PyPDFLoader(x) for x in list_file_path]
    documents = []
    for loader in loaders:
        documents.extend(loader.load())
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=32)
    doc_splits = text_splitter.split_documents(documents)
    return doc_splits

# **Vektor-Datenbank mit FAISS erstellen**
def create_db(docs):
    embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
    faiss_index = FAISS.from_documents(docs, embeddings)
    return faiss_index

# **LLM-Kette initialisieren**
def initialize_llm_chain(llm_model, temperature, max_tokens, vector_db):
    # Hugging Face Pipeline lokal verwenden
    local_pipeline = pipeline(
        "text2text-generation",
        model=llm_model,
        max_length=max_tokens,
        temperature=temperature
    )
    llm = HuggingFacePipeline(pipeline=local_pipeline)
    memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
    retriever = vector_db.as_retriever()
    
    # Retrieval-Augmented QA-Kette
    qa_chain = ConversationalRetrievalChain.from_llm(
        llm,
        retriever=retriever,
        memory=memory,
        return_source_documents=True
    )
    return qa_chain

# **Datenbank und Kette initialisieren**
def initialize_database(list_file_obj):
    list_file_path = [x.name for x in list_file_obj if x is not None]
    doc_splits = load_and_split_docs(list_file_path)
    vector_db = create_db(doc_splits)
    return vector_db, "Datenbank erfolgreich erstellt!"

def initialize_llm_chain_wrapper(llm_temperature, max_tokens, vector_db):
    qa_chain = initialize_llm_chain(LLM_MODEL_NAME, llm_temperature, max_tokens, vector_db)
    return qa_chain, "QA-Chatbot ist bereit!"

# **Konversation mit QA-Kette führen**
def conversation(qa_chain, message, history):
    response = qa_chain({"question": message, "chat_history": history})
    response_text = response["answer"]
    sources = [doc.metadata["source"] for doc in response["source_documents"]]
    return qa_chain, response_text, history + [(message, response_text)]

# **Gradio-Benutzeroberfläche**
def demo():
    with gr.Blocks() as demo:
        vector_db = gr.State()
        qa_chain = gr.State()

        gr.HTML("<center><h1>RAG Chatbot mit FAISS und lokalen Modellen</h1></center>")
        with gr.Row():
            with gr.Column():
                document = gr.Files(file_types=[".pdf"], label="PDF hochladen")
                db_btn = gr.Button("Erstelle Vektordatenbank")
                db_status = gr.Textbox(value="Status: Nicht initialisiert", show_label=False)
                slider_temperature = gr.Slider(0.01, 1.0, value=0.5, label="Temperature")
                slider_max_tokens = gr.Slider(64, 512, value=256, label="Max Tokens")
                qachain_btn = gr.Button("Initialisiere QA-Chatbot")

            with gr.Column():
                chatbot = gr.Chatbot(height=400)
                msg = gr.Textbox(placeholder="Frage eingeben...")
                submit_btn = gr.Button("Absenden")

        db_btn.click(initialize_database, [document], [vector_db, db_status])
        qachain_btn.click(initialize_llm_chain_wrapper, [slider_temperature, slider_max_tokens, vector_db], [qa_chain])
        submit_btn.click(conversation, [qa_chain, msg, chatbot], [qa_chain, msg, chatbot])

    demo.launch(debug=True)

if __name__ == "__main__":
    demo()