File size: 5,373 Bytes
1474017
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import gradio as gr
import requests
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from langchain.chains import ConversationalRetrievalChain
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.llms import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from transformers import pipeline, AutoModelForQuestionAnswering
import logging

# Logging einrichten
logging.basicConfig(level=logging.INFO)

# 1. Initialisiere das Embedding-Modell von Hugging Face
embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2"
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)

# FastAPI Backend
app = FastAPI()

# CORS-Middleware hinzufügen
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Erlaubt alle Ursprünge
    allow_credentials=True,
    allow_methods=["*"],  # Alle Methoden
    allow_headers=["*"],  # Alle Header
)

class QueryRequest(BaseModel):
    question: str
    chat_history: list
    num_sources: int = 3  # Standardmäßig 3 Quellen zurückgeben

@app.post("/query")
async def query(request: QueryRequest):
    try:
        logging.info(f"Received query: {request.question}")
        
        result = qa_chain({"question": request.question, "chat_history": request.chat_history})
        
        # Begrenze die Anzahl der zurückgegebenen Quellen
        sources = [
            {"source": doc.metadata["source"], "content": doc.page_content}
            for doc in result["source_documents"][:request.num_sources]
        ]
        
        response = {
            "answer": result["answer"],
            "sources": sources
        }
        
        logging.info(f"Answer: {response['answer']}")
        return JSONResponse(content=response)
    
    except Exception as e:
        logging.error(f"Error processing query: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

# 2. Lade PDF-Dokumente und extrahiere Inhalte
def load_pdf(file):
    loader = PyPDFLoader(file.name)
    pages = loader.load()

    # Text-Splitting für eine bessere Genauigkeit bei der Abfrage
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)

    # Seiten mit einer fortlaufenden Seitenzahl versehen und in Document-Objekte umwandeln
    documents = [
        {"content": page.page_content, "metadata": {"source": f"Seite {i + 1}"}} 
        for i, page in enumerate(pages)
    ]
    # Texte splitten
    split_texts = text_splitter.split_documents(pages)

    # FAISS Vektorspeicher mit den gesplitteten Texten
    vectorstore = FAISS.from_texts([text.page_content for text in split_texts], embeddings, metadatas=[text.metadata for text in split_texts])
    return vectorstore

# 3. Initialisiere ein Frage-Antwort Modell
qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2", tokenizer="deepset/roberta-base-squad2")

# 4. Konfiguriere den Prompt für die Retrieval-Augmented-Generation
prompt_template = """
Du bist ein hilfreicher KI-Chatbot. Nutze die folgenden Informationen, um die Frage zu beantworten:
{context}

Frage: {question}
Antwort:
"""
prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])

# 5. Baue die ConversationalRetrievalChain
qa_chain = ConversationalRetrievalChain(
    retriever=vectorstore.as_retriever(),
    llm=HuggingFacePipeline(pipeline=qa_pipeline),
    prompt=prompt,
    return_source_documents=True
)

# Gradio Frontend
API_URL = "http://localhost:8000/query"  # Kann auf die interne API umgestellt werden

def chat_with_bot(user_input, chat_history):
    response = requests.post(
        API_URL,
        json={"question": user_input, "chat_history": chat_history}
    )
    if response.status_code == 200:
        data = response.json()
        answer = data["answer"]
        sources = data.get("sources", [])
        sources_text = "\n".join([f"{src['source']}: {src['content']}" for src in sources])
        chat_history.append((user_input, answer))
        return answer, chat_history, sources_text
    else:
        return "Ein Fehler ist aufgetreten: " + response.text, chat_history, ""

def upload_pdf(file):
    vectorstore = load_pdf(file)
    return "PDF erfolgreich hochgeladen."

with gr.Blocks() as demo:
    gr.Markdown("## Chatbot mit RAG (LangChain)")
    
    # Datei-Upload-Komponente
    with gr.Row():
        upload_button = gr.File(label="Lade PDF hoch", file_count="single")
        upload_button.upload(upload_pdf, inputs=upload_button, outputs="status")
    
    chatbot = gr.Chatbot(label="Chatbot")
    question = gr.Textbox(label="Deine Frage")
    sources = gr.Textbox(label="Quellen", interactive=False)
    submit = gr.Button("Senden")
    clear = gr.Button("Chatverlauf löschen")
    chat_history = gr.State([])

    submit.click(
        fn=chat_with_bot,
        inputs=[question, chat_history],
        outputs=[chatbot, chat_history, sources],
        show_progress=True,
    )
    clear.click(lambda: ([], [], ""), inputs=[], outputs=[chatbot, chat_history, sources])

    demo.launch()