mimbres's picture
.
a03c9b4
raw
history blame
50.1 kB
# Copyright 2024 The YourMT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Please see the details in the LICENSE file.
"""ymt3.py"""
import os
from typing import Union, Optional, Tuple, Dict, List, Any
from collections import Counter
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
import torchaudio # for debugging audio
import pytorch_lightning as pl
import numpy as np
import wandb
from einops import rearrange
from transformers import T5Config
from model.t5mod import T5EncoderYMT3, T5DecoderYMT3, MultiChannelT5Decoder
from model.t5mod_helper import task_cond_dec_generate
from model.perceiver_mod import PerceiverTFEncoder
from model.perceiver_helper import PerceiverTFConfig
from model.conformer_mod import ConformerYMT3Encoder
from model.conformer_helper import ConformerYMT3Config
from model.lm_head import LMHead
from model.pitchshift_layer import PitchShiftLayer
from model.spectrogram import get_spectrogram_layer_from_audio_cfg
from model.conv_block import PreEncoderBlockRes3B
from model.conv_block import PreEncoderBlockHFTT, PreEncoderBlockRes3BHFTT # added for hFTT-like pre-encoder
from model.projection_layer import get_projection_layer, get_multi_channel_projection_layer
from model.optimizers import get_optimizer
from model.lr_scheduler import get_lr_scheduler
from utils.note_event_dataclasses import Note
from utils.note2event import mix_notes
from utils.event2note import merge_zipped_note_events_and_ties_to_notes, DECODING_ERR_TYPES
from utils.metrics import compute_track_metrics
from utils.metrics import AMTMetrics
# from utils.utils import write_model_output_as_npy
from utils.utils import write_model_output_as_midi, create_inverse_vocab, write_err_cnt_as_json
from utils.utils import Timer
from utils.task_manager import TaskManager
from config.config import audio_cfg as default_audio_cfg
from config.config import model_cfg as default_model_cfg
from config.config import shared_cfg as default_shared_cfg
from config.config import T5_BASE_CFG
class YourMT3(pl.LightningModule):
"""YourMT3:
Lightning wrapper for multi-task music transcription Transformer.
"""
def __init__(
self,
audio_cfg: Optional[Dict] = None,
model_cfg: Optional[Dict] = None,
shared_cfg: Optional[Dict] = None,
pretrained: bool = False,
optimizer_name: str = 'adamwscale',
scheduler_name: str = 'cosine',
base_lr: float = None, # None: 'auto' for AdaFactor, 1e-3 for constant, 1e-2 for cosine
max_steps: Optional[int] = None,
weight_decay: float = 0.0,
init_factor: Optional[Union[str, float]] = None,
task_manager: TaskManager = TaskManager(),
eval_subtask_key: Optional[str] = "default",
eval_vocab: Optional[Dict] = None,
eval_drum_vocab: Optional[Dict] = None,
write_output_dir: Optional[str] = None,
write_output_vocab: Optional[Dict] = None,
onset_tolerance: float = 0.05,
add_pitch_class_metric: Optional[List[str]] = None,
add_melody_metric_to_singing: bool = True,
test_optimal_octave_shift: bool = False,
test_pitch_shift_layer: Optional[str] = None,
**kwargs: Any) -> None:
super().__init__()
if pretrained is True:
raise NotImplementedError("Pretrained model is not supported in this version.")
self.test_pitch_shift_layer = test_pitch_shift_layer # debug only
# Config
if model_cfg is None:
model_cfg = default_model_cfg # default config, not overwritten by args of trainer
if audio_cfg is None:
audio_cfg = default_audio_cfg # default config, not overwritten by args of trainer
if shared_cfg is None:
shared_cfg = default_shared_cfg # default config, not overwritten by args of trainer
# Spec Layer (need to define here to infer max token length)
self.spectrogram, spec_output_shape = get_spectrogram_layer_from_audio_cfg(
audio_cfg) # can be spec or melspec; output_shape is (T, F)
model_cfg["feat_length"] = spec_output_shape[0] # T of (T, F)
# Task manger and Tokens
self.task_manager = task_manager
self.max_total_token_length = self.task_manager.max_total_token_length
# Task Conditioning
self.use_task_cond_encoder = bool(model_cfg["use_task_conditional_encoder"])
self.use_task_cond_decoder = bool(model_cfg["use_task_conditional_decoder"])
# Select Encoder type, Model-specific Config
assert model_cfg["encoder_type"] in ["t5", "perceiver-tf", "conformer"]
assert model_cfg["decoder_type"] in ["t5", "multi-t5"]
self.encoder_type = model_cfg["encoder_type"] # {"t5", "perceiver-tf", "conformer"}
self.decoder_type = model_cfg["decoder_type"] # {"t5", "multi-t5"}
encoder_config = model_cfg["encoder"][self.encoder_type] # mutable
decoder_config = model_cfg["decoder"][self.decoder_type] # mutable
# Positional Encoding
if isinstance(model_cfg["num_max_positions"], str) and model_cfg["num_max_positions"] == 'auto':
encoder_config["num_max_positions"] = int(model_cfg["feat_length"] +
self.task_manager.max_task_token_length + 10)
decoder_config["num_max_positions"] = int(self.max_total_token_length + 10)
else:
assert isinstance(model_cfg["num_max_positions"], int)
encoder_config["num_max_positions"] = model_cfg["num_max_positions"]
decoder_config["num_max_positions"] = model_cfg["num_max_positions"]
# Select Pre-Encoder and Pre-Decoder type
if model_cfg["pre_encoder_type"] == "default":
model_cfg["pre_encoder_type"] = model_cfg["pre_encoder_type_default"].get(model_cfg["encoder_type"], None)
elif model_cfg["pre_encoder_type"] in [None, "none", "None", "0"]:
model_cfg["pre_encoder_type"] = None
if model_cfg["pre_decoder_type"] == "default":
model_cfg["pre_decoder_type"] = model_cfg["pre_decoder_type_default"].get(model_cfg["encoder_type"]).get(
model_cfg["decoder_type"], None)
elif model_cfg["pre_decoder_type"] in [None, "none", "None", "0"]:
model_cfg["pre_decoder_type"] = None
self.pre_encoder_type = model_cfg["pre_encoder_type"]
self.pre_decoder_type = model_cfg["pre_decoder_type"]
# Pre-encoder
self.pre_encoder = nn.Sequential()
if self.pre_encoder_type in ["conv", "conv1d_t", "conv1d_f"]:
kernel_size = (3, 3)
avp_kernel_size = (1, 2)
if self.pre_encoder_type == "conv1d_t":
kernel_size = (3, 1)
elif self.pre_encoder_type == "conv1d_f":
kernel_size = (1, 3)
self.pre_encoder.append(
PreEncoderBlockRes3B(1,
model_cfg["conv_out_channels"],
kernel_size=kernel_size,
avp_kernerl_size=avp_kernel_size,
activation="relu"))
pre_enc_output_shape = (spec_output_shape[0], spec_output_shape[1] // 2**3, model_cfg["conv_out_channels"]
) # (T, F, C) excluding batch dim
elif self.pre_encoder_type == "hftt":
self.pre_encoder.append(PreEncoderBlockHFTT())
pre_enc_output_shape = (spec_output_shape[0], spec_output_shape[1], 128) # (T, F, C) excluding batch dim
elif self.pre_encoder_type == "res3b_hftt":
self.pre_encoder.append(PreEncoderBlockRes3BHFTT())
pre_enc_output_shape = (spec_output_shape[0], spec_output_shape[1] // 2**3, 128)
else:
pre_enc_output_shape = spec_output_shape # (T, F) excluding batch dim
# Auto-infer `d_feat` and `d_model`, `vocab_size`, and `num_max_positions`
if isinstance(model_cfg["d_feat"], str) and model_cfg["d_feat"] == 'auto':
if self.encoder_type == "perceiver-tf" and encoder_config["attention_to_channel"] is True:
model_cfg["d_feat"] = pre_enc_output_shape[-2] # TODO: better readablity
else:
model_cfg["d_feat"] = pre_enc_output_shape[-1] # C of (T, F, C) or F or (T, F)
if self.encoder_type == "perceiver-tf" and isinstance(encoder_config["d_model"], str):
if encoder_config["d_model"] == 'q':
encoder_config["d_model"] = encoder_config["d_latent"]
elif encoder_config["d_model"] == 'kv':
encoder_config["d_model"] = model_cfg["d_feat"]
else:
raise ValueError(f"Unknown d_model: {encoder_config['d_model']}")
# # required for PerceiverTF with attention_to_channel option
# if self.encoder_type == "perceiver-tf":
# if encoder_config["attention_to_channel"] is True:
# encoder_config["kv_dim"] = model_cfg["d_feat"] # TODO: better readablity
# else:
# encoder_config["kv_dim"] = model_cfg["conv_out_channels"]
if isinstance(model_cfg["vocab_size"], str) and model_cfg["vocab_size"] == 'auto':
model_cfg["vocab_size"] = task_manager.num_tokens
if isinstance(model_cfg["num_max_positions"], str) and model_cfg["num_max_positions"] == 'auto':
model_cfg["num_max_positions"] = int(
max(model_cfg["feat_length"], model_cfg["event_length"]) + self.task_manager.max_task_token_length + 10)
# Pre-decoder
self.pre_decoder = nn.Sequential()
if self.encoder_type == "perceiver-tf" and self.decoder_type == "t5":
t, f, c = pre_enc_output_shape # perceiver-tf: (110, 128, 128) for 2s
encoder_output_shape = (t, encoder_config["num_latents"], encoder_config["d_latent"]) # (T, K, D_source)
decoder_input_shape = (t, decoder_config["d_model"]) # (T, D_target)
proj_layer = get_projection_layer(input_shape=encoder_output_shape,
output_shape=decoder_input_shape,
proj_type=self.pre_decoder_type)
self.pre_encoder_output_shape = pre_enc_output_shape
self.encoder_output_shape = encoder_output_shape
self.decoder_input_shape = decoder_input_shape
self.pre_decoder.append(proj_layer)
elif self.encoder_type in ["t5", "conformer"] and self.decoder_type == "t5":
pass
elif self.encoder_type == "perceiver-tf" and self.decoder_type == "multi-t5":
# NOTE: this is experiemental, only for multi-channel decoding with 13 classes
assert encoder_config["num_latents"] % decoder_config["num_channels"] == 0
encoder_output_shape = (encoder_config["num_latents"], encoder_config["d_model"])
decoder_input_shape = (decoder_config["num_channels"], decoder_config["d_model"])
proj_layer = get_multi_channel_projection_layer(input_shape=encoder_output_shape,
output_shape=decoder_input_shape,
proj_type=self.pre_decoder_type)
self.pre_decoder.append(proj_layer)
else:
raise NotImplementedError(
f"Encoder type {self.encoder_type} and decoder type {self.decoder_type} is not implemented yet.")
# Positional Encoding, Vocab, etc.
if self.encoder_type in ["t5", "conformer"]:
encoder_config["num_max_positions"] = decoder_config["num_max_positions"] = model_cfg["num_max_positions"]
else: # perceiver-tf uses separate positional encoding
encoder_config["num_max_positions"] = model_cfg["feat_length"]
decoder_config["num_max_positions"] = model_cfg["num_max_positions"]
encoder_config["vocab_size"] = decoder_config["vocab_size"] = model_cfg["vocab_size"]
# Print and save updated configs
self.audio_cfg = audio_cfg
self.model_cfg = model_cfg
self.shared_cfg = shared_cfg
self.save_hyperparameters()
if self.global_rank == 0:
print(self.hparams)
# Encoder and Decoder and LM-head
self.encoder = None
self.decoder = None
self.lm_head = LMHead(decoder_config, 1.0, model_cfg["tie_word_embeddings"])
self.embed_tokens = nn.Embedding(decoder_config["vocab_size"], decoder_config["d_model"])
self.embed_tokens.weight.data.normal_(mean=0.0, std=1.0)
self.shift_right_fn = None
self.set_encoder_decoder() # shift_right_fn is also set here
# Model as ModuleDict
# self.model = nn.ModuleDict({
# "pitchshift": self.pitchshift, # no grad; created in setup() only for training,
# and called by training_step()
# "spectrogram": self.spectrogram, # no grad
# "pre_encoder": self.pre_encoder,
# "encoder": self.encoder,
# "pre_decoder": self.pre_decoder,
# "decoder": self.decoder,
# "embed_tokens": self.embed_tokens,
# "lm_head": self.lm_head,
# })
# Tables (for logging)
columns = ['Ep', 'Track ID', 'Pred Events', 'Actual Events', 'Pred Notes', 'Actual Notes']
self.sample_table = wandb.Table(columns=columns)
# Output MIDI
if write_output_dir is not None:
if write_output_vocab is None:
from config.vocabulary import program_vocab_presets
self.midi_output_vocab = program_vocab_presets["gm_ext_plus"]
else:
self.midi_output_vocab = write_output_vocab
self.midi_output_inverse_vocab = create_inverse_vocab(self.midi_output_vocab)
def set_encoder_decoder(self) -> None:
"""Set encoder, decoder, lm_head and emb_tokens from self.model_cfg"""
# Generate and update T5Config
t5_basename = self.model_cfg["t5_basename"]
if t5_basename in T5_BASE_CFG.keys():
# Load from pre-defined config in config.py
t5_config = T5Config(**T5_BASE_CFG[t5_basename])
else:
# Load from HuggingFace hub
t5_config = T5Config.from_pretrained(t5_basename)
# Create encoder, decoder, lm_head and embed_tokens
if self.encoder_type == "t5":
self.encoder = T5EncoderYMT3(self.model_cfg["encoder"]["t5"], t5_config)
elif self.encoder_type == "perceiver-tf":
perceivertf_config = PerceiverTFConfig()
perceivertf_config.update(self.model_cfg["encoder"]["perceiver-tf"])
self.encoder = PerceiverTFEncoder(perceivertf_config)
elif self.encoder_type == "conformer":
conformer_config = ConformerYMT3Config()
conformer_config.update(self.model_cfg["encoder"]["conformer"])
self.encoder = ConformerYMT3Encoder(conformer_config)
if self.decoder_type == "t5":
self.decoder = T5DecoderYMT3(self.model_cfg["decoder"]["t5"], t5_config)
elif self.decoder_type == "multi-t5":
self.decoder = MultiChannelT5Decoder(self.model_cfg["decoder"]["multi-t5"], t5_config)
# `shift_right` function for decoding
self.shift_right_fn = self.decoder._shift_right
def setup(self, stage: str) -> None:
# Defining metrics
if self.hparams.eval_vocab is None:
extra_classes_per_dataset = [None]
else:
extra_classes_per_dataset = [
list(v.keys()) if v is not None else None for v in self.hparams.eval_vocab
] # e.g. [['Piano'], ['Guitar'], ['Piano'], ['Piano', 'Strings', 'Winds'], None]
# For direct addition of extra metrics using full metric name
extra_metrics = None
if self.hparams.add_melody_metric_to_singing is True:
extra_metrics = ["melody_rpa_Singing Voice", "melody_rca_Singing Voice", "melody_oa_Singing Voice"]
# Add pitch class metric
if self.hparams.add_pitch_class_metric is not None:
for sublist in extra_classes_per_dataset:
for name in self.hparams.add_pitch_class_metric:
if sublist is not None and name in sublist:
sublist += [name + "_pc"]
extra_classes_unique = list(
set(item for sublist in extra_classes_per_dataset if sublist is not None
for item in sublist)) # e.g. ['Strings', 'Winds', 'Guitar', 'Piano']
dm = self.trainer.datamodule
# Train/Vaidation-only
if stage == "fit":
self.val_metrics_macro = AMTMetrics(prefix=f'validation/macro_', extra_classes=extra_classes_unique)
self.val_metrics = nn.ModuleList() # val_metric is a list of AMTMetrics objects
for i in range(dm.num_val_dataloaders):
self.val_metrics.append(
AMTMetrics(prefix=f'validation/({dm.get_val_dataset_name(i)})',
extra_classes=extra_classes_per_dataset[i],
error_types=DECODING_ERR_TYPES))
# Add pitchshift layer
if self.shared_cfg["AUGMENTATION"]["train_pitch_shift_range"] in [None, [0, 0]]:
self.pitchshift = None
else:
# torchaudio pitchshifter requires a dummy input for initialization in DDP
input_shape = (self.shared_cfg["BSZ"]["train_local"], 1, self.audio_cfg["input_frames"])
self.pitchshift = PitchShiftLayer(
pshift_range=self.shared_cfg["AUGMENTATION"]["train_pitch_shift_range"],
expected_input_shape=input_shape,
device=self.device)
# Test-only
elif stage == "test":
# self.test_metrics_macro = AMTMetrics(
# prefix=f'test/macro_', extra_classes=extra_classes_unique)
self.test_metrics = nn.ModuleList()
for i in range(dm.num_test_dataloaders):
self.test_metrics.append(
AMTMetrics(prefix=f'test/({dm.get_test_dataset_name(i)})',
extra_classes=extra_classes_per_dataset[i],
extra_metrics=extra_metrics,
error_types=DECODING_ERR_TYPES))
# Test pitch shift layer: debug only
if self.test_pitch_shift_layer is not None:
self.test_pitch_shift_semitone = int(self.test_pitch_shift_layer)
self.pitchshift = PitchShiftLayer(
pshift_range=[self.test_pitch_shift_semitone, self.test_pitch_shift_semitone])
def configure_optimizers(self) -> None:
"""Configure optimizer and scheduler"""
optimizer, base_lr = get_optimizer(models_dict=self.named_parameters(),
optimizer_name=self.hparams.optimizer_name,
base_lr=self.hparams.base_lr,
weight_decay=self.hparams.weight_decay)
if self.hparams.optimizer_name.lower() == 'adafactor' and self.hparams.base_lr == None:
print("Using AdaFactor with auto learning rate and no scheduler")
return [optimizer]
if self.hparams.optimizer_name.lower() == 'dadaptadam':
print("Using dAdaptAdam with auto learning rate and no scheduler")
return [optimizer]
elif self.hparams.base_lr == None:
print(f"Using default learning rate {base_lr} of {self.hparams.optimizer_name} as base learning rate.")
self.hparams.base_lr = base_lr
scheduler_cfg = self.shared_cfg["LR_SCHEDULE"]
if self.hparams.max_steps != -1:
# overwrite total_steps
scheduler_cfg["total_steps"] = self.hparams.max_steps
_lr_scheduler = get_lr_scheduler(optimizer,
scheduler_name=self.hparams.scheduler_name,
base_lr=base_lr,
scheduler_cfg=scheduler_cfg)
lr_scheduler = {'scheduler': _lr_scheduler, 'interval': 'step', 'frequency': 1}
return [optimizer], [lr_scheduler]
def forward(
self,
x: torch.FloatTensor,
target_tokens: torch.LongTensor,
# task_tokens: Optional[torch.LongTensor] = None,
**kwargs) -> Dict:
""" Forward pass with teacher-forcing for training and validation.
Args:
x: (B, 1, T) waveform with default T=32767
target_tokens: (B, C, N) tokenized sequence of length N=event_length
task_tokens: (B, C, task_len) tokenized task
Returns:
{
'logits': (B, N + task_len + 1, vocab_size)
'loss': (1, )
}
NOTE: all the commented shapes are in the case of original MT3 setup.
"""
x = self.spectrogram(x) # mel-/spectrogram: (b, 256, 512) or (B, T, F)
x = self.pre_encoder(x) # projection to d_model: (B, 256, 512)
# TODO: task_cond_encoder would not work properly because of 3-d task_tokens
# if task_tokens is not None and task_tokens.numel() > 0 and self.use_task_cond_encoder is True:
# # append task embedding to encoder input
# task_embed = self.embed_tokens(task_tokens) # (B, task_len, 512)
# x = torch.cat([task_embed, x], dim=1) # (B, task_len + 256, 512)
enc_hs = self.encoder(inputs_embeds=x)["last_hidden_state"] # (B, T', D)
enc_hs = self.pre_decoder(enc_hs) # (B, T', D) or (B, K, T, D)
# if task_tokens is not None and task_tokens.numel() > 0 and self.use_task_cond_decoder is True:
# # append task token to decoder input and output label
# labels = torch.cat([task_tokens, target_tokens], dim=2) # (B, C, task_len + N)
# else:
# labels = target_tokens # (B, C, N)
labels = target_tokens # (B, C, N)
if labels.shape[1] == 1: # for single-channel decoders, e.g. t5.
labels = labels.squeeze(1) # (B, N)
dec_input_ids = self.shift_right_fn(labels) # t5:(B, N), multi-t5:(B, C, N)
dec_inputs_embeds = self.embed_tokens(dec_input_ids) # t5:(B, N, D), multi-t5:(B, C, N, D)
dec_hs, _ = self.decoder(inputs_embeds=dec_inputs_embeds, encoder_hidden_states=enc_hs, return_dict=False)
if self.model_cfg["tie_word_embeddings"] is True:
dec_hs = dec_hs * (self.model_cfg["decoder"][self.decoder_type]["d_model"]**-0.5)
logits = self.lm_head(dec_hs)
loss = None
labels = labels.masked_fill(labels == 0, value=-100) # ignore pad tokens for loss
loss_fct = CrossEntropyLoss(ignore_index=-100)
loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
return {"logits": logits, "loss": loss}
def inference(self,
x: torch.FloatTensor,
task_tokens: Optional[torch.LongTensor] = None,
max_token_length: Optional[int] = None,
**kwargs: Any) -> torch.Tensor:
""" Inference from audio batch by cached autoregressive decoding.
Args:
x: (b, 1, t) waveform with t=32767
task_token: (b, c, task_len) tokenized task. If None, will not append task embeddings (from task_tokens) to input.
max_length: Maximum length of generated sequence. If None, self.max_total_token_length.
**kwargs: https://huggingface.co/docs/transformers/v4.27.2/en/main_classes/text_generation#transformers.GenerationMixin.generate
Returns:
res_tokens: (b, n) resulting tokenized sequence of variable length < max_length
"""
if self.test_pitch_shift_layer is not None:
x_ps = self.pitchshift(x, self.test_pitch_shift_semitone)
x = x_ps
# From spectrogram to pre-decoder is the same pipeline as in forward()
x = self.spectrogram(x) # mel-/spectrogram: (b, 256, 512) or (B, T, F)
x = self.pre_encoder(x) # projection to d_model: (B, 256, 512)
if task_tokens is not None and task_tokens.numel() > 0 and self.use_task_cond_encoder is True:
# append task embedding to encoder input
task_embed = self.embed_tokens(task_tokens) # (B, task_len, 512)
x = torch.cat([task_embed, x], dim=1) # (B, task_len + 256, 512)
enc_hs = self.encoder(inputs_embeds=x)["last_hidden_state"] # (B, task_len + 256, 512)
enc_hs = self.pre_decoder(enc_hs) # (B, task_len + 256, 512)
# Cached-autoregressive decoding with task token (can be None) as prefix
if max_token_length is None:
max_token_length = self.max_total_token_length
pred_ids = task_cond_dec_generate(decoder=self.decoder,
decoder_type=self.decoder_type,
embed_tokens=self.embed_tokens,
lm_head=self.lm_head,
encoder_hidden_states=enc_hs,
shift_right_fn=self.shift_right_fn,
prefix_ids=task_tokens,
max_length=max_token_length) # (B, task_len + N) or (B, C, task_len + N)
if pred_ids.dim() == 2:
pred_ids = pred_ids.unsqueeze(1) # (B, 1, task_len + N)
if self.test_pitch_shift_layer is None:
return pred_ids
else:
return pred_ids, x_ps
def inference_file(
self,
bsz: int,
audio_segments: torch.FloatTensor, # (n_items, 1, segment_len): from a single file
note_token_array: Optional[torch.LongTensor] = None,
task_token_array: Optional[torch.LongTensor] = None,
# subtask_key: Optional[str] = "default"
) -> Tuple[List[np.ndarray], Optional[torch.Tensor]]:
""" Inference from audio batch by autoregressive decoding:
Args:
bsz: batch size
audio_segments: (n_items, 1, segment_len): segmented audio from a single file
note_token_array: (n_items, max_token_len): Optional. If token_array is None, will not return loss.
subtask_key: (str): If None, not using subtask prefix. By default, using "default" defined in task manager.
"""
# if subtask_key is not None:
# _subtask_token = torch.LongTensor(
# self.task_manager.get_eval_subtask_prefix_dict()[subtask_key]).to(self.device)
n_items = audio_segments.shape[0]
loss = 0.
pred_token_array_file = [] # each element is (B, C, L) np.ndarray
x_ps_concat = []
for i in range(0, n_items, bsz):
if i + bsz > n_items: # last batch can be smaller
x = audio_segments[i:n_items].to(self.device)
# if subtask_key is not None:
# b = n_items - i # bsz for the last batch
# task_tokens = _subtask_token.expand((b, -1)) # (b, task_len)
if note_token_array is not None:
target_tokens = note_token_array[i:n_items].to(self.device)
if task_token_array is not None and task_token_array.numel() > 0:
task_tokens = task_token_array[i:n_items].to(self.device)
else:
task_tokens = None
else:
x = audio_segments[i:i + bsz].to(self.device) # (bsz, 1, segment_len)
# if subtask_key is not None:
# task_tokens = _subtask_token.expand((bsz, -1)) # (bsz, task_len)
if note_token_array is not None:
target_tokens = note_token_array[i:i + bsz].to(self.device) # (bsz, token_len)
if task_token_array is not None and task_token_array.numel() > 0:
task_tokens = task_token_array[i:i + bsz].to(self.device)
else:
task_tokens = None
# token prediction (fast-autoregressive decoding)
# if subtask_key is not None:
# preds = self.inference(x, task_tokens).detach().cpu().numpy()
# else:
# preds = self.inference(x).detach().cpu().numpy()
if self.test_pitch_shift_layer is not None: # debug only
preds, x_ps = self.inference(x, task_tokens)
preds = preds.detach().cpu().numpy()
x_ps_concat.append(x_ps.detach().cpu())
else:
preds = self.inference(x, task_tokens).detach().cpu().numpy()
if len(preds) != len(x):
raise ValueError(f'preds: {len(preds)}, x: {len(x)}')
pred_token_array_file.append(preds)
# validation loss (by teacher forcing)
if note_token_array is not None:
loss_weight = x.shape[0] / n_items
loss += self(x, target_tokens)['loss'] * loss_weight
# loss += self(x, target_tokens, task_tokens)['loss'] * loss_weight
else:
loss = None
if self.test_pitch_shift_layer is not None: # debug only
if self.hparams.write_output_dir is not None:
x_ps_concat = torch.cat(x_ps_concat, dim=0)
return pred_token_array_file, loss, x_ps_concat.flatten().unsqueeze(0)
else:
return pred_token_array_file, loss
def training_step(self, batch, batch_idx) -> torch.Tensor:
# batch: {
# 'dataset1': [Tuple[audio_segments(b, 1, t), tokens(b, max_token_len), ...]]
# 'dataset2': [Tuple[audio_segments(b, 1, t), tokens(b, max_token_len), ...]]
# 'dataset3': ...
# }
audio_segments, note_tokens, pshift_steps = [torch.cat(t, dim=0) for t in zip(*batch.values())]
if self.pitchshift is not None:
# Pitch shift
n_groups = len(batch)
audio_segments = torch.chunk(audio_segments, n_groups, dim=0)
pshift_steps = torch.chunk(pshift_steps, n_groups, dim=0)
for p in pshift_steps:
assert p.eq(p[0]).all().item()
audio_segments = torch.cat([self.pitchshift(a, p[0].item()) for a, p in zip(audio_segments, pshift_steps)],
dim=0)
loss = self(audio_segments, note_tokens)['loss']
self.log('train_loss',
loss,
on_step=True,
on_epoch=True,
prog_bar=True,
batch_size=note_tokens.shape[0],
sync_dist=True)
# print('lr', self.trainer.optimizers[0].param_groups[0]['lr'])
return loss
def validation_step(self, batch, batch_idx, dataloader_idx=0) -> Dict:
# File-wise validation
if self.task_manager.num_decoding_channels == 1:
bsz = self.shared_cfg["BSZ"]["validation"]
else:
bsz = self.shared_cfg["BSZ"]["validation"] // self.task_manager.num_decoding_channels * 3
# audio_segments, notes_dict, note_token_array, task_token_array = batch
audio_segments, notes_dict, note_token_array = batch
task_token_array = None
# Loop through the tensor in chunks of bsz (=subbsz actually)
n_items = audio_segments.shape[0]
start_secs_file = [32767 * i / 16000 for i in range(n_items)]
with Timer() as t:
pred_token_array_file, loss = self.inference_file(bsz, audio_segments, note_token_array, task_token_array)
"""
notes_dict: # Ground truth notes
{
'mtrack_id': int,
'program': List[int],
'is_drum': bool,
'duration_sec': float,
'notes': List[Note],
}
"""
# Process a list of channel-wise token arrays for a file
num_channels = self.task_manager.num_decoding_channels
pred_notes_in_file = []
n_err_cnt = Counter()
for ch in range(num_channels):
pred_token_array_ch = [arr[:, ch, :] for arr in pred_token_array_file] # (B, L)
zipped_note_events_and_tie, list_events, ne_err_cnt = self.task_manager.detokenize_list_batches(
pred_token_array_ch, start_secs_file, return_events=True)
pred_notes_ch, n_err_cnt_ch = merge_zipped_note_events_and_ties_to_notes(zipped_note_events_and_tie)
pred_notes_in_file.append(pred_notes_ch)
n_err_cnt += n_err_cnt_ch
pred_notes = mix_notes(pred_notes_in_file) # This is the mixed notes from all channels
if self.hparams.write_output_dir is not None:
track_info = [notes_dict[k] for k in notes_dict.keys() if k.endswith("_id")][0]
dataset_info = [k for k in notes_dict.keys() if k.endswith('_id')][0][:-3]
# write_model_output_as_npy(zipped_note_events_and_tie, self.hparams.write_output_dir,
# track_info)
write_model_output_as_midi(pred_notes,
self.hparams.write_output_dir,
track_info,
self.midi_output_inverse_vocab,
output_dir_suffix=str(dataset_info) + '_' +
str(self.hparams.eval_subtask_key))
# generate sample text to display in log table
# pred_events_text = [str([list_events[0][:200]])]
# pred_notes_text = [str([pred_notes[:200]])]
# this is local GPU metric per file, not global metric in DDP
drum_metric, non_drum_metric, instr_metric = compute_track_metrics(
pred_notes,
notes_dict['notes'],
eval_vocab=self.hparams.eval_vocab[dataloader_idx],
eval_drum_vocab=self.hparams.eval_drum_vocab,
onset_tolerance=self.hparams.onset_tolerance,
add_pitch_class_metric=self.hparams.add_pitch_class_metric)
self.val_metrics[dataloader_idx].bulk_update(drum_metric)
self.val_metrics[dataloader_idx].bulk_update(non_drum_metric)
self.val_metrics[dataloader_idx].bulk_update(instr_metric)
self.val_metrics_macro.bulk_update(drum_metric)
self.val_metrics_macro.bulk_update(non_drum_metric)
self.val_metrics_macro.bulk_update(instr_metric)
# Log sample table: predicted notes and ground truth notes
# if batch_idx in (0, 1) and self.global_rank == 0:
# actual_notes_text = [str([notes_dict['notes'][:200]])]
# actual_tokens = token_array[0, :200].detach().cpu().numpy().tolist()
# actual_events_text = [str(self.tokenizer._decode(actual_tokens))]
# track_info = [notes_dict[k] for k in notes_dict.keys() if k.endswith("_id")]
# self.sample_table.add_data(self.current_epoch, track_info, pred_events_text,
# actual_events_text, pred_notes_text, actual_notes_text)
# self.logger.log_table('Samples', self.sample_table.columns, self.sample_table.data)
decoding_time_sec = t.elapsed_time()
self.log('val_loss', loss, prog_bar=True, batch_size=n_items, sync_dist=True)
# self.val_metrics[dataloader_idx].bulk_update_errors({'decoding_time': decoding_time_sec})
def on_validation_epoch_end(self) -> None:
for val_metrics in self.val_metrics:
self.log_dict(val_metrics.bulk_compute(), sync_dist=True)
val_metrics.bulk_reset()
self.log_dict(self.val_metrics_macro.bulk_compute(), sync_dist=True)
self.val_metrics_macro.bulk_reset()
def test_step(self, batch, batch_idx, dataloader_idx=0) -> Dict:
# File-wise evaluation
if self.task_manager.num_decoding_channels == 1:
bsz = self.shared_cfg["BSZ"]["validation"]
else:
bsz = self.shared_cfg["BSZ"]["validation"] // self.task_manager.num_decoding_channels * 3
# audio_segments, notes_dict, note_token_array, task_token_array = batch
audio_segments, notes_dict, note_token_array = batch
task_token_array = None
# Test pitch shift layer: debug only
if self.test_pitch_shift_layer is not None and self.test_pitch_shift_semitone != 0:
for n in notes_dict['notes']:
if n.is_drum == False:
n.pitch = n.pitch + self.test_pitch_shift_semitone
# Loop through the tensor in chunks of bsz (=subbsz actually)
n_items = audio_segments.shape[0]
start_secs_file = [32767 * i / 16000 for i in range(n_items)]
if self.test_pitch_shift_layer is not None and self.hparams.write_output_dir is not None:
pred_token_array_file, loss, x_ps = self.inference_file(bsz, audio_segments, None, None)
else:
pred_token_array_file, loss = self.inference_file(bsz, audio_segments, None, None)
if len(pred_token_array_file) > 0:
# Process a list of channel-wise token arrays for a file
num_channels = self.task_manager.num_decoding_channels
pred_notes_in_file = []
n_err_cnt = Counter()
for ch in range(num_channels):
pred_token_array_ch = [arr[:, ch, :] for arr in pred_token_array_file] # (B, L)
zipped_note_events_and_tie, list_events, ne_err_cnt = self.task_manager.detokenize_list_batches(
pred_token_array_ch, start_secs_file, return_events=True)
pred_notes_ch, n_err_cnt_ch = merge_zipped_note_events_and_ties_to_notes(zipped_note_events_and_tie)
pred_notes_in_file.append(pred_notes_ch)
n_err_cnt += n_err_cnt_ch
pred_notes = mix_notes(pred_notes_in_file) # This is the mixed notes from all channels
if self.test_pitch_shift_layer is not None and self.hparams.write_output_dir is not None:
# debug only
wav_output_dir = os.path.join(self.hparams.write_output_dir, f"model_output_{dataset_info}")
os.makedirs(wav_output_dir, exist_ok=True)
wav_output_file = os.path.join(wav_output_dir, f"{track_info}_ps_{self.test_pitch_shift_semitone}.wav")
torchaudio.save(wav_output_file, x_ps.squeeze(1), 16000, bits_per_sample=16)
drum_metric, non_drum_metric, instr_metric = compute_track_metrics(
pred_notes,
notes_dict['notes'],
eval_vocab=self.hparams.eval_vocab[dataloader_idx],
eval_drum_vocab=self.hparams.eval_drum_vocab,
onset_tolerance=self.hparams.onset_tolerance,
add_pitch_class_metric=self.hparams.add_pitch_class_metric,
add_melody_metric=['Singing Voice'] if self.hparams.add_melody_metric_to_singing else None,
add_frame_metric=True,
add_micro_metric=True,
add_multi_f_metric=True)
if self.hparams.write_output_dir is not None and self.global_rank == 0:
# write model output to file
track_info = [notes_dict[k] for k in notes_dict.keys() if k.endswith("_id")][0]
dataset_info = [k for k in notes_dict.keys() if k.endswith('_id')][0][:-3]
f_score = f"OnF{non_drum_metric['onset_f']:.2f}_MulF{instr_metric['multi_f']:.2f}"
write_model_output_as_midi(pred_notes,
self.hparams.write_output_dir,
track_info,
self.midi_output_inverse_vocab,
output_dir_suffix=str(dataset_info) + '_' +
str(self.hparams.eval_subtask_key) + '_' + f_score)
write_err_cnt_as_json(track_info, self.hparams.write_output_dir,
str(dataset_info) + '_' + str(self.hparams.eval_subtask_key) + '_' + f_score,
n_err_cnt, ne_err_cnt)
# Test with optimal octave shift
if self.hparams.test_optimal_octave_shift:
track_info = [notes_dict[k] for k in notes_dict.keys() if k.endswith("_id")][0]
dataset_info = [k for k in notes_dict.keys() if k.endswith('_id')][0][:-3]
score = [instr_metric['onset_f_Bass']]
ref_notes_plus = []
ref_notes_minus = []
for note in notes_dict['notes']:
if note.is_drum == True:
ref_notes_plus.append(note)
ref_notes_minus.append(note)
else:
ref_notes_plus.append(
Note(is_drum=note.is_drum,
program=note.program,
onset=note.onset,
offset=note.offset,
pitch=note.pitch + 12,
velocity=note.velocity))
ref_notes_minus.append(
Note(is_drum=note.is_drum,
program=note.program,
onset=note.onset,
offset=note.offset,
pitch=note.pitch - 12,
velocity=note.velocity))
drum_metric_plus, non_drum_metric_plus, instr_metric_plus = compute_track_metrics(
pred_notes,
ref_notes_plus,
eval_vocab=self.hparams.eval_vocab[dataloader_idx],
eval_drum_vocab=self.hparams.eval_drum_vocab,
onset_tolerance=self.hparams.onset_tolerance,
add_pitch_class_metric=self.hparams.add_pitch_class_metric)
drum_metric_minus, non_drum_metric_minus, instr_metric_minus = compute_track_metrics(
ref_notes_minus,
notes_dict['notes'],
eval_vocab=self.hparams.eval_vocab[dataloader_idx],
eval_drum_vocab=self.hparams.eval_drum_vocab,
onset_tolerance=self.hparams.onset_tolerance,
add_pitch_class_metric=self.hparams.add_pitch_class_metric)
score.append(instr_metric_plus['onset_f_Bass'])
score.append(instr_metric_minus['onset_f_Bass'])
max_index = score.index(max(score))
if max_index == 0:
print(f"ZERO: {track_info}, z/p/m: {score[0]:.2f}/{score[1]:.2f}/{score[2]:.2f}")
elif max_index == 1:
# plus
instr_metric['onset_f_Bass'] = instr_metric_plus['onset_f_Bass']
print(f"PLUS: {track_info}, z/p/m: {score[0]:.2f}/{score[1]:.2f}/{score[2]:.2f}")
write_model_output_as_midi(ref_notes_plus,
self.hparams.write_output_dir,
track_info + '_ref_octave_plus',
self.midi_output_inverse_vocab,
output_dir_suffix=str(dataset_info) + '_' +
str(self.hparams.eval_subtask_key))
else:
# minus
instr_metric['onset_f_Bass'] = instr_metric_minus['onset_f_Bass']
print(f"MINUS: {track_info}, z/p/m: {score[0]:.2f}/{score[1]:.2f}/{score[2]:.2f}")
write_model_output_as_midi(ref_notes_minus,
self.hparams.write_output_dir,
track_info + '_ref_octave_minus',
self.midi_output_,
output_dir_suffix=str(dataset_info) + '_' +
str(self.hparams.eval_subtask_key))
self.test_metrics[dataloader_idx].bulk_update(drum_metric)
self.test_metrics[dataloader_idx].bulk_update(non_drum_metric)
self.test_metrics[dataloader_idx].bulk_update(instr_metric)
# self.test_metrics_macro.bulk_update(drum_metric)
# self.test_metrics_macro.bulk_update(non_drum_metric)
# self.test_metrics_macro.bulk_update(instr_metric)
def on_test_epoch_end(self) -> None:
# all_gather is done seeminglesly by torchmetrics
for test_metrics in self.test_metrics:
self.log_dict(test_metrics.bulk_compute(), sync_dist=True)
test_metrics.bulk_reset()
# self.log_dict(self.test_metrics_macro.bulk_compute(), sync_dist=True)
# self.test_metrics_macro.bulk_reset()
def test_case_forward_mt3():
import torch
from config.config import audio_cfg, model_cfg, shared_cfg
from model.ymt3 import YourMT3
model = YourMT3()
model.eval()
x = torch.randn(2, 1, 32767)
labels = torch.randint(0, 596, (2, 1, 1024), requires_grad=False) # (B, C=1, T)
task_tokens = torch.LongTensor([])
output = model.forward(x, labels, task_tokens)
logits, loss = output['logits'], output['loss']
assert logits.shape == (2, 1024, 596) # (B, N, vocab_size)
def test_case_inference_mt3():
import torch
from config.config import audio_cfg, model_cfg, shared_cfg
from model.ymt3 import YourMT3
model_cfg["num_max_positions"] = 1024 + 3 + 1
model = YourMT3(model_cfg=model_cfg)
model.eval()
x = torch.randn(2, 1, 32767)
task_tokens = torch.randint(0, 596, (2, 3), requires_grad=False)
pred_ids = model.inference(x, task_tokens, max_token_length=10) # (2, 3, 9) (B, C, L-task_len)
# TODO: need to check the length of pred_ids when task_tokens is not None
def test_case_forward_enc_perceiver_tf_dec_t5():
import torch
from model.ymt3 import YourMT3
from config.config import audio_cfg, model_cfg, shared_cfg
model_cfg["encoder_type"] = "perceiver-tf"
audio_cfg["codec"] = "spec"
audio_cfg["hop_length"] = 300
model = YourMT3(audio_cfg=audio_cfg, model_cfg=model_cfg)
model.eval()
x = torch.randn(2, 1, 32767)
labels = torch.randint(0, 596, (2, 1, 1024), requires_grad=False)
# forward
output = model.forward(x, labels)
logits, loss = output['logits'], output['loss'] # logits: (2, 1024, 596) (B, N, vocab_size)
# inference
pred_ids = model.inference(x, None, max_token_length=3) # (2, 1, 3) (B, C, L)
def test_case_forward_enc_conformer_dec_t5():
import torch
from model.ymt3 import YourMT3
from config.config import audio_cfg, model_cfg, shared_cfg
model_cfg["encoder_type"] = "conformer"
audio_cfg["codec"] = "melspec"
audio_cfg["hop_length"] = 128
model = YourMT3(audio_cfg=audio_cfg, model_cfg=model_cfg)
model.eval()
x = torch.randn(2, 1, 32767)
labels = torch.randint(0, 596, (2, 1024), requires_grad=False)
# forward
output = model.forward(x, labels)
logits, loss = output['logits'], output['loss'] # logits: (2, 1024, 596) (B, N, vocab_size)
# inference
pred_ids = model.inference(x, None, 20) # (2, 1, 20) (B, C, L)
def test_case_enc_perceiver_tf_dec_multi_t5():
import torch
from model.ymt3 import YourMT3
from config.config import audio_cfg, model_cfg, shared_cfg
model_cfg["encoder_type"] = "perceiver-tf"
model_cfg["decoder_type"] = "multi-t5"
model_cfg["encoder"]["perceiver-tf"]["attention_to_channel"] = True
model_cfg["encoder"]["perceiver-tf"]["num_latents"] = 26
audio_cfg["codec"] = "spec"
audio_cfg["hop_length"] = 300
model = YourMT3(audio_cfg=audio_cfg, model_cfg=model_cfg)
model.eval()
x = torch.randn(2, 1, 32767)
labels = torch.randint(0, 596, (2, 13, 200), requires_grad=False) # (B, C, T)
# x = model.spectrogram(x)
# x = model.pre_encoder(x) # (2, 110, 128, 128) (B, T, C, D)
# enc_hs = model.encoder(inputs_embeds=x)["last_hidden_state"] # (2, 110, 128, 128) (B, T, C, D)
# enc_hs = model.pre_decoder(enc_hs) # (2, 13, 110, 512) (B, C, T, D)
# dec_input_ids = model.shift_right_fn(labels) # (2, 13, 200) (B, C, T)
# dec_inputs_embeds = model.embed_tokens(dec_input_ids) # (2, 13, 200, 512) (B, C, T, D)
# dec_hs, _ = model.decoder(
# inputs_embeds=dec_inputs_embeds, encoder_hidden_states=enc_hs, return_dict=False)
# logits = model.lm_head(dec_hs) # (2, 13, 200, 596) (B, C, T, vocab_size)
# forward
x = torch.randn(2, 1, 32767)
labels = torch.randint(0, 596, (2, 13, 200), requires_grad=False) # (B, C, T)
output = model.forward(x, labels)
logits, loss = output['logits'], output['loss'] # (2, 13, 200, 596) (B, C, T, vocab_size)
# inference
model.max_total_token_length = 123 # to save time..
pred_ids = model.inference(x, None) # (2, 13, 123) (B, C, L)