File size: 4,538 Bytes
244a9ba
80396ad
d82bfa1
244a9ba
f1c2bc3
 
23cbcf8
 
2c20468
f1c2bc3
 
2c20468
f1c2bc3
80396ad
62d5470
f1c2bc3
80396ad
f1c2bc3
 
2c20468
f1c2bc3
2c20468
f1c2bc3
80396ad
f1c2bc3
80396ad
f1c2bc3
80396ad
f1c2bc3
80396ad
 
f1c2bc3
 
 
 
 
 
 
 
 
 
80396ad
f1c2bc3
 
80396ad
 
f1c2bc3
80396ad
 
 
 
 
 
f1c2bc3
80396ad
 
2c20468
f1c2bc3
80396ad
f1c2bc3
 
80396ad
 
 
62d5470
f1c2bc3
8ca77ad
f1c2bc3
 
 
 
80396ad
f1c2bc3
 
 
 
8ca77ad
f1c2bc3
2c20468
812f60c
80396ad
 
f1c2bc3
 
80396ad
 
 
 
 
 
 
 
 
 
 
 
 
 
f1c2bc3
2c20468
 
812f60c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import gradio as gr
import os
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint
from langchain_community.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory

# API-Token
api_token = os.getenv("HF_TOKEN")

# LLM-Optionen
list_llm = ["google/flan-t5-small", "google/flan-t5-base"]

# Dokumente laden und aufteilen
def load_doc(list_file_path):
    if not list_file_path:
        return [], "Fehler: Keine Dokumente gefunden!"
    loaders = [PyPDFLoader(x) for x in list_file_path]
    documents = []
    for loader in loaders:
        documents.extend(loader.load())
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
    return text_splitter.split_documents(documents)

# Vektor-Datenbank erstellen
def create_db(splits):
    embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
    return FAISS.from_documents(splits, embeddings)

# Datenbank initialisieren
def initialize_database(list_file_obj):
    if not list_file_obj:
        return None, "Fehler: Keine Dateien hochgeladen!"
    list_file_path = [x.name for x in list_file_obj if x is not None]
    doc_splits = load_doc(list_file_path)
    vector_db = create_db(doc_splits)
    return vector_db, "Datenbank erfolgreich erstellt!"

# LLM-Kette initialisieren
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
    if vector_db is None:
        return None, "Fehler: Keine Vektordatenbank verfügbar."
    llm = HuggingFaceEndpoint(
        repo_id=llm_model,
        huggingfacehub_api_token=api_token,
        temperature=temperature,
        max_new_tokens=max_tokens,
        top_k=top_k,
    )
    memory = ConversationBufferMemory(memory_key="chat_history", output_key="answer", return_messages=True)
    retriever = vector_db.as_retriever()
    return ConversationalRetrievalChain.from_llm(
        llm, retriever=retriever, chain_type="stuff", memory=memory, return_source_documents=True
    )

# LLM initialisieren
def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db):
    if vector_db is None:
        return None, "Datenbank wurde nicht erstellt!"
    llm_name = list_llm[llm_option]
    qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db)
    return qa_chain, "QA-Kette initialisiert. Chatbot ist bereit!"

# Konversation
def conversation(qa_chain, message, history):
    if qa_chain is None:
        return None, [{"role": "system", "content": "Die QA-Kette wurde nicht initialisiert."}], history
    if not message.strip():
        return qa_chain, [{"role": "system", "content": "Bitte eine Frage eingeben!"}], history
    response = qa_chain.invoke({"question": message, "chat_history": history})
    response_text = response.get("answer", "Keine Antwort verfügbar.")
    sources = [doc.metadata["source"] for doc in response.get("source_documents", [])]
    formatted_response = history + [{"role": "assistant", "content": response_text}]
    return qa_chain, formatted_response, formatted_response

# Demo erstellen
def demo():
    with gr.Blocks() as demo:
        vector_db = gr.State()
        qa_chain = gr.State()
        gr.Markdown("<center><h1>PDF-Chatbot mit kostenlosen Modellen</h1></center>")
        document = gr.Files(label="PDF-Dokument hochladen")
        db_btn = gr.Button("Erstelle Vektordatenbank")
        llm_btn = gr.Radio(["Flan-T5 Small", "Flan-T5 Base"], label="Verfügbare LLMs", value="Flan-T5 Small", type="index")
        slider_temperature = gr.Slider(0.01, 1.0, 0.5, label="Temperature")
        slider_maxtokens = gr.Slider(128, 2048, 512, label="Max Tokens")
        slider_topk = gr.Slider(1, 10, 3, label="Top-k")
        qachain_btn = gr.Button("Initialisiere QA-Chatbot")
        chatbot = gr.Chatbot(label="Chatbot", height=400)
        msg = gr.Textbox(label="Frage stellen")
        submit_btn = gr.Button("Absenden")

        db_btn.click(initialize_database, inputs=[document], outputs=[vector_db])
        qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], outputs=[qa_chain])
        submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, chatbot, chatbot])

    demo.launch(debug=True)

if __name__ == "__main__":
    demo()