Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -2,3 +2,50 @@ transformers
|
|
2 |
torch
|
3 |
sentence-transformers
|
4 |
gradio
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
torch
|
3 |
sentence-transformers
|
4 |
gradio
|
5 |
+
|
6 |
+
import gradio as gr
|
7 |
+
from transformers import pipeline
|
8 |
+
from sentence_transformers import SentenceTransformer, util
|
9 |
+
|
10 |
+
# Load models
|
11 |
+
retriever_model = SentenceTransformer('all-MiniLM-L6-v2') # A sentence transformer for retrieval
|
12 |
+
generator_model = pipeline('text-generation', model='gpt2') # Use a suitable model
|
13 |
+
|
14 |
+
# Sample document collection
|
15 |
+
documents = [
|
16 |
+
"The quick fox jumps over the lazy dog.",
|
17 |
+
"A speedy brown fox jumps over the lazy dog.",
|
18 |
+
"A cat in the hat is wearing a red bow.",
|
19 |
+
"The sun sets behind the mountains, casting a warm glow."
|
20 |
+
]
|
21 |
+
|
22 |
+
# Function to perform retrieval and generation
|
23 |
+
def rag_evaluation(query):
|
24 |
+
# Compute embeddings for the documents
|
25 |
+
doc_embeddings = retriever_model.encode(documents, convert_to_tensor=True)
|
26 |
+
|
27 |
+
# Compute embedding for the query
|
28 |
+
query_embedding = retriever_model.encode(query, convert_to_tensor=True)
|
29 |
+
|
30 |
+
# Calculate cosine similarity
|
31 |
+
similarities = util.pytorch_cos_sim(query_embedding, doc_embeddings)[0]
|
32 |
+
|
33 |
+
# Get the most similar document
|
34 |
+
most_similar_idx = similarities.argmax()
|
35 |
+
retrieved_doc = documents[most_similar_idx]
|
36 |
+
|
37 |
+
# Generate a response based on the retrieved document
|
38 |
+
generated_response = generator_model(f"Based on the document: {retrieved_doc}, answer: {query}", max_length=50)[0]['generated_text']
|
39 |
+
|
40 |
+
return retrieved_doc, generated_response
|
41 |
+
|
42 |
+
# Gradio interface
|
43 |
+
iface = gr.Interface(
|
44 |
+
fn=rag_evaluation,
|
45 |
+
inputs=gr.Textbox(label="Enter your query"),
|
46 |
+
outputs=[gr.Textbox(label="Retrieved Document"), gr.Textbox(label="Generated Response")],
|
47 |
+
title="RAG Evaluation App",
|
48 |
+
description="Evaluate retrieval and generation performance of a RAG system."
|
49 |
+
)
|
50 |
+
|
51 |
+
iface.launch()
|