Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,20 +1,12 @@
|
|
1 |
import gradio as gr
|
2 |
import os
|
3 |
from typing import List, Dict
|
4 |
-
|
5 |
-
from ragas.metrics import (
|
6 |
-
answer_relevancy,
|
7 |
-
faithfulness,
|
8 |
-
context_recall,
|
9 |
-
context_precision,
|
10 |
-
answer_correctness,
|
11 |
-
answer_similarity
|
12 |
-
)
|
13 |
from datasets import load_dataset
|
14 |
from langchain.text_splitter import (
|
15 |
RecursiveCharacterTextSplitter,
|
16 |
CharacterTextSplitter,
|
17 |
-
|
18 |
)
|
19 |
from langchain_community.vectorstores import FAISS, Chroma, Qdrant
|
20 |
from langchain_community.document_loaders import PyPDFLoader
|
@@ -22,13 +14,17 @@ from langchain.chains import ConversationalRetrievalChain
|
|
22 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
23 |
from langchain_community.llms import HuggingFaceEndpoint
|
24 |
from langchain.memory import ConversationBufferMemory
|
|
|
25 |
import torch
|
26 |
|
27 |
-
# Constants
|
28 |
list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]
|
29 |
list_llm_simple = [os.path.basename(llm) for llm in list_llm]
|
30 |
api_token = os.getenv("HF_TOKEN")
|
31 |
|
|
|
|
|
|
|
32 |
# Text splitting strategies
|
33 |
def get_text_splitter(strategy: str, chunk_size: int = 1024, chunk_overlap: int = 64):
|
34 |
splitters = {
|
@@ -40,14 +36,38 @@ def get_text_splitter(strategy: str, chunk_size: int = 1024, chunk_overlap: int
|
|
40 |
chunk_size=chunk_size,
|
41 |
chunk_overlap=chunk_overlap
|
42 |
),
|
43 |
-
"
|
44 |
-
embedding_function=HuggingFaceEmbeddings().embed_query,
|
45 |
chunk_size=chunk_size,
|
46 |
chunk_overlap=chunk_overlap
|
47 |
)
|
48 |
}
|
49 |
return splitters.get(strategy)
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
# Load and split PDF document
|
52 |
def load_doc(list_file_path: List[str], splitting_strategy: str = "recursive"):
|
53 |
loaders = [PyPDFLoader(x) for x in list_file_path]
|
@@ -83,17 +103,15 @@ def create_db(splits, db_choice: str = "faiss"):
|
|
83 |
}
|
84 |
return db_creators[db_choice](splits, embeddings)
|
85 |
|
86 |
-
# Updated evaluation functions
|
87 |
def load_evaluation_dataset():
|
88 |
-
# Load example dataset from RAGAS
|
89 |
dataset = load_dataset("explodinggradients/fiqa", split="test")
|
90 |
return dataset
|
91 |
|
92 |
-
def
|
93 |
# Sample a few examples for evaluation
|
94 |
eval_samples = dataset.select(range(5))
|
95 |
|
96 |
-
|
97 |
for sample in eval_samples:
|
98 |
question = sample["question"]
|
99 |
|
@@ -103,40 +121,23 @@ def prepare_ragas_dataset(qa_chain, dataset):
|
|
103 |
"chat_history": []
|
104 |
})
|
105 |
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
def evaluate_rag_pipeline(qa_chain, dataset):
|
116 |
-
ragas_dataset = prepare_ragas_dataset(qa_chain, dataset)
|
117 |
-
|
118 |
-
# Run RAGAS evaluation
|
119 |
-
results = evaluate(
|
120 |
-
ragas_dataset,
|
121 |
-
metrics=[
|
122 |
-
context_precision,
|
123 |
-
faithfulness,
|
124 |
-
answer_relevancy,
|
125 |
-
context_recall,
|
126 |
-
answer_correctness,
|
127 |
-
answer_similarity
|
128 |
-
]
|
129 |
-
)
|
130 |
|
131 |
-
#
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
"answer_relevancy": float(results["answer_relevancy"]),
|
136 |
-
"context_recall": float(results["context_recall"]),
|
137 |
-
"answer_correctness": float(results["answer_correctness"]),
|
138 |
-
"answer_similarity": float(results["answer_similarity"])
|
139 |
}
|
|
|
|
|
140 |
|
141 |
# Initialize langchain LLM chain
|
142 |
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
|
@@ -174,14 +175,12 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
|
|
174 |
)
|
175 |
return qa_chain
|
176 |
|
177 |
-
# Initialize database with chunking strategy and vector DB choice
|
178 |
def initialize_database(list_file_obj, splitting_strategy, db_choice, progress=gr.Progress()):
|
179 |
list_file_path = [x.name for x in list_file_obj if x is not None]
|
180 |
doc_splits = load_doc(list_file_path, splitting_strategy)
|
181 |
vector_db = create_db(doc_splits, db_choice)
|
182 |
return vector_db, f"Database created using {splitting_strategy} splitting and {db_choice} vector database!"
|
183 |
|
184 |
-
# Formatting chat history
|
185 |
def format_chat_history(message, chat_history):
|
186 |
formatted_chat_history = []
|
187 |
for user_message, bot_message in chat_history:
|
@@ -189,7 +188,6 @@ def format_chat_history(message, chat_history):
|
|
189 |
formatted_chat_history.append(f"Assistant: {bot_message}")
|
190 |
return formatted_chat_history
|
191 |
|
192 |
-
# Conversation function
|
193 |
def conversation(qa_chain, message, history):
|
194 |
formatted_chat_history = format_chat_history(message, history)
|
195 |
response = qa_chain.invoke({
|
@@ -230,7 +228,7 @@ def demo():
|
|
230 |
|
231 |
with gr.Row():
|
232 |
splitting_strategy = gr.Radio(
|
233 |
-
["recursive", "fixed", "
|
234 |
label="Text Splitting Strategy",
|
235 |
value="recursive"
|
236 |
)
|
|
|
1 |
import gradio as gr
|
2 |
import os
|
3 |
from typing import List, Dict
|
4 |
+
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
from datasets import load_dataset
|
6 |
from langchain.text_splitter import (
|
7 |
RecursiveCharacterTextSplitter,
|
8 |
CharacterTextSplitter,
|
9 |
+
TokenTextSplitter
|
10 |
)
|
11 |
from langchain_community.vectorstores import FAISS, Chroma, Qdrant
|
12 |
from langchain_community.document_loaders import PyPDFLoader
|
|
|
14 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
15 |
from langchain_community.llms import HuggingFaceEndpoint
|
16 |
from langchain.memory import ConversationBufferMemory
|
17 |
+
from sentence_transformers import SentenceTransformer, util
|
18 |
import torch
|
19 |
|
20 |
+
# Constants and setup
|
21 |
list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]
|
22 |
list_llm_simple = [os.path.basename(llm) for llm in list_llm]
|
23 |
api_token = os.getenv("HF_TOKEN")
|
24 |
|
25 |
+
# Initialize sentence transformer for evaluation
|
26 |
+
sentence_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
27 |
+
|
28 |
# Text splitting strategies
|
29 |
def get_text_splitter(strategy: str, chunk_size: int = 1024, chunk_overlap: int = 64):
|
30 |
splitters = {
|
|
|
36 |
chunk_size=chunk_size,
|
37 |
chunk_overlap=chunk_overlap
|
38 |
),
|
39 |
+
"token": TokenTextSplitter(
|
|
|
40 |
chunk_size=chunk_size,
|
41 |
chunk_overlap=chunk_overlap
|
42 |
)
|
43 |
}
|
44 |
return splitters.get(strategy)
|
45 |
|
46 |
+
# Custom evaluation metrics
|
47 |
+
def calculate_semantic_similarity(text1: str, text2: str) -> float:
|
48 |
+
embeddings1 = sentence_model.encode([text1], convert_to_tensor=True)
|
49 |
+
embeddings2 = sentence_model.encode([text2], convert_to_tensor=True)
|
50 |
+
similarity = util.pytorch_cos_sim(embeddings1, embeddings2)
|
51 |
+
return float(similarity[0][0])
|
52 |
+
|
53 |
+
def evaluate_response(question: str, answer: str, ground_truth: str, contexts: List[str]) -> Dict[str, float]:
|
54 |
+
# Answer similarity with ground truth
|
55 |
+
answer_similarity = calculate_semantic_similarity(answer, ground_truth)
|
56 |
+
|
57 |
+
# Context relevance - average similarity between question and contexts
|
58 |
+
context_scores = [calculate_semantic_similarity(question, ctx) for ctx in contexts]
|
59 |
+
context_relevance = np.mean(context_scores)
|
60 |
+
|
61 |
+
# Answer relevance - similarity between question and answer
|
62 |
+
answer_relevance = calculate_semantic_similarity(question, answer)
|
63 |
+
|
64 |
+
return {
|
65 |
+
"answer_similarity": answer_similarity,
|
66 |
+
"context_relevance": context_relevance,
|
67 |
+
"answer_relevance": answer_relevance,
|
68 |
+
"average_score": np.mean([answer_similarity, context_relevance, answer_relevance])
|
69 |
+
}
|
70 |
+
|
71 |
# Load and split PDF document
|
72 |
def load_doc(list_file_path: List[str], splitting_strategy: str = "recursive"):
|
73 |
loaders = [PyPDFLoader(x) for x in list_file_path]
|
|
|
103 |
}
|
104 |
return db_creators[db_choice](splits, embeddings)
|
105 |
|
|
|
106 |
def load_evaluation_dataset():
|
|
|
107 |
dataset = load_dataset("explodinggradients/fiqa", split="test")
|
108 |
return dataset
|
109 |
|
110 |
+
def evaluate_rag_pipeline(qa_chain, dataset):
|
111 |
# Sample a few examples for evaluation
|
112 |
eval_samples = dataset.select(range(5))
|
113 |
|
114 |
+
results = []
|
115 |
for sample in eval_samples:
|
116 |
question = sample["question"]
|
117 |
|
|
|
121 |
"chat_history": []
|
122 |
})
|
123 |
|
124 |
+
# Evaluate response
|
125 |
+
eval_result = evaluate_response(
|
126 |
+
question=question,
|
127 |
+
answer=response["answer"],
|
128 |
+
ground_truth=sample["answer"],
|
129 |
+
contexts=[doc.page_content for doc in response["source_documents"]]
|
130 |
+
)
|
131 |
+
|
132 |
+
results.append(eval_result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
+
# Calculate average scores across all samples
|
135 |
+
avg_results = {
|
136 |
+
metric: float(np.mean([r[metric] for r in results]))
|
137 |
+
for metric in results[0].keys()
|
|
|
|
|
|
|
|
|
138 |
}
|
139 |
+
|
140 |
+
return avg_results
|
141 |
|
142 |
# Initialize langchain LLM chain
|
143 |
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
|
|
|
175 |
)
|
176 |
return qa_chain
|
177 |
|
|
|
178 |
def initialize_database(list_file_obj, splitting_strategy, db_choice, progress=gr.Progress()):
|
179 |
list_file_path = [x.name for x in list_file_obj if x is not None]
|
180 |
doc_splits = load_doc(list_file_path, splitting_strategy)
|
181 |
vector_db = create_db(doc_splits, db_choice)
|
182 |
return vector_db, f"Database created using {splitting_strategy} splitting and {db_choice} vector database!"
|
183 |
|
|
|
184 |
def format_chat_history(message, chat_history):
|
185 |
formatted_chat_history = []
|
186 |
for user_message, bot_message in chat_history:
|
|
|
188 |
formatted_chat_history.append(f"Assistant: {bot_message}")
|
189 |
return formatted_chat_history
|
190 |
|
|
|
191 |
def conversation(qa_chain, message, history):
|
192 |
formatted_chat_history = format_chat_history(message, history)
|
193 |
response = qa_chain.invoke({
|
|
|
228 |
|
229 |
with gr.Row():
|
230 |
splitting_strategy = gr.Radio(
|
231 |
+
["recursive", "fixed", "token"],
|
232 |
label="Text Splitting Strategy",
|
233 |
value="recursive"
|
234 |
)
|