clamp3 / extract_clamp3.py
sander-wood's picture
Upload 8 files
ad822ab verified
raw
history blame
8.86 kB
import os
import torch
import numpy as np
from tqdm import tqdm
from config import *
from utils import *
from samplings import *
from accelerate import Accelerator
from transformers import BertConfig, AutoTokenizer
import argparse
# Parse command-line arguments
parser = argparse.ArgumentParser(description="Feature extraction for CLaMP3.")
parser.add_argument("--epoch", type=str, default=None, help="Epoch of the checkpoint to load.")
parser.add_argument("input_dir", type=str, help="Directory containing input data files.")
parser.add_argument("output_dir", type=str, help="Directory to save the output features.")
parser.add_argument("--get_global", action="store_true", help="Get global feature.")
args = parser.parse_args()
# Retrieve arguments
epoch = args.epoch
input_dir = args.input_dir
output_dir = args.output_dir
get_global = args.get_global
files = []
for root, dirs, fs in os.walk(input_dir):
for f in fs:
if f.endswith(".txt") or f.endswith(".abc") or f.endswith(".mtf") or f.endswith(".npy"):
files.append(os.path.join(root, f))
print(f"Found {len(files)} files in total")
# Initialize accelerator and device
accelerator = Accelerator()
device = accelerator.device
print("Using device:", device)
# Model and configuration setup
audio_config = BertConfig(vocab_size=1,
hidden_size=AUDIO_HIDDEN_SIZE,
num_hidden_layers=AUDIO_NUM_LAYERS,
num_attention_heads=AUDIO_HIDDEN_SIZE//64,
intermediate_size=AUDIO_HIDDEN_SIZE*4,
max_position_embeddings=MAX_AUDIO_LENGTH)
symbolic_config = BertConfig(vocab_size=1,
hidden_size=M3_HIDDEN_SIZE,
num_hidden_layers=PATCH_NUM_LAYERS,
num_attention_heads=M3_HIDDEN_SIZE//64,
intermediate_size=M3_HIDDEN_SIZE*4,
max_position_embeddings=PATCH_LENGTH)
model = CLaMP3Model(audio_config=audio_config,
symbolic_config=symbolic_config,
text_model_name=TEXT_MODEL_NAME,
hidden_size=CLAMP3_HIDDEN_SIZE,
load_m3=CLAMP3_LOAD_M3)
model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME)
patchilizer = M3Patchilizer()
# print parameter number
print("Total Parameter Number: "+str(sum(p.numel() for p in model.parameters())))
# Load model weights
model.eval()
checkpoint_path = CLAMP3_WEIGHTS_PATH
if epoch is not None:
checkpoint_path = CLAMP3_WEIGHTS_PATH.replace(".pth", f"_{epoch}.pth")
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
print(f"Successfully Loaded CLaMP 3 Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}")
model.load_state_dict(checkpoint['model'])
def extract_feature(filename, get_global=get_global):
if not filename.endswith(".npy"):
with open(filename, "r", encoding="utf-8") as f:
item = f.read()
if filename.endswith(".txt"):
item = list(set(item.split("\n")))
item = "\n".join(item)
item = item.split("\n")
item = [c for c in item if len(c) > 0]
item = tokenizer.sep_token.join(item)
input_data = tokenizer(item, return_tensors="pt")
input_data = input_data['input_ids'].squeeze(0)
max_input_length = MAX_TEXT_LENGTH
elif filename.endswith(".abc") or filename.endswith(".mtf"):
input_data = patchilizer.encode(item, add_special_patches=True)
input_data = torch.tensor(input_data)
max_input_length = PATCH_LENGTH
elif filename.endswith(".npy"):
input_data = np.load(filename)
input_data = torch.tensor(input_data)
input_data = input_data.reshape(-1, input_data.size(-1))
zero_vec = torch.zeros((1, input_data.size(-1)))
input_data = torch.cat((zero_vec, input_data, zero_vec), 0)
max_input_length = MAX_AUDIO_LENGTH
else:
raise ValueError(f"Unsupported file type: {filename}, only support .txt, .abc, .mtf, .npy files")
segment_list = []
for i in range(0, len(input_data), max_input_length):
segment_list.append(input_data[i:i+max_input_length])
segment_list[-1] = input_data[-max_input_length:]
last_hidden_states_list = []
for input_segment in segment_list:
input_masks = torch.tensor([1]*input_segment.size(0))
if filename.endswith(".txt"):
pad_indices = torch.ones(MAX_TEXT_LENGTH - input_segment.size(0)).long() * tokenizer.pad_token_id
elif filename.endswith(".abc") or filename.endswith(".mtf"):
pad_indices = torch.ones((PATCH_LENGTH - input_segment.size(0), PATCH_SIZE)).long() * patchilizer.pad_token_id
else:
pad_indices = torch.ones((MAX_AUDIO_LENGTH - input_segment.size(0), AUDIO_HIDDEN_SIZE)).float() * 0.
input_masks = torch.cat((input_masks, torch.zeros(max_input_length - input_segment.size(0))), 0)
input_segment = torch.cat((input_segment, pad_indices), 0)
if filename.endswith(".txt"):
last_hidden_states = model.get_text_features(text_inputs=input_segment.unsqueeze(0).to(device),
text_masks=input_masks.unsqueeze(0).to(device),
get_global=get_global)
elif filename.endswith(".abc") or filename.endswith(".mtf"):
last_hidden_states = model.get_symbolic_features(symbolic_inputs=input_segment.unsqueeze(0).to(device),
symbolic_masks=input_masks.unsqueeze(0).to(device),
get_global=get_global)
else:
last_hidden_states = model.get_audio_features(audio_inputs=input_segment.unsqueeze(0).to(device),
audio_masks=input_masks.unsqueeze(0).to(device),
get_global=get_global)
if not get_global:
last_hidden_states = last_hidden_states[:, :input_masks.sum().long().item(), :]
last_hidden_states_list.append(last_hidden_states)
if not get_global:
last_hidden_states_list = [last_hidden_states[0] for last_hidden_states in last_hidden_states_list]
last_hidden_states_list[-1] = last_hidden_states_list[-1][-(len(input_data)%max_input_length):]
last_hidden_states_list = torch.concat(last_hidden_states_list, 0)
else:
full_chunk_cnt = len(input_data) // max_input_length
remain_chunk_len = len(input_data) % max_input_length
if remain_chunk_len == 0:
feature_weights = torch.tensor([max_input_length] * full_chunk_cnt, device=device).view(-1, 1)
else:
feature_weights = torch.tensor([max_input_length] * full_chunk_cnt + [remain_chunk_len], device=device).view(-1, 1)
last_hidden_states_list = torch.concat(last_hidden_states_list, 0)
last_hidden_states_list = last_hidden_states_list * feature_weights
last_hidden_states_list = last_hidden_states_list.sum(dim=0) / feature_weights.sum()
return last_hidden_states_list
def process_directory(input_dir, output_dir, files):
# calculate the number of files to process per GPU
num_files_per_gpu = len(files) // accelerator.num_processes
# calculate the start and end index for the current GPU
start_idx = accelerator.process_index * num_files_per_gpu
end_idx = start_idx + num_files_per_gpu
if accelerator.process_index == accelerator.num_processes - 1:
end_idx = len(files)
files_to_process = files[start_idx:end_idx]
# process the files
for file in tqdm(files_to_process):
output_subdir = output_dir + os.path.dirname(file)[len(input_dir):]
try:
os.makedirs(output_subdir, exist_ok=True)
except Exception as e:
print(output_subdir + " can not be created\n" + str(e))
output_file = os.path.join(output_subdir, os.path.splitext(os.path.basename(file))[0] + ".npy")
if os.path.exists(output_file):
print(f"Skipping {file}, output already exists")
continue
try:
with torch.no_grad():
features = extract_feature(file).unsqueeze(0)
np.save(output_file, features.detach().cpu().numpy())
except Exception as e:
print(f"Failed to process {file}: {e}")
# process the files
process_directory(input_dir, output_dir, files)