RAG_test_1 / app.py
la04's picture
Update app.py
f1c2bc3 verified
raw
history blame
4.54 kB
import gradio as gr
import os
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint
from langchain_community.vectorstores import FAISS
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
# API-Token
api_token = os.getenv("HF_TOKEN")
# LLM-Optionen
list_llm = ["google/flan-t5-small", "google/flan-t5-base"]
# Dokumente laden und aufteilen
def load_doc(list_file_path):
if not list_file_path:
return [], "Fehler: Keine Dokumente gefunden!"
loaders = [PyPDFLoader(x) for x in list_file_path]
documents = []
for loader in loaders:
documents.extend(loader.load())
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
return text_splitter.split_documents(documents)
# Vektor-Datenbank erstellen
def create_db(splits):
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
return FAISS.from_documents(splits, embeddings)
# Datenbank initialisieren
def initialize_database(list_file_obj):
if not list_file_obj:
return None, "Fehler: Keine Dateien hochgeladen!"
list_file_path = [x.name for x in list_file_obj if x is not None]
doc_splits = load_doc(list_file_path)
vector_db = create_db(doc_splits)
return vector_db, "Datenbank erfolgreich erstellt!"
# LLM-Kette initialisieren
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
if vector_db is None:
return None, "Fehler: Keine Vektordatenbank verfügbar."
llm = HuggingFaceEndpoint(
repo_id=llm_model,
huggingfacehub_api_token=api_token,
temperature=temperature,
max_new_tokens=max_tokens,
top_k=top_k,
)
memory = ConversationBufferMemory(memory_key="chat_history", output_key="answer", return_messages=True)
retriever = vector_db.as_retriever()
return ConversationalRetrievalChain.from_llm(
llm, retriever=retriever, chain_type="stuff", memory=memory, return_source_documents=True
)
# LLM initialisieren
def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db):
if vector_db is None:
return None, "Datenbank wurde nicht erstellt!"
llm_name = list_llm[llm_option]
qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db)
return qa_chain, "QA-Kette initialisiert. Chatbot ist bereit!"
# Konversation
def conversation(qa_chain, message, history):
if qa_chain is None:
return None, [{"role": "system", "content": "Die QA-Kette wurde nicht initialisiert."}], history
if not message.strip():
return qa_chain, [{"role": "system", "content": "Bitte eine Frage eingeben!"}], history
response = qa_chain.invoke({"question": message, "chat_history": history})
response_text = response.get("answer", "Keine Antwort verfügbar.")
sources = [doc.metadata["source"] for doc in response.get("source_documents", [])]
formatted_response = history + [{"role": "assistant", "content": response_text}]
return qa_chain, formatted_response, formatted_response
# Demo erstellen
def demo():
with gr.Blocks() as demo:
vector_db = gr.State()
qa_chain = gr.State()
gr.Markdown("<center><h1>PDF-Chatbot mit kostenlosen Modellen</h1></center>")
document = gr.Files(label="PDF-Dokument hochladen")
db_btn = gr.Button("Erstelle Vektordatenbank")
llm_btn = gr.Radio(["Flan-T5 Small", "Flan-T5 Base"], label="Verfügbare LLMs", value="Flan-T5 Small", type="index")
slider_temperature = gr.Slider(0.01, 1.0, 0.5, label="Temperature")
slider_maxtokens = gr.Slider(128, 2048, 512, label="Max Tokens")
slider_topk = gr.Slider(1, 10, 3, label="Top-k")
qachain_btn = gr.Button("Initialisiere QA-Chatbot")
chatbot = gr.Chatbot(label="Chatbot", height=400)
msg = gr.Textbox(label="Frage stellen")
submit_btn = gr.Button("Absenden")
db_btn.click(initialize_database, inputs=[document], outputs=[vector_db])
qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], outputs=[qa_chain])
submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, chatbot, chatbot])
demo.launch(debug=True)
if __name__ == "__main__":
demo()