Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -11,8 +11,8 @@ from langchain.text_splitter import (
|
|
11 |
from langchain_community.vectorstores import FAISS, Chroma, Qdrant
|
12 |
from langchain_community.document_loaders import PyPDFLoader
|
13 |
from langchain.chains import ConversationalRetrievalChain
|
14 |
-
from langchain_community.embeddings import HuggingFaceEmbeddings
|
15 |
-
from
|
16 |
from langchain.memory import ConversationBufferMemory
|
17 |
from sentence_transformers import SentenceTransformer, util
|
18 |
import torch
|
@@ -48,53 +48,46 @@ class RAGEvaluator:
|
|
48 |
self.test_samples = []
|
49 |
|
50 |
def load_dataset(self, dataset_name: str, num_samples: int = 10):
|
51 |
-
"""Load
|
52 |
try:
|
53 |
if dataset_name == "squad":
|
54 |
dataset = load_dataset("squad_v2", split="validation")
|
55 |
-
# Select diverse questions
|
56 |
samples = dataset.select(range(0, 1000, 100))[:num_samples]
|
57 |
|
58 |
self.test_samples = []
|
59 |
for sample in samples:
|
60 |
-
#
|
61 |
-
|
|
|
62 |
self.test_samples.append({
|
63 |
"question": sample["question"],
|
64 |
-
"ground_truth":
|
65 |
"context": sample["context"]
|
66 |
})
|
67 |
|
68 |
elif dataset_name == "msmarco":
|
69 |
-
dataset = load_dataset("ms_marco", "v2.1", split="
|
70 |
samples = dataset.select(range(0, 1000, 100))[:num_samples]
|
71 |
|
72 |
self.test_samples = []
|
73 |
for sample in samples:
|
74 |
-
# Check
|
75 |
-
if sample.get("answers") and sample["answers"]:
|
76 |
self.test_samples.append({
|
77 |
"question": sample["query"],
|
78 |
"ground_truth": sample["answers"][0],
|
79 |
-
"context": sample["passages"][
|
80 |
-
if isinstance(sample["passages"], list)
|
81 |
-
else sample["passages"]["passage_text"][0]
|
82 |
})
|
83 |
|
84 |
self.current_dataset = dataset_name
|
85 |
-
|
86 |
-
# Return dataset info
|
87 |
return {
|
88 |
"dataset": dataset_name,
|
89 |
-
"
|
90 |
-
"
|
91 |
-
"status": "success"
|
92 |
}
|
93 |
|
94 |
except Exception as e:
|
95 |
print(f"Error loading dataset: {str(e)}")
|
96 |
return {
|
97 |
-
"dataset": dataset_name,
|
98 |
"error": str(e),
|
99 |
"status": "failed"
|
100 |
}
|
@@ -205,36 +198,58 @@ def create_db(splits, db_choice: str = "faiss"):
|
|
205 |
return db_creators[db_choice]()
|
206 |
|
207 |
def initialize_database(list_file_obj, splitting_strategy, chunk_size, db_choice, progress=gr.Progress()):
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
|
213 |
def initialize_llmchain(llm_choice, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
|
|
|
|
|
|
|
|
|
|
229 |
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
|
|
|
|
|
|
238 |
|
239 |
def conversation(qa_chain, message, history):
|
240 |
"""Fixed conversation function returning all required outputs"""
|
@@ -424,12 +439,26 @@ def demo():
|
|
424 |
initialize_database,
|
425 |
inputs=[document, splitting_strategy, chunk_size, db_choice],
|
426 |
outputs=[vector_db, db_progress]
|
|
|
|
|
|
|
|
|
427 |
)
|
428 |
|
429 |
init_llm_btn.click(
|
430 |
initialize_llmchain,
|
431 |
inputs=[llm_choice, temperature, max_tokens, top_k, vector_db],
|
432 |
outputs=[qa_chain, llm_progress]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
433 |
)
|
434 |
|
435 |
msg.submit(
|
|
|
11 |
from langchain_community.vectorstores import FAISS, Chroma, Qdrant
|
12 |
from langchain_community.document_loaders import PyPDFLoader
|
13 |
from langchain.chains import ConversationalRetrievalChain
|
14 |
+
from langchain_community.embeddings import HuggingFaceEmbeddings
|
15 |
+
from langchain_huggingface import HuggingFaceEndpoint
|
16 |
from langchain.memory import ConversationBufferMemory
|
17 |
from sentence_transformers import SentenceTransformer, util
|
18 |
import torch
|
|
|
48 |
self.test_samples = []
|
49 |
|
50 |
def load_dataset(self, dataset_name: str, num_samples: int = 10):
|
51 |
+
"""Load dataset with proper error handling"""
|
52 |
try:
|
53 |
if dataset_name == "squad":
|
54 |
dataset = load_dataset("squad_v2", split="validation")
|
|
|
55 |
samples = dataset.select(range(0, 1000, 100))[:num_samples]
|
56 |
|
57 |
self.test_samples = []
|
58 |
for sample in samples:
|
59 |
+
# Handle SQuAD format
|
60 |
+
answers = sample["answers"]
|
61 |
+
if answers["text"]: # Check if there are answers
|
62 |
self.test_samples.append({
|
63 |
"question": sample["question"],
|
64 |
+
"ground_truth": answers["text"][0],
|
65 |
"context": sample["context"]
|
66 |
})
|
67 |
|
68 |
elif dataset_name == "msmarco":
|
69 |
+
dataset = load_dataset("ms_marco", "v2.1", split="test") # Changed from dev to test
|
70 |
samples = dataset.select(range(0, 1000, 100))[:num_samples]
|
71 |
|
72 |
self.test_samples = []
|
73 |
for sample in samples:
|
74 |
+
if sample["answers"]: # Check if answers exist
|
|
|
75 |
self.test_samples.append({
|
76 |
"question": sample["query"],
|
77 |
"ground_truth": sample["answers"][0],
|
78 |
+
"context": sample["passages"]["passage_text"][0]
|
|
|
|
|
79 |
})
|
80 |
|
81 |
self.current_dataset = dataset_name
|
|
|
|
|
82 |
return {
|
83 |
"dataset": dataset_name,
|
84 |
+
"samples_loaded": len(self.test_samples),
|
85 |
+
"example_questions": [s["question"] for s in self.test_samples[:3]]
|
|
|
86 |
}
|
87 |
|
88 |
except Exception as e:
|
89 |
print(f"Error loading dataset: {str(e)}")
|
90 |
return {
|
|
|
91 |
"error": str(e),
|
92 |
"status": "failed"
|
93 |
}
|
|
|
198 |
return db_creators[db_choice]()
|
199 |
|
200 |
def initialize_database(list_file_obj, splitting_strategy, chunk_size, db_choice, progress=gr.Progress()):
|
201 |
+
"""Initialize vector database with error handling"""
|
202 |
+
try:
|
203 |
+
if not list_file_obj:
|
204 |
+
return None, "No files uploaded. Please upload PDF documents first."
|
205 |
+
|
206 |
+
list_file_path = [x.name for x in list_file_obj if x is not None]
|
207 |
+
if not list_file_path:
|
208 |
+
return None, "No valid files found. Please upload PDF documents."
|
209 |
+
|
210 |
+
doc_splits = load_doc(list_file_path, splitting_strategy, chunk_size)
|
211 |
+
if not doc_splits:
|
212 |
+
return None, "No content extracted from documents."
|
213 |
+
|
214 |
+
vector_db = create_db(doc_splits, db_choice)
|
215 |
+
return vector_db, f"Database created successfully using {splitting_strategy} splitting and {db_choice} vector database!"
|
216 |
+
|
217 |
+
except Exception as e:
|
218 |
+
return None, f"Error creating database: {str(e)}"
|
219 |
|
220 |
def initialize_llmchain(llm_choice, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
|
221 |
+
"""Initialize LLM chain with error handling"""
|
222 |
+
try:
|
223 |
+
if vector_db is None:
|
224 |
+
return None, "Please create vector database first."
|
225 |
+
|
226 |
+
llm_model = list_llm[llm_choice]
|
227 |
+
|
228 |
+
llm = HuggingFaceEndpoint(
|
229 |
+
repo_id=llm_model,
|
230 |
+
huggingfacehub_api_token=api_token,
|
231 |
+
temperature=temperature,
|
232 |
+
max_new_tokens=max_tokens,
|
233 |
+
top_k=top_k
|
234 |
+
)
|
235 |
+
|
236 |
+
memory = ConversationBufferMemory(
|
237 |
+
memory_key="chat_history",
|
238 |
+
output_key='answer',
|
239 |
+
return_messages=True
|
240 |
+
)
|
241 |
|
242 |
+
retriever = vector_db.as_retriever()
|
243 |
+
qa_chain = ConversationalRetrievalChain.from_llm(
|
244 |
+
llm,
|
245 |
+
retriever=retriever,
|
246 |
+
memory=memory,
|
247 |
+
return_source_documents=True
|
248 |
+
)
|
249 |
+
return qa_chain, "LLM initialized successfully!"
|
250 |
+
|
251 |
+
except Exception as e:
|
252 |
+
return None, f"Error initializing LLM: {str(e)}"
|
253 |
|
254 |
def conversation(qa_chain, message, history):
|
255 |
"""Fixed conversation function returning all required outputs"""
|
|
|
439 |
initialize_database,
|
440 |
inputs=[document, splitting_strategy, chunk_size, db_choice],
|
441 |
outputs=[vector_db, db_progress]
|
442 |
+
).then(
|
443 |
+
lambda x: gr.update(interactive=True) if x[0] is not None else gr.update(interactive=False),
|
444 |
+
inputs=[vector_db],
|
445 |
+
outputs=[init_llm_btn]
|
446 |
)
|
447 |
|
448 |
init_llm_btn.click(
|
449 |
initialize_llmchain,
|
450 |
inputs=[llm_choice, temperature, max_tokens, top_k, vector_db],
|
451 |
outputs=[qa_chain, llm_progress]
|
452 |
+
).then(
|
453 |
+
lambda x: gr.update(interactive=True) if x[0] is not None else gr.update(interactive=False),
|
454 |
+
inputs=[qa_chain],
|
455 |
+
outputs=[msg]
|
456 |
+
)
|
457 |
+
|
458 |
+
load_dataset_btn.click(
|
459 |
+
lambda x: evaluator.load_dataset(x),
|
460 |
+
inputs=[dataset_choice],
|
461 |
+
outputs=[dataset_info]
|
462 |
)
|
463 |
|
464 |
msg.submit(
|