Spaces:
Runtime error
Runtime error
transformers | |
torch | |
sentence-transformers | |
gradio | |
import gradio as gr | |
from transformers import pipeline | |
from sentence_transformers import SentenceTransformer, util | |
# Load models | |
retriever_model = SentenceTransformer('all-MiniLM-L6-v2') # A sentence transformer for retrieval | |
generator_model = pipeline('text-generation', model='gpt2') # Use a suitable model | |
# Sample document collection | |
documents = [ | |
"The quick fox jumps over the lazy dog.", | |
"A speedy brown fox jumps over the lazy dog.", | |
"A cat in the hat is wearing a red bow.", | |
"The sun sets behind the mountains, casting a warm glow." | |
] | |
# Function to perform retrieval and generation | |
def rag_evaluation(query): | |
# Compute embeddings for the documents | |
doc_embeddings = retriever_model.encode(documents, convert_to_tensor=True) | |
# Compute embedding for the query | |
query_embedding = retriever_model.encode(query, convert_to_tensor=True) | |
# Calculate cosine similarity | |
similarities = util.pytorch_cos_sim(query_embedding, doc_embeddings)[0] | |
# Get the most similar document | |
most_similar_idx = similarities.argmax() | |
retrieved_doc = documents[most_similar_idx] | |
# Generate a response based on the retrieved document | |
generated_response = generator_model(f"Based on the document: {retrieved_doc}, answer: {query}", max_length=50)[0]['generated_text'] | |
return retrieved_doc, generated_response | |
# Gradio interface | |
iface = gr.Interface( | |
fn=rag_evaluation, | |
inputs=gr.Textbox(label="Enter your query"), | |
outputs=[gr.Textbox(label="Retrieved Document"), gr.Textbox(label="Generated Response")], | |
title="RAG Evaluation App", | |
description="Evaluate retrieval and generation performance of a RAG system." | |
) | |
iface.launch() | |