arjunanand13 commited on
Commit
184e87b
1 Parent(s): e1175ed

Update app3.py

Browse files
Files changed (1) hide show
  1. app3.py +377 -193
app3.py CHANGED
@@ -16,16 +16,152 @@ from langchain_community.llms import HuggingFaceEndpoint
16
  from langchain.memory import ConversationBufferMemory
17
  from sentence_transformers import SentenceTransformer, util
18
  import torch
 
 
 
 
 
 
 
 
19
 
20
  # Constants and setup
21
  list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]
22
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
23
  api_token = os.getenv("HF_TOKEN")
24
 
 
 
 
 
 
25
  # Initialize sentence transformer for evaluation
26
  sentence_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
27
 
28
- # Text splitting strategies
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  def get_text_splitter(strategy: str, chunk_size: int = 1024, chunk_overlap: int = 64):
30
  splitters = {
31
  "recursive": RecursiveCharacterTextSplitter(
@@ -43,105 +179,38 @@ def get_text_splitter(strategy: str, chunk_size: int = 1024, chunk_overlap: int
43
  }
44
  return splitters.get(strategy)
45
 
46
- # Custom evaluation metrics
47
- def calculate_semantic_similarity(text1: str, text2: str) -> float:
48
- embeddings1 = sentence_model.encode([text1], convert_to_tensor=True)
49
- embeddings2 = sentence_model.encode([text2], convert_to_tensor=True)
50
- similarity = util.pytorch_cos_sim(embeddings1, embeddings2)
51
- return float(similarity[0][0])
52
-
53
- def evaluate_response(question: str, answer: str, ground_truth: str, contexts: List[str]) -> Dict[str, float]:
54
- # Answer similarity with ground truth
55
- answer_similarity = calculate_semantic_similarity(answer, ground_truth)
56
-
57
- # Context relevance - average similarity between question and contexts
58
- context_scores = [calculate_semantic_similarity(question, ctx) for ctx in contexts]
59
- context_relevance = np.mean(context_scores)
60
-
61
- # Answer relevance - similarity between question and answer
62
- answer_relevance = calculate_semantic_similarity(question, answer)
63
-
64
- return {
65
- "answer_similarity": answer_similarity,
66
- "context_relevance": context_relevance,
67
- "answer_relevance": answer_relevance,
68
- "average_score": np.mean([answer_similarity, context_relevance, answer_relevance])
69
- }
70
-
71
- # Load and split PDF document
72
- def load_doc(list_file_path: List[str], splitting_strategy: str = "recursive"):
73
  loaders = [PyPDFLoader(x) for x in list_file_path]
74
  pages = []
75
  for loader in loaders:
76
  pages.extend(loader.load())
77
 
78
- text_splitter = get_text_splitter(splitting_strategy)
79
  doc_splits = text_splitter.split_documents(pages)
80
  return doc_splits
81
 
82
- # Vector database creation functions
83
- def create_faiss_db(splits, embeddings):
84
- return FAISS.from_documents(splits, embeddings)
85
-
86
- def create_chroma_db(splits, embeddings):
87
- return Chroma.from_documents(splits, embeddings)
88
-
89
- def create_qdrant_db(splits, embeddings):
90
- return Qdrant.from_documents(
91
- splits,
92
- embeddings,
93
- location=":memory:",
94
- collection_name="pdf_docs"
95
- )
96
-
97
  def create_db(splits, db_choice: str = "faiss"):
98
  embeddings = HuggingFaceEmbeddings()
99
  db_creators = {
100
- "faiss": create_faiss_db,
101
- "chroma": create_chroma_db,
102
- "qdrant": create_qdrant_db
103
- }
104
- return db_creators[db_choice](splits, embeddings)
105
-
106
- def load_evaluation_dataset():
107
- dataset = load_dataset("explodinggradients/fiqa", split="test", trust_remote_code=True)
108
- return dataset
109
-
110
- def evaluate_rag_pipeline(qa_chain, dataset):
111
- # Sample a few examples for evaluation
112
- eval_samples = dataset.select(range(5))
113
-
114
- results = []
115
- for sample in eval_samples:
116
- question = sample["question"]
117
-
118
- # Get response from the chain
119
- response = qa_chain.invoke({
120
- "question": question,
121
- "chat_history": []
122
- })
123
-
124
- # Evaluate response
125
- eval_result = evaluate_response(
126
- question=question,
127
- answer=response["answer"],
128
- ground_truth=sample["answer"],
129
- contexts=[doc.page_content for doc in response["source_documents"]]
130
  )
131
-
132
- results.append(eval_result)
133
-
134
- # Calculate average scores across all samples
135
- avg_results = {
136
- metric: float(np.mean([r[metric] for r in results]))
137
- for metric in results[0].keys()
138
  }
139
-
140
- return avg_results
 
 
 
 
 
141
 
142
- # Initialize langchain LLM chain
143
  def initialize_llmchain(llm_choice, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
144
- # Get the full model name from the index
145
  llm_model = list_llm[llm_choice]
146
 
147
  llm = HuggingFaceEndpoint(
@@ -149,8 +218,7 @@ def initialize_llmchain(llm_choice, temperature, max_tokens, top_k, vector_db, p
149
  huggingfacehub_api_token=api_token,
150
  temperature=temperature,
151
  max_new_tokens=max_tokens,
152
- top_k=top_k,
153
- model=llm_model # Add model parameter
154
  )
155
 
156
  memory = ConversationBufferMemory(
@@ -163,162 +231,278 @@ def initialize_llmchain(llm_choice, temperature, max_tokens, top_k, vector_db, p
163
  qa_chain = ConversationalRetrievalChain.from_llm(
164
  llm,
165
  retriever=retriever,
166
- chain_type="stuff",
167
  memory=memory,
168
- return_source_documents=True,
169
- verbose=False,
170
  )
171
  return qa_chain, "LLM initialized successfully!"
172
 
173
- def initialize_database(list_file_obj, splitting_strategy, db_choice, progress=gr.Progress()):
174
- list_file_path = [x.name for x in list_file_obj if x is not None]
175
- doc_splits = load_doc(list_file_path, splitting_strategy)
176
- vector_db = create_db(doc_splits, db_choice)
177
- return vector_db, f"Database created using {splitting_strategy} splitting and {db_choice} vector database!"
178
-
179
- def format_chat_history(message, chat_history):
180
- formatted_chat_history = []
181
- for user_message, bot_message in chat_history:
182
- formatted_chat_history.append(f"User: {user_message}")
183
- formatted_chat_history.append(f"Assistant: {bot_message}")
184
- return formatted_chat_history
185
-
186
  def conversation(qa_chain, message, history):
187
- formatted_chat_history = format_chat_history(message, history)
188
  response = qa_chain.invoke({
189
  "question": message,
190
- "chat_history": formatted_chat_history
191
  })
192
 
193
  response_answer = response["answer"]
194
- if response_answer.find("Helpful Answer:") != -1:
195
  response_answer = response_answer.split("Helpful Answer:")[-1]
196
 
197
- response_sources = response["source_documents"]
198
- response_source1 = response_sources[0].page_content.strip()
199
- response_source2 = response_sources[1].page_content.strip()
200
- response_source3 = response_sources[2].page_content.strip()
201
 
202
- response_source1_page = response_sources[0].metadata["page"] + 1
203
- response_source2_page = response_sources[1].metadata["page"] + 1
204
- response_source3_page = response_sources[2].metadata["page"] + 1
 
205
 
206
- new_history = history + [(message, response_answer)]
 
 
 
207
 
208
- return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
 
 
 
 
 
 
 
 
 
 
 
209
 
210
  def demo():
 
 
211
  with gr.Blocks(theme=gr.themes.Default(primary_hue="red", secondary_hue="pink", neutral_hue="sky")) as demo:
212
  vector_db = gr.State()
213
  qa_chain = gr.State()
214
 
215
- gr.HTML("<center><h1>Enhanced RAG PDF Chatbot</h1></center>")
216
- gr.Markdown("""<b>Query your PDF documents with advanced RAG capabilities!</b>""")
217
 
218
- with gr.Row():
219
- with gr.Column(scale=86):
220
- gr.Markdown("<b>Step 1 - Configure and Initialize RAG Pipeline</b>")
221
  with gr.Row():
222
- document = gr.Files(height=300, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload PDF documents")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
  with gr.Row():
225
- splitting_strategy = gr.Radio(
 
 
 
226
  ["recursive", "fixed", "token"],
227
  label="Text Splitting Strategy",
228
  value="recursive"
229
  )
230
- db_choice = gr.Radio(
231
- ["faiss", "chroma", "qdrant"],
232
- label="Vector Database",
233
- value="faiss"
234
  )
235
 
236
  with gr.Row():
237
- db_btn = gr.Button("Create vector database")
238
- evaluate_btn = gr.Button("Evaluate RAG Pipeline")
239
-
240
- with gr.Row():
241
- db_progress = gr.Textbox(value="Not initialized", show_label=False)
242
- evaluation_results = gr.JSON(label="Evaluation Results")
243
-
244
- gr.Markdown("<b>Select Large Language Model (LLM) and input parameters</b>")
245
- with gr.Row():
246
- llm_btn = gr.Radio(list_llm_simple, label="Available LLMs", value=list_llm_simple[0], type="index")
247
-
248
- with gr.Row():
249
- with gr.Accordion("LLM input parameters", open=False):
250
- slider_temperature = gr.Slider(minimum=0.01, maximum=1.0, value=0.5, step=0.1, label="Temperature")
251
- slider_maxtokens = gr.Slider(minimum=128, maximum=9192, value=4096, step=128, label="Max New Tokens")
252
- slider_topk = gr.Slider(minimum=1, maximum=10, value=3, step=1, label="top-k")
253
-
254
- with gr.Row():
255
- qachain_btn = gr.Button("Initialize Question Answering Chatbot")
256
- llm_progress = gr.Textbox(value="Not initialized", show_label=False)
257
-
258
- with gr.Column(scale=200):
259
- gr.Markdown("<b>Step 2 - Chat with your Document</b>")
260
- chatbot = gr.Chatbot(height=505)
261
-
262
- with gr.Accordion("Relevant context from the source document", open=False):
263
- with gr.Row():
264
- doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
265
- source1_page = gr.Number(label="Page", scale=1)
266
- with gr.Row():
267
- doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
268
- source2_page = gr.Number(label="Page", scale=1)
269
- with gr.Row():
270
- doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
271
- source3_page = gr.Number(label="Page", scale=1)
272
-
273
- with gr.Row():
274
- msg = gr.Textbox(placeholder="Ask a question", container=True)
275
- with gr.Row():
276
- submit_btn = gr.Button("Submit")
277
- clear_btn = gr.ClearButton([msg, chatbot], value="Clear")
278
-
279
  # Event handlers
280
  db_btn.click(
281
  initialize_database,
282
- inputs=[document, splitting_strategy, db_choice],
283
  outputs=[vector_db, db_progress]
284
  )
285
 
286
- evaluate_btn.click(
287
- lambda qa_chain: evaluate_rag_pipeline(qa_chain, load_evaluation_dataset()) if qa_chain else None,
288
- inputs=[qa_chain],
289
- outputs=[evaluation_results]
290
- )
291
-
292
- qachain_btn.click(
293
- initialize_llmchain, # Fixed function name here
294
- inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db],
295
  outputs=[qa_chain, llm_progress]
296
- ).then(
297
- lambda: [None, "", 0, "", 0, "", 0],
298
- inputs=None,
299
- outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
300
- queue=False
301
  )
302
 
303
- msg.submit(conversation,
 
304
  inputs=[qa_chain, msg, chatbot],
305
- outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
306
- queue=False
307
  )
308
 
309
- submit_btn.click(conversation,
 
310
  inputs=[qa_chain, msg, chatbot],
311
- outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
312
- queue=False
313
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  clear_btn.click(
316
  lambda: [None, "", 0, "", 0, "", 0],
317
- inputs=None,
318
- outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
319
- queue=False
320
  )
321
-
 
322
  demo.queue().launch(debug=True)
323
 
324
  if __name__ == "__main__":
 
16
  from langchain.memory import ConversationBufferMemory
17
  from sentence_transformers import SentenceTransformer, util
18
  import torch
19
+ from ragas import evaluate
20
+ from ragas.metrics import (
21
+ ContextRecall,
22
+ AnswerRelevancy,
23
+ Faithfulness,
24
+ ContextPrecision
25
+ )
26
+ import pandas as pd
27
 
28
  # Constants and setup
29
  list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]
30
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
31
  api_token = os.getenv("HF_TOKEN")
32
 
33
+ CHUNK_SIZES = {
34
+ "small": {"recursive": 512, "fixed": 512, "token": 256},
35
+ "medium": {"recursive": 1024, "fixed": 1024, "token": 512}
36
+ }
37
+
38
  # Initialize sentence transformer for evaluation
39
  sentence_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
40
 
41
+ class RAGEvaluator:
42
+ def __init__(self):
43
+ self.datasets = {
44
+ "squad": "squad_v2",
45
+ "msmarco": "ms_marco"
46
+ }
47
+ self.current_dataset = None
48
+ self.test_samples = []
49
+
50
+ def load_dataset(self, dataset_name: str, num_samples: int = 10):
51
+ """Load a smaller subset of questions with proper error handling"""
52
+ try:
53
+ if dataset_name == "squad":
54
+ dataset = load_dataset("squad_v2", split="validation")
55
+ # Select diverse questions
56
+ samples = dataset.select(range(0, 1000, 100))[:num_samples]
57
+
58
+ self.test_samples = []
59
+ for sample in samples:
60
+ # Check if answers exist and are not empty
61
+ if sample.get("answers") and isinstance(sample["answers"], dict) and sample["answers"].get("text"):
62
+ self.test_samples.append({
63
+ "question": sample["question"],
64
+ "ground_truth": sample["answers"]["text"][0],
65
+ "context": sample["context"]
66
+ })
67
+
68
+ elif dataset_name == "msmarco":
69
+ dataset = load_dataset("ms_marco", "v2.1", split="dev")
70
+ samples = dataset.select(range(0, 1000, 100))[:num_samples]
71
+
72
+ self.test_samples = []
73
+ for sample in samples:
74
+ # Check for valid answers
75
+ if sample.get("answers") and sample["answers"]:
76
+ self.test_samples.append({
77
+ "question": sample["query"],
78
+ "ground_truth": sample["answers"][0],
79
+ "context": sample["passages"][0]["passage_text"]
80
+ if isinstance(sample["passages"], list)
81
+ else sample["passages"]["passage_text"][0]
82
+ })
83
+
84
+ self.current_dataset = dataset_name
85
+
86
+ # Return dataset info
87
+ return {
88
+ "dataset": dataset_name,
89
+ "num_samples": len(self.test_samples),
90
+ "sample_questions": [s["question"] for s in self.test_samples[:3]],
91
+ "status": "success"
92
+ }
93
+
94
+ except Exception as e:
95
+ print(f"Error loading dataset: {str(e)}")
96
+ return {
97
+ "dataset": dataset_name,
98
+ "error": str(e),
99
+ "status": "failed"
100
+ }
101
+
102
+ def evaluate_configuration(self, vector_db, qa_chain, splitting_strategy: str, chunk_size: str) -> Dict:
103
+ """Evaluate with progress tracking and error handling"""
104
+ if not self.test_samples:
105
+ return {"error": "No dataset loaded"}
106
+
107
+ results = []
108
+ total_questions = len(self.test_samples)
109
+
110
+ # Add progress tracking
111
+ for i, sample in enumerate(self.test_samples):
112
+ print(f"Evaluating question {i+1}/{total_questions}")
113
+
114
+ try:
115
+ response = qa_chain.invoke({
116
+ "question": sample["question"],
117
+ "chat_history": []
118
+ })
119
+
120
+ results.append({
121
+ "question": sample["question"],
122
+ "answer": response["answer"],
123
+ "contexts": [doc.page_content for doc in response["source_documents"]],
124
+ "ground_truths": [sample["ground_truth"]]
125
+ })
126
+ except Exception as e:
127
+ print(f"Error processing question {i+1}: {str(e)}")
128
+ continue
129
+
130
+ if not results:
131
+ return {
132
+ "configuration": f"{splitting_strategy}_{chunk_size}",
133
+ "error": "No successful evaluations",
134
+ "questions_evaluated": 0
135
+ }
136
+
137
+ try:
138
+ # Calculate RAGAS metrics
139
+ eval_dataset = Dataset.from_list(results)
140
+ metrics = [ContextRecall(), AnswerRelevancy(), Faithfulness(), ContextPrecision()]
141
+ scores = evaluate(eval_dataset, metrics=metrics)
142
+
143
+ return {
144
+ "configuration": f"{splitting_strategy}_{chunk_size}",
145
+ "questions_evaluated": len(results),
146
+ "context_recall": float(scores['context_recall']),
147
+ "answer_relevancy": float(scores['answer_relevancy']),
148
+ "faithfulness": float(scores['faithfulness']),
149
+ "context_precision": float(scores['context_precision']),
150
+ "average_score": float(np.mean([
151
+ scores['context_recall'],
152
+ scores['answer_relevancy'],
153
+ scores['faithfulness'],
154
+ scores['context_precision']
155
+ ]))
156
+ }
157
+ except Exception as e:
158
+ return {
159
+ "configuration": f"{splitting_strategy}_{chunk_size}",
160
+ "error": str(e),
161
+ "questions_evaluated": len(results)
162
+ }
163
+
164
+ # Text splitting and database functions
165
  def get_text_splitter(strategy: str, chunk_size: int = 1024, chunk_overlap: int = 64):
166
  splitters = {
167
  "recursive": RecursiveCharacterTextSplitter(
 
179
  }
180
  return splitters.get(strategy)
181
 
182
+ def load_doc(list_file_path: List[str], splitting_strategy: str, chunk_size: str):
183
+ chunk_size_value = CHUNK_SIZES[chunk_size][splitting_strategy]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
184
  loaders = [PyPDFLoader(x) for x in list_file_path]
185
  pages = []
186
  for loader in loaders:
187
  pages.extend(loader.load())
188
 
189
+ text_splitter = get_text_splitter(splitting_strategy, chunk_size_value)
190
  doc_splits = text_splitter.split_documents(pages)
191
  return doc_splits
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  def create_db(splits, db_choice: str = "faiss"):
194
  embeddings = HuggingFaceEmbeddings()
195
  db_creators = {
196
+ "faiss": lambda: FAISS.from_documents(splits, embeddings),
197
+ "chroma": lambda: Chroma.from_documents(splits, embeddings),
198
+ "qdrant": lambda: Qdrant.from_documents(
199
+ splits,
200
+ embeddings,
201
+ location=":memory:",
202
+ collection_name="pdf_docs"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  )
 
 
 
 
 
 
 
204
  }
205
+ return db_creators[db_choice]()
206
+
207
+ def initialize_database(list_file_obj, splitting_strategy, chunk_size, db_choice, progress=gr.Progress()):
208
+ list_file_path = [x.name for x in list_file_obj if x is not None]
209
+ doc_splits = load_doc(list_file_path, splitting_strategy, chunk_size)
210
+ vector_db = create_db(doc_splits, db_choice)
211
+ return vector_db, f"Database created using {splitting_strategy} splitting and {db_choice} vector database!"
212
 
 
213
  def initialize_llmchain(llm_choice, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
 
214
  llm_model = list_llm[llm_choice]
215
 
216
  llm = HuggingFaceEndpoint(
 
218
  huggingfacehub_api_token=api_token,
219
  temperature=temperature,
220
  max_new_tokens=max_tokens,
221
+ top_k=top_k
 
222
  )
223
 
224
  memory = ConversationBufferMemory(
 
231
  qa_chain = ConversationalRetrievalChain.from_llm(
232
  llm,
233
  retriever=retriever,
 
234
  memory=memory,
235
+ return_source_documents=True
 
236
  )
237
  return qa_chain, "LLM initialized successfully!"
238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  def conversation(qa_chain, message, history):
240
+ """Fixed conversation function returning all required outputs"""
241
  response = qa_chain.invoke({
242
  "question": message,
243
+ "chat_history": [(hist[0], hist[1]) for hist in history]
244
  })
245
 
246
  response_answer = response["answer"]
247
+ if "Helpful Answer:" in response_answer:
248
  response_answer = response_answer.split("Helpful Answer:")[-1]
249
 
250
+ # Get source documents, ensure we have exactly 3
251
+ sources = response["source_documents"][:3]
252
+ source_contents = []
253
+ source_pages = []
254
 
255
+ # Process available sources
256
+ for source in sources:
257
+ source_contents.append(source.page_content.strip())
258
+ source_pages.append(source.metadata.get("page", 0) + 1)
259
 
260
+ # Pad with empty values if we have fewer than 3 sources
261
+ while len(source_contents) < 3:
262
+ source_contents.append("")
263
+ source_pages.append(0)
264
 
265
+ # Return all required outputs in correct order
266
+ return (
267
+ qa_chain, # State
268
+ gr.update(value=""), # Clear message box
269
+ history + [(message, response_answer)], # Updated chat history
270
+ source_contents[0], # First source
271
+ source_pages[0], # First page
272
+ source_contents[1], # Second source
273
+ source_pages[1], # Second page
274
+ source_contents[2], # Third source
275
+ source_pages[2] # Third page
276
+ )
277
 
278
  def demo():
279
+ evaluator = RAGEvaluator()
280
+
281
  with gr.Blocks(theme=gr.themes.Default(primary_hue="red", secondary_hue="pink", neutral_hue="sky")) as demo:
282
  vector_db = gr.State()
283
  qa_chain = gr.State()
284
 
285
+ gr.HTML("<center><h1>Enhanced RAG PDF Chatbot with Evaluation</h1></center>")
 
286
 
287
+ with gr.Tabs():
288
+ # Custom PDF Tab
289
+ with gr.Tab("Custom PDF Chat"):
290
  with gr.Row():
291
+ with gr.Column(scale=86):
292
+ gr.Markdown("<b>Step 1 - Configure and Initialize RAG Pipeline</b>")
293
+ with gr.Row():
294
+ document = gr.Files(
295
+ height=300,
296
+ file_count="multiple",
297
+ file_types=["pdf"],
298
+ interactive=True,
299
+ label="Upload PDF documents"
300
+ )
301
+
302
+ with gr.Row():
303
+ splitting_strategy = gr.Radio(
304
+ ["recursive", "fixed", "token"],
305
+ label="Text Splitting Strategy",
306
+ value="recursive"
307
+ )
308
+ db_choice = gr.Radio(
309
+ ["faiss", "chroma", "qdrant"],
310
+ label="Vector Database",
311
+ value="faiss"
312
+ )
313
+ chunk_size = gr.Radio(
314
+ ["small", "medium"],
315
+ label="Chunk Size",
316
+ value="medium"
317
+ )
318
+
319
+ with gr.Row():
320
+ db_btn = gr.Button("Create vector database")
321
+ db_progress = gr.Textbox(
322
+ value="Not initialized",
323
+ show_label=False
324
+ )
325
+
326
+ gr.Markdown("<b>Step 2 - Configure LLM</b>")
327
+ with gr.Row():
328
+ llm_choice = gr.Radio(
329
+ list_llm_simple,
330
+ label="Available LLMs",
331
+ value=list_llm_simple[0],
332
+ type="index"
333
+ )
334
+
335
+ with gr.Row():
336
+ with gr.Accordion("LLM Parameters", open=False):
337
+ temperature = gr.Slider(
338
+ minimum=0.01,
339
+ maximum=1.0,
340
+ value=0.5,
341
+ step=0.1,
342
+ label="Temperature"
343
+ )
344
+ max_tokens = gr.Slider(
345
+ minimum=128,
346
+ maximum=4096,
347
+ value=2048,
348
+ step=128,
349
+ label="Max Tokens"
350
+ )
351
+ top_k = gr.Slider(
352
+ minimum=1,
353
+ maximum=10,
354
+ value=3,
355
+ step=1,
356
+ label="Top K"
357
+ )
358
+
359
+ with gr.Row():
360
+ init_llm_btn = gr.Button("Initialize LLM")
361
+ llm_progress = gr.Textbox(
362
+ value="Not initialized",
363
+ show_label=False
364
+ )
365
+
366
+ with gr.Column(scale=200):
367
+ gr.Markdown("<b>Step 3 - Chat with Documents</b>")
368
+ chatbot = gr.Chatbot(height=505)
369
+
370
+ with gr.Accordion("Source References", open=False):
371
+ with gr.Row():
372
+ source1 = gr.Textbox(label="Source 1", lines=2)
373
+ page1 = gr.Number(label="Page")
374
+ with gr.Row():
375
+ source2 = gr.Textbox(label="Source 2", lines=2)
376
+ page2 = gr.Number(label="Page")
377
+ with gr.Row():
378
+ source3 = gr.Textbox(label="Source 3", lines=2)
379
+ page3 = gr.Number(label="Page")
380
+
381
+ with gr.Row():
382
+ msg = gr.Textbox(
383
+ placeholder="Ask a question",
384
+ show_label=False
385
+ )
386
+ with gr.Row():
387
+ submit_btn = gr.Button("Submit")
388
+ clear_btn = gr.ClearButton(
389
+ [msg, chatbot],
390
+ value="Clear Chat"
391
+ )
392
+
393
+ # Evaluation Tab
394
+ with gr.Tab("RAG Evaluation"):
395
+ with gr.Row():
396
+ dataset_choice = gr.Dropdown(
397
+ choices=list(evaluator.datasets.keys()),
398
+ label="Select Evaluation Dataset",
399
+ value="squad"
400
+ )
401
+ load_dataset_btn = gr.Button("Load Dataset")
402
 
403
  with gr.Row():
404
+ dataset_info = gr.JSON(label="Dataset Information")
405
+
406
+ with gr.Row():
407
+ eval_splitting_strategy = gr.Radio(
408
  ["recursive", "fixed", "token"],
409
  label="Text Splitting Strategy",
410
  value="recursive"
411
  )
412
+ eval_chunk_size = gr.Radio(
413
+ ["small", "medium"],
414
+ label="Chunk Size",
415
+ value="medium"
416
  )
417
 
418
  with gr.Row():
419
+ evaluate_btn = gr.Button("Run Evaluation")
420
+ evaluation_results = gr.DataFrame(label="Evaluation Results")
421
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
  # Event handlers
423
  db_btn.click(
424
  initialize_database,
425
+ inputs=[document, splitting_strategy, chunk_size, db_choice],
426
  outputs=[vector_db, db_progress]
427
  )
428
 
429
+ init_llm_btn.click(
430
+ initialize_llmchain,
431
+ inputs=[llm_choice, temperature, max_tokens, top_k, vector_db],
 
 
 
 
 
 
432
  outputs=[qa_chain, llm_progress]
 
 
 
 
 
433
  )
434
 
435
+ msg.submit(
436
+ conversation,
437
  inputs=[qa_chain, msg, chatbot],
438
+ outputs=[qa_chain, msg, chatbot, source1, page1, source2, page2, source3, page3]
 
439
  )
440
 
441
+ submit_btn.click(
442
+ conversation,
443
  inputs=[qa_chain, msg, chatbot],
444
+ outputs=[qa_chain, msg, chatbot, source1, page1, source2, page2, source3, page3]
 
445
  )
446
+
447
+ def load_dataset_handler(dataset_name):
448
+ try:
449
+ result = evaluator.load_dataset(dataset_name)
450
+ if result.get("status") == "success":
451
+ return {
452
+ "dataset": result["dataset"],
453
+ "samples_loaded": result["num_samples"],
454
+ "example_questions": result["sample_questions"],
455
+ "status": "ready for evaluation"
456
+ }
457
+ else:
458
+ return {
459
+ "error": result.get("error", "Unknown error occurred"),
460
+ "status": "failed to load dataset"
461
+ }
462
+ except Exception as e:
463
+ return {
464
+ "error": str(e),
465
+ "status": "failed to load dataset"
466
+ }
467
 
468
+ def run_evaluation(dataset_choice, splitting_strategy, chunk_size, vector_db, qa_chain):
469
+ if not evaluator.current_dataset:
470
+ return pd.DataFrame()
471
+
472
+ results = evaluator.evaluate_configuration(
473
+ vector_db=vector_db,
474
+ qa_chain=qa_chain,
475
+ splitting_strategy=splitting_strategy,
476
+ chunk_size=chunk_size
477
+ )
478
+
479
+ return pd.DataFrame([results])
480
+
481
+ load_dataset_btn.click(
482
+ load_dataset_handler,
483
+ inputs=[dataset_choice],
484
+ outputs=[dataset_info]
485
+ )
486
+
487
+ evaluate_btn.click(
488
+ run_evaluation,
489
+ inputs=[
490
+ dataset_choice,
491
+ eval_splitting_strategy,
492
+ eval_chunk_size,
493
+ vector_db,
494
+ qa_chain
495
+ ],
496
+ outputs=[evaluation_results]
497
+ )
498
+
499
+ # Clear button handlers
500
  clear_btn.click(
501
  lambda: [None, "", 0, "", 0, "", 0],
502
+ outputs=[chatbot, source1, page1, source2, page2, source3, page3]
 
 
503
  )
504
+
505
+ # Launch the demo
506
  demo.queue().launch(debug=True)
507
 
508
  if __name__ == "__main__":