la04 commited on
Commit
a344264
·
verified ·
1 Parent(s): 2c0df56

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -7
app.py CHANGED
@@ -11,6 +11,7 @@ from transformers import pipeline
11
  EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
12
  LLM_MODEL_NAME = "google/flan-t5-small"
13
 
 
14
  def load_and_split_docs(list_file_path):
15
  if not list_file_path:
16
  return [], "Fehler: Keine Dokumente gefunden!"
@@ -21,31 +22,46 @@ def load_and_split_docs(list_file_path):
21
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=32)
22
  return text_splitter.split_documents(documents)
23
 
 
24
  def create_db(docs):
25
  embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
26
  return FAISS.from_documents(docs, embeddings)
27
 
 
28
  def initialize_database(list_file_obj):
29
  if not list_file_obj or all(x is None for x in list_file_obj):
30
  return None, "Fehler: Keine Dateien hochgeladen!"
31
  list_file_path = [x.name for x in list_file_obj if x is not None]
32
  doc_splits = load_and_split_docs(list_file_path)
33
  vector_db = create_db(doc_splits)
 
34
  return vector_db, "Datenbank erfolgreich erstellt!"
35
 
 
36
  def initialize_llm_chain_wrapper(temperature, max_tokens, vector_db):
37
  if vector_db is None:
38
- return None, "Fehler: Vektordatenbank nicht initialisiert!"
39
- qa_chain = initialize_llm_chain(temperature, max_tokens, vector_db)
40
- return qa_chain, "QA-Chatbot ist bereit!"
 
 
 
 
 
 
 
 
41
 
 
42
  def initialize_llm_chain(temperature, max_tokens, vector_db):
 
43
  local_pipeline = pipeline(
44
  "text2text-generation",
45
  model=LLM_MODEL_NAME,
46
  max_length=max_tokens,
47
  temperature=temperature
48
  )
 
49
  llm = HuggingFacePipeline(pipeline=local_pipeline)
50
  memory = ConversationBufferMemory(memory_key="chat_history")
51
  retriever = vector_db.as_retriever()
@@ -56,12 +72,14 @@ def initialize_llm_chain(temperature, max_tokens, vector_db):
56
  return_source_documents=True
57
  )
58
 
 
59
  def conversation(qa_chain, message, history):
60
  if qa_chain is None:
61
  return None, [{"role": "system", "content": "Der QA-Chain wurde nicht initialisiert!"}], history
62
  if not message.strip():
63
  return qa_chain, [{"role": "system", "content": "Bitte eine Frage eingeben!"}], history
64
  try:
 
65
  history = history[-5:] # Beschränke den Verlauf auf die letzten 5 Nachrichten
66
  response = qa_chain.invoke({"question": message, "chat_history": history})
67
  response_text = response["answer"]
@@ -73,15 +91,18 @@ def conversation(qa_chain, message, history):
73
  {"role": "user", "content": message},
74
  {"role": "assistant", "content": f"{response_text}\n\n**Quellen:**\n{sources_text}"}
75
  ]
 
76
  return qa_chain, formatted_response, formatted_response
77
  except Exception as e:
 
78
  return qa_chain, [{"role": "system", "content": f"Fehler: {str(e)}"}], history
79
 
 
80
  def demo():
81
  with gr.Blocks() as demo:
82
- vector_db = gr.State()
83
- qa_chain = gr.State()
84
- chat_history = gr.State([])
85
 
86
  gr.HTML("<center><h1>RAG Chatbot mit FAISS und lokalen Modellen</h1></center>")
87
  with gr.Row():
@@ -99,7 +120,7 @@ def demo():
99
  submit_btn = gr.Button("Absenden")
100
 
101
  db_btn.click(initialize_database, [document], [vector_db, db_status])
102
- qachain_btn.click(initialize_llm_chain_wrapper, [slider_temperature, slider_max_tokens, vector_db], [qa_chain])
103
  submit_btn.click(conversation, [qa_chain, msg, chat_history], [qa_chain, chatbot, chat_history])
104
 
105
  demo.launch(debug=True)
 
11
  EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
12
  LLM_MODEL_NAME = "google/flan-t5-small"
13
 
14
+ # **Dokumente laden und aufteilen**
15
  def load_and_split_docs(list_file_path):
16
  if not list_file_path:
17
  return [], "Fehler: Keine Dokumente gefunden!"
 
22
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=32)
23
  return text_splitter.split_documents(documents)
24
 
25
+ # **Vektor-Datenbank mit FAISS erstellen**
26
  def create_db(docs):
27
  embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
28
  return FAISS.from_documents(docs, embeddings)
29
 
30
+ # **Datenbank initialisieren**
31
  def initialize_database(list_file_obj):
32
  if not list_file_obj or all(x is None for x in list_file_obj):
33
  return None, "Fehler: Keine Dateien hochgeladen!"
34
  list_file_path = [x.name for x in list_file_obj if x is not None]
35
  doc_splits = load_and_split_docs(list_file_path)
36
  vector_db = create_db(doc_splits)
37
+ print("Vektordatenbank erfolgreich erstellt!")
38
  return vector_db, "Datenbank erfolgreich erstellt!"
39
 
40
+ # **QA-Kette initialisieren (Wrapper)**
41
  def initialize_llm_chain_wrapper(temperature, max_tokens, vector_db):
42
  if vector_db is None:
43
+ print("Fehler: Vektordatenbank nicht vorhanden!")
44
+ return None, "Fehler: Die Vektordatenbank wurde nicht erstellt! Bitte lade ein PDF hoch und klicke 'Erstelle Vektordatenbank'."
45
+
46
+ try:
47
+ print("Initialisiere QA-Chatbot...")
48
+ qa_chain = initialize_llm_chain(temperature, max_tokens, vector_db)
49
+ print("QA-Chatbot erfolgreich initialisiert!")
50
+ return qa_chain, "QA-Chatbot ist bereit!"
51
+ except Exception as e:
52
+ print(f"Fehler bei der Initialisierung: {str(e)}")
53
+ return None, f"Fehler bei der Initialisierung: {str(e)}"
54
 
55
+ # **LLM-Kette erstellen**
56
  def initialize_llm_chain(temperature, max_tokens, vector_db):
57
+ print("Lade Modellpipeline...")
58
  local_pipeline = pipeline(
59
  "text2text-generation",
60
  model=LLM_MODEL_NAME,
61
  max_length=max_tokens,
62
  temperature=temperature
63
  )
64
+ print(f"Modell {LLM_MODEL_NAME} erfolgreich geladen.")
65
  llm = HuggingFacePipeline(pipeline=local_pipeline)
66
  memory = ConversationBufferMemory(memory_key="chat_history")
67
  retriever = vector_db.as_retriever()
 
72
  return_source_documents=True
73
  )
74
 
75
+ # **Konversation mit QA-Kette führen**
76
  def conversation(qa_chain, message, history):
77
  if qa_chain is None:
78
  return None, [{"role": "system", "content": "Der QA-Chain wurde nicht initialisiert!"}], history
79
  if not message.strip():
80
  return qa_chain, [{"role": "system", "content": "Bitte eine Frage eingeben!"}], history
81
  try:
82
+ print(f"Frage: {message}")
83
  history = history[-5:] # Beschränke den Verlauf auf die letzten 5 Nachrichten
84
  response = qa_chain.invoke({"question": message, "chat_history": history})
85
  response_text = response["answer"]
 
91
  {"role": "user", "content": message},
92
  {"role": "assistant", "content": f"{response_text}\n\n**Quellen:**\n{sources_text}"}
93
  ]
94
+ print("Antwort erfolgreich generiert.")
95
  return qa_chain, formatted_response, formatted_response
96
  except Exception as e:
97
+ print(f"Fehler während der Konversation: {str(e)}")
98
  return qa_chain, [{"role": "system", "content": f"Fehler: {str(e)}"}], history
99
 
100
+ # **Gradio-Demo erstellen**
101
  def demo():
102
  with gr.Blocks() as demo:
103
+ vector_db = gr.State() # Zustand für die Vektordatenbank
104
+ qa_chain = gr.State() # Zustand für den QA-Chain
105
+ chat_history = gr.State([]) # Chatverlauf speichern
106
 
107
  gr.HTML("<center><h1>RAG Chatbot mit FAISS und lokalen Modellen</h1></center>")
108
  with gr.Row():
 
120
  submit_btn = gr.Button("Absenden")
121
 
122
  db_btn.click(initialize_database, [document], [vector_db, db_status])
123
+ qachain_btn.click(initialize_llm_chain_wrapper, [slider_temperature, slider_max_tokens, vector_db], [qa_chain, db_status])
124
  submit_btn.click(conversation, [qa_chain, msg, chat_history], [qa_chain, chatbot, chat_history])
125
 
126
  demo.launch(debug=True)