arjunanand13 commited on
Commit
8b077e8
1 Parent(s): e26ea16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +225 -62
app.py CHANGED
@@ -8,7 +8,7 @@ from langchain.text_splitter import (
8
  CharacterTextSplitter,
9
  TokenTextSplitter
10
  )
11
- from langchain_community.vectorstores import FAISS, Chroma
12
  from langchain_community.document_loaders import PyPDFLoader
13
  from langchain.chains import ConversationalRetrievalChain
14
  from langchain_community.embeddings import HuggingFaceEmbeddings
@@ -25,12 +25,19 @@ from ragas.metrics import (
25
  )
26
  import pandas as pd
27
 
28
- # Constants and configurations
 
 
 
 
29
  CHUNK_SIZES = {
30
  "small": {"recursive": 512, "fixed": 512, "token": 256},
31
  "medium": {"recursive": 1024, "fixed": 1024, "token": 512}
32
  }
33
 
 
 
 
34
  class RAGEvaluator:
35
  def __init__(self):
36
  self.datasets = {
@@ -51,7 +58,7 @@ class RAGEvaluator:
51
  "context": sample["context"]
52
  }
53
  for sample in samples
54
- if sample["answers"]["text"] # Filter out samples without answers
55
  ]
56
  elif dataset_name == "msmarco":
57
  dataset = load_dataset("ms_marco", "v2.1", split="train")
@@ -63,16 +70,12 @@ class RAGEvaluator:
63
  "context": sample["passages"]["passage_text"][0]
64
  }
65
  for sample in samples
66
- if sample["answers"] # Filter out samples without answers
67
  ]
68
  self.current_dataset = dataset_name
69
  return self.test_samples
70
-
71
- def evaluate_configuration(self,
72
- vector_db,
73
- qa_chain,
74
- splitting_strategy: str,
75
- chunk_size: str) -> Dict:
76
  if not self.test_samples:
77
  return {"error": "No dataset loaded"}
78
 
@@ -90,21 +93,9 @@ class RAGEvaluator:
90
  "ground_truths": [sample["ground_truth"]]
91
  })
92
 
93
- # Convert to RAGAS dataset format
94
  eval_dataset = Dataset.from_list(results)
95
-
96
- # Calculate RAGAS metrics
97
- metrics = [
98
- ContextRecall(),
99
- AnswerRelevancy(),
100
- Faithfulness(),
101
- ContextPrecision()
102
- ]
103
-
104
- scores = evaluate(
105
- eval_dataset,
106
- metrics=metrics
107
- )
108
 
109
  return {
110
  "configuration": f"{splitting_strategy}_{chunk_size}",
@@ -120,6 +111,102 @@ class RAGEvaluator:
120
  ]))
121
  }
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  def demo():
124
  evaluator = RAGEvaluator()
125
 
@@ -132,12 +219,17 @@ def demo():
132
  with gr.Tabs():
133
  # Custom PDF Tab
134
  with gr.Tab("Custom PDF Chat"):
135
- # Your existing UI components here
136
  with gr.Row():
137
  with gr.Column(scale=86):
138
  gr.Markdown("<b>Step 1 - Configure and Initialize RAG Pipeline</b>")
139
  with gr.Row():
140
- document = gr.Files(height=300, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload PDF documents")
 
 
 
 
 
 
141
 
142
  with gr.Row():
143
  splitting_strategy = gr.Radio(
@@ -145,8 +237,8 @@ def demo():
145
  label="Text Splitting Strategy",
146
  value="recursive"
147
  )
148
- db_choice = gr.Dropdown(
149
- ["faiss", "chroma"],
150
  label="Vector Database",
151
  value="faiss"
152
  )
@@ -156,7 +248,79 @@ def demo():
156
  value="medium"
157
  )
158
 
159
- # Rest of your existing UI components...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
  # Evaluation Tab
162
  with gr.Tab("RAG Evaluation"):
@@ -188,6 +352,30 @@ def demo():
188
  evaluation_results = gr.DataFrame(label="Evaluation Results")
189
 
190
  # Event handlers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  def load_dataset_handler(dataset_name):
192
  samples = evaluator.load_dataset(dataset_name)
193
  return {
@@ -199,7 +387,7 @@ def demo():
199
  def run_evaluation(dataset_choice, splitting_strategy, chunk_size, vector_db, qa_chain):
200
  if not evaluator.current_dataset:
201
  return pd.DataFrame()
202
-
203
  results = evaluator.evaluate_configuration(
204
  vector_db=vector_db,
205
  qa_chain=qa_chain,
@@ -207,11 +395,8 @@ def demo():
207
  chunk_size=chunk_size
208
  )
209
 
210
- # Convert results to DataFrame
211
- df = pd.DataFrame([results])
212
- return df
213
-
214
- # Connect event handlers
215
  load_dataset_btn.click(
216
  load_dataset_handler,
217
  inputs=[dataset_choice],
@@ -229,36 +414,14 @@ def demo():
229
  ],
230
  outputs=[evaluation_results]
231
  )
232
-
233
- qachain_btn.click(
234
- initialize_llmchain, # Fixed function name here
235
- inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db],
236
- outputs=[qa_chain, llm_progress]
237
- ).then(
238
- lambda: [None, "", 0, "", 0, "", 0],
239
- inputs=None,
240
- outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
241
- queue=False
242
- )
243
-
244
- msg.submit(conversation,
245
- inputs=[qa_chain, msg, chatbot],
246
- outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
247
- queue=False
248
- )
249
-
250
- submit_btn.click(conversation,
251
- inputs=[qa_chain, msg, chatbot],
252
- outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
253
- queue=False
254
- )
255
-
256
  clear_btn.click(
257
  lambda: [None, "", 0, "", 0, "", 0],
258
- inputs=None,
259
- outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
260
- queue=False
261
  )
 
 
262
  demo.queue().launch(debug=True)
263
 
264
  if __name__ == "__main__":
 
8
  CharacterTextSplitter,
9
  TokenTextSplitter
10
  )
11
+ from langchain_community.vectorstores import FAISS, Chroma, Qdrant
12
  from langchain_community.document_loaders import PyPDFLoader
13
  from langchain.chains import ConversationalRetrievalChain
14
  from langchain_community.embeddings import HuggingFaceEmbeddings
 
25
  )
26
  import pandas as pd
27
 
28
+ # Constants and setup
29
+ list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]
30
+ list_llm_simple = [os.path.basename(llm) for llm in list_llm]
31
+ api_token = os.getenv("HF_TOKEN")
32
+
33
  CHUNK_SIZES = {
34
  "small": {"recursive": 512, "fixed": 512, "token": 256},
35
  "medium": {"recursive": 1024, "fixed": 1024, "token": 512}
36
  }
37
 
38
+ # Initialize sentence transformer for evaluation
39
+ sentence_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
40
+
41
  class RAGEvaluator:
42
  def __init__(self):
43
  self.datasets = {
 
58
  "context": sample["context"]
59
  }
60
  for sample in samples
61
+ if sample["answers"]["text"]
62
  ]
63
  elif dataset_name == "msmarco":
64
  dataset = load_dataset("ms_marco", "v2.1", split="train")
 
70
  "context": sample["passages"]["passage_text"][0]
71
  }
72
  for sample in samples
73
+ if sample["answers"]
74
  ]
75
  self.current_dataset = dataset_name
76
  return self.test_samples
77
+
78
+ def evaluate_configuration(self, vector_db, qa_chain, splitting_strategy: str, chunk_size: str) -> Dict:
 
 
 
 
79
  if not self.test_samples:
80
  return {"error": "No dataset loaded"}
81
 
 
93
  "ground_truths": [sample["ground_truth"]]
94
  })
95
 
 
96
  eval_dataset = Dataset.from_list(results)
97
+ metrics = [ContextRecall(), AnswerRelevancy(), Faithfulness(), ContextPrecision()]
98
+ scores = evaluate(eval_dataset, metrics=metrics)
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  return {
101
  "configuration": f"{splitting_strategy}_{chunk_size}",
 
111
  ]))
112
  }
113
 
114
+ # Text splitting and database functions
115
+ def get_text_splitter(strategy: str, chunk_size: int = 1024, chunk_overlap: int = 64):
116
+ splitters = {
117
+ "recursive": RecursiveCharacterTextSplitter(
118
+ chunk_size=chunk_size,
119
+ chunk_overlap=chunk_overlap
120
+ ),
121
+ "fixed": CharacterTextSplitter(
122
+ chunk_size=chunk_size,
123
+ chunk_overlap=chunk_overlap
124
+ ),
125
+ "token": TokenTextSplitter(
126
+ chunk_size=chunk_size,
127
+ chunk_overlap=chunk_overlap
128
+ )
129
+ }
130
+ return splitters.get(strategy)
131
+
132
+ def load_doc(list_file_path: List[str], splitting_strategy: str, chunk_size: str):
133
+ chunk_size_value = CHUNK_SIZES[chunk_size][splitting_strategy]
134
+ loaders = [PyPDFLoader(x) for x in list_file_path]
135
+ pages = []
136
+ for loader in loaders:
137
+ pages.extend(loader.load())
138
+
139
+ text_splitter = get_text_splitter(splitting_strategy, chunk_size_value)
140
+ doc_splits = text_splitter.split_documents(pages)
141
+ return doc_splits
142
+
143
+ def create_db(splits, db_choice: str = "faiss"):
144
+ embeddings = HuggingFaceEmbeddings()
145
+ db_creators = {
146
+ "faiss": lambda: FAISS.from_documents(splits, embeddings),
147
+ "chroma": lambda: Chroma.from_documents(splits, embeddings),
148
+ "qdrant": lambda: Qdrant.from_documents(
149
+ splits,
150
+ embeddings,
151
+ location=":memory:",
152
+ collection_name="pdf_docs"
153
+ )
154
+ }
155
+ return db_creators[db_choice]()
156
+
157
+ def initialize_database(list_file_obj, splitting_strategy, chunk_size, db_choice, progress=gr.Progress()):
158
+ list_file_path = [x.name for x in list_file_obj if x is not None]
159
+ doc_splits = load_doc(list_file_path, splitting_strategy, chunk_size)
160
+ vector_db = create_db(doc_splits, db_choice)
161
+ return vector_db, f"Database created using {splitting_strategy} splitting and {db_choice} vector database!"
162
+
163
+ def initialize_llmchain(llm_choice, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
164
+ llm_model = list_llm[llm_choice]
165
+
166
+ llm = HuggingFaceEndpoint(
167
+ repo_id=llm_model,
168
+ huggingfacehub_api_token=api_token,
169
+ temperature=temperature,
170
+ max_new_tokens=max_tokens,
171
+ top_k=top_k
172
+ )
173
+
174
+ memory = ConversationBufferMemory(
175
+ memory_key="chat_history",
176
+ output_key='answer',
177
+ return_messages=True
178
+ )
179
+
180
+ retriever = vector_db.as_retriever()
181
+ qa_chain = ConversationalRetrievalChain.from_llm(
182
+ llm,
183
+ retriever=retriever,
184
+ memory=memory,
185
+ return_source_documents=True
186
+ )
187
+ return qa_chain, "LLM initialized successfully!"
188
+
189
+ def conversation(qa_chain, message, history):
190
+ response = qa_chain.invoke({
191
+ "question": message,
192
+ "chat_history": [(hist[0], hist[1]) for hist in history]
193
+ })
194
+
195
+ response_answer = response["answer"]
196
+ if "Helpful Answer:" in response_answer:
197
+ response_answer = response_answer.split("Helpful Answer:")[-1]
198
+
199
+ sources = response["source_documents"][:3]
200
+ source_contents = [s.page_content.strip() for s in sources]
201
+ source_pages = [s.metadata.get("page", 0) + 1 for s in sources]
202
+
203
+ while len(sources) < 3:
204
+ source_contents.append("")
205
+ source_pages.append(0)
206
+
207
+ return (qa_chain, gr.update(value=""), history + [(message, response_answer)] +
208
+ source_contents + source_pages)
209
+
210
  def demo():
211
  evaluator = RAGEvaluator()
212
 
 
219
  with gr.Tabs():
220
  # Custom PDF Tab
221
  with gr.Tab("Custom PDF Chat"):
 
222
  with gr.Row():
223
  with gr.Column(scale=86):
224
  gr.Markdown("<b>Step 1 - Configure and Initialize RAG Pipeline</b>")
225
  with gr.Row():
226
+ document = gr.Files(
227
+ height=300,
228
+ file_count="multiple",
229
+ file_types=["pdf"],
230
+ interactive=True,
231
+ label="Upload PDF documents"
232
+ )
233
 
234
  with gr.Row():
235
  splitting_strategy = gr.Radio(
 
237
  label="Text Splitting Strategy",
238
  value="recursive"
239
  )
240
+ db_choice = gr.Radio(
241
+ ["faiss", "chroma", "qdrant"],
242
  label="Vector Database",
243
  value="faiss"
244
  )
 
248
  value="medium"
249
  )
250
 
251
+ with gr.Row():
252
+ db_btn = gr.Button("Create vector database")
253
+ db_progress = gr.Textbox(
254
+ value="Not initialized",
255
+ show_label=False
256
+ )
257
+
258
+ gr.Markdown("<b>Step 2 - Configure LLM</b>")
259
+ with gr.Row():
260
+ llm_choice = gr.Radio(
261
+ list_llm_simple,
262
+ label="Available LLMs",
263
+ value=list_llm_simple[0],
264
+ type="index"
265
+ )
266
+
267
+ with gr.Row():
268
+ with gr.Accordion("LLM Parameters", open=False):
269
+ temperature = gr.Slider(
270
+ minimum=0.01,
271
+ maximum=1.0,
272
+ value=0.5,
273
+ step=0.1,
274
+ label="Temperature"
275
+ )
276
+ max_tokens = gr.Slider(
277
+ minimum=128,
278
+ maximum=4096,
279
+ value=2048,
280
+ step=128,
281
+ label="Max Tokens"
282
+ )
283
+ top_k = gr.Slider(
284
+ minimum=1,
285
+ maximum=10,
286
+ value=3,
287
+ step=1,
288
+ label="Top K"
289
+ )
290
+
291
+ with gr.Row():
292
+ init_llm_btn = gr.Button("Initialize LLM")
293
+ llm_progress = gr.Textbox(
294
+ value="Not initialized",
295
+ show_label=False
296
+ )
297
+
298
+ with gr.Column(scale=200):
299
+ gr.Markdown("<b>Step 3 - Chat with Documents</b>")
300
+ chatbot = gr.Chatbot(height=505)
301
+
302
+ with gr.Accordion("Source References", open=False):
303
+ with gr.Row():
304
+ source1 = gr.Textbox(label="Source 1", lines=2)
305
+ page1 = gr.Number(label="Page")
306
+ with gr.Row():
307
+ source2 = gr.Textbox(label="Source 2", lines=2)
308
+ page2 = gr.Number(label="Page")
309
+ with gr.Row():
310
+ source3 = gr.Textbox(label="Source 3", lines=2)
311
+ page3 = gr.Number(label="Page")
312
+
313
+ with gr.Row():
314
+ msg = gr.Textbox(
315
+ placeholder="Ask a question",
316
+ show_label=False
317
+ )
318
+ with gr.Row():
319
+ submit_btn = gr.Button("Submit")
320
+ clear_btn = gr.ClearButton(
321
+ [msg, chatbot],
322
+ value="Clear Chat"
323
+ )
324
 
325
  # Evaluation Tab
326
  with gr.Tab("RAG Evaluation"):
 
352
  evaluation_results = gr.DataFrame(label="Evaluation Results")
353
 
354
  # Event handlers
355
+ db_btn.click(
356
+ initialize_database,
357
+ inputs=[document, splitting_strategy, chunk_size, db_choice],
358
+ outputs=[vector_db, db_progress]
359
+ )
360
+
361
+ init_llm_btn.click(
362
+ initialize_llmchain,
363
+ inputs=[llm_choice, temperature, max_tokens, top_k, vector_db],
364
+ outputs=[qa_chain, llm_progress]
365
+ )
366
+
367
+ msg.submit(
368
+ conversation,
369
+ inputs=[qa_chain, msg, chatbot],
370
+ outputs=[qa_chain, msg, chatbot, source1, page1, source2, page2, source3, page3]
371
+ )
372
+
373
+ submit_btn.click(
374
+ conversation,
375
+ inputs=[qa_chain, msg, chatbot],
376
+ outputs=[qa_chain, msg, chatbot, source1, page1, source2, page2, source3, page3]
377
+ )
378
+
379
  def load_dataset_handler(dataset_name):
380
  samples = evaluator.load_dataset(dataset_name)
381
  return {
 
387
  def run_evaluation(dataset_choice, splitting_strategy, chunk_size, vector_db, qa_chain):
388
  if not evaluator.current_dataset:
389
  return pd.DataFrame()
390
+
391
  results = evaluator.evaluate_configuration(
392
  vector_db=vector_db,
393
  qa_chain=qa_chain,
 
395
  chunk_size=chunk_size
396
  )
397
 
398
+ return pd.DataFrame([results])
399
+
 
 
 
400
  load_dataset_btn.click(
401
  load_dataset_handler,
402
  inputs=[dataset_choice],
 
414
  ],
415
  outputs=[evaluation_results]
416
  )
417
+
418
+ # Clear button handlers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
  clear_btn.click(
420
  lambda: [None, "", 0, "", 0, "", 0],
421
+ outputs=[chatbot, source1, page1, source2, page2, source3, page3]
 
 
422
  )
423
+
424
+ # Launch the demo
425
  demo.queue().launch(debug=True)
426
 
427
  if __name__ == "__main__":