HF_Python / app_old_rag.py
Reyad-Ahmmed's picture
Rename app.py to app_old_rag.py
f188b30 verified
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
from datasets import load_dataset
# Load the wiki_dpr dataset and trust the remote code to execute
dataset = load_dataset('wiki_dpr', 'psgs_w100.nq.exact', trust_remote_code=True)
# Initialize the RAG tokenizer (use the T5 tokenizer for RAG)
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
# Initialize the RAG Retriever with the correct index name for wiki_dpr dataset
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="compressed", use_dummy_dataset=True)
# Initialize the RAG Sequence Model (T5-based)
model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")
# Tokenize a sample from the dataset (using wiki_dpr for retrieval)
sample = dataset["train"][0] # or dataset["validation"][0]
input_text = sample["query"]
context_text = sample["passage"]
# Tokenize the input question
inputs = tokenizer(input_text, return_tensors="pt")
# Generate the answer using the RAG model
outputs = model.generate(input_ids=inputs['input_ids'],
decoder_start_token_id=model.config.pad_token_id,
num_beams=3,
num_return_sequences=1,
do_sample=False)
# Decode the generated output
generated_answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Question: {input_text}")
print(f"Answer: {generated_answer}")