Spaces:
Sleeping
Sleeping
File size: 5,205 Bytes
244a9ba 80396ad d82bfa1 244a9ba f1c2bc3 23cbcf8 2c20468 f1c2bc3 2c20468 f1c2bc3 80396ad 62d5470 f1c2bc3 80396ad f1c2bc3 2c20468 f1c2bc3 2c20468 f1c2bc3 80396ad f1c2bc3 80396ad f1c2bc3 80396ad f1c2bc3 80396ad f1c2bc3 80396ad f1c2bc3 15da3c5 80396ad f1c2bc3 80396ad f1c2bc3 80396ad 2c20468 f1c2bc3 80396ad f1c2bc3 6dedc06 80396ad 62d5470 f1c2bc3 8ca77ad f1c2bc3 80396ad f1c2bc3 6dedc06 f1c2bc3 8ca77ad f1c2bc3 2c20468 812f60c 80396ad f1c2bc3 80396ad d14d249 15da3c5 d14d249 80396ad f1c2bc3 2c20468 812f60c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import gradio as gr
import os
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint
from langchain_community.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
# API-Token
api_token = os.getenv("HF_TOKEN")
# LLM-Optionen
list_llm = ["google/flan-t5-small", "google/flan-t5-base"]
# Dokumente laden und aufteilen
def load_doc(list_file_path):
if not list_file_path:
return [], "Fehler: Keine Dokumente gefunden!"
loaders = [PyPDFLoader(x) for x in list_file_path]
documents = []
for loader in loaders:
documents.extend(loader.load())
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
return text_splitter.split_documents(documents)
# Vektor-Datenbank erstellen
def create_db(splits):
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
return FAISS.from_documents(splits, embeddings)
# Datenbank initialisieren
def initialize_database(list_file_obj):
if not list_file_obj:
return None, "Fehler: Keine Dateien hochgeladen!"
list_file_path = [x.name for x in list_file_obj if x is not None]
doc_splits = load_doc(list_file_path)
vector_db = create_db(doc_splits)
return vector_db, "Datenbank erfolgreich erstellt!"
# LLM-Kette initialisieren
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
if vector_db is None:
return None, "Fehler: Keine Vektordatenbank verfügbar."
if max_tokens > 250:
max_tokens = 250 # Begrenze max_new_tokens, um Fehler zu vermeiden
llm = HuggingFaceEndpoint(
repo_id=llm_model,
huggingfacehub_api_token=api_token,
temperature=temperature,
max_new_tokens=max_tokens,
top_k=top_k,
)
memory = ConversationBufferMemory(memory_key="chat_history", output_key="answer", return_messages=True)
retriever = vector_db.as_retriever()
return ConversationalRetrievalChain.from_llm(
llm, retriever=retriever, chain_type="stuff", memory=memory, return_source_documents=True
)
# LLM initialisieren
def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db):
if vector_db is None:
return None, "Fehler: Datenbank wurde nicht erstellt!"
llm_name = list_llm[llm_option]
qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db)
return qa_chain, "QA-Kette initialisiert. Chatbot ist bereit!"
# Konversation
def conversation(qa_chain, message, history):
if qa_chain is None:
return None, [{"role": "system", "content": "Die QA-Kette wurde nicht initialisiert."}], history
if not message.strip():
return qa_chain, [{"role": "system", "content": "Bitte eine Frage eingeben!"}], history
response = qa_chain.invoke({"question": message, "chat_history": history})
response_text = response.get("answer", "Keine Antwort verfügbar.")
formatted_response = history + [{"role": "user", "content": message}, {"role": "assistant", "content": response_text}]
return qa_chain, formatted_response, formatted_response
# Demo erstellen
def demo():
with gr.Blocks() as demo:
vector_db = gr.State()
qa_chain = gr.State()
gr.Markdown("<center><h1>PDF-Chatbot mit kostenlosen Modellen</h1></center>")
with gr.Row():
with gr.Column():
document = gr.Files(label="PDF-Dokument hochladen")
db_btn = gr.Button("Erstelle Vektordatenbank")
db_status = gr.Textbox(label="Datenbankstatus", value="Nicht erstellt", interactive=False)
llm_btn = gr.Radio(["Flan-T5 Small", "Flan-T5 Base"], label="Verfügbare LLMs", value="Flan-T5 Small", type="index")
slider_temperature = gr.Slider(0.01, 1.0, 0.5, label="Temperature")
slider_maxtokens = gr.Slider(1, 250, 128, label="Max Tokens") # Begrenzung auf 250
slider_topk = gr.Slider(1, 10, 3, label="Top-k")
qachain_btn = gr.Button("Initialisiere QA-Chatbot")
llm_status = gr.Textbox(label="Chatbot-Status", value="Nicht initialisiert", interactive=False)
with gr.Column():
chatbot = gr.Chatbot(label="Chatbot", height=400, type="messages")
msg = gr.Textbox(label="Frage stellen")
submit_btn = gr.Button("Absenden")
# Event-Handling
db_btn.click(
initialize_database,
inputs=[document],
outputs=[vector_db, db_status]
)
qachain_btn.click(
initialize_LLM,
inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db],
outputs=[qa_chain, llm_status]
)
submit_btn.click(
conversation,
inputs=[qa_chain, msg, chatbot],
outputs=[qa_chain, chatbot, chatbot]
)
demo.launch(debug=True)
if __name__ == "__main__":
demo()
|