Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
from typing import List, Dict | |
import numpy as np | |
from datasets import load_dataset | |
from langchain.text_splitter import ( | |
RecursiveCharacterTextSplitter, | |
CharacterTextSplitter, | |
TokenTextSplitter | |
) | |
from langchain_community.vectorstores import FAISS, Chroma | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.llms import HuggingFaceEndpoint | |
from langchain.memory import ConversationBufferMemory | |
from sentence_transformers import SentenceTransformer, util | |
import torch | |
from ragas import evaluate | |
from ragas.metrics import ( | |
ContextRecall, | |
AnswerRelevancy, | |
Faithfulness, | |
ContextPrecision | |
) | |
import pandas as pd | |
# Constants and configurations | |
CHUNK_SIZES = { | |
"small": {"recursive": 512, "fixed": 512, "token": 256}, | |
"medium": {"recursive": 1024, "fixed": 1024, "token": 512} | |
} | |
class RAGEvaluator: | |
def __init__(self): | |
self.datasets = { | |
"squad": "squad_v2", | |
"msmarco": "ms_marco" | |
} | |
self.current_dataset = None | |
self.test_samples = [] | |
def load_dataset(self, dataset_name: str, num_samples: int = 50): | |
if dataset_name == "squad": | |
dataset = load_dataset("squad_v2", split="validation") | |
samples = dataset.select(range(num_samples)) | |
self.test_samples = [ | |
{ | |
"question": sample["question"], | |
"ground_truth": sample["answers"]["text"][0] if sample["answers"]["text"] else "", | |
"context": sample["context"] | |
} | |
for sample in samples | |
if sample["answers"]["text"] # Filter out samples without answers | |
] | |
elif dataset_name == "msmarco": | |
dataset = load_dataset("ms_marco", "v2.1", split="train") | |
samples = dataset.select(range(num_samples)) | |
self.test_samples = [ | |
{ | |
"question": sample["query"], | |
"ground_truth": sample["answers"][0] if sample["answers"] else "", | |
"context": sample["passages"]["passage_text"][0] | |
} | |
for sample in samples | |
if sample["answers"] # Filter out samples without answers | |
] | |
self.current_dataset = dataset_name | |
return self.test_samples | |
def evaluate_configuration(self, | |
vector_db, | |
qa_chain, | |
splitting_strategy: str, | |
chunk_size: str) -> Dict: | |
if not self.test_samples: | |
return {"error": "No dataset loaded"} | |
results = [] | |
for sample in self.test_samples: | |
response = qa_chain.invoke({ | |
"question": sample["question"], | |
"chat_history": [] | |
}) | |
results.append({ | |
"question": sample["question"], | |
"answer": response["answer"], | |
"contexts": [doc.page_content for doc in response["source_documents"]], | |
"ground_truths": [sample["ground_truth"]] | |
}) | |
# Convert to RAGAS dataset format | |
eval_dataset = Dataset.from_list(results) | |
# Calculate RAGAS metrics | |
metrics = [ | |
ContextRecall(), | |
AnswerRelevancy(), | |
Faithfulness(), | |
ContextPrecision() | |
] | |
scores = evaluate( | |
eval_dataset, | |
metrics=metrics | |
) | |
return { | |
"configuration": f"{splitting_strategy}_{chunk_size}", | |
"context_recall": float(scores['context_recall']), | |
"answer_relevancy": float(scores['answer_relevancy']), | |
"faithfulness": float(scores['faithfulness']), | |
"context_precision": float(scores['context_precision']), | |
"average_score": float(np.mean([ | |
scores['context_recall'], | |
scores['answer_relevancy'], | |
scores['faithfulness'], | |
scores['context_precision'] | |
])) | |
} | |
def demo(): | |
evaluator = RAGEvaluator() | |
with gr.Blocks(theme=gr.themes.Default(primary_hue="red", secondary_hue="pink", neutral_hue="sky")) as demo: | |
vector_db = gr.State() | |
qa_chain = gr.State() | |
gr.HTML("<center><h1>Enhanced RAG PDF Chatbot with Evaluation</h1></center>") | |
with gr.Tabs(): | |
# Custom PDF Tab | |
with gr.Tab("Custom PDF Chat"): | |
# Your existing UI components here | |
with gr.Row(): | |
with gr.Column(scale=86): | |
gr.Markdown("<b>Step 1 - Configure and Initialize RAG Pipeline</b>") | |
with gr.Row(): | |
document = gr.Files(height=300, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload PDF documents") | |
with gr.Row(): | |
splitting_strategy = gr.Radio( | |
["recursive", "fixed", "token"], | |
label="Text Splitting Strategy", | |
value="recursive" | |
) | |
db_choice = gr.Dropdown( | |
["faiss", "chroma"], | |
label="Vector Database", | |
value="faiss" | |
) | |
chunk_size = gr.Radio( | |
["small", "medium"], | |
label="Chunk Size", | |
value="medium" | |
) | |
# Rest of your existing UI components... | |
# Evaluation Tab | |
with gr.Tab("RAG Evaluation"): | |
with gr.Row(): | |
dataset_choice = gr.Dropdown( | |
choices=list(evaluator.datasets.keys()), | |
label="Select Evaluation Dataset", | |
value="squad" | |
) | |
load_dataset_btn = gr.Button("Load Dataset") | |
with gr.Row(): | |
dataset_info = gr.JSON(label="Dataset Information") | |
with gr.Row(): | |
eval_splitting_strategy = gr.Radio( | |
["recursive", "fixed", "token"], | |
label="Text Splitting Strategy", | |
value="recursive" | |
) | |
eval_chunk_size = gr.Radio( | |
["small", "medium"], | |
label="Chunk Size", | |
value="medium" | |
) | |
with gr.Row(): | |
evaluate_btn = gr.Button("Run Evaluation") | |
evaluation_results = gr.DataFrame(label="Evaluation Results") | |
# Event handlers | |
def load_dataset_handler(dataset_name): | |
samples = evaluator.load_dataset(dataset_name) | |
return { | |
"dataset": dataset_name, | |
"num_samples": len(samples), | |
"sample_questions": [s["question"] for s in samples[:3]] | |
} | |
def run_evaluation(dataset_choice, splitting_strategy, chunk_size, vector_db, qa_chain): | |
if not evaluator.current_dataset: | |
return pd.DataFrame() | |
results = evaluator.evaluate_configuration( | |
vector_db=vector_db, | |
qa_chain=qa_chain, | |
splitting_strategy=splitting_strategy, | |
chunk_size=chunk_size | |
) | |
# Convert results to DataFrame | |
df = pd.DataFrame([results]) | |
return df | |
# Connect event handlers | |
load_dataset_btn.click( | |
load_dataset_handler, | |
inputs=[dataset_choice], | |
outputs=[dataset_info] | |
) | |
evaluate_btn.click( | |
run_evaluation, | |
inputs=[ | |
dataset_choice, | |
eval_splitting_strategy, | |
eval_chunk_size, | |
vector_db, | |
qa_chain | |
], | |
outputs=[evaluation_results] | |
) | |
qachain_btn.click( | |
initialize_llmchain, # Fixed function name here | |
inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], | |
outputs=[qa_chain, llm_progress] | |
).then( | |
lambda: [None, "", 0, "", 0, "", 0], | |
inputs=None, | |
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], | |
queue=False | |
) | |
msg.submit(conversation, | |
inputs=[qa_chain, msg, chatbot], | |
outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], | |
queue=False | |
) | |
submit_btn.click(conversation, | |
inputs=[qa_chain, msg, chatbot], | |
outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], | |
queue=False | |
) | |
clear_btn.click( | |
lambda: [None, "", 0, "", 0, "", 0], | |
inputs=None, | |
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], | |
queue=False | |
) | |
demo.queue().launch(debug=True) | |
if __name__ == "__main__": | |
demo() |