Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint | |
from langchain_community.vectorstores import FAISS | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.memory import ConversationBufferMemory | |
# API-Token aus Umgebungsvariable laden | |
api_token = os.getenv("HF_Token") | |
# Modelle für Auswahl | |
list_llm = [ | |
"google/flan-t5-base", # Leichtes Instruktionsmodell | |
"sentence-transformers/all-MiniLM-L6-v2", # Embeddings-optimiertes Modell | |
"OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5", # Pythia 12B | |
"bigscience/bloom-3b", # Multilingualer BLOOM | |
"bigscience/bloom-1b7" # Leichtes BLOOM-Modell | |
] | |
# Dokumentenverarbeitung | |
def load_doc(list_file_path): | |
if not list_file_path: | |
return [], "Fehler: Keine Dokumente gefunden!" | |
loaders = [PyPDFLoader(x) for x in list_file_path] | |
documents = [] | |
for loader in loaders: | |
documents.extend(loader.load()) | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=32) | |
return text_splitter.split_documents(documents) | |
# Erstelle Vektordatenbank | |
def create_db(splits): | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
return FAISS.from_documents(splits, embeddings) | |
# Initialisiere Datenbank | |
def initialize_database(list_file_obj): | |
if not list_file_obj: | |
return None, "Fehler: Keine Dateien hochgeladen!" | |
list_file_path = list_file_obj # Dateipfade von den hochgeladenen Dateien | |
doc_splits = load_doc(list_file_path) | |
vector_db = create_db(doc_splits) | |
return vector_db, "Datenbank erfolgreich erstellt!" | |
# Initialisiere LLM-Kette | |
def initialize_llmchain(llm_model, temperature, max_tokens, vector_db): | |
if vector_db is None: | |
return None, "Fehler: Keine Vektordatenbank verfügbar." | |
if "pythia" in llm_model or "bloom" in llm_model: | |
max_tokens = min(max_tokens, 2048) | |
else: | |
max_tokens = min(max_tokens, 1024) | |
llm = HuggingFaceEndpoint( | |
repo_id=llm_model, | |
huggingfacehub_api_token=api_token, | |
temperature=temperature, | |
max_new_tokens=max_tokens | |
) | |
memory = ConversationBufferMemory(memory_key="chat_history", output_key="answer", return_messages=True) | |
retriever = vector_db.as_retriever() | |
qa_chain = ConversationalRetrievalChain.from_llm( | |
llm, retriever=retriever, chain_type="stuff", memory=memory, return_source_documents=True | |
) | |
return qa_chain | |
# Initialisiere LLM | |
def initialize_LLM(llm_option, llm_temperature, max_tokens, vector_db): | |
if vector_db is None: | |
return None, "Fehler: Datenbank wurde nicht erstellt!" | |
llm_name = list_llm[llm_option] | |
qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, vector_db) | |
return qa_chain, "QA-Kette initialisiert. Chatbot ist bereit!" | |
# Konversation | |
def conversation(qa_chain, message, history): | |
if qa_chain is None: | |
return None, [{"role": "system", "content": "Die QA-Kette wurde nicht initialisiert."}], history | |
if not message.strip(): | |
return qa_chain, [{"role": "system", "content": "Bitte eine Frage eingeben!"}], history | |
response = qa_chain.invoke({"question": message, "chat_history": history}) | |
response_text = response.get("answer", "Keine Antwort verfügbar.") | |
formatted_response = history + [{"role": "user", "content": message}, {"role": "assistant", "content": response_text}] | |
return qa_chain, formatted_response, formatted_response | |
# Gradio UI | |
def demo(): | |
with gr.Blocks() as demo: | |
vector_db = gr.State() | |
qa_chain = gr.State() | |
gr.Markdown("<center><h1>RAG-Chatbot mit Pythia und BLOOM (CPU-kompatibel)</h1></center>") | |
with gr.Row(): | |
with gr.Column(): | |
document = gr.Files(label="PDF-Dokument hochladen", type="filepath", file_types=[".pdf"], file_count="multiple") | |
db_btn = gr.Button("Erstelle Vektordatenbank") | |
db_status = gr.Textbox(label="Datenbankstatus", value="Nicht erstellt", interactive=False) | |
llm_btn = gr.Radio( | |
["Flan-T5 Base", "MiniLM", "Pythia 12B", "BLOOM 3B", "BLOOM 1.7B"], | |
label="Verfügbare LLMs", | |
value="Flan-T5 Base", | |
type="index" | |
) | |
slider_temperature = gr.Slider(0.01, 1.0, 0.5, label="Temperature") | |
slider_maxtokens = gr.Slider(1, 2048, 512, label="Max Tokens") | |
qachain_btn = gr.Button("Initialisiere QA-Chatbot") | |
llm_status = gr.Textbox(label="Chatbot-Status", value="Nicht initialisiert", interactive=False) | |
with gr.Column(): | |
chatbot = gr.Chatbot(label="Chatbot", height=400, type="messages") | |
msg = gr.Textbox(label="Frage stellen") | |
submit_btn = gr.Button("Absenden") | |
# Events verknüpfen | |
db_btn.click(initialize_database, inputs=[document], outputs=[vector_db, db_status]) | |
qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, vector_db], outputs=[qa_chain, llm_status]) | |
submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, chatbot, chatbot]) | |
demo.launch(debug=True) | |
if __name__ == "__main__": | |
demo() | |