Spaces:
Running
Running
import os | |
import torch | |
import numpy as np | |
from tqdm import tqdm | |
from config import * | |
from utils import * | |
from samplings import * | |
from accelerate import Accelerator | |
from transformers import BertConfig, AutoTokenizer | |
import argparse | |
# Parse command-line arguments | |
parser = argparse.ArgumentParser(description="Feature extraction for CLaMP3.") | |
parser.add_argument("--epoch", type=str, default=None, help="Epoch of the checkpoint to load.") | |
parser.add_argument("input_dir", type=str, help="Directory containing input data files.") | |
parser.add_argument("output_dir", type=str, help="Directory to save the output features.") | |
parser.add_argument("--get_global", action="store_true", help="Get global feature.") | |
args = parser.parse_args() | |
# Retrieve arguments | |
epoch = args.epoch | |
input_dir = args.input_dir | |
output_dir = args.output_dir | |
get_global = args.get_global | |
files = [] | |
for root, dirs, fs in os.walk(input_dir): | |
for f in fs: | |
if f.endswith(".txt") or f.endswith(".abc") or f.endswith(".mtf") or f.endswith(".npy"): | |
files.append(os.path.join(root, f)) | |
print(f"Found {len(files)} files in total") | |
# Initialize accelerator and device | |
accelerator = Accelerator() | |
device = accelerator.device | |
print("Using device:", device) | |
# Model and configuration setup | |
audio_config = BertConfig(vocab_size=1, | |
hidden_size=AUDIO_HIDDEN_SIZE, | |
num_hidden_layers=AUDIO_NUM_LAYERS, | |
num_attention_heads=AUDIO_HIDDEN_SIZE//64, | |
intermediate_size=AUDIO_HIDDEN_SIZE*4, | |
max_position_embeddings=MAX_AUDIO_LENGTH) | |
symbolic_config = BertConfig(vocab_size=1, | |
hidden_size=M3_HIDDEN_SIZE, | |
num_hidden_layers=PATCH_NUM_LAYERS, | |
num_attention_heads=M3_HIDDEN_SIZE//64, | |
intermediate_size=M3_HIDDEN_SIZE*4, | |
max_position_embeddings=PATCH_LENGTH) | |
model = CLaMP3Model(audio_config=audio_config, | |
symbolic_config=symbolic_config, | |
text_model_name=TEXT_MODEL_NAME, | |
hidden_size=CLAMP3_HIDDEN_SIZE, | |
load_m3=CLAMP3_LOAD_M3) | |
model = model.to(device) | |
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_NAME) | |
patchilizer = M3Patchilizer() | |
# print parameter number | |
print("Total Parameter Number: "+str(sum(p.numel() for p in model.parameters()))) | |
# Load model weights | |
model.eval() | |
checkpoint_path = CLAMP3_WEIGHTS_PATH | |
if epoch is not None: | |
checkpoint_path = CLAMP3_WEIGHTS_PATH.replace(".pth", f"_{epoch}.pth") | |
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=True) | |
print(f"Successfully Loaded CLaMP 3 Checkpoint from Epoch {checkpoint['epoch']} with loss {checkpoint['min_eval_loss']}") | |
model.load_state_dict(checkpoint['model']) | |
def extract_feature(filename, get_global=get_global): | |
if not filename.endswith(".npy"): | |
with open(filename, "r", encoding="utf-8") as f: | |
item = f.read() | |
if filename.endswith(".txt"): | |
item = list(set(item.split("\n"))) | |
item = "\n".join(item) | |
item = item.split("\n") | |
item = [c for c in item if len(c) > 0] | |
item = tokenizer.sep_token.join(item) | |
input_data = tokenizer(item, return_tensors="pt") | |
input_data = input_data['input_ids'].squeeze(0) | |
max_input_length = MAX_TEXT_LENGTH | |
elif filename.endswith(".abc") or filename.endswith(".mtf"): | |
input_data = patchilizer.encode(item, add_special_patches=True) | |
input_data = torch.tensor(input_data) | |
max_input_length = PATCH_LENGTH | |
elif filename.endswith(".npy"): | |
input_data = np.load(filename) | |
input_data = torch.tensor(input_data) | |
input_data = input_data.reshape(-1, input_data.size(-1)) | |
zero_vec = torch.zeros((1, input_data.size(-1))) | |
input_data = torch.cat((zero_vec, input_data, zero_vec), 0) | |
max_input_length = MAX_AUDIO_LENGTH | |
else: | |
raise ValueError(f"Unsupported file type: {filename}, only support .txt, .abc, .mtf, .npy files") | |
segment_list = [] | |
for i in range(0, len(input_data), max_input_length): | |
segment_list.append(input_data[i:i+max_input_length]) | |
segment_list[-1] = input_data[-max_input_length:] | |
last_hidden_states_list = [] | |
for input_segment in segment_list: | |
input_masks = torch.tensor([1]*input_segment.size(0)) | |
if filename.endswith(".txt"): | |
pad_indices = torch.ones(MAX_TEXT_LENGTH - input_segment.size(0)).long() * tokenizer.pad_token_id | |
elif filename.endswith(".abc") or filename.endswith(".mtf"): | |
pad_indices = torch.ones((PATCH_LENGTH - input_segment.size(0), PATCH_SIZE)).long() * patchilizer.pad_token_id | |
else: | |
pad_indices = torch.ones((MAX_AUDIO_LENGTH - input_segment.size(0), AUDIO_HIDDEN_SIZE)).float() * 0. | |
input_masks = torch.cat((input_masks, torch.zeros(max_input_length - input_segment.size(0))), 0) | |
input_segment = torch.cat((input_segment, pad_indices), 0) | |
if filename.endswith(".txt"): | |
last_hidden_states = model.get_text_features(text_inputs=input_segment.unsqueeze(0).to(device), | |
text_masks=input_masks.unsqueeze(0).to(device), | |
get_global=get_global) | |
elif filename.endswith(".abc") or filename.endswith(".mtf"): | |
last_hidden_states = model.get_symbolic_features(symbolic_inputs=input_segment.unsqueeze(0).to(device), | |
symbolic_masks=input_masks.unsqueeze(0).to(device), | |
get_global=get_global) | |
else: | |
last_hidden_states = model.get_audio_features(audio_inputs=input_segment.unsqueeze(0).to(device), | |
audio_masks=input_masks.unsqueeze(0).to(device), | |
get_global=get_global) | |
if not get_global: | |
last_hidden_states = last_hidden_states[:, :input_masks.sum().long().item(), :] | |
last_hidden_states_list.append(last_hidden_states) | |
if not get_global: | |
last_hidden_states_list = [last_hidden_states[0] for last_hidden_states in last_hidden_states_list] | |
last_hidden_states_list[-1] = last_hidden_states_list[-1][-(len(input_data)%max_input_length):] | |
last_hidden_states_list = torch.concat(last_hidden_states_list, 0) | |
else: | |
full_chunk_cnt = len(input_data) // max_input_length | |
remain_chunk_len = len(input_data) % max_input_length | |
if remain_chunk_len == 0: | |
feature_weights = torch.tensor([max_input_length] * full_chunk_cnt, device=device).view(-1, 1) | |
else: | |
feature_weights = torch.tensor([max_input_length] * full_chunk_cnt + [remain_chunk_len], device=device).view(-1, 1) | |
last_hidden_states_list = torch.concat(last_hidden_states_list, 0) | |
last_hidden_states_list = last_hidden_states_list * feature_weights | |
last_hidden_states_list = last_hidden_states_list.sum(dim=0) / feature_weights.sum() | |
return last_hidden_states_list | |
def process_directory(input_dir, output_dir, files): | |
# calculate the number of files to process per GPU | |
num_files_per_gpu = len(files) // accelerator.num_processes | |
# calculate the start and end index for the current GPU | |
start_idx = accelerator.process_index * num_files_per_gpu | |
end_idx = start_idx + num_files_per_gpu | |
if accelerator.process_index == accelerator.num_processes - 1: | |
end_idx = len(files) | |
files_to_process = files[start_idx:end_idx] | |
# process the files | |
for file in tqdm(files_to_process): | |
output_subdir = output_dir + os.path.dirname(file)[len(input_dir):] | |
try: | |
os.makedirs(output_subdir, exist_ok=True) | |
except Exception as e: | |
print(output_subdir + " can not be created\n" + str(e)) | |
output_file = os.path.join(output_subdir, os.path.splitext(os.path.basename(file))[0] + ".npy") | |
if os.path.exists(output_file): | |
print(f"Skipping {file}, output already exists") | |
continue | |
try: | |
with torch.no_grad(): | |
features = extract_feature(file).unsqueeze(0) | |
np.save(output_file, features.detach().cpu().numpy()) | |
except Exception as e: | |
print(f"Failed to process {file}: {e}") | |
# process the files | |
process_directory(input_dir, output_dir, files) |