from datasets import Dataset, load_from_disk import faiss import numpy as np from transformers import RagRetriever, RagTokenizer, RagSequenceForGeneration # Example: Create a dataset data = {"text": ["This is a sample text.", "Another sample text."]} dataset = Dataset.from_dict(data) # Save the dataset to disk dataset_path = "path/to/your/dataset" dataset.save_to_disk(dataset_path) # Create FAISS index passages = dataset["text"] tokenizer = RagTokenizer.from_pretrained("facebook/rag-sequence-nq") model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq") passage_embeddings = model.get_encoder()( tokenizer(passages, return_tensors="pt", padding=True, truncation=True) ).last_hidden_state.mean(dim=1).detach().numpy() index = faiss.IndexFlatL2(passage_embeddings.shape[1]) index.add(passage_embeddings) # Save the index to disk index_path = "path/to/your/index" faiss.write_index(index, index_path)