Spaces:
Sleeping
Sleeping
File size: 3,137 Bytes
a3ef303 818c1d2 ddc705c 8ca6390 155dba6 8ca6390 ddc705c 818c1d2 456bc55 f499a63 818c1d2 a0908b4 818c1d2 5c94b0e 818c1d2 ddc705c 818c1d2 72355a9 818c1d2 253a65f 818c1d2 5c94b0e 0ebed34 5c94b0e 818c1d2 5c94b0e 818c1d2 a92e9d3 818c1d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import gradio as gr
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, CrossEncoder
import faiss
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import os
import spaces
print("CUDA available:", torch.cuda.is_available())
@spaces.GPU
def claim_gpu():
# Dummy function to make Spaces detect GPU usage
pass
claim_gpu()
# Login automatically if HF_TOKEN is present
hf_token = os.getenv("HF_TOKEN")
if hf_token:
from huggingface_hub import login
login(token=hf_token)
# Load corpus
print("Loading dataset...")
dataset = load_dataset("rag-datasets/rag-mini-wikipedia", "text-corpus")
# corpus = [item for item in dataset["passages"]]
# Always clean + use this corpus consistently
corpus = []
for item in dataset["passages"]:
text = str(item).strip()
if text:
corpus.append(text)
# Embedding model
print("Encoding corpus...")
embedder = SentenceTransformer("all-MiniLM-L6-v2")
corpus_embeddings = embedder.encode(corpus, convert_to_tensor=True, device='cpu')
corpus_embeddings_np = corpus_embeddings.numpy()
# FAISS index
index = faiss.IndexFlatL2(corpus_embeddings_np.shape[1])
index.add(corpus_embeddings_np)
# Reranker model
# reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
# Generator (choose one: local HF model or OpenAI)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", device_map="auto", torch_dtype=torch.float16)
generator = pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=150)
def rag_pipeline(query):
# Embed query
query_embedding = embedder.encode([query], convert_to_tensor=True, device='cpu').numpy()
# Retrieve top-k from FAISS
D, I = index.search(query_embedding, k=5)
retrieved_docs = [corpus[idx] for idx in I[0]]
print("Retrieved indices:", I[0])
print("Retrieved docs:")
for doc in retrieved_docs:
print("-", repr(doc))
# # Rerank
# rerank_pairs = [[str(query), str(doc)] for doc in retrieved_docs if isinstance(doc, str) and doc.strip()]
# if not rerank_pairs:
# return "No valid documents found to rerank."
# scores = reranker.predict(rerank_pairs)
# scores = reranker.predict(rerank_pairs)
# reranked_docs = [doc for _, doc in sorted(zip(scores, retrieved_docs), reverse=True)]
# Combine for context
context = "\n\n".join(retrieved_docs[:2])
prompt = f"""Answer the following question using the provided context.\n\nContext:\n{context}\n\nQuestion: {query}\nAnswer:"""
# Generate
response = generator(prompt)[0]["generated_text"]
return response.split("Answer:")[-1].strip()
# Gradio UI
iface = gr.Interface(fn=rag_pipeline,
inputs=gr.Textbox(lines=2, placeholder="How fast is a penguin?"),
outputs="text",
title="Mini RAG Wikipedia Demo",
description="Retrieval-Augmented Generation on a small Wikipedia subset.")
iface.launch()
|