File size: 8,254 Bytes
261056f
c241b7f
261056f
 
c241b7f
 
 
261056f
 
c241b7f
 
261056f
c241b7f
 
 
261056f
 
 
 
 
 
 
c241b7f
 
261056f
 
 
 
 
 
 
 
 
 
 
 
c241b7f
 
 
 
 
 
261056f
c241b7f
261056f
 
c241b7f
261056f
 
 
 
 
c241b7f
261056f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c241b7f
 
261056f
 
 
 
 
 
 
 
 
 
 
 
 
c241b7f
 
 
261056f
 
 
 
 
 
c241b7f
 
 
261056f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c241b7f
 
 
 
 
 
 
 
 
261056f
 
 
 
 
 
 
 
c241b7f
261056f
 
 
c241b7f
261056f
c241b7f
261056f
 
 
 
 
 
 
 
 
 
 
 
c241b7f
 
261056f
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import datetime
import os
import time

import faiss
from datasets import load_dataset
from huggingface_hub import HfApi, hf_hub_download, upload_file
from sentence_transformers import SentenceTransformer

from arxiv_stuff import ARXIV_CATEGORIES_FLAT, retrieve_arxiv_papers

# Dataset details
default_dataset_revision = "v1.0.0"
local_index_path = "arxiv_faiss_index.faiss"

HF_TOKEN = os.getenv("HF_TOKEN")


class DatasetManager:

    def __init__(
        self,
        dataset_name: str = "nomadicsynth/arxiv-dataset-abstract-embeddings",
        embedding_model: SentenceTransformer = None,
        hf_token: str = None,
    ):
        """
        Initialize the DatasetManager with the dataset name, Hugging Face token, and embedding model.
        Args:
            dataset_name (str): The name of the dataset on Hugging Face Hub.
            embedding_model (SentenceTransformer): The embedding model to use for generating embeddings.
            hf_token (str): The Hugging Face token for authentication.
        """
        self.dataset_name = dataset_name
        self.hf_token = hf_token
        self.embedding_model = embedding_model
        self.revision = self.get_latest_revision()

        if self.hf_token is None:
            self.hf_token = HF_TOKEN
        if self.embedding_model is None:
            raise ValueError("Embedding model must be provided.")

        self.dataset = None
        self.setup_dataset()

    def generate_revision_name(self):
        """Generate a timestamp-based revision name."""
        return datetime.datetime.now().strftime("v%Y-%m-%d")

    def get_latest_revision(self):
        """Return the latest timestamp-based revision."""
        global default_dataset_revision
        api = HfApi()
        print(f"Fetching revisions for dataset: {self.dataset_name}")

        # List all tags in the repository
        refs = api.list_repo_refs(repo_id=self.dataset_name, repo_type="dataset", token=self.hf_token)
        tags = refs.tags

        print(f"Found tags: {[tag.name for tag in tags]}")

        # Filter tags with the "vYYYY-MM-DD" format
        timestamp_tags = [
            tag.name for tag in tags if tag.name.startswith("v") and len(tag.name) == 11 and tag.name[1:11].isdigit()
        ]

        if not timestamp_tags:
            print(f"No valid timestamp-based revisions found. Using `{default_dataset_revision}` as default.")
            return default_dataset_revision
        print(f"Valid timestamp-based revisions: {timestamp_tags}")

        # Sort and return the most recent tag
        latest_revision = sorted(timestamp_tags)[-1]
        print(f"Latest revision determined: {latest_revision}")
        return latest_revision

    def setup_dataset(self):
        """Load dataset with FAISS index."""
        print("Loading dataset from Hugging Face...")

        # Load dataset
        dataset = load_dataset(
            self.dataset_name,
            revision=self.revision,
            token=self.hf_token,
        )

        # Try to load the index from the Hub
        try:
            print("Downloading pre-built FAISS index...")
            index_path = hf_hub_download(
                repo_id=self.dataset_name,
                filename=local_index_path,
                revision=self.revision,
                token=self.hf_token,
                repo_type="dataset",
            )

            print("Loading pre-built FAISS index...")
            dataset["train"].load_faiss_index("embedding", index_path)
            print("Pre-built FAISS index loaded successfully")

        except Exception as e:
            print(f"Could not load pre-built index: {e}")
            print("Building new FAISS index...")

            # Add FAISS index if it doesn't exist
            if not dataset["train"].features.get("embedding"):
                print("Dataset doesn't have 'embedding' column, cannot create FAISS index")
                raise ValueError("Dataset doesn't have 'embedding' column")

            dataset["train"].add_faiss_index(
                column="embedding",
                metric_type=faiss.METRIC_INNER_PRODUCT,
                string_factory="HNSW,RFlat",  # Using reranking
            )

        print(f"Dataset loaded with {len(dataset['train'])} items and FAISS index ready")

        self.dataset = dataset
        return dataset

    def update_dataset_with_new_papers(self):
        """Fetch new papers from arXiv, ensure no duplicates, and update the dataset and FAISS index."""
        if self.dataset is None:
            self.setup_dataset()

        # Get the last update date from the dataset
        last_update_date = max(
            [datetime.datetime.strptime(row["update_date"], "%Y-%m-%d") for row in self.dataset["train"]],
            default=datetime.datetime.now() - datetime.timedelta(days=1),
        )

        # Initialize variables for iterative querying
        start = 0
        max_results_per_query = 100
        all_new_papers = []

        while True:
            # Fetch new papers from arXiv since the last update
            new_papers = retrieve_arxiv_papers(
                categories=list(ARXIV_CATEGORIES_FLAT.keys()),
                start_date=last_update_date,
                end_date=datetime.datetime.now(),
                start=start,
                max_results=max_results_per_query,
            )

            if not new_papers:
                break

            all_new_papers.extend(new_papers)
            start += max_results_per_query

            # Respect the rate limit of 1 query every 3 seconds
            time.sleep(3)

        # Filter out duplicates
        existing_ids = set(row["id"] for row in self.dataset["train"])
        unique_papers = [paper for paper in all_new_papers if paper["arxiv_id"] not in existing_ids]

        if not unique_papers:
            print("No new papers to add.")
            return

        # Add new papers to the dataset
        for paper in unique_papers:
            embedding = self.embedding_model.embed_text(paper["abstract"])
            self.dataset["train"].add_item(
                {
                    "id": paper["arxiv_id"],
                    "title": paper["title"],
                    "authors": ", ".join(paper["authors"]),
                    "categories": ", ".join(paper["categories"]),
                    "abstract": paper["abstract"],
                    "update_date": paper["published_date"],
                    "embedding": embedding,
                }
            )

        # Save the updated dataset to the Hub with a new revision
        new_revision = self.generate_revision_name()
        self.dataset.push_to_hub(
            repo_id=self.dataset_name,
            token=self.hf_token,
            commit_message=f"Update dataset with new papers ({new_revision})",
            revision=new_revision,
        )

        # Update the FAISS index
        self.dataset["train"].add_faiss_index(
            column="embedding",
            metric_type=faiss.METRIC_INNER_PRODUCT,
            string_factory="HNSW,RFlat",
        )

        # Save the FAISS index to the Hub
        self.save_faiss_index_to_hub(new_revision)

        print(f"Dataset updated and saved to the Hub with revision {new_revision}.")

    def save_faiss_index_to_hub(self, revision: str):
        """Save the FAISS index to the Hub for easy access"""
        global local_index_path

        # 1. Save the index to a local file
        self.dataset["train"].save_faiss_index("embedding", local_index_path)
        print(f"FAISS index saved locally to {local_index_path}")

        # 2. Upload the index file to the Hub
        remote_path = upload_file(
            path_or_fileobj=local_index_path,
            path_in_repo=local_index_path,  # Same name on the Hub
            repo_id=self.dataset_name,  # Use your dataset repo
            token=self.hf_token,
            repo_type="dataset",  # This is a dataset file
            revision=revision,  # Use the new revision
            commit_message=f"Add FAISS index for dataset revision {revision}",
        )

        print(f"FAISS index uploaded to Hub at {remote_path}")

        # Remove the local file. It's now stored on the Hub.
        os.remove(local_index_path)