DHEIVER commited on
Commit
a3e638d
·
verified ·
1 Parent(s): 08ceb44

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -26
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import os
3
- import torch
4
  from langchain_community.vectorstores import Chroma
5
  from langchain_community.document_loaders import PyPDFLoader
6
  from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -12,42 +12,50 @@ from langchain_community.retrievers import BM25Retriever
12
  from langchain.retrievers import EnsembleRetriever
13
 
14
  # Environment variable for API token
15
- api_token = os.getenv("API_TOKEN")
16
- print(f"API Token loaded: {api_token[:5]}...") # Debug: Show first 5 chars of token
17
  if not api_token:
18
- raise ValueError("Environment variable 'FirstToken' not set. Please set the Hugging Face API token.")
19
 
20
  # Available LLM models
21
  list_llm = [
22
- "mistralai/Mixtral-8x7B-Instruct-v0.1", # Publicly accessible
23
  "mistralai/Mistral-7B-Instruct-v0.2",
24
  "deepseek-ai/deepseek-llm-7b-chat"
25
  ]
26
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
27
 
28
  # -----------------------------------------------------------------------------
29
- # Document Loading and Splitting
30
  # -----------------------------------------------------------------------------
 
 
 
 
 
31
  def load_doc(list_file_path, progress=gr.Progress()):
32
- """Load and split PDF documents into chunks."""
33
  if not list_file_path:
34
  raise ValueError("No files provided for processing.")
35
 
36
- loaders = [PyPDFLoader(x) for x in list_file_path]
37
- pages = []
38
- for i, loader in enumerate(loaders):
39
- progress((i + 1) / len(loaders), "Loading PDFs...")
40
- pages.extend(loader.load())
41
 
42
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
43
- return text_splitter.split_documents(pages)
 
 
44
 
45
  # -----------------------------------------------------------------------------
46
- # Vector Database Creation
47
  # -----------------------------------------------------------------------------
48
- def create_chromadb(splits, persist_directory="chroma_db"):
49
- """Create ChromaDB vector database from document splits."""
50
- embeddings = HuggingFaceEmbeddings()
 
 
51
  chromadb = Chroma.from_documents(
52
  documents=splits,
53
  embedding=embeddings,
@@ -61,13 +69,13 @@ def create_chromadb(splits, persist_directory="chroma_db"):
61
  def create_bm25_retriever(splits):
62
  """Create BM25 retriever from document splits."""
63
  retriever = BM25Retriever.from_documents(splits)
64
- retriever.k = 3
65
  return retriever
66
 
67
  def create_ensemble_retriever(vector_db, bm25_retriever):
68
- """Create an ensemble retriever combining vector DB and BM25."""
69
  return EnsembleRetriever(
70
- retrievers=[vector_db.as_retriever(), bm25_retriever],
71
  weights=[0.7, 0.3]
72
  )
73
 
@@ -78,10 +86,12 @@ def initialize_database(list_file_obj, progress=gr.Progress()):
78
  """Initialize the document database with error handling."""
79
  try:
80
  list_file_path = [x.name for x in list_file_obj if x is not None]
 
81
  doc_splits = load_doc(list_file_path, progress)
82
- chromadb = create_chromadb(doc_splits)
83
  bm25_retriever = create_bm25_retriever(doc_splits)
84
  ensemble_retriever = create_ensemble_retriever(chromadb, bm25_retriever)
 
85
  return ensemble_retriever, "Database created successfully!"
86
  except Exception as e:
87
  return None, f"Error initializing database: {str(e)}"
@@ -90,7 +100,7 @@ def initialize_database(list_file_obj, progress=gr.Progress()):
90
  # Initialize LLM Chain
91
  # -----------------------------------------------------------------------------
92
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, retriever):
93
- """Initialize the language model chain with error handling."""
94
  if retriever is None:
95
  raise ValueError("Retriever is None. Please process documents first.")
96
 
@@ -211,9 +221,9 @@ def demo():
211
  language_btn = gr.Radio(choices=["English", "Português"], label="Response Language", value="English")
212
  with gr.Accordion("Advanced Settings", open=False):
213
  slider_temperature = gr.Slider(0.01, 1.0, value=0.5, step=0.1, label="Analysis Precision")
214
- slider_maxtokens = gr.Slider(128, 9192, value=4096, step=128, label="Response Length")
215
- slider_topk = gr.Slider(1, 10, value=3, step=1, label="Analysis Diversity")
216
- qachain_btn = gr.Button("Initialize Assistant", interactive=False) # Disabled by default
217
  llm_progress = gr.Textbox(value="Not initialized", label="Assistant Status")
218
 
219
  with gr.Column(scale=2):
 
1
  import gradio as gr
2
  import os
3
+ from concurrent.futures import ThreadPoolExecutor
4
  from langchain_community.vectorstores import Chroma
5
  from langchain_community.document_loaders import PyPDFLoader
6
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
12
  from langchain.retrievers import EnsembleRetriever
13
 
14
  # Environment variable for API token
15
+ api_token = os.getenv("FirstToken")
16
+ print(f"API Token loaded: {api_token[:5]}...") # Debug
17
  if not api_token:
18
+ raise ValueError("Environment variable 'FirstToken' not set.")
19
 
20
  # Available LLM models
21
  list_llm = [
22
+ "mistralai/Mixtral-8x7B-Instruct-v0.1",
23
  "mistralai/Mistral-7B-Instruct-v0.2",
24
  "deepseek-ai/deepseek-llm-7b-chat"
25
  ]
26
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
27
 
28
  # -----------------------------------------------------------------------------
29
+ # Document Loading and Splitting (Optimized with Threading)
30
  # -----------------------------------------------------------------------------
31
+ def load_single_pdf(file_path):
32
+ """Load a single PDF file."""
33
+ loader = PyPDFLoader(file_path)
34
+ return loader.load()
35
+
36
  def load_doc(list_file_path, progress=gr.Progress()):
37
+ """Load and split PDF documents into chunks with multi-threading."""
38
  if not list_file_path:
39
  raise ValueError("No files provided for processing.")
40
 
41
+ # Use ThreadPoolExecutor to parallelize PDF loading
42
+ with ThreadPoolExecutor() as executor:
43
+ pages = list(executor.map(load_single_pdf, list_file_path))
44
+ pages = [page for sublist in pages for page in sublist] # Flatten list
 
45
 
46
+ progress(0.5, "Splitting documents...")
47
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=2048, chunk_overlap=128) # Increased chunk size
48
+ doc_splits = text_splitter.split_documents(pages)
49
+ return doc_splits
50
 
51
  # -----------------------------------------------------------------------------
52
+ # Vector Database Creation (Optimized with Lightweight Embeddings)
53
  # -----------------------------------------------------------------------------
54
+ def create_chromadb(splits, persist_directory="chroma_db", progress=gr.Progress()):
55
+ """Create ChromaDB vector database with optimized embeddings."""
56
+ # Use a lighter embedding model
57
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
58
+ progress(0.7, "Creating vector database...")
59
  chromadb = Chroma.from_documents(
60
  documents=splits,
61
  embedding=embeddings,
 
69
  def create_bm25_retriever(splits):
70
  """Create BM25 retriever from document splits."""
71
  retriever = BM25Retriever.from_documents(splits)
72
+ retriever.k = 2 # Reduced to 2 documents for faster retrieval
73
  return retriever
74
 
75
  def create_ensemble_retriever(vector_db, bm25_retriever):
76
+ """Create an ensemble retriever."""
77
  return EnsembleRetriever(
78
+ retrievers=[vector_db.as_retriever(search_kwargs={"k": 2}), bm25_retriever], # Limit to 2 docs
79
  weights=[0.7, 0.3]
80
  )
81
 
 
86
  """Initialize the document database with error handling."""
87
  try:
88
  list_file_path = [x.name for x in list_file_obj if x is not None]
89
+ progress(0.1, "Loading documents...")
90
  doc_splits = load_doc(list_file_path, progress)
91
+ chromadb = create_chromadb(doc_splits, progress=progress)
92
  bm25_retriever = create_bm25_retriever(doc_splits)
93
  ensemble_retriever = create_ensemble_retriever(chromadb, bm25_retriever)
94
+ progress(1.0, "Database creation complete!")
95
  return ensemble_retriever, "Database created successfully!"
96
  except Exception as e:
97
  return None, f"Error initializing database: {str(e)}"
 
100
  # Initialize LLM Chain
101
  # -----------------------------------------------------------------------------
102
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, retriever):
103
+ """Initialize the language model chain."""
104
  if retriever is None:
105
  raise ValueError("Retriever is None. Please process documents first.")
106
 
 
221
  language_btn = gr.Radio(choices=["English", "Português"], label="Response Language", value="English")
222
  with gr.Accordion("Advanced Settings", open=False):
223
  slider_temperature = gr.Slider(0.01, 1.0, value=0.5, step=0.1, label="Analysis Precision")
224
+ slider_maxtokens = gr.Slider(128, 2048, value=1024, step=128, label="Response Length") # Reduced max_tokens
225
+ slider_topk = gr.Slider(1, 5, value=3, step=1, label="Analysis Diversity") # Reduced range
226
+ qachain_btn = gr.Button("Initialize Assistant", interactive=False)
227
  llm_progress = gr.Textbox(value="Not initialized", label="Assistant Status")
228
 
229
  with gr.Column(scale=2):