nomadicsynth commited on
Commit
c241b7f
·
1 Parent(s): aa83efc

Refactor dataset management and improve dataset update functionality

Browse files
Files changed (2) hide show
  1. app.py +23 -17
  2. dataset_utils.py +42 -38
app.py CHANGED
@@ -1,19 +1,15 @@
1
  import json
2
  import os
3
 
4
- import faiss
5
  import gradio as gr
6
  import pandas as pd
7
  import spaces
8
  import torch
9
- from datasets import load_dataset
10
- from huggingface_hub import InferenceClient, hf_hub_download
11
- from huggingface_hub import login as hf_hub_login
12
- from huggingface_hub import upload_file
13
  from sentence_transformers import SentenceTransformer
14
 
15
  from arxiv_stuff import ARXIV_CATEGORIES_FLAT
16
- from dataset_utils import DatasetManager
17
 
18
  # Get HF_TOKEN from environment variables
19
  HF_TOKEN = os.getenv("HF_TOKEN")
@@ -24,11 +20,6 @@ if persistent_storage:
24
  # Use persistent storage
25
  print("Using persistent storage")
26
 
27
- # Dataset details
28
- dataset_name = "nomadicsynth/arxiv-dataset-abstract-embeddings"
29
- dataset_revision = "v1.0.0"
30
- local_index_path = "arxiv_faiss_index.faiss"
31
-
32
  # Embedding model details
33
  embedding_model_name = "nomadicsynth/research-compass-arxiv-abstracts-embedding-model"
34
  embedding_model_revision = "2025-01-28_23-06-17-1epochs-12batch-32eval-512embed-final"
@@ -57,7 +48,18 @@ embedding_model = None
57
  reasoning_model = None
58
 
59
 
60
- def init_embedding_model(model_name_or_path: str, model_revision: str = None, hf_token: str = None) -> SentenceTransformer:
 
 
 
 
 
 
 
 
 
 
 
61
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
  embedding_model = SentenceTransformer(
63
  model_name_or_path,
@@ -71,6 +73,13 @@ def init_embedding_model(model_name_or_path: str, model_revision: str = None, hf
71
 
72
  @spaces.GPU
73
  def embed_text(text: str | list[str]) -> torch.Tensor:
 
 
 
 
 
 
 
74
  global embedding_model
75
 
76
  # Strip any leading/trailing whitespace
@@ -393,12 +402,9 @@ def find_synergistic_papers(abstract: str, limit=25) -> list[dict]:
393
 
394
  def format_search_results_json(abstract: str) -> str:
395
  """Format search results as JSON for display"""
396
- # Find papers synergistic with the given abstract
397
- papers = find_synergistic_papers(abstract)
398
-
399
- # Convert to JSON for display
400
  json_output = json.dumps(papers, indent=2)
401
- print(f"JSON output: {json_output}")
402
  return json_output
403
 
404
 
 
1
  import json
2
  import os
3
 
 
4
  import gradio as gr
5
  import pandas as pd
6
  import spaces
7
  import torch
8
+ from huggingface_hub import InferenceClient
 
 
 
9
  from sentence_transformers import SentenceTransformer
10
 
11
  from arxiv_stuff import ARXIV_CATEGORIES_FLAT
12
+ from dataset_utils import DatasetManager, dataset_name
13
 
14
  # Get HF_TOKEN from environment variables
15
  HF_TOKEN = os.getenv("HF_TOKEN")
 
20
  # Use persistent storage
21
  print("Using persistent storage")
22
 
 
 
 
 
 
23
  # Embedding model details
24
  embedding_model_name = "nomadicsynth/research-compass-arxiv-abstracts-embedding-model"
25
  embedding_model_revision = "2025-01-28_23-06-17-1epochs-12batch-32eval-512embed-final"
 
48
  reasoning_model = None
49
 
50
 
51
+ def init_embedding_model(
52
+ model_name_or_path: str, model_revision: str = None, hf_token: str = None
53
+ ) -> SentenceTransformer:
54
+ """
55
+ Initialize the embedding model with the specified model name or path and revision.
56
+ Args:
57
+ model_name_or_path (str): The name or path of the model.
58
+ model_revision (str): The revision of the model.
59
+ hf_token (str): The Hugging Face token for authentication.
60
+ Returns:
61
+ SentenceTransformer: The initialized embedding model.
62
+ """
63
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
64
  embedding_model = SentenceTransformer(
65
  model_name_or_path,
 
73
 
74
  @spaces.GPU
75
  def embed_text(text: str | list[str]) -> torch.Tensor:
76
+ """
77
+ Generate embeddings for the given text using the embedding model.
78
+ Args:
79
+ text (str | list[str]): The text or list of texts to embed.
80
+ Returns:
81
+ torch.Tensor: The generated embeddings.
82
+ """
83
  global embedding_model
84
 
85
  # Strip any leading/trailing whitespace
 
402
 
403
  def format_search_results_json(abstract: str) -> str:
404
  """Format search results as JSON for display"""
405
+ papers = find_synergistic_papers(abstract, limit=10)
 
 
 
406
  json_output = json.dumps(papers, indent=2)
407
+
408
  return json_output
409
 
410
 
dataset_utils.py CHANGED
@@ -1,17 +1,18 @@
1
- from datasets import load_dataset
2
- from huggingface_hub import HfApi, hf_hub_download
3
- import faiss
4
- import os
5
  import datetime
 
6
  import time
7
 
8
- # from embedding_model import EmbeddingModel
9
- # from app import EmbeddingModel
10
- from arxiv_stuff import retrieve_arxiv_papers, ARXIV_CATEGORIES_FLAT
11
  from sentence_transformers import SentenceTransformer
12
 
 
 
13
  # Dataset details
14
- dataset_name = "nomadicsynth/arxiv-dataset-abstract-embeddings"
 
 
15
  HF_TOKEN = os.getenv("HF_TOKEN")
16
 
17
 
@@ -19,8 +20,8 @@ class DatasetManager:
19
 
20
  def __init__(
21
  self,
22
- dataset_name: str,
23
- embedding_model: SentenceTransformer,
24
  hf_token: str = None,
25
  ):
26
  """
@@ -33,16 +34,23 @@ class DatasetManager:
33
  self.dataset_name = dataset_name
34
  self.hf_token = hf_token
35
  self.embedding_model = embedding_model
36
- self.dataset = None
 
 
 
 
 
37
 
 
38
  self.setup_dataset()
39
 
40
- def get_revision_name(self):
41
  """Generate a timestamp-based revision name."""
42
  return datetime.datetime.now().strftime("v%Y-%m-%d")
43
 
44
  def get_latest_revision(self):
45
  """Return the latest timestamp-based revision."""
 
46
  api = HfApi()
47
  print(f"Fetching revisions for dataset: {self.dataset_name}")
48
 
@@ -58,8 +66,8 @@ class DatasetManager:
58
  ]
59
 
60
  if not timestamp_tags:
61
- print("No valid timestamp-based revisions found. Using `v1.0.0` as default.")
62
- return "v1.0.0"
63
  print(f"Valid timestamp-based revisions: {timestamp_tags}")
64
 
65
  # Sort and return the most recent tag
@@ -71,22 +79,20 @@ class DatasetManager:
71
  """Load dataset with FAISS index."""
72
  print("Loading dataset from Hugging Face...")
73
 
74
- # Fetch the latest revision dynamically
75
- latest_revision = self.get_latest_revision()
76
-
77
  # Load dataset
78
  dataset = load_dataset(
79
- dataset_name,
80
- revision=latest_revision,
 
81
  )
82
 
83
  # Try to load the index from the Hub
84
  try:
85
  print("Downloading pre-built FAISS index...")
86
  index_path = hf_hub_download(
87
- repo_id=dataset_name,
88
- filename="arxiv_faiss_index.faiss",
89
- revision=latest_revision,
90
  token=self.hf_token,
91
  repo_type="dataset",
92
  )
@@ -173,6 +179,15 @@ class DatasetManager:
173
  }
174
  )
175
 
 
 
 
 
 
 
 
 
 
176
  # Update the FAISS index
177
  self.dataset["train"].add_faiss_index(
178
  column="embedding",
@@ -181,38 +196,27 @@ class DatasetManager:
181
  )
182
 
183
  # Save the FAISS index to the Hub
184
- self.save_faiss_index_to_hub()
185
-
186
- # Save the updated dataset to the Hub with a new revision
187
- new_revision = self.get_revision_name()
188
- self.dataset.push_to_hub(
189
- repo_id=self.dataset_name,
190
- token=self.hf_token,
191
- commit_message=f"Update dataset with new papers ({new_revision})",
192
- revision=new_revision,
193
- )
194
 
195
  print(f"Dataset updated and saved to the Hub with revision {new_revision}.")
196
 
197
- def save_faiss_index_to_hub(self):
198
  """Save the FAISS index to the Hub for easy access"""
199
- local_index_path = "arxiv_faiss_index.faiss"
200
 
201
  # 1. Save the index to a local file
202
  self.dataset["train"].save_faiss_index("embedding", local_index_path)
203
  print(f"FAISS index saved locally to {local_index_path}")
204
 
205
  # 2. Upload the index file to the Hub
206
- from huggingface_hub import upload_file
207
-
208
  remote_path = upload_file(
209
  path_or_fileobj=local_index_path,
210
  path_in_repo=local_index_path, # Same name on the Hub
211
  repo_id=self.dataset_name, # Use your dataset repo
212
  token=self.hf_token,
213
  repo_type="dataset", # This is a dataset file
214
- revision=self.get_revision_name(), # Use the current revision
215
- commit_message="Add FAISS index", # Commit message
216
  )
217
 
218
  print(f"FAISS index uploaded to Hub at {remote_path}")
 
 
 
 
 
1
  import datetime
2
+ import os
3
  import time
4
 
5
+ import faiss
6
+ from datasets import load_dataset
7
+ from huggingface_hub import HfApi, hf_hub_download, upload_file
8
  from sentence_transformers import SentenceTransformer
9
 
10
+ from arxiv_stuff import ARXIV_CATEGORIES_FLAT, retrieve_arxiv_papers
11
+
12
  # Dataset details
13
+ default_dataset_revision = "v1.0.0"
14
+ local_index_path = "arxiv_faiss_index.faiss"
15
+
16
  HF_TOKEN = os.getenv("HF_TOKEN")
17
 
18
 
 
20
 
21
  def __init__(
22
  self,
23
+ dataset_name: str = "nomadicsynth/arxiv-dataset-abstract-embeddings",
24
+ embedding_model: SentenceTransformer = None,
25
  hf_token: str = None,
26
  ):
27
  """
 
34
  self.dataset_name = dataset_name
35
  self.hf_token = hf_token
36
  self.embedding_model = embedding_model
37
+ self.revision = self.get_latest_revision()
38
+
39
+ if self.hf_token is None:
40
+ self.hf_token = HF_TOKEN
41
+ if self.embedding_model is None:
42
+ raise ValueError("Embedding model must be provided.")
43
 
44
+ self.dataset = None
45
  self.setup_dataset()
46
 
47
+ def generate_revision_name(self):
48
  """Generate a timestamp-based revision name."""
49
  return datetime.datetime.now().strftime("v%Y-%m-%d")
50
 
51
  def get_latest_revision(self):
52
  """Return the latest timestamp-based revision."""
53
+ global default_dataset_revision
54
  api = HfApi()
55
  print(f"Fetching revisions for dataset: {self.dataset_name}")
56
 
 
66
  ]
67
 
68
  if not timestamp_tags:
69
+ print(f"No valid timestamp-based revisions found. Using `{default_dataset_revision}` as default.")
70
+ return default_dataset_revision
71
  print(f"Valid timestamp-based revisions: {timestamp_tags}")
72
 
73
  # Sort and return the most recent tag
 
79
  """Load dataset with FAISS index."""
80
  print("Loading dataset from Hugging Face...")
81
 
 
 
 
82
  # Load dataset
83
  dataset = load_dataset(
84
+ self.dataset_name,
85
+ revision=self.revision,
86
+ token=self.hf_token,
87
  )
88
 
89
  # Try to load the index from the Hub
90
  try:
91
  print("Downloading pre-built FAISS index...")
92
  index_path = hf_hub_download(
93
+ repo_id=self.dataset_name,
94
+ filename=local_index_path,
95
+ revision=self.revision,
96
  token=self.hf_token,
97
  repo_type="dataset",
98
  )
 
179
  }
180
  )
181
 
182
+ # Save the updated dataset to the Hub with a new revision
183
+ new_revision = self.generate_revision_name()
184
+ self.dataset.push_to_hub(
185
+ repo_id=self.dataset_name,
186
+ token=self.hf_token,
187
+ commit_message=f"Update dataset with new papers ({new_revision})",
188
+ revision=new_revision,
189
+ )
190
+
191
  # Update the FAISS index
192
  self.dataset["train"].add_faiss_index(
193
  column="embedding",
 
196
  )
197
 
198
  # Save the FAISS index to the Hub
199
+ self.save_faiss_index_to_hub(new_revision)
 
 
 
 
 
 
 
 
 
200
 
201
  print(f"Dataset updated and saved to the Hub with revision {new_revision}.")
202
 
203
+ def save_faiss_index_to_hub(self, revision: str):
204
  """Save the FAISS index to the Hub for easy access"""
205
+ global local_index_path
206
 
207
  # 1. Save the index to a local file
208
  self.dataset["train"].save_faiss_index("embedding", local_index_path)
209
  print(f"FAISS index saved locally to {local_index_path}")
210
 
211
  # 2. Upload the index file to the Hub
 
 
212
  remote_path = upload_file(
213
  path_or_fileobj=local_index_path,
214
  path_in_repo=local_index_path, # Same name on the Hub
215
  repo_id=self.dataset_name, # Use your dataset repo
216
  token=self.hf_token,
217
  repo_type="dataset", # This is a dataset file
218
+ revision=revision, # Use the new revision
219
+ commit_message=f"Add FAISS index for dataset revision {revision}",
220
  )
221
 
222
  print(f"FAISS index uploaded to Hub at {remote_path}")