sander-wood commited on
Commit
ad822ab
·
verified ·
1 Parent(s): 20c1bbc

Upload 8 files

Browse files
Files changed (8) hide show
  1. README.md +3 -3
  2. app.py +304 -0
  3. config.py +79 -0
  4. extract_clamp3.py +189 -0
  5. features.zip +3 -0
  6. requirements.txt +72 -0
  7. utils.py +574 -0
  8. wikimt-x-public.jsonl +0 -0
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: Clamp3
3
- emoji:
4
- colorFrom: gray
5
- colorTo: pink
6
  sdk: gradio
7
  sdk_version: 5.16.0
8
  app_file: app.py
 
1
  ---
2
  title: Clamp3
3
+ emoji: 🗜️
4
+ colorFrom: pink
5
+ colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 5.16.0
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import gradio as gr
5
+ import zipfile
6
+ import json
7
+ import requests
8
+ import subprocess
9
+ import shutil
10
+ from transformers import BlipProcessor, BlipForConditionalGeneration
11
+
12
+ title = "# 🗜️ CLaMP 3 - Multimodal & Multilingual Semantic Music Search"
13
+
14
+ badges = """
15
+ <div style="text-align: center;">
16
+ <a href="#"><img src="https://img.shields.io/badge/CLaMP%203%20Homepage-Coming%20Soon-lightgrey?style=for-the-badge&logo=home-assistant" alt="Homepage"></a>
17
+ <a href="#"><img src="https://img.shields.io/badge/CLaMP%203%20Paper-Coming%20Soon-lightgrey?style=for-the-badge&logo=arxiv" alt="Paper"></a>
18
+ <a href="https://github.com/sanderwood/clamp3"><img src="https://img.shields.io/badge/CLaMP%203%20Code-GitHub-181717?style=for-the-badge&logo=github" alt="GitHub"></a>
19
+ <a href="https://huggingface.co/sander-wood/clamp3/tree/main"><img src="https://img.shields.io/badge/Model%20Weights-Hugging%20Face-ffcc00?style=for-the-badge&logo=huggingface" alt="Model Weights"></a>
20
+ <a href="https://huggingface.co/datasets/sander-wood/m4-rag"><img src="https://img.shields.io/badge/M4--RAG%20Pretraining%20Dataset-Hugging%20Face-ffcc00?style=for-the-badge&logo=huggingface" alt="Dataset"></a>
21
+ <a href="https://huggingface.co/datasets/sander-wood/wikimt-x"><img src="https://img.shields.io/badge/WikiMT--X%20Evaluation%20Benchmark-Hugging%20Face-ffcc00?style=for-the-badge&logo=huggingface" alt="Benchmark"></a>
22
+ </div>
23
+ <style>
24
+ div a {
25
+ display: inline-block;
26
+ margin: 5px;
27
+ }
28
+ div a img {
29
+ height: 30px;
30
+ }
31
+ </style>
32
+ """
33
+
34
+ description = """CLaMP 3 is a **multimodal and multilingual** music information retrieval (MIR) framework, supporting **sheet music, audio, and performance signals** in over **100 languages**. Using **contrastive learning**, it aligns these modalities in a shared space for **cross-modal retrieval**.
35
+
36
+ ### 🔍 **How This Demo Works**
37
+ - You can **retrieve music using any text input (in any language) or an image** (`.png`, `.jpg`).
38
+ - When using an image, **BLIP** generates a caption, which is then used for retrieval.
39
+ - Since CLaMP 3's training data includes **rich visual descriptions of musical scenes**, it can **match images to semantically relevant music**.
40
+
41
+ ### ⚠️ **Limitations**
42
+ - This demo retrieves music **only from the WikiMT-X benchmark (1,000 pieces)**.
43
+ - These pieces are **mainly from the U.S. and Western Europe (especially the U.S.)** and **mostly from the 20th century**.
44
+ - The retrieval results are **mostly limited to Western 20th-century music**, so you **won’t** find music from **other regions or historical periods**.
45
+ - If you need retrieval for a **different music collection**, deploy **CLaMP 3 on your own dataset**.
46
+
47
+ This demo is for **research purposes only**."""
48
+
49
+ # Load BLIP image captioning model and processor
50
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
51
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
52
+
53
+ # Download weight file if it does not exist
54
+ weights_url = "https://huggingface.co/sander-wood/clamp3/resolve/main/weights_clamp3_saas_h_size_768_t_model_FacebookAI_xlm-roberta-base_t_length_128_a_size_768_a_layers_12_a_length_128_s_size_768_s_layers_12_p_size_64_p_length_512.pth"
55
+ weights_filename = "weights_clamp3_saas_h_size_768_t_model_FacebookAI_xlm-roberta-base_t_length_128_a_size_768_a_layers_12_a_length_128_s_size_768_s_layers_12_p_size_64_p_length_512.pth"
56
+
57
+ if not os.path.exists(weights_filename):
58
+ print("Downloading weights file...")
59
+ response = requests.get(weights_url, stream=True)
60
+ response.raise_for_status()
61
+ with open(weights_filename, "wb") as f:
62
+ for chunk in response.iter_content(chunk_size=8192):
63
+ if chunk:
64
+ f.write(chunk)
65
+ print("Weights file downloaded.")
66
+
67
+ ZIP_PATH = "features.zip"
68
+ if os.path.exists(ZIP_PATH):
69
+ print(f"Extracting {ZIP_PATH}...")
70
+ with zipfile.ZipFile(ZIP_PATH, "r") as zip_ref:
71
+ zip_ref.extractall(".")
72
+ print("Extraction complete.")
73
+
74
+ # Load metadata
75
+ metadata_map = {}
76
+ METADATA_FILE = "wikimt-x-public.jsonl"
77
+ if os.path.exists(METADATA_FILE):
78
+ with open(METADATA_FILE, "r", encoding="utf-8") as f:
79
+ for line in f:
80
+ data = json.loads(line)
81
+ metadata_map[data["id"]] = data
82
+ else:
83
+ print(f"Warning: {METADATA_FILE} not found.")
84
+
85
+ features_cache = {}
86
+
87
+ def get_info(folder_path):
88
+ """
89
+ Load all .npy files from the specified folder and return a dictionary
90
+ with the file names (without extension) as keys.
91
+ """
92
+ if folder_path in features_cache:
93
+ return features_cache[folder_path]
94
+ if not os.path.exists(folder_path):
95
+ return {}
96
+ files = sorted(os.listdir(folder_path))
97
+ features = {}
98
+ for file in files:
99
+ if file.endswith(".npy"):
100
+ key = file.split(".")[0]
101
+ try:
102
+ features[key] = np.load(os.path.join(folder_path, file))[0]
103
+ except Exception as e:
104
+ print(f"Error loading {file}: {e}")
105
+ features_cache[folder_path] = features
106
+ return features
107
+
108
+ def find_top_similar(query_file, reference_folder):
109
+ """
110
+ Compare the query feature with all reference features in the specified folder
111
+ using cosine similarity and return the top 10 candidate results in the format:
112
+ Title | Artists | sim: SimilarityScore.
113
+ """
114
+ top_k = 10
115
+ try:
116
+ query_feature = np.load(query_file.name)[0]
117
+ except Exception as e:
118
+ return [], f"Error loading query feature: {e}"
119
+ query_tensor = torch.tensor(query_feature, dtype=torch.float32).unsqueeze(dim=0)
120
+ key_features = get_info(reference_folder)
121
+ if not key_features:
122
+ return [], f"No reference features found in {reference_folder}."
123
+ ref_keys = list(key_features.keys())
124
+ ref_array = np.array([key_features[k] for k in ref_keys])
125
+ key_feats_tensor = torch.tensor(ref_array, dtype=torch.float32)
126
+ query_tensor_expanded = query_tensor.expand(key_feats_tensor.size(0), -1)
127
+ similarities = torch.cosine_similarity(query_tensor_expanded, key_feats_tensor, dim=1)
128
+ ranked_indices = torch.argsort(similarities, descending=True)
129
+ candidate_ids = []
130
+ candidate_display = []
131
+ for i in range(top_k):
132
+ if i < len(ref_keys):
133
+ candidate_idx = ranked_indices[i].item()
134
+ candidate_id = ref_keys[candidate_idx]
135
+ sim = round(similarities[candidate_idx].item(), 4)
136
+ meta = metadata_map.get(candidate_id, {})
137
+ title = meta.get("title", candidate_id)
138
+ artists = meta.get("artists", "Unknown")
139
+ if isinstance(artists, list):
140
+ artists = ", ".join(artists)
141
+ candidate_ids.append(candidate_id)
142
+ candidate_display.append(f"{title} | {artists} | sim: {sim}")
143
+ else:
144
+ candidate_ids.append("N/A")
145
+ candidate_display.append("N/A")
146
+ return candidate_ids, candidate_display
147
+
148
+ def show_details(selected_id):
149
+ """
150
+ Return detailed metadata and embedded YouTube video HTML based on the candidate ID.
151
+ """
152
+ if selected_id == "N/A":
153
+ return ("", "", "", "", "", "", "", "")
154
+ data = metadata_map.get(selected_id, {})
155
+ if not data:
156
+ return ("No details found", "", "", "", "", "", "", "")
157
+ title = data.get("title", "")
158
+ artists = data.get("artists", "")
159
+ if isinstance(artists, list):
160
+ artists = ", ".join(artists)
161
+ genre = data.get("genre", "")
162
+ background = data.get("background", "")
163
+ analysis = data.get("analysis", "")
164
+ description = data.get("description", "")
165
+ scene = data.get("scene", "")
166
+ youtube_html = (
167
+ f'<iframe width="560" height="315" src="https://www.youtube.com/embed/{selected_id}" '
168
+ f'frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; '
169
+ f'gyroscope; picture-in-picture" allowfullscreen></iframe>'
170
+ )
171
+ return title, artists, genre, background, analysis, description, scene, youtube_html
172
+
173
+ def extract_features_from_text(text):
174
+ """
175
+ Save the input text to a file, call the CLaMP 3 feature extraction script,
176
+ and return the generated feature file path.
177
+ """
178
+ input_dir = "input_dir"
179
+ output_dir = "output_dir"
180
+ os.makedirs(input_dir, exist_ok=True)
181
+ os.makedirs(output_dir, exist_ok=True)
182
+ # Clear input_dir and output_dir
183
+ for d in [input_dir, output_dir]:
184
+ for filename in os.listdir(d):
185
+ file_path = os.path.join(d, filename)
186
+ if os.path.isfile(file_path) or os.path.islink(file_path):
187
+ os.unlink(file_path)
188
+ elif os.path.isdir(file_path):
189
+ shutil.rmtree(file_path)
190
+ input_file = os.path.join(input_dir, "input.txt")
191
+ print("Text input:", text)
192
+ with open(input_file, "w", encoding="utf-8") as f:
193
+ f.write(text)
194
+ command = ["python", "extract_clamp3.py", input_dir, output_dir, "--get_global"]
195
+ subprocess.run(command, check=True)
196
+ output_file = os.path.join(output_dir, "input.npy")
197
+ return output_file
198
+
199
+ def generate_caption(image):
200
+ """
201
+ Use the BLIP model to generate a descriptive caption for the given image.
202
+ """
203
+ inputs = processor(image, return_tensors="pt")
204
+ outputs = blip_model.generate(**inputs)
205
+ caption = processor.decode(outputs[0], skip_special_tokens=True)
206
+ return caption
207
+
208
+ class FileWrapper:
209
+ """
210
+ Simulate a file object with a .name attribute.
211
+ """
212
+ def __init__(self, path):
213
+ self.name = path
214
+
215
+ def search_wrapper(search_mode, text_input, image_input):
216
+ """
217
+ Perform retrieval based on the selected input mode:
218
+ - If search_mode is "Image", use the uploaded image to generate a caption, then extract features
219
+ and search in the "image/" folder.
220
+ - If search_mode is "Text", use the provided text to extract features and search in the "image/" folder.
221
+ """
222
+ if search_mode == "Image":
223
+ if image_input is None:
224
+ return text_input, gr.update(choices=[]), "Please upload an image.", "", "", "", "", "", "", ""
225
+ caption = generate_caption(image_input)
226
+ text_to_use = caption
227
+ reference_folder = "image/"
228
+ elif search_mode == "Text":
229
+ if not text_input or text_input.strip() == "":
230
+ return "Describe the music you're looking for (in any language)", gr.update(choices=[]), "Please enter text for retrieval.", "", "", "", "", "", "", ""
231
+ text_to_use = text_input
232
+ reference_folder = "text/"
233
+ else:
234
+ return "Describe the music you're looking for (in any language)", gr.update(choices=[]), "Invalid search mode selected.", "", "", "", "", "", "", ""
235
+
236
+ try:
237
+ output_file = extract_features_from_text(text_to_use)
238
+ query_file = FileWrapper(output_file)
239
+ except Exception as e:
240
+ return text_to_use, gr.update(choices=[]), f"Error during feature extraction: {e}", "", "", "", "", "", "", ""
241
+ candidate_ids, candidate_display = find_top_similar(query_file, reference_folder)
242
+ if not candidate_ids:
243
+ return text_to_use, gr.update(choices=[]), "", "", "", "", "", "", "", ""
244
+ choices = [(f"{i+1}. {disp}", cid) for i, (cid, disp) in enumerate(zip(candidate_ids, candidate_display))]
245
+ top_candidate = candidate_ids[0]
246
+ details = show_details(top_candidate)
247
+ return text_to_use, gr.update(choices=choices), *details
248
+
249
+ with gr.Blocks() as demo:
250
+ gr.Markdown(title)
251
+ gr.HTML(badges)
252
+ gr.Markdown(description)
253
+ gr.HTML(
254
+ """
255
+ <style>
256
+ .vertical-radio .gradio-radio label {
257
+ display: block !important;
258
+ margin-bottom: 5px;
259
+ }
260
+ </style>
261
+ """
262
+ )
263
+ with gr.Row():
264
+ with gr.Column():
265
+ search_mode = gr.Radio(
266
+ choices=["Text", "Image"],
267
+ label="Select Search Mode",
268
+ value="Text",
269
+ interactive=True,
270
+ elem_classes=["vertical-radio"]
271
+ )
272
+ text_input = gr.Textbox(
273
+ placeholder="Describe the music you're looking for (in any language)",
274
+ lines=4
275
+ )
276
+ image_input = gr.Image(
277
+ label="Or upload an image (PNG, JPG)",
278
+ type="pil"
279
+ )
280
+ search_button = gr.Button("Search")
281
+ candidate_radio = gr.Radio(choices=[], label="Select Retrieval Result", interactive=True, elem_classes=["vertical-radio"])
282
+ with gr.Column():
283
+ gr.Markdown("### YouTube Video")
284
+ youtube_box = gr.HTML(label="YouTube Video")
285
+ gr.Markdown("### Metadata")
286
+ title_box = gr.Textbox(label="Title", interactive=False)
287
+ artists_box = gr.Textbox(label="Artists", interactive=False)
288
+ genre_box = gr.Textbox(label="Genre", interactive=False)
289
+ background_box = gr.Textbox(label="Background", interactive=False)
290
+ analysis_box = gr.Textbox(label="Analysis", interactive=False)
291
+ description_box = gr.Textbox(label="Description", interactive=False)
292
+ scene_box = gr.Textbox(label="Scene", interactive=False)
293
+ search_button.click(
294
+ fn=search_wrapper,
295
+ inputs=[search_mode, text_input, image_input],
296
+ outputs=[text_input, candidate_radio, title_box, artists_box, genre_box, background_box, analysis_box, description_box, scene_box, youtube_box]
297
+ )
298
+ candidate_radio.change(
299
+ fn=show_details,
300
+ inputs=candidate_radio,
301
+ outputs=[title_box, artists_box, genre_box, background_box, analysis_box, description_box, scene_box, youtube_box]
302
+ )
303
+
304
+ demo.launch()
config.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ EVAL_SPLIT = 0.01 # Fraction of training data used for evaluation
2
+ WANDB_KEY = "<YOUR_WANDB_KEY>" # Weights and Biases API key
3
+
4
+ # -------------------- Configuration for M3 Training --------------------
5
+ M3_TRAIN_FOLDERS = [
6
+ "<YOUR_TRAINING_DATA_FOLDER>" # Directory containing training data for M3
7
+ ]
8
+
9
+ M3_EVAL_FOLDERS = [
10
+ "<YOUR_EVALUATION_DATA_FOLDER>" # Directory containing evaluation data for M3 (optional)
11
+ ]
12
+
13
+ PATCH_SIZE = 64 # Size of each patch
14
+ PATCH_LENGTH = 512 # Length of the patches
15
+ PATCH_NUM_LAYERS = 12 # Number of layers in the encoder
16
+ TOKEN_NUM_LAYERS = 3 # Number of layers in the decoder
17
+ M3_HIDDEN_SIZE = 768 # Size of the hidden layer
18
+
19
+ M3_NUM_EPOCH = 100 # Maximum number of epochs for training
20
+ M3_LEARNING_RATE = 1e-4 # Learning rate for the optimizer
21
+ M3_BATCH_SIZE = 16 # Batch size per GPU (single card) during training
22
+ M3_MASK_RATIO = 0.45 # Ratio of masked elements during training
23
+ M3_DETERMINISTIC = True # Ensures deterministic results with random seeds
24
+ M3_WANDB_LOG = True # Enable logging to Weights and Biases
25
+ M3_LOAD_CKPT = True # Load model weights from a checkpoint if available
26
+
27
+ M3_WEIGHTS_PATH = (
28
+ "weights_m3"+
29
+ "_h_size_" + str(M3_HIDDEN_SIZE) +
30
+ "_t_layers_" + str(TOKEN_NUM_LAYERS) +
31
+ "_p_layers_" + str(PATCH_NUM_LAYERS) +
32
+ "_p_size_" + str(PATCH_SIZE) +
33
+ "_p_length_" + str(PATCH_LENGTH) +
34
+ "_lr_" + str(M3_LEARNING_RATE) +
35
+ "_batch_" + str(M3_BATCH_SIZE) +
36
+ "_mask_" + str(M3_MASK_RATIO) + ".pth"
37
+ ) # Path to store the model weights
38
+ M3_LOGS_PATH = M3_WEIGHTS_PATH.replace("weights", "logs").replace("pth", "txt") # Path to save training logs
39
+
40
+ # -------------------- Configuration for CLaMP3 Training ----------------
41
+ CLAMP3_TRAIN_JSONL = "<YOUR_TRAINING_JSONL_FILE>" # Path to the JSONL file with training data for CLaMP3
42
+ CLAMP3_EVAL_JSONL = "<YOUR_EVALUATION_JSONL_FILE>" # Path to the JSONL file with evaluation data for CLaMP3 (optional)
43
+
44
+ CLAMP3_HIDDEN_SIZE = 768 # Size of the hidden layer
45
+ TEXT_MODEL_NAME = "FacebookAI/xlm-roberta-base" # Name of the pre-trained text model
46
+ MAX_TEXT_LENGTH = 128 # Maximum allowed length for text input
47
+
48
+ AUDIO_HIDDEN_SIZE = 768 # Size of the hidden layer for audio features
49
+ AUDIO_NUM_LAYERS = 12 # Number of layers in the audio encoder
50
+ MAX_AUDIO_LENGTH = 128 # Maximum allowed length for audio input
51
+
52
+ CLAMP3_NUM_EPOCH = 100 # Maximum number of epochs for training
53
+ CLAMP3_LEARNING_RATE = 1e-5 # Learning rate for the optimizer
54
+ CLAMP3_BATCH_SIZE = 256 # Batch size per GPU (single card) during training
55
+ LOGIT_SCALE = 1 # Scaling factor for contrastive loss
56
+
57
+ FREEZE_TEXT = False # Freeze the weights of the text model and text projection layer
58
+ TEXT_DROPOUT = True # Whether to apply dropout during text processing
59
+ CLAMP3_DETERMINISTIC = True # Ensures deterministic results with random seeds
60
+ CLAMP3_LOAD_M3 = True # Load weights from the M3 model
61
+ CLAMP3_WANDB_LOG = True # Enable logging to Weights and Biases
62
+ CLAMP3_LOAD_CKPT = True # Load weights from a checkpoint if available
63
+ SAVE_EVERY = 5 # Save model weights every SAVE_EVERY epochs
64
+
65
+ CLAMP3_WEIGHTS_PATH = (
66
+ "weights_clamp3_saas" +
67
+ "_h_size_" + str(CLAMP3_HIDDEN_SIZE) +
68
+ "_t_model_" + TEXT_MODEL_NAME.replace("/", "_") +
69
+ "_t_length_" + str(MAX_TEXT_LENGTH) +
70
+ "_a_size_" + str(AUDIO_HIDDEN_SIZE) +
71
+ "_a_layers_" + str(AUDIO_NUM_LAYERS) +
72
+ "_a_length_" + str(MAX_AUDIO_LENGTH) +
73
+ "_s_size_" + str(M3_HIDDEN_SIZE) +
74
+ "_s_layers_" + str(PATCH_NUM_LAYERS) +
75
+ "_p_size_" + str(PATCH_SIZE) +
76
+ "_p_length_" + str(PATCH_LENGTH) + ".pth"
77
+
78
+ ) # Path to store CLaMP3 model weights
79
+ CLAMP3_LOGS_PATH = CLAMP3_WEIGHTS_PATH.replace("weights", "logs").replace("pth", "txt") # Path to save training logs
extract_clamp3.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ from config import *
6
+ from utils import *
7
+ from samplings import *
8
+ from accelerate import Accelerator
9
+ from transformers import BertConfig, AutoTokenizer
10
+ import argparse
11
+
12
+ # Parse command-line arguments
13
+ parser = argparse.ArgumentParser(description="Feature extraction for CLaMP3.")
14
+ parser.add_argument("--epoch", type=str, default=None, help="Epoch of the checkpoint to load.")
15
+ parser.add_argument("input_dir", type=str, help="Directory containing input data files.")
16
+ parser.add_argument("output_dir", type=str, help="Directory to save the output features.")
17
+ parser.add_argument("--get_global", action="store_true", help="Get global feature.")
18
+
19
+ args = parser.parse_args()
20
+
21
+ # Retrieve arguments
22
+ epoch = args.epoch
23
+ input_dir = args.input_dir
24
+ output_dir = args.output_dir
25
+ get_global = args.get_global
26
+
27
+ files = []
28
+ for root, dirs, fs in os.walk(input_dir):
29
+ for f in fs:
30
+ if f.endswith(".txt") or f.endswith(".abc") or f.endswith(".mtf") or f.endswith(".npy"):
31
+ files.append(os.path.join(root, f))
32
+
33
+ print(f"Found {len(files)} files in total")
34
+
35
+ # Initialize accelerator and device
36
+ accelerator = Accelerator()
37
+ device = accelerator.device
38
+ print("Using device:", device)
39
+
40
+ # Model and configuration setup
41
+ audio_config = BertConfig(vocab_size=1,
42
+ hidden_size=AUDIO_HIDDEN_SIZE,
43
+ num_hidden_layers=AUDIO_NUM_LAYERS,
44
+ num_attention_heads=AUDIO_HIDDEN_SIZE//64,
45
+ intermediate_size=AUDIO_HIDDEN_SIZE*4,
46
+ max_position_embeddings=MAX_AUDIO_LENGTH)
47
+ symbolic_config = BertConfig(vocab_size=1,
48
+ hidden_size=M3_HIDDEN_SIZE,
49
+ num_hidden_layers=PATCH_NUM_LAYERS,
50
+ num_attention_heads=M3_HIDDEN_SIZE//64,
51
+ intermediate_size=M3_HIDDEN_SIZE*4,
52
+ max_position_embeddings=PATCH_LENGTH)
53
+ model = CLaMP3Model(audio_config=audio_config,
54
+ symbolic_config=symbolic_config,
55
+ text_model_name=TEXT_MODEL_NAME,
56
+ hidden_size=CLAMP3_HIDDEN_SIZE,
57
+ load_m3=CLAMP3_LOAD_M3)
58
+ model = model.to(device)
59
+
60
+ tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
61
+ patchilizer = M3Patchilizer()
62
+
63
+ # print parameter number
64
+ print("Total Parameter Number: "+str(sum(p.numel() for p in model.parameters())))
65
+
66
+ # Load model weights
67
+ model.eval()
68
+ checkpoint_path = CLAMP3_WEIGHTS_PATH
69
+ if epoch is not None:
70
+ checkpoint_path = CLAMP3_WEIGHTS_PATH.replace(".pth", f"_{epoch}.pth")
71
+ checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
72
+ print(f"Successfully Loaded CLaMP 3 Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}")
73
+ model.load_state_dict(checkpoint['model'])
74
+
75
+ def extract_feature(filename, get_global=get_global):
76
+ if not filename.endswith(".npy"):
77
+ with open(filename, "r", encoding="utf-8") as f:
78
+ item = f.read()
79
+
80
+ if filename.endswith(".txt"):
81
+ item = list(set(item.split("\n")))
82
+ item = "\n".join(item)
83
+ item = item.split("\n")
84
+ item = [c for c in item if len(c) > 0]
85
+ item = tokenizer.sep_token.join(item)
86
+ input_data = tokenizer(item, return_tensors="pt")
87
+ input_data = input_data['input_ids'].squeeze(0)
88
+ max_input_length = MAX_TEXT_LENGTH
89
+ elif filename.endswith(".abc") or filename.endswith(".mtf"):
90
+ input_data = patchilizer.encode(item, add_special_patches=True)
91
+ input_data = torch.tensor(input_data)
92
+ max_input_length = PATCH_LENGTH
93
+ elif filename.endswith(".npy"):
94
+ input_data = np.load(filename)
95
+ input_data = torch.tensor(input_data)
96
+ input_data = input_data.reshape(-1, input_data.size(-1))
97
+ zero_vec = torch.zeros((1, input_data.size(-1)))
98
+ input_data = torch.cat((zero_vec, input_data, zero_vec), 0)
99
+ max_input_length = MAX_AUDIO_LENGTH
100
+ else:
101
+ raise ValueError(f"Unsupported file type: {filename}, only support .txt, .abc, .mtf, .npy files")
102
+
103
+ segment_list = []
104
+ for i in range(0, len(input_data), max_input_length):
105
+ segment_list.append(input_data[i:i+max_input_length])
106
+ segment_list[-1] = input_data[-max_input_length:]
107
+
108
+ last_hidden_states_list = []
109
+
110
+ for input_segment in segment_list:
111
+ input_masks = torch.tensor([1]*input_segment.size(0))
112
+ if filename.endswith(".txt"):
113
+ pad_indices = torch.ones(MAX_TEXT_LENGTH - input_segment.size(0)).long() * tokenizer.pad_token_id
114
+ elif filename.endswith(".abc") or filename.endswith(".mtf"):
115
+ pad_indices = torch.ones((PATCH_LENGTH - input_segment.size(0), PATCH_SIZE)).long() * patchilizer.pad_token_id
116
+ else:
117
+ pad_indices = torch.ones((MAX_AUDIO_LENGTH - input_segment.size(0), AUDIO_HIDDEN_SIZE)).float() * 0.
118
+ input_masks = torch.cat((input_masks, torch.zeros(max_input_length - input_segment.size(0))), 0)
119
+ input_segment = torch.cat((input_segment, pad_indices), 0)
120
+
121
+ if filename.endswith(".txt"):
122
+ last_hidden_states = model.get_text_features(text_inputs=input_segment.unsqueeze(0).to(device),
123
+ text_masks=input_masks.unsqueeze(0).to(device),
124
+ get_global=get_global)
125
+ elif filename.endswith(".abc") or filename.endswith(".mtf"):
126
+ last_hidden_states = model.get_symbolic_features(symbolic_inputs=input_segment.unsqueeze(0).to(device),
127
+ symbolic_masks=input_masks.unsqueeze(0).to(device),
128
+ get_global=get_global)
129
+ else:
130
+ last_hidden_states = model.get_audio_features(audio_inputs=input_segment.unsqueeze(0).to(device),
131
+ audio_masks=input_masks.unsqueeze(0).to(device),
132
+ get_global=get_global)
133
+ if not get_global:
134
+ last_hidden_states = last_hidden_states[:, :input_masks.sum().long().item(), :]
135
+ last_hidden_states_list.append(last_hidden_states)
136
+
137
+ if not get_global:
138
+ last_hidden_states_list = [last_hidden_states[0] for last_hidden_states in last_hidden_states_list]
139
+ last_hidden_states_list[-1] = last_hidden_states_list[-1][-(len(input_data)%max_input_length):]
140
+ last_hidden_states_list = torch.concat(last_hidden_states_list, 0)
141
+ else:
142
+ full_chunk_cnt = len(input_data) // max_input_length
143
+ remain_chunk_len = len(input_data) % max_input_length
144
+ if remain_chunk_len == 0:
145
+ feature_weights = torch.tensor([max_input_length] * full_chunk_cnt, device=device).view(-1, 1)
146
+ else:
147
+ feature_weights = torch.tensor([max_input_length] * full_chunk_cnt + [remain_chunk_len], device=device).view(-1, 1)
148
+
149
+ last_hidden_states_list = torch.concat(last_hidden_states_list, 0)
150
+ last_hidden_states_list = last_hidden_states_list * feature_weights
151
+ last_hidden_states_list = last_hidden_states_list.sum(dim=0) / feature_weights.sum()
152
+
153
+ return last_hidden_states_list
154
+
155
+ def process_directory(input_dir, output_dir, files):
156
+ # calculate the number of files to process per GPU
157
+ num_files_per_gpu = len(files) // accelerator.num_processes
158
+
159
+ # calculate the start and end index for the current GPU
160
+ start_idx = accelerator.process_index * num_files_per_gpu
161
+ end_idx = start_idx + num_files_per_gpu
162
+ if accelerator.process_index == accelerator.num_processes - 1:
163
+ end_idx = len(files)
164
+
165
+ files_to_process = files[start_idx:end_idx]
166
+
167
+ # process the files
168
+ for file in tqdm(files_to_process):
169
+ output_subdir = output_dir + os.path.dirname(file)[len(input_dir):]
170
+ try:
171
+ os.makedirs(output_subdir, exist_ok=True)
172
+ except Exception as e:
173
+ print(output_subdir + " can not be created\n" + str(e))
174
+
175
+ output_file = os.path.join(output_subdir, os.path.splitext(os.path.basename(file))[0] + ".npy")
176
+
177
+ if os.path.exists(output_file):
178
+ print(f"Skipping {file}, output already exists")
179
+ continue
180
+
181
+ try:
182
+ with torch.no_grad():
183
+ features = extract_feature(file).unsqueeze(0)
184
+ np.save(output_file, features.detach().cpu().numpy())
185
+ except Exception as e:
186
+ print(f"Failed to process {file}: {e}")
187
+
188
+ # process the files
189
+ process_directory(input_dir, output_dir, files)
features.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60273370643e51e092a466d0e9a28041cfc944b2d0b55f6fbe926081ce1ff570
3
+ size 6242016
requirements.txt ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PyTorch (CPU-only version)
2
+ torch==2.4.0
3
+ torchvision==0.19.0
4
+ torchaudio==2.4.0
5
+ -f https://download.pytorch.org/whl/cpu
6
+
7
+ # Core dependencies
8
+ numpy==1.26.4
9
+ scipy==1.14.1
10
+ scikit-learn==1.5.1
11
+ pandas==1.3.5
12
+ tqdm==4.66.5
13
+ requests==2.32.3
14
+ pillow==9.5.0
15
+ pyyaml==6.0.1
16
+ typing-extensions==4.12.2
17
+
18
+ # Transformers and optimization
19
+ transformers==4.40.0
20
+ optimum==1.21.4
21
+ tokenizers==0.19.1
22
+ sentencepiece==0.2.0
23
+ safetensors==0.4.4
24
+ accelerate==0.34.0
25
+
26
+ # Audio processing
27
+ librosa==0.10.1
28
+ soundfile==0.12.1
29
+ pydub==0.25.1
30
+ soxr==0.5.0.post1
31
+ audioread==3.0.1
32
+ nnAudio==0.3.3
33
+
34
+ # MIDI and music processing
35
+ mido==1.3.0
36
+ music21==7.3.3
37
+ abctoolkit==0.0.4
38
+
39
+ # Natural language processing and text utilities
40
+ nltk==3.8.1
41
+ sacrebleu==2.4.3
42
+ sacremoses==0.0.53
43
+ langdetect==1.0.9
44
+ langid==1.1.6
45
+ language-data==1.2.0
46
+ regex==2023.8.8
47
+ unidecode==1.3.6
48
+
49
+ # Hugging Face Hub
50
+ huggingface-hub==0.24.6
51
+ datasets==2.21.0
52
+
53
+ # Logging and tracking
54
+ wandb==0.17.8
55
+ setproctitle==1.3.3
56
+ sentry-sdk==2.13.0
57
+
58
+ # Utilities
59
+ protobuf==5.28.0
60
+ filelock==3.12.2
61
+ tabulate==0.9.0
62
+ dill==0.3.8
63
+ fsspec==2024.6.1
64
+ xxhash==3.5.0
65
+ gitpython==3.1.43
66
+ certifi==2023.7.22
67
+ charset-normalizer==3.2.0
68
+ urllib3==2.0.4
69
+ yarl==1.9.7
70
+ idna==3.4
71
+ samplings==0.1.7
72
+ six==1.16.0
utils.py ADDED
@@ -0,0 +1,574 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import math
4
+ import torch
5
+ import random
6
+ from config import *
7
+ from unidecode import unidecode
8
+ from torch.nn import functional as F
9
+ from transformers import AutoModel, BertModel, GPT2LMHeadModel, PreTrainedModel, GPT2Config
10
+
11
+ try:
12
+ import torch.distributed.nn
13
+ from torch import distributed as dist
14
+
15
+ has_distributed = True
16
+ except ImportError:
17
+ has_distributed = False
18
+
19
+ try:
20
+ import horovod.torch as hvd
21
+ except ImportError:
22
+ hvd = None
23
+
24
+ class ClipLoss(torch.nn.Module):
25
+
26
+ def __init__(
27
+ self,
28
+ local_loss=False,
29
+ gather_with_grad=False,
30
+ cache_labels=False,
31
+ rank=0,
32
+ world_size=1,
33
+ use_horovod=False,
34
+ ):
35
+ super().__init__()
36
+ self.local_loss = local_loss
37
+ self.gather_with_grad = gather_with_grad
38
+ self.cache_labels = cache_labels
39
+ self.rank = rank
40
+ self.world_size = world_size
41
+ self.use_horovod = use_horovod
42
+
43
+ # cache state
44
+ self.prev_num_logits = 0
45
+ self.labels = {}
46
+
47
+ def gather_features(
48
+ self,
49
+ image_features,
50
+ text_features,
51
+ local_loss=False,
52
+ gather_with_grad=False,
53
+ rank=0,
54
+ world_size=1,
55
+ use_horovod=False
56
+ ):
57
+ assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
58
+ if use_horovod:
59
+ assert hvd is not None, 'Please install horovod'
60
+ if gather_with_grad:
61
+ all_image_features = hvd.allgather(image_features)
62
+ all_text_features = hvd.allgather(text_features)
63
+ else:
64
+ with torch.no_grad():
65
+ all_image_features = hvd.allgather(image_features)
66
+ all_text_features = hvd.allgather(text_features)
67
+ if not local_loss:
68
+ # ensure grads for local rank when all_* features don't have a gradient
69
+ gathered_image_features = list(all_image_features.chunk(world_size, dim=0))
70
+ gathered_text_features = list(all_text_features.chunk(world_size, dim=0))
71
+ gathered_image_features[rank] = image_features
72
+ gathered_text_features[rank] = text_features
73
+ all_image_features = torch.cat(gathered_image_features, dim=0)
74
+ all_text_features = torch.cat(gathered_text_features, dim=0)
75
+ else:
76
+ # We gather tensors from all gpus
77
+ if gather_with_grad:
78
+ all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
79
+ all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
80
+ else:
81
+ gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
82
+ gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
83
+ dist.all_gather(gathered_image_features, image_features)
84
+ dist.all_gather(gathered_text_features, text_features)
85
+ if not local_loss:
86
+ # ensure grads for local rank when all_* features don't have a gradient
87
+ gathered_image_features[rank] = image_features
88
+ gathered_text_features[rank] = text_features
89
+ all_image_features = torch.cat(gathered_image_features, dim=0)
90
+ all_text_features = torch.cat(gathered_text_features, dim=0)
91
+
92
+ return all_image_features, all_text_features
93
+
94
+ def get_ground_truth(self, device, num_logits) -> torch.Tensor:
95
+ # calculated ground-truth and cache if enabled
96
+ if self.prev_num_logits != num_logits or device not in self.labels:
97
+ labels = torch.arange(num_logits, device=device, dtype=torch.long)
98
+ if self.world_size > 1 and self.local_loss:
99
+ labels = labels + num_logits * self.rank
100
+ if self.cache_labels:
101
+ self.labels[device] = labels
102
+ self.prev_num_logits = num_logits
103
+ else:
104
+ labels = self.labels[device]
105
+ return labels
106
+
107
+ def get_logits(self, image_features, text_features, logit_scale):
108
+ if self.world_size > 1:
109
+ all_image_features, all_text_features = self.gather_features(
110
+ image_features, text_features,
111
+ self.local_loss, self.gather_with_grad, self.rank, self.world_size, self.use_horovod)
112
+
113
+ if self.local_loss:
114
+ logits_per_image = logit_scale * image_features @ all_text_features.T
115
+ logits_per_text = logit_scale * text_features @ all_image_features.T
116
+ else:
117
+ logits_per_image = logit_scale * all_image_features @ all_text_features.T
118
+ logits_per_text = logits_per_image.T
119
+ else:
120
+ logits_per_image = logit_scale * image_features @ text_features.T
121
+ logits_per_text = logit_scale * text_features @ image_features.T
122
+
123
+ return logits_per_image, logits_per_text
124
+
125
+ def forward(self, image_features, text_features, logit_scale, output_dict=False):
126
+ device = image_features.device
127
+ logits_per_image, logits_per_text = self.get_logits(image_features, text_features, logit_scale)
128
+
129
+ labels = self.get_ground_truth(device, logits_per_image.shape[0])
130
+
131
+ total_loss = (
132
+ F.cross_entropy(logits_per_image, labels) +
133
+ F.cross_entropy(logits_per_text, labels)
134
+ ) / 2
135
+
136
+ return {"contrastive_loss": total_loss} if output_dict else total_loss
137
+
138
+ class M3Patchilizer:
139
+ def __init__(self):
140
+ self.delimiters = ["|:", "::", ":|", "[|", "||", "|]", "|"]
141
+ self.regexPattern = '(' + '|'.join(map(re.escape, self.delimiters)) + ')'
142
+ self.pad_token_id = 0
143
+ self.bos_token_id = 1
144
+ self.eos_token_id = 2
145
+ self.mask_token_id = 3
146
+
147
+ def split_bars(self, body):
148
+ bars = re.split(self.regexPattern, ''.join(body))
149
+ bars = list(filter(None, bars)) # remove empty strings
150
+ if bars[0] in self.delimiters:
151
+ bars[1] = bars[0] + bars[1]
152
+ bars = bars[1:]
153
+ bars = [bars[i * 2] + bars[i * 2 + 1] for i in range(len(bars) // 2)]
154
+ return bars
155
+
156
+ def bar2patch(self, bar, patch_size=PATCH_SIZE):
157
+ patch = [self.bos_token_id] + [ord(c) for c in bar] + [self.eos_token_id]
158
+ patch = patch[:patch_size]
159
+ patch += [self.pad_token_id] * (patch_size - len(patch))
160
+ return patch
161
+
162
+ def patch2bar(self, patch):
163
+ return ''.join(chr(idx) if idx > self.mask_token_id else '' for idx in patch)
164
+
165
+ def encode(self,
166
+ item,
167
+ patch_size=PATCH_SIZE,
168
+ add_special_patches=False,
169
+ truncate=False,
170
+ random_truncate=False):
171
+ item = item.replace("L:1/8\n", "")
172
+ item = unidecode(item)
173
+ lines = re.findall(r'.*?\n|.*$', item)
174
+ lines = list(filter(None, lines)) # remove empty lines
175
+
176
+ patches = []
177
+
178
+ if lines[0].split(" ")[0] == "ticks_per_beat":
179
+ patch = ""
180
+ for line in lines:
181
+ if patch.startswith(line.split(" ")[0]) and (len(patch) + len(" ".join(line.split(" ")[1:])) <= patch_size-2):
182
+ patch = patch[:-1] + "\t" + " ".join(line.split(" ")[1:])
183
+ else:
184
+ if patch:
185
+ patches.append(patch)
186
+ patch = line
187
+ if patch!="":
188
+ patches.append(patch)
189
+ else:
190
+ for line in lines:
191
+ if len(line) > 1 and ((line[0].isalpha() and line[1] == ':') or line.startswith('%%')):
192
+ patches.append(line)
193
+ else:
194
+ bars = self.split_bars(line)
195
+ if bars:
196
+ bars[-1] += '\n'
197
+ patches.extend(bars)
198
+
199
+ if add_special_patches:
200
+ bos_patch = chr(self.bos_token_id) * patch_size
201
+ eos_patch = chr(self.eos_token_id) * patch_size
202
+ patches = [bos_patch] + patches + [eos_patch]
203
+
204
+ if len(patches) > PATCH_LENGTH and truncate:
205
+ choices = ["head", "tail", "middle"]
206
+ choice = random.choice(choices)
207
+ if choice=="head" or random_truncate==False:
208
+ patches = patches[:PATCH_LENGTH]
209
+ elif choice=="tail":
210
+ patches = patches[-PATCH_LENGTH:]
211
+ else:
212
+ start = random.randint(1, len(patches)-PATCH_LENGTH)
213
+ patches = patches[start:start+PATCH_LENGTH]
214
+
215
+ patches = [self.bar2patch(patch) for patch in patches]
216
+
217
+ return patches
218
+
219
+ def decode(self, patches):
220
+ return ''.join(self.patch2bar(patch) for patch in patches)
221
+
222
+ class M3PatchEncoder(PreTrainedModel):
223
+ def __init__(self, config):
224
+ super(M3PatchEncoder, self).__init__(config)
225
+ self.patch_embedding = torch.nn.Linear(PATCH_SIZE*128, M3_HIDDEN_SIZE)
226
+ torch.nn.init.normal_(self.patch_embedding.weight, std=0.02)
227
+ self.base = BertModel(config=config)
228
+ self.pad_token_id = 0
229
+ self.bos_token_id = 1
230
+ self.eos_token_id = 2
231
+ self.mask_token_id = 3
232
+
233
+ def forward(self,
234
+ input_patches, # [batch_size, seq_length, hidden_size]
235
+ input_masks): # [batch_size, seq_length]
236
+ # Transform input_patches into embeddings
237
+ input_patches = torch.nn.functional.one_hot(input_patches, num_classes=128)
238
+ input_patches = input_patches.reshape(len(input_patches), -1, PATCH_SIZE*128).type(torch.FloatTensor)
239
+ input_patches = self.patch_embedding(input_patches.to(self.device))
240
+
241
+ # Apply BERT model to input_patches and input_masks
242
+ return self.base(inputs_embeds=input_patches, attention_mask=input_masks)
243
+
244
+ class M3TokenDecoder(PreTrainedModel):
245
+ def __init__(self, config):
246
+ super(M3TokenDecoder, self).__init__(config)
247
+ self.base = GPT2LMHeadModel(config=config)
248
+ self.pad_token_id = 0
249
+ self.bos_token_id = 1
250
+ self.eos_token_id = 2
251
+ self.mask_token_id = 3
252
+
253
+ def forward(self,
254
+ patch_features, # [batch_size, hidden_size]
255
+ target_patches): # [batch_size, seq_length]
256
+ # get input embeddings
257
+ inputs_embeds = torch.nn.functional.embedding(target_patches, self.base.transformer.wte.weight)
258
+
259
+ # concatenate the encoded patches with the input embeddings
260
+ inputs_embeds = torch.cat((patch_features.unsqueeze(1), inputs_embeds[:,1:,:]), dim=1)
261
+
262
+ # preparing the labels for model training
263
+ target_masks = target_patches == self.pad_token_id
264
+ target_patches = target_patches.clone().masked_fill_(target_masks, -100)
265
+
266
+ # get the attention mask
267
+ target_masks = ~target_masks
268
+ target_masks = target_masks.type(torch.int)
269
+
270
+ return self.base(inputs_embeds=inputs_embeds,
271
+ attention_mask=target_masks,
272
+ labels=target_patches)
273
+
274
+ def generate(self,
275
+ patch_feature,
276
+ tokens):
277
+ # reshape the patch_feature and tokens
278
+ patch_feature = patch_feature.reshape(1, 1, -1)
279
+ tokens = tokens.reshape(1, -1)
280
+
281
+ # get input embeddings
282
+ tokens = torch.nn.functional.embedding(tokens, self.base.transformer.wte.weight)
283
+
284
+ # concatenate the encoded patches with the input embeddings
285
+ tokens = torch.cat((patch_feature, tokens[:,1:,:]), dim=1)
286
+
287
+ # get the outputs from the model
288
+ outputs = self.base(inputs_embeds=tokens)
289
+
290
+ # get the probabilities of the next token
291
+ probs = torch.nn.functional.softmax(outputs.logits.squeeze(0)[-1], dim=-1)
292
+
293
+ return probs.detach().cpu().numpy()
294
+
295
+ class M3Model(PreTrainedModel):
296
+ def __init__(self, encoder_config, decoder_config):
297
+ super(M3Model, self).__init__(encoder_config)
298
+ self.encoder = M3PatchEncoder(encoder_config)
299
+ self.decoder = M3TokenDecoder(decoder_config)
300
+ self.pad_token_id = 0
301
+ self.bos_token_id = 1
302
+ self.eos_token_id = 2
303
+ self.mask_token_id = 3
304
+
305
+ def forward(self,
306
+ input_patches, # [batch_size, seq_length, hidden_size]
307
+ input_masks, # [batch_size, seq_length]
308
+ selected_indices, # [batch_size, seq_length]
309
+ target_patches): # [batch_size, seq_length, hidden_size]
310
+ input_patches = input_patches.reshape(len(input_patches), -1, PATCH_SIZE).to(self.device)
311
+ input_masks = input_masks.to(self.device)
312
+ selected_indices = selected_indices.to(self.device)
313
+ target_patches = target_patches.reshape(len(target_patches), -1, PATCH_SIZE).to(self.device)
314
+
315
+ # Pass the input_patches and input_masks through the encoder
316
+ outputs = self.encoder(input_patches, input_masks)["last_hidden_state"]
317
+
318
+ # Use selected_indices to form target_patches
319
+ target_patches = target_patches[selected_indices.bool()]
320
+ patch_features = outputs[selected_indices.bool()]
321
+
322
+ # Pass patch_features and target_patches through the decoder
323
+ return self.decoder(patch_features, target_patches)
324
+
325
+ class CLaMP3Model(PreTrainedModel):
326
+ def __init__(self,
327
+ audio_config,
328
+ symbolic_config,
329
+ global_rank=None,
330
+ world_size=None,
331
+ text_model_name=TEXT_MODEL_NAME,
332
+ hidden_size=CLAMP3_HIDDEN_SIZE,
333
+ load_m3=CLAMP3_LOAD_M3):
334
+ super(CLaMP3Model, self).__init__(symbolic_config)
335
+
336
+ self.text_model = AutoModel.from_pretrained(text_model_name) # Load the text model
337
+ self.text_proj = torch.nn.Linear(self.text_model.config.hidden_size, hidden_size) # Linear layer for text projections
338
+ torch.nn.init.normal_(self.text_proj.weight, std=0.02) # Initialize weights with normal distribution
339
+
340
+ self.symbolic_model = M3PatchEncoder(symbolic_config) # Initialize the symbolic model
341
+ self.symbolic_proj = torch.nn.Linear(M3_HIDDEN_SIZE, hidden_size) # Linear layer for symbolic projections
342
+ torch.nn.init.normal_(self.symbolic_proj.weight, std=0.02) # Initialize weights with normal distribution
343
+
344
+ self.audio_model = BertModel(audio_config) # Initialize the audio model
345
+ self.audio_proj = torch.nn.Linear(audio_config.hidden_size, hidden_size) # Linear layer for audio projections
346
+ torch.nn.init.normal_(self.audio_proj.weight, std=0.02) # Initialize weights with normal distribution
347
+
348
+ if global_rank==None or world_size==None:
349
+ global_rank = 0
350
+ world_size = 1
351
+
352
+ self.loss_fn = ClipLoss(local_loss=False,
353
+ gather_with_grad=True,
354
+ cache_labels=False,
355
+ rank=global_rank,
356
+ world_size=world_size,
357
+ use_horovod=False)
358
+
359
+ if load_m3 and os.path.exists(M3_WEIGHTS_PATH):
360
+ checkpoint = torch.load(M3_WEIGHTS_PATH, map_location='cpu', weights_only=True)
361
+ decoder_config = GPT2Config(vocab_size=128,
362
+ n_positions=PATCH_SIZE,
363
+ n_embd=M3_HIDDEN_SIZE,
364
+ n_layer=TOKEN_NUM_LAYERS,
365
+ n_head=M3_HIDDEN_SIZE//64,
366
+ n_inner=M3_HIDDEN_SIZE*4)
367
+ model = M3Model(symbolic_config, decoder_config)
368
+ model.load_state_dict(checkpoint['model'])
369
+ self.symbolic_model = model.encoder
370
+ model = None
371
+ print(f"Successfully Loaded M3 Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}")
372
+
373
+ def set_trainable(self, freeze_list):
374
+ if "text_model" in freeze_list:
375
+ self.text_model.eval()
376
+ for param in self.text_model.parameters():
377
+ param.requires_grad = False
378
+ print("Text Model Frozen")
379
+ else:
380
+ self.text_model.train()
381
+ for param in self.text_model.parameters():
382
+ param.requires_grad = True
383
+ print("Text Model Training")
384
+
385
+ if "text_proj" in freeze_list:
386
+ self.text_proj.eval()
387
+ for param in self.text_proj.parameters():
388
+ param.requires_grad = False
389
+ print("Text Projection Layer Frozen")
390
+ else:
391
+ self.text_proj.train()
392
+ for param in self.text_proj.parameters():
393
+ param.requires_grad = True
394
+ print("Text Projection Layer Training")
395
+
396
+ if "symbolic_model" in freeze_list:
397
+ self.symbolic_model.eval()
398
+ for param in self.symbolic_model.parameters():
399
+ param.requires_grad = False
400
+ print("Symbolic Model Frozen")
401
+ else:
402
+ self.symbolic_model.train()
403
+ for param in self.symbolic_model.parameters():
404
+ param.requires_grad = True
405
+ print("Symbolic Model Training")
406
+
407
+ if "symbolic_proj" in freeze_list:
408
+ self.symbolic_proj.eval()
409
+ for param in self.symbolic_proj.parameters():
410
+ param.requires_grad = False
411
+ print("Symbolic Projection Layer Frozen")
412
+ else:
413
+ self.symbolic_proj.train()
414
+ for param in self.symbolic_proj.parameters():
415
+ param.requires_grad = True
416
+ print("Symbolic Projection Layer Training")
417
+
418
+ if "audio_model" in freeze_list:
419
+ self.audio_model.eval()
420
+ for param in self.audio_model.parameters():
421
+ param.requires_grad = False
422
+ print("Audio Model Frozen")
423
+ else:
424
+ self.audio_model.train()
425
+ for param in self.audio_model.parameters():
426
+ param.requires_grad = True
427
+ print("Audio Model Training")
428
+
429
+ if "audio_proj" in freeze_list:
430
+ self.audio_proj.eval()
431
+ for param in self.audio_proj.parameters():
432
+ param.requires_grad = False
433
+ print("Audio Projection Layer Frozen")
434
+ else:
435
+ self.audio_proj.train()
436
+ for param in self.audio_proj.parameters():
437
+ param.requires_grad = True
438
+ print("Audio Projection Layer Training")
439
+
440
+ def avg_pooling(self, input_features, input_masks):
441
+ input_masks = input_masks.unsqueeze(-1).to(self.device) # add a dimension to match the feature dimension
442
+ input_features = input_features * input_masks # apply mask to input_features
443
+ avg_pool = input_features.sum(dim=1) / input_masks.sum(dim=1) # calculate average pooling
444
+
445
+ return avg_pool
446
+
447
+ def get_text_features(self,
448
+ text_inputs,
449
+ text_masks,
450
+ get_global=False):
451
+ text_features = self.text_model(text_inputs.to(self.device),
452
+ attention_mask=text_masks.to(self.device))['last_hidden_state']
453
+
454
+ if get_global:
455
+ text_features = self.avg_pooling(text_features, text_masks)
456
+ text_features = self.text_proj(text_features)
457
+
458
+ return text_features
459
+
460
+ def get_symbolic_features(self,
461
+ symbolic_inputs,
462
+ symbolic_masks,
463
+ get_global=False):
464
+ symbolic_features = self.symbolic_model(symbolic_inputs.to(self.device),
465
+ symbolic_masks.to(self.device))['last_hidden_state']
466
+
467
+ if get_global:
468
+ symbolic_features = self.avg_pooling(symbolic_features, symbolic_masks)
469
+ symbolic_features = self.symbolic_proj(symbolic_features)
470
+
471
+ return symbolic_features
472
+
473
+ def get_audio_features(self,
474
+ audio_inputs,
475
+ audio_masks,
476
+ get_global=False):
477
+ audio_features = self.audio_model(inputs_embeds=audio_inputs.to(self.device),
478
+ attention_mask=audio_masks.to(self.device))['last_hidden_state']
479
+
480
+ if get_global:
481
+ audio_features = self.avg_pooling(audio_features, audio_masks)
482
+ audio_features = self.audio_proj(audio_features)
483
+
484
+ return audio_features
485
+
486
+ def forward(self,
487
+ text_inputs, # [batch_size, seq_length]
488
+ text_masks, # [batch_size, seq_length]
489
+ music_inputs, # [batch_size, seq_length, hidden_size]
490
+ music_masks, # [batch_size, seq_length]
491
+ music_modality): # "symbolic" or "audio"
492
+ # Compute the text features
493
+ text_features = self.get_text_features(text_inputs, text_masks, get_global=True)
494
+
495
+ # Compute the music features
496
+ if music_modality=="symbolic":
497
+ music_features = self.get_symbolic_features(music_inputs, music_masks, get_global=True)
498
+ elif music_modality=="audio":
499
+ music_features = self.get_audio_features(music_inputs, music_masks, get_global=True)
500
+ else:
501
+ raise ValueError("music_modality must be either 'symbolic' or 'audio'")
502
+
503
+ return self.loss_fn(text_features,
504
+ music_features,
505
+ LOGIT_SCALE,
506
+ output_dict=False)
507
+
508
+ def split_data(data, eval_ratio=EVAL_SPLIT):
509
+ random.shuffle(data)
510
+ split_idx = int(len(data)*eval_ratio)
511
+ eval_set = data[:split_idx]
512
+ train_set = data[split_idx:]
513
+ return train_set, eval_set
514
+
515
+ def mask_patches(target_patches, patchilizer, mode):
516
+ indices = list(range(len(target_patches)))
517
+ random.shuffle(indices)
518
+ selected_indices = indices[:math.ceil(M3_MASK_RATIO*len(indices))]
519
+ sorted_indices = sorted(selected_indices)
520
+ input_patches = torch.tensor(target_patches)
521
+
522
+ if mode=="eval":
523
+ choice = "original"
524
+ else:
525
+ choice = random.choices(["mask", "shuffle", "original"], weights=[0.8, 0.1, 0.1])[0]
526
+
527
+ if choice=="mask":
528
+ input_patches[sorted_indices] = torch.tensor([patchilizer.mask_token_id]*PATCH_SIZE)
529
+ elif choice=="shuffle":
530
+ for idx in sorted_indices:
531
+ patch = input_patches[idx]
532
+ try:
533
+ index_eos = (patch == patchilizer.eos_token_id).nonzero().item()
534
+ except:
535
+ index_eos = len(patch)
536
+
537
+ indices = list(range(1, index_eos))
538
+ random.shuffle(indices)
539
+ indices = [0] + indices + list(range(index_eos, len(patch)))
540
+ input_patches[idx] = patch[indices]
541
+
542
+ selected_indices = torch.zeros(len(target_patches))
543
+ selected_indices[sorted_indices] = 1.
544
+
545
+ return input_patches, selected_indices
546
+
547
+ def remove_instrument_info(item):
548
+ # remove instrument information from symbolic music
549
+ lines = re.findall(r'.*?\n|.*$', item)
550
+ lines = list(filter(None, lines))
551
+ if lines[0].split(" ")[0] == "ticks_per_beat":
552
+ type = "mtf"
553
+ else:
554
+ type = "abc"
555
+
556
+ cleaned_lines = []
557
+ for line in lines:
558
+ if type=="abc" and line.startswith("V:"):
559
+ # find the position of " nm=" or " snm="
560
+ nm_pos = line.find(" nm=")
561
+ snm_pos = line.find(" snm=")
562
+ # keep the part before " nm=" or " snm="
563
+ if nm_pos != -1:
564
+ line = line[:nm_pos]
565
+ elif snm_pos != -1:
566
+ line = line[:snm_pos]
567
+ if nm_pos != -1 or snm_pos != -1:
568
+ line += "\n"
569
+ elif type=="mtf" and line.startswith("program_change"):
570
+ line = " ".join(line.split(" ")[:-1]) + " 0\n"
571
+
572
+ cleaned_lines.append(line)
573
+
574
+ return ''.join(cleaned_lines)
wikimt-x-public.jsonl ADDED
The diff for this file is too large to render. See raw diff