la04 commited on
Commit
90d6700
·
verified ·
1 Parent(s): 15da3c5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -42
app.py CHANGED
@@ -7,13 +7,19 @@ from langchain_community.vectorstores import FAISS
7
  from langchain.chains import ConversationalRetrievalChain
8
  from langchain.memory import ConversationBufferMemory
9
 
10
- # API-Token
11
- api_token = os.getenv("HF_TOKEN")
12
 
13
- # LLM-Optionen
14
- list_llm = ["google/flan-t5-small", "google/flan-t5-base"]
 
 
 
 
 
 
15
 
16
- # Dokumente laden und aufteilen
17
  def load_doc(list_file_path):
18
  if not list_file_path:
19
  return [], "Fehler: Keine Dokumente gefunden!"
@@ -21,15 +27,15 @@ def load_doc(list_file_path):
21
  documents = []
22
  for loader in loaders:
23
  documents.extend(loader.load())
24
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
25
  return text_splitter.split_documents(documents)
26
 
27
- # Vektor-Datenbank erstellen
28
  def create_db(splits):
29
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
30
  return FAISS.from_documents(splits, embeddings)
31
 
32
- # Datenbank initialisieren
33
  def initialize_database(list_file_obj):
34
  if not list_file_obj:
35
  return None, "Fehler: Keine Dateien hochgeladen!"
@@ -38,31 +44,36 @@ def initialize_database(list_file_obj):
38
  vector_db = create_db(doc_splits)
39
  return vector_db, "Datenbank erfolgreich erstellt!"
40
 
41
- # LLM-Kette initialisieren
42
- def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
43
  if vector_db is None:
44
  return None, "Fehler: Keine Vektordatenbank verfügbar."
45
- if max_tokens > 250:
46
- max_tokens = 250 # Begrenze max_new_tokens, um Fehler zu vermeiden
 
 
 
47
  llm = HuggingFaceEndpoint(
48
  repo_id=llm_model,
49
  huggingfacehub_api_token=api_token,
50
  temperature=temperature,
51
- max_new_tokens=max_tokens,
52
- top_k=top_k,
53
  )
54
  memory = ConversationBufferMemory(memory_key="chat_history", output_key="answer", return_messages=True)
55
  retriever = vector_db.as_retriever()
56
- return ConversationalRetrievalChain.from_llm(
 
57
  llm, retriever=retriever, chain_type="stuff", memory=memory, return_source_documents=True
58
  )
59
 
60
- # LLM initialisieren
61
- def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db):
 
 
62
  if vector_db is None:
63
  return None, "Fehler: Datenbank wurde nicht erstellt!"
64
  llm_name = list_llm[llm_option]
65
- qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db)
66
  return qa_chain, "QA-Kette initialisiert. Chatbot ist bereit!"
67
 
68
  # Konversation
@@ -76,23 +87,27 @@ def conversation(qa_chain, message, history):
76
  formatted_response = history + [{"role": "user", "content": message}, {"role": "assistant", "content": response_text}]
77
  return qa_chain, formatted_response, formatted_response
78
 
79
- # Demo erstellen
80
  def demo():
81
  with gr.Blocks() as demo:
82
  vector_db = gr.State()
83
  qa_chain = gr.State()
84
- gr.Markdown("<center><h1>PDF-Chatbot mit kostenlosen Modellen</h1></center>")
85
 
86
  with gr.Row():
87
  with gr.Column():
88
- document = gr.Files(label="PDF-Dokument hochladen")
89
  db_btn = gr.Button("Erstelle Vektordatenbank")
90
  db_status = gr.Textbox(label="Datenbankstatus", value="Nicht erstellt", interactive=False)
91
 
92
- llm_btn = gr.Radio(["Flan-T5 Small", "Flan-T5 Base"], label="Verfügbare LLMs", value="Flan-T5 Small", type="index")
 
 
 
 
 
93
  slider_temperature = gr.Slider(0.01, 1.0, 0.5, label="Temperature")
94
- slider_maxtokens = gr.Slider(1, 250, 128, label="Max Tokens") # Begrenzung auf 250
95
- slider_topk = gr.Slider(1, 10, 3, label="Top-k")
96
  qachain_btn = gr.Button("Initialisiere QA-Chatbot")
97
  llm_status = gr.Textbox(label="Chatbot-Status", value="Nicht initialisiert", interactive=False)
98
 
@@ -101,24 +116,12 @@ def demo():
101
  msg = gr.Textbox(label="Frage stellen")
102
  submit_btn = gr.Button("Absenden")
103
 
104
- # Event-Handling
105
- db_btn.click(
106
- initialize_database,
107
- inputs=[document],
108
- outputs=[vector_db, db_status]
109
- )
110
- qachain_btn.click(
111
- initialize_LLM,
112
- inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db],
113
- outputs=[qa_chain, llm_status]
114
- )
115
- submit_btn.click(
116
- conversation,
117
- inputs=[qa_chain, msg, chatbot],
118
- outputs=[qa_chain, chatbot, chatbot]
119
- )
120
-
121
- demo.launch(debug=True)
122
 
123
  if __name__ == "__main__":
124
  demo()
 
7
  from langchain.chains import ConversationalRetrievalChain
8
  from langchain.memory import ConversationBufferMemory
9
 
10
+ # Dein Hugging Face Read Token
11
+ api_token = os.getenv("HF_TOKEN", "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")
12
 
13
+ # Modelle für Auswahl
14
+ list_llm = [
15
+ "google/flan-t5-base", # Leichtes Instruktionsmodell
16
+ "sentence-transformers/all-MiniLM-L6-v2", # Embeddings-optimiertes Modell
17
+ "OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5", # Pythia 12B
18
+ "bigscience/bloom-3b", # Multilingualer BLOOM
19
+ "bigscience/bloom-1b7" # Leichtes BLOOM-Modell
20
+ ]
21
 
22
+ # Dokumentenverarbeitung
23
  def load_doc(list_file_path):
24
  if not list_file_path:
25
  return [], "Fehler: Keine Dokumente gefunden!"
 
27
  documents = []
28
  for loader in loaders:
29
  documents.extend(loader.load())
30
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=32)
31
  return text_splitter.split_documents(documents)
32
 
33
+ # Erstelle Vektordatenbank
34
  def create_db(splits):
35
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
36
  return FAISS.from_documents(splits, embeddings)
37
 
38
+ # Initialisiere Datenbank
39
  def initialize_database(list_file_obj):
40
  if not list_file_obj:
41
  return None, "Fehler: Keine Dateien hochgeladen!"
 
44
  vector_db = create_db(doc_splits)
45
  return vector_db, "Datenbank erfolgreich erstellt!"
46
 
47
+ # Initialisiere LLM-Kette
48
+ def initialize_llmchain(llm_model, temperature, max_tokens, vector_db):
49
  if vector_db is None:
50
  return None, "Fehler: Keine Vektordatenbank verfügbar."
51
+ if "pythia" in llm_model or "bloom" in llm_model:
52
+ max_tokens = min(max_tokens, 2048)
53
+ else:
54
+ max_tokens = min(max_tokens, 1024)
55
+
56
  llm = HuggingFaceEndpoint(
57
  repo_id=llm_model,
58
  huggingfacehub_api_token=api_token,
59
  temperature=temperature,
60
+ max_new_tokens=max_tokens
 
61
  )
62
  memory = ConversationBufferMemory(memory_key="chat_history", output_key="answer", return_messages=True)
63
  retriever = vector_db.as_retriever()
64
+
65
+ qa_chain = ConversationalRetrievalChain.from_llm(
66
  llm, retriever=retriever, chain_type="stuff", memory=memory, return_source_documents=True
67
  )
68
 
69
+ return qa_chain
70
+
71
+ # Initialisiere LLM
72
+ def initialize_LLM(llm_option, llm_temperature, max_tokens, vector_db):
73
  if vector_db is None:
74
  return None, "Fehler: Datenbank wurde nicht erstellt!"
75
  llm_name = list_llm[llm_option]
76
+ qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, vector_db)
77
  return qa_chain, "QA-Kette initialisiert. Chatbot ist bereit!"
78
 
79
  # Konversation
 
87
  formatted_response = history + [{"role": "user", "content": message}, {"role": "assistant", "content": response_text}]
88
  return qa_chain, formatted_response, formatted_response
89
 
90
+ # Gradio UI
91
  def demo():
92
  with gr.Blocks() as demo:
93
  vector_db = gr.State()
94
  qa_chain = gr.State()
95
+ gr.Markdown("<center><h1>RAG-Chatbot mit Pythia und BLOOM (CPU-kompatibel)</h1></center>")
96
 
97
  with gr.Row():
98
  with gr.Column():
99
+ document = gr.Files(label="PDF-Dokument hochladen", type="file", file_types=[".pdf"], file_count="multiple")
100
  db_btn = gr.Button("Erstelle Vektordatenbank")
101
  db_status = gr.Textbox(label="Datenbankstatus", value="Nicht erstellt", interactive=False)
102
 
103
+ llm_btn = gr.Radio(
104
+ ["Flan-T5 Base", "MiniLM", "Pythia 12B", "BLOOM 3B", "BLOOM 1.7B"],
105
+ label="Verfügbare LLMs",
106
+ value="Flan-T5 Base",
107
+ type="index"
108
+ )
109
  slider_temperature = gr.Slider(0.01, 1.0, 0.5, label="Temperature")
110
+ slider_maxtokens = gr.Slider(1, 2048, 512, label="Max Tokens")
 
111
  qachain_btn = gr.Button("Initialisiere QA-Chatbot")
112
  llm_status = gr.Textbox(label="Chatbot-Status", value="Nicht initialisiert", interactive=False)
113
 
 
116
  msg = gr.Textbox(label="Frage stellen")
117
  submit_btn = gr.Button("Absenden")
118
 
119
+ # Events verknüpfen
120
+ db_btn.click(initialize_database, inputs=[document], outputs=[vector_db, db_status])
121
+ qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, vector_db], outputs=[qa_chain, llm_status])
122
+ submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, chatbot, chatbot])
123
+
124
+ demo.launch(debug=True, share=True)
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  if __name__ == "__main__":
127
  demo()