Spaces:
Runtime error
Runtime error
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration | |
from datasets import load_dataset | |
# Load the wiki_dpr dataset and trust the remote code to execute | |
dataset = load_dataset('wiki_dpr', 'psgs_w100.nq.exact', trust_remote_code=True) | |
# Initialize the RAG tokenizer (use the T5 tokenizer for RAG) | |
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") | |
# Initialize the RAG Retriever with the correct index name for wiki_dpr dataset | |
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="compressed", use_dummy_dataset=True) | |
# Initialize the RAG Sequence Model (T5-based) | |
model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq") | |
# Tokenize a sample from the dataset (using wiki_dpr for retrieval) | |
sample = dataset["train"][0] # or dataset["validation"][0] | |
input_text = sample["query"] | |
context_text = sample["passage"] | |
# Tokenize the input question | |
inputs = tokenizer(input_text, return_tensors="pt") | |
# Generate the answer using the RAG model | |
outputs = model.generate(input_ids=inputs['input_ids'], | |
decoder_start_token_id=model.config.pad_token_id, | |
num_beams=3, | |
num_return_sequences=1, | |
do_sample=False) | |
# Decode the generated output | |
generated_answer = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
print(f"Question: {input_text}") | |
print(f"Answer: {generated_answer}") | |