bluuebunny commited on
Commit
d3c1ddf
·
verified ·
1 Parent(s): 2dd6dc6

Create update_embeddings.py

Browse files
Files changed (1) hide show
  1. update_embeddings.py +213 -0
update_embeddings.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Download the arXiv metadata from Kaggle
2
+ ## https://www.kaggle.com/datasets/Cornell-University/arxiv
3
+
4
+ ## Requires the Kaggle API to be installed
5
+ ## Using subprocess to run the Kaggle CLI commands instead of Kaggle API
6
+ ## As it allows for anonymous downloads without needing to sign in
7
+ import subprocess
8
+ from datasets import load_dataset # To load dataset without breaking ram
9
+ from multiprocessing import cpu_count # To get the number of cores
10
+ from sentence_transformers import SentenceTransformer # For embedding the text
11
+ import torch # For gpu
12
+ import pandas as pd # Data manipulation
13
+ from huggingface_hub import snapshot_download # Download previous embeddings
14
+ import json # To make milvus compatible $meta
15
+ import os # Folder and file creation
16
+ from tqdm import tqdm # Progress bar
17
+ tqdm.pandas() # Progress bar for pandas
18
+ from mixedbread_ai.client import MixedbreadAI # For embedding the text
19
+ import numpy as np # For array manipulation
20
+ from huggingface_hub import HfApi # To transact with huggingface.co
21
+
22
+ ################################################################################
23
+ # Configuration
24
+
25
+ # Year to update embeddings for
26
+ year = '24'
27
+
28
+ # Flag to force download and conversion even if files already exist
29
+ FORCE = True
30
+
31
+ # Flag to embed the data locally, otherwise it will use mxbai api to embed
32
+ LOCAL = False
33
+
34
+ # Flag to upload the data to the Hugging Face Hub
35
+ UPLOAD = True
36
+
37
+ # Model to use for embedding
38
+ model_name = "mixedbread-ai/mxbai-embed-large-v1"
39
+
40
+ # Number of cores to use for multiprocessing
41
+ num_cores = cpu_count()-1
42
+
43
+ # Setup transaction details
44
+ repo_id = "bluuebunny/arxiv_abstract_embedding_mxbai_large_v1_milvus"
45
+ repo_type = "dataset"
46
+
47
+ # Subfolder in the repo of the dataset where the file is stored
48
+ folder_in_repo = "data"
49
+ allow_patterns = f"{folder_in_repo}/{year}.parquet"
50
+
51
+ # Where to store the local copy of the dataset
52
+ local_dir = repo_id
53
+
54
+ # Create embed folder
55
+ embed_folder = f"{year}-diff-embed"
56
+ os.makedirs(embed_folder, exist_ok=True)
57
+
58
+ ################################################################################
59
+ # Download the dataset
60
+
61
+ # Dataset name
62
+ dataset_path = 'Cornell-University/arxiv'
63
+
64
+ # Download folder
65
+ download_folder = 'data'
66
+
67
+ # Data file path
68
+ download_file = f'{download_folder}/arxiv-metadata-oai-snapshot.json'
69
+
70
+ ## Download the dataset if it doesn't exist
71
+ if not os.path.exists(download_file) or FORCE:
72
+
73
+ print(f'Downloading {download_file}')
74
+
75
+ subprocess.run(['kaggle', 'datasets', 'download', '--dataset', dataset_path, '--path', download_folder, '--unzip'])
76
+
77
+ print(f'Downloaded {download_file}')
78
+
79
+ else:
80
+
81
+ print(f'{download_file} already exists')
82
+ print('Skipping download')
83
+
84
+ ################################################################################
85
+ # Filter by year and convert to parquet
86
+
87
+ # https://huggingface.co/docs/datasets/en/about_arrow#memory-mapping
88
+ # Load metadata
89
+ print(f"Loading json metadata")
90
+ dataset = load_dataset("json", data_files= str(f"{download_file}"), num_proc=num_cores)
91
+
92
+ ########################################
93
+ # Function to add year to metadata
94
+ def add_year(example):
95
+
96
+ example['year'] = example['id'].split('/')[1][:2] if '/' in example['id'] else example['id'][:2]
97
+
98
+ return example
99
+ ########################################
100
+
101
+ # Add year to metadata
102
+ print(f"Adding year to metadata")
103
+ dataset = dataset.map(add_year, num_proc=num_cores)
104
+
105
+ # Filter by year
106
+ print(f"Filtering metadata by year: {year}")
107
+ dataset = dataset.filter(lambda example: example['year'] == year, num_proc=num_cores)
108
+
109
+ # Convert to pandas
110
+ print(f"Loading metadata for year: {year} into pandas")
111
+ arxiv_metadata_split = dataset['train'].to_pandas()
112
+
113
+ ################################################################################
114
+ # Load Model
115
+
116
+ if LOCAL:
117
+ # Make the app device agnostic
118
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
119
+
120
+ # Load a pretrained Sentence Transformer model and move it to the appropriate device
121
+ print(f"Loading model {model_name} to device: {device}")
122
+ model = SentenceTransformer(model_name)
123
+ model = model.to(device)
124
+ else:
125
+ print("Setting up mxbai client")
126
+ # Setup mxbai
127
+ mxbai_api_key = os.getenv("MXBAI_API_KEY")
128
+ mxbai = MixedbreadAI(api_key=mxbai_api_key)
129
+
130
+ ########################################
131
+ # Function that does the embedding
132
+ def embed(input_text):
133
+
134
+ if LOCAL:
135
+
136
+ # Calculate embeddings by calling model.encode(), specifying the device
137
+ embedding = model.encode(input_text, device=device)
138
+
139
+ else:
140
+
141
+ # Calculate embeddings by calling mxbai.embeddings()
142
+ result = mxbai.embeddings(
143
+ model='mixedbread-ai/mxbai-embed-large-v1',
144
+ input=input_text,
145
+ normalized=True,
146
+ encoding_format='float',
147
+ truncation_strategy='end'
148
+ )
149
+
150
+ embedding = np.array(result.data[0].embedding)
151
+
152
+ return embedding
153
+ ########################################
154
+
155
+ ################################################################################
156
+ # Gather preexisting embeddings
157
+
158
+ # Create local directory
159
+ os.makedirs(local_dir, exist_ok=True)
160
+
161
+ # Download the repo
162
+ snapshot_download(repo_id=repo_id, repo_type=repo_type, local_dir=local_dir, allow_patterns=allow_patterns)
163
+
164
+ # Gather previous embed file
165
+ previous_embed = f'{local_dir}/{folder_in_repo}/{year}.parquet'
166
+
167
+ # Load previous_embed
168
+ print(f"Loading previously embedded file: {previous_embed}")
169
+ previous_embeddings = pd.read_parquet(previous_embed)
170
+
171
+ ########################################
172
+ # Embed the new abstracts
173
+
174
+ # Find papers that are not in the previous embeddings
175
+ new_papers = arxiv_metadata_split[~arxiv_metadata_split['id'].isin(previous_embeddings['id'])]
176
+
177
+ # Create a column for embeddings
178
+ print(f"Creating new embeddings for: {len(new_papers)} entries")
179
+ new_papers["vector"] = new_papers["abstract"].progress_apply(embed)
180
+
181
+ # Rename columns
182
+ new_papers.rename(columns={'title': 'Title', 'authors': 'Authors', 'abstract': 'Abstract'}, inplace=True)
183
+
184
+ # Add URL column
185
+ new_papers['URL'] = 'https://arxiv.org/abs/' + new_papers['id']
186
+
187
+ # Create milvus compatible parquet file, $meta is a json string of the metadata
188
+ new_papers['$meta'] = new_papers[['Title', 'Authors', 'Abstract', 'URL']].apply(lambda row: json.dumps(row.to_dict()), axis=1)
189
+
190
+ # Selecting id, vector and $meta to retain
191
+ selected_columns = ['id', 'vector', '$meta']
192
+
193
+ # Merge previous embeddings and new embeddings
194
+ new_embeddings = pd.concat([previous_embeddings, new_papers[selected_columns]])
195
+
196
+ # Save the embedded file
197
+ embed_filename = f'{embed_folder}/{year}.parquet'
198
+ print(f"Saving newly embedded dataframe to: {embed_filename}")
199
+ # Keeping index=False to avoid saving the index column as a separate column in the parquet file
200
+ # This keeps milvus from throwing an error when importing the parquet file
201
+ new_embeddings.to_parquet(embed_filename, index=False)
202
+
203
+ ################################################################################
204
+
205
+ # Upload the new embeddings to the repo
206
+ if UPLOAD:
207
+
208
+ print(f"Uploading new embeddings to: {repo_id}")
209
+ access_token = os.getenv("HF_API_KEY")
210
+ api = HfApi(token=access_token)
211
+
212
+ # Upload all files within the folder to the specified repository
213
+ api.upload_folder(repo_id=repo_id, folder_path=embed_folder, path_in_repo=folder_in_repo, repo_type="dataset")