Update app.py
Browse files
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
import os
|
3 |
-
import
|
4 |
from langchain_community.vectorstores import Chroma
|
5 |
from langchain_community.document_loaders import PyPDFLoader
|
6 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
@@ -12,42 +12,50 @@ from langchain_community.retrievers import BM25Retriever
|
|
12 |
from langchain.retrievers import EnsembleRetriever
|
13 |
|
14 |
# Environment variable for API token
|
15 |
-
api_token = os.getenv("
|
16 |
-
print(f"API Token loaded: {api_token[:5]}...") # Debug
|
17 |
if not api_token:
|
18 |
-
raise ValueError("Environment variable 'FirstToken' not set.
|
19 |
|
20 |
# Available LLM models
|
21 |
list_llm = [
|
22 |
-
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
23 |
"mistralai/Mistral-7B-Instruct-v0.2",
|
24 |
"deepseek-ai/deepseek-llm-7b-chat"
|
25 |
]
|
26 |
list_llm_simple = [os.path.basename(llm) for llm in list_llm]
|
27 |
|
28 |
# -----------------------------------------------------------------------------
|
29 |
-
# Document Loading and Splitting
|
30 |
# -----------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
31 |
def load_doc(list_file_path, progress=gr.Progress()):
|
32 |
-
"""Load and split PDF documents into chunks."""
|
33 |
if not list_file_path:
|
34 |
raise ValueError("No files provided for processing.")
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
pages.extend(loader.load())
|
41 |
|
42 |
-
|
43 |
-
|
|
|
|
|
44 |
|
45 |
# -----------------------------------------------------------------------------
|
46 |
-
# Vector Database Creation
|
47 |
# -----------------------------------------------------------------------------
|
48 |
-
def create_chromadb(splits, persist_directory="chroma_db"):
|
49 |
-
"""Create ChromaDB vector database
|
50 |
-
|
|
|
|
|
51 |
chromadb = Chroma.from_documents(
|
52 |
documents=splits,
|
53 |
embedding=embeddings,
|
@@ -61,13 +69,13 @@ def create_chromadb(splits, persist_directory="chroma_db"):
|
|
61 |
def create_bm25_retriever(splits):
|
62 |
"""Create BM25 retriever from document splits."""
|
63 |
retriever = BM25Retriever.from_documents(splits)
|
64 |
-
retriever.k =
|
65 |
return retriever
|
66 |
|
67 |
def create_ensemble_retriever(vector_db, bm25_retriever):
|
68 |
-
"""Create an ensemble retriever
|
69 |
return EnsembleRetriever(
|
70 |
-
retrievers=[vector_db.as_retriever(), bm25_retriever],
|
71 |
weights=[0.7, 0.3]
|
72 |
)
|
73 |
|
@@ -78,10 +86,12 @@ def initialize_database(list_file_obj, progress=gr.Progress()):
|
|
78 |
"""Initialize the document database with error handling."""
|
79 |
try:
|
80 |
list_file_path = [x.name for x in list_file_obj if x is not None]
|
|
|
81 |
doc_splits = load_doc(list_file_path, progress)
|
82 |
-
chromadb = create_chromadb(doc_splits)
|
83 |
bm25_retriever = create_bm25_retriever(doc_splits)
|
84 |
ensemble_retriever = create_ensemble_retriever(chromadb, bm25_retriever)
|
|
|
85 |
return ensemble_retriever, "Database created successfully!"
|
86 |
except Exception as e:
|
87 |
return None, f"Error initializing database: {str(e)}"
|
@@ -90,7 +100,7 @@ def initialize_database(list_file_obj, progress=gr.Progress()):
|
|
90 |
# Initialize LLM Chain
|
91 |
# -----------------------------------------------------------------------------
|
92 |
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, retriever):
|
93 |
-
"""Initialize the language model chain
|
94 |
if retriever is None:
|
95 |
raise ValueError("Retriever is None. Please process documents first.")
|
96 |
|
@@ -211,9 +221,9 @@ def demo():
|
|
211 |
language_btn = gr.Radio(choices=["English", "Português"], label="Response Language", value="English")
|
212 |
with gr.Accordion("Advanced Settings", open=False):
|
213 |
slider_temperature = gr.Slider(0.01, 1.0, value=0.5, step=0.1, label="Analysis Precision")
|
214 |
-
slider_maxtokens = gr.Slider(128,
|
215 |
-
slider_topk = gr.Slider(1,
|
216 |
-
qachain_btn = gr.Button("Initialize Assistant", interactive=False)
|
217 |
llm_progress = gr.Textbox(value="Not initialized", label="Assistant Status")
|
218 |
|
219 |
with gr.Column(scale=2):
|
|
|
1 |
import gradio as gr
|
2 |
import os
|
3 |
+
from concurrent.futures import ThreadPoolExecutor
|
4 |
from langchain_community.vectorstores import Chroma
|
5 |
from langchain_community.document_loaders import PyPDFLoader
|
6 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
12 |
from langchain.retrievers import EnsembleRetriever
|
13 |
|
14 |
# Environment variable for API token
|
15 |
+
api_token = os.getenv("FirstToken")
|
16 |
+
print(f"API Token loaded: {api_token[:5]}...") # Debug
|
17 |
if not api_token:
|
18 |
+
raise ValueError("Environment variable 'FirstToken' not set.")
|
19 |
|
20 |
# Available LLM models
|
21 |
list_llm = [
|
22 |
+
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
23 |
"mistralai/Mistral-7B-Instruct-v0.2",
|
24 |
"deepseek-ai/deepseek-llm-7b-chat"
|
25 |
]
|
26 |
list_llm_simple = [os.path.basename(llm) for llm in list_llm]
|
27 |
|
28 |
# -----------------------------------------------------------------------------
|
29 |
+
# Document Loading and Splitting (Optimized with Threading)
|
30 |
# -----------------------------------------------------------------------------
|
31 |
+
def load_single_pdf(file_path):
|
32 |
+
"""Load a single PDF file."""
|
33 |
+
loader = PyPDFLoader(file_path)
|
34 |
+
return loader.load()
|
35 |
+
|
36 |
def load_doc(list_file_path, progress=gr.Progress()):
|
37 |
+
"""Load and split PDF documents into chunks with multi-threading."""
|
38 |
if not list_file_path:
|
39 |
raise ValueError("No files provided for processing.")
|
40 |
|
41 |
+
# Use ThreadPoolExecutor to parallelize PDF loading
|
42 |
+
with ThreadPoolExecutor() as executor:
|
43 |
+
pages = list(executor.map(load_single_pdf, list_file_path))
|
44 |
+
pages = [page for sublist in pages for page in sublist] # Flatten list
|
|
|
45 |
|
46 |
+
progress(0.5, "Splitting documents...")
|
47 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=2048, chunk_overlap=128) # Increased chunk size
|
48 |
+
doc_splits = text_splitter.split_documents(pages)
|
49 |
+
return doc_splits
|
50 |
|
51 |
# -----------------------------------------------------------------------------
|
52 |
+
# Vector Database Creation (Optimized with Lightweight Embeddings)
|
53 |
# -----------------------------------------------------------------------------
|
54 |
+
def create_chromadb(splits, persist_directory="chroma_db", progress=gr.Progress()):
|
55 |
+
"""Create ChromaDB vector database with optimized embeddings."""
|
56 |
+
# Use a lighter embedding model
|
57 |
+
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
|
58 |
+
progress(0.7, "Creating vector database...")
|
59 |
chromadb = Chroma.from_documents(
|
60 |
documents=splits,
|
61 |
embedding=embeddings,
|
|
|
69 |
def create_bm25_retriever(splits):
|
70 |
"""Create BM25 retriever from document splits."""
|
71 |
retriever = BM25Retriever.from_documents(splits)
|
72 |
+
retriever.k = 2 # Reduced to 2 documents for faster retrieval
|
73 |
return retriever
|
74 |
|
75 |
def create_ensemble_retriever(vector_db, bm25_retriever):
|
76 |
+
"""Create an ensemble retriever."""
|
77 |
return EnsembleRetriever(
|
78 |
+
retrievers=[vector_db.as_retriever(search_kwargs={"k": 2}), bm25_retriever], # Limit to 2 docs
|
79 |
weights=[0.7, 0.3]
|
80 |
)
|
81 |
|
|
|
86 |
"""Initialize the document database with error handling."""
|
87 |
try:
|
88 |
list_file_path = [x.name for x in list_file_obj if x is not None]
|
89 |
+
progress(0.1, "Loading documents...")
|
90 |
doc_splits = load_doc(list_file_path, progress)
|
91 |
+
chromadb = create_chromadb(doc_splits, progress=progress)
|
92 |
bm25_retriever = create_bm25_retriever(doc_splits)
|
93 |
ensemble_retriever = create_ensemble_retriever(chromadb, bm25_retriever)
|
94 |
+
progress(1.0, "Database creation complete!")
|
95 |
return ensemble_retriever, "Database created successfully!"
|
96 |
except Exception as e:
|
97 |
return None, f"Error initializing database: {str(e)}"
|
|
|
100 |
# Initialize LLM Chain
|
101 |
# -----------------------------------------------------------------------------
|
102 |
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, retriever):
|
103 |
+
"""Initialize the language model chain."""
|
104 |
if retriever is None:
|
105 |
raise ValueError("Retriever is None. Please process documents first.")
|
106 |
|
|
|
221 |
language_btn = gr.Radio(choices=["English", "Português"], label="Response Language", value="English")
|
222 |
with gr.Accordion("Advanced Settings", open=False):
|
223 |
slider_temperature = gr.Slider(0.01, 1.0, value=0.5, step=0.1, label="Analysis Precision")
|
224 |
+
slider_maxtokens = gr.Slider(128, 2048, value=1024, step=128, label="Response Length") # Reduced max_tokens
|
225 |
+
slider_topk = gr.Slider(1, 5, value=3, step=1, label="Analysis Diversity") # Reduced range
|
226 |
+
qachain_btn = gr.Button("Initialize Assistant", interactive=False)
|
227 |
llm_progress = gr.Textbox(value="Not initialized", label="Assistant Status")
|
228 |
|
229 |
with gr.Column(scale=2):
|