from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration from datasets import load_dataset # Load the wiki_dpr dataset and trust the remote code to execute dataset = load_dataset('wiki_dpr', 'psgs_w100.nq.exact', trust_remote_code=True) # Initialize the RAG tokenizer (use the T5 tokenizer for RAG) tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") # Initialize the RAG Retriever with the correct index name for wiki_dpr dataset retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="compressed", use_dummy_dataset=True) # Initialize the RAG Sequence Model (T5-based) model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq") # Tokenize a sample from the dataset (using wiki_dpr for retrieval) sample = dataset["train"][0] # or dataset["validation"][0] input_text = sample["query"] context_text = sample["passage"] # Tokenize the input question inputs = tokenizer(input_text, return_tensors="pt") # Generate the answer using the RAG model outputs = model.generate(input_ids=inputs['input_ids'], decoder_start_token_id=model.config.pad_token_id, num_beams=3, num_return_sequences=1, do_sample=False) # Decode the generated output generated_answer = tokenizer.decode(outputs[0], skip_special_tokens=True) print(f"Question: {input_text}") print(f"Answer: {generated_answer}")