arjunanand13 commited on
Commit
1596101
1 Parent(s): c64a83f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -47
app.py CHANGED
@@ -1,12 +1,12 @@
1
  import gradio as gr
2
  import os
3
  from typing import List, Dict
4
- import ragas
5
  from ragas.metrics import (
6
- context_relevancy,
7
- faithfulness,
8
- answer_relevancy,
9
- context_recall
10
  )
11
  from datasets import load_dataset
12
  from langchain.text_splitter import (
@@ -81,7 +81,7 @@ def create_db(splits, db_choice: str = "faiss"):
81
  }
82
  return db_creators[db_choice](splits, embeddings)
83
 
84
- # Evaluation functions
85
  def load_evaluation_dataset():
86
  # Load example dataset from RAGAS
87
  dataset = load_dataset("explodinggradients/fiqa", split="test")
@@ -91,16 +91,10 @@ def evaluate_rag_pipeline(qa_chain, dataset):
91
  # Sample a few examples for evaluation
92
  eval_samples = dataset.select(range(5))
93
 
94
- results = {
95
- "context_relevancy": [],
96
- "faithfulness": [],
97
- "answer_relevancy": [],
98
- "context_recall": []
99
- }
100
-
101
  for sample in eval_samples:
102
  question = sample["question"]
103
- ground_truth = sample["answer"]
104
 
105
  # Get response from the chain
106
  response = qa_chain.invoke({
@@ -108,40 +102,34 @@ def evaluate_rag_pipeline(qa_chain, dataset):
108
  "chat_history": []
109
  })
110
 
111
- # Evaluate using RAGAS metrics
112
- metrics = {
113
- "context_relevancy": context_relevancy.score(
114
- question=question,
115
- answer=response["answer"],
116
- contexts=response["source_documents"]
117
- ),
118
- "faithfulness": faithfulness.score(
119
- question=question,
120
- answer=response["answer"],
121
- contexts=response["source_documents"]
122
- ),
123
- "answer_relevancy": answer_relevancy.score(
124
- question=question,
125
- answer=response["answer"]
126
- ),
127
- "context_recall": context_recall.score(
128
- question=question,
129
- answer=response["answer"],
130
- contexts=response["source_documents"],
131
- ground_truth=ground_truth
132
- )
133
- }
134
-
135
- for metric, score in metrics.items():
136
- results[metric].append(score)
137
 
138
- # Calculate average scores
139
- avg_results = {
140
- metric: sum(scores) / len(scores)
141
- for metric, scores in results.items()
142
- }
 
 
 
 
 
 
 
 
143
 
144
- return avg_results
 
 
 
 
 
 
145
 
146
  # Initialize langchain LLM chain
147
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
@@ -186,6 +174,39 @@ def initialize_database(list_file_obj, splitting_strategy, db_choice, progress=g
186
  vector_db = create_db(doc_splits, db_choice)
187
  return vector_db, f"Database created using {splitting_strategy} splitting and {db_choice} vector database!"
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  def demo():
190
  with gr.Blocks(theme=gr.themes.Default(primary_hue="red", secondary_hue="pink", neutral_hue="sky")) as demo:
191
  vector_db = gr.State()
@@ -279,7 +300,6 @@ def demo():
279
  queue=False
280
  )
281
 
282
- # Chatbot event handlers remain the same
283
  msg.submit(conversation,
284
  inputs=[qa_chain, msg, chatbot],
285
  outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],
 
1
  import gradio as gr
2
  import os
3
  from typing import List, Dict
4
+ from ragas import evaluate
5
  from ragas.metrics import (
6
+ ContextRecall,
7
+ ContextRelevancy,
8
+ Faithfulness,
9
+ AnswerRelevancy
10
  )
11
  from datasets import load_dataset
12
  from langchain.text_splitter import (
 
81
  }
82
  return db_creators[db_choice](splits, embeddings)
83
 
84
+ # Updated evaluation functions
85
  def load_evaluation_dataset():
86
  # Load example dataset from RAGAS
87
  dataset = load_dataset("explodinggradients/fiqa", split="test")
 
91
  # Sample a few examples for evaluation
92
  eval_samples = dataset.select(range(5))
93
 
94
+ # Prepare data for RAGAS evaluation
95
+ eval_data = []
 
 
 
 
 
96
  for sample in eval_samples:
97
  question = sample["question"]
 
98
 
99
  # Get response from the chain
100
  response = qa_chain.invoke({
 
102
  "chat_history": []
103
  })
104
 
105
+ eval_data.append({
106
+ "question": question,
107
+ "answer": response["answer"],
108
+ "ground_truth": sample["answer"],
109
+ "contexts": [doc.page_content for doc in response["source_documents"]]
110
+ })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
+ # Initialize RAGAS metrics
113
+ metrics = [
114
+ ContextRecall(),
115
+ ContextRelevancy(),
116
+ Faithfulness(),
117
+ AnswerRelevancy()
118
+ ]
119
+
120
+ # Run evaluation
121
+ results = evaluate(
122
+ eval_data,
123
+ metrics=metrics
124
+ )
125
 
126
+ # Convert results to dictionary
127
+ return {
128
+ "context_recall": float(results["context_recall"]),
129
+ "context_relevancy": float(results["context_relevancy"]),
130
+ "faithfulness": float(results["faithfulness"]),
131
+ "answer_relevancy": float(results["answer_relevancy"])
132
+ }
133
 
134
  # Initialize langchain LLM chain
135
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
 
174
  vector_db = create_db(doc_splits, db_choice)
175
  return vector_db, f"Database created using {splitting_strategy} splitting and {db_choice} vector database!"
176
 
177
+ # Formatting chat history
178
+ def format_chat_history(message, chat_history):
179
+ formatted_chat_history = []
180
+ for user_message, bot_message in chat_history:
181
+ formatted_chat_history.append(f"User: {user_message}")
182
+ formatted_chat_history.append(f"Assistant: {bot_message}")
183
+ return formatted_chat_history
184
+
185
+ # Conversation function
186
+ def conversation(qa_chain, message, history):
187
+ formatted_chat_history = format_chat_history(message, history)
188
+ response = qa_chain.invoke({
189
+ "question": message,
190
+ "chat_history": formatted_chat_history
191
+ })
192
+
193
+ response_answer = response["answer"]
194
+ if response_answer.find("Helpful Answer:") != -1:
195
+ response_answer = response_answer.split("Helpful Answer:")[-1]
196
+
197
+ response_sources = response["source_documents"]
198
+ response_source1 = response_sources[0].page_content.strip()
199
+ response_source2 = response_sources[1].page_content.strip()
200
+ response_source3 = response_sources[2].page_content.strip()
201
+
202
+ response_source1_page = response_sources[0].metadata["page"] + 1
203
+ response_source2_page = response_sources[1].metadata["page"] + 1
204
+ response_source3_page = response_sources[2].metadata["page"] + 1
205
+
206
+ new_history = history + [(message, response_answer)]
207
+
208
+ return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
209
+
210
  def demo():
211
  with gr.Blocks(theme=gr.themes.Default(primary_hue="red", secondary_hue="pink", neutral_hue="sky")) as demo:
212
  vector_db = gr.State()
 
300
  queue=False
301
  )
302
 
 
303
  msg.submit(conversation,
304
  inputs=[qa_chain, msg, chatbot],
305
  outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page],