Spaces:
Running
on
Zero
Running
on
Zero
import datetime | |
import os | |
import time | |
import faiss | |
from datasets import load_dataset | |
from huggingface_hub import HfApi, hf_hub_download, upload_file | |
from sentence_transformers import SentenceTransformer | |
from arxiv_stuff import ARXIV_CATEGORIES_FLAT, retrieve_arxiv_papers | |
# Dataset details | |
default_dataset_revision = "v1.0.0" | |
local_index_path = "arxiv_faiss_index.faiss" | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
class DatasetManager: | |
def __init__( | |
self, | |
dataset_name: str = "nomadicsynth/arxiv-dataset-abstract-embeddings", | |
embedding_model: SentenceTransformer = None, | |
hf_token: str = None, | |
): | |
""" | |
Initialize the DatasetManager with the dataset name, Hugging Face token, and embedding model. | |
Args: | |
dataset_name (str): The name of the dataset on Hugging Face Hub. | |
embedding_model (SentenceTransformer): The embedding model to use for generating embeddings. | |
hf_token (str): The Hugging Face token for authentication. | |
""" | |
self.dataset_name = dataset_name | |
self.hf_token = hf_token | |
self.embedding_model = embedding_model | |
self.revision = self.get_latest_revision() | |
if self.hf_token is None: | |
self.hf_token = HF_TOKEN | |
if self.embedding_model is None: | |
raise ValueError("Embedding model must be provided.") | |
self.dataset = None | |
self.setup_dataset() | |
def generate_revision_name(self): | |
"""Generate a timestamp-based revision name.""" | |
return datetime.datetime.now().strftime("v%Y-%m-%d") | |
def get_latest_revision(self): | |
"""Return the latest timestamp-based revision.""" | |
global default_dataset_revision | |
api = HfApi() | |
print(f"Fetching revisions for dataset: {self.dataset_name}") | |
# List all tags in the repository | |
refs = api.list_repo_refs(repo_id=self.dataset_name, repo_type="dataset", token=self.hf_token) | |
tags = refs.tags | |
print(f"Found tags: {[tag.name for tag in tags]}") | |
# Filter tags with the "vYYYY-MM-DD" format | |
timestamp_tags = [ | |
tag.name for tag in tags if tag.name.startswith("v") and len(tag.name) == 11 and tag.name[1:11].isdigit() | |
] | |
if not timestamp_tags: | |
print(f"No valid timestamp-based revisions found. Using `{default_dataset_revision}` as default.") | |
return default_dataset_revision | |
print(f"Valid timestamp-based revisions: {timestamp_tags}") | |
# Sort and return the most recent tag | |
latest_revision = sorted(timestamp_tags)[-1] | |
print(f"Latest revision determined: {latest_revision}") | |
return latest_revision | |
def setup_dataset(self): | |
"""Load dataset with FAISS index.""" | |
print("Loading dataset from Hugging Face...") | |
# Load dataset | |
dataset = load_dataset( | |
self.dataset_name, | |
revision=self.revision, | |
token=self.hf_token, | |
) | |
# Try to load the index from the Hub | |
try: | |
print("Downloading pre-built FAISS index...") | |
index_path = hf_hub_download( | |
repo_id=self.dataset_name, | |
filename=local_index_path, | |
revision=self.revision, | |
token=self.hf_token, | |
repo_type="dataset", | |
) | |
print("Loading pre-built FAISS index...") | |
dataset["train"].load_faiss_index("embedding", index_path) | |
print("Pre-built FAISS index loaded successfully") | |
except Exception as e: | |
print(f"Could not load pre-built index: {e}") | |
print("Building new FAISS index...") | |
# Add FAISS index if it doesn't exist | |
if not dataset["train"].features.get("embedding"): | |
print("Dataset doesn't have 'embedding' column, cannot create FAISS index") | |
raise ValueError("Dataset doesn't have 'embedding' column") | |
dataset["train"].add_faiss_index( | |
column="embedding", | |
metric_type=faiss.METRIC_INNER_PRODUCT, | |
string_factory="HNSW,RFlat", # Using reranking | |
) | |
print(f"Dataset loaded with {len(dataset['train'])} items and FAISS index ready") | |
self.dataset = dataset | |
return dataset | |
def update_dataset_with_new_papers(self): | |
"""Fetch new papers from arXiv, ensure no duplicates, and update the dataset and FAISS index.""" | |
if self.dataset is None: | |
self.setup_dataset() | |
# Get the last update date from the dataset | |
last_update_date = max( | |
[datetime.datetime.strptime(row["update_date"], "%Y-%m-%d") for row in self.dataset["train"]], | |
default=datetime.datetime.now() - datetime.timedelta(days=1), | |
) | |
# Initialize variables for iterative querying | |
start = 0 | |
max_results_per_query = 100 | |
all_new_papers = [] | |
while True: | |
# Fetch new papers from arXiv since the last update | |
new_papers = retrieve_arxiv_papers( | |
categories=list(ARXIV_CATEGORIES_FLAT.keys()), | |
start_date=last_update_date, | |
end_date=datetime.datetime.now(), | |
start=start, | |
max_results=max_results_per_query, | |
) | |
if not new_papers: | |
break | |
all_new_papers.extend(new_papers) | |
start += max_results_per_query | |
# Respect the rate limit of 1 query every 3 seconds | |
time.sleep(3) | |
# Filter out duplicates | |
existing_ids = set(row["id"] for row in self.dataset["train"]) | |
unique_papers = [paper for paper in all_new_papers if paper["arxiv_id"] not in existing_ids] | |
if not unique_papers: | |
print("No new papers to add.") | |
return | |
# Add new papers to the dataset | |
for paper in unique_papers: | |
embedding = self.embedding_model.embed_text(paper["abstract"]) | |
self.dataset["train"].add_item( | |
{ | |
"id": paper["arxiv_id"], | |
"title": paper["title"], | |
"authors": ", ".join(paper["authors"]), | |
"categories": ", ".join(paper["categories"]), | |
"abstract": paper["abstract"], | |
"update_date": paper["published_date"], | |
"embedding": embedding, | |
} | |
) | |
# Save the updated dataset to the Hub with a new revision | |
new_revision = self.generate_revision_name() | |
self.dataset.push_to_hub( | |
repo_id=self.dataset_name, | |
token=self.hf_token, | |
commit_message=f"Update dataset with new papers ({new_revision})", | |
revision=new_revision, | |
) | |
# Update the FAISS index | |
self.dataset["train"].add_faiss_index( | |
column="embedding", | |
metric_type=faiss.METRIC_INNER_PRODUCT, | |
string_factory="HNSW,RFlat", | |
) | |
# Save the FAISS index to the Hub | |
self.save_faiss_index_to_hub(new_revision) | |
print(f"Dataset updated and saved to the Hub with revision {new_revision}.") | |
def save_faiss_index_to_hub(self, revision: str): | |
"""Save the FAISS index to the Hub for easy access""" | |
global local_index_path | |
# 1. Save the index to a local file | |
self.dataset["train"].save_faiss_index("embedding", local_index_path) | |
print(f"FAISS index saved locally to {local_index_path}") | |
# 2. Upload the index file to the Hub | |
remote_path = upload_file( | |
path_or_fileobj=local_index_path, | |
path_in_repo=local_index_path, # Same name on the Hub | |
repo_id=self.dataset_name, # Use your dataset repo | |
token=self.hf_token, | |
repo_type="dataset", # This is a dataset file | |
revision=revision, # Use the new revision | |
commit_message=f"Add FAISS index for dataset revision {revision}", | |
) | |
print(f"FAISS index uploaded to Hub at {remote_path}") | |
# Remove the local file. It's now stored on the Hub. | |
os.remove(local_index_path) | |