la04 commited on
Commit
c3bf080
·
verified ·
1 Parent(s): 625e6f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -37
app.py CHANGED
@@ -2,18 +2,15 @@ import os
2
  import gradio as gr
3
  from langchain_community.document_loaders import PyPDFLoader
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
- from langchain_community.embeddings import HuggingFaceEmbeddings
6
  from langchain_community.vectorstores import FAISS
7
  from langchain.chains import ConversationalRetrievalChain
8
  from langchain.memory import ConversationBufferMemory
9
- from langchain_community.llms import HuggingFacePipeline
10
  from transformers import pipeline
11
 
12
- # Embeddings- und LLM-Modelle
13
  EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
14
  LLM_MODEL_NAME = "google/flan-t5-small"
15
 
16
- # **Dokumente laden und aufteilen**
17
  def load_and_split_docs(list_file_path):
18
  if not list_file_path:
19
  return [], "Fehler: Keine Dokumente gefunden!"
@@ -24,12 +21,10 @@ def load_and_split_docs(list_file_path):
24
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=32)
25
  return text_splitter.split_documents(documents)
26
 
27
- # **Vektor-Datenbank mit FAISS erstellen**
28
  def create_db(docs):
29
  embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
30
  return FAISS.from_documents(docs, embeddings)
31
 
32
- # **Datenbank initialisieren**
33
  def initialize_database(list_file_obj):
34
  if not list_file_obj or all(x is None for x in list_file_obj):
35
  return None, "Fehler: Keine Dateien hochgeladen!"
@@ -38,14 +33,12 @@ def initialize_database(list_file_obj):
38
  vector_db = create_db(doc_splits)
39
  return vector_db, "Datenbank erfolgreich erstellt!"
40
 
41
- # **LLM-Kette initialisieren (Wrapper)**
42
  def initialize_llm_chain_wrapper(temperature, max_tokens, vector_db):
43
  if vector_db is None:
44
  return None, "Fehler: Vektordatenbank nicht initialisiert!"
45
  qa_chain = initialize_llm_chain(temperature, max_tokens, vector_db)
46
  return qa_chain, "QA-Chatbot ist bereit!"
47
 
48
- # **LLM-Kette erstellen**
49
  def initialize_llm_chain(temperature, max_tokens, vector_db):
50
  local_pipeline = pipeline(
51
  "text2text-generation",
@@ -63,27 +56,30 @@ def initialize_llm_chain(temperature, max_tokens, vector_db):
63
  return_source_documents=True
64
  )
65
 
66
- # **Konversation mit QA-Kette führen**
67
  def conversation(qa_chain, message, history):
68
  if qa_chain is None:
69
- return None, "Der QA-Chain wurde nicht initialisiert!", history
70
  if not message.strip():
71
- return qa_chain, "Bitte eine Frage eingeben!", history
72
  try:
73
- response = qa_chain({"question": message, "chat_history": history})
 
74
  response_text = response["answer"]
75
  sources = [doc.metadata["source"] for doc in response["source_documents"]]
76
  sources_text = "\n".join(sources) if sources else "Keine Quellen verfügbar"
77
- return qa_chain, f"{response_text}\n\n**Quellen:**\n{sources_text}", history + [(message, response_text)]
 
 
 
 
78
  except Exception as e:
79
- return qa_chain, f"Fehler: {str(e)}", history
80
 
81
- # **Gradio-Demo erstellen**
82
  def demo():
83
  with gr.Blocks() as demo:
84
- vector_db = gr.State() # Zustand für die Vektordatenbank
85
- qa_chain = gr.State() # Zustand für den QA-Chain
86
- chat_history = gr.State([]) # Chatverlauf speichern
87
 
88
  gr.HTML("<center><h1>RAG Chatbot mit FAISS und lokalen Modellen</h1></center>")
89
  with gr.Row():
@@ -100,26 +96,11 @@ def demo():
100
  msg = gr.Textbox(label="Deine Frage:", placeholder="Frage eingeben...")
101
  submit_btn = gr.Button("Absenden")
102
 
103
- # **Button-Events definieren**
104
- db_btn.click(
105
- initialize_database,
106
- inputs=[document], # Eingabe der hochgeladenen Dokumente
107
- outputs=[vector_db, db_status] # Ausgabe: Vektor-Datenbank und Status
108
- )
109
-
110
- qachain_btn.click(
111
- initialize_llm_chain_wrapper,
112
- inputs=[slider_temperature, slider_max_tokens, vector_db],
113
- outputs=[qa_chain, db_status]
114
- )
115
 
116
- submit_btn.click(
117
- conversation,
118
- inputs=[qa_chain, msg, chat_history], # Chatkette, Nutzerfrage, Chatverlauf
119
- outputs=[qa_chain, chatbot, chat_history] # Antwort der Kette, Chatbot-Ausgabe, neuer Verlauf
120
- )
121
-
122
- demo.launch(debug=True) # Ohne queue=True
123
 
124
  if __name__ == "__main__":
125
  demo()
 
2
  import gradio as gr
3
  from langchain_community.document_loaders import PyPDFLoader
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
6
  from langchain_community.vectorstores import FAISS
7
  from langchain.chains import ConversationalRetrievalChain
8
  from langchain.memory import ConversationBufferMemory
 
9
  from transformers import pipeline
10
 
 
11
  EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
12
  LLM_MODEL_NAME = "google/flan-t5-small"
13
 
 
14
  def load_and_split_docs(list_file_path):
15
  if not list_file_path:
16
  return [], "Fehler: Keine Dokumente gefunden!"
 
21
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=32)
22
  return text_splitter.split_documents(documents)
23
 
 
24
  def create_db(docs):
25
  embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
26
  return FAISS.from_documents(docs, embeddings)
27
 
 
28
  def initialize_database(list_file_obj):
29
  if not list_file_obj or all(x is None for x in list_file_obj):
30
  return None, "Fehler: Keine Dateien hochgeladen!"
 
33
  vector_db = create_db(doc_splits)
34
  return vector_db, "Datenbank erfolgreich erstellt!"
35
 
 
36
  def initialize_llm_chain_wrapper(temperature, max_tokens, vector_db):
37
  if vector_db is None:
38
  return None, "Fehler: Vektordatenbank nicht initialisiert!"
39
  qa_chain = initialize_llm_chain(temperature, max_tokens, vector_db)
40
  return qa_chain, "QA-Chatbot ist bereit!"
41
 
 
42
  def initialize_llm_chain(temperature, max_tokens, vector_db):
43
  local_pipeline = pipeline(
44
  "text2text-generation",
 
56
  return_source_documents=True
57
  )
58
 
 
59
  def conversation(qa_chain, message, history):
60
  if qa_chain is None:
61
+ return None, [{"role": "system", "content": "Der QA-Chain wurde nicht initialisiert!"}], history
62
  if not message.strip():
63
+ return qa_chain, [{"role": "system", "content": "Bitte eine Frage eingeben!"}], history
64
  try:
65
+ history = history[-5:] # Nur die letzten 5 Nachrichten übergeben
66
+ response = qa_chain.invoke({"question": message, "chat_history": history})
67
  response_text = response["answer"]
68
  sources = [doc.metadata["source"] for doc in response["source_documents"]]
69
  sources_text = "\n".join(sources) if sources else "Keine Quellen verfügbar"
70
+ formatted_response = [
71
+ {"role": "user", "content": message},
72
+ {"role": "assistant", "content": f"{response_text}\n\n**Quellen:**\n{sources_text}"}
73
+ ]
74
+ return qa_chain, formatted_response, history + [(message, response_text)]
75
  except Exception as e:
76
+ return qa_chain, [{"role": "system", "content": f"Fehler: {str(e)}"}], history
77
 
 
78
  def demo():
79
  with gr.Blocks() as demo:
80
+ vector_db = gr.State()
81
+ qa_chain = gr.State()
82
+ chat_history = gr.State([])
83
 
84
  gr.HTML("<center><h1>RAG Chatbot mit FAISS und lokalen Modellen</h1></center>")
85
  with gr.Row():
 
96
  msg = gr.Textbox(label="Deine Frage:", placeholder="Frage eingeben...")
97
  submit_btn = gr.Button("Absenden")
98
 
99
+ db_btn.click(initialize_database, [document], [vector_db, db_status])
100
+ qachain_btn.click(initialize_llm_chain_wrapper, [slider_temperature, slider_max_tokens, vector_db], [qa_chain])
101
+ submit_btn.click(conversation, [qa_chain, msg, chat_history], [qa_chain, chatbot, chat_history])
 
 
 
 
 
 
 
 
 
102
 
103
+ demo.launch(debug=True)
 
 
 
 
 
 
104
 
105
  if __name__ == "__main__":
106
  demo()