DHEIVER commited on
Commit
77fdcad
·
verified ·
1 Parent(s): ca3571c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -17
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import gradio as gr
2
  import os
3
  import torch
4
- from langchain_community.vectorstores import FAISS, Chroma
5
  from langchain_community.document_loaders import PyPDFLoader
6
  from langchain.text_splitter import RecursiveCharacterTextSplitter
7
  from langchain.chains import ConversationalRetrievalChain
@@ -10,16 +10,16 @@ from langchain_community.llms import HuggingFaceEndpoint
10
  from langchain.memory import ConversationBufferMemory
11
  from langchain_community.retrievers import BM25Retriever
12
  from langchain.retrievers import EnsembleRetriever
13
- from langchain.retrievers.multi_query import MultiQueryRetriever
14
 
15
  # Environment variable for API token
16
- api_token = os.getenv("API_TOKEN")
 
17
  if not api_token:
18
  raise ValueError("Environment variable 'FirstToken' not set. Please set the Hugging Face API token.")
19
 
20
  # Available LLM models
21
  list_llm = [
22
- "meta-llama/Meta-Llama-3-8B-Instruct",
23
  "mistralai/Mistral-7B-Instruct-v0.2",
24
  "deepseek-ai/deepseek-llm-7b-chat"
25
  ]
@@ -55,11 +55,6 @@ def create_chromadb(splits, persist_directory="chroma_db"):
55
  )
56
  return chromadb
57
 
58
- def create_faissdb(splits):
59
- """Create FAISS vector database from document splits."""
60
- embeddings = HuggingFaceEmbeddings()
61
- return FAISS.from_documents(splits, embeddings)
62
-
63
  # -----------------------------------------------------------------------------
64
  # Retrievers
65
  # -----------------------------------------------------------------------------
@@ -96,7 +91,11 @@ def initialize_database(list_file_obj, progress=gr.Progress()):
96
  # -----------------------------------------------------------------------------
97
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, retriever):
98
  """Initialize the language model chain with error handling."""
 
 
 
99
  try:
 
100
  llm = HuggingFaceEndpoint(
101
  repo_id=llm_model,
102
  huggingfacehub_api_token=api_token,
@@ -127,6 +126,9 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, retriever):
127
  # -----------------------------------------------------------------------------
128
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, retriever, progress=gr.Progress()):
129
  """Initialize the Language Model."""
 
 
 
130
  try:
131
  llm_name = list_llm[llm_option]
132
  print(f"Selected LLM model: {llm_name}")
@@ -150,7 +152,6 @@ def conversation(qa_chain, message, history, lang):
150
  if not qa_chain:
151
  return None, gr.update(value="Assistant not initialized"), history, "", 0, "", 0, "", 0
152
 
153
- # Add language instruction
154
  lang_instruction = " (Responda em Português)" if lang == "pt" else " (Respond in English)"
155
  query = message + lang_instruction
156
 
@@ -159,13 +160,11 @@ def conversation(qa_chain, message, history, lang):
159
  response = qa_chain.invoke({"question": query, "chat_history": formatted_chat_history})
160
  answer = response["answer"].split("Helpful Answer:")[-1].strip() if "Helpful Answer:" in response["answer"] else response["answer"]
161
 
162
- # Extract sources (handle cases where fewer than 3 documents are returned)
163
  sources = response["source_documents"]
164
  source_data = [("Unknown", 0)] * 3
165
  for i, doc in enumerate(sources[:3]):
166
  source_data[i] = (doc.page_content.strip(), doc.metadata["page"] + 1)
167
 
168
- # Update history without the language instruction
169
  new_history = history + [(message, answer)]
170
  return (
171
  qa_chain, gr.update(value=""), new_history,
@@ -214,7 +213,7 @@ def demo():
214
  slider_temperature = gr.Slider(0.01, 1.0, value=0.5, step=0.1, label="Analysis Precision")
215
  slider_maxtokens = gr.Slider(128, 9192, value=4096, step=128, label="Response Length")
216
  slider_topk = gr.Slider(1, 10, value=3, step=1, label="Analysis Diversity")
217
- qachain_btn = gr.Button("Initialize Assistant")
218
  llm_progress = gr.Textbox(value="Not initialized", label="Assistant Status")
219
 
220
  with gr.Column(scale=2):
@@ -232,10 +231,36 @@ def demo():
232
 
233
  # Event Handlers
234
  language_btn.change(lambda x: "en" if x == "English" else "pt", inputs=language_btn, outputs=language)
235
- db_btn.click(initialize_database, inputs=[document], outputs=[retriever, db_progress])
236
- qachain_btn.click(initialize_LLM, inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, retriever], outputs=[qa_chain, llm_progress])
237
- submit_btn.click(conversation, inputs=[qa_chain, msg, chatbot, language], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page])
238
- msg.submit(conversation, inputs=[qa_chain, msg, chatbot, language], outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
  demo.launch(debug=True)
241
 
 
1
  import gradio as gr
2
  import os
3
  import torch
4
+ from langchain_community.vectorstores import Chroma
5
  from langchain_community.document_loaders import PyPDFLoader
6
  from langchain.text_splitter import RecursiveCharacterTextSplitter
7
  from langchain.chains import ConversationalRetrievalChain
 
10
  from langchain.memory import ConversationBufferMemory
11
  from langchain_community.retrievers import BM25Retriever
12
  from langchain.retrievers import EnsembleRetriever
 
13
 
14
  # Environment variable for API token
15
+ api_token = os.getenv("FirstToken")
16
+ print(f"API Token loaded: {api_token[:5]}...") # Debug: Show first 5 chars of token
17
  if not api_token:
18
  raise ValueError("Environment variable 'FirstToken' not set. Please set the Hugging Face API token.")
19
 
20
  # Available LLM models
21
  list_llm = [
22
+ "mistralai/Mixtral-8x7B-Instruct-v0.1", # Publicly accessible
23
  "mistralai/Mistral-7B-Instruct-v0.2",
24
  "deepseek-ai/deepseek-llm-7b-chat"
25
  ]
 
55
  )
56
  return chromadb
57
 
 
 
 
 
 
58
  # -----------------------------------------------------------------------------
59
  # Retrievers
60
  # -----------------------------------------------------------------------------
 
91
  # -----------------------------------------------------------------------------
92
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, retriever):
93
  """Initialize the language model chain with error handling."""
94
+ if retriever is None:
95
+ raise ValueError("Retriever is None. Please process documents first.")
96
+
97
  try:
98
+ print(f"Initializing LLM: {llm_model} with token: {api_token[:5]}...")
99
  llm = HuggingFaceEndpoint(
100
  repo_id=llm_model,
101
  huggingfacehub_api_token=api_token,
 
126
  # -----------------------------------------------------------------------------
127
  def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, retriever, progress=gr.Progress()):
128
  """Initialize the Language Model."""
129
+ if retriever is None:
130
+ return None, "Error: No database initialized. Please process documents first."
131
+
132
  try:
133
  llm_name = list_llm[llm_option]
134
  print(f"Selected LLM model: {llm_name}")
 
152
  if not qa_chain:
153
  return None, gr.update(value="Assistant not initialized"), history, "", 0, "", 0, "", 0
154
 
 
155
  lang_instruction = " (Responda em Português)" if lang == "pt" else " (Respond in English)"
156
  query = message + lang_instruction
157
 
 
160
  response = qa_chain.invoke({"question": query, "chat_history": formatted_chat_history})
161
  answer = response["answer"].split("Helpful Answer:")[-1].strip() if "Helpful Answer:" in response["answer"] else response["answer"]
162
 
 
163
  sources = response["source_documents"]
164
  source_data = [("Unknown", 0)] * 3
165
  for i, doc in enumerate(sources[:3]):
166
  source_data[i] = (doc.page_content.strip(), doc.metadata["page"] + 1)
167
 
 
168
  new_history = history + [(message, answer)]
169
  return (
170
  qa_chain, gr.update(value=""), new_history,
 
213
  slider_temperature = gr.Slider(0.01, 1.0, value=0.5, step=0.1, label="Analysis Precision")
214
  slider_maxtokens = gr.Slider(128, 9192, value=4096, step=128, label="Response Length")
215
  slider_topk = gr.Slider(1, 10, value=3, step=1, label="Analysis Diversity")
216
+ qachain_btn = gr.Button("Initialize Assistant", interactive=False) # Disabled by default
217
  llm_progress = gr.Textbox(value="Not initialized", label="Assistant Status")
218
 
219
  with gr.Column(scale=2):
 
231
 
232
  # Event Handlers
233
  language_btn.change(lambda x: "en" if x == "English" else "pt", inputs=language_btn, outputs=language)
234
+
235
+ def enable_qachain_btn(retriever, status):
236
+ return gr.update(interactive=retriever is not None and "successfully" in status)
237
+
238
+ db_btn.click(
239
+ initialize_database,
240
+ inputs=[document],
241
+ outputs=[retriever, db_progress]
242
+ ).then(
243
+ enable_qachain_btn,
244
+ inputs=[retriever, db_progress],
245
+ outputs=[qachain_btn]
246
+ )
247
+
248
+ qachain_btn.click(
249
+ initialize_LLM,
250
+ inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, retriever],
251
+ outputs=[qa_chain, llm_progress]
252
+ )
253
+
254
+ submit_btn.click(
255
+ conversation,
256
+ inputs=[qa_chain, msg, chatbot, language],
257
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page]
258
+ )
259
+ msg.submit(
260
+ conversation,
261
+ inputs=[qa_chain, msg, chatbot, language],
262
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page]
263
+ )
264
 
265
  demo.launch(debug=True)
266