HF_Python / app.py
Reyad-Ahmmed's picture
Update app.py
426568f verified
raw
history blame
1.42 kB
from transformers import RagTokenizer, RagRetriever, RagSequenceForGeneration
from datasets import load_dataset
# Load the dataset - Here we use the wiki_dpr dataset for retrieval
dataset = load_dataset('wiki_dpr')
# Initialize the RAG tokenizer (use the T5 tokenizer for RAG)
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
# Initialize the RAG Retriever with the correct index name for wiki_dpr dataset
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="compressed", use_dummy_dataset=True)
# Initialize the RAG Sequence Model (T5-based)
model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")
# Tokenize a sample from the dataset (using wiki_dpr for retrieval)
sample = dataset["train"][0] # or dataset["validation"][0]
input_text = sample["query"]
context_text = sample["passage"]
# Tokenize the input question
inputs = tokenizer(input_text, return_tensors="pt")
# Generate the answer using the RAG model
outputs = model.generate(input_ids=inputs['input_ids'],
decoder_start_token_id=model.config.pad_token_id,
num_beams=3,
num_return_sequences=1,
do_sample=False)
# Decode the generated output
generated_answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Question: {input_text}")
print(f"Answer: {generated_answer}")