la04 commited on
Commit
f1c2bc3
·
verified ·
1 Parent(s): 80396ad

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -32
app.py CHANGED
@@ -1,76 +1,87 @@
1
  import gradio as gr
2
  import os
3
- from langchain_community.vectorstores import FAISS
4
  from langchain_community.document_loaders import PyPDFLoader
5
  from langchain.text_splitter import RecursiveCharacterTextSplitter
6
- from langchain_community.embeddings import HuggingFaceEmbeddings
7
- from langchain_community.llms import HuggingFaceEndpoint
8
  from langchain.chains import ConversationalRetrievalChain
9
  from langchain.memory import ConversationBufferMemory
10
 
11
- # Der Token wird sicher aus den Space Secrets abgerufen
12
- api_token = os.getenv("HF_TOKEN") # Kein direkter API-Token im Code sichtbar
13
 
14
- # Kostenlose LLM-Optionen (Free-Version)
15
  list_llm = ["google/flan-t5-small", "google/flan-t5-base"]
16
 
17
- # **Dokument laden und aufteilen**
18
  def load_doc(list_file_path):
 
 
19
  loaders = [PyPDFLoader(x) for x in list_file_path]
20
- pages = []
21
  for loader in loaders:
22
- pages.extend(loader.load())
23
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
24
- return text_splitter.split_documents(pages)
25
 
26
- # **Vektor-Datenbank erstellen**
27
  def create_db(splits):
28
- embeddings = HuggingFaceEmbeddings()
29
  return FAISS.from_documents(splits, embeddings)
30
 
31
- # **LLM-Kette initialisieren**
 
 
 
 
 
 
 
 
 
32
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
 
 
33
  llm = HuggingFaceEndpoint(
34
  repo_id=llm_model,
35
- huggingfacehub_api_token=api_token, # Holt den API-Token aus den Space Secrets
36
  temperature=temperature,
37
  max_new_tokens=max_tokens,
38
  top_k=top_k,
39
  )
40
  memory = ConversationBufferMemory(memory_key="chat_history", output_key="answer", return_messages=True)
41
  retriever = vector_db.as_retriever()
42
- qa_chain = ConversationalRetrievalChain.from_llm(
43
  llm, retriever=retriever, chain_type="stuff", memory=memory, return_source_documents=True
44
  )
45
- return qa_chain
46
 
47
- # **Datenbank initialisieren**
48
- def initialize_database(list_file_obj):
49
- list_file_path = [x.name for x in list_file_obj if x is not None]
50
- doc_splits = load_doc(list_file_path)
51
- vector_db = create_db(doc_splits)
52
- return vector_db, "Datenbank erfolgreich erstellt!"
53
-
54
- # **LLM initialisieren**
55
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db):
 
 
56
  llm_name = list_llm[llm_option]
57
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db)
58
  return qa_chain, "QA-Kette initialisiert. Chatbot ist bereit!"
59
 
60
- # **Konversation**
61
  def conversation(qa_chain, message, history):
 
 
 
 
62
  response = qa_chain.invoke({"question": message, "chat_history": history})
63
- response_answer = response["answer"]
64
- return qa_chain, response_answer, history + [(message, response_answer)]
 
 
65
 
66
- # **Demo erstellen**
67
  def demo():
68
  with gr.Blocks() as demo:
69
  vector_db = gr.State()
70
  qa_chain = gr.State()
71
-
72
- gr.Markdown("<center><h1>PDF-Chatbot mit kostenfreien Hugging Face-Modellen</h1></center>")
73
- document = gr.Files(label="Lade PDF-Dokumente hoch", file_types=[".pdf"])
74
  db_btn = gr.Button("Erstelle Vektordatenbank")
75
  llm_btn = gr.Radio(["Flan-T5 Small", "Flan-T5 Base"], label="Verfügbare LLMs", value="Flan-T5 Small", type="index")
76
  slider_temperature = gr.Slider(0.01, 1.0, 0.5, label="Temperature")
@@ -85,7 +96,7 @@ def demo():
85
  qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], outputs=[qa_chain])
86
  submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, chatbot, chatbot])
87
 
88
- demo.launch()
89
 
90
  if __name__ == "__main__":
91
  demo()
 
1
  import gradio as gr
2
  import os
 
3
  from langchain_community.document_loaders import PyPDFLoader
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint
6
+ from langchain_community.vectorstores import FAISS
7
  from langchain.chains import ConversationalRetrievalChain
8
  from langchain.memory import ConversationBufferMemory
9
 
10
+ # API-Token
11
+ api_token = os.getenv("HF_TOKEN")
12
 
13
+ # LLM-Optionen
14
  list_llm = ["google/flan-t5-small", "google/flan-t5-base"]
15
 
16
+ # Dokumente laden und aufteilen
17
  def load_doc(list_file_path):
18
+ if not list_file_path:
19
+ return [], "Fehler: Keine Dokumente gefunden!"
20
  loaders = [PyPDFLoader(x) for x in list_file_path]
21
+ documents = []
22
  for loader in loaders:
23
+ documents.extend(loader.load())
24
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
25
+ return text_splitter.split_documents(documents)
26
 
27
+ # Vektor-Datenbank erstellen
28
  def create_db(splits):
29
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
30
  return FAISS.from_documents(splits, embeddings)
31
 
32
+ # Datenbank initialisieren
33
+ def initialize_database(list_file_obj):
34
+ if not list_file_obj:
35
+ return None, "Fehler: Keine Dateien hochgeladen!"
36
+ list_file_path = [x.name for x in list_file_obj if x is not None]
37
+ doc_splits = load_doc(list_file_path)
38
+ vector_db = create_db(doc_splits)
39
+ return vector_db, "Datenbank erfolgreich erstellt!"
40
+
41
+ # LLM-Kette initialisieren
42
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
43
+ if vector_db is None:
44
+ return None, "Fehler: Keine Vektordatenbank verfügbar."
45
  llm = HuggingFaceEndpoint(
46
  repo_id=llm_model,
47
+ huggingfacehub_api_token=api_token,
48
  temperature=temperature,
49
  max_new_tokens=max_tokens,
50
  top_k=top_k,
51
  )
52
  memory = ConversationBufferMemory(memory_key="chat_history", output_key="answer", return_messages=True)
53
  retriever = vector_db.as_retriever()
54
+ return ConversationalRetrievalChain.from_llm(
55
  llm, retriever=retriever, chain_type="stuff", memory=memory, return_source_documents=True
56
  )
 
57
 
58
+ # LLM initialisieren
 
 
 
 
 
 
 
59
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db):
60
+ if vector_db is None:
61
+ return None, "Datenbank wurde nicht erstellt!"
62
  llm_name = list_llm[llm_option]
63
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db)
64
  return qa_chain, "QA-Kette initialisiert. Chatbot ist bereit!"
65
 
66
+ # Konversation
67
  def conversation(qa_chain, message, history):
68
+ if qa_chain is None:
69
+ return None, [{"role": "system", "content": "Die QA-Kette wurde nicht initialisiert."}], history
70
+ if not message.strip():
71
+ return qa_chain, [{"role": "system", "content": "Bitte eine Frage eingeben!"}], history
72
  response = qa_chain.invoke({"question": message, "chat_history": history})
73
+ response_text = response.get("answer", "Keine Antwort verfügbar.")
74
+ sources = [doc.metadata["source"] for doc in response.get("source_documents", [])]
75
+ formatted_response = history + [{"role": "assistant", "content": response_text}]
76
+ return qa_chain, formatted_response, formatted_response
77
 
78
+ # Demo erstellen
79
  def demo():
80
  with gr.Blocks() as demo:
81
  vector_db = gr.State()
82
  qa_chain = gr.State()
83
+ gr.Markdown("<center><h1>PDF-Chatbot mit kostenlosen Modellen</h1></center>")
84
+ document = gr.Files(label="PDF-Dokument hochladen")
 
85
  db_btn = gr.Button("Erstelle Vektordatenbank")
86
  llm_btn = gr.Radio(["Flan-T5 Small", "Flan-T5 Base"], label="Verfügbare LLMs", value="Flan-T5 Small", type="index")
87
  slider_temperature = gr.Slider(0.01, 1.0, 0.5, label="Temperature")
 
96
  qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], outputs=[qa_chain])
97
  submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot], outputs=[qa_chain, chatbot, chatbot])
98
 
99
+ demo.launch(debug=True)
100
 
101
  if __name__ == "__main__":
102
  demo()