la04 commited on
Commit
244a9ba
·
verified ·
1 Parent(s): 231e3ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -39
app.py CHANGED
@@ -1,88 +1,99 @@
1
- import gradio as gr
2
  import os
3
- from langchain.vectorstores import SimpleVectorStore # Direkt, ohne zusätzliche Abhängigkeiten
4
  from langchain.document_loaders import PyPDFLoader
 
5
  from langchain.embeddings import HuggingFaceEmbeddings
 
6
  from langchain.chains import ConversationalRetrievalChain
7
  from langchain.memory import ConversationBufferMemory
8
- from langchain.llms import HuggingFaceHub
9
- from langchain.text_splitter import RecursiveCharacterTextSplitter
10
 
11
- list_llm = ["google/flan-t5-small", "sentence-transformers/all-MiniLM-L6-v2"]
 
 
12
 
13
- def load_doc(list_file_path):
 
14
  loaders = [PyPDFLoader(x) for x in list_file_path]
15
- pages = []
16
  for loader in loaders:
17
- pages.extend(loader.load())
18
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=32)
19
- doc_splits = text_splitter.split_documents(pages)
20
  return doc_splits
21
 
22
- def create_db(splits):
23
- vectordb = SimpleVectorStore.from_documents(splits) # Speichern im Speicher
24
- return vectordb
 
 
25
 
26
- def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
27
- llm = HuggingFaceHub(
28
- repo_id=llm_model,
29
- model_kwargs={
30
- "temperature": temperature,
31
- "max_length": max_tokens,
32
- }
 
33
  )
 
34
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
35
  retriever = vector_db.as_retriever()
 
 
36
  qa_chain = ConversationalRetrievalChain.from_llm(
37
  llm,
38
  retriever=retriever,
39
- chain_type="stuff",
40
  memory=memory,
41
- return_source_documents=True,
42
  )
43
  return qa_chain
44
 
 
45
  def initialize_database(list_file_obj):
46
  list_file_path = [x.name for x in list_file_obj if x is not None]
47
- doc_splits = load_doc(list_file_path)
48
  vector_db = create_db(doc_splits)
49
  return vector_db, "Datenbank erfolgreich erstellt!"
50
 
51
- def initialize_LLM(llm_option, llm_temperature, max_tokens, vector_db):
52
- llm_name = list_llm[llm_option]
53
- qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, 3, vector_db)
54
- return qa_chain, "Chatbot ist bereit."
55
 
 
56
  def conversation(qa_chain, message, history):
57
- formatted_chat_history = [(f"User: {m}", f"Assistant: {r}") for m, r in history]
58
- response = qa_chain({"question": message, "chat_history": formatted_chat_history})
59
- response_answer = response["answer"]
60
- new_history = history + [(message, response_answer)]
61
- return qa_chain, gr.update(value=""), new_history
62
 
 
63
  def demo():
64
  with gr.Blocks() as demo:
65
  vector_db = gr.State()
66
  qa_chain = gr.State()
67
- gr.HTML("<center><h1>PDF QA Chatbot (Kostenlose Version)</h1></center>")
 
68
  with gr.Row():
69
  with gr.Column():
70
- document = gr.Files(file_types=[".pdf"], interactive=True)
71
  db_btn = gr.Button("Erstelle Vektordatenbank")
72
- db_progress = gr.Textbox(value="Nicht initialisiert", show_label=False)
73
- llm_btn = gr.Radio(["Flan-T5-Small", "MiniLM"], label="Verfügbare Modelle")
74
  slider_temperature = gr.Slider(0.01, 1.0, value=0.5, label="Temperature")
75
- slider_maxtokens = gr.Slider(64, 512, value=256, label="Max Tokens")
76
  qachain_btn = gr.Button("Initialisiere QA-Chatbot")
77
 
78
  with gr.Column():
79
  chatbot = gr.Chatbot(height=400)
80
- msg = gr.Textbox(placeholder="Frage stellen...")
81
  submit_btn = gr.Button("Absenden")
82
 
83
- db_btn.click(initialize_database, [document], [vector_db, db_progress])
84
- qachain_btn.click(initialize_LLM, [llm_btn, slider_temperature, slider_maxtokens, vector_db], [qa_chain])
85
  submit_btn.click(conversation, [qa_chain, msg, chatbot], [qa_chain, msg, chatbot])
 
86
  demo.launch(debug=True)
87
 
88
  if __name__ == "__main__":
 
 
1
  import os
2
+ import gradio as gr
3
  from langchain.document_loaders import PyPDFLoader
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  from langchain.embeddings import HuggingFaceEmbeddings
6
+ from langchain.vectorstores import FAISS
7
  from langchain.chains import ConversationalRetrievalChain
8
  from langchain.memory import ConversationBufferMemory
9
+ from langchain.llms import HuggingFacePipeline
10
+ from transformers import pipeline
11
 
12
+ # **Embeddings-Modell (kein API-Key nötig, lokal geladen)**
13
+ EMBEDDINGS_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
14
+ LLM_MODEL_NAME = "google/flan-t5-small" # Alternativ: "google/flan-t5-base", etc.
15
 
16
+ # **Dokumente laden und aufteilen**
17
+ def load_and_split_docs(list_file_path):
18
  loaders = [PyPDFLoader(x) for x in list_file_path]
19
+ documents = []
20
  for loader in loaders:
21
+ documents.extend(loader.load())
22
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=32)
23
+ doc_splits = text_splitter.split_documents(documents)
24
  return doc_splits
25
 
26
+ # **Vektor-Datenbank mit FAISS erstellen**
27
+ def create_db(docs):
28
+ embeddings = HuggingFaceEmbeddings(model_name=EMBEDDINGS_MODEL_NAME)
29
+ faiss_index = FAISS.from_documents(docs, embeddings)
30
+ return faiss_index
31
 
32
+ # **LLM-Kette initialisieren**
33
+ def initialize_llm_chain(llm_model, temperature, max_tokens, vector_db):
34
+ # Hugging Face Pipeline lokal verwenden
35
+ local_pipeline = pipeline(
36
+ "text2text-generation",
37
+ model=llm_model,
38
+ max_length=max_tokens,
39
+ temperature=temperature
40
  )
41
+ llm = HuggingFacePipeline(pipeline=local_pipeline)
42
  memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
43
  retriever = vector_db.as_retriever()
44
+
45
+ # Retrieval-Augmented QA-Kette
46
  qa_chain = ConversationalRetrievalChain.from_llm(
47
  llm,
48
  retriever=retriever,
 
49
  memory=memory,
50
+ return_source_documents=True
51
  )
52
  return qa_chain
53
 
54
+ # **Datenbank und Kette initialisieren**
55
  def initialize_database(list_file_obj):
56
  list_file_path = [x.name for x in list_file_obj if x is not None]
57
+ doc_splits = load_and_split_docs(list_file_path)
58
  vector_db = create_db(doc_splits)
59
  return vector_db, "Datenbank erfolgreich erstellt!"
60
 
61
+ def initialize_llm_chain_wrapper(llm_temperature, max_tokens, vector_db):
62
+ qa_chain = initialize_llm_chain(LLM_MODEL_NAME, llm_temperature, max_tokens, vector_db)
63
+ return qa_chain, "QA-Chatbot ist bereit!"
 
64
 
65
+ # **Konversation mit QA-Kette führen**
66
  def conversation(qa_chain, message, history):
67
+ response = qa_chain({"question": message, "chat_history": history})
68
+ response_text = response["answer"]
69
+ sources = [doc.metadata["source"] for doc in response["source_documents"]]
70
+ return qa_chain, response_text, history + [(message, response_text)]
 
71
 
72
+ # **Gradio-Benutzeroberfläche**
73
  def demo():
74
  with gr.Blocks() as demo:
75
  vector_db = gr.State()
76
  qa_chain = gr.State()
77
+
78
+ gr.HTML("<center><h1>RAG Chatbot mit FAISS und lokalen Modellen</h1></center>")
79
  with gr.Row():
80
  with gr.Column():
81
+ document = gr.Files(file_types=[".pdf"], label="PDF hochladen")
82
  db_btn = gr.Button("Erstelle Vektordatenbank")
83
+ db_status = gr.Textbox(value="Status: Nicht initialisiert", show_label=False)
 
84
  slider_temperature = gr.Slider(0.01, 1.0, value=0.5, label="Temperature")
85
+ slider_max_tokens = gr.Slider(64, 512, value=256, label="Max Tokens")
86
  qachain_btn = gr.Button("Initialisiere QA-Chatbot")
87
 
88
  with gr.Column():
89
  chatbot = gr.Chatbot(height=400)
90
+ msg = gr.Textbox(placeholder="Frage eingeben...")
91
  submit_btn = gr.Button("Absenden")
92
 
93
+ db_btn.click(initialize_database, [document], [vector_db, db_status])
94
+ qachain_btn.click(initialize_llm_chain_wrapper, [slider_temperature, slider_max_tokens, vector_db], [qa_chain])
95
  submit_btn.click(conversation, [qa_chain, msg, chatbot], [qa_chain, msg, chatbot])
96
+
97
  demo.launch(debug=True)
98
 
99
  if __name__ == "__main__":