Spaces:
Sleeping
Sleeping
File size: 5,373 Bytes
1474017 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import gradio as gr
import requests
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from langchain.chains import ConversationalRetrievalChain
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from transformers import pipeline, AutoModelForQuestionAnswering
import logging
# Logging einrichten
logging.basicConfig(level=logging.INFO)
# 1. Initialisiere das Embedding-Modell von Hugging Face
embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2"
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)
# FastAPI Backend
app = FastAPI()
# CORS-Middleware hinzufügen
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Erlaubt alle Ursprünge
allow_credentials=True,
allow_methods=["*"], # Alle Methoden
allow_headers=["*"], # Alle Header
)
class QueryRequest(BaseModel):
question: str
chat_history: list
num_sources: int = 3 # Standardmäßig 3 Quellen zurückgeben
@app.post("/query")
async def query(request: QueryRequest):
try:
logging.info(f"Received query: {request.question}")
result = qa_chain({"question": request.question, "chat_history": request.chat_history})
# Begrenze die Anzahl der zurückgegebenen Quellen
sources = [
{"source": doc.metadata["source"], "content": doc.page_content}
for doc in result["source_documents"][:request.num_sources]
]
response = {
"answer": result["answer"],
"sources": sources
}
logging.info(f"Answer: {response['answer']}")
return JSONResponse(content=response)
except Exception as e:
logging.error(f"Error processing query: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# 2. Lade PDF-Dokumente und extrahiere Inhalte
def load_pdf(file):
loader = PyPDFLoader(file.name)
pages = loader.load()
# Text-Splitting für eine bessere Genauigkeit bei der Abfrage
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
# Seiten mit einer fortlaufenden Seitenzahl versehen und in Document-Objekte umwandeln
documents = [
{"content": page.page_content, "metadata": {"source": f"Seite {i + 1}"}}
for i, page in enumerate(pages)
]
# Texte splitten
split_texts = text_splitter.split_documents(pages)
# FAISS Vektorspeicher mit den gesplitteten Texten
vectorstore = FAISS.from_texts([text.page_content for text in split_texts], embeddings, metadatas=[text.metadata for text in split_texts])
return vectorstore
# 3. Initialisiere ein Frage-Antwort Modell
qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2", tokenizer="deepset/roberta-base-squad2")
# 4. Konfiguriere den Prompt für die Retrieval-Augmented-Generation
prompt_template = """
Du bist ein hilfreicher KI-Chatbot. Nutze die folgenden Informationen, um die Frage zu beantworten:
{context}
Frage: {question}
Antwort:
"""
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
# 5. Baue die ConversationalRetrievalChain
qa_chain = ConversationalRetrievalChain(
retriever=vectorstore.as_retriever(),
llm=HuggingFacePipeline(pipeline=qa_pipeline),
prompt=prompt,
return_source_documents=True
)
# Gradio Frontend
API_URL = "http://localhost:8000/query" # Kann auf die interne API umgestellt werden
def chat_with_bot(user_input, chat_history):
response = requests.post(
API_URL,
json={"question": user_input, "chat_history": chat_history}
)
if response.status_code == 200:
data = response.json()
answer = data["answer"]
sources = data.get("sources", [])
sources_text = "\n".join([f"{src['source']}: {src['content']}" for src in sources])
chat_history.append((user_input, answer))
return answer, chat_history, sources_text
else:
return "Ein Fehler ist aufgetreten: " + response.text, chat_history, ""
def upload_pdf(file):
vectorstore = load_pdf(file)
return "PDF erfolgreich hochgeladen."
with gr.Blocks() as demo:
gr.Markdown("## Chatbot mit RAG (LangChain)")
# Datei-Upload-Komponente
with gr.Row():
upload_button = gr.File(label="Lade PDF hoch", file_count="single")
upload_button.upload(upload_pdf, inputs=upload_button, outputs="status")
chatbot = gr.Chatbot(label="Chatbot")
question = gr.Textbox(label="Deine Frage")
sources = gr.Textbox(label="Quellen", interactive=False)
submit = gr.Button("Senden")
clear = gr.Button("Chatverlauf löschen")
chat_history = gr.State([])
submit.click(
fn=chat_with_bot,
inputs=[question, chat_history],
outputs=[chatbot, chat_history, sources],
show_progress=True,
)
clear.click(lambda: ([], [], ""), inputs=[], outputs=[chatbot, chat_history, sources])
demo.launch()
|