Spaces:
Sleeping
Sleeping
File size: 5,500 Bytes
244a9ba 80396ad d82bfa1 244a9ba f1c2bc3 23cbcf8 2c20468 fd3e3c6 04dd8cd 2c20468 90d6700 62d5470 90d6700 80396ad f1c2bc3 2c20468 f1c2bc3 2c20468 f1c2bc3 90d6700 f1c2bc3 80396ad 90d6700 80396ad f1c2bc3 80396ad 90d6700 f1c2bc3 a0ac2ce f1c2bc3 90d6700 f1c2bc3 90d6700 80396ad f1c2bc3 80396ad 90d6700 80396ad 90d6700 80396ad 2c20468 90d6700 f1c2bc3 6dedc06 80396ad 90d6700 80396ad 62d5470 f1c2bc3 8ca77ad f1c2bc3 80396ad f1c2bc3 6dedc06 f1c2bc3 8ca77ad 90d6700 2c20468 812f60c 80396ad 90d6700 80396ad d14d249 a0ac2ce d14d249 90d6700 d14d249 90d6700 d14d249 90d6700 3b7742d 2c20468 812f60c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import gradio as gr
import os
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint
from langchain_community.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
# API-Token aus Umgebungsvariable laden
api_token = os.getenv("HF_Token")
# Modelle für Auswahl
list_llm = [
"google/flan-t5-base", # Leichtes Instruktionsmodell
"sentence-transformers/all-MiniLM-L6-v2", # Embeddings-optimiertes Modell
"OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5", # Pythia 12B
"bigscience/bloom-3b", # Multilingualer BLOOM
"bigscience/bloom-1b7" # Leichtes BLOOM-Modell
]
# Dokumentenverarbeitung
def load_doc(list_file_path):
if not list_file_path:
return [], "Fehler: Keine Dokumente gefunden!"
loaders = [PyPDFLoader(x) for x in list_file_path]
documents = []
for loader in loaders:
documents.extend(loader.load())
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=32)
return text_splitter.split_documents(documents)
# Erstelle Vektordatenbank
def create_db(splits):
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
return FAISS.from_documents(splits, embeddings)
# Initialisiere Datenbank
def initialize_database(list_file_obj):
if not list_file_obj:
return None, "Fehler: Keine Dateien hochgeladen!"
list_file_path = list_file_obj # Dateipfade von den hochgeladenen Dateien
doc_splits = load_doc(list_file_path)
vector_db = create_db(doc_splits)
return vector_db, "Datenbank erfolgreich erstellt!"
# Initialisiere LLM-Kette
def initialize_llmchain(llm_model, temperature, max_tokens, vector_db):
if vector_db is None:
return None, "Fehler: Keine Vektordatenbank verfügbar."
if "pythia" in llm_model or "bloom" in llm_model:
max_tokens = min(max_tokens, 2048)
else:
max_tokens = min(max_tokens, 1024)
llm = HuggingFaceEndpoint(
repo_id=llm_model,
huggingfacehub_api_token=api_token,
temperature=temperature,
max_new_tokens=max_tokens
)
memory = ConversationBufferMemory(memory_key="chat_history", output_key="answer", return_messages=True)
retriever = vector_db.as_retriever()
qa_chain = ConversationalRetrievalChain.from_llm(
llm, retriever=retriever, chain_type="stuff", memory=memory, return_source_documents=True
)
return qa_chain
# Initialisiere LLM
def initialize_LLM(llm_option, llm_temperature, max_tokens, vector_db):
if vector_db is None:
return None, "Fehler: Datenbank wurde nicht erstellt!"
llm_name = list_llm[llm_option]
qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, vector_db)
return qa_chain, "QA-Kette initialisiert. Chatbot ist bereit!"
# Konversation
def conversation(qa_chain, message, history):
if qa_chain is None:
return None, [{"role": "system", "content": "Die QA-Kette wurde nicht initialisiert."}], history
if not message.strip():
return qa_chain, [{"role": "system", "content": "Bitte eine Frage eingeben!"}], history
response = qa_chain.invoke({"question": message, "chat_history": history})
response_text = response.get("answer", "Keine Antwort verfügbar.")
formatted_response = history + [{"role": "user", "content": message}, {"role": "assistant", "content": response_text}]
return qa_chain, formatted_response, formatted_response
# Gradio UI
def demo():
with gr.Blocks() as demo:
vector_db = gr.State()
qa_chain = gr.State()
gr.Markdown("<center><h1>RAG-Chatbot mit Pythia und BLOOM (CPU-kompatibel)</h1></center>")
with gr.Row():
with gr.Column():
document = gr.Files(label="PDF-Dokument hochladen", type="filepath", file_types=[".pdf"], file_count="multiple")
db_btn = gr.Button("Erstelle Vektordatenbank")
db_status = gr.Textbox(label="Datenbankstatus", value="Nicht erstellt", interactive=False)
llm_btn = gr.Radio(
["Flan-T5 Base", "MiniLM", "Pythia 12B", "BLOOM 3B", "BLOOM 1.7B"],
label="Verfügbare LLMs",
value="Flan-T5 Base",
type="index"
)
slider_temperature = gr.Slider(0.01, 1.0, 0.5, label="Temperature")
slider_maxtokens = gr.Slider(1, 2048, 512, label="Max Tokens")
qachain_btn = gr.Button("Initialisiere QA-Chatbot")
llm_status = gr.Textbox(label="Chatbot-Status", value="Nicht initialisiert", interactive=False)
with gr.Column():
chatbot = gr.Chatbot(label="Chatbot", height=400, type="messages")
msg = gr.Textbox(label="Frage stellen")
submit_btn = gr.Button("Absenden")
# Events verknüpfen
db_btn.click(initialize_database, inputs=[document], outputs=[vector_db, db_status])
qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, vector_db], outputs=[qa_chain, llm_status])
submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, chatbot, chatbot])
demo.launch(debug=True)
if __name__ == "__main__":
demo()
|