Create update_embeddings.py
Browse files- update_embeddings.py +213 -0
update_embeddings.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Download the arXiv metadata from Kaggle
|
2 |
+
## https://www.kaggle.com/datasets/Cornell-University/arxiv
|
3 |
+
|
4 |
+
## Requires the Kaggle API to be installed
|
5 |
+
## Using subprocess to run the Kaggle CLI commands instead of Kaggle API
|
6 |
+
## As it allows for anonymous downloads without needing to sign in
|
7 |
+
import subprocess
|
8 |
+
from datasets import load_dataset # To load dataset without breaking ram
|
9 |
+
from multiprocessing import cpu_count # To get the number of cores
|
10 |
+
from sentence_transformers import SentenceTransformer # For embedding the text
|
11 |
+
import torch # For gpu
|
12 |
+
import pandas as pd # Data manipulation
|
13 |
+
from huggingface_hub import snapshot_download # Download previous embeddings
|
14 |
+
import json # To make milvus compatible $meta
|
15 |
+
import os # Folder and file creation
|
16 |
+
from tqdm import tqdm # Progress bar
|
17 |
+
tqdm.pandas() # Progress bar for pandas
|
18 |
+
from mixedbread_ai.client import MixedbreadAI # For embedding the text
|
19 |
+
import numpy as np # For array manipulation
|
20 |
+
from huggingface_hub import HfApi # To transact with huggingface.co
|
21 |
+
|
22 |
+
################################################################################
|
23 |
+
# Configuration
|
24 |
+
|
25 |
+
# Year to update embeddings for
|
26 |
+
year = '24'
|
27 |
+
|
28 |
+
# Flag to force download and conversion even if files already exist
|
29 |
+
FORCE = True
|
30 |
+
|
31 |
+
# Flag to embed the data locally, otherwise it will use mxbai api to embed
|
32 |
+
LOCAL = False
|
33 |
+
|
34 |
+
# Flag to upload the data to the Hugging Face Hub
|
35 |
+
UPLOAD = True
|
36 |
+
|
37 |
+
# Model to use for embedding
|
38 |
+
model_name = "mixedbread-ai/mxbai-embed-large-v1"
|
39 |
+
|
40 |
+
# Number of cores to use for multiprocessing
|
41 |
+
num_cores = cpu_count()-1
|
42 |
+
|
43 |
+
# Setup transaction details
|
44 |
+
repo_id = "bluuebunny/arxiv_abstract_embedding_mxbai_large_v1_milvus"
|
45 |
+
repo_type = "dataset"
|
46 |
+
|
47 |
+
# Subfolder in the repo of the dataset where the file is stored
|
48 |
+
folder_in_repo = "data"
|
49 |
+
allow_patterns = f"{folder_in_repo}/{year}.parquet"
|
50 |
+
|
51 |
+
# Where to store the local copy of the dataset
|
52 |
+
local_dir = repo_id
|
53 |
+
|
54 |
+
# Create embed folder
|
55 |
+
embed_folder = f"{year}-diff-embed"
|
56 |
+
os.makedirs(embed_folder, exist_ok=True)
|
57 |
+
|
58 |
+
################################################################################
|
59 |
+
# Download the dataset
|
60 |
+
|
61 |
+
# Dataset name
|
62 |
+
dataset_path = 'Cornell-University/arxiv'
|
63 |
+
|
64 |
+
# Download folder
|
65 |
+
download_folder = 'data'
|
66 |
+
|
67 |
+
# Data file path
|
68 |
+
download_file = f'{download_folder}/arxiv-metadata-oai-snapshot.json'
|
69 |
+
|
70 |
+
## Download the dataset if it doesn't exist
|
71 |
+
if not os.path.exists(download_file) or FORCE:
|
72 |
+
|
73 |
+
print(f'Downloading {download_file}')
|
74 |
+
|
75 |
+
subprocess.run(['kaggle', 'datasets', 'download', '--dataset', dataset_path, '--path', download_folder, '--unzip'])
|
76 |
+
|
77 |
+
print(f'Downloaded {download_file}')
|
78 |
+
|
79 |
+
else:
|
80 |
+
|
81 |
+
print(f'{download_file} already exists')
|
82 |
+
print('Skipping download')
|
83 |
+
|
84 |
+
################################################################################
|
85 |
+
# Filter by year and convert to parquet
|
86 |
+
|
87 |
+
# https://huggingface.co/docs/datasets/en/about_arrow#memory-mapping
|
88 |
+
# Load metadata
|
89 |
+
print(f"Loading json metadata")
|
90 |
+
dataset = load_dataset("json", data_files= str(f"{download_file}"), num_proc=num_cores)
|
91 |
+
|
92 |
+
########################################
|
93 |
+
# Function to add year to metadata
|
94 |
+
def add_year(example):
|
95 |
+
|
96 |
+
example['year'] = example['id'].split('/')[1][:2] if '/' in example['id'] else example['id'][:2]
|
97 |
+
|
98 |
+
return example
|
99 |
+
########################################
|
100 |
+
|
101 |
+
# Add year to metadata
|
102 |
+
print(f"Adding year to metadata")
|
103 |
+
dataset = dataset.map(add_year, num_proc=num_cores)
|
104 |
+
|
105 |
+
# Filter by year
|
106 |
+
print(f"Filtering metadata by year: {year}")
|
107 |
+
dataset = dataset.filter(lambda example: example['year'] == year, num_proc=num_cores)
|
108 |
+
|
109 |
+
# Convert to pandas
|
110 |
+
print(f"Loading metadata for year: {year} into pandas")
|
111 |
+
arxiv_metadata_split = dataset['train'].to_pandas()
|
112 |
+
|
113 |
+
################################################################################
|
114 |
+
# Load Model
|
115 |
+
|
116 |
+
if LOCAL:
|
117 |
+
# Make the app device agnostic
|
118 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
119 |
+
|
120 |
+
# Load a pretrained Sentence Transformer model and move it to the appropriate device
|
121 |
+
print(f"Loading model {model_name} to device: {device}")
|
122 |
+
model = SentenceTransformer(model_name)
|
123 |
+
model = model.to(device)
|
124 |
+
else:
|
125 |
+
print("Setting up mxbai client")
|
126 |
+
# Setup mxbai
|
127 |
+
mxbai_api_key = os.getenv("MXBAI_API_KEY")
|
128 |
+
mxbai = MixedbreadAI(api_key=mxbai_api_key)
|
129 |
+
|
130 |
+
########################################
|
131 |
+
# Function that does the embedding
|
132 |
+
def embed(input_text):
|
133 |
+
|
134 |
+
if LOCAL:
|
135 |
+
|
136 |
+
# Calculate embeddings by calling model.encode(), specifying the device
|
137 |
+
embedding = model.encode(input_text, device=device)
|
138 |
+
|
139 |
+
else:
|
140 |
+
|
141 |
+
# Calculate embeddings by calling mxbai.embeddings()
|
142 |
+
result = mxbai.embeddings(
|
143 |
+
model='mixedbread-ai/mxbai-embed-large-v1',
|
144 |
+
input=input_text,
|
145 |
+
normalized=True,
|
146 |
+
encoding_format='float',
|
147 |
+
truncation_strategy='end'
|
148 |
+
)
|
149 |
+
|
150 |
+
embedding = np.array(result.data[0].embedding)
|
151 |
+
|
152 |
+
return embedding
|
153 |
+
########################################
|
154 |
+
|
155 |
+
################################################################################
|
156 |
+
# Gather preexisting embeddings
|
157 |
+
|
158 |
+
# Create local directory
|
159 |
+
os.makedirs(local_dir, exist_ok=True)
|
160 |
+
|
161 |
+
# Download the repo
|
162 |
+
snapshot_download(repo_id=repo_id, repo_type=repo_type, local_dir=local_dir, allow_patterns=allow_patterns)
|
163 |
+
|
164 |
+
# Gather previous embed file
|
165 |
+
previous_embed = f'{local_dir}/{folder_in_repo}/{year}.parquet'
|
166 |
+
|
167 |
+
# Load previous_embed
|
168 |
+
print(f"Loading previously embedded file: {previous_embed}")
|
169 |
+
previous_embeddings = pd.read_parquet(previous_embed)
|
170 |
+
|
171 |
+
########################################
|
172 |
+
# Embed the new abstracts
|
173 |
+
|
174 |
+
# Find papers that are not in the previous embeddings
|
175 |
+
new_papers = arxiv_metadata_split[~arxiv_metadata_split['id'].isin(previous_embeddings['id'])]
|
176 |
+
|
177 |
+
# Create a column for embeddings
|
178 |
+
print(f"Creating new embeddings for: {len(new_papers)} entries")
|
179 |
+
new_papers["vector"] = new_papers["abstract"].progress_apply(embed)
|
180 |
+
|
181 |
+
# Rename columns
|
182 |
+
new_papers.rename(columns={'title': 'Title', 'authors': 'Authors', 'abstract': 'Abstract'}, inplace=True)
|
183 |
+
|
184 |
+
# Add URL column
|
185 |
+
new_papers['URL'] = 'https://arxiv.org/abs/' + new_papers['id']
|
186 |
+
|
187 |
+
# Create milvus compatible parquet file, $meta is a json string of the metadata
|
188 |
+
new_papers['$meta'] = new_papers[['Title', 'Authors', 'Abstract', 'URL']].apply(lambda row: json.dumps(row.to_dict()), axis=1)
|
189 |
+
|
190 |
+
# Selecting id, vector and $meta to retain
|
191 |
+
selected_columns = ['id', 'vector', '$meta']
|
192 |
+
|
193 |
+
# Merge previous embeddings and new embeddings
|
194 |
+
new_embeddings = pd.concat([previous_embeddings, new_papers[selected_columns]])
|
195 |
+
|
196 |
+
# Save the embedded file
|
197 |
+
embed_filename = f'{embed_folder}/{year}.parquet'
|
198 |
+
print(f"Saving newly embedded dataframe to: {embed_filename}")
|
199 |
+
# Keeping index=False to avoid saving the index column as a separate column in the parquet file
|
200 |
+
# This keeps milvus from throwing an error when importing the parquet file
|
201 |
+
new_embeddings.to_parquet(embed_filename, index=False)
|
202 |
+
|
203 |
+
################################################################################
|
204 |
+
|
205 |
+
# Upload the new embeddings to the repo
|
206 |
+
if UPLOAD:
|
207 |
+
|
208 |
+
print(f"Uploading new embeddings to: {repo_id}")
|
209 |
+
access_token = os.getenv("HF_API_KEY")
|
210 |
+
api = HfApi(token=access_token)
|
211 |
+
|
212 |
+
# Upload all files within the folder to the specified repository
|
213 |
+
api.upload_folder(repo_id=repo_id, folder_path=embed_folder, path_in_repo=folder_in_repo, repo_type="dataset")
|