File size: 5,429 Bytes
7f96312
2c20468
bf5014c
 
 
 
 
 
 
2c20468
bf5014c
0dda7f4
2c20468
 
bf5014c
2c20468
 
 
 
bf5014c
 
2c20468
 
 
bf5014c
2c20468
bf5014c
2c20468
 
 
bf5014c
812f60c
7b3bf1d
812f60c
7b3bf1d
 
 
 
 
812f60c
0dda7f4
812f60c
2c20468
 
 
812f60c
2c20468
 
812f60c
2c20468
 
 
bf5014c
812f60c
2c20468
 
 
812f60c
2c20468
bf5014c
812f60c
2c20468
812f60c
 
2c20468
bf5014c
2c20468
 
 
 
 
 
 
bf5014c
2c20468
 
7b3bf1d
2c20468
 
812f60c
2c20468
bf5014c
2c20468
812f60c
2c20468
 
812f60c
2c20468
812f60c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c20468
 
812f60c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
import os
from langchain.vectorstores import FAISS  # Import für Vektordatenbank FAISS
from langchain.document_loaders import PyPDFLoader  # PDF-Loader zum Laden der Dokumente
from langchain.embeddings import HuggingFaceEmbeddings  # Embeddings-Erstellung mit Hugging Face-Modellen
from langchain.chains import ConversationalRetrievalChain  # Chain für QA-Funktionalität
from langchain.memory import ConversationBufferMemory  # Speichern des Chat-Verlaufs im Speicher
from langchain.llms import HuggingFaceHub  # Für das Laden der Modelle von Hugging Face Hub
from langchain.text_splitter import RecursiveCharacterTextSplitter  # Aufteilen von Dokumenten in Chunks

# Liste der LLM-Modelle (leichte CPU-freundliche Modelle)
list_llm = ["google/flan-t5-small", "distilbert-base-uncased"]
list_llm_simple = [os.path.basename(llm) for llm in list_llm]

# PDF-Dokument laden und in Chunks aufteilen
def load_doc(list_file_path):
    loaders = [PyPDFLoader(x) for x in list_file_path]
    pages = []
    for loader in loaders:
        pages.extend(loader.load())  # Laden der Seiten aus PDF
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=32)  # Chunks für CPU
    doc_splits = text_splitter.split_documents(pages)
    return doc_splits

# Vektordatenbank erstellen
def create_db(splits):
    embeddings = HuggingFaceEmbeddings()  # Erstellen der Embeddings mit Hugging Face
    vectordb = FAISS.from_documents(splits, embeddings)
    return vectordb

# Initialisierung des ConversationalRetrievalChain
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
    llm = HuggingFaceHub(
        repo_id=llm_model,
        model_kwargs={
            "temperature": temperature,
            "max_length": max_tokens,
            "top_k": top_k,
        }
    )
    memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
    retriever = vector_db.as_retriever()
    qa_chain = ConversationalRetrievalChain.from_llm(
        llm,
        retriever=retriever,
        chain_type="stuff",
        memory=memory,
        return_source_documents=True,
        verbose=False
    )
    return qa_chain

# Initialisierung der Datenbank
def initialize_database(list_file_obj):
    list_file_path = [x.name for x in list_file_obj if x is not None]
    doc_splits = load_doc(list_file_path)
    vector_db = create_db(doc_splits)
    return vector_db, "Datenbank erfolgreich erstellt!"

# Initialisierung des LLM
def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db):
    llm_name = list_llm[llm_option]
    qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db)
    return qa_chain, "LLM erfolgreich initialisiert! Chatbot ist bereit."

# Chat-Historie formatieren
def format_chat_history(message, chat_history):
    formatted_chat_history = []
    for user_message, bot_message in chat_history:
        formatted_chat_history.append(f"User: {user_message}")
        formatted_chat_history.append(f"Assistant: {bot_message}")
    return formatted_chat_history

# Konversationsfunktion
def conversation(qa_chain, message, history):
    formatted_chat_history = format_chat_history(message, history)
    response = qa_chain({"question": message, "chat_history": formatted_chat_history})
    response_answer = response["answer"]
    new_history = history + [(message, response_answer)]
    return qa_chain, gr.update(value=""), new_history

# Gradio-Frontend
def demo():
    with gr.Blocks() as demo:
        vector_db = gr.State()
        qa_chain = gr.State()
        gr.HTML("<center><h1>RAG PDF Chatbot</h1></center>")
        with gr.Row():
            with gr.Column():
                gr.Markdown("### Schritt 1: Lade PDF-Dokument hoch")
                document = gr.Files(height=300, file_count="multiple", file_types=[".pdf"], interactive=True)
                db_btn = gr.Button("Erstelle Vektordatenbank")
                db_progress = gr.Textbox(value="Nicht initialisiert", show_label=False)
                gr.Markdown("### Schritt 2: Wähle LLM und Einstellungen")
                llm_btn = gr.Radio(list_llm_simple, label="Verfügbare Modelle", value=list_llm_simple[0], type="index")
                slider_temperature = gr.Slider(0.01, 1.0, value=0.5, step=0.1, label="Temperature")
                slider_maxtokens = gr.Slider(64, 512, value=256, step=64, label="Max Tokens")
                slider_topk = gr.Slider(1, 10, value=3, step=1, label="Top-k")
                qachain_btn = gr.Button("Initialisiere QA-Chatbot")
                llm_progress = gr.Textbox(value="Nicht initialisiert", show_label=False)

            with gr.Column():
                gr.Markdown("### Schritt 3: Stelle Fragen an dein Dokument")
                chatbot = gr.Chatbot(height=400, type="messages")
                msg = gr.Textbox(placeholder="Frage stellen...")
                submit_btn = gr.Button("Absenden")

        db_btn.click(initialize_database, [document], [vector_db, db_progress])
        qachain_btn.click(initialize_LLM, [llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], [qa_chain, llm_progress])
        msg.submit(conversation, [qa_chain, msg, chatbot], [qa_chain, msg, chatbot])
        submit_btn.click(conversation, [qa_chain, msg, chatbot], [qa_chain, msg, chatbot])
    demo.launch(debug=True)

if __name__ == "__main__":
    demo()