File size: 3,916 Bytes
7f96312
2c20468
23cbcf8
 
 
 
 
 
 
2c20468
23cbcf8
2c20468
 
 
 
 
23cbcf8
 
2c20468
 
 
 
23cbcf8
 
2c20468
 
812f60c
7b3bf1d
812f60c
7b3bf1d
 
 
 
 
812f60c
0dda7f4
812f60c
2c20468
 
 
812f60c
2c20468
 
812f60c
2c20468
 
 
812f60c
2c20468
 
 
812f60c
2c20468
812f60c
2c20468
812f60c
 
2c20468
 
23cbcf8
7b3bf1d
2c20468
 
812f60c
2c20468
 
812f60c
2c20468
 
23cbcf8
2c20468
812f60c
 
 
 
23cbcf8
 
 
812f60c
 
 
 
 
 
 
 
23cbcf8
812f60c
 
2c20468
 
812f60c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import gradio as gr
import os
from langchain.vectorstores import Chroma  # Chroma als Vektordatenbank
from langchain.document_loaders import PyPDFLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain.llms import HuggingFaceHub
from langchain.text_splitter import RecursiveCharacterTextSplitter

list_llm = ["google/flan-t5-small", "sentence-transformers/all-MiniLM-L6-v2"]

def load_doc(list_file_path):
    loaders = [PyPDFLoader(x) for x in list_file_path]
    pages = []
    for loader in loaders:
        pages.extend(loader.load())
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=32)
    doc_splits = text_splitter.split_documents(pages)
    return doc_splits

def create_db(splits):
    embeddings = HuggingFaceEmbeddings()
    vectordb = Chroma.from_documents(splits, embeddings)
    return vectordb

def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
    llm = HuggingFaceHub(
        repo_id=llm_model,
        model_kwargs={
            "temperature": temperature,
            "max_length": max_tokens,
            "top_k": top_k,
        }
    )
    memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
    retriever = vector_db.as_retriever()
    qa_chain = ConversationalRetrievalChain.from_llm(
        llm,
        retriever=retriever,
        chain_type="stuff",
        memory=memory,
        return_source_documents=True,
        verbose=False
    )
    return qa_chain

def initialize_database(list_file_obj):
    list_file_path = [x.name for x in list_file_obj if x is not None]
    doc_splits = load_doc(list_file_path)
    vector_db = create_db(doc_splits)
    return vector_db, "Datenbank erfolgreich erstellt!"

def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db):
    llm_name = list_llm[llm_option]
    qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db)
    return qa_chain, "LLM erfolgreich initialisiert! Chatbot ist bereit."

def conversation(qa_chain, message, history):
    formatted_chat_history = [(f"User: {m}", f"Assistant: {r}") for m, r in history]
    response = qa_chain({"question": message, "chat_history": formatted_chat_history})
    response_answer = response["answer"]
    new_history = history + [(message, response_answer)]
    return qa_chain, gr.update(value=""), new_history

def demo():
    with gr.Blocks() as demo:
        vector_db = gr.State()
        qa_chain = gr.State()
        gr.HTML("<center><h1>RAG PDF Chatbot (Kostenlose Version)</h1></center>")
        with gr.Row():
            with gr.Column():
                document = gr.Files(height=300, file_count="multiple", file_types=[".pdf"], interactive=True)
                db_btn = gr.Button("Erstelle Vektordatenbank")
                db_progress = gr.Textbox(value="Nicht initialisiert", show_label=False)
                llm_btn = gr.Radio(["Flan-T5-Small", "MiniLM"], label="Verfügbare Modelle")
                slider_temperature = gr.Slider(0.01, 1.0, value=0.5, label="Temperature")
                slider_maxtokens = gr.Slider(64, 512, value=256, label="Max Tokens")
                qachain_btn = gr.Button("Initialisiere QA-Chatbot")

            with gr.Column():
                chatbot = gr.Chatbot(height=400, type="messages")
                msg = gr.Textbox(placeholder="Frage stellen...")
                submit_btn = gr.Button("Absenden")

        db_btn.click(initialize_database, [document], [vector_db, db_progress])
        qachain_btn.click(initialize_LLM, [llm_btn, slider_temperature, slider_maxtokens, vector_db], [qa_chain])
        submit_btn.click(conversation, [qa_chain, msg, chatbot], [qa_chain, msg, chatbot])
    demo.launch(debug=True)

if __name__ == "__main__":
    demo()