File size: 3,001 Bytes
2c20468
244a9ba
d82bfa1
244a9ba
d82bfa1
 
23cbcf8
 
d82bfa1
244a9ba
2c20468
244a9ba
d82bfa1
2c20468
244a9ba
d82bfa1
 
2c20468
244a9ba
2c20468
244a9ba
23cbcf8
d82bfa1
2c20468
244a9ba
 
d82bfa1
2c20468
244a9ba
 
 
 
 
 
812f60c
244a9ba
d82bfa1
812f60c
d82bfa1
2c20468
 
 
244a9ba
2c20468
 
 
812f60c
2c20468
 
244a9ba
 
2c20468
812f60c
244a9ba
812f60c
244a9ba
23cbcf8
244a9ba
812f60c
 
 
d82bfa1
244a9ba
812f60c
 
244a9ba
 
d82bfa1
244a9ba
d82bfa1
2c20468
 
812f60c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import gradio as gr
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from langchain_community.llms import HuggingFacePipeline
from transformers import pipeline

EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
LLM_MODEL_NAME = "google/flan-t5-small"

def load_and_split_docs(list_file_path):
    if not list_file_path:
        return [], "Fehler: Keine Dokumente gefunden!"
    loaders = [PyPDFLoader(x) for x in list_file_path]
    documents = []
    for loader in loaders:
        documents.extend(loader.load())
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=32)
    return text_splitter.split_documents(documents)

def create_db(docs):
    embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
    return FAISS.from_documents(docs, embeddings)

def initialize_llm_chain(llm_model, temperature, max_tokens, vector_db):
    local_pipeline = pipeline(
        "text2text-generation",
        model=llm_model,
        max_length=max_tokens,
        temperature=temperature
    )
    llm = HuggingFacePipeline(pipeline=local_pipeline)
    memory = ConversationBufferMemory(memory_key="chat_history")
    retriever = vector_db.as_retriever()
    return ConversationalRetrievalChain.from_llm(
        llm,
        retriever=retriever,
        memory=memory,
        return_source_documents=True
    )

def demo():
    with gr.Blocks() as demo:
        vector_db = gr.State()
        qa_chain = gr.State()

        gr.HTML("<center><h1>RAG Chatbot mit FAISS und lokalen Modellen</h1></center>")
        with gr.Row():
            with gr.Column():
                document = gr.Files(file_types=[".pdf"], label="PDF hochladen")
                db_btn = gr.Button("Erstelle Vektordatenbank")
                db_status = gr.Textbox(value="Status: Nicht initialisiert", show_label=False)
                slider_temperature = gr.Slider(0.01, 1.0, value=0.5, label="Temperature")
                slider_max_tokens = gr.Slider(64, 512, value=256, label="Max Tokens")
                qachain_btn = gr.Button("Initialisiere QA-Chatbot")

            with gr.Column():
                chatbot = gr.Chatbot(type='messages', height=400)
                msg = gr.Textbox(placeholder="Frage eingeben...")
                submit_btn = gr.Button("Absenden")

        db_btn.click(initialize_database, [document], [vector_db, db_status])
        qachain_btn.click(initialize_llm_chain_wrapper, [slider_temperature, slider_max_tokens, vector_db], [qa_chain])
        submit_btn.click(conversation, [qa_chain, msg, []], [qa_chain, "message", "history"])

    demo.launch(debug=True, enable_queue=True)

if __name__ == "__main__":
    demo()