inkling / dataset_utils.py
nomadicsynth's picture
Refactor dataset management and improve dataset update functionality
c241b7f
import datetime
import os
import time
import faiss
from datasets import load_dataset
from huggingface_hub import HfApi, hf_hub_download, upload_file
from sentence_transformers import SentenceTransformer
from arxiv_stuff import ARXIV_CATEGORIES_FLAT, retrieve_arxiv_papers
# Dataset details
default_dataset_revision = "v1.0.0"
local_index_path = "arxiv_faiss_index.faiss"
HF_TOKEN = os.getenv("HF_TOKEN")
class DatasetManager:
def __init__(
self,
dataset_name: str = "nomadicsynth/arxiv-dataset-abstract-embeddings",
embedding_model: SentenceTransformer = None,
hf_token: str = None,
):
"""
Initialize the DatasetManager with the dataset name, Hugging Face token, and embedding model.
Args:
dataset_name (str): The name of the dataset on Hugging Face Hub.
embedding_model (SentenceTransformer): The embedding model to use for generating embeddings.
hf_token (str): The Hugging Face token for authentication.
"""
self.dataset_name = dataset_name
self.hf_token = hf_token
self.embedding_model = embedding_model
self.revision = self.get_latest_revision()
if self.hf_token is None:
self.hf_token = HF_TOKEN
if self.embedding_model is None:
raise ValueError("Embedding model must be provided.")
self.dataset = None
self.setup_dataset()
def generate_revision_name(self):
"""Generate a timestamp-based revision name."""
return datetime.datetime.now().strftime("v%Y-%m-%d")
def get_latest_revision(self):
"""Return the latest timestamp-based revision."""
global default_dataset_revision
api = HfApi()
print(f"Fetching revisions for dataset: {self.dataset_name}")
# List all tags in the repository
refs = api.list_repo_refs(repo_id=self.dataset_name, repo_type="dataset", token=self.hf_token)
tags = refs.tags
print(f"Found tags: {[tag.name for tag in tags]}")
# Filter tags with the "vYYYY-MM-DD" format
timestamp_tags = [
tag.name for tag in tags if tag.name.startswith("v") and len(tag.name) == 11 and tag.name[1:11].isdigit()
]
if not timestamp_tags:
print(f"No valid timestamp-based revisions found. Using `{default_dataset_revision}` as default.")
return default_dataset_revision
print(f"Valid timestamp-based revisions: {timestamp_tags}")
# Sort and return the most recent tag
latest_revision = sorted(timestamp_tags)[-1]
print(f"Latest revision determined: {latest_revision}")
return latest_revision
def setup_dataset(self):
"""Load dataset with FAISS index."""
print("Loading dataset from Hugging Face...")
# Load dataset
dataset = load_dataset(
self.dataset_name,
revision=self.revision,
token=self.hf_token,
)
# Try to load the index from the Hub
try:
print("Downloading pre-built FAISS index...")
index_path = hf_hub_download(
repo_id=self.dataset_name,
filename=local_index_path,
revision=self.revision,
token=self.hf_token,
repo_type="dataset",
)
print("Loading pre-built FAISS index...")
dataset["train"].load_faiss_index("embedding", index_path)
print("Pre-built FAISS index loaded successfully")
except Exception as e:
print(f"Could not load pre-built index: {e}")
print("Building new FAISS index...")
# Add FAISS index if it doesn't exist
if not dataset["train"].features.get("embedding"):
print("Dataset doesn't have 'embedding' column, cannot create FAISS index")
raise ValueError("Dataset doesn't have 'embedding' column")
dataset["train"].add_faiss_index(
column="embedding",
metric_type=faiss.METRIC_INNER_PRODUCT,
string_factory="HNSW,RFlat", # Using reranking
)
print(f"Dataset loaded with {len(dataset['train'])} items and FAISS index ready")
self.dataset = dataset
return dataset
def update_dataset_with_new_papers(self):
"""Fetch new papers from arXiv, ensure no duplicates, and update the dataset and FAISS index."""
if self.dataset is None:
self.setup_dataset()
# Get the last update date from the dataset
last_update_date = max(
[datetime.datetime.strptime(row["update_date"], "%Y-%m-%d") for row in self.dataset["train"]],
default=datetime.datetime.now() - datetime.timedelta(days=1),
)
# Initialize variables for iterative querying
start = 0
max_results_per_query = 100
all_new_papers = []
while True:
# Fetch new papers from arXiv since the last update
new_papers = retrieve_arxiv_papers(
categories=list(ARXIV_CATEGORIES_FLAT.keys()),
start_date=last_update_date,
end_date=datetime.datetime.now(),
start=start,
max_results=max_results_per_query,
)
if not new_papers:
break
all_new_papers.extend(new_papers)
start += max_results_per_query
# Respect the rate limit of 1 query every 3 seconds
time.sleep(3)
# Filter out duplicates
existing_ids = set(row["id"] for row in self.dataset["train"])
unique_papers = [paper for paper in all_new_papers if paper["arxiv_id"] not in existing_ids]
if not unique_papers:
print("No new papers to add.")
return
# Add new papers to the dataset
for paper in unique_papers:
embedding = self.embedding_model.embed_text(paper["abstract"])
self.dataset["train"].add_item(
{
"id": paper["arxiv_id"],
"title": paper["title"],
"authors": ", ".join(paper["authors"]),
"categories": ", ".join(paper["categories"]),
"abstract": paper["abstract"],
"update_date": paper["published_date"],
"embedding": embedding,
}
)
# Save the updated dataset to the Hub with a new revision
new_revision = self.generate_revision_name()
self.dataset.push_to_hub(
repo_id=self.dataset_name,
token=self.hf_token,
commit_message=f"Update dataset with new papers ({new_revision})",
revision=new_revision,
)
# Update the FAISS index
self.dataset["train"].add_faiss_index(
column="embedding",
metric_type=faiss.METRIC_INNER_PRODUCT,
string_factory="HNSW,RFlat",
)
# Save the FAISS index to the Hub
self.save_faiss_index_to_hub(new_revision)
print(f"Dataset updated and saved to the Hub with revision {new_revision}.")
def save_faiss_index_to_hub(self, revision: str):
"""Save the FAISS index to the Hub for easy access"""
global local_index_path
# 1. Save the index to a local file
self.dataset["train"].save_faiss_index("embedding", local_index_path)
print(f"FAISS index saved locally to {local_index_path}")
# 2. Upload the index file to the Hub
remote_path = upload_file(
path_or_fileobj=local_index_path,
path_in_repo=local_index_path, # Same name on the Hub
repo_id=self.dataset_name, # Use your dataset repo
token=self.hf_token,
repo_type="dataset", # This is a dataset file
revision=revision, # Use the new revision
commit_message=f"Add FAISS index for dataset revision {revision}",
)
print(f"FAISS index uploaded to Hub at {remote_path}")
# Remove the local file. It's now stored on the Hub.
os.remove(local_index_path)