import datetime import os import time import faiss from datasets import load_dataset from huggingface_hub import HfApi, hf_hub_download, upload_file from sentence_transformers import SentenceTransformer from arxiv_stuff import ARXIV_CATEGORIES_FLAT, retrieve_arxiv_papers # Dataset details default_dataset_revision = "v1.0.0" local_index_path = "arxiv_faiss_index.faiss" HF_TOKEN = os.getenv("HF_TOKEN") class DatasetManager: def __init__( self, dataset_name: str = "nomadicsynth/arxiv-dataset-abstract-embeddings", embedding_model: SentenceTransformer = None, hf_token: str = None, ): """ Initialize the DatasetManager with the dataset name, Hugging Face token, and embedding model. Args: dataset_name (str): The name of the dataset on Hugging Face Hub. embedding_model (SentenceTransformer): The embedding model to use for generating embeddings. hf_token (str): The Hugging Face token for authentication. """ self.dataset_name = dataset_name self.hf_token = hf_token self.embedding_model = embedding_model self.revision = self.get_latest_revision() if self.hf_token is None: self.hf_token = HF_TOKEN if self.embedding_model is None: raise ValueError("Embedding model must be provided.") self.dataset = None self.setup_dataset() def generate_revision_name(self): """Generate a timestamp-based revision name.""" return datetime.datetime.now().strftime("v%Y-%m-%d") def get_latest_revision(self): """Return the latest timestamp-based revision.""" global default_dataset_revision api = HfApi() print(f"Fetching revisions for dataset: {self.dataset_name}") # List all tags in the repository refs = api.list_repo_refs(repo_id=self.dataset_name, repo_type="dataset", token=self.hf_token) tags = refs.tags print(f"Found tags: {[tag.name for tag in tags]}") # Filter tags with the "vYYYY-MM-DD" format timestamp_tags = [ tag.name for tag in tags if tag.name.startswith("v") and len(tag.name) == 11 and tag.name[1:11].isdigit() ] if not timestamp_tags: print(f"No valid timestamp-based revisions found. Using `{default_dataset_revision}` as default.") return default_dataset_revision print(f"Valid timestamp-based revisions: {timestamp_tags}") # Sort and return the most recent tag latest_revision = sorted(timestamp_tags)[-1] print(f"Latest revision determined: {latest_revision}") return latest_revision def setup_dataset(self): """Load dataset with FAISS index.""" print("Loading dataset from Hugging Face...") # Load dataset dataset = load_dataset( self.dataset_name, revision=self.revision, token=self.hf_token, ) # Try to load the index from the Hub try: print("Downloading pre-built FAISS index...") index_path = hf_hub_download( repo_id=self.dataset_name, filename=local_index_path, revision=self.revision, token=self.hf_token, repo_type="dataset", ) print("Loading pre-built FAISS index...") dataset["train"].load_faiss_index("embedding", index_path) print("Pre-built FAISS index loaded successfully") except Exception as e: print(f"Could not load pre-built index: {e}") print("Building new FAISS index...") # Add FAISS index if it doesn't exist if not dataset["train"].features.get("embedding"): print("Dataset doesn't have 'embedding' column, cannot create FAISS index") raise ValueError("Dataset doesn't have 'embedding' column") dataset["train"].add_faiss_index( column="embedding", metric_type=faiss.METRIC_INNER_PRODUCT, string_factory="HNSW,RFlat", # Using reranking ) print(f"Dataset loaded with {len(dataset['train'])} items and FAISS index ready") self.dataset = dataset return dataset def update_dataset_with_new_papers(self): """Fetch new papers from arXiv, ensure no duplicates, and update the dataset and FAISS index.""" if self.dataset is None: self.setup_dataset() # Get the last update date from the dataset last_update_date = max( [datetime.datetime.strptime(row["update_date"], "%Y-%m-%d") for row in self.dataset["train"]], default=datetime.datetime.now() - datetime.timedelta(days=1), ) # Initialize variables for iterative querying start = 0 max_results_per_query = 100 all_new_papers = [] while True: # Fetch new papers from arXiv since the last update new_papers = retrieve_arxiv_papers( categories=list(ARXIV_CATEGORIES_FLAT.keys()), start_date=last_update_date, end_date=datetime.datetime.now(), start=start, max_results=max_results_per_query, ) if not new_papers: break all_new_papers.extend(new_papers) start += max_results_per_query # Respect the rate limit of 1 query every 3 seconds time.sleep(3) # Filter out duplicates existing_ids = set(row["id"] for row in self.dataset["train"]) unique_papers = [paper for paper in all_new_papers if paper["arxiv_id"] not in existing_ids] if not unique_papers: print("No new papers to add.") return # Add new papers to the dataset for paper in unique_papers: embedding = self.embedding_model.embed_text(paper["abstract"]) self.dataset["train"].add_item( { "id": paper["arxiv_id"], "title": paper["title"], "authors": ", ".join(paper["authors"]), "categories": ", ".join(paper["categories"]), "abstract": paper["abstract"], "update_date": paper["published_date"], "embedding": embedding, } ) # Save the updated dataset to the Hub with a new revision new_revision = self.generate_revision_name() self.dataset.push_to_hub( repo_id=self.dataset_name, token=self.hf_token, commit_message=f"Update dataset with new papers ({new_revision})", revision=new_revision, ) # Update the FAISS index self.dataset["train"].add_faiss_index( column="embedding", metric_type=faiss.METRIC_INNER_PRODUCT, string_factory="HNSW,RFlat", ) # Save the FAISS index to the Hub self.save_faiss_index_to_hub(new_revision) print(f"Dataset updated and saved to the Hub with revision {new_revision}.") def save_faiss_index_to_hub(self, revision: str): """Save the FAISS index to the Hub for easy access""" global local_index_path # 1. Save the index to a local file self.dataset["train"].save_faiss_index("embedding", local_index_path) print(f"FAISS index saved locally to {local_index_path}") # 2. Upload the index file to the Hub remote_path = upload_file( path_or_fileobj=local_index_path, path_in_repo=local_index_path, # Same name on the Hub repo_id=self.dataset_name, # Use your dataset repo token=self.hf_token, repo_type="dataset", # This is a dataset file revision=revision, # Use the new revision commit_message=f"Add FAISS index for dataset revision {revision}", ) print(f"FAISS index uploaded to Hub at {remote_path}") # Remove the local file. It's now stored on the Hub. os.remove(local_index_path)