File size: 5,205 Bytes
244a9ba
80396ad
d82bfa1
244a9ba
f1c2bc3
 
23cbcf8
 
2c20468
f1c2bc3
 
2c20468
f1c2bc3
80396ad
62d5470
f1c2bc3
80396ad
f1c2bc3
 
2c20468
f1c2bc3
2c20468
f1c2bc3
80396ad
f1c2bc3
80396ad
f1c2bc3
80396ad
f1c2bc3
80396ad
 
f1c2bc3
 
 
 
 
 
 
 
 
 
80396ad
f1c2bc3
 
15da3c5
 
80396ad
 
f1c2bc3
80396ad
 
 
 
 
 
f1c2bc3
80396ad
 
2c20468
f1c2bc3
80396ad
f1c2bc3
6dedc06
80396ad
 
 
62d5470
f1c2bc3
8ca77ad
f1c2bc3
 
 
 
80396ad
f1c2bc3
6dedc06
f1c2bc3
8ca77ad
f1c2bc3
2c20468
812f60c
80396ad
 
f1c2bc3
80396ad
d14d249
 
 
 
 
 
 
 
15da3c5
d14d249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80396ad
f1c2bc3
2c20468
 
812f60c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gradio as gr
import os
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint
from langchain_community.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory

# API-Token
api_token = os.getenv("HF_TOKEN")

# LLM-Optionen
list_llm = ["google/flan-t5-small", "google/flan-t5-base"]

# Dokumente laden und aufteilen
def load_doc(list_file_path):
    if not list_file_path:
        return [], "Fehler: Keine Dokumente gefunden!"
    loaders = [PyPDFLoader(x) for x in list_file_path]
    documents = []
    for loader in loaders:
        documents.extend(loader.load())
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
    return text_splitter.split_documents(documents)

# Vektor-Datenbank erstellen
def create_db(splits):
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    return FAISS.from_documents(splits, embeddings)

# Datenbank initialisieren
def initialize_database(list_file_obj):
    if not list_file_obj:
        return None, "Fehler: Keine Dateien hochgeladen!"
    list_file_path = [x.name for x in list_file_obj if x is not None]
    doc_splits = load_doc(list_file_path)
    vector_db = create_db(doc_splits)
    return vector_db, "Datenbank erfolgreich erstellt!"

# LLM-Kette initialisieren
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
    if vector_db is None:
        return None, "Fehler: Keine Vektordatenbank verfügbar."
    if max_tokens > 250:
        max_tokens = 250  # Begrenze max_new_tokens, um Fehler zu vermeiden
    llm = HuggingFaceEndpoint(
        repo_id=llm_model,
        huggingfacehub_api_token=api_token,
        temperature=temperature,
        max_new_tokens=max_tokens,
        top_k=top_k,
    )
    memory = ConversationBufferMemory(memory_key="chat_history", output_key="answer", return_messages=True)
    retriever = vector_db.as_retriever()
    return ConversationalRetrievalChain.from_llm(
        llm, retriever=retriever, chain_type="stuff", memory=memory, return_source_documents=True
    )

# LLM initialisieren
def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db):
    if vector_db is None:
        return None, "Fehler: Datenbank wurde nicht erstellt!"
    llm_name = list_llm[llm_option]
    qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db)
    return qa_chain, "QA-Kette initialisiert. Chatbot ist bereit!"

# Konversation
def conversation(qa_chain, message, history):
    if qa_chain is None:
        return None, [{"role": "system", "content": "Die QA-Kette wurde nicht initialisiert."}], history
    if not message.strip():
        return qa_chain, [{"role": "system", "content": "Bitte eine Frage eingeben!"}], history
    response = qa_chain.invoke({"question": message, "chat_history": history})
    response_text = response.get("answer", "Keine Antwort verfügbar.")
    formatted_response = history + [{"role": "user", "content": message}, {"role": "assistant", "content": response_text}]
    return qa_chain, formatted_response, formatted_response

# Demo erstellen
def demo():
    with gr.Blocks() as demo:
        vector_db = gr.State()
        qa_chain = gr.State()
        gr.Markdown("<center><h1>PDF-Chatbot mit kostenlosen Modellen</h1></center>")

        with gr.Row():
            with gr.Column():
                document = gr.Files(label="PDF-Dokument hochladen")
                db_btn = gr.Button("Erstelle Vektordatenbank")
                db_status = gr.Textbox(label="Datenbankstatus", value="Nicht erstellt", interactive=False)

                llm_btn = gr.Radio(["Flan-T5 Small", "Flan-T5 Base"], label="Verfügbare LLMs", value="Flan-T5 Small", type="index")
                slider_temperature = gr.Slider(0.01, 1.0, 0.5, label="Temperature")
                slider_maxtokens = gr.Slider(1, 250, 128, label="Max Tokens")  # Begrenzung auf 250
                slider_topk = gr.Slider(1, 10, 3, label="Top-k")
                qachain_btn = gr.Button("Initialisiere QA-Chatbot")
                llm_status = gr.Textbox(label="Chatbot-Status", value="Nicht initialisiert", interactive=False)

            with gr.Column():
                chatbot = gr.Chatbot(label="Chatbot", height=400, type="messages")
                msg = gr.Textbox(label="Frage stellen")
                submit_btn = gr.Button("Absenden")

        # Event-Handling
        db_btn.click(
            initialize_database,
            inputs=[document],
            outputs=[vector_db, db_status]
        )
        qachain_btn.click(
            initialize_LLM,
            inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db],
            outputs=[qa_chain, llm_status]
        )
        submit_btn.click(
            conversation,
            inputs=[qa_chain, msg, chatbot],
            outputs=[qa_chain, chatbot, chatbot]
        )

    demo.launch(debug=True)

if __name__ == "__main__":
    demo()