Spaces:
Running
on
Zero
Running
on
Zero
File size: 8,254 Bytes
261056f c241b7f 261056f c241b7f 261056f c241b7f 261056f c241b7f 261056f c241b7f 261056f c241b7f 261056f c241b7f 261056f c241b7f 261056f c241b7f 261056f c241b7f 261056f c241b7f 261056f c241b7f 261056f c241b7f 261056f c241b7f 261056f c241b7f 261056f c241b7f 261056f c241b7f 261056f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
import datetime
import os
import time
import faiss
from datasets import load_dataset
from huggingface_hub import HfApi, hf_hub_download, upload_file
from sentence_transformers import SentenceTransformer
from arxiv_stuff import ARXIV_CATEGORIES_FLAT, retrieve_arxiv_papers
# Dataset details
default_dataset_revision = "v1.0.0"
local_index_path = "arxiv_faiss_index.faiss"
HF_TOKEN = os.getenv("HF_TOKEN")
class DatasetManager:
def __init__(
self,
dataset_name: str = "nomadicsynth/arxiv-dataset-abstract-embeddings",
embedding_model: SentenceTransformer = None,
hf_token: str = None,
):
"""
Initialize the DatasetManager with the dataset name, Hugging Face token, and embedding model.
Args:
dataset_name (str): The name of the dataset on Hugging Face Hub.
embedding_model (SentenceTransformer): The embedding model to use for generating embeddings.
hf_token (str): The Hugging Face token for authentication.
"""
self.dataset_name = dataset_name
self.hf_token = hf_token
self.embedding_model = embedding_model
self.revision = self.get_latest_revision()
if self.hf_token is None:
self.hf_token = HF_TOKEN
if self.embedding_model is None:
raise ValueError("Embedding model must be provided.")
self.dataset = None
self.setup_dataset()
def generate_revision_name(self):
"""Generate a timestamp-based revision name."""
return datetime.datetime.now().strftime("v%Y-%m-%d")
def get_latest_revision(self):
"""Return the latest timestamp-based revision."""
global default_dataset_revision
api = HfApi()
print(f"Fetching revisions for dataset: {self.dataset_name}")
# List all tags in the repository
refs = api.list_repo_refs(repo_id=self.dataset_name, repo_type="dataset", token=self.hf_token)
tags = refs.tags
print(f"Found tags: {[tag.name for tag in tags]}")
# Filter tags with the "vYYYY-MM-DD" format
timestamp_tags = [
tag.name for tag in tags if tag.name.startswith("v") and len(tag.name) == 11 and tag.name[1:11].isdigit()
]
if not timestamp_tags:
print(f"No valid timestamp-based revisions found. Using `{default_dataset_revision}` as default.")
return default_dataset_revision
print(f"Valid timestamp-based revisions: {timestamp_tags}")
# Sort and return the most recent tag
latest_revision = sorted(timestamp_tags)[-1]
print(f"Latest revision determined: {latest_revision}")
return latest_revision
def setup_dataset(self):
"""Load dataset with FAISS index."""
print("Loading dataset from Hugging Face...")
# Load dataset
dataset = load_dataset(
self.dataset_name,
revision=self.revision,
token=self.hf_token,
)
# Try to load the index from the Hub
try:
print("Downloading pre-built FAISS index...")
index_path = hf_hub_download(
repo_id=self.dataset_name,
filename=local_index_path,
revision=self.revision,
token=self.hf_token,
repo_type="dataset",
)
print("Loading pre-built FAISS index...")
dataset["train"].load_faiss_index("embedding", index_path)
print("Pre-built FAISS index loaded successfully")
except Exception as e:
print(f"Could not load pre-built index: {e}")
print("Building new FAISS index...")
# Add FAISS index if it doesn't exist
if not dataset["train"].features.get("embedding"):
print("Dataset doesn't have 'embedding' column, cannot create FAISS index")
raise ValueError("Dataset doesn't have 'embedding' column")
dataset["train"].add_faiss_index(
column="embedding",
metric_type=faiss.METRIC_INNER_PRODUCT,
string_factory="HNSW,RFlat", # Using reranking
)
print(f"Dataset loaded with {len(dataset['train'])} items and FAISS index ready")
self.dataset = dataset
return dataset
def update_dataset_with_new_papers(self):
"""Fetch new papers from arXiv, ensure no duplicates, and update the dataset and FAISS index."""
if self.dataset is None:
self.setup_dataset()
# Get the last update date from the dataset
last_update_date = max(
[datetime.datetime.strptime(row["update_date"], "%Y-%m-%d") for row in self.dataset["train"]],
default=datetime.datetime.now() - datetime.timedelta(days=1),
)
# Initialize variables for iterative querying
start = 0
max_results_per_query = 100
all_new_papers = []
while True:
# Fetch new papers from arXiv since the last update
new_papers = retrieve_arxiv_papers(
categories=list(ARXIV_CATEGORIES_FLAT.keys()),
start_date=last_update_date,
end_date=datetime.datetime.now(),
start=start,
max_results=max_results_per_query,
)
if not new_papers:
break
all_new_papers.extend(new_papers)
start += max_results_per_query
# Respect the rate limit of 1 query every 3 seconds
time.sleep(3)
# Filter out duplicates
existing_ids = set(row["id"] for row in self.dataset["train"])
unique_papers = [paper for paper in all_new_papers if paper["arxiv_id"] not in existing_ids]
if not unique_papers:
print("No new papers to add.")
return
# Add new papers to the dataset
for paper in unique_papers:
embedding = self.embedding_model.embed_text(paper["abstract"])
self.dataset["train"].add_item(
{
"id": paper["arxiv_id"],
"title": paper["title"],
"authors": ", ".join(paper["authors"]),
"categories": ", ".join(paper["categories"]),
"abstract": paper["abstract"],
"update_date": paper["published_date"],
"embedding": embedding,
}
)
# Save the updated dataset to the Hub with a new revision
new_revision = self.generate_revision_name()
self.dataset.push_to_hub(
repo_id=self.dataset_name,
token=self.hf_token,
commit_message=f"Update dataset with new papers ({new_revision})",
revision=new_revision,
)
# Update the FAISS index
self.dataset["train"].add_faiss_index(
column="embedding",
metric_type=faiss.METRIC_INNER_PRODUCT,
string_factory="HNSW,RFlat",
)
# Save the FAISS index to the Hub
self.save_faiss_index_to_hub(new_revision)
print(f"Dataset updated and saved to the Hub with revision {new_revision}.")
def save_faiss_index_to_hub(self, revision: str):
"""Save the FAISS index to the Hub for easy access"""
global local_index_path
# 1. Save the index to a local file
self.dataset["train"].save_faiss_index("embedding", local_index_path)
print(f"FAISS index saved locally to {local_index_path}")
# 2. Upload the index file to the Hub
remote_path = upload_file(
path_or_fileobj=local_index_path,
path_in_repo=local_index_path, # Same name on the Hub
repo_id=self.dataset_name, # Use your dataset repo
token=self.hf_token,
repo_type="dataset", # This is a dataset file
revision=revision, # Use the new revision
commit_message=f"Add FAISS index for dataset revision {revision}",
)
print(f"FAISS index uploaded to Hub at {remote_path}")
# Remove the local file. It's now stored on the Hub.
os.remove(local_index_path)
|