nomadicsynth commited on
Commit
261056f
·
1 Parent(s): 4236eec

Add DatasetManager for handling dataset operations and update dataset with new papers

Browse files
Files changed (3) hide show
  1. app.py +13 -79
  2. dataset_utils.py +221 -0
  3. update_dataset.py +21 -0
app.py CHANGED
@@ -13,6 +13,7 @@ from huggingface_hub import upload_file
13
  from sentence_transformers import SentenceTransformer
14
 
15
  from arxiv_stuff import ARXIV_CATEGORIES_FLAT
 
16
 
17
  # Get HF_TOKEN from environment variables
18
  HF_TOKEN = os.getenv("HF_TOKEN")
@@ -56,88 +57,17 @@ embedding_model = None
56
  reasoning_model = None
57
 
58
 
59
- def save_faiss_index_to_hub():
60
- """Save the FAISS index to the Hub for easy access"""
61
- global dataset, local_index_path
62
- # 1. Save the index to a local file
63
- dataset["train"].save_faiss_index("embedding", local_index_path)
64
- print(f"FAISS index saved locally to {local_index_path}")
65
-
66
- # 2. Upload the index file to the Hub
67
- remote_path = upload_file(
68
- path_or_fileobj=local_index_path,
69
- path_in_repo=local_index_path, # Same name on the Hub
70
- repo_id=dataset_name, # Use your dataset repo
71
- token=HF_TOKEN,
72
- repo_type="dataset", # This is a dataset file
73
- revision=dataset_revision, # Use the same revision as the dataset
74
- commit_message="Add FAISS index", # Commit message
75
- )
76
-
77
- print(f"FAISS index uploaded to Hub at {remote_path}")
78
-
79
- # Remove the local file. It's now stored on the Hub.
80
- os.remove(local_index_path)
81
-
82
-
83
- def setup_dataset():
84
- """Load dataset with FAISS index"""
85
- global dataset
86
- print("Loading dataset from Hugging Face...")
87
-
88
- # Load dataset
89
- dataset = load_dataset(
90
- dataset_name,
91
- revision=dataset_revision,
92
- )
93
-
94
- # Try to load the index from the Hub
95
- try:
96
- print("Downloading pre-built FAISS index...")
97
- index_path = hf_hub_download(
98
- repo_id=dataset_name,
99
- filename="arxiv_faiss_index.faiss",
100
- revision=dataset_revision,
101
- token=HF_TOKEN,
102
- repo_type="dataset",
103
- )
104
-
105
- print("Loading pre-built FAISS index...")
106
- dataset["train"].load_faiss_index("embedding", index_path)
107
- print("Pre-built FAISS index loaded successfully")
108
-
109
- except Exception as e:
110
- print(f"Could not load pre-built index: {e}")
111
- print("Building new FAISS index...")
112
-
113
- # Add FAISS index if it doesn't exist
114
- if not dataset["train"].features.get("embedding"):
115
- print("Dataset doesn't have 'embedding' column, cannot create FAISS index")
116
- raise ValueError("Dataset doesn't have 'embedding' column")
117
-
118
- dataset["train"].add_faiss_index(
119
- column="embedding",
120
- metric_type=faiss.METRIC_INNER_PRODUCT,
121
- string_factory="HNSW,RFlat", # Using reranking
122
- )
123
-
124
- # Save the FAISS index to the Hub
125
- save_faiss_index_to_hub()
126
-
127
- print(f"Dataset loaded with {len(dataset['train'])} items and FAISS index ready")
128
-
129
-
130
- def init_embedding_model(model_name_or_path: str, model_revision: str = None) -> SentenceTransformer:
131
- global embedding_model
132
-
133
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
134
  embedding_model = SentenceTransformer(
135
  model_name_or_path,
136
  revision=model_revision,
137
- token=HF_TOKEN,
138
  device=device,
139
  )
140
 
 
 
141
 
142
  @spaces.GPU
143
  def embed_text(text: str | list[str]) -> torch.Tensor:
@@ -798,14 +728,18 @@ def create_interface():
798
 
799
 
800
  if __name__ == "__main__":
801
- # Load dataset with FAISS index
802
- setup_dataset()
803
-
804
  # Initialize the embedding model
805
- init_embedding_model(embedding_model_name, embedding_model_revision)
806
 
807
  # Initialize the reasoning model
808
  reasoning_model = init_reasoning_model(reasoning_model_id)
809
 
 
 
 
 
 
 
 
810
  demo = create_interface()
811
  demo.queue(api_open=False).launch(ssr_mode=False, show_api=True)
 
13
  from sentence_transformers import SentenceTransformer
14
 
15
  from arxiv_stuff import ARXIV_CATEGORIES_FLAT
16
+ from dataset_utils import DatasetManager
17
 
18
  # Get HF_TOKEN from environment variables
19
  HF_TOKEN = os.getenv("HF_TOKEN")
 
57
  reasoning_model = None
58
 
59
 
60
+ def init_embedding_model(model_name_or_path: str, model_revision: str = None, hf_token: str = None) -> SentenceTransformer:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
  embedding_model = SentenceTransformer(
63
  model_name_or_path,
64
  revision=model_revision,
65
+ token=hf_token,
66
  device=device,
67
  )
68
 
69
+ return embedding_model
70
+
71
 
72
  @spaces.GPU
73
  def embed_text(text: str | list[str]) -> torch.Tensor:
 
728
 
729
 
730
  if __name__ == "__main__":
 
 
 
731
  # Initialize the embedding model
732
+ embedding_model = init_embedding_model(embedding_model_name, embedding_model_revision)
733
 
734
  # Initialize the reasoning model
735
  reasoning_model = init_reasoning_model(reasoning_model_id)
736
 
737
+ # Load dataset with FAISS index
738
+ dataset = DatasetManager(
739
+ dataset_name=dataset_name,
740
+ hf_token=HF_TOKEN,
741
+ embedding_model=embedding_model,
742
+ )
743
+
744
  demo = create_interface()
745
  demo.queue(api_open=False).launch(ssr_mode=False, show_api=True)
dataset_utils.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from huggingface_hub import HfApi, hf_hub_download
3
+ import faiss
4
+ import os
5
+ import datetime
6
+ import time
7
+
8
+ # from embedding_model import EmbeddingModel
9
+ # from app import EmbeddingModel
10
+ from arxiv_stuff import retrieve_arxiv_papers, ARXIV_CATEGORIES_FLAT
11
+ from sentence_transformers import SentenceTransformer
12
+
13
+ # Dataset details
14
+ dataset_name = "nomadicsynth/arxiv-dataset-abstract-embeddings"
15
+ HF_TOKEN = os.getenv("HF_TOKEN")
16
+
17
+
18
+ class DatasetManager:
19
+
20
+ def __init__(
21
+ self,
22
+ dataset_name: str,
23
+ embedding_model: SentenceTransformer,
24
+ hf_token: str = None,
25
+ ):
26
+ """
27
+ Initialize the DatasetManager with the dataset name, Hugging Face token, and embedding model.
28
+ Args:
29
+ dataset_name (str): The name of the dataset on Hugging Face Hub.
30
+ embedding_model (SentenceTransformer): The embedding model to use for generating embeddings.
31
+ hf_token (str): The Hugging Face token for authentication.
32
+ """
33
+ self.dataset_name = dataset_name
34
+ self.hf_token = hf_token
35
+ self.embedding_model = embedding_model
36
+ self.dataset = None
37
+
38
+ self.setup_dataset()
39
+
40
+ def get_revision_name(self):
41
+ """Generate a timestamp-based revision name."""
42
+ return datetime.datetime.now().strftime("v%Y-%m-%d")
43
+
44
+ def get_latest_revision(self):
45
+ """Return the latest timestamp-based revision."""
46
+ api = HfApi()
47
+ print(f"Fetching revisions for dataset: {self.dataset_name}")
48
+
49
+ # List all tags in the repository
50
+ refs = api.list_repo_refs(repo_id=self.dataset_name, repo_type="dataset", token=self.hf_token)
51
+ tags = refs.tags
52
+
53
+ print(f"Found tags: {[tag.name for tag in tags]}")
54
+
55
+ # Filter tags with the "vYYYY-MM-DD" format
56
+ timestamp_tags = [
57
+ tag.name for tag in tags if tag.name.startswith("v") and len(tag.name) == 11 and tag.name[1:11].isdigit()
58
+ ]
59
+
60
+ if not timestamp_tags:
61
+ print("No valid timestamp-based revisions found. Using `v1.0.0` as default.")
62
+ return "v1.0.0"
63
+ print(f"Valid timestamp-based revisions: {timestamp_tags}")
64
+
65
+ # Sort and return the most recent tag
66
+ latest_revision = sorted(timestamp_tags)[-1]
67
+ print(f"Latest revision determined: {latest_revision}")
68
+ return latest_revision
69
+
70
+ def setup_dataset(self):
71
+ """Load dataset with FAISS index."""
72
+ print("Loading dataset from Hugging Face...")
73
+
74
+ # Fetch the latest revision dynamically
75
+ latest_revision = self.get_latest_revision()
76
+
77
+ # Load dataset
78
+ dataset = load_dataset(
79
+ dataset_name,
80
+ revision=latest_revision,
81
+ )
82
+
83
+ # Try to load the index from the Hub
84
+ try:
85
+ print("Downloading pre-built FAISS index...")
86
+ index_path = hf_hub_download(
87
+ repo_id=dataset_name,
88
+ filename="arxiv_faiss_index.faiss",
89
+ revision=latest_revision,
90
+ token=self.hf_token,
91
+ repo_type="dataset",
92
+ )
93
+
94
+ print("Loading pre-built FAISS index...")
95
+ dataset["train"].load_faiss_index("embedding", index_path)
96
+ print("Pre-built FAISS index loaded successfully")
97
+
98
+ except Exception as e:
99
+ print(f"Could not load pre-built index: {e}")
100
+ print("Building new FAISS index...")
101
+
102
+ # Add FAISS index if it doesn't exist
103
+ if not dataset["train"].features.get("embedding"):
104
+ print("Dataset doesn't have 'embedding' column, cannot create FAISS index")
105
+ raise ValueError("Dataset doesn't have 'embedding' column")
106
+
107
+ dataset["train"].add_faiss_index(
108
+ column="embedding",
109
+ metric_type=faiss.METRIC_INNER_PRODUCT,
110
+ string_factory="HNSW,RFlat", # Using reranking
111
+ )
112
+
113
+ print(f"Dataset loaded with {len(dataset['train'])} items and FAISS index ready")
114
+
115
+ self.dataset = dataset
116
+ return dataset
117
+
118
+ def update_dataset_with_new_papers(self):
119
+ """Fetch new papers from arXiv, ensure no duplicates, and update the dataset and FAISS index."""
120
+ if self.dataset is None:
121
+ self.setup_dataset()
122
+
123
+ # Get the last update date from the dataset
124
+ last_update_date = max(
125
+ [datetime.datetime.strptime(row["update_date"], "%Y-%m-%d") for row in self.dataset["train"]],
126
+ default=datetime.datetime.now() - datetime.timedelta(days=1),
127
+ )
128
+
129
+ # Initialize variables for iterative querying
130
+ start = 0
131
+ max_results_per_query = 100
132
+ all_new_papers = []
133
+
134
+ while True:
135
+ # Fetch new papers from arXiv since the last update
136
+ new_papers = retrieve_arxiv_papers(
137
+ categories=list(ARXIV_CATEGORIES_FLAT.keys()),
138
+ start_date=last_update_date,
139
+ end_date=datetime.datetime.now(),
140
+ start=start,
141
+ max_results=max_results_per_query,
142
+ )
143
+
144
+ if not new_papers:
145
+ break
146
+
147
+ all_new_papers.extend(new_papers)
148
+ start += max_results_per_query
149
+
150
+ # Respect the rate limit of 1 query every 3 seconds
151
+ time.sleep(3)
152
+
153
+ # Filter out duplicates
154
+ existing_ids = set(row["id"] for row in self.dataset["train"])
155
+ unique_papers = [paper for paper in all_new_papers if paper["arxiv_id"] not in existing_ids]
156
+
157
+ if not unique_papers:
158
+ print("No new papers to add.")
159
+ return
160
+
161
+ # Add new papers to the dataset
162
+ for paper in unique_papers:
163
+ embedding = self.embedding_model.embed_text(paper["abstract"])
164
+ self.dataset["train"].add_item(
165
+ {
166
+ "id": paper["arxiv_id"],
167
+ "title": paper["title"],
168
+ "authors": ", ".join(paper["authors"]),
169
+ "categories": ", ".join(paper["categories"]),
170
+ "abstract": paper["abstract"],
171
+ "update_date": paper["published_date"],
172
+ "embedding": embedding,
173
+ }
174
+ )
175
+
176
+ # Update the FAISS index
177
+ self.dataset["train"].add_faiss_index(
178
+ column="embedding",
179
+ metric_type=faiss.METRIC_INNER_PRODUCT,
180
+ string_factory="HNSW,RFlat",
181
+ )
182
+
183
+ # Save the FAISS index to the Hub
184
+ self.save_faiss_index_to_hub()
185
+
186
+ # Save the updated dataset to the Hub with a new revision
187
+ new_revision = self.get_revision_name()
188
+ self.dataset.push_to_hub(
189
+ repo_id=self.dataset_name,
190
+ token=self.hf_token,
191
+ commit_message=f"Update dataset with new papers ({new_revision})",
192
+ revision=new_revision,
193
+ )
194
+
195
+ print(f"Dataset updated and saved to the Hub with revision {new_revision}.")
196
+
197
+ def save_faiss_index_to_hub(self):
198
+ """Save the FAISS index to the Hub for easy access"""
199
+ local_index_path = "arxiv_faiss_index.faiss"
200
+
201
+ # 1. Save the index to a local file
202
+ self.dataset["train"].save_faiss_index("embedding", local_index_path)
203
+ print(f"FAISS index saved locally to {local_index_path}")
204
+
205
+ # 2. Upload the index file to the Hub
206
+ from huggingface_hub import upload_file
207
+
208
+ remote_path = upload_file(
209
+ path_or_fileobj=local_index_path,
210
+ path_in_repo=local_index_path, # Same name on the Hub
211
+ repo_id=self.dataset_name, # Use your dataset repo
212
+ token=self.hf_token,
213
+ repo_type="dataset", # This is a dataset file
214
+ revision=self.get_revision_name(), # Use the current revision
215
+ commit_message="Add FAISS index", # Commit message
216
+ )
217
+
218
+ print(f"FAISS index uploaded to Hub at {remote_path}")
219
+
220
+ # Remove the local file. It's now stored on the Hub.
221
+ os.remove(local_index_path)
update_dataset.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataset_utils import DatasetManager
3
+ from app import init_embedding_model
4
+
5
+ # Dataset details
6
+ dataset_name = "nomadicsynth/arxiv-dataset-abstract-embeddings"
7
+ HF_TOKEN = os.getenv("HF_TOKEN")
8
+
9
+ if __name__ == "__main__":
10
+ # Initialize the embedding model
11
+ embedding_model = init_embedding_model(
12
+ model_name_or_path="nomadicsynth/research-compass-arxiv-abstracts-embedding-model",
13
+ model_revision="2025-01-28_23-06-17-1epochs-12batch-32eval-512embed-final",
14
+ hf_token=HF_TOKEN,
15
+ )
16
+
17
+ # Initialize DatasetManager with the embedding model
18
+ dataset_manager = DatasetManager(dataset_name=dataset_name, hf_token=HF_TOKEN, embedding_model=embedding_model)
19
+
20
+ # Update the dataset with new papers
21
+ dataset_manager.update_dataset_with_new_papers()