Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -2,18 +2,15 @@ import os
|
|
2 |
import gradio as gr
|
3 |
from langchain_community.document_loaders import PyPDFLoader
|
4 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
5 |
-
from
|
6 |
from langchain_community.vectorstores import FAISS
|
7 |
from langchain.chains import ConversationalRetrievalChain
|
8 |
from langchain.memory import ConversationBufferMemory
|
9 |
-
from langchain_community.llms import HuggingFacePipeline
|
10 |
from transformers import pipeline
|
11 |
|
12 |
-
# Embeddings- und LLM-Modelle
|
13 |
EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
14 |
LLM_MODEL_NAME = "google/flan-t5-small"
|
15 |
|
16 |
-
# **Dokumente laden und aufteilen**
|
17 |
def load_and_split_docs(list_file_path):
|
18 |
if not list_file_path:
|
19 |
return [], "Fehler: Keine Dokumente gefunden!"
|
@@ -24,12 +21,10 @@ def load_and_split_docs(list_file_path):
|
|
24 |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=32)
|
25 |
return text_splitter.split_documents(documents)
|
26 |
|
27 |
-
# **Vektor-Datenbank mit FAISS erstellen**
|
28 |
def create_db(docs):
|
29 |
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
|
30 |
return FAISS.from_documents(docs, embeddings)
|
31 |
|
32 |
-
# **Datenbank initialisieren**
|
33 |
def initialize_database(list_file_obj):
|
34 |
if not list_file_obj or all(x is None for x in list_file_obj):
|
35 |
return None, "Fehler: Keine Dateien hochgeladen!"
|
@@ -38,14 +33,12 @@ def initialize_database(list_file_obj):
|
|
38 |
vector_db = create_db(doc_splits)
|
39 |
return vector_db, "Datenbank erfolgreich erstellt!"
|
40 |
|
41 |
-
# **LLM-Kette initialisieren (Wrapper)**
|
42 |
def initialize_llm_chain_wrapper(temperature, max_tokens, vector_db):
|
43 |
if vector_db is None:
|
44 |
return None, "Fehler: Vektordatenbank nicht initialisiert!"
|
45 |
qa_chain = initialize_llm_chain(temperature, max_tokens, vector_db)
|
46 |
return qa_chain, "QA-Chatbot ist bereit!"
|
47 |
|
48 |
-
# **LLM-Kette erstellen**
|
49 |
def initialize_llm_chain(temperature, max_tokens, vector_db):
|
50 |
local_pipeline = pipeline(
|
51 |
"text2text-generation",
|
@@ -63,27 +56,30 @@ def initialize_llm_chain(temperature, max_tokens, vector_db):
|
|
63 |
return_source_documents=True
|
64 |
)
|
65 |
|
66 |
-
# **Konversation mit QA-Kette führen**
|
67 |
def conversation(qa_chain, message, history):
|
68 |
if qa_chain is None:
|
69 |
-
return None, "Der QA-Chain wurde nicht initialisiert!", history
|
70 |
if not message.strip():
|
71 |
-
return qa_chain, "Bitte eine Frage eingeben!", history
|
72 |
try:
|
73 |
-
|
|
|
74 |
response_text = response["answer"]
|
75 |
sources = [doc.metadata["source"] for doc in response["source_documents"]]
|
76 |
sources_text = "\n".join(sources) if sources else "Keine Quellen verfügbar"
|
77 |
-
|
|
|
|
|
|
|
|
|
78 |
except Exception as e:
|
79 |
-
return qa_chain, f"Fehler: {str(e)}", history
|
80 |
|
81 |
-
# **Gradio-Demo erstellen**
|
82 |
def demo():
|
83 |
with gr.Blocks() as demo:
|
84 |
-
vector_db = gr.State()
|
85 |
-
qa_chain = gr.State()
|
86 |
-
chat_history = gr.State([])
|
87 |
|
88 |
gr.HTML("<center><h1>RAG Chatbot mit FAISS und lokalen Modellen</h1></center>")
|
89 |
with gr.Row():
|
@@ -100,26 +96,11 @@ def demo():
|
|
100 |
msg = gr.Textbox(label="Deine Frage:", placeholder="Frage eingeben...")
|
101 |
submit_btn = gr.Button("Absenden")
|
102 |
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
inputs=[document], # Eingabe der hochgeladenen Dokumente
|
107 |
-
outputs=[vector_db, db_status] # Ausgabe: Vektor-Datenbank und Status
|
108 |
-
)
|
109 |
-
|
110 |
-
qachain_btn.click(
|
111 |
-
initialize_llm_chain_wrapper,
|
112 |
-
inputs=[slider_temperature, slider_max_tokens, vector_db],
|
113 |
-
outputs=[qa_chain, db_status]
|
114 |
-
)
|
115 |
|
116 |
-
|
117 |
-
conversation,
|
118 |
-
inputs=[qa_chain, msg, chat_history], # Chatkette, Nutzerfrage, Chatverlauf
|
119 |
-
outputs=[qa_chain, chatbot, chat_history] # Antwort der Kette, Chatbot-Ausgabe, neuer Verlauf
|
120 |
-
)
|
121 |
-
|
122 |
-
demo.launch(debug=True) # Ohne queue=True
|
123 |
|
124 |
if __name__ == "__main__":
|
125 |
demo()
|
|
|
2 |
import gradio as gr
|
3 |
from langchain_community.document_loaders import PyPDFLoader
|
4 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
5 |
+
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
|
6 |
from langchain_community.vectorstores import FAISS
|
7 |
from langchain.chains import ConversationalRetrievalChain
|
8 |
from langchain.memory import ConversationBufferMemory
|
|
|
9 |
from transformers import pipeline
|
10 |
|
|
|
11 |
EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
12 |
LLM_MODEL_NAME = "google/flan-t5-small"
|
13 |
|
|
|
14 |
def load_and_split_docs(list_file_path):
|
15 |
if not list_file_path:
|
16 |
return [], "Fehler: Keine Dokumente gefunden!"
|
|
|
21 |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=32)
|
22 |
return text_splitter.split_documents(documents)
|
23 |
|
|
|
24 |
def create_db(docs):
|
25 |
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
|
26 |
return FAISS.from_documents(docs, embeddings)
|
27 |
|
|
|
28 |
def initialize_database(list_file_obj):
|
29 |
if not list_file_obj or all(x is None for x in list_file_obj):
|
30 |
return None, "Fehler: Keine Dateien hochgeladen!"
|
|
|
33 |
vector_db = create_db(doc_splits)
|
34 |
return vector_db, "Datenbank erfolgreich erstellt!"
|
35 |
|
|
|
36 |
def initialize_llm_chain_wrapper(temperature, max_tokens, vector_db):
|
37 |
if vector_db is None:
|
38 |
return None, "Fehler: Vektordatenbank nicht initialisiert!"
|
39 |
qa_chain = initialize_llm_chain(temperature, max_tokens, vector_db)
|
40 |
return qa_chain, "QA-Chatbot ist bereit!"
|
41 |
|
|
|
42 |
def initialize_llm_chain(temperature, max_tokens, vector_db):
|
43 |
local_pipeline = pipeline(
|
44 |
"text2text-generation",
|
|
|
56 |
return_source_documents=True
|
57 |
)
|
58 |
|
|
|
59 |
def conversation(qa_chain, message, history):
|
60 |
if qa_chain is None:
|
61 |
+
return None, [{"role": "system", "content": "Der QA-Chain wurde nicht initialisiert!"}], history
|
62 |
if not message.strip():
|
63 |
+
return qa_chain, [{"role": "system", "content": "Bitte eine Frage eingeben!"}], history
|
64 |
try:
|
65 |
+
history = history[-5:] # Nur die letzten 5 Nachrichten übergeben
|
66 |
+
response = qa_chain.invoke({"question": message, "chat_history": history})
|
67 |
response_text = response["answer"]
|
68 |
sources = [doc.metadata["source"] for doc in response["source_documents"]]
|
69 |
sources_text = "\n".join(sources) if sources else "Keine Quellen verfügbar"
|
70 |
+
formatted_response = [
|
71 |
+
{"role": "user", "content": message},
|
72 |
+
{"role": "assistant", "content": f"{response_text}\n\n**Quellen:**\n{sources_text}"}
|
73 |
+
]
|
74 |
+
return qa_chain, formatted_response, history + [(message, response_text)]
|
75 |
except Exception as e:
|
76 |
+
return qa_chain, [{"role": "system", "content": f"Fehler: {str(e)}"}], history
|
77 |
|
|
|
78 |
def demo():
|
79 |
with gr.Blocks() as demo:
|
80 |
+
vector_db = gr.State()
|
81 |
+
qa_chain = gr.State()
|
82 |
+
chat_history = gr.State([])
|
83 |
|
84 |
gr.HTML("<center><h1>RAG Chatbot mit FAISS und lokalen Modellen</h1></center>")
|
85 |
with gr.Row():
|
|
|
96 |
msg = gr.Textbox(label="Deine Frage:", placeholder="Frage eingeben...")
|
97 |
submit_btn = gr.Button("Absenden")
|
98 |
|
99 |
+
db_btn.click(initialize_database, [document], [vector_db, db_status])
|
100 |
+
qachain_btn.click(initialize_llm_chain_wrapper, [slider_temperature, slider_max_tokens, vector_db], [qa_chain])
|
101 |
+
submit_btn.click(conversation, [qa_chain, msg, chat_history], [qa_chain, chatbot, chat_history])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
+
demo.launch(debug=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
|
105 |
if __name__ == "__main__":
|
106 |
demo()
|