Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
261056f
1
Parent(s):
4236eec
Add DatasetManager for handling dataset operations and update dataset with new papers
Browse files- app.py +13 -79
- dataset_utils.py +221 -0
- update_dataset.py +21 -0
app.py
CHANGED
@@ -13,6 +13,7 @@ from huggingface_hub import upload_file
|
|
13 |
from sentence_transformers import SentenceTransformer
|
14 |
|
15 |
from arxiv_stuff import ARXIV_CATEGORIES_FLAT
|
|
|
16 |
|
17 |
# Get HF_TOKEN from environment variables
|
18 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
@@ -56,88 +57,17 @@ embedding_model = None
|
|
56 |
reasoning_model = None
|
57 |
|
58 |
|
59 |
-
def
|
60 |
-
"""Save the FAISS index to the Hub for easy access"""
|
61 |
-
global dataset, local_index_path
|
62 |
-
# 1. Save the index to a local file
|
63 |
-
dataset["train"].save_faiss_index("embedding", local_index_path)
|
64 |
-
print(f"FAISS index saved locally to {local_index_path}")
|
65 |
-
|
66 |
-
# 2. Upload the index file to the Hub
|
67 |
-
remote_path = upload_file(
|
68 |
-
path_or_fileobj=local_index_path,
|
69 |
-
path_in_repo=local_index_path, # Same name on the Hub
|
70 |
-
repo_id=dataset_name, # Use your dataset repo
|
71 |
-
token=HF_TOKEN,
|
72 |
-
repo_type="dataset", # This is a dataset file
|
73 |
-
revision=dataset_revision, # Use the same revision as the dataset
|
74 |
-
commit_message="Add FAISS index", # Commit message
|
75 |
-
)
|
76 |
-
|
77 |
-
print(f"FAISS index uploaded to Hub at {remote_path}")
|
78 |
-
|
79 |
-
# Remove the local file. It's now stored on the Hub.
|
80 |
-
os.remove(local_index_path)
|
81 |
-
|
82 |
-
|
83 |
-
def setup_dataset():
|
84 |
-
"""Load dataset with FAISS index"""
|
85 |
-
global dataset
|
86 |
-
print("Loading dataset from Hugging Face...")
|
87 |
-
|
88 |
-
# Load dataset
|
89 |
-
dataset = load_dataset(
|
90 |
-
dataset_name,
|
91 |
-
revision=dataset_revision,
|
92 |
-
)
|
93 |
-
|
94 |
-
# Try to load the index from the Hub
|
95 |
-
try:
|
96 |
-
print("Downloading pre-built FAISS index...")
|
97 |
-
index_path = hf_hub_download(
|
98 |
-
repo_id=dataset_name,
|
99 |
-
filename="arxiv_faiss_index.faiss",
|
100 |
-
revision=dataset_revision,
|
101 |
-
token=HF_TOKEN,
|
102 |
-
repo_type="dataset",
|
103 |
-
)
|
104 |
-
|
105 |
-
print("Loading pre-built FAISS index...")
|
106 |
-
dataset["train"].load_faiss_index("embedding", index_path)
|
107 |
-
print("Pre-built FAISS index loaded successfully")
|
108 |
-
|
109 |
-
except Exception as e:
|
110 |
-
print(f"Could not load pre-built index: {e}")
|
111 |
-
print("Building new FAISS index...")
|
112 |
-
|
113 |
-
# Add FAISS index if it doesn't exist
|
114 |
-
if not dataset["train"].features.get("embedding"):
|
115 |
-
print("Dataset doesn't have 'embedding' column, cannot create FAISS index")
|
116 |
-
raise ValueError("Dataset doesn't have 'embedding' column")
|
117 |
-
|
118 |
-
dataset["train"].add_faiss_index(
|
119 |
-
column="embedding",
|
120 |
-
metric_type=faiss.METRIC_INNER_PRODUCT,
|
121 |
-
string_factory="HNSW,RFlat", # Using reranking
|
122 |
-
)
|
123 |
-
|
124 |
-
# Save the FAISS index to the Hub
|
125 |
-
save_faiss_index_to_hub()
|
126 |
-
|
127 |
-
print(f"Dataset loaded with {len(dataset['train'])} items and FAISS index ready")
|
128 |
-
|
129 |
-
|
130 |
-
def init_embedding_model(model_name_or_path: str, model_revision: str = None) -> SentenceTransformer:
|
131 |
-
global embedding_model
|
132 |
-
|
133 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
134 |
embedding_model = SentenceTransformer(
|
135 |
model_name_or_path,
|
136 |
revision=model_revision,
|
137 |
-
token=
|
138 |
device=device,
|
139 |
)
|
140 |
|
|
|
|
|
141 |
|
142 |
@spaces.GPU
|
143 |
def embed_text(text: str | list[str]) -> torch.Tensor:
|
@@ -798,14 +728,18 @@ def create_interface():
|
|
798 |
|
799 |
|
800 |
if __name__ == "__main__":
|
801 |
-
# Load dataset with FAISS index
|
802 |
-
setup_dataset()
|
803 |
-
|
804 |
# Initialize the embedding model
|
805 |
-
init_embedding_model(embedding_model_name, embedding_model_revision)
|
806 |
|
807 |
# Initialize the reasoning model
|
808 |
reasoning_model = init_reasoning_model(reasoning_model_id)
|
809 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
810 |
demo = create_interface()
|
811 |
demo.queue(api_open=False).launch(ssr_mode=False, show_api=True)
|
|
|
13 |
from sentence_transformers import SentenceTransformer
|
14 |
|
15 |
from arxiv_stuff import ARXIV_CATEGORIES_FLAT
|
16 |
+
from dataset_utils import DatasetManager
|
17 |
|
18 |
# Get HF_TOKEN from environment variables
|
19 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
|
|
57 |
reasoning_model = None
|
58 |
|
59 |
|
60 |
+
def init_embedding_model(model_name_or_path: str, model_revision: str = None, hf_token: str = None) -> SentenceTransformer:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
62 |
embedding_model = SentenceTransformer(
|
63 |
model_name_or_path,
|
64 |
revision=model_revision,
|
65 |
+
token=hf_token,
|
66 |
device=device,
|
67 |
)
|
68 |
|
69 |
+
return embedding_model
|
70 |
+
|
71 |
|
72 |
@spaces.GPU
|
73 |
def embed_text(text: str | list[str]) -> torch.Tensor:
|
|
|
728 |
|
729 |
|
730 |
if __name__ == "__main__":
|
|
|
|
|
|
|
731 |
# Initialize the embedding model
|
732 |
+
embedding_model = init_embedding_model(embedding_model_name, embedding_model_revision)
|
733 |
|
734 |
# Initialize the reasoning model
|
735 |
reasoning_model = init_reasoning_model(reasoning_model_id)
|
736 |
|
737 |
+
# Load dataset with FAISS index
|
738 |
+
dataset = DatasetManager(
|
739 |
+
dataset_name=dataset_name,
|
740 |
+
hf_token=HF_TOKEN,
|
741 |
+
embedding_model=embedding_model,
|
742 |
+
)
|
743 |
+
|
744 |
demo = create_interface()
|
745 |
demo.queue(api_open=False).launch(ssr_mode=False, show_api=True)
|
dataset_utils.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import load_dataset
|
2 |
+
from huggingface_hub import HfApi, hf_hub_download
|
3 |
+
import faiss
|
4 |
+
import os
|
5 |
+
import datetime
|
6 |
+
import time
|
7 |
+
|
8 |
+
# from embedding_model import EmbeddingModel
|
9 |
+
# from app import EmbeddingModel
|
10 |
+
from arxiv_stuff import retrieve_arxiv_papers, ARXIV_CATEGORIES_FLAT
|
11 |
+
from sentence_transformers import SentenceTransformer
|
12 |
+
|
13 |
+
# Dataset details
|
14 |
+
dataset_name = "nomadicsynth/arxiv-dataset-abstract-embeddings"
|
15 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
16 |
+
|
17 |
+
|
18 |
+
class DatasetManager:
|
19 |
+
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
dataset_name: str,
|
23 |
+
embedding_model: SentenceTransformer,
|
24 |
+
hf_token: str = None,
|
25 |
+
):
|
26 |
+
"""
|
27 |
+
Initialize the DatasetManager with the dataset name, Hugging Face token, and embedding model.
|
28 |
+
Args:
|
29 |
+
dataset_name (str): The name of the dataset on Hugging Face Hub.
|
30 |
+
embedding_model (SentenceTransformer): The embedding model to use for generating embeddings.
|
31 |
+
hf_token (str): The Hugging Face token for authentication.
|
32 |
+
"""
|
33 |
+
self.dataset_name = dataset_name
|
34 |
+
self.hf_token = hf_token
|
35 |
+
self.embedding_model = embedding_model
|
36 |
+
self.dataset = None
|
37 |
+
|
38 |
+
self.setup_dataset()
|
39 |
+
|
40 |
+
def get_revision_name(self):
|
41 |
+
"""Generate a timestamp-based revision name."""
|
42 |
+
return datetime.datetime.now().strftime("v%Y-%m-%d")
|
43 |
+
|
44 |
+
def get_latest_revision(self):
|
45 |
+
"""Return the latest timestamp-based revision."""
|
46 |
+
api = HfApi()
|
47 |
+
print(f"Fetching revisions for dataset: {self.dataset_name}")
|
48 |
+
|
49 |
+
# List all tags in the repository
|
50 |
+
refs = api.list_repo_refs(repo_id=self.dataset_name, repo_type="dataset", token=self.hf_token)
|
51 |
+
tags = refs.tags
|
52 |
+
|
53 |
+
print(f"Found tags: {[tag.name for tag in tags]}")
|
54 |
+
|
55 |
+
# Filter tags with the "vYYYY-MM-DD" format
|
56 |
+
timestamp_tags = [
|
57 |
+
tag.name for tag in tags if tag.name.startswith("v") and len(tag.name) == 11 and tag.name[1:11].isdigit()
|
58 |
+
]
|
59 |
+
|
60 |
+
if not timestamp_tags:
|
61 |
+
print("No valid timestamp-based revisions found. Using `v1.0.0` as default.")
|
62 |
+
return "v1.0.0"
|
63 |
+
print(f"Valid timestamp-based revisions: {timestamp_tags}")
|
64 |
+
|
65 |
+
# Sort and return the most recent tag
|
66 |
+
latest_revision = sorted(timestamp_tags)[-1]
|
67 |
+
print(f"Latest revision determined: {latest_revision}")
|
68 |
+
return latest_revision
|
69 |
+
|
70 |
+
def setup_dataset(self):
|
71 |
+
"""Load dataset with FAISS index."""
|
72 |
+
print("Loading dataset from Hugging Face...")
|
73 |
+
|
74 |
+
# Fetch the latest revision dynamically
|
75 |
+
latest_revision = self.get_latest_revision()
|
76 |
+
|
77 |
+
# Load dataset
|
78 |
+
dataset = load_dataset(
|
79 |
+
dataset_name,
|
80 |
+
revision=latest_revision,
|
81 |
+
)
|
82 |
+
|
83 |
+
# Try to load the index from the Hub
|
84 |
+
try:
|
85 |
+
print("Downloading pre-built FAISS index...")
|
86 |
+
index_path = hf_hub_download(
|
87 |
+
repo_id=dataset_name,
|
88 |
+
filename="arxiv_faiss_index.faiss",
|
89 |
+
revision=latest_revision,
|
90 |
+
token=self.hf_token,
|
91 |
+
repo_type="dataset",
|
92 |
+
)
|
93 |
+
|
94 |
+
print("Loading pre-built FAISS index...")
|
95 |
+
dataset["train"].load_faiss_index("embedding", index_path)
|
96 |
+
print("Pre-built FAISS index loaded successfully")
|
97 |
+
|
98 |
+
except Exception as e:
|
99 |
+
print(f"Could not load pre-built index: {e}")
|
100 |
+
print("Building new FAISS index...")
|
101 |
+
|
102 |
+
# Add FAISS index if it doesn't exist
|
103 |
+
if not dataset["train"].features.get("embedding"):
|
104 |
+
print("Dataset doesn't have 'embedding' column, cannot create FAISS index")
|
105 |
+
raise ValueError("Dataset doesn't have 'embedding' column")
|
106 |
+
|
107 |
+
dataset["train"].add_faiss_index(
|
108 |
+
column="embedding",
|
109 |
+
metric_type=faiss.METRIC_INNER_PRODUCT,
|
110 |
+
string_factory="HNSW,RFlat", # Using reranking
|
111 |
+
)
|
112 |
+
|
113 |
+
print(f"Dataset loaded with {len(dataset['train'])} items and FAISS index ready")
|
114 |
+
|
115 |
+
self.dataset = dataset
|
116 |
+
return dataset
|
117 |
+
|
118 |
+
def update_dataset_with_new_papers(self):
|
119 |
+
"""Fetch new papers from arXiv, ensure no duplicates, and update the dataset and FAISS index."""
|
120 |
+
if self.dataset is None:
|
121 |
+
self.setup_dataset()
|
122 |
+
|
123 |
+
# Get the last update date from the dataset
|
124 |
+
last_update_date = max(
|
125 |
+
[datetime.datetime.strptime(row["update_date"], "%Y-%m-%d") for row in self.dataset["train"]],
|
126 |
+
default=datetime.datetime.now() - datetime.timedelta(days=1),
|
127 |
+
)
|
128 |
+
|
129 |
+
# Initialize variables for iterative querying
|
130 |
+
start = 0
|
131 |
+
max_results_per_query = 100
|
132 |
+
all_new_papers = []
|
133 |
+
|
134 |
+
while True:
|
135 |
+
# Fetch new papers from arXiv since the last update
|
136 |
+
new_papers = retrieve_arxiv_papers(
|
137 |
+
categories=list(ARXIV_CATEGORIES_FLAT.keys()),
|
138 |
+
start_date=last_update_date,
|
139 |
+
end_date=datetime.datetime.now(),
|
140 |
+
start=start,
|
141 |
+
max_results=max_results_per_query,
|
142 |
+
)
|
143 |
+
|
144 |
+
if not new_papers:
|
145 |
+
break
|
146 |
+
|
147 |
+
all_new_papers.extend(new_papers)
|
148 |
+
start += max_results_per_query
|
149 |
+
|
150 |
+
# Respect the rate limit of 1 query every 3 seconds
|
151 |
+
time.sleep(3)
|
152 |
+
|
153 |
+
# Filter out duplicates
|
154 |
+
existing_ids = set(row["id"] for row in self.dataset["train"])
|
155 |
+
unique_papers = [paper for paper in all_new_papers if paper["arxiv_id"] not in existing_ids]
|
156 |
+
|
157 |
+
if not unique_papers:
|
158 |
+
print("No new papers to add.")
|
159 |
+
return
|
160 |
+
|
161 |
+
# Add new papers to the dataset
|
162 |
+
for paper in unique_papers:
|
163 |
+
embedding = self.embedding_model.embed_text(paper["abstract"])
|
164 |
+
self.dataset["train"].add_item(
|
165 |
+
{
|
166 |
+
"id": paper["arxiv_id"],
|
167 |
+
"title": paper["title"],
|
168 |
+
"authors": ", ".join(paper["authors"]),
|
169 |
+
"categories": ", ".join(paper["categories"]),
|
170 |
+
"abstract": paper["abstract"],
|
171 |
+
"update_date": paper["published_date"],
|
172 |
+
"embedding": embedding,
|
173 |
+
}
|
174 |
+
)
|
175 |
+
|
176 |
+
# Update the FAISS index
|
177 |
+
self.dataset["train"].add_faiss_index(
|
178 |
+
column="embedding",
|
179 |
+
metric_type=faiss.METRIC_INNER_PRODUCT,
|
180 |
+
string_factory="HNSW,RFlat",
|
181 |
+
)
|
182 |
+
|
183 |
+
# Save the FAISS index to the Hub
|
184 |
+
self.save_faiss_index_to_hub()
|
185 |
+
|
186 |
+
# Save the updated dataset to the Hub with a new revision
|
187 |
+
new_revision = self.get_revision_name()
|
188 |
+
self.dataset.push_to_hub(
|
189 |
+
repo_id=self.dataset_name,
|
190 |
+
token=self.hf_token,
|
191 |
+
commit_message=f"Update dataset with new papers ({new_revision})",
|
192 |
+
revision=new_revision,
|
193 |
+
)
|
194 |
+
|
195 |
+
print(f"Dataset updated and saved to the Hub with revision {new_revision}.")
|
196 |
+
|
197 |
+
def save_faiss_index_to_hub(self):
|
198 |
+
"""Save the FAISS index to the Hub for easy access"""
|
199 |
+
local_index_path = "arxiv_faiss_index.faiss"
|
200 |
+
|
201 |
+
# 1. Save the index to a local file
|
202 |
+
self.dataset["train"].save_faiss_index("embedding", local_index_path)
|
203 |
+
print(f"FAISS index saved locally to {local_index_path}")
|
204 |
+
|
205 |
+
# 2. Upload the index file to the Hub
|
206 |
+
from huggingface_hub import upload_file
|
207 |
+
|
208 |
+
remote_path = upload_file(
|
209 |
+
path_or_fileobj=local_index_path,
|
210 |
+
path_in_repo=local_index_path, # Same name on the Hub
|
211 |
+
repo_id=self.dataset_name, # Use your dataset repo
|
212 |
+
token=self.hf_token,
|
213 |
+
repo_type="dataset", # This is a dataset file
|
214 |
+
revision=self.get_revision_name(), # Use the current revision
|
215 |
+
commit_message="Add FAISS index", # Commit message
|
216 |
+
)
|
217 |
+
|
218 |
+
print(f"FAISS index uploaded to Hub at {remote_path}")
|
219 |
+
|
220 |
+
# Remove the local file. It's now stored on the Hub.
|
221 |
+
os.remove(local_index_path)
|
update_dataset.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from dataset_utils import DatasetManager
|
3 |
+
from app import init_embedding_model
|
4 |
+
|
5 |
+
# Dataset details
|
6 |
+
dataset_name = "nomadicsynth/arxiv-dataset-abstract-embeddings"
|
7 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
8 |
+
|
9 |
+
if __name__ == "__main__":
|
10 |
+
# Initialize the embedding model
|
11 |
+
embedding_model = init_embedding_model(
|
12 |
+
model_name_or_path="nomadicsynth/research-compass-arxiv-abstracts-embedding-model",
|
13 |
+
model_revision="2025-01-28_23-06-17-1epochs-12batch-32eval-512embed-final",
|
14 |
+
hf_token=HF_TOKEN,
|
15 |
+
)
|
16 |
+
|
17 |
+
# Initialize DatasetManager with the embedding model
|
18 |
+
dataset_manager = DatasetManager(dataset_name=dataset_name, hf_token=HF_TOKEN, embedding_model=embedding_model)
|
19 |
+
|
20 |
+
# Update the dataset with new papers
|
21 |
+
dataset_manager.update_dataset_with_new_papers()
|