File size: 5,238 Bytes
7f96312
2c20468
 
 
 
 
812f60c
2c20468
812f60c
2c20468
812f60c
 
2c20468
 
812f60c
2c20468
 
 
 
 
 
812f60c
 
 
2c20468
 
 
812f60c
2c20468
 
 
 
 
812f60c
 
 
 
 
 
 
 
2c20468
 
 
 
 
 
 
812f60c
2c20468
 
 
812f60c
2c20468
 
812f60c
2c20468
 
 
812f60c
 
2c20468
 
 
812f60c
2c20468
812f60c
 
2c20468
812f60c
 
2c20468
 
 
 
 
 
 
 
812f60c
2c20468
 
 
 
812f60c
2c20468
 
812f60c
2c20468
812f60c
2c20468
812f60c
2c20468
 
812f60c
2c20468
812f60c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2c20468
 
812f60c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
import os
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings 
from langchain.chains import ConversationalRetrievalChain
from langchain_community.llms import HuggingFaceEndpoint
from langchain.memory import ConversationBufferMemory

# Liste der Modelle
list_llm = ["google/flan-t5-small", "distilbert-base-uncased"]  # Leichtere, CPU-freundliche Modelle
list_llm_simple = [os.path.basename(llm) for llm in list_llm]

# PDF-Dokument laden und aufteilen
def load_doc(list_file_path):
    loaders = [PyPDFLoader(x) for x in list_file_path]
    pages = []
    for loader in loaders:
        pages.extend(loader.load())
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=512,  # Kleinere Chunks für schnellere Verarbeitung
        chunk_overlap=32
    )
    doc_splits = text_splitter.split_documents(pages)
    return doc_splits

# Erstellen der Vektordatenbank
def create_db(splits):
    embeddings = HuggingFaceEmbeddings()
    vectordb = FAISS.from_documents(splits, embeddings)
    return vectordb

# Initialisierung des LLM Chains
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
    llm = HuggingFaceEndpoint(
        repo_id=llm_model,
        temperature=temperature,
        max_new_tokens=max_tokens,
        top_k=top_k
    )

    memory = ConversationBufferMemory(
        memory_key="chat_history",
        output_key='answer',
        return_messages=True
    )

    retriever = vector_db.as_retriever()
    qa_chain = ConversationalRetrievalChain.from_llm(
        llm,
        retriever=retriever,
        chain_type="stuff",
        memory=memory,
        return_source_documents=True,
        verbose=False
    )
    return qa_chain

# Initialisierung der Datenbank
def initialize_database(list_file_obj):
    list_file_path = [x.name for x in list_file_obj if x is not None]
    doc_splits = load_doc(list_file_path)
    vector_db = create_db(doc_splits)
    return vector_db, "Datenbank erfolgreich erstellt!"

# Initialisierung des LLMs
def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db):
    llm_name = list_llm[llm_option]
    qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db)
    return qa_chain, "LLM erfolgreich initialisiert! Chatbot ist bereit."

def format_chat_history(message, chat_history):
    formatted_chat_history = []
    for user_message, bot_message in chat_history:
        formatted_chat_history.append(f"User: {user_message}")
        formatted_chat_history.append(f"Assistant: {bot_message}")
    return formatted_chat_history

# Chat-Funktion
def conversation(qa_chain, message, history):
    formatted_chat_history = format_chat_history(message, history)
    response = qa_chain.invoke({"question": message, "chat_history": formatted_chat_history})
    response_answer = response["answer"]
    if "Helpful Answer:" in response_answer:
        response_answer = response_answer.split("Helpful Answer:")[-1]
    new_history = history + [(message, response_answer)]
    return qa_chain, gr.update(value=""), new_history

# Gradio App erstellen
def demo():
    with gr.Blocks() as demo:
        vector_db = gr.State()
        qa_chain = gr.State()
        gr.HTML("<center><h1>RAG PDF Chatbot</h1></center>")
        with gr.Row():
            with gr.Column():
                gr.Markdown("### Schritt 1: Lade PDF-Dokument hoch")
                document = gr.Files(height=300, file_count="multiple", file_types=[".pdf"], interactive=True)
                db_btn = gr.Button("Erstelle Vektordatenbank")
                db_progress = gr.Textbox(value="Nicht initialisiert", show_label=False)
                gr.Markdown("### Schritt 2: Wähle LLM und Einstellungen")
                llm_btn = gr.Radio(list_llm_simple, label="Verfügbare Modelle", value=list_llm_simple[0], type="index")
                slider_temperature = gr.Slider(0.01, 1.0, value=0.5, step=0.1, label="Temperature")
                slider_maxtokens = gr.Slider(64, 512, value=256, step=64, label="Max Tokens")
                slider_topk = gr.Slider(1, 10, value=3, step=1, label="Top-k")
                qachain_btn = gr.Button("Initialisiere QA-Chatbot")
                llm_progress = gr.Textbox(value="Nicht initialisiert", show_label=False)

            with gr.Column():
                gr.Markdown("### Schritt 3: Stelle Fragen an dein Dokument")
                chatbot = gr.Chatbot(height=400, type="messages")
                msg = gr.Textbox(placeholder="Frage stellen...")
                submit_btn = gr.Button("Absenden")

        db_btn.click(initialize_database, [document], [vector_db, db_progress])
        qachain_btn.click(initialize_LLM, [llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], [qa_chain, llm_progress])
        msg.submit(conversation, [qa_chain, msg, chatbot], [qa_chain, msg, chatbot])
        submit_btn.click(conversation, [qa_chain, msg, chatbot], [qa_chain, msg, chatbot])
    demo.launch(debug=True)

if __name__ == "__main__":
    demo()