la04 commited on
Commit
1474017
·
verified ·
1 Parent(s): 44a313b

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +152 -0
app.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ from fastapi import FastAPI, HTTPException
4
+ from pydantic import BaseModel
5
+ from fastapi.responses import JSONResponse
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from langchain.chains import ConversationalRetrievalChain
8
+ from langchain_community.vectorstores import FAISS
9
+ from langchain_community.embeddings import HuggingFaceEmbeddings
10
+ from langchain_community.llms import HuggingFacePipeline
11
+ from langchain.prompts import PromptTemplate
12
+ from langchain_community.document_loaders import PyPDFLoader
13
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
14
+ from transformers import pipeline, AutoModelForQuestionAnswering
15
+ import logging
16
+
17
+ # Logging einrichten
18
+ logging.basicConfig(level=logging.INFO)
19
+
20
+ # 1. Initialisiere das Embedding-Modell von Hugging Face
21
+ embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2"
22
+ embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)
23
+
24
+ # FastAPI Backend
25
+ app = FastAPI()
26
+
27
+ # CORS-Middleware hinzufügen
28
+ app.add_middleware(
29
+ CORSMiddleware,
30
+ allow_origins=["*"], # Erlaubt alle Ursprünge
31
+ allow_credentials=True,
32
+ allow_methods=["*"], # Alle Methoden
33
+ allow_headers=["*"], # Alle Header
34
+ )
35
+
36
+ class QueryRequest(BaseModel):
37
+ question: str
38
+ chat_history: list
39
+ num_sources: int = 3 # Standardmäßig 3 Quellen zurückgeben
40
+
41
+ @app.post("/query")
42
+ async def query(request: QueryRequest):
43
+ try:
44
+ logging.info(f"Received query: {request.question}")
45
+
46
+ result = qa_chain({"question": request.question, "chat_history": request.chat_history})
47
+
48
+ # Begrenze die Anzahl der zurückgegebenen Quellen
49
+ sources = [
50
+ {"source": doc.metadata["source"], "content": doc.page_content}
51
+ for doc in result["source_documents"][:request.num_sources]
52
+ ]
53
+
54
+ response = {
55
+ "answer": result["answer"],
56
+ "sources": sources
57
+ }
58
+
59
+ logging.info(f"Answer: {response['answer']}")
60
+ return JSONResponse(content=response)
61
+
62
+ except Exception as e:
63
+ logging.error(f"Error processing query: {str(e)}")
64
+ raise HTTPException(status_code=500, detail=str(e))
65
+
66
+ # 2. Lade PDF-Dokumente und extrahiere Inhalte
67
+ def load_pdf(file):
68
+ loader = PyPDFLoader(file.name)
69
+ pages = loader.load()
70
+
71
+ # Text-Splitting für eine bessere Genauigkeit bei der Abfrage
72
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
73
+
74
+ # Seiten mit einer fortlaufenden Seitenzahl versehen und in Document-Objekte umwandeln
75
+ documents = [
76
+ {"content": page.page_content, "metadata": {"source": f"Seite {i + 1}"}}
77
+ for i, page in enumerate(pages)
78
+ ]
79
+ # Texte splitten
80
+ split_texts = text_splitter.split_documents(pages)
81
+
82
+ # FAISS Vektorspeicher mit den gesplitteten Texten
83
+ vectorstore = FAISS.from_texts([text.page_content for text in split_texts], embeddings, metadatas=[text.metadata for text in split_texts])
84
+ return vectorstore
85
+
86
+ # 3. Initialisiere ein Frage-Antwort Modell
87
+ qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2", tokenizer="deepset/roberta-base-squad2")
88
+
89
+ # 4. Konfiguriere den Prompt für die Retrieval-Augmented-Generation
90
+ prompt_template = """
91
+ Du bist ein hilfreicher KI-Chatbot. Nutze die folgenden Informationen, um die Frage zu beantworten:
92
+ {context}
93
+
94
+ Frage: {question}
95
+ Antwort:
96
+ """
97
+ prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
98
+
99
+ # 5. Baue die ConversationalRetrievalChain
100
+ qa_chain = ConversationalRetrievalChain(
101
+ retriever=vectorstore.as_retriever(),
102
+ llm=HuggingFacePipeline(pipeline=qa_pipeline),
103
+ prompt=prompt,
104
+ return_source_documents=True
105
+ )
106
+
107
+ # Gradio Frontend
108
+ API_URL = "http://localhost:8000/query" # Kann auf die interne API umgestellt werden
109
+
110
+ def chat_with_bot(user_input, chat_history):
111
+ response = requests.post(
112
+ API_URL,
113
+ json={"question": user_input, "chat_history": chat_history}
114
+ )
115
+ if response.status_code == 200:
116
+ data = response.json()
117
+ answer = data["answer"]
118
+ sources = data.get("sources", [])
119
+ sources_text = "\n".join([f"{src['source']}: {src['content']}" for src in sources])
120
+ chat_history.append((user_input, answer))
121
+ return answer, chat_history, sources_text
122
+ else:
123
+ return "Ein Fehler ist aufgetreten: " + response.text, chat_history, ""
124
+
125
+ def upload_pdf(file):
126
+ vectorstore = load_pdf(file)
127
+ return "PDF erfolgreich hochgeladen."
128
+
129
+ with gr.Blocks() as demo:
130
+ gr.Markdown("## Chatbot mit RAG (LangChain)")
131
+
132
+ # Datei-Upload-Komponente
133
+ with gr.Row():
134
+ upload_button = gr.File(label="Lade PDF hoch", file_count="single")
135
+ upload_button.upload(upload_pdf, inputs=upload_button, outputs="status")
136
+
137
+ chatbot = gr.Chatbot(label="Chatbot")
138
+ question = gr.Textbox(label="Deine Frage")
139
+ sources = gr.Textbox(label="Quellen", interactive=False)
140
+ submit = gr.Button("Senden")
141
+ clear = gr.Button("Chatverlauf löschen")
142
+ chat_history = gr.State([])
143
+
144
+ submit.click(
145
+ fn=chat_with_bot,
146
+ inputs=[question, chat_history],
147
+ outputs=[chatbot, chat_history, sources],
148
+ show_progress=True,
149
+ )
150
+ clear.click(lambda: ([], [], ""), inputs=[], outputs=[chatbot, chat_history, sources])
151
+
152
+ demo.launch()