la04 commited on
Commit
80396ad
·
verified ·
1 Parent(s): 62d5470

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -118
app.py CHANGED
@@ -1,143 +1,91 @@
1
- import os
2
  import gradio as gr
 
 
3
  from langchain_community.document_loaders import PyPDFLoader
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
- from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
6
- from langchain_community.vectorstores import FAISS
7
  from langchain.chains import ConversationalRetrievalChain
8
  from langchain.memory import ConversationBufferMemory
9
- from transformers import pipeline
10
 
11
- EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
12
- LLM_MODEL_NAME = "google/flan-t5-small"
13
 
14
- MAX_INPUT_LENGTH = 512 # Maximale Länge der Eingabe für das Modell
 
15
 
16
- # **Dokumente laden und aufteilen**
17
- def load_and_split_docs(list_file_path):
18
- if not list_file_path:
19
- return [], "Fehler: Keine Dokumente gefunden!"
20
  loaders = [PyPDFLoader(x) for x in list_file_path]
21
- documents = []
22
  for loader in loaders:
23
- documents.extend(loader.load())
24
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=32)
25
- return text_splitter.split_documents(documents)
26
-
27
- # **Vektor-Datenbank mit FAISS erstellen**
28
- def create_db(docs):
29
- embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
30
- return FAISS.from_documents(docs, embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  # **Datenbank initialisieren**
33
  def initialize_database(list_file_obj):
34
- if not list_file_obj or all(x is None for x in list_file_obj):
35
- return None, "Fehler: Keine Dateien hochgeladen!"
36
  list_file_path = [x.name for x in list_file_obj if x is not None]
37
- doc_splits = load_and_split_docs(list_file_path)
38
  vector_db = create_db(doc_splits)
39
- print("Vektordatenbank erfolgreich erstellt!")
40
  return vector_db, "Datenbank erfolgreich erstellt!"
41
 
42
- # **QA-Kette initialisieren (Wrapper)**
43
- def initialize_llm_chain_wrapper(temperature, max_tokens, vector_db):
44
- if vector_db is None:
45
- print("Fehler: Vektordatenbank nicht vorhanden!")
46
- return None, "Fehler: Die Vektordatenbank wurde nicht erstellt! Bitte lade ein PDF hoch und klicke 'Erstelle Vektordatenbank'."
47
-
48
- try:
49
- print("Initialisiere QA-Chatbot...")
50
- qa_chain = initialize_llm_chain(temperature, max_tokens, vector_db)
51
- print("QA-Chatbot erfolgreich initialisiert!")
52
- return qa_chain, "QA-Chatbot ist bereit!"
53
- except Exception as e:
54
- print(f"Fehler bei der Initialisierung: {str(e)}")
55
- return None, f"Fehler bei der Initialisierung: {str(e)}"
56
-
57
- # **LLM-Kette erstellen**
58
- def initialize_llm_chain(temperature, max_tokens, vector_db):
59
- print("Lade Modellpipeline...")
60
- local_pipeline = pipeline(
61
- "text2text-generation",
62
- model=LLM_MODEL_NAME,
63
- max_length=max_tokens,
64
- temperature=temperature
65
- )
66
- print(f"Modell {LLM_MODEL_NAME} erfolgreich geladen.")
67
- llm = HuggingFacePipeline(pipeline=local_pipeline)
68
- memory = ConversationBufferMemory(memory_key="chat_history", output_key="answer") # Speichere nur die Antwort
69
- retriever = vector_db.as_retriever()
70
- return ConversationalRetrievalChain.from_llm(
71
- llm,
72
- retriever=retriever,
73
- memory=memory,
74
- return_source_documents=True
75
- )
76
-
77
- # **Konversation mit QA-Kette führen**
78
- def truncate_history(history, max_length=MAX_INPUT_LENGTH):
79
- total_length = 0
80
- truncated_history = []
81
-
82
- for message in reversed(history):
83
- total_length += len(message[0]) + len(message[1])
84
- if total_length > max_length:
85
- break
86
- truncated_history.insert(0, message)
87
-
88
- return truncated_history
89
 
 
90
  def conversation(qa_chain, message, history):
91
- if qa_chain is None:
92
- return None, [{"role": "system", "content": "Der QA-Chain wurde nicht initialisiert!"}], history
93
- if not message.strip():
94
- return qa_chain, [{"role": "system", "content": "Bitte eine Frage eingeben!"}], history
95
- try:
96
- print(f"Frage: {message}")
97
- history = truncate_history(history) # Beschränke den Verlauf auf unter 512 Tokens
98
- response = qa_chain.invoke({"question": message, "chat_history": history})
99
- response_text = response["answer"]
100
- sources = [doc.metadata["source"] for doc in response["source_documents"]]
101
- sources_text = "\n".join(sources) if sources else "Keine Quellen verfügbar"
102
-
103
- # Strukturierte Rückgabe an `gr.Chatbot`
104
- formatted_response = history + [
105
- {"role": "user", "content": message},
106
- {"role": "assistant", "content": f"{response_text}\n\n**Quellen:**\n{sources_text}"}
107
- ]
108
- print("Antwort erfolgreich generiert.")
109
- return qa_chain, formatted_response, formatted_response
110
- except Exception as e:
111
- print(f"Fehler während der Konversation: {str(e)}")
112
- return qa_chain, [{"role": "system", "content": f"Fehler: {str(e)}"}], history
113
 
114
- # **Gradio-Demo erstellen**
115
  def demo():
116
  with gr.Blocks() as demo:
117
- vector_db = gr.State() # Zustand für die Vektordatenbank
118
- qa_chain = gr.State() # Zustand für den QA-Chain
119
- chat_history = gr.State([]) # Chatverlauf speichern
120
-
121
- gr.HTML("<center><h1>RAG Chatbot mit FAISS und lokalen Modellen</h1></center>")
122
- with gr.Row():
123
- with gr.Column():
124
- document = gr.Files(file_types=[".pdf"], label="PDF hochladen")
125
- db_btn = gr.Button("Erstelle Vektordatenbank")
126
- db_status = gr.Textbox(value="Status: Nicht initialisiert", show_label=False)
127
- slider_temperature = gr.Slider(0.01, 1.0, value=0.5, label="Temperature")
128
- slider_max_tokens = gr.Slider(64, 512, value=256, label="Max Tokens")
129
- qachain_btn = gr.Button("Initialisiere QA-Chatbot")
130
-
131
- with gr.Column():
132
- chatbot = gr.Chatbot(label="Chatbot", type='messages', height=400)
133
- msg = gr.Textbox(label="Deine Frage:", placeholder="Frage eingeben...")
134
- submit_btn = gr.Button("Absenden")
135
-
136
- db_btn.click(initialize_database, [document], [vector_db, db_status])
137
- qachain_btn.click(initialize_llm_chain_wrapper, [slider_temperature, slider_max_tokens, vector_db], [qa_chain, db_status])
138
- submit_btn.click(conversation, [qa_chain, msg, chat_history], [qa_chain, chatbot, chat_history])
139
-
140
- demo.launch(debug=True)
141
 
142
  if __name__ == "__main__":
143
  demo()
 
 
1
  import gradio as gr
2
+ import os
3
+ from langchain_community.vectorstores import FAISS
4
  from langchain_community.document_loaders import PyPDFLoader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain_community.embeddings import HuggingFaceEmbeddings
7
+ from langchain_community.llms import HuggingFaceEndpoint
8
  from langchain.chains import ConversationalRetrievalChain
9
  from langchain.memory import ConversationBufferMemory
 
10
 
11
+ # Der Token wird sicher aus den Space Secrets abgerufen
12
+ api_token = os.getenv("HF_TOKEN") # Kein direkter API-Token im Code sichtbar
13
 
14
+ # Kostenlose LLM-Optionen (Free-Version)
15
+ list_llm = ["google/flan-t5-small", "google/flan-t5-base"]
16
 
17
+ # **Dokument laden und aufteilen**
18
+ def load_doc(list_file_path):
 
 
19
  loaders = [PyPDFLoader(x) for x in list_file_path]
20
+ pages = []
21
  for loader in loaders:
22
+ pages.extend(loader.load())
23
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
24
+ return text_splitter.split_documents(pages)
25
+
26
+ # **Vektor-Datenbank erstellen**
27
+ def create_db(splits):
28
+ embeddings = HuggingFaceEmbeddings()
29
+ return FAISS.from_documents(splits, embeddings)
30
+
31
+ # **LLM-Kette initialisieren**
32
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
33
+ llm = HuggingFaceEndpoint(
34
+ repo_id=llm_model,
35
+ huggingfacehub_api_token=api_token, # Holt den API-Token aus den Space Secrets
36
+ temperature=temperature,
37
+ max_new_tokens=max_tokens,
38
+ top_k=top_k,
39
+ )
40
+ memory = ConversationBufferMemory(memory_key="chat_history", output_key="answer", return_messages=True)
41
+ retriever = vector_db.as_retriever()
42
+ qa_chain = ConversationalRetrievalChain.from_llm(
43
+ llm, retriever=retriever, chain_type="stuff", memory=memory, return_source_documents=True
44
+ )
45
+ return qa_chain
46
 
47
  # **Datenbank initialisieren**
48
  def initialize_database(list_file_obj):
 
 
49
  list_file_path = [x.name for x in list_file_obj if x is not None]
50
+ doc_splits = load_doc(list_file_path)
51
  vector_db = create_db(doc_splits)
 
52
  return vector_db, "Datenbank erfolgreich erstellt!"
53
 
54
+ # **LLM initialisieren**
55
+ def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db):
56
+ llm_name = list_llm[llm_option]
57
+ qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db)
58
+ return qa_chain, "QA-Kette initialisiert. Chatbot ist bereit!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ # **Konversation**
61
  def conversation(qa_chain, message, history):
62
+ response = qa_chain.invoke({"question": message, "chat_history": history})
63
+ response_answer = response["answer"]
64
+ return qa_chain, response_answer, history + [(message, response_answer)]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ # **Demo erstellen**
67
  def demo():
68
  with gr.Blocks() as demo:
69
+ vector_db = gr.State()
70
+ qa_chain = gr.State()
71
+
72
+ gr.Markdown("<center><h1>PDF-Chatbot mit kostenfreien Hugging Face-Modellen</h1></center>")
73
+ document = gr.Files(label="Lade PDF-Dokumente hoch", file_types=[".pdf"])
74
+ db_btn = gr.Button("Erstelle Vektordatenbank")
75
+ llm_btn = gr.Radio(["Flan-T5 Small", "Flan-T5 Base"], label="Verfügbare LLMs", value="Flan-T5 Small", type="index")
76
+ slider_temperature = gr.Slider(0.01, 1.0, 0.5, label="Temperature")
77
+ slider_maxtokens = gr.Slider(128, 2048, 512, label="Max Tokens")
78
+ slider_topk = gr.Slider(1, 10, 3, label="Top-k")
79
+ qachain_btn = gr.Button("Initialisiere QA-Chatbot")
80
+ chatbot = gr.Chatbot(label="Chatbot", height=400)
81
+ msg = gr.Textbox(label="Frage stellen")
82
+ submit_btn = gr.Button("Absenden")
83
+
84
+ db_btn.click(initialize_database, inputs=[document], outputs=[vector_db])
85
+ qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], outputs=[qa_chain])
86
+ submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, chatbot, chatbot])
87
+
88
+ demo.launch()
 
 
 
 
89
 
90
  if __name__ == "__main__":
91
  demo()