la04 commited on
Commit
d82bfa1
·
verified ·
1 Parent(s): b73ff8c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -41
app.py CHANGED
@@ -1,37 +1,32 @@
1
  import os
2
  import gradio as gr
3
- from langchain.document_loaders import PyPDFLoader
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
- from langchain.embeddings import HuggingFaceEmbeddings
6
- from langchain.vectorstores import FAISS
7
  from langchain.chains import ConversationalRetrievalChain
8
  from langchain.memory import ConversationBufferMemory
9
- from langchain.llms import HuggingFacePipeline
10
  from transformers import pipeline
11
 
12
- # **Embeddings-Modell (kein API-Key nötig, lokal geladen)**
13
  EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
14
- LLM_MODEL_NAME = "google/flan-t5-small" # Alternativ: "google/flan-t5-base", etc.
15
 
16
- # **Dokumente laden und aufteilen**
17
  def load_and_split_docs(list_file_path):
 
 
18
  loaders = [PyPDFLoader(x) for x in list_file_path]
19
  documents = []
20
  for loader in loaders:
21
  documents.extend(loader.load())
22
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=32)
23
- doc_splits = text_splitter.split_documents(documents)
24
- return doc_splits
25
 
26
- # **Vektor-Datenbank mit FAISS erstellen**
27
  def create_db(docs):
28
  embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
29
- faiss_index = FAISS.from_documents(docs, embeddings)
30
- return faiss_index
31
 
32
- # **LLM-Kette initialisieren**
33
  def initialize_llm_chain(llm_model, temperature, max_tokens, vector_db):
34
- # Hugging Face Pipeline lokal verwenden
35
  local_pipeline = pipeline(
36
  "text2text-generation",
37
  model=llm_model,
@@ -39,37 +34,15 @@ def initialize_llm_chain(llm_model, temperature, max_tokens, vector_db):
39
  temperature=temperature
40
  )
41
  llm = HuggingFacePipeline(pipeline=local_pipeline)
42
- memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
43
  retriever = vector_db.as_retriever()
44
-
45
- # Retrieval-Augmented QA-Kette
46
- qa_chain = ConversationalRetrievalChain.from_llm(
47
  llm,
48
  retriever=retriever,
49
  memory=memory,
50
  return_source_documents=True
51
  )
52
- return qa_chain
53
 
54
- # **Datenbank und Kette initialisieren**
55
- def initialize_database(list_file_obj):
56
- list_file_path = [x.name for x in list_file_obj if x is not None]
57
- doc_splits = load_and_split_docs(list_file_path)
58
- vector_db = create_db(doc_splits)
59
- return vector_db, "Datenbank erfolgreich erstellt!"
60
-
61
- def initialize_llm_chain_wrapper(llm_temperature, max_tokens, vector_db):
62
- qa_chain = initialize_llm_chain(LLM_MODEL_NAME, llm_temperature, max_tokens, vector_db)
63
- return qa_chain, "QA-Chatbot ist bereit!"
64
-
65
- # **Konversation mit QA-Kette führen**
66
- def conversation(qa_chain, message, history):
67
- response = qa_chain({"question": message, "chat_history": history})
68
- response_text = response["answer"]
69
- sources = [doc.metadata["source"] for doc in response["source_documents"]]
70
- return qa_chain, response_text, history + [(message, response_text)]
71
-
72
- # **Gradio-Benutzeroberfläche**
73
  def demo():
74
  with gr.Blocks() as demo:
75
  vector_db = gr.State()
@@ -86,15 +59,15 @@ def demo():
86
  qachain_btn = gr.Button("Initialisiere QA-Chatbot")
87
 
88
  with gr.Column():
89
- chatbot = gr.Chatbot(height=400)
90
  msg = gr.Textbox(placeholder="Frage eingeben...")
91
  submit_btn = gr.Button("Absenden")
92
 
93
  db_btn.click(initialize_database, [document], [vector_db, db_status])
94
  qachain_btn.click(initialize_llm_chain_wrapper, [slider_temperature, slider_max_tokens, vector_db], [qa_chain])
95
- submit_btn.click(conversation, [qa_chain, msg, chatbot], [qa_chain, msg, chatbot])
96
 
97
- demo.launch(debug=True)
98
 
99
  if __name__ == "__main__":
100
  demo()
 
1
  import os
2
  import gradio as gr
3
+ from langchain_community.document_loaders import PyPDFLoader
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain_community.embeddings import HuggingFaceEmbeddings
6
+ from langchain_community.vectorstores import FAISS
7
  from langchain.chains import ConversationalRetrievalChain
8
  from langchain.memory import ConversationBufferMemory
9
+ from langchain_community.llms import HuggingFacePipeline
10
  from transformers import pipeline
11
 
 
12
  EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
13
+ LLM_MODEL_NAME = "google/flan-t5-small"
14
 
 
15
  def load_and_split_docs(list_file_path):
16
+ if not list_file_path:
17
+ return [], "Fehler: Keine Dokumente gefunden!"
18
  loaders = [PyPDFLoader(x) for x in list_file_path]
19
  documents = []
20
  for loader in loaders:
21
  documents.extend(loader.load())
22
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=32)
23
+ return text_splitter.split_documents(documents)
 
24
 
 
25
  def create_db(docs):
26
  embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
27
+ return FAISS.from_documents(docs, embeddings)
 
28
 
 
29
  def initialize_llm_chain(llm_model, temperature, max_tokens, vector_db):
 
30
  local_pipeline = pipeline(
31
  "text2text-generation",
32
  model=llm_model,
 
34
  temperature=temperature
35
  )
36
  llm = HuggingFacePipeline(pipeline=local_pipeline)
37
+ memory = ConversationBufferMemory(memory_key="chat_history")
38
  retriever = vector_db.as_retriever()
39
+ return ConversationalRetrievalChain.from_llm(
 
 
40
  llm,
41
  retriever=retriever,
42
  memory=memory,
43
  return_source_documents=True
44
  )
 
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def demo():
47
  with gr.Blocks() as demo:
48
  vector_db = gr.State()
 
59
  qachain_btn = gr.Button("Initialisiere QA-Chatbot")
60
 
61
  with gr.Column():
62
+ chatbot = gr.Chatbot(type='messages', height=400)
63
  msg = gr.Textbox(placeholder="Frage eingeben...")
64
  submit_btn = gr.Button("Absenden")
65
 
66
  db_btn.click(initialize_database, [document], [vector_db, db_status])
67
  qachain_btn.click(initialize_llm_chain_wrapper, [slider_temperature, slider_max_tokens, vector_db], [qa_chain])
68
+ submit_btn.click(conversation, [qa_chain, msg, []], [qa_chain, "message", "history"])
69
 
70
+ demo.launch(debug=True, enable_queue=True)
71
 
72
  if __name__ == "__main__":
73
  demo()