la04 commited on
Commit
0a1e2b9
·
verified ·
1 Parent(s): 8ca77ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -3
app.py CHANGED
@@ -9,9 +9,11 @@ from langchain.memory import ConversationBufferMemory
9
  from langchain_community.llms import HuggingFacePipeline
10
  from transformers import pipeline
11
 
 
12
  EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
13
  LLM_MODEL_NAME = "google/flan-t5-small"
14
 
 
15
  def load_and_split_docs(list_file_path):
16
  if not list_file_path:
17
  return [], "Fehler: Keine Dokumente gefunden!"
@@ -22,10 +24,12 @@ def load_and_split_docs(list_file_path):
22
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=32)
23
  return text_splitter.split_documents(documents)
24
 
 
25
  def create_db(docs):
26
  embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
27
  return FAISS.from_documents(docs, embeddings)
28
 
 
29
  def initialize_database(list_file_obj):
30
  if not list_file_obj or all(x is None for x in list_file_obj):
31
  return None, "Fehler: Keine Dateien hochgeladen!"
@@ -34,10 +38,18 @@ def initialize_database(list_file_obj):
34
  vector_db = create_db(doc_splits)
35
  return vector_db, "Datenbank erfolgreich erstellt!"
36
 
37
- def initialize_llm_chain(llm_model, temperature, max_tokens, vector_db):
 
 
 
 
 
 
 
 
38
  local_pipeline = pipeline(
39
  "text2text-generation",
40
- model=llm_model,
41
  max_length=max_tokens,
42
  temperature=temperature
43
  )
@@ -51,6 +63,7 @@ def initialize_llm_chain(llm_model, temperature, max_tokens, vector_db):
51
  return_source_documents=True
52
  )
53
 
 
54
  def conversation(qa_chain, message, history):
55
  if qa_chain is None:
56
  return None, "Der QA-Chain wurde nicht initialisiert!", history
@@ -63,6 +76,7 @@ def conversation(qa_chain, message, history):
63
  except Exception as e:
64
  return qa_chain, f"Fehler: {str(e)}", history
65
 
 
66
  def demo():
67
  with gr.Blocks() as demo:
68
  vector_db = gr.State()
@@ -83,8 +97,9 @@ def demo():
83
  msg = gr.Textbox(placeholder="Frage eingeben...")
84
  submit_btn = gr.Button("Absenden")
85
 
 
86
  db_btn.click(initialize_database, [document], [vector_db, db_status])
87
- qachain_btn.click(initialize_llm_chain, [LLM_MODEL_NAME, slider_temperature, slider_max_tokens, vector_db], [qa_chain])
88
  submit_btn.click(conversation, [qa_chain, msg, []], [qa_chain, "message", "history"])
89
 
90
  demo.launch(debug=True, enable_queue=True)
 
9
  from langchain_community.llms import HuggingFacePipeline
10
  from transformers import pipeline
11
 
12
+ # Embeddings- und LLM-Modelle
13
  EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
14
  LLM_MODEL_NAME = "google/flan-t5-small"
15
 
16
+ # **Dokumente laden und aufteilen**
17
  def load_and_split_docs(list_file_path):
18
  if not list_file_path:
19
  return [], "Fehler: Keine Dokumente gefunden!"
 
24
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=32)
25
  return text_splitter.split_documents(documents)
26
 
27
+ # **Vektor-Datenbank mit FAISS erstellen**
28
  def create_db(docs):
29
  embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
30
  return FAISS.from_documents(docs, embeddings)
31
 
32
+ # **Datenbank initialisieren**
33
  def initialize_database(list_file_obj):
34
  if not list_file_obj or all(x is None for x in list_file_obj):
35
  return None, "Fehler: Keine Dateien hochgeladen!"
 
38
  vector_db = create_db(doc_splits)
39
  return vector_db, "Datenbank erfolgreich erstellt!"
40
 
41
+ # **LLM-Kette initialisieren (Wrapper)**
42
+ def initialize_llm_chain_wrapper(temperature, max_tokens, vector_db):
43
+ if vector_db is None:
44
+ return None, "Fehler: Vektordatenbank nicht initialisiert!"
45
+ qa_chain = initialize_llm_chain(temperature, max_tokens, vector_db)
46
+ return qa_chain, "QA-Chatbot ist bereit!"
47
+
48
+ # **LLM-Kette erstellen**
49
+ def initialize_llm_chain(temperature, max_tokens, vector_db):
50
  local_pipeline = pipeline(
51
  "text2text-generation",
52
+ model=LLM_MODEL_NAME,
53
  max_length=max_tokens,
54
  temperature=temperature
55
  )
 
63
  return_source_documents=True
64
  )
65
 
66
+ # **Konversation mit QA-Kette führen**
67
  def conversation(qa_chain, message, history):
68
  if qa_chain is None:
69
  return None, "Der QA-Chain wurde nicht initialisiert!", history
 
76
  except Exception as e:
77
  return qa_chain, f"Fehler: {str(e)}", history
78
 
79
+ # **Gradio-Demo erstellen**
80
  def demo():
81
  with gr.Blocks() as demo:
82
  vector_db = gr.State()
 
97
  msg = gr.Textbox(placeholder="Frage eingeben...")
98
  submit_btn = gr.Button("Absenden")
99
 
100
+ # Button-Events definieren
101
  db_btn.click(initialize_database, [document], [vector_db, db_status])
102
+ qachain_btn.click(initialize_llm_chain_wrapper, [slider_temperature, slider_max_tokens, vector_db], [qa_chain])
103
  submit_btn.click(conversation, [qa_chain, msg, []], [qa_chain, "message", "history"])
104
 
105
  demo.launch(debug=True, enable_queue=True)