Update update_embeddings.py
Browse files- update_embeddings.py +51 -17
update_embeddings.py
CHANGED
@@ -19,12 +19,17 @@ from mixedbread_ai.client import MixedbreadAI # For embedding the text
|
|
19 |
import numpy as np # For array manipulation
|
20 |
from huggingface_hub import HfApi # To transact with huggingface.co
|
21 |
import sys # To quit the script
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
################################################################################
|
24 |
# Configuration
|
25 |
|
26 |
-
# Year to update embeddings for
|
27 |
-
year =
|
28 |
|
29 |
# Flag to force download and conversion even if files already exist
|
30 |
FORCE = True
|
@@ -71,7 +76,8 @@ download_file = f'{download_folder}/arxiv-metadata-oai-snapshot.json'
|
|
71 |
## Download the dataset if it doesn't exist
|
72 |
if not os.path.exists(download_file) or FORCE:
|
73 |
|
74 |
-
print(f'Downloading {download_file}')
|
|
|
75 |
|
76 |
subprocess.run(['kaggle', 'datasets', 'download', '--dataset', dataset_path, '--path', download_folder, '--unzip'])
|
77 |
|
@@ -79,8 +85,8 @@ if not os.path.exists(download_file) or FORCE:
|
|
79 |
|
80 |
else:
|
81 |
|
82 |
-
print(f'{download_file} already exists')
|
83 |
-
print('
|
84 |
|
85 |
################################################################################
|
86 |
# Filter by year and convert to parquet
|
@@ -88,7 +94,7 @@ else:
|
|
88 |
# https://huggingface.co/docs/datasets/en/about_arrow#memory-mapping
|
89 |
# Load metadata
|
90 |
print(f"Loading json metadata")
|
91 |
-
|
92 |
|
93 |
########################################
|
94 |
# Function to add year to metadata
|
@@ -101,20 +107,24 @@ def add_year(example):
|
|
101 |
|
102 |
# Add year to metadata
|
103 |
print(f"Adding year to metadata")
|
104 |
-
|
105 |
|
106 |
# Filter by year
|
107 |
print(f"Filtering metadata by year: {year}")
|
108 |
-
|
109 |
|
110 |
# Convert to pandas
|
111 |
print(f"Loading metadata for year: {year} into pandas")
|
112 |
-
arxiv_metadata_split =
|
113 |
|
114 |
################################################################################
|
115 |
# Load Model
|
116 |
|
117 |
if LOCAL:
|
|
|
|
|
|
|
|
|
118 |
# Make the app device agnostic
|
119 |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
120 |
|
@@ -123,7 +133,8 @@ if LOCAL:
|
|
123 |
model = SentenceTransformer(model_name)
|
124 |
model = model.to(device)
|
125 |
else:
|
126 |
-
print("Setting up mxbai client")
|
|
|
127 |
# Setup mxbai
|
128 |
mxbai_api_key = os.getenv("MXBAI_API_KEY")
|
129 |
mxbai = MixedbreadAI(api_key=mxbai_api_key)
|
@@ -162,12 +173,20 @@ os.makedirs(local_dir, exist_ok=True)
|
|
162 |
# Download the repo
|
163 |
snapshot_download(repo_id=repo_id, repo_type=repo_type, local_dir=local_dir, allow_patterns=allow_patterns)
|
164 |
|
165 |
-
|
166 |
-
|
|
|
|
|
167 |
|
168 |
-
# Load previous_embed
|
169 |
-
print(f"Loading previously embedded file: {previous_embed}")
|
170 |
-
previous_embeddings = pd.read_parquet(previous_embed)
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
|
172 |
########################################
|
173 |
# Embed the new abstracts
|
@@ -214,10 +233,25 @@ new_embeddings.to_parquet(embed_filename, index=False)
|
|
214 |
|
215 |
# Upload the new embeddings to the repo
|
216 |
if UPLOAD:
|
217 |
-
|
218 |
print(f"Uploading new embeddings to: {repo_id}")
|
219 |
access_token = os.getenv("HF_API_KEY")
|
220 |
api = HfApi(token=access_token)
|
221 |
|
222 |
# Upload all files within the folder to the specified repository
|
223 |
-
api.upload_folder(repo_id=repo_id, folder_path=embed_folder, path_in_repo=folder_in_repo, repo_type="dataset")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
import numpy as np # For array manipulation
|
20 |
from huggingface_hub import HfApi # To transact with huggingface.co
|
21 |
import sys # To quit the script
|
22 |
+
import datetime # get current year
|
23 |
+
from time import time # To time the script
|
24 |
+
|
25 |
+
# Start timer
|
26 |
+
start = time()
|
27 |
|
28 |
################################################################################
|
29 |
# Configuration
|
30 |
|
31 |
+
# Year to update embeddings for, get and set the current year
|
32 |
+
year = str(datetime.datetime.now().year)[2:]
|
33 |
|
34 |
# Flag to force download and conversion even if files already exist
|
35 |
FORCE = True
|
|
|
76 |
## Download the dataset if it doesn't exist
|
77 |
if not os.path.exists(download_file) or FORCE:
|
78 |
|
79 |
+
print(f'Downloading {download_file}, if it exists it will be overwritten')
|
80 |
+
print('Set FORCE to False to skip download if file already exists')
|
81 |
|
82 |
subprocess.run(['kaggle', 'datasets', 'download', '--dataset', dataset_path, '--path', download_folder, '--unzip'])
|
83 |
|
|
|
85 |
|
86 |
else:
|
87 |
|
88 |
+
print(f'{download_file} already exists, skipping download')
|
89 |
+
print('Set FORCE = True to force download')
|
90 |
|
91 |
################################################################################
|
92 |
# Filter by year and convert to parquet
|
|
|
94 |
# https://huggingface.co/docs/datasets/en/about_arrow#memory-mapping
|
95 |
# Load metadata
|
96 |
print(f"Loading json metadata")
|
97 |
+
arxiv_metadata_all = load_dataset("json", data_files= str(f"{download_file}"))
|
98 |
|
99 |
########################################
|
100 |
# Function to add year to metadata
|
|
|
107 |
|
108 |
# Add year to metadata
|
109 |
print(f"Adding year to metadata")
|
110 |
+
arxiv_metadata_all = arxiv_metadata_all.map(add_year, num_proc=num_cores)
|
111 |
|
112 |
# Filter by year
|
113 |
print(f"Filtering metadata by year: {year}")
|
114 |
+
arxiv_metadata_all = arxiv_metadata_all.filter(lambda example: example['year'] == year, num_proc=num_cores)
|
115 |
|
116 |
# Convert to pandas
|
117 |
print(f"Loading metadata for year: {year} into pandas")
|
118 |
+
arxiv_metadata_split = arxiv_metadata_all['train'].to_pandas()
|
119 |
|
120 |
################################################################################
|
121 |
# Load Model
|
122 |
|
123 |
if LOCAL:
|
124 |
+
|
125 |
+
print(f"Setting up local embedding model")
|
126 |
+
print("To use mxbai API, set LOCAL = False")
|
127 |
+
|
128 |
# Make the app device agnostic
|
129 |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
130 |
|
|
|
133 |
model = SentenceTransformer(model_name)
|
134 |
model = model.to(device)
|
135 |
else:
|
136 |
+
print("Setting up mxbai API client")
|
137 |
+
print("To use local resources, set LOCAL = True")
|
138 |
# Setup mxbai
|
139 |
mxbai_api_key = os.getenv("MXBAI_API_KEY")
|
140 |
mxbai = MixedbreadAI(api_key=mxbai_api_key)
|
|
|
173 |
# Download the repo
|
174 |
snapshot_download(repo_id=repo_id, repo_type=repo_type, local_dir=local_dir, allow_patterns=allow_patterns)
|
175 |
|
176 |
+
try:
|
177 |
+
|
178 |
+
# Gather previous embed file
|
179 |
+
previous_embed = f'{local_dir}/{folder_in_repo}/{year}.parquet'
|
180 |
|
181 |
+
# Load previous_embed
|
182 |
+
print(f"Loading previously embedded file: {previous_embed}")
|
183 |
+
previous_embeddings = pd.read_parquet(previous_embed)
|
184 |
+
|
185 |
+
except Exception as e:
|
186 |
+
print(f"Errored out with: {e}")
|
187 |
+
print(f"No previous embeddings found for year: {year}")
|
188 |
+
print("Creating new embeddings for all papers")
|
189 |
+
previous_embeddings = pd.DataFrame(columns=['id', 'vector', '$meta'])
|
190 |
|
191 |
########################################
|
192 |
# Embed the new abstracts
|
|
|
233 |
|
234 |
# Upload the new embeddings to the repo
|
235 |
if UPLOAD:
|
236 |
+
|
237 |
print(f"Uploading new embeddings to: {repo_id}")
|
238 |
access_token = os.getenv("HF_API_KEY")
|
239 |
api = HfApi(token=access_token)
|
240 |
|
241 |
# Upload all files within the folder to the specified repository
|
242 |
+
api.upload_folder(repo_id=repo_id, folder_path=embed_folder, path_in_repo=folder_in_repo, repo_type="dataset")
|
243 |
+
|
244 |
+
print(f"Upload complete for year: {year}")
|
245 |
+
|
246 |
+
else:
|
247 |
+
print("Not uploading new embeddings to the repo")
|
248 |
+
print("To upload new embeddings, set UPLOAD to True")
|
249 |
+
################################################################################
|
250 |
+
|
251 |
+
# Track time
|
252 |
+
end = time()
|
253 |
+
|
254 |
+
# Calculate and show time taken
|
255 |
+
print(f"Time taken: {end - start} seconds")
|
256 |
+
|
257 |
+
print("Done!")
|