bluuebunny commited on
Commit
0ee2db9
·
verified ·
1 Parent(s): ebae87d

Update update_embeddings.py

Browse files
Files changed (1) hide show
  1. update_embeddings.py +51 -17
update_embeddings.py CHANGED
@@ -19,12 +19,17 @@ from mixedbread_ai.client import MixedbreadAI # For embedding the text
19
  import numpy as np # For array manipulation
20
  from huggingface_hub import HfApi # To transact with huggingface.co
21
  import sys # To quit the script
 
 
 
 
 
22
 
23
  ################################################################################
24
  # Configuration
25
 
26
- # Year to update embeddings for
27
- year = '24'
28
 
29
  # Flag to force download and conversion even if files already exist
30
  FORCE = True
@@ -71,7 +76,8 @@ download_file = f'{download_folder}/arxiv-metadata-oai-snapshot.json'
71
  ## Download the dataset if it doesn't exist
72
  if not os.path.exists(download_file) or FORCE:
73
 
74
- print(f'Downloading {download_file}')
 
75
 
76
  subprocess.run(['kaggle', 'datasets', 'download', '--dataset', dataset_path, '--path', download_folder, '--unzip'])
77
 
@@ -79,8 +85,8 @@ if not os.path.exists(download_file) or FORCE:
79
 
80
  else:
81
 
82
- print(f'{download_file} already exists')
83
- print('Skipping download')
84
 
85
  ################################################################################
86
  # Filter by year and convert to parquet
@@ -88,7 +94,7 @@ else:
88
  # https://huggingface.co/docs/datasets/en/about_arrow#memory-mapping
89
  # Load metadata
90
  print(f"Loading json metadata")
91
- dataset = load_dataset("json", data_files= str(f"{download_file}"), num_proc=num_cores)
92
 
93
  ########################################
94
  # Function to add year to metadata
@@ -101,20 +107,24 @@ def add_year(example):
101
 
102
  # Add year to metadata
103
  print(f"Adding year to metadata")
104
- dataset = dataset.map(add_year, num_proc=num_cores)
105
 
106
  # Filter by year
107
  print(f"Filtering metadata by year: {year}")
108
- dataset = dataset.filter(lambda example: example['year'] == year, num_proc=num_cores)
109
 
110
  # Convert to pandas
111
  print(f"Loading metadata for year: {year} into pandas")
112
- arxiv_metadata_split = dataset['train'].to_pandas()
113
 
114
  ################################################################################
115
  # Load Model
116
 
117
  if LOCAL:
 
 
 
 
118
  # Make the app device agnostic
119
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
120
 
@@ -123,7 +133,8 @@ if LOCAL:
123
  model = SentenceTransformer(model_name)
124
  model = model.to(device)
125
  else:
126
- print("Setting up mxbai client")
 
127
  # Setup mxbai
128
  mxbai_api_key = os.getenv("MXBAI_API_KEY")
129
  mxbai = MixedbreadAI(api_key=mxbai_api_key)
@@ -162,12 +173,20 @@ os.makedirs(local_dir, exist_ok=True)
162
  # Download the repo
163
  snapshot_download(repo_id=repo_id, repo_type=repo_type, local_dir=local_dir, allow_patterns=allow_patterns)
164
 
165
- # Gather previous embed file
166
- previous_embed = f'{local_dir}/{folder_in_repo}/{year}.parquet'
 
 
167
 
168
- # Load previous_embed
169
- print(f"Loading previously embedded file: {previous_embed}")
170
- previous_embeddings = pd.read_parquet(previous_embed)
 
 
 
 
 
 
171
 
172
  ########################################
173
  # Embed the new abstracts
@@ -214,10 +233,25 @@ new_embeddings.to_parquet(embed_filename, index=False)
214
 
215
  # Upload the new embeddings to the repo
216
  if UPLOAD:
217
-
218
  print(f"Uploading new embeddings to: {repo_id}")
219
  access_token = os.getenv("HF_API_KEY")
220
  api = HfApi(token=access_token)
221
 
222
  # Upload all files within the folder to the specified repository
223
- api.upload_folder(repo_id=repo_id, folder_path=embed_folder, path_in_repo=folder_in_repo, repo_type="dataset")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  import numpy as np # For array manipulation
20
  from huggingface_hub import HfApi # To transact with huggingface.co
21
  import sys # To quit the script
22
+ import datetime # get current year
23
+ from time import time # To time the script
24
+
25
+ # Start timer
26
+ start = time()
27
 
28
  ################################################################################
29
  # Configuration
30
 
31
+ # Year to update embeddings for, get and set the current year
32
+ year = str(datetime.datetime.now().year)[2:]
33
 
34
  # Flag to force download and conversion even if files already exist
35
  FORCE = True
 
76
  ## Download the dataset if it doesn't exist
77
  if not os.path.exists(download_file) or FORCE:
78
 
79
+ print(f'Downloading {download_file}, if it exists it will be overwritten')
80
+ print('Set FORCE to False to skip download if file already exists')
81
 
82
  subprocess.run(['kaggle', 'datasets', 'download', '--dataset', dataset_path, '--path', download_folder, '--unzip'])
83
 
 
85
 
86
  else:
87
 
88
+ print(f'{download_file} already exists, skipping download')
89
+ print('Set FORCE = True to force download')
90
 
91
  ################################################################################
92
  # Filter by year and convert to parquet
 
94
  # https://huggingface.co/docs/datasets/en/about_arrow#memory-mapping
95
  # Load metadata
96
  print(f"Loading json metadata")
97
+ arxiv_metadata_all = load_dataset("json", data_files= str(f"{download_file}"))
98
 
99
  ########################################
100
  # Function to add year to metadata
 
107
 
108
  # Add year to metadata
109
  print(f"Adding year to metadata")
110
+ arxiv_metadata_all = arxiv_metadata_all.map(add_year, num_proc=num_cores)
111
 
112
  # Filter by year
113
  print(f"Filtering metadata by year: {year}")
114
+ arxiv_metadata_all = arxiv_metadata_all.filter(lambda example: example['year'] == year, num_proc=num_cores)
115
 
116
  # Convert to pandas
117
  print(f"Loading metadata for year: {year} into pandas")
118
+ arxiv_metadata_split = arxiv_metadata_all['train'].to_pandas()
119
 
120
  ################################################################################
121
  # Load Model
122
 
123
  if LOCAL:
124
+
125
+ print(f"Setting up local embedding model")
126
+ print("To use mxbai API, set LOCAL = False")
127
+
128
  # Make the app device agnostic
129
  device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
130
 
 
133
  model = SentenceTransformer(model_name)
134
  model = model.to(device)
135
  else:
136
+ print("Setting up mxbai API client")
137
+ print("To use local resources, set LOCAL = True")
138
  # Setup mxbai
139
  mxbai_api_key = os.getenv("MXBAI_API_KEY")
140
  mxbai = MixedbreadAI(api_key=mxbai_api_key)
 
173
  # Download the repo
174
  snapshot_download(repo_id=repo_id, repo_type=repo_type, local_dir=local_dir, allow_patterns=allow_patterns)
175
 
176
+ try:
177
+
178
+ # Gather previous embed file
179
+ previous_embed = f'{local_dir}/{folder_in_repo}/{year}.parquet'
180
 
181
+ # Load previous_embed
182
+ print(f"Loading previously embedded file: {previous_embed}")
183
+ previous_embeddings = pd.read_parquet(previous_embed)
184
+
185
+ except Exception as e:
186
+ print(f"Errored out with: {e}")
187
+ print(f"No previous embeddings found for year: {year}")
188
+ print("Creating new embeddings for all papers")
189
+ previous_embeddings = pd.DataFrame(columns=['id', 'vector', '$meta'])
190
 
191
  ########################################
192
  # Embed the new abstracts
 
233
 
234
  # Upload the new embeddings to the repo
235
  if UPLOAD:
236
+
237
  print(f"Uploading new embeddings to: {repo_id}")
238
  access_token = os.getenv("HF_API_KEY")
239
  api = HfApi(token=access_token)
240
 
241
  # Upload all files within the folder to the specified repository
242
+ api.upload_folder(repo_id=repo_id, folder_path=embed_folder, path_in_repo=folder_in_repo, repo_type="dataset")
243
+
244
+ print(f"Upload complete for year: {year}")
245
+
246
+ else:
247
+ print("Not uploading new embeddings to the repo")
248
+ print("To upload new embeddings, set UPLOAD to True")
249
+ ################################################################################
250
+
251
+ # Track time
252
+ end = time()
253
+
254
+ # Calculate and show time taken
255
+ print(f"Time taken: {end - start} seconds")
256
+
257
+ print("Done!")