File size: 1,459 Bytes
681469c
 
 
84dfc21
fdd38bd
681469c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
from datasets import load_dataset

# Load the wiki_dpr dataset and trust the remote code to execute
dataset = load_dataset('wiki_dpr', 'psgs_w100.nq.exact', trust_remote_code=True)

# Initialize the RAG tokenizer (use the T5 tokenizer for RAG)
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")

# Initialize the RAG Retriever with the correct index name for wiki_dpr dataset
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="compressed", use_dummy_dataset=True)

# Initialize the RAG Sequence Model (T5-based)
model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")

# Tokenize a sample from the dataset (using wiki_dpr for retrieval)
sample = dataset["train"][0]  # or dataset["validation"][0]
input_text = sample["query"]
context_text = sample["passage"]

# Tokenize the input question
inputs = tokenizer(input_text, return_tensors="pt")

# Generate the answer using the RAG model
outputs = model.generate(input_ids=inputs['input_ids'],
                         decoder_start_token_id=model.config.pad_token_id,
                         num_beams=3,
                         num_return_sequences=1,
                         do_sample=False)

# Decode the generated output
generated_answer = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(f"Question: {input_text}")
print(f"Answer: {generated_answer}")