Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,20 +1,12 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import os
|
| 3 |
from typing import List, Dict
|
| 4 |
-
|
| 5 |
-
from ragas.metrics import (
|
| 6 |
-
answer_relevancy,
|
| 7 |
-
faithfulness,
|
| 8 |
-
context_recall,
|
| 9 |
-
context_precision,
|
| 10 |
-
answer_correctness,
|
| 11 |
-
answer_similarity
|
| 12 |
-
)
|
| 13 |
from datasets import load_dataset
|
| 14 |
from langchain.text_splitter import (
|
| 15 |
RecursiveCharacterTextSplitter,
|
| 16 |
CharacterTextSplitter,
|
| 17 |
-
|
| 18 |
)
|
| 19 |
from langchain_community.vectorstores import FAISS, Chroma, Qdrant
|
| 20 |
from langchain_community.document_loaders import PyPDFLoader
|
|
@@ -22,13 +14,17 @@ from langchain.chains import ConversationalRetrievalChain
|
|
| 22 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 23 |
from langchain_community.llms import HuggingFaceEndpoint
|
| 24 |
from langchain.memory import ConversationBufferMemory
|
|
|
|
| 25 |
import torch
|
| 26 |
|
| 27 |
-
# Constants
|
| 28 |
list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]
|
| 29 |
list_llm_simple = [os.path.basename(llm) for llm in list_llm]
|
| 30 |
api_token = os.getenv("HF_TOKEN")
|
| 31 |
|
|
|
|
|
|
|
|
|
|
| 32 |
# Text splitting strategies
|
| 33 |
def get_text_splitter(strategy: str, chunk_size: int = 1024, chunk_overlap: int = 64):
|
| 34 |
splitters = {
|
|
@@ -40,14 +36,38 @@ def get_text_splitter(strategy: str, chunk_size: int = 1024, chunk_overlap: int
|
|
| 40 |
chunk_size=chunk_size,
|
| 41 |
chunk_overlap=chunk_overlap
|
| 42 |
),
|
| 43 |
-
"
|
| 44 |
-
embedding_function=HuggingFaceEmbeddings().embed_query,
|
| 45 |
chunk_size=chunk_size,
|
| 46 |
chunk_overlap=chunk_overlap
|
| 47 |
)
|
| 48 |
}
|
| 49 |
return splitters.get(strategy)
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
# Load and split PDF document
|
| 52 |
def load_doc(list_file_path: List[str], splitting_strategy: str = "recursive"):
|
| 53 |
loaders = [PyPDFLoader(x) for x in list_file_path]
|
|
@@ -83,17 +103,15 @@ def create_db(splits, db_choice: str = "faiss"):
|
|
| 83 |
}
|
| 84 |
return db_creators[db_choice](splits, embeddings)
|
| 85 |
|
| 86 |
-
# Updated evaluation functions
|
| 87 |
def load_evaluation_dataset():
|
| 88 |
-
# Load example dataset from RAGAS
|
| 89 |
dataset = load_dataset("explodinggradients/fiqa", split="test")
|
| 90 |
return dataset
|
| 91 |
|
| 92 |
-
def
|
| 93 |
# Sample a few examples for evaluation
|
| 94 |
eval_samples = dataset.select(range(5))
|
| 95 |
|
| 96 |
-
|
| 97 |
for sample in eval_samples:
|
| 98 |
question = sample["question"]
|
| 99 |
|
|
@@ -103,40 +121,23 @@ def prepare_ragas_dataset(qa_chain, dataset):
|
|
| 103 |
"chat_history": []
|
| 104 |
})
|
| 105 |
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
def evaluate_rag_pipeline(qa_chain, dataset):
|
| 116 |
-
ragas_dataset = prepare_ragas_dataset(qa_chain, dataset)
|
| 117 |
-
|
| 118 |
-
# Run RAGAS evaluation
|
| 119 |
-
results = evaluate(
|
| 120 |
-
ragas_dataset,
|
| 121 |
-
metrics=[
|
| 122 |
-
context_precision,
|
| 123 |
-
faithfulness,
|
| 124 |
-
answer_relevancy,
|
| 125 |
-
context_recall,
|
| 126 |
-
answer_correctness,
|
| 127 |
-
answer_similarity
|
| 128 |
-
]
|
| 129 |
-
)
|
| 130 |
|
| 131 |
-
#
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
"answer_relevancy": float(results["answer_relevancy"]),
|
| 136 |
-
"context_recall": float(results["context_recall"]),
|
| 137 |
-
"answer_correctness": float(results["answer_correctness"]),
|
| 138 |
-
"answer_similarity": float(results["answer_similarity"])
|
| 139 |
}
|
|
|
|
|
|
|
| 140 |
|
| 141 |
# Initialize langchain LLM chain
|
| 142 |
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
|
|
@@ -174,14 +175,12 @@ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, pr
|
|
| 174 |
)
|
| 175 |
return qa_chain
|
| 176 |
|
| 177 |
-
# Initialize database with chunking strategy and vector DB choice
|
| 178 |
def initialize_database(list_file_obj, splitting_strategy, db_choice, progress=gr.Progress()):
|
| 179 |
list_file_path = [x.name for x in list_file_obj if x is not None]
|
| 180 |
doc_splits = load_doc(list_file_path, splitting_strategy)
|
| 181 |
vector_db = create_db(doc_splits, db_choice)
|
| 182 |
return vector_db, f"Database created using {splitting_strategy} splitting and {db_choice} vector database!"
|
| 183 |
|
| 184 |
-
# Formatting chat history
|
| 185 |
def format_chat_history(message, chat_history):
|
| 186 |
formatted_chat_history = []
|
| 187 |
for user_message, bot_message in chat_history:
|
|
@@ -189,7 +188,6 @@ def format_chat_history(message, chat_history):
|
|
| 189 |
formatted_chat_history.append(f"Assistant: {bot_message}")
|
| 190 |
return formatted_chat_history
|
| 191 |
|
| 192 |
-
# Conversation function
|
| 193 |
def conversation(qa_chain, message, history):
|
| 194 |
formatted_chat_history = format_chat_history(message, history)
|
| 195 |
response = qa_chain.invoke({
|
|
@@ -230,7 +228,7 @@ def demo():
|
|
| 230 |
|
| 231 |
with gr.Row():
|
| 232 |
splitting_strategy = gr.Radio(
|
| 233 |
-
["recursive", "fixed", "
|
| 234 |
label="Text Splitting Strategy",
|
| 235 |
value="recursive"
|
| 236 |
)
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import os
|
| 3 |
from typing import List, Dict
|
| 4 |
+
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from datasets import load_dataset
|
| 6 |
from langchain.text_splitter import (
|
| 7 |
RecursiveCharacterTextSplitter,
|
| 8 |
CharacterTextSplitter,
|
| 9 |
+
TokenTextSplitter
|
| 10 |
)
|
| 11 |
from langchain_community.vectorstores import FAISS, Chroma, Qdrant
|
| 12 |
from langchain_community.document_loaders import PyPDFLoader
|
|
|
|
| 14 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
| 15 |
from langchain_community.llms import HuggingFaceEndpoint
|
| 16 |
from langchain.memory import ConversationBufferMemory
|
| 17 |
+
from sentence_transformers import SentenceTransformer, util
|
| 18 |
import torch
|
| 19 |
|
| 20 |
+
# Constants and setup
|
| 21 |
list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]
|
| 22 |
list_llm_simple = [os.path.basename(llm) for llm in list_llm]
|
| 23 |
api_token = os.getenv("HF_TOKEN")
|
| 24 |
|
| 25 |
+
# Initialize sentence transformer for evaluation
|
| 26 |
+
sentence_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
|
| 27 |
+
|
| 28 |
# Text splitting strategies
|
| 29 |
def get_text_splitter(strategy: str, chunk_size: int = 1024, chunk_overlap: int = 64):
|
| 30 |
splitters = {
|
|
|
|
| 36 |
chunk_size=chunk_size,
|
| 37 |
chunk_overlap=chunk_overlap
|
| 38 |
),
|
| 39 |
+
"token": TokenTextSplitter(
|
|
|
|
| 40 |
chunk_size=chunk_size,
|
| 41 |
chunk_overlap=chunk_overlap
|
| 42 |
)
|
| 43 |
}
|
| 44 |
return splitters.get(strategy)
|
| 45 |
|
| 46 |
+
# Custom evaluation metrics
|
| 47 |
+
def calculate_semantic_similarity(text1: str, text2: str) -> float:
|
| 48 |
+
embeddings1 = sentence_model.encode([text1], convert_to_tensor=True)
|
| 49 |
+
embeddings2 = sentence_model.encode([text2], convert_to_tensor=True)
|
| 50 |
+
similarity = util.pytorch_cos_sim(embeddings1, embeddings2)
|
| 51 |
+
return float(similarity[0][0])
|
| 52 |
+
|
| 53 |
+
def evaluate_response(question: str, answer: str, ground_truth: str, contexts: List[str]) -> Dict[str, float]:
|
| 54 |
+
# Answer similarity with ground truth
|
| 55 |
+
answer_similarity = calculate_semantic_similarity(answer, ground_truth)
|
| 56 |
+
|
| 57 |
+
# Context relevance - average similarity between question and contexts
|
| 58 |
+
context_scores = [calculate_semantic_similarity(question, ctx) for ctx in contexts]
|
| 59 |
+
context_relevance = np.mean(context_scores)
|
| 60 |
+
|
| 61 |
+
# Answer relevance - similarity between question and answer
|
| 62 |
+
answer_relevance = calculate_semantic_similarity(question, answer)
|
| 63 |
+
|
| 64 |
+
return {
|
| 65 |
+
"answer_similarity": answer_similarity,
|
| 66 |
+
"context_relevance": context_relevance,
|
| 67 |
+
"answer_relevance": answer_relevance,
|
| 68 |
+
"average_score": np.mean([answer_similarity, context_relevance, answer_relevance])
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
# Load and split PDF document
|
| 72 |
def load_doc(list_file_path: List[str], splitting_strategy: str = "recursive"):
|
| 73 |
loaders = [PyPDFLoader(x) for x in list_file_path]
|
|
|
|
| 103 |
}
|
| 104 |
return db_creators[db_choice](splits, embeddings)
|
| 105 |
|
|
|
|
| 106 |
def load_evaluation_dataset():
|
|
|
|
| 107 |
dataset = load_dataset("explodinggradients/fiqa", split="test")
|
| 108 |
return dataset
|
| 109 |
|
| 110 |
+
def evaluate_rag_pipeline(qa_chain, dataset):
|
| 111 |
# Sample a few examples for evaluation
|
| 112 |
eval_samples = dataset.select(range(5))
|
| 113 |
|
| 114 |
+
results = []
|
| 115 |
for sample in eval_samples:
|
| 116 |
question = sample["question"]
|
| 117 |
|
|
|
|
| 121 |
"chat_history": []
|
| 122 |
})
|
| 123 |
|
| 124 |
+
# Evaluate response
|
| 125 |
+
eval_result = evaluate_response(
|
| 126 |
+
question=question,
|
| 127 |
+
answer=response["answer"],
|
| 128 |
+
ground_truth=sample["answer"],
|
| 129 |
+
contexts=[doc.page_content for doc in response["source_documents"]]
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
results.append(eval_result)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
|
| 134 |
+
# Calculate average scores across all samples
|
| 135 |
+
avg_results = {
|
| 136 |
+
metric: float(np.mean([r[metric] for r in results]))
|
| 137 |
+
for metric in results[0].keys()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
}
|
| 139 |
+
|
| 140 |
+
return avg_results
|
| 141 |
|
| 142 |
# Initialize langchain LLM chain
|
| 143 |
def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
|
|
|
|
| 175 |
)
|
| 176 |
return qa_chain
|
| 177 |
|
|
|
|
| 178 |
def initialize_database(list_file_obj, splitting_strategy, db_choice, progress=gr.Progress()):
|
| 179 |
list_file_path = [x.name for x in list_file_obj if x is not None]
|
| 180 |
doc_splits = load_doc(list_file_path, splitting_strategy)
|
| 181 |
vector_db = create_db(doc_splits, db_choice)
|
| 182 |
return vector_db, f"Database created using {splitting_strategy} splitting and {db_choice} vector database!"
|
| 183 |
|
|
|
|
| 184 |
def format_chat_history(message, chat_history):
|
| 185 |
formatted_chat_history = []
|
| 186 |
for user_message, bot_message in chat_history:
|
|
|
|
| 188 |
formatted_chat_history.append(f"Assistant: {bot_message}")
|
| 189 |
return formatted_chat_history
|
| 190 |
|
|
|
|
| 191 |
def conversation(qa_chain, message, history):
|
| 192 |
formatted_chat_history = format_chat_history(message, history)
|
| 193 |
response = qa_chain.invoke({
|
|
|
|
| 228 |
|
| 229 |
with gr.Row():
|
| 230 |
splitting_strategy = gr.Radio(
|
| 231 |
+
["recursive", "fixed", "token"],
|
| 232 |
label="Text Splitting Strategy",
|
| 233 |
value="recursive"
|
| 234 |
)
|