Spaces:
Sleeping
Sleeping
File size: 4,538 Bytes
244a9ba 80396ad d82bfa1 244a9ba f1c2bc3 23cbcf8 2c20468 f1c2bc3 2c20468 f1c2bc3 80396ad 62d5470 f1c2bc3 80396ad f1c2bc3 2c20468 f1c2bc3 2c20468 f1c2bc3 80396ad f1c2bc3 80396ad f1c2bc3 80396ad f1c2bc3 80396ad f1c2bc3 80396ad f1c2bc3 80396ad f1c2bc3 80396ad f1c2bc3 80396ad 2c20468 f1c2bc3 80396ad f1c2bc3 80396ad 62d5470 f1c2bc3 8ca77ad f1c2bc3 80396ad f1c2bc3 8ca77ad f1c2bc3 2c20468 812f60c 80396ad f1c2bc3 80396ad f1c2bc3 2c20468 812f60c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import gradio as gr
import os
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint
from langchain_community.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
# API-Token
api_token = os.getenv("HF_TOKEN")
# LLM-Optionen
list_llm = ["google/flan-t5-small", "google/flan-t5-base"]
# Dokumente laden und aufteilen
def load_doc(list_file_path):
if not list_file_path:
return [], "Fehler: Keine Dokumente gefunden!"
loaders = [PyPDFLoader(x) for x in list_file_path]
documents = []
for loader in loaders:
documents.extend(loader.load())
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
return text_splitter.split_documents(documents)
# Vektor-Datenbank erstellen
def create_db(splits):
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
return FAISS.from_documents(splits, embeddings)
# Datenbank initialisieren
def initialize_database(list_file_obj):
if not list_file_obj:
return None, "Fehler: Keine Dateien hochgeladen!"
list_file_path = [x.name for x in list_file_obj if x is not None]
doc_splits = load_doc(list_file_path)
vector_db = create_db(doc_splits)
return vector_db, "Datenbank erfolgreich erstellt!"
# LLM-Kette initialisieren
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
if vector_db is None:
return None, "Fehler: Keine Vektordatenbank verfügbar."
llm = HuggingFaceEndpoint(
repo_id=llm_model,
huggingfacehub_api_token=api_token,
temperature=temperature,
max_new_tokens=max_tokens,
top_k=top_k,
)
memory = ConversationBufferMemory(memory_key="chat_history", output_key="answer", return_messages=True)
retriever = vector_db.as_retriever()
return ConversationalRetrievalChain.from_llm(
llm, retriever=retriever, chain_type="stuff", memory=memory, return_source_documents=True
)
# LLM initialisieren
def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db):
if vector_db is None:
return None, "Datenbank wurde nicht erstellt!"
llm_name = list_llm[llm_option]
qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db)
return qa_chain, "QA-Kette initialisiert. Chatbot ist bereit!"
# Konversation
def conversation(qa_chain, message, history):
if qa_chain is None:
return None, [{"role": "system", "content": "Die QA-Kette wurde nicht initialisiert."}], history
if not message.strip():
return qa_chain, [{"role": "system", "content": "Bitte eine Frage eingeben!"}], history
response = qa_chain.invoke({"question": message, "chat_history": history})
response_text = response.get("answer", "Keine Antwort verfügbar.")
sources = [doc.metadata["source"] for doc in response.get("source_documents", [])]
formatted_response = history + [{"role": "assistant", "content": response_text}]
return qa_chain, formatted_response, formatted_response
# Demo erstellen
def demo():
with gr.Blocks() as demo:
vector_db = gr.State()
qa_chain = gr.State()
gr.Markdown("<center><h1>PDF-Chatbot mit kostenlosen Modellen</h1></center>")
document = gr.Files(label="PDF-Dokument hochladen")
db_btn = gr.Button("Erstelle Vektordatenbank")
llm_btn = gr.Radio(["Flan-T5 Small", "Flan-T5 Base"], label="Verfügbare LLMs", value="Flan-T5 Small", type="index")
slider_temperature = gr.Slider(0.01, 1.0, 0.5, label="Temperature")
slider_maxtokens = gr.Slider(128, 2048, 512, label="Max Tokens")
slider_topk = gr.Slider(1, 10, 3, label="Top-k")
qachain_btn = gr.Button("Initialisiere QA-Chatbot")
chatbot = gr.Chatbot(label="Chatbot", height=400)
msg = gr.Textbox(label="Frage stellen")
submit_btn = gr.Button("Absenden")
db_btn.click(initialize_database, inputs=[document], outputs=[vector_db])
qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], outputs=[qa_chain])
submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, chatbot, chatbot])
demo.launch(debug=True)
if __name__ == "__main__":
demo()
|