File size: 5,500 Bytes
244a9ba
80396ad
d82bfa1
244a9ba
f1c2bc3
 
23cbcf8
 
2c20468
fd3e3c6
04dd8cd
2c20468
90d6700
 
 
 
 
 
 
 
62d5470
90d6700
80396ad
f1c2bc3
 
2c20468
f1c2bc3
2c20468
f1c2bc3
90d6700
f1c2bc3
80396ad
90d6700
80396ad
f1c2bc3
80396ad
 
90d6700
f1c2bc3
 
 
a0ac2ce
f1c2bc3
 
 
 
90d6700
 
f1c2bc3
 
90d6700
 
 
 
 
80396ad
 
f1c2bc3
80396ad
90d6700
80396ad
 
 
90d6700
 
80396ad
 
2c20468
90d6700
 
 
 
f1c2bc3
6dedc06
80396ad
90d6700
80396ad
62d5470
f1c2bc3
8ca77ad
f1c2bc3
 
 
 
80396ad
f1c2bc3
6dedc06
f1c2bc3
8ca77ad
90d6700
2c20468
812f60c
80396ad
 
90d6700
80396ad
d14d249
 
a0ac2ce
d14d249
 
 
90d6700
 
 
 
 
 
d14d249
90d6700
d14d249
 
 
 
 
 
 
 
90d6700
 
 
 
 
3b7742d
2c20468
 
812f60c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import gradio as gr
import os
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint
from langchain_community.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory

# API-Token aus Umgebungsvariable laden
api_token = os.getenv("HF_Token")

# Modelle für Auswahl
list_llm = [
    "google/flan-t5-base",  # Leichtes Instruktionsmodell
    "sentence-transformers/all-MiniLM-L6-v2",  # Embeddings-optimiertes Modell
    "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5",  # Pythia 12B
    "bigscience/bloom-3b",  # Multilingualer BLOOM
    "bigscience/bloom-1b7"  # Leichtes BLOOM-Modell
]

# Dokumentenverarbeitung
def load_doc(list_file_path):
    if not list_file_path:
        return [], "Fehler: Keine Dokumente gefunden!"
    loaders = [PyPDFLoader(x) for x in list_file_path]
    documents = []
    for loader in loaders:
        documents.extend(loader.load())
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=32)
    return text_splitter.split_documents(documents)

# Erstelle Vektordatenbank
def create_db(splits):
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    return FAISS.from_documents(splits, embeddings)

# Initialisiere Datenbank
def initialize_database(list_file_obj):
    if not list_file_obj:
        return None, "Fehler: Keine Dateien hochgeladen!"
    list_file_path = list_file_obj  # Dateipfade von den hochgeladenen Dateien
    doc_splits = load_doc(list_file_path)
    vector_db = create_db(doc_splits)
    return vector_db, "Datenbank erfolgreich erstellt!"

# Initialisiere LLM-Kette
def initialize_llmchain(llm_model, temperature, max_tokens, vector_db):
    if vector_db is None:
        return None, "Fehler: Keine Vektordatenbank verfügbar."
    if "pythia" in llm_model or "bloom" in llm_model:
        max_tokens = min(max_tokens, 2048)
    else:
        max_tokens = min(max_tokens, 1024)

    llm = HuggingFaceEndpoint(
        repo_id=llm_model,
        huggingfacehub_api_token=api_token,
        temperature=temperature,
        max_new_tokens=max_tokens
    )
    memory = ConversationBufferMemory(memory_key="chat_history", output_key="answer", return_messages=True)
    retriever = vector_db.as_retriever()

    qa_chain = ConversationalRetrievalChain.from_llm(
        llm, retriever=retriever, chain_type="stuff", memory=memory, return_source_documents=True
    )

    return qa_chain

# Initialisiere LLM
def initialize_LLM(llm_option, llm_temperature, max_tokens, vector_db):
    if vector_db is None:
        return None, "Fehler: Datenbank wurde nicht erstellt!"
    llm_name = list_llm[llm_option]
    qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, vector_db)
    return qa_chain, "QA-Kette initialisiert. Chatbot ist bereit!"

# Konversation
def conversation(qa_chain, message, history):
    if qa_chain is None:
        return None, [{"role": "system", "content": "Die QA-Kette wurde nicht initialisiert."}], history
    if not message.strip():
        return qa_chain, [{"role": "system", "content": "Bitte eine Frage eingeben!"}], history
    response = qa_chain.invoke({"question": message, "chat_history": history})
    response_text = response.get("answer", "Keine Antwort verfügbar.")
    formatted_response = history + [{"role": "user", "content": message}, {"role": "assistant", "content": response_text}]
    return qa_chain, formatted_response, formatted_response

# Gradio UI
def demo():
    with gr.Blocks() as demo:
        vector_db = gr.State()
        qa_chain = gr.State()
        gr.Markdown("<center><h1>RAG-Chatbot mit Pythia und BLOOM (CPU-kompatibel)</h1></center>")

        with gr.Row():
            with gr.Column():
                document = gr.Files(label="PDF-Dokument hochladen", type="filepath", file_types=[".pdf"], file_count="multiple")
                db_btn = gr.Button("Erstelle Vektordatenbank")
                db_status = gr.Textbox(label="Datenbankstatus", value="Nicht erstellt", interactive=False)

                llm_btn = gr.Radio(
                    ["Flan-T5 Base", "MiniLM", "Pythia 12B", "BLOOM 3B", "BLOOM 1.7B"],
                    label="Verfügbare LLMs",
                    value="Flan-T5 Base",
                    type="index"
                )
                slider_temperature = gr.Slider(0.01, 1.0, 0.5, label="Temperature")
                slider_maxtokens = gr.Slider(1, 2048, 512, label="Max Tokens")
                qachain_btn = gr.Button("Initialisiere QA-Chatbot")
                llm_status = gr.Textbox(label="Chatbot-Status", value="Nicht initialisiert", interactive=False)

            with gr.Column():
                chatbot = gr.Chatbot(label="Chatbot", height=400, type="messages")
                msg = gr.Textbox(label="Frage stellen")
                submit_btn = gr.Button("Absenden")

        # Events verknüpfen
        db_btn.click(initialize_database, inputs=[document], outputs=[vector_db, db_status])
        qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, vector_db], outputs=[qa_chain, llm_status])
        submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, chatbot, chatbot])

    demo.launch(debug=True)

if __name__ == "__main__":
    demo()