Spaces:
Runtime error
Runtime error
from datasets import load_dataset | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import numpy as np | |
import gradio as gr | |
import chromadb | |
# Load a small subset (10,000 rows) | |
dataset = load_dataset("wiki40b", "en", split="train[:1000]") | |
# Extract only text | |
docs = [d["text"] for d in dataset] | |
print("Loaded dataset with", len(docs), "documents.") | |
# Load embedding model | |
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") | |
# Convert texts to embeddings | |
embeddings = embed_model.encode(docs, show_progress_bar=True) | |
# Initialize ChromaDB client | |
chroma_client = chromadb.PersistentClient(path="./chroma_db") # Stores data persistently | |
collection = chroma_client.get_or_create_collection(name="wikipedia_docs") | |
# Store embeddings in ChromaDB | |
for i, (doc, embedding) in enumerate(zip(docs, embeddings)): | |
collection.add( | |
ids=[str(i)], # Unique ID for each doc | |
embeddings=[embedding.tolist()], # Convert numpy array to list | |
documents=[doc] | |
) | |
print("Stored embeddings in ChromaDB!") | |
# Search function using ChromaDB | |
def search_wikipedia(query, top_k=3): | |
query_embedding = embed_model.encode([query]).tolist() | |
results = collection.query( | |
query_embeddings=query_embedding, | |
n_results=top_k | |
) | |
return "\n\n".join(results["documents"][0]) # Return top results | |
# Gradio Interface | |
iface = gr.Interface( | |
fn=search_wikipedia, | |
inputs="text", | |
outputs="text", | |
title="Wikipedia Search RAG", | |
description="Enter a query and retrieve relevant Wikipedia passages." | |
) | |
iface.launch() |