Spaces:
Sleeping
Sleeping
import gradio as gr | |
import requests | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from fastapi.responses import JSONResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.llms import HuggingFacePipeline | |
from langchain.prompts import PromptTemplate | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from transformers import pipeline, AutoModelForQuestionAnswering | |
import logging | |
# Logging einrichten | |
logging.basicConfig(level=logging.INFO) | |
# 1. Initialisiere das Embedding-Modell von Hugging Face | |
embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2" | |
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name) | |
# FastAPI Backend | |
app = FastAPI() | |
# CORS-Middleware hinzufügen | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Erlaubt alle Ursprünge | |
allow_credentials=True, | |
allow_methods=["*"], # Alle Methoden | |
allow_headers=["*"], # Alle Header | |
) | |
class QueryRequest(BaseModel): | |
question: str | |
chat_history: list | |
num_sources: int = 3 # Standardmäßig 3 Quellen zurückgeben | |
async def query(request: QueryRequest): | |
try: | |
logging.info(f"Received query: {request.question}") | |
result = qa_chain({"question": request.question, "chat_history": request.chat_history}) | |
# Begrenze die Anzahl der zurückgegebenen Quellen | |
sources = [ | |
{"source": doc.metadata["source"], "content": doc.page_content} | |
for doc in result["source_documents"][:request.num_sources] | |
] | |
response = { | |
"answer": result["answer"], | |
"sources": sources | |
} | |
logging.info(f"Answer: {response['answer']}") | |
return JSONResponse(content=response) | |
except Exception as e: | |
logging.error(f"Error processing query: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
# 2. Lade PDF-Dokumente und extrahiere Inhalte | |
def load_pdf(file): | |
loader = PyPDFLoader(file.name) | |
pages = loader.load() | |
# Text-Splitting für eine bessere Genauigkeit bei der Abfrage | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) | |
# Seiten mit einer fortlaufenden Seitenzahl versehen und in Document-Objekte umwandeln | |
documents = [ | |
{"content": page.page_content, "metadata": {"source": f"Seite {i + 1}"}} | |
for i, page in enumerate(pages) | |
] | |
# Texte splitten | |
split_texts = text_splitter.split_documents(pages) | |
# FAISS Vektorspeicher mit den gesplitteten Texten | |
vectorstore = FAISS.from_texts([text.page_content for text in split_texts], embeddings, metadatas=[text.metadata for text in split_texts]) | |
return vectorstore | |
# 3. Initialisiere ein Frage-Antwort Modell | |
qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2", tokenizer="deepset/roberta-base-squad2") | |
# 4. Konfiguriere den Prompt für die Retrieval-Augmented-Generation | |
prompt_template = """ | |
Du bist ein hilfreicher KI-Chatbot. Nutze die folgenden Informationen, um die Frage zu beantworten: | |
{context} | |
Frage: {question} | |
Antwort: | |
""" | |
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"]) | |
# 5. Baue die ConversationalRetrievalChain | |
qa_chain = ConversationalRetrievalChain( | |
retriever=vectorstore.as_retriever(), | |
llm=HuggingFacePipeline(pipeline=qa_pipeline), | |
prompt=prompt, | |
return_source_documents=True | |
) | |
# Gradio Frontend | |
API_URL = "http://localhost:8000/query" # Kann auf die interne API umgestellt werden | |
def chat_with_bot(user_input, chat_history): | |
response = requests.post( | |
API_URL, | |
json={"question": user_input, "chat_history": chat_history} | |
) | |
if response.status_code == 200: | |
data = response.json() | |
answer = data["answer"] | |
sources = data.get("sources", []) | |
sources_text = "\n".join([f"{src['source']}: {src['content']}" for src in sources]) | |
chat_history.append((user_input, answer)) | |
return answer, chat_history, sources_text | |
else: | |
return "Ein Fehler ist aufgetreten: " + response.text, chat_history, "" | |
def upload_pdf(file): | |
vectorstore = load_pdf(file) | |
return "PDF erfolgreich hochgeladen." | |
with gr.Blocks() as demo: | |
gr.Markdown("## Chatbot mit RAG (LangChain)") | |
# Datei-Upload-Komponente | |
with gr.Row(): | |
upload_button = gr.File(label="Lade PDF hoch", file_count="single") | |
upload_button.upload(upload_pdf, inputs=upload_button, outputs="status") | |
chatbot = gr.Chatbot(label="Chatbot") | |
question = gr.Textbox(label="Deine Frage") | |
sources = gr.Textbox(label="Quellen", interactive=False) | |
submit = gr.Button("Senden") | |
clear = gr.Button("Chatverlauf löschen") | |
chat_history = gr.State([]) | |
submit.click( | |
fn=chat_with_bot, | |
inputs=[question, chat_history], | |
outputs=[chatbot, chat_history, sources], | |
show_progress=True, | |
) | |
clear.click(lambda: ([], [], ""), inputs=[], outputs=[chatbot, chat_history, sources]) | |
demo.launch() | |