arjunanand13 commited on
Commit
6cec587
·
verified ·
1 Parent(s): bcb57b7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -53
app.py CHANGED
@@ -1,20 +1,12 @@
1
  import gradio as gr
2
  import os
3
  from typing import List, Dict
4
- from ragas import evaluate
5
- from ragas.metrics import (
6
- answer_relevancy,
7
- faithfulness,
8
- context_recall,
9
- context_precision,
10
- answer_correctness,
11
- answer_similarity
12
- )
13
  from datasets import load_dataset
14
  from langchain.text_splitter import (
15
  RecursiveCharacterTextSplitter,
16
  CharacterTextSplitter,
17
- SemanticTextSplitter
18
  )
19
  from langchain_community.vectorstores import FAISS, Chroma, Qdrant
20
  from langchain_community.document_loaders import PyPDFLoader
@@ -22,13 +14,17 @@ from langchain.chains import ConversationalRetrievalChain
22
  from langchain_community.embeddings import HuggingFaceEmbeddings
23
  from langchain_community.llms import HuggingFaceEndpoint
24
  from langchain.memory import ConversationBufferMemory
 
25
  import torch
26
 
27
- # Constants
28
  list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]
29
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
30
  api_token = os.getenv("HF_TOKEN")
31
 
 
 
 
32
  # Text splitting strategies
33
  def get_text_splitter(strategy: str, chunk_size: int = 1024, chunk_overlap: int = 64):
34
  splitters = {
@@ -40,14 +36,38 @@ def get_text_splitter(strategy: str, chunk_size: int = 1024, chunk_overlap: int
40
  chunk_size=chunk_size,
41
  chunk_overlap=chunk_overlap
42
  ),
43
- "semantic": SemanticTextSplitter(
44
- embedding_function=HuggingFaceEmbeddings().embed_query,
45
  chunk_size=chunk_size,
46
  chunk_overlap=chunk_overlap
47
  )
48
  }
49
  return splitters.get(strategy)
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  # Load and split PDF document
52
  def load_doc(list_file_path: List[str], splitting_strategy: str = "recursive"):
53
  loaders = [PyPDFLoader(x) for x in list_file_path]
@@ -83,17 +103,15 @@ def create_db(splits, db_choice: str = "faiss"):
83
  }
84
  return db_creators[db_choice](splits, embeddings)
85
 
86
- # Updated evaluation functions
87
  def load_evaluation_dataset():
88
- # Load example dataset from RAGAS
89
  dataset = load_dataset("explodinggradients/fiqa", split="test")
90
  return dataset
91
 
92
- def prepare_ragas_dataset(qa_chain, dataset):
93
  # Sample a few examples for evaluation
94
  eval_samples = dataset.select(range(5))
95
 
96
- ragas_dataset = []
97
  for sample in eval_samples:
98
  question = sample["question"]
99
 
@@ -103,40 +121,23 @@ def prepare_ragas_dataset(qa_chain, dataset):
103
  "chat_history": []
104
  })
105
 
106
- ragas_dataset.append({
107
- "question": question,
108
- "answer": response["answer"],
109
- "contexts": [doc.page_content for doc in response["source_documents"]],
110
- "ground_truth": sample["answer"]
111
- })
112
-
113
- return ragas_dataset
114
-
115
- def evaluate_rag_pipeline(qa_chain, dataset):
116
- ragas_dataset = prepare_ragas_dataset(qa_chain, dataset)
117
-
118
- # Run RAGAS evaluation
119
- results = evaluate(
120
- ragas_dataset,
121
- metrics=[
122
- context_precision,
123
- faithfulness,
124
- answer_relevancy,
125
- context_recall,
126
- answer_correctness,
127
- answer_similarity
128
- ]
129
- )
130
 
131
- # Convert results to a dictionary
132
- return {
133
- "context_precision": float(results["context_precision"]),
134
- "faithfulness": float(results["faithfulness"]),
135
- "answer_relevancy": float(results["answer_relevancy"]),
136
- "context_recall": float(results["context_recall"]),
137
- "answer_correctness": float(results["answer_correctness"]),
138
- "answer_similarity": float(results["answer_similarity"])
139
  }
 
 
140
 
141
  # Initialize langchain LLM chain
142
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
@@ -174,14 +175,12 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
174
  )
175
  return qa_chain
176
 
177
- # Initialize database with chunking strategy and vector DB choice
178
  def initialize_database(list_file_obj, splitting_strategy, db_choice, progress=gr.Progress()):
179
  list_file_path = [x.name for x in list_file_obj if x is not None]
180
  doc_splits = load_doc(list_file_path, splitting_strategy)
181
  vector_db = create_db(doc_splits, db_choice)
182
  return vector_db, f"Database created using {splitting_strategy} splitting and {db_choice} vector database!"
183
 
184
- # Formatting chat history
185
  def format_chat_history(message, chat_history):
186
  formatted_chat_history = []
187
  for user_message, bot_message in chat_history:
@@ -189,7 +188,6 @@ def format_chat_history(message, chat_history):
189
  formatted_chat_history.append(f"Assistant: {bot_message}")
190
  return formatted_chat_history
191
 
192
- # Conversation function
193
  def conversation(qa_chain, message, history):
194
  formatted_chat_history = format_chat_history(message, history)
195
  response = qa_chain.invoke({
@@ -230,7 +228,7 @@ def demo():
230
 
231
  with gr.Row():
232
  splitting_strategy = gr.Radio(
233
- ["recursive", "fixed", "semantic"],
234
  label="Text Splitting Strategy",
235
  value="recursive"
236
  )
 
1
  import gradio as gr
2
  import os
3
  from typing import List, Dict
4
+ import numpy as np
 
 
 
 
 
 
 
 
5
  from datasets import load_dataset
6
  from langchain.text_splitter import (
7
  RecursiveCharacterTextSplitter,
8
  CharacterTextSplitter,
9
+ TokenTextSplitter
10
  )
11
  from langchain_community.vectorstores import FAISS, Chroma, Qdrant
12
  from langchain_community.document_loaders import PyPDFLoader
 
14
  from langchain_community.embeddings import HuggingFaceEmbeddings
15
  from langchain_community.llms import HuggingFaceEndpoint
16
  from langchain.memory import ConversationBufferMemory
17
+ from sentence_transformers import SentenceTransformer, util
18
  import torch
19
 
20
+ # Constants and setup
21
  list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]
22
  list_llm_simple = [os.path.basename(llm) for llm in list_llm]
23
  api_token = os.getenv("HF_TOKEN")
24
 
25
+ # Initialize sentence transformer for evaluation
26
+ sentence_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
27
+
28
  # Text splitting strategies
29
  def get_text_splitter(strategy: str, chunk_size: int = 1024, chunk_overlap: int = 64):
30
  splitters = {
 
36
  chunk_size=chunk_size,
37
  chunk_overlap=chunk_overlap
38
  ),
39
+ "token": TokenTextSplitter(
 
40
  chunk_size=chunk_size,
41
  chunk_overlap=chunk_overlap
42
  )
43
  }
44
  return splitters.get(strategy)
45
 
46
+ # Custom evaluation metrics
47
+ def calculate_semantic_similarity(text1: str, text2: str) -> float:
48
+ embeddings1 = sentence_model.encode([text1], convert_to_tensor=True)
49
+ embeddings2 = sentence_model.encode([text2], convert_to_tensor=True)
50
+ similarity = util.pytorch_cos_sim(embeddings1, embeddings2)
51
+ return float(similarity[0][0])
52
+
53
+ def evaluate_response(question: str, answer: str, ground_truth: str, contexts: List[str]) -> Dict[str, float]:
54
+ # Answer similarity with ground truth
55
+ answer_similarity = calculate_semantic_similarity(answer, ground_truth)
56
+
57
+ # Context relevance - average similarity between question and contexts
58
+ context_scores = [calculate_semantic_similarity(question, ctx) for ctx in contexts]
59
+ context_relevance = np.mean(context_scores)
60
+
61
+ # Answer relevance - similarity between question and answer
62
+ answer_relevance = calculate_semantic_similarity(question, answer)
63
+
64
+ return {
65
+ "answer_similarity": answer_similarity,
66
+ "context_relevance": context_relevance,
67
+ "answer_relevance": answer_relevance,
68
+ "average_score": np.mean([answer_similarity, context_relevance, answer_relevance])
69
+ }
70
+
71
  # Load and split PDF document
72
  def load_doc(list_file_path: List[str], splitting_strategy: str = "recursive"):
73
  loaders = [PyPDFLoader(x) for x in list_file_path]
 
103
  }
104
  return db_creators[db_choice](splits, embeddings)
105
 
 
106
  def load_evaluation_dataset():
 
107
  dataset = load_dataset("explodinggradients/fiqa", split="test")
108
  return dataset
109
 
110
+ def evaluate_rag_pipeline(qa_chain, dataset):
111
  # Sample a few examples for evaluation
112
  eval_samples = dataset.select(range(5))
113
 
114
+ results = []
115
  for sample in eval_samples:
116
  question = sample["question"]
117
 
 
121
  "chat_history": []
122
  })
123
 
124
+ # Evaluate response
125
+ eval_result = evaluate_response(
126
+ question=question,
127
+ answer=response["answer"],
128
+ ground_truth=sample["answer"],
129
+ contexts=[doc.page_content for doc in response["source_documents"]]
130
+ )
131
+
132
+ results.append(eval_result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ # Calculate average scores across all samples
135
+ avg_results = {
136
+ metric: float(np.mean([r[metric] for r in results]))
137
+ for metric in results[0].keys()
 
 
 
 
138
  }
139
+
140
+ return avg_results
141
 
142
  # Initialize langchain LLM chain
143
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
 
175
  )
176
  return qa_chain
177
 
 
178
  def initialize_database(list_file_obj, splitting_strategy, db_choice, progress=gr.Progress()):
179
  list_file_path = [x.name for x in list_file_obj if x is not None]
180
  doc_splits = load_doc(list_file_path, splitting_strategy)
181
  vector_db = create_db(doc_splits, db_choice)
182
  return vector_db, f"Database created using {splitting_strategy} splitting and {db_choice} vector database!"
183
 
 
184
  def format_chat_history(message, chat_history):
185
  formatted_chat_history = []
186
  for user_message, bot_message in chat_history:
 
188
  formatted_chat_history.append(f"Assistant: {bot_message}")
189
  return formatted_chat_history
190
 
 
191
  def conversation(qa_chain, message, history):
192
  formatted_chat_history = format_chat_history(message, history)
193
  response = qa_chain.invoke({
 
228
 
229
  with gr.Row():
230
  splitting_strategy = gr.Radio(
231
+ ["recursive", "fixed", "token"],
232
  label="Text Splitting Strategy",
233
  value="recursive"
234
  )