arif670 commited on
Commit
2424894
·
verified ·
1 Parent(s): 2795690

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -2
app.py CHANGED
@@ -2,8 +2,8 @@ import torch
2
  from transformers import RagRetriever, RagTokenizer, RagSequenceForGeneration
3
  from datasets import load_dataset
4
 
5
- # Step 1: Load the dataset with the trust_remote_code flag enabled
6
- dataset = load_dataset("wiki_dpr", trust_remote_code=True)
7
 
8
  # Step 2: Load the retriever using the pre-trained model, with use_dummy_dataset=True and trust_remote_code=True
9
  retriever = RagRetriever.from_pretrained(
 
2
  from transformers import RagRetriever, RagTokenizer, RagSequenceForGeneration
3
  from datasets import load_dataset
4
 
5
+ # Step 1: Load the dataset with the trust_remote_code flag enabled and a valid config name
6
+ dataset = load_dataset("wiki_dpr", "psgs_w100.nq.exact", trust_remote_code=True)
7
 
8
  # Step 2: Load the retriever using the pre-trained model, with use_dummy_dataset=True and trust_remote_code=True
9
  retriever = RagRetriever.from_pretrained(