la04 commited on
Commit
23cbcf8
·
verified ·
1 Parent(s): 675bca8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -41
app.py CHANGED
@@ -1,34 +1,29 @@
1
  import gradio as gr
2
  import os
3
- from langchain.vectorstores import FAISS # Import für Vektordatenbank FAISS
4
- from langchain.document_loaders import PyPDFLoader # PDF-Loader zum Laden der Dokumente
5
- from langchain.embeddings import HuggingFaceEmbeddings # Embeddings-Erstellung mit Hugging Face-Modellen
6
- from langchain.chains import ConversationalRetrievalChain # Chain für QA-Funktionalität
7
- from langchain.memory import ConversationBufferMemory # Speichern des Chat-Verlaufs im Speicher
8
- from langchain.llms import HuggingFaceHub # Für das Laden der Modelle von Hugging Face Hub
9
- from langchain.text_splitter import RecursiveCharacterTextSplitter # Aufteilen von Dokumenten in Chunks
10
 
11
- # Liste der LLM-Modelle (leichte CPU-freundliche Modelle)
12
- list_llm = ["google/flan-t5-small", "distilbert-base-uncased"]
13
- list_llm_simple = [os.path.basename(llm) for llm in list_llm]
14
 
15
- # PDF-Dokument laden und in Chunks aufteilen
16
  def load_doc(list_file_path):
17
  loaders = [PyPDFLoader(x) for x in list_file_path]
18
  pages = []
19
  for loader in loaders:
20
- pages.extend(loader.load()) # Laden der Seiten aus PDF
21
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=32) # Chunks für CPU
22
  doc_splits = text_splitter.split_documents(pages)
23
  return doc_splits
24
 
25
- # Vektordatenbank erstellen
26
  def create_db(splits):
27
- embeddings = HuggingFaceEmbeddings() # Erstellen der Embeddings mit Hugging Face
28
- vectordb = FAISS.from_documents(splits, embeddings)
29
  return vectordb
30
 
31
- # Initialisierung des ConversationalRetrievalChain
32
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
33
  llm = HuggingFaceHub(
34
  repo_id=llm_model,
@@ -50,64 +45,46 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
50
  )
51
  return qa_chain
52
 
53
- # Initialisierung der Datenbank
54
  def initialize_database(list_file_obj):
55
  list_file_path = [x.name for x in list_file_obj if x is not None]
56
  doc_splits = load_doc(list_file_path)
57
  vector_db = create_db(doc_splits)
58
  return vector_db, "Datenbank erfolgreich erstellt!"
59
 
60
- # Initialisierung des LLM
61
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db):
62
  llm_name = list_llm[llm_option]
63
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db)
64
  return qa_chain, "LLM erfolgreich initialisiert! Chatbot ist bereit."
65
 
66
- # Chat-Historie formatieren
67
- def format_chat_history(message, chat_history):
68
- formatted_chat_history = []
69
- for user_message, bot_message in chat_history:
70
- formatted_chat_history.append(f"User: {user_message}")
71
- formatted_chat_history.append(f"Assistant: {bot_message}")
72
- return formatted_chat_history
73
-
74
- # Konversationsfunktion
75
  def conversation(qa_chain, message, history):
76
- formatted_chat_history = format_chat_history(message, history)
77
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
78
  response_answer = response["answer"]
79
  new_history = history + [(message, response_answer)]
80
  return qa_chain, gr.update(value=""), new_history
81
 
82
- # Gradio-Frontend
83
  def demo():
84
  with gr.Blocks() as demo:
85
  vector_db = gr.State()
86
  qa_chain = gr.State()
87
- gr.HTML("<center><h1>RAG PDF Chatbot</h1></center>")
88
  with gr.Row():
89
  with gr.Column():
90
- gr.Markdown("### Schritt 1: Lade PDF-Dokument hoch")
91
  document = gr.Files(height=300, file_count="multiple", file_types=[".pdf"], interactive=True)
92
  db_btn = gr.Button("Erstelle Vektordatenbank")
93
  db_progress = gr.Textbox(value="Nicht initialisiert", show_label=False)
94
- gr.Markdown("### Schritt 2: Wähle LLM und Einstellungen")
95
- llm_btn = gr.Radio(list_llm_simple, label="Verfügbare Modelle", value=list_llm_simple[0], type="index")
96
- slider_temperature = gr.Slider(0.01, 1.0, value=0.5, step=0.1, label="Temperature")
97
- slider_maxtokens = gr.Slider(64, 512, value=256, step=64, label="Max Tokens")
98
- slider_topk = gr.Slider(1, 10, value=3, step=1, label="Top-k")
99
  qachain_btn = gr.Button("Initialisiere QA-Chatbot")
100
- llm_progress = gr.Textbox(value="Nicht initialisiert", show_label=False)
101
 
102
  with gr.Column():
103
- gr.Markdown("### Schritt 3: Stelle Fragen an dein Dokument")
104
  chatbot = gr.Chatbot(height=400, type="messages")
105
  msg = gr.Textbox(placeholder="Frage stellen...")
106
  submit_btn = gr.Button("Absenden")
107
 
108
  db_btn.click(initialize_database, [document], [vector_db, db_progress])
109
- qachain_btn.click(initialize_LLM, [llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], [qa_chain, llm_progress])
110
- msg.submit(conversation, [qa_chain, msg, chatbot], [qa_chain, msg, chatbot])
111
  submit_btn.click(conversation, [qa_chain, msg, chatbot], [qa_chain, msg, chatbot])
112
  demo.launch(debug=True)
113
 
 
1
  import gradio as gr
2
  import os
3
+ from langchain.vectorstores import Chroma # Chroma als Vektordatenbank
4
+ from langchain.document_loaders import PyPDFLoader
5
+ from langchain.embeddings import HuggingFaceEmbeddings
6
+ from langchain.chains import ConversationalRetrievalChain
7
+ from langchain.memory import ConversationBufferMemory
8
+ from langchain.llms import HuggingFaceHub
9
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
10
 
11
+ list_llm = ["google/flan-t5-small", "sentence-transformers/all-MiniLM-L6-v2"]
 
 
12
 
 
13
  def load_doc(list_file_path):
14
  loaders = [PyPDFLoader(x) for x in list_file_path]
15
  pages = []
16
  for loader in loaders:
17
+ pages.extend(loader.load())
18
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=32)
19
  doc_splits = text_splitter.split_documents(pages)
20
  return doc_splits
21
 
 
22
  def create_db(splits):
23
+ embeddings = HuggingFaceEmbeddings()
24
+ vectordb = Chroma.from_documents(splits, embeddings)
25
  return vectordb
26
 
 
27
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db):
28
  llm = HuggingFaceHub(
29
  repo_id=llm_model,
 
45
  )
46
  return qa_chain
47
 
 
48
  def initialize_database(list_file_obj):
49
  list_file_path = [x.name for x in list_file_obj if x is not None]
50
  doc_splits = load_doc(list_file_path)
51
  vector_db = create_db(doc_splits)
52
  return vector_db, "Datenbank erfolgreich erstellt!"
53
 
 
54
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db):
55
  llm_name = list_llm[llm_option]
56
  qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db)
57
  return qa_chain, "LLM erfolgreich initialisiert! Chatbot ist bereit."
58
 
 
 
 
 
 
 
 
 
 
59
  def conversation(qa_chain, message, history):
60
+ formatted_chat_history = [(f"User: {m}", f"Assistant: {r}") for m, r in history]
61
  response = qa_chain({"question": message, "chat_history": formatted_chat_history})
62
  response_answer = response["answer"]
63
  new_history = history + [(message, response_answer)]
64
  return qa_chain, gr.update(value=""), new_history
65
 
 
66
  def demo():
67
  with gr.Blocks() as demo:
68
  vector_db = gr.State()
69
  qa_chain = gr.State()
70
+ gr.HTML("<center><h1>RAG PDF Chatbot (Kostenlose Version)</h1></center>")
71
  with gr.Row():
72
  with gr.Column():
 
73
  document = gr.Files(height=300, file_count="multiple", file_types=[".pdf"], interactive=True)
74
  db_btn = gr.Button("Erstelle Vektordatenbank")
75
  db_progress = gr.Textbox(value="Nicht initialisiert", show_label=False)
76
+ llm_btn = gr.Radio(["Flan-T5-Small", "MiniLM"], label="Verfügbare Modelle")
77
+ slider_temperature = gr.Slider(0.01, 1.0, value=0.5, label="Temperature")
78
+ slider_maxtokens = gr.Slider(64, 512, value=256, label="Max Tokens")
 
 
79
  qachain_btn = gr.Button("Initialisiere QA-Chatbot")
 
80
 
81
  with gr.Column():
 
82
  chatbot = gr.Chatbot(height=400, type="messages")
83
  msg = gr.Textbox(placeholder="Frage stellen...")
84
  submit_btn = gr.Button("Absenden")
85
 
86
  db_btn.click(initialize_database, [document], [vector_db, db_progress])
87
+ qachain_btn.click(initialize_LLM, [llm_btn, slider_temperature, slider_maxtokens, vector_db], [qa_chain])
 
88
  submit_btn.click(conversation, [qa_chain, msg, chatbot], [qa_chain, msg, chatbot])
89
  demo.launch(debug=True)
90