la04 commited on
Commit
6e1c776
·
verified ·
1 Parent(s): 53680db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -143
app.py CHANGED
@@ -1,152 +1,67 @@
1
  import gradio as gr
2
- import requests
3
- from fastapi import FastAPI, HTTPException
4
- from pydantic import BaseModel
5
- from fastapi.responses import JSONResponse
6
- from fastapi.middleware.cors import CORSMiddleware
7
- from langchain.chains import ConversationalRetrievalChain
8
- from langchain_community.vectorstores import FAISS
9
- from langchain_community.embeddings import HuggingFaceEmbeddings
10
- from langchain_community.llms import HuggingFacePipeline
11
- from langchain.prompts import PromptTemplate
12
- from langchain_community.document_loaders import PyPDFLoader
13
- from langchain.text_splitter import RecursiveCharacterTextSplitter
14
- from transformers import pipeline, AutoModelForQuestionAnswering
15
- import logging
16
-
17
- # Logging einrichten
18
- logging.basicConfig(level=logging.INFO)
19
-
20
- # 1. Initialisiere das Embedding-Modell von Hugging Face
21
- embedding_model_name = "sentence-transformers/all-MiniLM-L6-v2"
22
- embeddings = HuggingFaceEmbeddings(model_name=embedding_model_name)
23
-
24
- # FastAPI Backend
25
- app = FastAPI()
26
-
27
- # CORS-Middleware hinzufügen
28
- app.add_middleware(
29
- CORSMiddleware,
30
- allow_origins=["*"], # Erlaubt alle Ursprünge
31
- allow_credentials=True,
32
- allow_methods=["*"], # Alle Methoden
33
- allow_headers=["*"], # Alle Header
34
- )
35
-
36
- class QueryRequest(BaseModel):
37
- question: str
38
- chat_history: list
39
- num_sources: int = 3 # Standardmäßig 3 Quellen zurückgeben
40
-
41
- @app.post("/query")
42
- async def query(request: QueryRequest):
43
- try:
44
- logging.info(f"Received query: {request.question}")
45
-
46
- result = qa_chain({"question": request.question, "chat_history": request.chat_history})
47
-
48
- # Begrenze die Anzahl der zurückgegebenen Quellen
49
- sources = [
50
- {"source": doc.metadata["source"], "content": doc.page_content}
51
- for doc in result["source_documents"][:request.num_sources]
52
- ]
53
-
54
- response = {
55
- "answer": result["answer"],
56
- "sources": sources
57
- }
58
-
59
- logging.info(f"Answer: {response['answer']}")
60
- return JSONResponse(content=response)
61
 
62
- except Exception as e:
63
- logging.error(f"Error processing query: {str(e)}")
64
- raise HTTPException(status_code=500, detail=str(e))
65
-
66
- # 2. Lade PDF-Dokumente und extrahiere Inhalte
67
- def load_pdf(file):
68
- loader = PyPDFLoader(file.name)
69
- pages = loader.load()
70
-
71
- # Text-Splitting für eine bessere Genauigkeit bei der Abfrage
72
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
73
-
74
- # Seiten mit einer fortlaufenden Seitenzahl versehen und in Document-Objekte umwandeln
75
- documents = [
76
- {"content": page.page_content, "metadata": {"source": f"Seite {i + 1}"}}
77
- for i, page in enumerate(pages)
78
- ]
79
- # Texte splitten
80
- split_texts = text_splitter.split_documents(pages)
81
-
82
- # FAISS Vektorspeicher mit den gesplitteten Texten
83
- vectorstore = FAISS.from_texts([text.page_content for text in split_texts], embeddings, metadatas=[text.metadata for text in split_texts])
84
- return vectorstore
85
-
86
- # 3. Initialisiere ein Frage-Antwort Modell
87
- qa_pipeline = pipeline("question-answering", model="deepset/roberta-base-squad2", tokenizer="deepset/roberta-base-squad2")
88
-
89
- # 4. Konfiguriere den Prompt für die Retrieval-Augmented-Generation
90
- prompt_template = """
91
- Du bist ein hilfreicher KI-Chatbot. Nutze die folgenden Informationen, um die Frage zu beantworten:
92
- {context}
93
-
94
- Frage: {question}
95
- Antwort:
96
- """
97
- prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
98
-
99
- # 5. Baue die ConversationalRetrievalChain
100
- qa_chain = ConversationalRetrievalChain(
101
- retriever=vectorstore.as_retriever(),
102
- llm=HuggingFacePipeline(pipeline=qa_pipeline),
103
- prompt=prompt,
104
- return_source_documents=True
105
- )
106
 
107
- # Gradio Frontend
108
- API_URL = "http://localhost:8000/query" # Kann auf die interne API umgestellt werden
 
 
109
 
110
- def chat_with_bot(user_input, chat_history):
111
- response = requests.post(
112
- API_URL,
113
- json={"question": user_input, "chat_history": chat_history}
114
- )
115
- if response.status_code == 200:
116
- data = response.json()
117
- answer = data["answer"]
118
- sources = data.get("sources", [])
119
- sources_text = "\n".join([f"{src['source']}: {src['content']}" for src in sources])
120
- chat_history.append((user_input, answer))
121
- return answer, chat_history, sources_text
122
- else:
123
- return "Ein Fehler ist aufgetreten: " + response.text, chat_history, ""
124
 
125
- def upload_pdf(file):
126
- vectorstore = load_pdf(file)
127
- return "PDF erfolgreich hochgeladen."
 
 
 
128
 
129
- with gr.Blocks() as demo:
130
- gr.Markdown("## Chatbot mit RAG (LangChain)")
131
 
132
- # Datei-Upload-Komponente
133
- with gr.Row():
134
- upload_button = gr.File(label="Lade PDF hoch", file_count="single")
135
- upload_button.upload(upload_pdf, inputs=upload_button, outputs="status")
136
-
137
- chatbot = gr.Chatbot(label="Chatbot")
138
- question = gr.Textbox(label="Deine Frage")
139
- sources = gr.Textbox(label="Quellen", interactive=False)
140
- submit = gr.Button("Senden")
141
- clear = gr.Button("Chatverlauf löschen")
142
- chat_history = gr.State([])
143
-
144
- submit.click(
145
- fn=chat_with_bot,
146
- inputs=[question, chat_history],
147
- outputs=[chatbot, chat_history, sources],
148
- show_progress=True,
149
  )
150
- clear.click(lambda: ([], [], ""), inputs=[], outputs=[chatbot, chat_history, sources])
151
 
152
- demo.launch()
 
 
1
  import gradio as gr
2
+ import easyocr
3
+ from pdf2image import convert_from_path
4
+ from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
5
+ import os
6
+
7
+ # Initialisiere EasyOCR für Deutsch
8
+ reader = easyocr.Reader(['de']) # für die deutsche Sprache
9
+
10
+ # Initialisiere das deutsche Modell und den Tokenizer für RAG
11
+ model_name = "deepset/gbert-base" # Beispiel für ein deutsches Modell
12
+ tokenizer = RagTokenizer.from_pretrained(model_name)
13
+ model = RagSequenceForGeneration.from_pretrained(model_name)
14
+ retriever = RagRetriever.from_pretrained(model_name, index_name="exact", use_dummy_dataset=True)
15
+
16
+ # OCR-Funktion: Konvertiert PDF zu Bildern und extrahiert Text mit EasyOCR
17
+ def extract_text_from_pdf(file):
18
+ # Konvertiere PDF-Seiten in Bilder
19
+ images = convert_from_path(file.name, 300) # 300 DPI für bessere Qualität
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ text = ""
22
+ # Extrahiere Text aus jedem Bild mit EasyOCR
23
+ for image in images:
24
+ ocr_result = reader.readtext(image)
25
+ for detection in ocr_result:
26
+ text += detection[1] + "\n"
27
+
28
+ return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ # Funktion zur Generierung einer Antwort basierend auf dem hochgeladenen Dokument
31
+ def get_rag_answer(input_message, uploaded_file):
32
+ # Extrahiere den Text aus dem hochgeladenen PDF-Dokument mit OCR
33
+ document_text = extract_text_from_pdf(uploaded_file)
34
 
35
+ # Simuliere den Retrieval-Mechanismus, indem wir den extrahierten Text verwenden
36
+ inputs = tokenizer(input_message, return_tensors="pt")
37
+ retrieved_docs = retriever.retrieve(input_ids=inputs["input_ids"])
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ # Kombiniere die extrahierten Dokumente und frage das Modell zur Generierung einer Antwort
40
+ input_ids = tokenizer(input_message, return_tensors="pt").input_ids
41
+ generated_ids = model.generate(input_ids=input_ids,
42
+ decoder_start_token_id=model.config.pad_token_id,
43
+ num_beams=4,
44
+ max_length=100)
45
 
46
+ answer = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
 
47
 
48
+ # Hier könnten wir eine Referenz (z.B. Absatz, Seite) in die Antwort einfügen
49
+ references = "Referenz: Abschnitt X, Seite Y (aus Dokument)" # Füge diese Infos hinzu, falls möglich
50
+
51
+ return f"{answer} \n\n{references}"
52
+
53
+ # Gradio-Oberfläche
54
+ def gradio_interface():
55
+ iface = gr.Interface(
56
+ fn=get_rag_answer,
57
+ inputs=[
58
+ gr.Textbox(label="User Input", placeholder="Stellen Sie eine Frage..."),
59
+ gr.File(label="Laden Sie ein PDF-Dokument hoch", type="file") # Ermöglicht das Hochladen von PDF-Dateien
60
+ ],
61
+ outputs=gr.Textbox(label="Antwort des Chatbots"),
62
+ live=True # Sofortige Antwortgenerierung
 
 
63
  )
64
+ iface.launch()
65
 
66
+ # Starte die Gradio-Oberfläche
67
+ gradio_interface()