diff --git a/.gitignore b/.gitignore index 35987d26bae7cd773ac2c0131187b93f8d399b43..c8ebd1620ba95e19b56f2ae7dcec022cfb036305 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,2 @@ -amt/ +amt/logs/ examples/ diff --git a/amt/src/.coverage b/amt/src/.coverage new file mode 100644 index 0000000000000000000000000000000000000000..6f20dba0326a42a0aaa082ea70dcf84d52e8605e Binary files /dev/null and b/amt/src/.coverage differ diff --git a/amt/src/.coveragerc b/amt/src/.coveragerc new file mode 100644 index 0000000000000000000000000000000000000000..51d0cf37bf38607b6ffd273ced0985c9d519c428 --- /dev/null +++ b/amt/src/.coveragerc @@ -0,0 +1,5 @@ +[run] +omit = + train.py + test.py + install*.py diff --git a/amt/src/config/.DS_Store b/amt/src/config/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..271e64a4b94feb3ff6977a16ba1250591c068a98 Binary files /dev/null and b/amt/src/config/.DS_Store differ diff --git a/amt/src/config/config.py b/amt/src/config/config.py new file mode 100644 index 0000000000000000000000000000000000000000..295ebacbbfc7836281d26ca3400b367772d3112f --- /dev/null +++ b/amt/src/config/config.py @@ -0,0 +1,272 @@ +"""config.py""" +import numpy as np +# yapf: disable +""" +audio_cfg: +- Used by 'ymt3' to create a spectrogram layer. +- Input shape of model is determined by audio_cfg. +- 'train.py' arguments can override these defaults. +""" +audio_cfg = { + # Overwrittable by args in train.py + "codec": "melspec", # {melspec, spec} melspec for MT3, spec for PerceiverTF + "hop_length": 128, # {128, 300} 128 for MT3, 300 for PerceiverTF + # Shared audio parameters + "audio_backend": "torchaudio", # {torchaudio, nnAudio} + "sample_rate": 16000, + "input_frames": 32767, # number of input frames (~=2.048 s), determining in-/output shape of front layers. + "n_fft": 2048, + "n_mels": 512, # only for melspec + "f_min": 50.0, + "f_max": 8000.0, +} # TODO: currently dataloader is not updated by "input_frames" + +""" +model_cfg: +- Encoder type dictates use of T5_CFG or PERCEIVER_TF_CFG. +- 'train.py' arguments can override these defaults. +""" +model_cfg = { + "encoder_type": "t5", # {"t5", "perceiver-tf", "conformer"} + "decoder_type": "t5", # {"t5", "multi-t5"} + "pre_encoder_type": "default", # {None, "default", "conv", "conv1d", "conv2d_avpt"} by default, t5:None, perceiver:conv. + "pre_encoder_type_default": {"t5": None, "perceiver-tf": "conv", "conformer": None}, + "pre_decoder_type": "default", # {None, 'linear', 'conv1', 'mlp', 'group_linear'} see model/projection_layer.py + "pre_decoder_type_default": { # [enc_type][dec_type] + "t5": {"t5": None,}, + "perceiver-tf": {"t5": "linear", "multi-t5": "mc_shared_linear"}, + "conformer": {"t5": None,}, + }, + "conv_out_channels": 128, # number of filters for 'conv' pre_encoder. Otherwise ignored. + "t5_basename": "google/t5-v1_1-small", + "pretrained": False, # bool, if True, load pretrained weights from t5_basename. Mismatched layers are ignored. + "use_task_conditional_encoder": True, # True by default, but default task is None. So not activated by default. + "use_task_conditional_decoder": True, # True by default, but default task is None. So not activated by default. + "d_feat": "auto", # Input audio feature dimension for encoder. Automatically inferred by audio_cfg and existence of pre_encoders. + "tie_word_embeddings": True, # If True, weights of embed_tokens and lm_head are tied for stabilizing gradients. + "vocab_size": "auto", # int or "auto", automatically inferred by task manager. + "num_max_positions": "auto", # int or "auto". Length of positional encoding. Automatically inferred by "feat_length", "event_length" and task_manager.max_task_token_length. + # 'vocab_size', 'tie_word_embeddings' and 'num_max_positions' are auto-copied to encoder and decoder configs in the below. + "encoder": { + "t5": { + "d_model": 512, # Hidden size of T5 encoder. + "num_heads": 6, + "num_layers": 8, + "dropout_rate": 0.05, + "position_encoding_type": "sinusoidal", # {'sinusoidal', 'trainable'}. + "ff_widening_factor": 2, # wideening factor for MLP/MoE layers. Default is 2 in T5. + "ff_layer_type": "t5_gmlp", # {'t5_gmlp', 'moe', 'mlp', 'gmlp'}. 'moe' for mixture of experts, 'mlp' for standard transformer dense layer, 'gmlp' for simple gated MLP. + }, + "perceiver-tf": { + "num_latents": 24, # number of latents in Perceiver. 24 in perceiver-tf paper. + "d_latent": 128, # latent dimension of Perceiver. 128 in perceiver-tf paper. + "d_model": "q", # int or "q" or "kv". Inner-dim of sca and local/temporal self-att. + # "q" follows "latent_dim". "kv" follows "d_feat". Best practice is to inc-/decrease 'd_latent', instead of 'd_model'. + "num_blocks": 3, # number of Perceiver-TF blocks in encoder. L in the paper. + "num_local_transformers_per_block": 2, # N in the paper. + "num_temporal_transformers_per_block": 2, # M in the paper. + "sca_use_query_residual": False, + "dropout_rate": 0.1, + "position_encoding_type": "trainable", # {'trainable', 'rotary', 'alibi', 'alibit', None, 'tkd','td', 'tk', 'kdt'}. alibit is alibi with trainable slopes. + "attention_to_channel": True, # Whether to use channel attention in sca. + "layer_norm_type": "layer_norm", # {'layer_norm', 'rms_norm'} + "ff_layer_type": "mlp", # {'moe', 'mlp', gmlp}. 'moe' for mixture of experts, 'mlp' for standard transformer dense layer, 'gmlp' for simple gated MLP. + "ff_widening_factor": 1, # wideening factor for MLP/MoE layers. Default is 1. + "moe_num_experts": 4, # number of experts in MoE layer. Default is 4. Disabled if ff_layer_type is not 'moe'. + "moe_topk": 2, # top-k routing in MoE layer. Default is 2. Disabled if ff_layer_type is not 'moe'. + "hidden_act": 'gelu', # activation function in MLP/MoE layer. Default is 'gelu'. {'gelu', 'silu', 'relu'} + "rotary_type_sca": "pixel", # {'l'|'lang', 'p'|'pixel'}. Default is 'pixel'. + "rotary_type_latent": "pixel", # {'l'|'lang', 'p'|'pixel'}. Default is 'pixel'. + "rotary_type_temporal": "lang", # {'l'|'lang', 'p'|'pixel'}. Default is 'lang'. + "rotary_apply_to_keys": False, # Whether to apply rotary to keys. Default is False. + "rotary_partial_pe": False, # Whether to use partial positional encoding. Default is False. + }, + "conformer": { + "d_model": 512, # Hidden size of T5 encoder. + "intermediate_size": 512, # or 2048. size of the intermediate feed forward layer in each T5Block + "num_heads": 8, + "num_layers": 8, + "dropout_rate": 0.1, + "layerdrop": 0.1, # see https://arxiv.org/abs/1909.11556 + "position_encoding_type": "rotary", # {'rotary', 'relative'}. + "conv_dim": (512, 512, 512, 512, 512, 512, 512), + "conv_stride": (5, 2, 2, 2, 2, 2, 2), + "conv_kernel": (10, 3, 3, 3, 3, 3, 3), + "conv_depthwise_kernel_size": 31, + }, + + }, + "decoder": { + "t5": { + "d_model": 512, # Hidden size of T5 encoder. If encoder has lower dim, it is projected to this dim for enc-dec cross att. + "num_heads": 6, + "num_layers": 8, + "dropout_rate": 0.05, + "position_encoding_type": "sinusoidal", # {'sinusoidal', 'trainable'}. + "ff_widening_factor": 2, # wideening factor for MLP/MoE layers. Default is 2 in T5. + "ff_layer_type": "t5_gmlp", # {'t5_gmlp', 'moe', 'mlp', 'gmlp'}. 'moe' for mixture of experts, 'mlp' for standard transformer dense layer, 'gmlp' for simple gated MLP. + }, + "multi-t5": { + "d_model": 512, # Hidden size of T5 encoder. Recommended: {256 or 512} + "num_heads": 6, + "num_layers": 8, + "dropout_rate": 0.05, + "position_encoding_type": "sinusoidal", # {'sinusoidal', 'trainable'}. + "ff_widening_factor": 2, # wideening factor for MLP/MoE layers. Default is 2 in T5. + "ff_layer_type": "t5_gmlp", # {'t5_gmlp', 'moe', 'mlp', 'gmlp'}. 'moe' for mixture of experts, 'mlp' for standard transformer dense layer, 'gmlp' for simple gated MLP. + "num_channels": 13, + }, + }, + "feat_length": "auto", # Input audio feature length for encoder. Automatically inferred by audio_cfg. + # mt3: 256 time steps + "event_length": 1024, # max length of event tokens excluding task tokens <-- 128 for multi-t5 + "init_factor": 1.0, # initialization factor for embedding layers +} + +# yapf: enable +shared_cfg = { + "PATH": { + "data_home": "../../data", # path to the data directory. If using relative path, it is relative to /src directory. + }, + "BSZ": { # global batch size is local_bsz * n_GPUs in DDP mode + "train_sub": 12, #20, # sub-batch size is per CPU worker + "train_local": 24, #40, # local batch size is per GPU in DDP mode + "validation": 64, # validation batch size is per GPU in DDP mode + "test": 64, + }, + "AUGMENTATION": { + "train_random_amp_range": [0.8, 1.1], # min and max amplitude scaling factor + "train_stem_iaug_prob": 0.7, # probability of stem activation in intra-stem augmentation + "train_stem_xaug_policy": { + "max_k": 3, + "tau": 0.3, + "alpha": 1.0, + "max_subunit_stems": 12, # the number of subunit stems to be reduced to this number of stems + "p_include_singing": None, # NOT IMPLEMENTED; probability of including singing for cross augmented examples. if None, use base probaility. + "no_instr_overlap": True, + "no_drum_overlap": True, + "uhat_intra_stem_augment": True, + }, + "train_pitch_shift_range": [-2, 2], # [min, max] in semitones. None or [0, 0] for no pitch shift. + }, + "DATAIO": { # do not set `shuffle` here. + "num_workers": 4, # num_worker is per GPU in DDP mode + "prefetch_factor": 2, #2, + "pin_memory": True, + "persistent_workers": False, + }, + "CHECKPOINT": { + "save_top_k": 4, # max top k checkpoints to save + "monitor": 'validation/macro_onset_f', + "mode": 'max', + # "every_n_epochs": 20, # only working when check_val_every_n_epoch is 0 + "save_last": True, # save last model + "filename": "{epoch}-{step}", + }, + "TRAINER": { # do not coverwrite args in this section + "limit_train_batches": 1.0, # How much of training dataset to check (float = fraction, int = num_batches) + "limit_val_batches": 1.0, + "limit_test_batches": 1.0, + "gradient_clip_val": 1.0, # {0 or None} means don't clip. + "accumulate_grad_batches": 1, #1, # Accumulates grads every k batches. If set to 1, no effect. + "check_val_every_n_epoch": 1, #5, 1 for very large dataset such as EGMD + "num_sanity_val_steps": 0, + }, + "WANDB": { + "save_dir": "../logs", + "cache_dir": "../logs/.wandb_cache", + "resume": "allow", + "anonymous": "allow", # {never, allow, must} + "mode": "online", # {online, offline, disabled} + }, + "LR_SCHEDULE": { + # "scheduler_type": "cosine", # {legacy, cosine, constant} + "warmup_steps": 1000, # only for cosine scheduler, legacy scheduler follows T5's legacy schedule + "total_steps": 100000, # argparser of train.py can overwrite this + "final_cosine": 1e-5, # only for cosine scheduler + }, + "TOKENIZER": { + "max_shift_steps": "auto", # max number of shift steps in the model. (int) or "auto". If "auto", it is set by audio_cfg["input_frames"] and shift_steps_ms. 206 with default setup. + "shift_step_ms": 10, # shift step in ms + }, +} + +T5_BASE_CFG = { + "google/t5-v1_1-small": { + "architectures": ["T5ForConditionalGeneration"], + "d_ff": + 1024, # size of the intermediate feed forward layer in each T5Block. Can be overwrten by ff_widening_factor in model_cfg. + "d_kv": 64, # d_kv has to be equal to d_model // num_heads. + # "d_model": 512, # encoder hiddnen size, defined by model_cfg + "decoder_start_token_id": 0, + "dense_act_fn": "gelu_new", + # "dropout_rate": 0.05, # can be overwritten by args in ymt3 + "eos_token_id": 1, + "feed_forward_proj": "gated-gelu", + "initializer_factor": 1.0, + "is_encoder_decoder": True, + "is_gated_act": True, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + # "num_decoder_layers": 8, # defined by model_cfg + # "num_heads": 6, # defined by model_cfg + # "num_layers": 8, # defined by model_cfg + "output_past": True, + "pad_token_id": 0, + "relative_attention_num_buckets": 32, + # "tie_word_embeddings": True, + "use_cache": True, + # "vocab_size": 1391 # vocab_size is automatically set by the task manager... + }, + "google/t5-efficient-small": { + "architectures": ["T5ForConditionalGeneration"], + "d_ff": 2048, + "d_kv": 64, + "d_model": 512, + "decoder_start_token_id": 0, + "dropout_rate": 0.1, + "eos_token_id": 1, + "feed_forward_proj": "relu", + "initializer_factor": 1.0, + "is_encoder_decoder": True, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + "num_decoder_layers": 6, + "num_heads": 8, + "num_layers": 6, + "pad_token_id": 0, + "relative_attention_num_buckets": 32, + "torch_dtype": "float32", + "transformers_version": "4.17.0.dev0", + "use_cache": True, + }, +} + +# yapf: enable +DEEPSPEED_CFG = { + "zero_allow_untested_optimizer": True, + "optimizer": { + "type": "adam", + "params": { + "lr": 1e-4, + "betas": [0.998, 0.999], + "eps": 1e-3, + "weight_decay": 0.001, + "adam_w_mode": True, + } + }, + "scheduler": { + "type": "WarmupLR", + "params": { + "last_batch_iteration": -1, + "warmup_min_lr": 0, + "warmup_max_lr": 3e-5, + "warmup_num_steps": 100, + }, + }, + "zero_optimization": { + "stage": 0, #0,1,2,3 + # "offload_optimizer": + # False, # Enable Offloading optimizer state/calculation to the host CPU + }, +} diff --git a/amt/src/config/data_presets.py b/amt/src/config/data_presets.py new file mode 100644 index 0000000000000000000000000000000000000000..8947f18110d50404ce9c10dd7449c9a8db0ec5c0 --- /dev/null +++ b/amt/src/config/data_presets.py @@ -0,0 +1,811 @@ +""" data.py: +Data presets for training and evaluation. + +Single Presets: + musicnet_mt3 + musicnet_em + musicnet_thickstun + slakh + guitarset + ... + +Multi Presets: + all_mmegs + ... + +""" +from config.vocabulary import * +from config.vocabulary import drum_vocab_presets, program_vocab_presets +from utils.utils import deduplicate_splits, merge_splits, merge_vocab + +data_preset_single_cfg = { + "musicnet_mt3": { + "eval_vocab": [MUSICNET_INSTR_CLASS], + "dataset_name": "musicnet", + "train_split": "train_mt3", + "validation_split": "validation_mt3_acoustic", + "test_split": "test_mt3_acoustic", + "has_stem": False, + }, + "musicnet_mt3_synth_only": { # sanity-check + "eval_vocab": [MUSICNET_INSTR_CLASS], + "dataset_name": "musicnet", + "train_split": "train_mt3_synth", + "validation_split": "validation_mt3_synth", + "test_split": "test_mt3_acoustic", + "has_stem": False, + }, + "musicnet_mt3_em": { + "eval_vocab": [MUSICNET_INSTR_CLASS], + "dataset_name": "musicnet", + "train_split": "train_mt3_em", + "validation_split": "validation_mt3_em", + "test_split": "test_mt3_em", + "has_stem": False, + }, + "musicnet_thickstun": { # exp4 + "eval_vocab": [MUSICNET_INSTR_CLASS], + "dataset_name": "musicnet", + "train_split": "train_thickstun", + "validation_split": "test_thickstun", + "test_split": "test_thickstun", + "has_stem": False, + }, + "musicnet_thickstun_em": { # NOTE: this is not the use of external 'synth' in the paper, but the use of 'synth' within the dataset + "eval_vocab": [MUSICNET_INSTR_CLASS], + "dataset_name": "musicnet", + "train_split": "train_thickstun_em", + "validation_split": "test_thickstun_em", + "test_split": "test_thickstun_em", + "has_stem": False, + }, + "musicnet_thickstun_ext": { # exp4 + "eval_vocab": [MUSICNET_INSTR_CLASS], + "dataset_name": "musicnet", + "train_split": "train_thickstun", + "validation_split": "test_thickstun_ext", + "test_split": "test_thickstun_ext", + "has_stem": False, + }, + "musicnet_thickstun_ext_em": { # NOTE: this is not the use of external 'synth' in the paper, but the use of 'synth' within the dataset + "eval_vocab": [MUSICNET_INSTR_CLASS], + "dataset_name": "musicnet", + "train_split": "train_thickstun_em", + "validation_split": "test_thickstun_ext_em", + "test_split": "test_thickstun_ext_em", + "has_stem": False, + }, + "maps_default": { + "eval_vocab": [PIANO_SOLO_CLASS], + "dataset_name": "maps", + "train_split": "train", + "validation_split": "test", + "test_split": "test", + "has_stem": False, + }, + "maps_all": { + "eval_vocab": [None], + "dataset_name": "maps", + "train_split": "all", + "validation_split": None, + "test_split": None, + "has_stem": False, + }, + "maestro": { + "eval_vocab": [PIANO_SOLO_CLASS], + "dataset_name": "maestro", + "train_split": "train", + "validation_split": "validation", + "test_split": "test", + "has_stem": False, + }, + "maestro_final": { + "eval_vocab": [PIANO_SOLO_CLASS], + "dataset_name": "maestro", + "train_split": merge_splits(["train", "validation"], dataset_name="maestro"), + "validation_split": "test", + "test_split": "test", + "has_stem": False, + }, + "guitarset": { # 4 random players for train, 1 for valid, and 1 for test + "eval_vocab": [GUITAR_SOLO_CLASS], + "dataset_name": "guitarset", + "train_split": "train", + "validation_split": "validation", + "test_split": "test", + "has_stem": False, + }, + "guitarset_pshift": { # guitarset + pitch shift + "eval_vocab": [GUITAR_SOLO_CLASS], + "dataset_name": "guitarset", + "train_split": "train_pshift", + "validation_split": "validation", + "test_split": "test", + "has_stem": False, + }, + "guitarset_progression": { # progression 1 and 2 as train, progression 3 as test + "eval_vocab": [GUITAR_SOLO_CLASS], + "dataset_name": "guitarset", + "train_split": merge_splits(["progression_1", "progression_2"], dataset_name="guitarset"), + "validation_split": "progression_3", + "test_split": "progression_3", + "has_stem": False, + }, + "guitarset_progression_pshift": { # guuitarset_progression + pitch shift + "eval_vocab": [GUITAR_SOLO_CLASS], + "dataset_name": "guitarset", + "train_split": merge_splits(["progression_1_pshift", "progression_2_pshift"], dataset_name="guitarset"), + "validation_split": "progression_3", + "test_split": "progression_3", + "has_stem": False, + }, + "guitarset_minus_bn": { # guuitarset_style + pitch shift + "eval_vocab": [GUITAR_SOLO_CLASS], + "dataset_name": "guitarset", + "train_split": merge_splits(["Funk_pshift", "SS_pshift", "Jazz_pshift", "Rock_pshift"], + dataset_name="guitarset"), + "validation_split": "BN", + "test_split": "BN", + "has_stem": False, + }, + "guitarset_minus_funk": { # guuitarset_style + pitch shift + "eval_vocab": [GUITAR_SOLO_CLASS], + "dataset_name": "guitarset", + "train_split": merge_splits(["BN_pshift", "SS_pshift", "Jazz_pshift", "Rock_pshift"], + dataset_name="guitarset"), + "validation_split": "Funk", + "test_split": "Funk", + "has_stem": False, + }, + "guitarset_minus_ss": { # guuitarset_style + pitch shift + "eval_vocab": GUITAR_SOLO_CLASS, + "dataset_name": "guitarset", + "train_split": merge_splits(["BN_pshift", "Funk_pshift", "Jazz_pshift", "Rock_pshift"], + dataset_name="guitarset"), + "validation_split": "SS", + "test_split": "SS", + "has_stem": False, + }, + "guitarset_minus_jazz": { # guuitarset_style + pitch shift + "eval_vocab": [GUITAR_SOLO_CLASS], + "dataset_name": "guitarset", + "train_split": merge_splits(["BN_pshift", "Funk_pshift", "SS_pshift", "Rock_pshift"], + dataset_name="guitarset"), + "validation_split": "Jazz", + "test_split": "Jazz", + "has_stem": False, + }, + "guitarset_minus_rock": { # guuitarset_style + pitch shift + "eval_vocab": [GUITAR_SOLO_CLASS], + "dataset_name": "guitarset", + "train_split": merge_splits(["BN_pshift", "Funk_pshift", "SS_pshift", "Jazz_pshift"], + dataset_name="guitarset"), + "validation_split": "Rock", + "test_split": "Rock", + "has_stem": False, + }, + "guitarset_all": { + "eval_vocab": [None], + "dataset_name": "guitarset", + "train_split": "all", + "validation_split": None, + "test_split": None, + "has_stem": False, + }, + "enstdrums_dtp": { + "eval_vocab": [None], + "eval_drum_vocab": drum_vocab_presets["ksh"], + "dataset_name": "enstdrums", + "train_split": merge_splits(["drummer_1_dtp", "drummer_2_dtp", "drummer_1_dtp", "drummer_2_dtp"], dataset_name="enstdrums"), + "validation_split": "drummer_1_dtp", # for sanity check + "test_split": "drummer_3_dtp", + "has_stem": False, + }, + "enstdrums_dtm": { + "eval_vocab": [None], + "eval_drum_vocab": drum_vocab_presets["ksh"], + "dataset_name": "enstdrums", + "train_split": merge_splits(["drummer_1_dtm", "drummer_2_dtm", "drummer_1_dtp", "drummer_2_dtp"], dataset_name="enstdrums"), + "validation_split": "drummer_3_dtm_r2", # 0.6 * drum + "test_split": "drummer_3_dtm_r1", # 0.75 * drum + "has_stem": True, + }, + "enstdrums_random_dtm": { # single dataset training as a denoising ADT model + "eval_vocab": [None], + "eval_drum_vocab": drum_vocab_presets["ksh"], + "dataset_name": "enstdrums", + "train_split": "train_dtm", + "validation_split": "validation_dtm", + "test_split": "test_dtm", + "has_stem": True, + }, + "enstdrums_random": { # multi dataset training with random split of 70:15:15 + "eval_vocab": [None], + "eval_drum_vocab": drum_vocab_presets["ksh"], + "dataset_name": "enstdrums", + "train_split": "train_dtp", + "validation_split": "test_dtm", + "test_split": "test_dtm", + "has_stem": True, + }, + "enstdrums_random_plus_dtd": { # multi dataset training plus dtd + "eval_vocab": [None], + "eval_drum_vocab": drum_vocab_presets["ksh"], + "dataset_name": "enstdrums", + "train_split": merge_splits(["train_dtp", "all_dtd"], dataset_name="enstdrums"), + "validation_split": "test_dtm", + "test_split": "test_dtm", + "has_stem": True, + }, + "mir_st500": { + "eval_vocab": [SINGING_SOLO_CLASS], + "dataset_name": "mir_st500", + "train_split": "train_stem", + "validation_split": "test", + "test_split": "test", + "has_stem": True, + }, + "mir_st500_voc": { + "eval_vocab": [SINGING_SOLO_CLASS], + "dataset_name": "mir_st500", + "train_split": "train_vocal", + "validation_split": "test_vocal", + "test_split": "test_vocal", + "has_stem": False, + }, + "mir_st500_voc_debug": { # using train_vocal for test (for debugging) + "eval_vocab": [SINGING_SOLO_CLASS], + "dataset_name": "mir_st500", + "train_split": "train_vocal", + "validation_split": "test_vocal", + "test_split": "train_vocal", + "has_stem": False, + }, + "slakh": { + "eval_vocab": [GM_INSTR_CLASS], + "eval_drum_vocab": drum_vocab_presets["gm"], + "dataset_name": "slakh", + "train_split": "train", + "validation_split": "validation", + "test_split": "test", + "has_stem": True, + }, + "slakh_final": { + "eval_vocab": [GM_INSTR_CLASS], + "eval_drum_vocab": drum_vocab_presets["gm"], + "dataset_name": "slakh", + "train_split": merge_splits(["train", "validation"], dataset_name="slakh"), + "validation_split": "test", + "test_split": "test", + "has_stem": True, + }, + "rwc_pop_bass": { + "eval_vocab": [BASS_SOLO_CLASS], + "add_pitch_class_metric": ["Bass"], + "dataset_name": "rwc_pop", + "train_split": None, + "validation_split": "bass", + "test_split": "bass", + "has_stem": False, + }, + "rwc_pop_full": { + "eval_vocab": [GM_INSTR_CLASS_PLUS], + "add_pitch_class_metric": list(GM_INSTR_CLASS_PLUS.keys()), + "dataset_name": "rwc_pop", + "train_split": None, + "validation_split": "full", + "test_split": "full", + "has_stem": False, + }, + "egmd": { + "eval_vocab": [None], + "eval_drum_vocab": drum_vocab_presets["ksh"], + "dataset_name": "egmd", + "train_split": "train", + "validation_split": "validation", + "test_split": "test_reduced", # EGMD has 5000+ test files, so we reudce it to 200 files to save time + # "train_limit_num_files": 4402, #8804, # 17608, # limit the number of files for training to random choice of half. + "has_stem": False, + }, + "urmp": { + "eval_vocab": [GM_INSTR_CLASS], + "dataset_name": "urmp", + "train_split": "train", + "validation_split": "test", + "test_split": "test", + "has_stem": True, + }, + "cmedia": { + "eval_vocab": [SINGING_SOLO_CLASS], + "dataset_name": "cmedia", + "train_split": "train_stem", + "validation_split": "train", + "test_split": "train", + "has_stem": True, + }, + "cmedia_voc": { + "eval_vocab": [SINGING_SOLO_CLASS], + "dataset_name": "cmedia", + "train_split": "train_vocal", + "validation_split": "train_vocal", + "test_split": "train_vocal", + "has_stem": False, + }, + "idmt_smt_bass": { + "eval_vocab": [BASS_SOLO_CLASS], + "dataset_name": "idmt_smt_bass", + "train_split": "train", + "validation_split": "validation", + "test_split": "validation", + "has_stem": False, + }, + "geerdes": { # full mix dataset for evaluation + "eval_vocab": [GM_INSTR_CLASS_PLUS], + "dataset_name": "geerdes", + "train_split": None, + "validation_split": None, + "test_split": "all", + "has_stem": False, + }, + "geerdes_sep": { # Using vocal/accomp separation for evalutation + "eval_vocab": [GM_INSTR_CLASS_PLUS], + "dataset_name": "geerdes", + "train_split": None, + "validation_split": None, + "test_split": "all_sep", + "has_stem": False, + }, + "geerdes_half": { # Using half dataset for train/val + "eval_vocab": [GM_INSTR_CLASS_PLUS], + "dataset_name": "geerdes", + "train_split": "train", + "validation_split": "validation", + "test_split": "validation", + "has_stem": False, + }, + "geerdes_half_sep": { # Using half dataset with vocal/accomp separation for train/val + "eval_vocab": [GM_INSTR_CLASS_PLUS], + "dataset_name": "geerdes", + "train_split": "train_sep", + "validation_split": "validation_sep", + "test_split": "validation_sep", + "has_stem": False, + }, +} + +data_preset_multi_cfg = { + "musicnet_mt3_em_synth_plus_maps": { + "presets": ["musicnet_mt3_em_synth", "maps_all"], + "weights": [0.6, 0.4], + "eval_vocab": [MUSICNET_INSTR_CLASS], + }, + "musicnet_em_synth_table2_plus_maps": { + "presets": ["musicnet_em_synth_table2", "maps_all"], + "weights": [0.6, 0.4], + "eval_vocab": [MUSICNET_INSTR_CLASS], + }, + "musicnet_em_synth_table2_plus_maps_multi": { + "presets": ["musicnet_em_synth_table2", "maps_default"], + "weights": [0.6, 0.4], + "eval_vocab": [MUSICNET_INSTR_CLASS], + }, + "guitarset_progression_plus_maps": { + "presets": ["guitarset_progression", "maps_all"], + "weights": [0.5, 0.5], + "eval_vocab": [GUITAR_SOLO_CLASS], + }, + "guitarset_pshift_plus_maps": { + "presets": ["guitarset_pshift", "maps_default"], + "weights": [0.6, 0.4], + "eval_vocab": [merge_vocab([GUITAR_SOLO_CLASS, PIANO_SOLO_CLASS])], + }, + "guitarset_pshift_plus_musicnet_thick": { + "presets": ["guitarset_pshift", "musicnet_thickstun_em"], + "weights": [0.5, 0.5], + "eval_vocab": [merge_vocab([GUITAR_SOLO_CLASS, PIANO_SOLO_CLASS])], + }, + "multi_sanity_check": { + "presets": ["musicnet_mt3_synth_only", "musicnet_mt3_synth_only"], + "weights": [0.6, 0.4], + "eval_vocab": [MUSICNET_INSTR_CLASS], + }, + "all_mmegs": { + "presets": [ + "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp", "guitarset_pshift" + ], + "weights": [0.2, 0.2, 0.2, 0.2, 0.2], + "eval_vocab": [None] * 5, # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "all_gt_cv0": { + "presets": [ + "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp", "guitarset_minus_bn" + ], + "weights": [0.2, 0.2, 0.2, 0.2, 0.2], + "eval_vocab": [None] * 5, # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "all_gt_cv1": { + "presets": [ + "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp", + "guitarset_minus_funk" + ], + "weights": [0.2, 0.2, 0.2, 0.2, 0.2], + "eval_vocab": [None] * 5, # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "all_gt_cv2": { + "presets": [ + "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp", "guitarset_minus_ss" + ], + "weights": [0.2, 0.2, 0.2, 0.2, 0.2], + "eval_vocab": [None] * 5, # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "all_gt_cv3": { + "presets": [ + "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp", + "guitarset_minus_rock" + ], + "weights": [0.2, 0.2, 0.2, 0.2, 0.2], + "eval_vocab": [None] * 5, # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "all_gt_cv4": { + "presets": [ + "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp", + "guitarset_minus_jazz" + ], + "weights": [0.2, 0.2, 0.2, 0.2, 0.2], + "eval_vocab": [None] * 5, # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "all_enstdrums_random": { + "presets": [ + "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_random", "guitarset" + ], + "weights": [0.2, 0.2, 0.2, 0.2, 0.2], + "eval_vocab": [None] * 5, # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "all_plus_egmd": { + "presets": [ + "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_random_plus_dtd", + "guitarset", "egmd" + ], + "weights": [0.2, 0.2, 0.2, 0.1, 0.1, 0.2], + "eval_vocab": [None] * 6, # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "all_dtp_egmd": { + "presets": [ + "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp", "guitarset", "egmd" + ], + "weights": [0.2, 0.2, 0.2, 0.1, 0.1, 0.2], + "eval_vocab": [None] * 6, # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "all_weighted_slakh": { + "presets": [ + "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp", "guitarset_pshift", "egmd" + ], + "weights": [0.5, 0.1, 0.1, 0.05, 0.05, 0.2], + "eval_vocab": [None] * 6, # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "all_weighted_mt3": { # for comparison with MT3 + "presets": [ + "slakh", "musicnet_mt3", "mir_st500_voc", "enstdrums_dtp", + "guitarset_progression_pshift", "egmd" + ], + "weights": [0.5, 0.1, 0.1, 0.05, 0.05, 0.2], + "eval_vocab": [None] * 6, # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "all_weighted_mt3_em": { # musicnet_mt3_em + "presets": [ + "slakh", "musicnet_mt3_em", "mir_st500_voc", "enstdrums_dtp", + "guitarset_progression_pshift", "egmd" + ], + "weights": [0.5, 0.1, 0.1, 0.05, 0.05, 0.2], + "eval_vocab": [None] * 6, # None means instrument-agnoßstic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "all_urmp": { + "presets": [ + "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp", + "guitarset_pshift", "egmd", "urmp" + ], + "weights": [0.5, 0.2, 0.1, 0.05, 0.05, 0.05, 0.1], + "eval_vocab": [None] * 7, # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "all_urmp_mt3": { # for comparison with MT3 including URMP + "presets": [ + "slakh", "musicnet_mt3", "mir_st500_voc", "enstdrums_dtp", + "guitarset_progression", "egmd", "urmp" + ], + "weights": [0.5, 0.2, 0.1, 0.05, 0.05, 0.0125, 0.1], + "eval_vocab": [None] * 7, # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "all_urmp_mt3_em": { # musicnet_mt3_em including URMP + "presets": [ + "slakh", "musicnet_mt3_em", "mir_st500_voc", "enstdrums_dtp", + "guitarset_progression", "egmd", "urmp" + ], + "weights": [0.5, 0.2, 0.1, 0.05, 0.05, 0.0125, 0.1], + "eval_vocab": [None] * 7, # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "all_maestro": { # including Mestro and URMP + "presets": [ + "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp", + "guitarset_pshift", "egmd", "urmp", "maestro" + ], + "weights": [0.5, 0.1, 0.125, 0.075, 0.025, 0.01, 0.1, 0.1], + "eval_vocab": [None] * 8, # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "all_maestro_mt3": { # for comparison with MT3 including URMP + "presets": [ + "slakh", "musicnet_mt3", "mir_st500_voc", "enstdrums_dtp", + "guitarset_progression", "egmd", "urmp", "maestro" + ], + "weights": [0.5, 0.1, 0.1, 0.05, 0.05, 0.0125, 0.1, 0.1], + "eval_vocab": [None] * 8, # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "all_maestro_mt3_em": { # musicnet_mt3_em including URMP + "presets": [ + "slakh", "musicnet_mt3_em", "mir_st500_voc", "enstdrums_dtp", + "guitarset_progression", "egmd", "urmp", "maestro" + ], + "weights": [0.5, 0.1, 0.1, 0.05, 0.05, 0.0125, 0.1, 0.1], + "eval_vocab": [None] * 8, # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "singing_v1": { # slakh + mir_st500 without spleeter + "presets": ["slakh", "mir_st500"], + "weights": [0.8, 0.2], + "eval_vocab": [None, SINGING_SOLO_CLASS], # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "all_singing_v1": { # for singing-only task + "presets": [ + "slakh", "musicnet_thickstun_em", "mir_st500_stem", "enstdrums_dtp", + "guitarset_pshift", "egmd", "urmp", "maestro" + ], + "weights": [0.5, 0.1, 0.1, 0.05, 0.05, 0.0125, 0.1, 0.1], + "eval_vocab": [None, None, SINGING_SOLO_CLASS, None, None, None, None, None], # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "all_singing_drum_v1": { # for singing-only and drum-only tasks + "presets": [ + "slakh", "musicnet_thickstun_em", "mir_st500_stem", "enstdrums_dtm", + "guitarset_pshift", "egmd", "urmp", "maestro" + ], + "weights": [0.5, 0.1, 0.1, 0.05, 0.05, 0.0125, 0.1, 0.1], + "eval_vocab": [None, None, SINGING_SOLO_CLASS, None, None, None, None, None], # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "all_cross": { # including Mestro and URMP + "presets": [ + "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp", + "guitarset_pshift", "egmd", "urmp", "maestro" + ], + "weights": [0.5, 0.1, 0.125, 0.075, 0.025, 0.01, 0.1, 0.1], + "eval_vocab": [None, None, SINGING_SOLO_CLASS, None, None, None, None, None], # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "all_cross_rebal": { # rebalanced for cross-augment, using spleeter + "presets": [ + "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp", + "guitarset_pshift", "egmd", "urmp", "maestro" + ], + "weights": [0.4, 0.15, 0.15, 0.075, 0.025, 0.01, 0.1, 0.1], + "eval_vocab": [None, None, SINGING_SOLO_CLASS, None, None, None, None, None], # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "all_cross_rebal2": { # rebalanced for cross-augment, using spleeter + "presets": [ + "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp", + "guitarset_pshift", "egmd", "urmp", "maestro" + ], + "weights": [0.275, 0.19, 0.19, 0.1, 0.025, 0.02, 0.1, 0.1], + "eval_vocab": [None, None, SINGING_SOLO_CLASS, None, None, None, None, None], # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "all_cross_rebal4": { # rebalanced for cross-augment, using spleeter + "presets": [ + "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp", + "guitarset_pshift", "egmd", "urmp", "maestro" + ], + "weights": [0.258, 0.19, 0.2, 0.125, 0.022, 0.005, 0.1, 0.1], + "eval_vocab": [None, None, SINGING_SOLO_CLASS, None, None, None, None, None], # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "all_cross_rebal5": { # rebalanced for cross-augment, using spleeter + "presets": [ + "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp", + "guitarset_pshift", "egmd", "urmp", "maestro" + ], + "weights": [0.295, 0.19, 0.24, 0.05, 0.02, 0.005, 0.1, 0.1], + "eval_vocab": [None, None, SINGING_SOLO_CLASS, None, None, None, None, None], # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "all_cross_stem": { # accomp stem for sub-task learning + rebalanced for cross-augment + "presets": [ + "slakh", "musicnet_thickstun_em", "mir_st500_stem", "enstdrums_dtm", + "guitarset_pshift", "egmd", "urmp", "maestro" + ], + "weights": [0.4, 0.15, 0.15, 0.075, 0.025, 0.01, 0.1, 0.1], + "eval_vocab": [None, None, SINGING_SOLO_CLASS, None, None, None, None, None], # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "all_cross_stem_rebal3": { # accomp stem for sub-task learning + rebalanced for cross-augment + "presets": [ + "slakh", "musicnet_thickstun_em", "mir_st500_stem", "enstdrums_dtm", + "guitarset_pshift", "egmd", "urmp", "maestro" + ], + "weights": [0.265, 0.18, 0.21, 0.1, 0.025, 0.02, 0.1, 0.1], + "eval_vocab": [None, None, SINGING_SOLO_CLASS, None, None, None, None, None], # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "all_cross_v6": { # +cmeida +idmt_smt_bass + "presets": [ + "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp", + "guitarset", "egmd", "urmp", "maestro", "idmt_smt_bass", "cmedia_voc", + ], + "weights": [0.295, 0.19, 0.19, 0.05, 0.01, 0.005, 0.1, 0.1, 0.01, 0.05], + "eval_vocab": [None, None, SINGING_SOLO_CLASS, None, None, None, None, None, BASS_SOLO_CLASS, SINGING_SOLO_CLASS], # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "all_cross_v6_geerdes": { # +geerdes_half + "presets": [ + "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp", + "guitarset", "egmd", "urmp", "maestro", "idmt_smt_bass", "cmedia_voc", + "geerdes_half", "geerdes_half_sep" + ], + "weights": [0.295, 0.19, 0.19, 0.05, 0.01, 0.005, 0.075, 0.075, 0.01, 0.05, 0.025, 0.025], + "eval_vocab": [None, None, SINGING_SOLO_CLASS, None, None, None, None, None, BASS_SOLO_CLASS, + SINGING_SOLO_CLASS, GM_INSTR_CLASS_PLUS, GM_INSTR_CLASS_PLUS], # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "all_cross_v6_geerdes_rebal": { # +geerdes_half + "presets": [ + "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp", + "guitarset", "egmd", "urmp", "maestro", "idmt_smt_bass", "cmedia_voc", + "geerdes_half", "geerdes_half_sep" + ], + "weights": [0.245, 0.175, 0.19, 0.05, 0.01, 0.005, 0.075, 0.05, 0.01, 0.05, 0.075, 0.075], + "eval_vocab": [None, None, SINGING_SOLO_CLASS, None, None, None, None, None, BASS_SOLO_CLASS, + SINGING_SOLO_CLASS, GM_INSTR_EXT_CLASS_PLUS, GM_INSTR_EXT_CLASS_PLUS], # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "all_cross_v7": { + "presets": [ + "slakh", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp", + "guitarset_progression_pshift", "egmd", "urmp", "maestro", "idmt_smt_bass", "cmedia_voc", + ], + "weights": [0.295, 0.19, 0.191, 0.05, 0.01, 0.004, 0.1, 0.1, 0.01, 0.05], + "eval_vocab": [None, None, SINGING_SOLO_CLASS, None, None, None, None, None, BASS_SOLO_CLASS, SINGING_SOLO_CLASS], # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "all_cross_final": { + "presets": [ + "slakh_final", "musicnet_thickstun_em", "mir_st500_voc", "enstdrums_dtp", + "guitarset_progression_pshift", "egmd", "urmp", "maestro_final", "idmt_smt_bass", "cmedia_voc", + ], + "weights": [0.295, 0.19, 0.191, 0.05, 0.01, 0.004, 0.1, 0.1, 0.01, 0.05], + "eval_vocab": [None, None, SINGING_SOLO_CLASS, None, None, None, None, None, BASS_SOLO_CLASS, SINGING_SOLO_CLASS], # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "all_eval_final": { # The final evaluation set + "presets": [ + "slakh", "musicnet_thickstun", "musicnet_thickstun_em", "musicnet_thickstun_ext", + "musicnet_thickstun_ext_em", "mir_st500_voc", "mir_st500", "enstdrums_dtp", + "enstdrums_dtm", "guitarset_progression_pshift", "rwc_pop_bass", "maestro", "urmp", + "maps_default", "rwc_pop_full", # "geerdes", "geerdes_sep", + ], + "eval_vocab": [ + GM_INSTR_CLASS, MUSICNET_INSTR_CLASS, MUSICNET_INSTR_CLASS, MUSICNET_INSTR_CLASS, + MUSICNET_INSTR_CLASS, SINGING_SOLO_CLASS, SINGING_SOLO_CLASS, None, + None, None, BASS_SOLO_CLASS, PIANO_SOLO_CLASS, GM_INSTR_CLASS, + PIANO_SOLO_CLASS, GM_INSTR_CLASS_PLUS, # GM_INSTR_CLASS_PLUS, GM_INSTR_CLASS_PLUS + ], + "eval_drum_vocab": drum_vocab_presets["ksh"], + }, + "geerdes_eval": { # Geerdes evaluation sets for models trained without Geerdes. + "presets": ["geerdes_sep", "geerdes"], + "eval_vocab": [GM_INSTR_CLASS_PLUS, GM_INSTR_CLASS_PLUS], + "eval_drum_vocab": drum_vocab_presets["gm"], + }, + "geerdes_half_eval": { # Geerdes evaluation sets for models trained with Geerdes-half + "presets": ["geerdes_half_sep", "geerdes_half"], + "eval_vocab": [GM_INSTR_CLASS_PLUS, GM_INSTR_CLASS_PLUS], + "eval_drum_vocab": drum_vocab_presets["gm"], + }, + "minimal": { # slakh + mir_st500 with spleeter + "presets": ["slakh", "mir_st500_voc"], + "weights": [0.8, 0.2], + "eval_vocab": [None, SINGING_SOLO_CLASS], # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, + "singing_debug": { # slakh + mir_st500 with spleeter + "presets": ["mir_st500_voc_debug"], + "weights": [1.0], + "eval_vocab": [SINGING_SOLO_CLASS], # None means instrument-agnostic F1 for each dataset + "eval_drum_vocab": drum_vocab_presets["ksh"], # for drums, kick-snare-hihat metric + "val_max_num_files": 20, # max 20 files per dataset + "test_max_num_files": None, + }, +} diff --git a/amt/src/config/task.py b/amt/src/config/task.py new file mode 100644 index 0000000000000000000000000000000000000000..aca181f3461b1ff06d87b13937358338f5dd0503 --- /dev/null +++ b/amt/src/config/task.py @@ -0,0 +1,119 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +"""task.py""" +from config.vocabulary import * +from utils.note_event_dataclasses import Event + +task_cfg = { + "mt3_midi": { # 11 classes + drum class + "name": "mt3_midi", + "train_program_vocab": program_vocab_presets["mt3_midi"], + "train_drum_vocab": drum_vocab_presets["gm"], + }, + "mt3_midi_plus": { # 11 classes + singing + drum class + "name": "mt3_midi_plus", + "train_program_vocab": program_vocab_presets["mt3_midi_plus"], + "train_drum_vocab": drum_vocab_presets["gm"], + }, + "mt3_full": { # 34 classes (except drums) as in MT3 paper + "name": "mt3_full", + "train_program_vocab": program_vocab_presets["mt3_full"], + "train_drum_vocab": drum_vocab_presets["gm"], + }, + "mt3_full_plus": { # 34 classes (except drums) as in MT3 paper + singing + drum class + "name": "mt3_full_plus", + "train_program_vocab": program_vocab_presets["mt3_full_plus"], + "train_drum_vocab": drum_vocab_presets["gm"], + }, + "gm_ext_plus": { # 13 classes + singing + chorus (except drums) + "name": "gm_ext_plus", + "train_program_vocab": program_vocab_presets["gm_ext_plus"], + "train_drum_vocab": drum_vocab_presets["gm"], + }, + "singing_v1": { + "name": "singing", + "train_program_vocab": program_vocab_presets["mt3_full_plus"], + "train_drum_vocab": drum_vocab_presets["gm"], + "subtask_tokens": ["task", "transcribe_singing", "transcribe_all"], + "ignore_decoding_tokens": ["task", "transcribe_singing", "transcribe_all"], + "max_task_token_length": 2, + "eval_subtask_prefix": { + "default": [Event("transcribe_all", 0), Event("task", 0)], + "singing-only": [Event("transcribe_singing", 0), + Event("task", 0)], + } + }, + "singing_drum_v1": { + "name": "singing_drum", + "train_program_vocab": program_vocab_presets["mt3_full_plus"], + "train_drum_vocab": drum_vocab_presets["gm"], + "subtask_tokens": ["task", "transcribe_singing", "transcribe_drum", "transcribe_all"], + "ignore_decoding_tokens": [ + "task", "transcribe_singing", "transcribe_drum", "transcribe_all" + ], + "max_task_token_length": 2, + "eval_subtask_prefix": { + "default": [Event("transcribe_all", 0), Event("task", 0)], + "singing-only": [Event("transcribe_singing", 0), + Event("task", 0)], + "drum-only": [Event("transcribe_drum", 0), + Event("task", 0)], + } + }, + "mc13": { # multi-channel decoding task of {11 classes + drums + singing} + "name": "mc13", + "train_program_vocab": program_vocab_presets["gm_plus"], + "train_drum_vocab": drum_vocab_presets["gm"], + "num_decoding_channels": len(program_vocab_presets["gm_plus"]) + 1, # 13 + "max_note_token_length_per_ch": 512, # multi-channel decoding exclusive parameter + "mask_loss_strategy": None, # multi-channel decoding exclusive parameter + }, + "mc13_256": { # multi-channel decoding task of {11 classes + drums + singing} + "name": "mc13_256", + "train_program_vocab": program_vocab_presets["gm_plus"], + "train_drum_vocab": drum_vocab_presets["gm"], + "num_decoding_channels": len(program_vocab_presets["gm_plus"]) + 1, # 13 + "max_note_token_length_per_ch": 256, # multi-channel decoding exclusive parameter + "mask_loss_strategy": None, # multi-channel decoding exclusive parameter + }, + "mc13_full_plus": { # multi-channel decoding task of {34 classes + drums + singing & chorus} mapped to 13 channels + "name": "mc13_full_plus", + "train_program_vocab": program_vocab_presets["mt3_full_plus"], + "train_drum_vocab": drum_vocab_presets["gm"], + "program2channel_vocab_source": program_vocab_presets["gm_plus"], + "num_decoding_channels": 13, + "max_note_token_length_per_ch": 512, # multi-channel decoding exclusive parameter + "mask_loss_strategy": None, # multi-channel decoding exclusive parameter + }, + "mc13_full_plus_256": { # multi-channel decoding task of {34 classes + drums + singing & chorus} mapped to 13 channels + "name": "mc13_full_plus_256", + "train_program_vocab": program_vocab_presets["mt3_full_plus"], + "train_drum_vocab": drum_vocab_presets["gm"], + "program2channel_vocab_source": program_vocab_presets["gm_plus"], + "num_decoding_channels": 13, + "max_note_token_length_per_ch": 256, # multi-channel decoding exclusive parameter + "mask_loss_strategy": None, # multi-channel decoding exclusive parameter + }, + "exc_v1": { + "name": "exclusive", + "train_program_vocab": program_vocab_presets["mt3_full_plus"], + "train_drum_vocab": drum_vocab_presets["gm"], + "subtask_tokens": ["transcribe", "all", ":"], + # "ignore_decoding_tokens": [ + # "task", "transcribe_singing", "transcribe_drum", "transcribe_all" + # ], + # "max_task_token_length": 2, + "ignore_decoding_tokens_from_and_to": ["transcribe", ":"], + "eval_subtask_prefix": { # this is the main task that transcribe all instruments + "default": [Event("transcribe", 0), Event("all", 0), Event(":", 0)], + }, + "shuffle_subtasks": True, + }, +} diff --git a/amt/src/config/vocabulary.py b/amt/src/config/vocabulary.py new file mode 100644 index 0000000000000000000000000000000000000000..55de16d21e35652e1ca80ffdbe97dda69d00b4e7 --- /dev/null +++ b/amt/src/config/vocabulary.py @@ -0,0 +1,384 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +"""vocabulary.py + +Vocabulary for instrument classes. Vocabulary can be used as train_vocab +or test_vocab in data_presets.py or train.py arguments. + +- When it is used as train_vocab, it maps the instrument classes to the first + program number of the class. For example, if you use 'GM_INSTR_CLASS' as + train_vocab, then the program number of 'Piano' is [0,1,2,3,4,5,6,7]. These + program numbers are trained as program [0] in the model. + + - When it is used as eval_vocab, any program number in the instrument class + is considered as correct. + + +MUSICNET_INSTR_CLASS: 3 classes used for MusicNet benchmark +GM_INSTR_CLASS: equivalent to 'MIDI Class' defined by MT3. +GM_INSTR_CLASS_PLUS: GM_INSTR_CLASS + singing voice +GM_INSTR_FULL: 128 GM instruments, which is extended from 'MT3_FULL' +MT3_FULL: this matches the class names in Table 3 of MT3 paper +ENST_DRUM_NOTES: 20 drum notes used in ENST dataset +GM_DRUM_NOTES: 45 GM drum notes with percussions + +Program 128 is reserved for 'drum' internally. +Program 129 is reserved for 'unannotated', internally. +Program 100 is reserved for 'singing voice (melody)' in GM_INSTR_CLASS_PLUS. +Program 101 is reserved for 'singing voice (chorus)' in GM_INSTR_CLASS_PLUS. + + +""" +# yapf: disable +import numpy as np + +PIANO_SOLO_CLASS = { + "Piano": np.arange(0, 8), +} + +GUITAR_SOLO_CLASS = { + "Guitar": np.arange(24, 32), +} + +SINGING_SOLO_CLASS = { + "Singing Voice": [100, 101], +} + +SINGING_CHORUS_SEP_CLASS = { + "Singing Voice": [100], + "Singing Voice (chorus)": [101], +} + +BASS_SOLO_CLASS = { + "Bass": np.arange(32, 40), +} + +MUSICNET_INSTR_CLASS = { + "Piano": np.arange(0, 8), + "Strings": np.arange(40, 52), # Solo strings + ensemble strings + "Winds": np.arange(64, 80), # Reed + Pipe +} + +GM_INSTR_CLASS = { + "Piano": np.arange(0, 8), + "Chromatic Percussion": np.arange(8, 16), + "Organ": np.arange(16, 24), + "Guitar": np.arange(24, 32), + "Bass": np.arange(32, 40), + "Strings": np.arange(40, 56), # Strings + Ensemble + # "Strings": np.arange(40, 48), + # "Ensemble": np.arange(48, 56), + "Brass": np.arange(56, 64), + "Reed": np.arange(64, 72), + "Pipe": np.arange(72, 80), + "Synth Lead": np.arange(80, 88), + "Synth Pad": np.arange(88, 96), +} + +GM_INSTR_CLASS_PLUS = GM_INSTR_CLASS.copy() +GM_INSTR_CLASS_PLUS["Singing Voice"] = [100, 101] + +GM_INSTR_EXT_CLASS = { # Best for enjoyable MIDI file generation + "Acoustic Piano": [0, 1, 3, 6, 7], + "Electric Piano": [2, 4, 5], + "Chromatic Percussion": np.arange(8, 16), + "Organ": np.arange(16, 24), + "Guitar (clean)": np.arange(24, 28), + "Guitar (distortion)": [30, 28, 29, 31], # np.arange(28, 32), + "Bass": [33, 32, 34, 35, 36, 37, 38, 39], # np.arange(32, 40), + "Strings": [48, 40, 41, 42, 43, 44, 45, 46, 47, 49, 50, 51, 52, 53, 54, 55], # np.arange(40, 56), + "Brass": np.arange(56, 64), + "Reed": np.arange(64, 72), + "Pipe": np.arange(72, 80), + "Synth Lead": np.arange(80, 88), + "Synth Pad": np.arange(88, 96), +} +GM_INSTR_EXT_CLASS_PLUS = GM_INSTR_EXT_CLASS.copy() +GM_INSTR_EXT_CLASS_PLUS["Singing Voice"] = [100] +GM_INSTR_EXT_CLASS_PLUS["Singing Voice (chorus)"] = [101] + +GM_INSTR_FULL = { + "Acoustic Grand Piano": [0], + "Bright Acoustic Piano": [1], + "Electric Grand Piano": [2], + "Honky-tonk Piano": [3], + "Electric Piano 1": [4], + "Electric Piano 2": [5], + "Harpsichord": [6], + "Clavinet": [7], + "Celesta": [8], + "Glockenspiel": [9], + "Music Box": [10], + "Vibraphone": [11], + "Marimba": [12], + "Xylophone": [13], + "Tubular Bells": [14], + "Dulcimer": [15], + "Drawbar Organ": [16], + "Percussive Organ": [17], + "Rock Organ": [18], + "Church Organ": [19], + "Reed Organ": [20], + "Accordion": [21], + "Harmonica": [22], + "Tango Accordion": [23], + "Acoustic Guitar (nylon)": [24], + "Acoustic Guitar (steel)": [25], + "Electric Guitar (jazz)": [26], + "Electric Guitar (clean)": [27], + "Electric Guitar (muted)": [28], + "Overdriven Guitar": [29], + "Distortion Guitar": [30], + "Guitar Harmonics": [31], + "Acoustic Bass": [32], + "Electric Bass (finger)": [33], + "Electric Bass (pick)": [34], + "Fretless Bass": [35], + "Slap Bass 1": [36], + "Slap Bass 2": [37], + "Synth Bass 1": [38], + "Synth Bass 2": [39], + "Violin": [40], + "Viola": [41], + "Cello": [42], + "Contrabass": [43], + "Tremolo Strings": [44], + "Pizzicato Strings": [45], + "Orchestral Harp": [46], + "Timpani": [47], + "String Ensemble 1": [48], + "String Ensemble 2": [49], + "Synth Strings 1": [50], + "Synth Strings 2": [51], + "Choir Aahs": [52], + "Voice Oohs": [53], + "Synth Choir": [54], + "Orchestra Hit": [55], + "Trumpet": [56], + "Trombone": [57], + "Tuba": [58], + "Muted Trumpet": [59], + "French Horn": [60], + "Brass Section": [61], + "Synth Brass 1": [62], + "Synth Brass 2": [63], + "Soprano Sax": [64], + "Alto Sax": [65], + "Tenor Sax": [66], + "Baritone Sax": [67], + "Oboe": [68], + "English Horn": [69], + "Bassoon": [70], + "Clarinet": [71], + "Piccolo": [72], + "Flute": [73], + "Recorder": [74], + "Pan Flute": [75], + "Bottle Blow": [76], + "Shakuhachi": [77], + "Whistle": [78], + "Ocarina": [79], + "Lead 1 (square)": [80], + "Lead 2 (sawtooth)": [81], + "Lead 3 (calliope)": [82], + "Lead 4 (chiff)": [83], + "Lead 5 (charang)": [84], + "Lead 6 (voice)": [85], + "Lead 7 (fifths)": [86], + "Lead 8 (bass + lead)": [87], + "Pad 1 (new age)": [88], + "Pad 2 (warm)": [89], + "Pad 3 (polysynth)": [90], + "Pad 4 (choir)": [91], + "Pad 5 (bowed)": [92], + "Pad 6 (metallic)": [93], + "Pad 7 (halo)": [94], + "Pad 8 (sweep)": [95], + # "FX 1 (rain)": [96], + # "FX 2 (soundtrack)": [97], + # "FX 3 (crystal)": [98], + # "FX 4 (atmosphere)": [99], + # "FX 5 (brightness)": [100], + # "FX 6 (goblins)": [101], + # "FX 7 (echoes)": [102], + # "FX 8 (sci-fi)": [103], + # "Sitar": [104], + # "Banjo": [105], + # "Shamisen": [106], + # "Koto": [107], + # "Kalimba": [108], + # "Bagpipe": [109], + # "Fiddle": [110], + # "Shanai": [111], + # "Tinkle Bell": [112], + # "Agogo": [113], + # "Steel Drums": [114], + # "Woodblock": [115], + # "Taiko Drum": [116], + # "Melodic Tom": [117], + # "Synth Drum": [118], + # "Reverse Cymbal": [119], + # "Guitar Fret Noise": [120], + # "Breath Noise": [121], + # "Seashore": [122], + # "Bird Tweet": [123], + # "Telephone Ring": [124], + # "Helicopter": [125], + # "Applause": [126], + # "Gunshot": [127] +} + +MT3_FULL = { # this matches the class names in Table 3 of MT3 paper + "Acoustic Piano": [0, 1, 3, 6, 7], + "Electric Piano": [2, 4, 5], + "Chromatic Percussion": np.arange(8, 16), + "Organ": np.arange(16, 24), + "Acoustic Guitar": np.arange(24, 26), + "Clean Electric Guitar": np.arange(26, 29), + "Distorted Electric Guitar": np.arange(29, 32), + "Acoustic Bass": [32, 35], + "Electric Bass": [33, 34, 36, 37, 38, 39], + "Violin": [40], + "Viola": [41], + "Cello": [42], + "Contrabass": [43], + "Orchestral Harp": [46], + "Timpani": [47], + "String Ensemble": [48, 49, 44, 45], + "Synth Strings": [50, 51], + "Choir and Voice": [52, 53, 54], + "Orchestra Hit": [55], + "Trumpet": [56, 59], + "Trombone": [57], + "Tuba": [58], + "French Horn": [60], + "Brass Section": [61, 62, 63], + "Soprano/Alto Sax": [64, 65], + "Tenor Sax": [66], + "Baritone Sax": [67], + "Oboe": [68], + "English Horn": [69], + "Bassoon": [70], + "Clarinet": [71], + "Pipe": [73, 72, 74, 75, 76, 77, 78, 79], + "Synth Lead": np.arange(80, 88), + "Synth Pad": np.arange(88, 96), +} + +MT3_FULL_PLUS = MT3_FULL.copy() +MT3_FULL_PLUS["Singing Voice"] = [100] +MT3_FULL_PLUS["Singing Voice (chorus)"] = [101] + +ENST_DRUM_NOTES = { + "bd": [36], # Kick Drum + "sd": [38], # Snare Drum + "sweep": [0], # Brush sweep + "sticks": [1], # Sticks + "rs": [2], # Rim shot + "cs": [37], # X-stick + "chh": [42], # Closed Hi-Hat + "ohh": [46], # Open Hi-Hat + "cb": [56], # Cowbell + "c": [3], # Other Cymbals + "lmt": [47], # Low Mid Tom + "mt": [48], # Mid Tom + "mtr": [58], # Mid Tom Rim + "lt": [45], # Low Tom + "ltr": [50], # Low Tom Rim + "lft": [41], # Low Floor Tom + "rc": [51], # Ride Cymbal + "ch": [52], # Chinese Cymbal + "cr": [49], # Crash Cymbal + "spl": [55], # Splash Cymbal +} + +EGMD_DRUM_NOTES = { + "Kick Drum": [36], # Listed by order of most common annotation + "Snare X-stick": [37], # Snare X-Stick, https://youtu.be/a2KFrrKaoYU?t=80 + "Snare Drum": [38], # Snare (head) and Electric Snare + "Closed Hi-Hat": [42, 44, 22], # 44 is pedal hi-hat + "Open Hi-Hat": [46, 26], + "Cowbell": [56], + "High Floor Tom": [43], + "Low Floor Tom": [41], # Lowest Tom + "Low Tom": [45], + "Low-Mid Tom": [47], + "Mid Tom": [48], + "Low Tom (Rim)": [50], # TD-17: 47, 50, 58 + "Mid Tom (Rim)": [58], + # "Ride Cymbal": [51, 53, 59], + "Ride": [51], + "Ride (Bell)": [53], # https://youtu.be/b94hZoM5s3k?t=323 + "Ride (Edge)": [59], + "Chinese Cymbal": [52], + "Crash Cymbal": [49, 57], + "Splash Cymbal": [55], +} + +# Inspired by Roland TD-17 MIDI note map, https://rolandus.zendesk.com/hc/en-us/articles/360005173411-TD-17-Default-Factory-MIDI-Note-Map +GM_DRUM_NOTES = { + "Kick Drum": [36, 35], # Listed by order of most common annotation + "Snare X-stick": [37, 2], # Snare X-Stick, https://youtu.be/a2KFrrKaoYU?t=80 + "Snare Drum": [38, 40], # Snare (head) and Electric Snare + "Closed Hi-Hat": [42, 44, 22], # 44 is pedal hi-hat + "Open Hi-Hat": [46, 26], + "Cowbell": [56], + "High Floor Tom": [43], + "Low Floor Tom": [41], # Lowest Tom + "Low Tom": [45], + "Low-Mid Tom": [47], + "Mid Tom": [48], + "Low Tom (Rim)": [50], # TD-17: 47, 50, 58 + "Mid Tom (Rim)": [58], + # "Ride Cymbal": [51, 53, 59], + "Ride": [51], + "Ride (Bell)": [53], # https://youtu.be/b94hZoM5s3k?t=323 + "Ride (Edge)": [59], + "Chinese Cymbal": [52], + "Crash Cymbal": [49, 57], + "Splash Cymbal": [55], +} + +KICK_SNARE_HIHAT = { + "Kick Drum": [36, 35], + "Snare Drum": [38, 40], + # "Snare Drum + X-Stick": [38, 40, 37, 2], + # "Snare X-stick": [37, 2], # Snare X-Stick, https://youtu.be/a2KFrrKaoYU?t=80 + "Hi-Hat": [42, 44, 46, 22, 26], + # "Ride Cymbal": [51, 53, 59], + # "Hi-Hat + Ride": [42, 44, 46, 22, 26, 51, 53, 59], + # "HiHat + all Cymbals": [42, 44, 46, 22, 26, 51, 53, 59, 52, 49, 57, 55], + # "Kick Drum + Low Tom": [36, 35, 45], + # "All Cymbal": [51, 53, 59, 52, 49, 57, 55] + # "all": np.arange(30, 60) +} + +drum_vocab_presets = { + "gm": GM_DRUM_NOTES, + "egmd": EGMD_DRUM_NOTES, + "enst": ENST_DRUM_NOTES, + "ksh": KICK_SNARE_HIHAT, + "kshr": { + "Kick Drum": [36, 35], + "Snare Drum": [38, 40], + "Hi-Hat": [42, 44, 46, 22, 26, 51, 53, 59], + } +} + +program_vocab_presets = { + "gm_full": GM_INSTR_FULL, # 96 classes (except drums) + "mt3_full": MT3_FULL, # 34 classes (except drums) as in MT3 paper + "mt3_midi": GM_INSTR_CLASS, # 11 classes (except drums) as in MT3 paper + "mt3_midi_plus": GM_INSTR_CLASS_PLUS, # 11 classes + singing (except drums) + "mt3_full_plus": MT3_FULL_PLUS, # 34 classes (except drums) mt3_full + singing (except drums) + "gm": GM_INSTR_CLASS, # 11 classes (except drums) + "gm_plus": GM_INSTR_CLASS_PLUS, # 11 classes + singing (except drums) + "gm_ext_plus": GM_INSTR_EXT_CLASS_PLUS, # 13 classes + singing + chorus (except drums) +} diff --git a/amt/src/extras/.DS_Store b/amt/src/extras/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..e3902033640d68d922c0d9020010e49dceda5a29 Binary files /dev/null and b/amt/src/extras/.DS_Store differ diff --git a/amt/src/extras/Dockerfile b/amt/src/extras/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..95cc82a62cf20cba34fd2935c949a1e625e91b83 --- /dev/null +++ b/amt/src/extras/Dockerfile @@ -0,0 +1,18 @@ +FROM pytorch/pytorch:2.0.1-cuda11.7-cudnn8-devel +LABEL maintainer="https://github.com/mimbres/YourMT3" + +ENV TZ=Europe/London \ + DEBIAN_FRONTEND=noninteractive +RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone + +RUN apt-get update +ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 + +RUN apt-get update --fix-missing && apt-get install -y wget curl \ + nano git ffmpeg sox tmux htop +RUN pip3 install --upgrade pip +RUN pip3 install mirdata mido git+https://github.com/craffel/mir_eval.git \ + matplotlib lightning>=2.0.2 pytest-timeout pytest deprecated librosa \ + einops transformers wandb + +CMD [ "/bin/bash" ] \ No newline at end of file diff --git a/amt/src/extras/check_drum_channel_slakh.py b/amt/src/extras/check_drum_channel_slakh.py new file mode 100644 index 0000000000000000000000000000000000000000..fd89011bfe7465f6a35a62ff8d650f2269c24212 --- /dev/null +++ b/amt/src/extras/check_drum_channel_slakh.py @@ -0,0 +1,24 @@ +from utils.mirdata_dev.datasets import slakh16k + + +def check_drum_channel_slakh(data_home: str): + ds = slakh16k.Dataset(data_home, version='default') + for track_id in ds.track_ids: + is_drum = ds.track(track_id).is_drum + midi = MidiFile(ds.track(track_id).midi_path) + cnt = 0 + for msg in midi: + if 'note' in msg.type: + if is_drum and (msg.channel != 9): + print('found drum track with channel != 9 in track_id: ', + track_id) + if not is_drum and (msg.channel == 9): + print( + 'found non-drum track with channel == 9 in track_id: ', + track_id) + if is_drum and (msg.channel == 9): + cnt += 1 + if cnt > 0: + print(f'found {cnt} notes in drum track with ch 9 in track_id: ', + track_id) + return \ No newline at end of file diff --git a/amt/src/extras/dataset_mutable_var_sanity_check.py b/amt/src/extras/dataset_mutable_var_sanity_check.py new file mode 100644 index 0000000000000000000000000000000000000000..e6a9dd5c7e91b209e435de5a1a37ea9f22ef80e0 --- /dev/null +++ b/amt/src/extras/dataset_mutable_var_sanity_check.py @@ -0,0 +1,81 @@ +for n in range(1000): + sampled_data = ds.__getitem__(n) + + a = deepcopy(sampled_data['note_event_segments']) + b = deepcopy(sampled_data['note_event_segments']) + + for (note_events, tie_note_events, start_time) in list(zip(*b.values())): + note_events = pitch_shift_note_events(note_events, 2) + tie_note_events = pitch_shift_note_events(tie_note_events, 2) + + # compare + for i, (note_events, tie_note_events, start_time) in enumerate(list(zip(*b.values()))): + for j, ne in enumerate(note_events): + if ne.is_drum is False: + if ne.pitch != a['note_events'][i][j].pitch + 2: + print(i, j) + assert ne.pitch == a['note_events'][i][j].pitch + 2 + + for k, tne in enumerate(tie_note_events): + assert tne.pitch == a['tie_note_events'][i][k].pitch + 2 + + print('test {} passed'.format(n)) + + +def assert_note_events_almost_equal(actual_note_events, + predicted_note_events, + ignore_time=False, + ignore_activity=True, + delta=5.1e-3): + """ + Asserts that the given lists of Note instances are equal up to a small + floating-point tolerance, similar to `assertAlmostEqual` of `unittest`. + Tolerance is 5e-3 by default, which is 5 ms for 100 ticks-per-second. + + If `ignore_time` is True, then the time field is ignored. (useful for + comparing tie note events, default is False) + + If `ignore_activity` is True, then the activity field is ignored (default + is True). + """ + assert len(actual_note_events) == len(predicted_note_events) + for j, (actual_note_event, + predicted_note_event) in enumerate(zip(actual_note_events, predicted_note_events)): + if ignore_time is False: + assert abs(actual_note_event.time - predicted_note_event.time) <= delta + assert actual_note_event.is_drum == predicted_note_event.is_drum + if actual_note_event.is_drum is False and predicted_note_event.is_drum is False: + assert actual_note_event.program == predicted_note_event.program + assert actual_note_event.pitch == predicted_note_event.pitch + assert actual_note_event.velocity == predicted_note_event.velocity + if ignore_activity is False: + assert actual_note_event.activity == predicted_note_event.activity + + +cache_old = deepcopy(dict(ds.cache)) +for n in range(500): + sampled_data = ds.__getitem__(n) + cache_new = ds.cache + cnt = 0 + for k, v in cache_new.items(): + if k in cache_old: + cnt += 1 + assert (cache_new[k]['programs'] == cache_old[k]['programs']).all() + assert (cache_new[k]['is_drum'] == cache_old[k]['is_drum']).all() + assert (cache_new[k]['has_stems'] == cache_old[k]['has_stems']) + assert (cache_new[k]['has_unannotated'] == cache_old[k]['has_unannotated']) + assert (cache_new[k]['audio_array'] == cache_old[k]['audio_array']).all() + + for nes_new, nes_old in zip(cache_new[k]['note_event_segments']['note_events'], + cache_old[k]['note_event_segments']['note_events']): + assert_note_events_almost_equal(nes_new, nes_old) + + for tnes_new, tnes_old in zip(cache_new[k]['note_event_segments']['tie_note_events'], + cache_old[k]['note_event_segments']['tie_note_events']): + assert_note_events_almost_equal(tnes_new, tnes_old, ignore_time=True) + + for s_new, s_old in zip(cache_new[k]['note_event_segments']['start_times'], + cache_old[k]['note_event_segments']['start_times']): + assert s_new == s_old + cache_old = deepcopy(dict(ds.cache)) + print(n, cnt) diff --git a/amt/src/extras/datasets_eval_testing.py b/amt/src/extras/datasets_eval_testing.py new file mode 100644 index 0000000000000000000000000000000000000000..e70805d7dcec406a8038bda1c359f5f4d731bd38 --- /dev/null +++ b/amt/src/extras/datasets_eval_testing.py @@ -0,0 +1,42 @@ +from utils.datasets_eval import AudioFileDataset +from torch.utils.data import DataLoader +import pytorch_lightning as pl + + +def test(): + + ds = AudioFileDataset() + dl = DataLoader( + ds, batch_size=None, collate_fn=lambda k: k + ) # empty collate_fn is required to use mixed types. + + for x, y in dl: + break + + class MyModel(pl.LightningModule): + + def __init__(self, **kwargs): + super().__init__() + + def forward(self, x): + return x + + def training_step(self, batch, batch_idx): + return 0 + + def validation_step(self, batch, batch_idx): + print(batch) + return 0 + + def train_dataloader(self): + return dl + + def val_dataloader(self): + return dl + + def configure_optimizers(self): + return None + + model = MyModel() + trainer = pl.Trainer() + trainer.validate(model) \ No newline at end of file diff --git a/amt/src/extras/demo_cross_augmentation.py b/amt/src/extras/demo_cross_augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..fb21d608f09c72f053180ba5470fb2e8ac16998b --- /dev/null +++ b/amt/src/extras/demo_cross_augmentation.py @@ -0,0 +1,69 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +from typing import Dict, Tuple +from copy import deepcopy +import soundfile as sf +import torch +from utils.data_modules import AMTDataModule +from config.data_presets import data_preset_single_cfg, data_preset_multi_cfg +from utils.augment import intra_stem_augment_processor + + +def get_ds(data_preset_multi: Dict, train_num_samples_per_epoch: int = 90000): + dm = AMTDataModule(data_preset_multi=data_preset_multi, train_num_samples_per_epoch=train_num_samples_per_epoch) + dm.setup('fit') + dl = dm.train_dataloader() + ds = dl.flattened[0].dataset + return ds + + +def debug_func(num_segments: int = 10): + sampled_data, sampled_ids = ds._get_rand_segments_from_cache(num_segments) + ux_sampled_data, _ = ds._get_rand_segments_from_cache(ux_count_sum, False, sampled_ids) + s = deepcopy(sampled_data) + intra_stem_augment_processor(sampled_data, submix_audio=False) + + +def gen_audio(index: int = 0): + # audio_arr: (b, 1, nframe), note_token_arr: (b, l), task_token_arr: (b, task_l) + audio_arr, note_token_arr, task_token_arr = ds.__getitem__(index) + + # merge all the segments into one audio file + audio = audio_arr.permute(0, 2, 1).reshape(-1).squeeze().numpy() + + # save the audio file + sf.write('xaug_demo_audio.wav', audio, 16000, subtype='PCM_16') + + +data_preset_multi = data_preset_multi_cfg["all_cross_rebal5"] +ds = get_ds(data_preset_multi) +ds.random_amp_range = [0.8, 1.1] +ds.stem_xaug_policy = { + "max_k": 5, + "tau": 0.3, + "alpha": 1.0, + "max_subunit_stems": 12, + "no_instr_overlap": True, + "no_drum_overlap": True, + "uhat_intra_stem_augment": True, +} +gen_audio(3) + +# for k in ds.cache.keys(): +# arr = ds.cache[k]['audio_array'] +# arr = np.sum(arr, axis=1).reshape(-1) +# # sf.write(f'xxx/{k}.wav', arr, 16000, subtype='PCM_16') +# if np.min(arr) > -0.5: +# print(k) + +# arr = ds.cache[52]['audio_array'] +# for i in range(arr.shape[1]): +# a = arr[:, i, :].reshape(-1) +# sf.write(f'xxx52/52_{i}.wav', a, 16000, subtype='PCM_16') diff --git a/amt/src/extras/demo_intra_augmentation.py b/amt/src/extras/demo_intra_augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..d9b2ba36e0f945c771d1e0336d8c0ce287c011be --- /dev/null +++ b/amt/src/extras/demo_intra_augmentation.py @@ -0,0 +1,52 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +import numpy as np +import torch +import json +import soundfile as sf +from utils.datasets_train import get_cache_data_loader + + +def get_filelist(track_id: int) -> dict: + filelist = '../../data/yourmt3_indexes/slakh_train_file_list.json' + with open(filelist, 'r') as f: + fl = json.load(f) + new_filelist = dict() + for key, value in fl.items(): + if int(key) == track_id: + new_filelist[0] = value + return new_filelist + + +def get_ds(track_id: int, random_amp_range: list = [1., 1.], stem_aug_prob: float = 0.8): + filelist = get_filelist(track_id) + dl = get_cache_data_loader(filelist, + 'train', + 1, + 1, + random_amp_range=random_amp_range, + stem_aug_prob=stem_aug_prob, + shuffle=False) + ds = dl.dataset + return ds + + +def gen_audio(track_id: int, n_segments: int = 30, random_amp_range: list = [1., 1.], stem_aug_prob: float = 0.8): + ds = get_ds(track_id, random_amp_range, stem_aug_prob) + audio = [] + for i in range(n_segments): + audio.append(ds.__getitem__(0)[0]) + # audio.append(ds.__getitem__(i)[0]) + + audio = torch.concat(audio, dim=2).numpy()[0, 0, :] + sf.write('audio.wav', audio, 16000, subtype='PCM_16') + + +gen_audio(1, 20) diff --git a/amt/src/extras/download_mirst500.py b/amt/src/extras/download_mirst500.py new file mode 100644 index 0000000000000000000000000000000000000000..9c79736ecf3594e646730c7c7ef511c43f5c7d5b --- /dev/null +++ b/amt/src/extras/download_mirst500.py @@ -0,0 +1,50 @@ +import os +import json +import numpy as np +from pytube import YouTube + + +def downloadMp3(yt, idx, askPath=0): + # extract only audio + video = yt.streams.filter(only_audio=True).first() + + destination = 'mp3File' + # check for destination to save file + if (askPath == 1): + print("Enter the destination (leave blank for default dir mp3File)") + destination = str(input(">> ")) or 'mp3File' + + # download the file + out_file = video.download(output_path=destination) + + # save the file + # base, ext = os.path.splitext(out_file) + dir_path, file_base = os.path.split(out_file) + + new_file = os.path.join(dir_path, f'{idx}.mp3') + os.rename(out_file, new_file) + # result of success + print(yt.title + " has been successfully downloaded.") + + +MISSING_FILE_IDS = [ + 16, 26, 33, 38, 40, 50, 53, 55, 60, 81, 82, 98, 107, 122, 126, 127, 129, 141, 145, 150, 172, + 201, 205, 206, 215, 216, 221, 226, 232, 240, 243, 245, 255, 257, 267, 273, 278, 279, 285, 287, + 291, 304, 312, 319, 321, 325, 329, 332, 333, 336, 337, 342, 359, 375, 402, 417, 438, 445, 454, + 498 +] + +data_link_file = '../../../data/mir_St500_yourmt3_16k/MIR-ST500_20210206/MIR-ST500_link.json' +data_link = json.load(open(data_link_file, 'r')) +download_fail = [] + +for i in MISSING_FILE_IDS: + print(f'Downloading {i}...') + yt = YouTube(data_link[str(i)]) + try: + downloadMp3(yt, idx=i) + except: + download_fail.append(i) + print(f'Failed to download {i}.') + +print(f'Failed to download {len(download_fail)} files: {download_fail}') \ No newline at end of file diff --git a/amt/src/extras/fig/label_smooth_interval_of_interest.png b/amt/src/extras/fig/label_smooth_interval_of_interest.png new file mode 100644 index 0000000000000000000000000000000000000000..0df751cbc10cd50a91e520e934ef597e36c60b79 Binary files /dev/null and b/amt/src/extras/fig/label_smooth_interval_of_interest.png differ diff --git a/amt/src/extras/fig/pitchshift_benchnmark.png b/amt/src/extras/fig/pitchshift_benchnmark.png new file mode 100644 index 0000000000000000000000000000000000000000..4ceed96c858172dd5b87643e2215f8541d936cc4 Binary files /dev/null and b/amt/src/extras/fig/pitchshift_benchnmark.png differ diff --git a/amt/src/extras/fig/pitchshift_stretch_and_resampler_process_time.png b/amt/src/extras/fig/pitchshift_stretch_and_resampler_process_time.png new file mode 100644 index 0000000000000000000000000000000000000000..40bd49f6906558af08bed5371debcc545a7e8637 Binary files /dev/null and b/amt/src/extras/fig/pitchshift_stretch_and_resampler_process_time.png differ diff --git a/amt/src/extras/inspecting_slakh_bass.py b/amt/src/extras/inspecting_slakh_bass.py new file mode 100644 index 0000000000000000000000000000000000000000..e86e517896630078007022508ec28b6ab20800fb --- /dev/null +++ b/amt/src/extras/inspecting_slakh_bass.py @@ -0,0 +1,34 @@ +import mirdata +from utils.mirdata_dev.datasets import slakh16k + +ds = slakh16k.Dataset(data_home='../../data', version='2100-yourmt3-16k') +mtrack_ids = ds.mtrack_ids + +# Collect plugin names +plugin_names = set() +cnt = 0 +for mtrack_id in mtrack_ids: + mtrack = ds.multitrack(mtrack_id) + for track_id in mtrack.track_ids: + track = ds.track(track_id) + if track.instrument.lower() == 'bass': + if track.plugin_name == 'upright_bass.nkm': + print(f'{str(cnt)}: {track_id}: {track.plugin_name}') + # if track.plugin_name not in plugin_names: + # plugin_names.add(track.plugin_name) + # print(f'{str(cnt)}: {track_id}: {track.plugin_name}') + # cnt += 1 +""" +0: Track00001-S03: scarbee_rickenbacker_bass_palm_muted.nkm +1: Track00002-S01: classic_bass.nkm +2: Track00004-S01: scarbee_rickenbacker_bass.nkm +3: Track00005-S04: scarbee_jay_bass_both.nkm +4: Track00006-S03: pop_bass.nkm +5: Track00008-S00: scarbee_pre_bass.nkm +6: Track00013-S00: jazz_upright.nkm +7: Track00014-S01: funk_bass.nkm +8: Track00016-S01: scarbee_mm_bass.nkm +9: Track00024-S07: upright_bass.nkm +10: Track00027-S03: scarbee_jay_bass_slap_both.nkm +11: Track00094-S08: upright_bass2.nkm +""" \ No newline at end of file diff --git a/amt/src/extras/install_deepspeed.md b/amt/src/extras/install_deepspeed.md new file mode 100644 index 0000000000000000000000000000000000000000..25921090d2a29f3f6117a0a0a85a6423b0335278 --- /dev/null +++ b/amt/src/extras/install_deepspeed.md @@ -0,0 +1,28 @@ +""" + +# not required on pytorch 2.0:latest container +pip install cupy-cuda11x -f https://pip.cupy.dev/aarch64 + +apt-get update +apt-get install git +apt-get install libaio-dev + +DS_BUILD_OPS=1 pip install deepspeed +ds_report + + +pip install deepspeed==0.7.7 + +git clone https://github.com/NVIDIA/apex +cd apex +pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ + +In case you have trouble building apex from source we recommend using the NGC containers + from here which come with a pre-built PyTorch and apex release. + +nvcr.io/nvidia/pytorch:23.01-py3 + +pip install deepspeed, pip install transformers[deepspeed] +https://www.deepspeed.ai/docs/config-json/#autotuning + +""" diff --git a/amt/src/extras/label_smoothing.py b/amt/src/extras/label_smoothing.py new file mode 100644 index 0000000000000000000000000000000000000000..a7e052dd808238e2d6e8eab0353a1b67c481a3a5 --- /dev/null +++ b/amt/src/extras/label_smoothing.py @@ -0,0 +1,67 @@ +import torch +import numpy as np +import matplotlib.pyplot as plt + +a = torch.signal.windows.gaussian(11, sym=True, std=3) +plt.plot(a) + + +def gaussian_smoothing(y_hot, mu=5, sigma=0.865): + """ + y_hot: one-hot encoded array + """ + #sigma = np.sqrt(np.abs(np.log(0.05) / ((4 - mu)**2))) / 2 + + # Generate index array + i = np.arange(len(y_hot)) + + # Gaussian function + y_smooth = np.exp(-(i - mu)**2 / (2 * sigma**2)) + + # Normalize the resulting array + y_smooth /= y_smooth.sum() + return y_smooth, sigma + + +# y_ls = (1 - α) * y_hot + α / K, where K is the number of classes, alpha is the smoothing parameter + +y_hot = torch.zeros(11) +y_hot[5] = 1 +plt.plot(y_hot, 'b.-') + +alpha = 0.3 +y_ls = (1 - alpha) * y_hot + alpha / 10 +plt.plot(y_ls, 'r.-') + +y_gs, std = gaussian_smoothing(y_hot, A=0.5) +plt.plot(y_gs, 'g.-') + +y_gst_a, std = gaussian_smoothing(y_hot, A=0.5, mu=5.5) +plt.plot(y_gst_a, 'y.-') + +y_gst_b, std = gaussian_smoothing(y_hot, A=0.5, mu=5.8) +plt.plot(y_gst_b, 'c.-') + +plt.legend([ + 'y_hot', 'label smoothing' + '\n' + '(alpha=0.3)', + 'gaussian smoothing' + '\n' + 'for interval of interest' + '\n' + 'mu=5', + 'gaussian smoothing' + '\n' + 'mu=5.5', 'gaussian smoothing' + '\n' + 'mu=5.8' +]) + +plt.grid() +plt.xticks(np.arange(11), np.arange(0, 110, 10)) +plt.xlabel('''Time (ms) +original (quantized) one hot label: +[0,0,0,0,0,1,0,0,0,0,0] +\n +label smooting is defined as: + y_ls = (1 - α) * y_hot + α / K, +where K is the number of classes, α is the smoothing parameter +\n +gaussian smoothing for the interval (± 10ms) of interest: +y_gs = A * exp(-(i - mu)**2 / (2 * sigma**2)) +with sigma = 0.865 an mu = 5 +\n +gaussian smoothing with unqunatized target timing: +mu = 5.5 for 55ms target timing +''') diff --git a/amt/src/extras/multi_channel_seqlen_stats.py b/amt/src/extras/multi_channel_seqlen_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..52804eb622653f776f1e1a89b43a263dd72e14a4 --- /dev/null +++ b/amt/src/extras/multi_channel_seqlen_stats.py @@ -0,0 +1,177 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +from typing import Dict, Tuple +from copy import deepcopy +from collections import Counter +import numpy as np +import torch +from utils.data_modules import AMTDataModule +from utils.task_manager import TaskManager +from config.data_presets import data_preset_single_cfg, data_preset_multi_cfg +from utils.augment import intra_stem_augment_processor + + +def get_ds(data_preset_multi: Dict, task_name: str, train_num_samples_per_epoch: int = 90000): + tm = TaskManager(task_name=task_name) + tm.max_note_token_length_per_ch = 1024 # only to check the max length + dm = AMTDataModule(data_preset_multi=data_preset_multi, + task_manager=tm, + train_num_samples_per_epoch=train_num_samples_per_epoch) + dm.setup('fit') + dl = dm.train_dataloader() + ds = dl.flattened[0].dataset + return ds + + +data_preset_multi = data_preset_multi_cfg["all_cross_v6"] +task_name = "mc13" # "mt3_full_plus" +ds = get_ds(data_preset_multi, task_name=task_name) +ds.random_amp_range = [0.8, 1.1] +ds.stem_xaug_policy = { + "max_k": 5, + "tau": 0.3, + "alpha": 1.0, + "max_subunit_stems": 12, + "no_instr_overlap": True, + "no_drum_overlap": True, + "uhat_intra_stem_augment": True, +} + +length_all = [] +for i in range(40000): + if i % 5000 == 0: + print(i) + audio_arr, note_token_arr, task_totken_arr, pshift_steps = ds.__getitem__(i) + lengths = torch.sum(note_token_arr != 0, dim=2).flatten().cpu().tolist() + length_all.extend(lengths) + +length_all = np.asarray(length_all) + +# stats +empty_sequence = np.sum(length_all < 3) / len(length_all) * 100 +print("empty_sequences:", f"{empty_sequence:.2f}", "%") + +mean_except_empty = np.mean(length_all[length_all > 2]) +print("mean_except_empty:", mean_except_empty) + +median_except_empty = np.median(length_all[length_all > 2]) +print("median_except_empty:", median_except_empty) + +ch_less_than_768 = np.sum(length_all < 768) / len(length_all) * 100 +print("ch_less_than_768:", f"{ch_less_than_768:.2f}", "%") + +ch_larger_than_512 = np.sum(length_all > 512) / len(length_all) * 100 +print("ch_larger_than_512:", f"{ch_larger_than_512:.6f}", "%") + +ch_larger_than_256 = np.sum(length_all > 256) / len(length_all) * 100 +print("ch_larger_than_256:", f"{ch_larger_than_256:.6f}", "%") + +ch_larger_than_128 = np.sum(length_all > 128) / len(length_all) * 100 +print("ch_larger_than_128:", f"{ch_larger_than_128:.6f}", "%") + +ch_larger_than_64 = np.sum(length_all > 64) / len(length_all) * 100 +print("ch_larger_than_64:", f"{ch_larger_than_64:.6f}", "%") + +song_length_all = length_all.reshape(-1, 13) +song_larger_than_512 = 0 +song_larger_than_256 = 0 +song_larger_than_128 = 0 +song_larger_than_64 = 0 +for l in song_length_all: + if np.sum(l > 512) > 0: + song_larger_than_512 += 1 + if np.sum(l > 256) > 0: + song_larger_than_256 += 1 + if np.sum(l > 128) > 0: + song_larger_than_128 += 1 + if np.sum(l > 64) > 0: + song_larger_than_64 += 1 +num_songs = len(song_length_all) +print("song_larger_than_512:", f"{song_larger_than_512/num_songs*100:.4f}", "%") +print("song_larger_than_256:", f"{song_larger_than_256/num_songs*100:.4f}", "%") +print("song_larger_than_128:", f"{song_larger_than_128/num_songs*100:.4f}", "%") +print("song_larger_than_64:", f"{song_larger_than_64/num_songs*100:.4f}", "%") + +instr_dict = { + 0: "Piano", + 1: "Chromatic Percussion", + 2: "Organ", + 3: "Guitar", + 4: "Bass", + 5: "Strings + Ensemble", + 6: "Brass", + 7: "Reed", + 8: "Pipe", + 9: "Synth Lead", + 10: "Synth Pad", + 11: "Singing", + 12: "Drums", +} +cnt_larger_than_512 = Counter() +for i in np.where(length_all > 512)[0] % 13: + cnt_larger_than_512[i] += 1 +print("larger_than_512:") +for k, v in cnt_larger_than_512.items(): + print(f" - {instr_dict[k]}: {v}") + +cnt_larger_than_256 = Counter() +for i in np.where(length_all > 256)[0] % 13: + cnt_larger_than_256[i] += 1 +print("larger_than_256:") +for k, v in cnt_larger_than_256.items(): + print(f" - {instr_dict[k]}: {v}") + +cnt_larger_than_128 = Counter() +for i in np.where(length_all > 128)[0] % 13: + cnt_larger_than_128[i] += 1 +print("larger_than_128:") +for k, v in cnt_larger_than_128.items(): + print(f" - {instr_dict[k]}: {v}") +""" +empty_sequences: 91.06 % +mean_except_empty: 36.68976799156269 +median_except_empty: 31.0 +ch_less_than_768: 100.00 % +ch_larger_than_512: 0.000158 % +ch_larger_than_256: 0.015132 % +ch_larger_than_128: 0.192061 % +ch_larger_than_64: 0.661260 % +song_larger_than_512: 0.0021 % +song_larger_than_256: 0.1926 % +song_larger_than_128: 2.2280 % +song_larger_than_64: 6.1033 % + +larger_than_512: + - Guitar: 7 + - Strings + Ensemble: 3 +larger_than_256: + - Piano: 177 + - Guitar: 680 + - Strings + Ensemble: 79 + - Organ: 2 + - Chromatic Percussion: 11 + - Bass: 1 + - Synth Lead: 2 + - Brass: 1 + - Reed: 5 +larger_than_128: + - Guitar: 4711 + - Strings + Ensemble: 1280 + - Piano: 5548 + - Bass: 211 + - Synth Pad: 22 + - Pipe: 18 + - Chromatic Percussion: 55 + - Synth Lead: 22 + - Organ: 75 + - Reed: 161 + - Brass: 45 + - Drums: 11 +""" diff --git a/amt/src/extras/npy_speed_benchmark.py b/amt/src/extras/npy_speed_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..a0f311b67932fdc3276c72e63e54ae11efb62770 --- /dev/null +++ b/amt/src/extras/npy_speed_benchmark.py @@ -0,0 +1,187 @@ +import os +from tasks.utils.event_codec import Event, EventRange +from tasks.utils import event_codec + +ec = event_codec.Codec( + max_shift_steps=1000, # this means 0,1,...,1000 + steps_per_second=100, + event_ranges=[ + EventRange('pitch', min_value=0, max_value=127), + EventRange('velocity', min_value=0, max_value=1), + EventRange('tie', min_value=0, max_value=0), + EventRange('program', min_value=0, max_value=127), + EventRange('drum', min_value=0, max_value=127), + ], + ) + +events = [ + Event(type='shift', value=0), # actually not needed + Event(type='shift', value=1), # 10 ms shift + Event(type='shift', value=1000), # 10 s shift + Event(type='pitch', value=0), # lowest pitch 8.18 Hz + Event(type='pitch', value=60), # C4 or 261.63 Hz + Event(type='pitch', value=127), # highest pitch G9 or 12543.85 Hz + Event(type='velocity', value=0), # lowest velocity) + Event(type='velocity', value=1), # lowest velocity) + Event(type='tie', value=0), # tie + Event(type='program', value=0), # program + Event(type='program', value=127), # program + Event(type='drum', value=0), # drum + Event(type='drum', value=127), # drum +] + +events = events * 100 +tokens = [ec.encode_event(e) for e in events] +tokens = np.array(tokens, dtype=np.int16) + +import csv +# Save events to a CSV file +with open('events.csv', 'w', newline='') as file: + writer = csv.writer(file) + for event in events: + writer.writerow([event.type, event.value]) + +# Load events from a CSV file +with open('events.csv', 'r') as file: + reader = csv.reader(file) + events2 = [Event(row[0], int(row[1])) for row in reader] + + +import json +# Save events to a JSON file +with open('events.json', 'w') as file: + json.dump([event.__dict__ for event in events], file) + +# Load events from a JSON file +with open('events.json', 'r') as file: + events = [Event(**event_dict) for event_dict in json.load(file)] + + + + +"""----------------------------""" +# Write the tokens to a npy file +import numpy as np +np.save('tokens.npy', tokens) + +def t_npy(): + t = np.load('tokens.npy', allow_pickle=True) # allow pickle doesn't affect speed + +os.makedirs('temp', exist_ok=True) +for i in range(2400): + np.save(f'temp/tokens{i}.npy', tokens) + +def t_npy2400(): + for i in range(2400): + t = np.load(f'temp/tokens{i}.npy') +def t_npy2400_take200(): + for i in range(200): + t = np.load(f'temp/tokens{i}.npy') + +import shutil +shutil.rmtree('temp', ignore_errors=True) + +# Write the 2400 tokens to a single npy file +data = dict() +for i in range(2400): + data[f'arr{i}'] = tokens.copy() +np.save(f'tokens_2400x.npy', data) +def t_npy2400single(): + t = np.load('tokens_2400x.npy', allow_pickle=True).item() + +def t_mmap2400single(): + t = np.load('tokens_2400x.npy', mmap_mode='r') + +# Write the tokens to a npz file +np.savez('tokens.npz', arr0=tokens) +def t_npz(): + npz_file = np.load('tokens.npz') + tt = npz_file['arr0'] + +data = dict() +for i in range(2400): + data[f'arr{i}'] = tokens +np.savez('tokens.npz', **data ) +def t_npz2400(): + npz_file = np.load('tokens.npz') + for i in range(2400): + tt = npz_file[f'arr{i}'] + +def t_npz2400_take200(): + npz_file = np.load('tokens.npz') + # npz_file.files + for i in range(200): + tt = npz_file[f'arr{i}'] + + +# Write the tokens to a txt file +with open('tokens.txt', 'w') as file: + file.write(' '.join(map(str, tokens))) + +def t_txt(): + # Read the tokens from the file + with open('tokens.txt', 'r') as file: + t = list(map(int, file.read().split())) + t = np.array(t) + + +# Write the tokens to a CSV file +with open('tokens.csv', 'w', newline='') as file: + writer = csv.writer(file) + writer.writerow(tokens) + +def t_csv(): + # Read the tokens from the CSV file + with open('tokens.csv', 'r') as file: + reader = csv.reader(file) + t = list(map(int, next(reader))) + t = np.array(t) + + +# Write the tokens to a JSON file +with open('tokens.json', 'w') as file: + json.dump(tokens, file) + +def t_json(): + # Read the tokens from the JSON file + with open('tokens.json', 'r') as file: + t = json.load(file) + t = np.array(t) + +with open('tokens_2400x.json', 'w') as file: + json.dump(data, file) + +def t_json2400single(): + # Read the tokens from the JSON file + with open('tokens_2400x.json', 'r') as file: + t = json.load(file) + +def t_mmap(): + t = np.load('tokens.npy', mmap_mode='r') + +# Write the tokens to bytes file + + + + +np.savetxt('tokens.ntxt', tokens) +def t_ntxt(): + t = np.loadtxt('tokens.ntxt').astype(np.int32) + +%timeit t_npz() # 139 us +%timeit t_mmap() # 3.12 ms +%timeit t_npy() # 87.8 us +%timeit t_txt() # 109 152 us +%timeit t_csv() # 145 190 us +%timeit t_json() # 72.8 119 us +%timeit t_ntxt() # 878 us + +%timeit t_npy2400() # 212 ms; 2400 files in a folder +%timeit t_npz2400() # 296 ms; uncompreesed 1000 arrays in a single file + +%timeit t_npy2400_take200() # 17.4 ms; 25 Mb +%timeit t_npz2400_take200() # 28.8 ms; 3.72 ms for 10 arrays; 25 Mb +%timeit t_npy2400single() # 4 ms; frozen dictionary containing 2400 arrays; 6.4 Mb; int16 +%timeit t_mmap2400single() # dictionary is not supported +%timeit t_json2400single() # 175 ms; 17 Mb +# 2400 files from 100ms hop for 4 minutes \ No newline at end of file diff --git a/amt/src/extras/perceivertf_inspect.py b/amt/src/extras/perceivertf_inspect.py new file mode 100644 index 0000000000000000000000000000000000000000..9ca467326c0fb973da042c96350c5bb6d22d7c4e --- /dev/null +++ b/amt/src/extras/perceivertf_inspect.py @@ -0,0 +1,640 @@ +import numpy as np +import torch +import torch.nn.functional as F + + +def l2_normalize(matrix): + """ + L2 Normalize the matrix along its rows. + + Parameters: + matrix (numpy.ndarray): The input matrix. + + Returns: + numpy.ndarray: The L2 normalized matrix. + """ + l2_norms = np.linalg.norm(matrix, axis=1, keepdims=True) + normalized_matrix = matrix / l2_norms + return normalized_matrix + + +def z_normalize(matrix): + """ + Z-normalize the matrix along its rows (mean=0 and std=1). + Z-normalization is also known as "standardization", and derives from z-score. + Z = (X - mean) / std + Z-nomarlized, each row has mean=0 and std=1. + + Parameters: + matrix (numpy.ndarray): The input matrix. + + Returns: + numpy.ndarray: The Z normalized matrix. + """ + mean = np.mean(matrix, axis=1, keepdims=True) + std = np.std(matrix, axis=1, keepdims=True) + normalized_matrix = (matrix - mean) / std + return normalized_matrix + + +def l2_normalize_tensors(tensor_tuple): + """ + Applies L2 normalization on the last two dimensions for each tensor in a tuple. + + Parameters: + tensor_tuple (tuple of torch.Tensor): A tuple containing N tensors, each of shape (1, k, 30, 30). + + Returns: + tuple of torch.Tensor: A tuple containing N L2-normalized tensors. + """ + normalized_tensors = [] + for tensor in tensor_tuple: + # Ensure the tensor is a floating-point type + tensor = tensor.float() + + # Calculate L2 norm on the last two dimensions, keeping the dimensions using keepdim=True + l2_norm = torch.linalg.norm(tensor, dim=(-2, -1), keepdim=True) + + # Apply L2 normalization + normalized_tensor = tensor / ( + l2_norm + 1e-7) # Small value to avoid division by zero + + normalized_tensors.append(normalized_tensor) + + return tuple(normalized_tensors) + + +def z_normalize_tensors(tensor_tuple): + """ + Applies Z-normalization on the last two dimensions for each tensor in a tuple. + + Parameters: + tensor_tuple (tuple of torch.Tensor): A tuple containing N tensors, each of shape (1, k, 30, 30). + + Returns: + tuple of torch.Tensor: A tuple containing N Z-normalized tensors. + """ + normalized_tensors = [] + for tensor in tensor_tuple: + # Ensure the tensor is a floating-point type + tensor = tensor.float() + + # Calculate mean and std on the last two dimensions + mean = tensor.mean(dim=(-2, -1), keepdim=True) + std = tensor.std(dim=(-2, -1), keepdim=True) + + # Apply Z-normalization + normalized_tensor = (tensor - mean) / ( + std + 1e-7) # Small value to avoid division by zero + + normalized_tensors.append(normalized_tensor) + + return tuple(normalized_tensors) + + +def apply_temperature_to_attention_tensors(tensor_tuple, temperature=1.0): + """ + Applies temperature scaling to the attention weights in each tensor in a tuple. + + Parameters: + tensor_tuple (tuple of torch.Tensor): A tuple containing N tensors, + each of shape (1, k, 30, 30). + temperature (float): Temperature parameter to control the sharpness + of the attention weights. Default is 1.0. + + Returns: + tuple of torch.Tensor: A tuple containing N tensors with scaled attention weights. + """ + scaled_attention_tensors = [] + + for tensor in tensor_tuple: + # Ensure the tensor is a floating-point type + tensor = tensor.float() + + # Flatten the last two dimensions + flattened_tensor = tensor.reshape(1, tensor.shape[1], + -1) # Modified line here + + # Apply temperature scaling and softmax along the last dimension + scaled_attention = flattened_tensor / temperature + scaled_attention = F.softmax(scaled_attention, dim=-1) + + # Reshape to original shape + scaled_attention = scaled_attention.view_as(tensor) + + scaled_attention_tensors.append(scaled_attention) + + return tuple(scaled_attention_tensors) + + +def shorten_att(tensor_tuple, length=30): + shortend_tensors = [] + for tensor in tensor_tuple: + shortend_tensors.append(tensor[:, :, :length, :length]) + return tuple(shortend_tensors) + + +def keep_top_k(matrix, k=6): + """ + Keep only the top k values in each row, set the rest to 0. + + Parameters: + matrix (numpy.ndarray): The input matrix. + k (int): The number of top values to keep in each row. + + Returns: + numpy.ndarray: The transformed matrix. + """ + topk_indices_per_row = np.argpartition(matrix, -k, axis=1)[:, -k:] + result_matrix = np.zeros_like(matrix) + + for i in range(matrix.shape[0]): + result_matrix[i, topk_indices_per_row[i]] = matrix[ + i, topk_indices_per_row[i]] + return result_matrix + + +def test_case_forward_enc_perceiver_tf_dec_t5(): + import torch + from model.ymt3 import YourMT3 + from config.config import audio_cfg, model_cfg, shared_cfg + model_cfg["encoder_type"] = "perceiver-tf" + model_cfg["encoder"]["perceiver-tf"]["attention_to_channel"] = True + model_cfg["encoder"]["perceiver-tf"]["num_latents"] = 24 + model_cfg["decoder_type"] = "t5" + model_cfg["pre_decoder_type"] = "default" + + audio_cfg["codec"] = "spec" + audio_cfg["hop_length"] = 300 + model = YourMT3(audio_cfg=audio_cfg, model_cfg=model_cfg) + model.eval() + + # x = torch.randn(2, 1, 32767) + # labels = torch.randint(0, 400, (2, 1024), requires_grad=False) + + # # forward + # output = model.forward(x, labels) + + # # inference + # result = model.inference(x, None) + + # display latents + checkpoint = torch.load( + "../logs/ymt3/ptf_all_cross_rebal5_spec300_xk2_amp0811_edr_005_attend_c_full_plus_b52/checkpoints/model.ckpt", + map_location="cpu") + state_dict = checkpoint['state_dict'] + new_state_dict = { + k: v + for k, v in state_dict.items() if 'pitchshift' not in k + } + model.load_state_dict(new_state_dict, strict=False) + + latents = model.encoder.latent_array.latents.detach().numpy() + import matplotlib.pyplot as plt + import numpy as np + from sklearn.metrics.pairwise import cosine_similarity + cos = cosine_similarity(latents) + + from utils.data_modules import AMTDataModule + from einops import rearrange + dm = AMTDataModule(data_preset_multi={"presets": ["slakh"]}) + dm.setup("test") + dl = dm.test_dataloader() + ds = list(dl.values())[0].dataset + audio, notes, tokens, _ = ds.__getitem__(7) + x = audio[[16], ::] + label = tokens[[16], :] + # spectrogram + x_spec = model.spectrogram(x) + plt.imshow(x_spec[0].detach().numpy().T, aspect='auto', origin='lower') + plt.title("spectrogram") + plt.xlabel('time step') + plt.ylabel('frequency bin') + plt.show() + x_conv = model.pre_encoder(x_spec) + # Create a larger figure + plt.figure( + figsize=(15, + 10)) # Adjust these numbers as needed for width and height + plt.subplot(2, 4, 1) + plt.imshow(x_spec[0].detach().numpy().T, aspect='auto', origin='lower') + plt.title("spectrogram") + plt.xlabel('time step') + plt.ylabel('frequency bin') + plt.subplot(2, 4, 2) + plt.imshow(x_conv[0][:, :, 0].detach().numpy().T, + aspect='auto', + origin='lower') + plt.title("conv(spec), ch=0") + plt.xlabel('time step') + plt.ylabel('F') + plt.subplot(2, 4, 3) + plt.imshow(x_conv[0][:, :, 42].detach().numpy().T, + aspect='auto', + origin='lower') + plt.title("ch=42") + plt.xlabel('time step') + plt.ylabel('F') + plt.subplot(2, 4, 4) + plt.imshow(x_conv[0][:, :, 80].detach().numpy().T, + aspect='auto', + origin='lower') + plt.title("ch=80") + plt.xlabel('time step') + plt.ylabel('F') + plt.subplot(2, 4, 5) + plt.imshow(x_conv[0][:, :, 11].detach().numpy().T, + aspect='auto', + origin='lower') + plt.title("ch=11") + plt.xlabel('time step') + plt.ylabel('F') + plt.subplot(2, 4, 6) + plt.imshow(x_conv[0][:, :, 20].detach().numpy().T, + aspect='auto', + origin='lower') + plt.title("ch=20") + plt.xlabel('time step') + plt.ylabel('F') + plt.subplot(2, 4, 7) + plt.imshow(x_conv[0][:, :, 77].detach().numpy().T, + aspect='auto', + origin='lower') + plt.title("ch=77") + plt.xlabel('time step') + plt.ylabel('F') + plt.subplot(2, 4, 8) + plt.imshow(x_conv[0][:, :, 90].detach().numpy().T, + aspect='auto', + origin='lower') + plt.title("ch=90") + plt.xlabel('time step') + plt.ylabel('F') + plt.tight_layout() + plt.show() + + # encoding + output = model.encoder(inputs_embeds=x_conv, + output_hidden_states=True, + output_attentions=True) + enc_hs_all, att, catt = output["hidden_states"], output[ + "attentions"], output["cross_attentions"] + enc_hs_last = enc_hs_all[2] + + # enc_hs: time-varying encoder hidden state + plt.subplot(2, 3, 1) + plt.imshow(enc_hs_all[0][0][:, :, 21].detach().numpy().T) + plt.title('ENC_HS B0, d21') + plt.colorbar(orientation='horizontal') + plt.ylabel('latent k') + plt.xlabel('t') + plt.subplot(2, 3, 4) + plt.imshow(enc_hs_all[0][0][:, :, 127].detach().numpy().T) + plt.colorbar(orientation='horizontal') + plt.title('B0, d127') + plt.ylabel('latent k') + plt.xlabel('t') + plt.subplot(2, 3, 2) + plt.imshow(enc_hs_all[1][0][:, :, 21].detach().numpy().T) + plt.colorbar(orientation='horizontal') + plt.title('B1, d21') + plt.ylabel('latent k') + plt.xlabel('t') + plt.subplot(2, 3, 5) + plt.imshow(enc_hs_all[1][0][:, :, 127].detach().numpy().T) + plt.colorbar(orientation='horizontal') + plt.title('B1, d127') + plt.ylabel('latent k') + plt.xlabel('t') + plt.subplot(2, 3, 3) + plt.imshow(enc_hs_all[2][0][:, :, 21].detach().numpy().T) + plt.colorbar(orientation='horizontal') + plt.title('B2, d21') + plt.ylabel('latent k') + plt.xlabel('t') + plt.subplot(2, 3, 6) + plt.imshow(enc_hs_all[2][0][:, :, 127].detach().numpy().T) + plt.colorbar(orientation='horizontal') + plt.title('B2, d127') + plt.ylabel('latent k') + plt.xlabel('t') + plt.tight_layout() + plt.show() + + enc_hs_proj = model.pre_decoder(enc_hs_last) + plt.imshow(enc_hs_proj[0].detach().numpy()) + plt.title( + 'ENC_HS_PROJ: linear projection of encoder output, which is used for enc-dec cross attention' + ) + plt.colorbar(orientation='horizontal') + plt.ylabel('latent k') + plt.xlabel('d') + plt.show() + + plt.subplot(221) + plt.imshow(enc_hs_all[2][0][0, :, :].detach().numpy(), aspect='auto') + plt.title('enc_hs, t=0') + plt.ylabel('latent k') + plt.xlabel('d') + plt.subplot(222) + plt.imshow(enc_hs_all[2][0][10, :, :].detach().numpy(), aspect='auto') + plt.title('enc_hs, t=10') + plt.ylabel('latent k') + plt.xlabel('d') + plt.subplot(223) + plt.imshow(enc_hs_all[2][0][20, :, :].detach().numpy(), aspect='auto') + plt.title('enc_hs, t=20') + plt.ylabel('latent k') + plt.xlabel('d') + plt.subplot(224) + plt.imshow(enc_hs_all[2][0][30, :, :].detach().numpy(), aspect='auto') + plt.title('enc_hs, t=30') + plt.ylabel('latent k') + plt.xlabel('d') + plt.tight_layout() + plt.show() + + # enc_hs correlation: which dim has most unique info? + plt.subplot(1, 3, 1) + a = rearrange(enc_hs_last, '1 t k d -> t (k d)').detach().numpy() + plt.imshow(cosine_similarity(a)) + plt.title("enc hs, t x t cos_sim") + plt.subplot(1, 3, 2) + b = rearrange(enc_hs_last, '1 t k d -> k (t d)').detach().numpy() + plt.imshow(cosine_similarity(b)) + plt.title("enc hs, k x k cos_sim") + plt.subplot(1, 3, 3) + c = rearrange(enc_hs_last, '1 t k d -> d (k t)').detach().numpy() + plt.imshow(cosine_similarity(c)) + plt.title("cross att, d x d cos_sim") + plt.tight_layout() + plt.show() + + # enc latent + plt.imshow(model.encoder.latent_array.latents.detach().numpy()) + plt.title('latent array') + plt.xlabel('d') + plt.ylabel('latent k') + plt.show() + + # enc Spectral Cross Attention: (T x head x K x D). How latent K attends to conv channel C? + plt.subplot(311) + plt.imshow( + torch.sum(torch.sum(catt[0][0], axis=0), axis=0).detach().numpy()) + plt.title('block=0') + plt.ylabel('latent k') + plt.xlabel('conv channel') + plt.subplot(312) + plt.imshow( + torch.sum(torch.sum(catt[1][0], axis=0), axis=0).detach().numpy()) + plt.title('block=1') + plt.ylabel('latent k') + plt.xlabel('conv channel') + plt.subplot(313) + plt.imshow( + torch.sum(torch.sum(catt[2][0], axis=0), axis=0).detach().numpy()) + plt.title('block=2') + plt.ylabel('latent k') + plt.xlabel('conv channel') + plt.tight_layout() + plt.show() + # enc Latent Self-attention: How latent K attends to K? + plt.subplot(231) + plt.imshow(torch.sum(torch.sum(att[0][0], axis=1), + axis=0).detach().numpy(), + origin='upper') + plt.title('B0L0') + plt.xlabel('latent k') + plt.ylabel('latent k') + plt.subplot(234) + plt.imshow(torch.sum(torch.sum(att[0][1], axis=1), + axis=0).detach().numpy(), + origin='upper') + plt.title('B0L1') + plt.xlabel('latent k') + plt.ylabel('latent k') + plt.subplot(232) + plt.imshow(torch.sum(torch.sum(att[1][0], axis=1), + axis=0).detach().numpy(), + origin='upper') + plt.title('B1L0') + plt.xlabel('latent k') + plt.ylabel('latent k') + plt.subplot(235) + plt.imshow(torch.sum(torch.sum(att[1][1], axis=1), + axis=0).detach().numpy(), + origin='upper') + plt.title('B1L1') + plt.xlabel('latent k') + plt.ylabel('latent k') + plt.subplot(233) + plt.imshow(torch.sum(torch.sum(att[2][0], axis=1), + axis=0).detach().numpy(), + origin='upper') + plt.title('B2L0') + plt.xlabel('latent k') + plt.ylabel('latent k') + plt.subplot(236) + plt.imshow(torch.sum(torch.sum(att[2][1], axis=1), + axis=0).detach().numpy(), + origin='upper') + plt.title('B2L1') + plt.xlabel('latent k') + plt.ylabel('latent k') + plt.tight_layout() + plt.show() + # Time varying, different head for latent self-attention + plt.subplot(231) + plt.imshow(att[0][0][30, 3, :, :].detach().numpy()) + plt.title('B0L0, t=30, Head=3') + plt.colorbar(orientation='horizontal') + plt.xlabel('k') + plt.ylabel('k') + plt.subplot(234) + plt.imshow(att[0][1][30, 3, :, :].detach().numpy()) + plt.title('B0L1, t=30, Head=3') + plt.colorbar(orientation='horizontal') + plt.xlabel('k') + plt.ylabel('k') + plt.subplot(232) + plt.imshow(att[1][0][30, 3, :, :].detach().numpy()) + plt.title('B1L0, t=30, Head=3') + plt.colorbar(orientation='horizontal') + plt.xlabel('k') + plt.ylabel('k') + plt.subplot(235) + plt.imshow(att[1][1][30, 3, :, :].detach().numpy()) + plt.title('B1L1, t=30, Head=3') + plt.colorbar(orientation='horizontal') + plt.xlabel('k') + plt.ylabel('k') + plt.subplot(233) + plt.imshow(att[2][0][30, 3, :, :].detach().numpy()) + plt.title('B2L0, t=30, Head=3') + plt.colorbar(orientation='horizontal') + plt.xlabel('k') + plt.ylabel('k') + plt.subplot(236) + plt.imshow(att[2][1][30, 3, :, :].detach().numpy()) + plt.title('B2L1, t=30, Head=3') + plt.colorbar(orientation='horizontal') + plt.xlabel('k') + plt.ylabel('k') + plt.tight_layout() + plt.show() + plt.subplot(231) + plt.imshow(att[0][0][30, 5, :, :].detach().numpy()) + plt.title('B0L0, t=30, Head=5') + plt.colorbar(orientation='horizontal') + plt.xlabel('k') + plt.ylabel('k') + plt.subplot(234) + plt.imshow(att[0][1][30, 5, :, :].detach().numpy()) + plt.title('B0L1, t=30, Head=5') + plt.colorbar(orientation='horizontal') + plt.xlabel('k') + plt.ylabel('k') + plt.subplot(232) + plt.imshow(att[1][0][30, 5, :, :].detach().numpy()) + plt.title('B1L0, t=30, Head=5') + plt.colorbar(orientation='horizontal') + plt.xlabel('k') + plt.ylabel('k') + plt.subplot(235) + plt.imshow(att[1][1][30, 5, :, :].detach().numpy()) + plt.title('B1L1, t=30, Head=5') + plt.colorbar(orientation='horizontal') + plt.xlabel('k') + plt.ylabel('k') + plt.subplot(233) + plt.imshow(att[2][0][30, 5, :, :].detach().numpy()) + plt.title('B2L0, t=30, Head=5') + plt.colorbar(orientation='horizontal') + plt.xlabel('k') + plt.ylabel('k') + plt.subplot(236) + plt.imshow(att[2][1][30, 5, :, :].detach().numpy()) + plt.title('B2L1, t=30, Head=5') + plt.colorbar(orientation='horizontal') + plt.xlabel('k') + plt.ylabel('k') + plt.tight_layout() + plt.show() + + # Temporal Self-attention: (K x H x T x T) How time t attends to time t? + plt.subplot(231) + plt.imshow(torch.sum(torch.sum(att[0][2], axis=1), + axis=0).detach().numpy(), + origin='upper') + plt.title('B0L2') + plt.xlabel('t') + plt.ylabel('t') + plt.subplot(234) + plt.imshow(torch.sum(torch.sum(att[0][3], axis=1), + axis=0).detach().numpy(), + origin='upper') + plt.title('B0L3') + plt.xlabel('t') + plt.ylabel('t') + plt.subplot(232) + plt.imshow(torch.sum(torch.sum(att[1][2], axis=1), + axis=0).detach().numpy(), + origin='upper') + plt.title('B1L2') + plt.xlabel('t') + plt.ylabel('t') + plt.subplot(235) + plt.imshow(torch.sum(torch.sum(att[1][3], axis=1), + axis=0).detach().numpy(), + origin='upper') + plt.title('B1L3') + plt.xlabel('t') + plt.ylabel('t') + plt.subplot(233) + plt.imshow(torch.sum(torch.sum(att[2][2], axis=1), + axis=0).detach().numpy(), + origin='upper') + plt.title('B2L2') + plt.xlabel('t') + plt.ylabel('t') + plt.subplot(236) + plt.imshow(torch.sum(torch.sum(att[2][3], axis=1), + axis=0).detach().numpy(), + origin='upper') + plt.title('B2L3') + plt.xlabel('t') + plt.ylabel('t') + plt.tight_layout() + plt.show() + + # decoding + dec_input_ids = model.shift_right_fn(label) + dec_inputs_embeds = model.embed_tokens(dec_input_ids) + dec_output = model.decoder(inputs_embeds=dec_inputs_embeds, + encoder_hidden_states=enc_hs_proj, + output_attentions=True, + output_hidden_states=True, + return_dict=True) + dec_att, dec_catt = dec_output.attentions, dec_output.cross_attentions + dec_hs_all = dec_output.hidden_states + + # dec att + plt.subplot(1, 2, 1) + plt.imshow(torch.sum(dec_att[0][0], axis=0).detach().numpy()) + plt.title('decoder attention, layer0') + plt.xlabel('decoder time step') + plt.ylabel('decoder time step') + plt.subplot(1, 2, 2) + plt.imshow(torch.sum(dec_att[7][0], axis=0).detach().numpy()) + plt.title('decoder attention, layer8') + plt.xlabel('decoder time step') + plt.show() + # dec catt + plt.imshow(np.rot90((torch.sum(dec_catt[7][0], + axis=0))[:1000, :].detach().numpy()), + origin='upper', + aspect='auto') + plt.colorbar() + plt.title('decoder cross att, layer8') + plt.xlabel('decoder time step') + plt.ylabel('encoder frame') + plt.show() + # dec catt by head with xxx + dec_att_z = z_normalize_tensors(shorten_att(dec_att)) + plt.imshow(dec_att_z[0][0, 0, :, :].detach().numpy()) + from bertviz import head_view + token = [] + for i in label[0, :30]: + token.append(str(i)) + head_view(dec_att_z, tokens) + + # dec_hs + plt.subplot(1, 2, 1) + plt.imshow(dec_hs_all[0][0].detach().numpy(), origin='upper') + plt.colorbar(orientation='horizontal') + plt.title('decoder hidden state, layer1') + plt.xlabel('hidden dim') + plt.ylabel('time step') + plt.subplot(1, 2, 2) + plt.imshow(dec_hs_all[7][0].detach().numpy(), origin='upper') + plt.colorbar(orientation='horizontal') + plt.title('decoder hidden state, layer8') + plt.xlabel('hidden dim') + plt.show() + + # lm head + logits = model.lm_head(dec_hs_all[0]) + plt.imshow(logits[0][0:200, :].detach().numpy(), origin='upper') + plt.title('lm head softmax') + plt.xlabel('vocab dim') + plt.ylabel('time step') + plt.xlim([1000, 1350]) + plt.show() + softmax = torch.nn.Softmax(dim=2) + logits_sm = softmax(logits) + plt.imshow(logits_sm[0][0:200, :].detach().numpy(), origin='upper') + plt.title('lm head softmax') + plt.xlabel('vocab dim') + plt.ylabel('time step') + plt.xlim([1000, 1350]) + plt.show() diff --git a/amt/src/extras/perceivertf_multi_inspect.py b/amt/src/extras/perceivertf_multi_inspect.py new file mode 100644 index 0000000000000000000000000000000000000000..6e15c86073a7bf987534867cd9cee051798d4c87 --- /dev/null +++ b/amt/src/extras/perceivertf_multi_inspect.py @@ -0,0 +1,778 @@ +import numpy as np +import torch +import torch.nn.functional as F +import torchaudio +from matplotlib.animation import FuncAnimation + +def l2_normalize(matrix): + """ + L2 Normalize the matrix along its rows. + + Parameters: + matrix (numpy.ndarray): The input matrix. + + Returns: + numpy.ndarray: The L2 normalized matrix. + """ + l2_norms = np.linalg.norm(matrix, axis=1, keepdims=True) + normalized_matrix = matrix / l2_norms + return normalized_matrix + + +def z_normalize(matrix): + """ + Z-normalize the matrix along its rows (mean=0 and std=1). + Z-normalization is also known as "standardization", and derives from z-score. + Z = (X - mean) / std + Z-nomarlized, each row has mean=0 and std=1. + + Parameters: + matrix (numpy.ndarray): The input matrix. + + Returns: + numpy.ndarray: The Z normalized matrix. + """ + mean = np.mean(matrix, axis=1, keepdims=True) + std = np.std(matrix, axis=1, keepdims=True) + normalized_matrix = (matrix - mean) / std + return normalized_matrix + + +def l2_normalize_tensors(tensor_tuple): + """ + Applies L2 normalization on the last two dimensions for each tensor in a tuple. + + Parameters: + tensor_tuple (tuple of torch.Tensor): A tuple containing N tensors, each of shape (1, k, 30, 30). + + Returns: + tuple of torch.Tensor: A tuple containing N L2-normalized tensors. + """ + normalized_tensors = [] + for tensor in tensor_tuple: + # Ensure the tensor is a floating-point type + tensor = tensor.float() + + # Calculate L2 norm on the last two dimensions, keeping the dimensions using keepdim=True + l2_norm = torch.linalg.norm(tensor, dim=(-2, -1), keepdim=True) + + # Apply L2 normalization + normalized_tensor = tensor / ( + l2_norm + 1e-7) # Small value to avoid division by zero + + normalized_tensors.append(normalized_tensor) + + return tuple(normalized_tensors) + + +def z_normalize_tensors(tensor_tuple): + """ + Applies Z-normalization on the last two dimensions for each tensor in a tuple. + + Parameters: + tensor_tuple (tuple of torch.Tensor): A tuple containing N tensors, each of shape (1, k, 30, 30). + + Returns: + tuple of torch.Tensor: A tuple containing N Z-normalized tensors. + """ + normalized_tensors = [] + for tensor in tensor_tuple: + # Ensure the tensor is a floating-point type + tensor = tensor.float() + + # Calculate mean and std on the last two dimensions + mean = tensor.mean(dim=(-2, -1), keepdim=True) + std = tensor.std(dim=(-2, -1), keepdim=True) + + # Apply Z-normalization + normalized_tensor = (tensor - mean) / ( + std + 1e-7) # Small value to avoid division by zero + + normalized_tensors.append(normalized_tensor) + + return tuple(normalized_tensors) + + +def apply_temperature_to_attention_tensors(tensor_tuple, temperature=1.0): + """ + Applies temperature scaling to the attention weights in each tensor in a tuple. + + Parameters: + tensor_tuple (tuple of torch.Tensor): A tuple containing N tensors, + each of shape (1, k, 30, 30). + temperature (float): Temperature parameter to control the sharpness + of the attention weights. Default is 1.0. + + Returns: + tuple of torch.Tensor: A tuple containing N tensors with scaled attention weights. + """ + scaled_attention_tensors = [] + + for tensor in tensor_tuple: + # Ensure the tensor is a floating-point type + tensor = tensor.float() + + # Flatten the last two dimensions + flattened_tensor = tensor.reshape(1, tensor.shape[1], + -1) # Modified line here + + # Apply temperature scaling and softmax along the last dimension + scaled_attention = flattened_tensor / temperature + scaled_attention = F.softmax(scaled_attention, dim=-1) + + # Reshape to original shape + scaled_attention = scaled_attention.view_as(tensor) + + scaled_attention_tensors.append(scaled_attention) + + return tuple(scaled_attention_tensors) + + +def shorten_att(tensor_tuple, length=30): + shortend_tensors = [] + for tensor in tensor_tuple: + shortend_tensors.append(tensor[:, :, :length, :length]) + return tuple(shortend_tensors) + + +def keep_top_k(matrix, k=6): + """ + Keep only the top k values in each row, set the rest to 0. + + Parameters: + matrix (numpy.ndarray): The input matrix. + k (int): The number of top values to keep in each row. + + Returns: + numpy.ndarray: The transformed matrix. + """ + topk_indices_per_row = np.argpartition(matrix, -k, axis=1)[:, -k:] + result_matrix = np.zeros_like(matrix) + + for i in range(matrix.shape[0]): + result_matrix[i, topk_indices_per_row[i]] = matrix[ + i, topk_indices_per_row[i]] + return result_matrix + + +def test_case_forward_enc_perceiver_tf_dec_multi_t5(): + import torch + from model.ymt3 import YourMT3 + from config.config import audio_cfg, model_cfg, shared_cfg + model_cfg["encoder_type"] = "perceiver-tf" + + model_cfg["encoder"]["perceiver-tf"]["attention_to_channel"] = True + model_cfg["encoder"]["perceiver-tf"]["num_latents"] = 26 + + model_cfg["decoder_type"] = "multi-t5" + + audio_cfg["codec"] = "spec" + audio_cfg["hop_length"] = 300 + model = YourMT3(audio_cfg=audio_cfg, model_cfg=model_cfg) + model.eval() + + # x = torch.randn(2, 1, 32767) + # labels = torch.randint(0, 400, (2, 1024), requires_grad=False) + + # # forward + # output = model.forward(x, labels) + + # # inference + # result = model.inference(x, None) + + # display latents + checkpoint = torch.load( + "../logs/ymt3/ptf_mc13_256_all_cross_v6_xk5_amp0811_edr005_attend_c_full_plus_2psn_nl26_sb_b26r_800k/checkpoints/model.ckpt", + map_location="cpu") + state_dict = checkpoint['state_dict'] + new_state_dict = { + k: v + for k, v in state_dict.items() if 'pitchshift' not in k + } + model.load_state_dict(new_state_dict, strict=False) + + latents = model.encoder.latent_array.latents.detach().numpy() + import matplotlib.pyplot as plt + import numpy as np + from sklearn.metrics.pairwise import cosine_similarity + cos = cosine_similarity(latents) + + from utils.data_modules import AMTDataModule + from einops import rearrange + # dm = AMTDataModule(data_preset_multi={"presets": ["slakh"]}) + #dm.setup("test") + # dl = dm.test_dataloader() + # ds = list(dl.values())[0].dataset + # audio, notes, tokens, _ = ds.__getitem__(7) + # x = audio[[16], ::] + # label = tokens[[16], :] + + # from utils.task_manager import TaskManager + # tm = TaskManager(task_name='mc13_256') + # dm = AMTDataModule(data_preset_multi={"presets": ["slakh"]}, + # task_manager=tm, + # train_stem_iaug_prob=None, + # train_stem_xaug_policy=None) + # dm.setup('fit') + # dl = dm.train_dataloader() + # ds = dl.flattened[0].dataset + # audio,tokens, _, _ = ds.__getitem__(67) + # x = audio[[5], ::] + # label = tokens[[5], :] + # save audio + # torchaudio.save("singing.wav", x[0, :, :], 16000) + + x, _ = torchaudio.load('piano.wav')#'test.wav') + x = x.unsqueeze(0) + + # spectrogram + x_spec = model.spectrogram(x) + x_conv = model.pre_encoder(x_spec) + # Create a larger figure + plt.figure( + figsize=(15, + 10)) # Adjust these numbers as needed for width and height + plt.subplot(2, 4, 1) + plt.imshow(x_spec[0].detach().numpy().T, aspect='auto', origin='lower') + plt.title("spectrogram") + plt.xlabel('time step') + plt.ylabel('frequency bin') + plt.subplot(2, 4, 2) + plt.imshow(x_conv[0][:, :, 0].detach().numpy().T, + aspect='auto', + origin='lower') + plt.title("conv(spec), ch=0") + plt.xlabel('time step') + plt.ylabel('F') + plt.subplot(2, 4, 3) + plt.imshow(x_conv[0][:, :, 42].detach().numpy().T, + aspect='auto', + origin='lower') + plt.title("ch=42") + plt.xlabel('time step') + plt.ylabel('F') + plt.subplot(2, 4, 4) + plt.imshow(x_conv[0][:, :, 80].detach().numpy().T, + aspect='auto', + origin='lower') + plt.title("ch=80") + plt.xlabel('time step') + plt.ylabel('F') + plt.subplot(2, 4, 5) + plt.imshow(x_conv[0][:, :, 11].detach().numpy().T, + aspect='auto', + origin='lower') + plt.title("ch=11") + plt.xlabel('time step') + plt.ylabel('F') + plt.subplot(2, 4, 6) + plt.imshow(x_conv[0][:, :, 20].detach().numpy().T, + aspect='auto', + origin='lower') + plt.title("ch=20") + plt.xlabel('time step') + plt.ylabel('F') + plt.subplot(2, 4, 7) + plt.imshow(x_conv[0][:, :, 77].detach().numpy().T, + aspect='auto', + origin='lower') + plt.title("ch=77") + plt.xlabel('time step') + plt.ylabel('F') + plt.subplot(2, 4, 8) + plt.imshow(x_conv[0][:, :, 90].detach().numpy().T, + aspect='auto', + origin='lower') + plt.title("ch=90") + plt.xlabel('time step') + plt.ylabel('F') + plt.tight_layout() + plt.show() + + # encoding + output = model.encoder(inputs_embeds=x_conv, + output_hidden_states=True, + output_attentions=True) + enc_hs_all, att, catt = output["hidden_states"], output[ + "attentions"], output["cross_attentions"] + enc_hs_last = enc_hs_all[2] + + # enc_hs: time-varying encoder hidden state + plt.subplot(2, 3, 1) + plt.imshow(enc_hs_all[0][0][:, :, 21].detach().numpy().T) + plt.title('ENC_HS B0, d21') + plt.colorbar(orientation='horizontal') + plt.ylabel('latent k') + plt.xlabel('t') + plt.subplot(2, 3, 4) + plt.imshow(enc_hs_all[0][0][:, :, 127].detach().numpy().T) + plt.colorbar(orientation='horizontal') + plt.title('B0, d127') + plt.ylabel('latent k') + plt.xlabel('t') + plt.subplot(2, 3, 2) + plt.imshow(enc_hs_all[1][0][:, :, 21].detach().numpy().T) + plt.colorbar(orientation='horizontal') + plt.title('B1, d21') + plt.ylabel('latent k') + plt.xlabel('t') + plt.subplot(2, 3, 5) + plt.imshow(enc_hs_all[1][0][:, :, 127].detach().numpy().T) + plt.colorbar(orientation='horizontal') + plt.title('B1, d127') + plt.ylabel('latent k') + plt.xlabel('t') + plt.subplot(2, 3, 3) + plt.imshow(enc_hs_all[2][0][:, :, 21].detach().numpy().T) + plt.colorbar(orientation='horizontal') + plt.title('B2, d21') + plt.ylabel('latent k') + plt.xlabel('t') + plt.subplot(2, 3, 6) + plt.imshow(enc_hs_all[2][0][:, :, 127].detach().numpy().T) + plt.colorbar(orientation='horizontal') + plt.title('B2, d127') + plt.ylabel('latent k') + plt.xlabel('t') + plt.tight_layout() + plt.show() + + # enc_hs: time-varying encoder hidden state by k (block, 1, t, k, d) + # --> (t, d) for each k in last block + data = enc_hs_all[2][0].detach().numpy() # (T, K, D) + fig, axs = plt.subplots( + 5, 5, figsize=(10, 9)) # 25 subplots arranged in 5 rows and 5 columns + axs = axs.flatten( + ) # Flatten the 2D array of axes to 1D for easy iteration + + for k in range(25): # Iterating through K indices from 0 to 24 + axs[k].imshow(data[:, k, :].T, + cmap='viridis') # Transposing the matrix to swap T and D + axs[k].set_title(f'k={k}') + axs[k].set_xlabel('Time step') + axs[k].set_ylabel('Dim') + + # Adjusting layout for better visibility + plt.tight_layout() + plt.show() + + #!! Projected encoder hidden state for 13 channels, that is conditioning for decoder + enc_hs_proj = model.pre_decoder(enc_hs_last) + fig, axs = plt.subplots(1, 13, figsize=(26, 8)) # 13 subplots in a row + data = enc_hs_proj[0].detach().numpy() + for ch in range(13): + axs[ch].imshow(np.rot90(data[ch]), cmap='viridis') # Rotate 90 degrees + axs[ch].set_title(f'ch: {ch}') + axs[ch].set_xlabel('Time step') + axs[ch].set_ylabel('Dim') + plt.suptitle( + 'linear projection of encoder outputs by channel, which is conditioning for enc-dec cross attention', + y=0.1, + fontsize=12) + plt.tight_layout(rect=[0, 0.1, 1, 1]) + plt.show() + + plt.subplot(221) + plt.imshow(enc_hs_all[2][0][0, :, :].detach().numpy(), aspect='auto') + plt.title('enc_hs, t=0') + plt.ylabel('latent k') + plt.xlabel('d') + plt.subplot(222) + plt.imshow(enc_hs_all[2][0][10, :, :].detach().numpy(), aspect='auto') + plt.title('enc_hs, t=10') + plt.ylabel('latent k') + plt.xlabel('d') + plt.subplot(223) + plt.imshow(enc_hs_all[2][0][20, :, :].detach().numpy(), aspect='auto') + plt.title('enc_hs, t=20') + plt.ylabel('latent k') + plt.xlabel('d') + plt.subplot(224) + plt.imshow(enc_hs_all[2][0][30, :, :].detach().numpy(), aspect='auto') + plt.title('enc_hs, t=30') + plt.ylabel('latent k') + plt.xlabel('d') + plt.tight_layout() + plt.show() + + # enc_hs correlation: which dim has most unique info? + plt.subplot(1, 3, 1) + a = rearrange(enc_hs_last, '1 t k d -> t (k d)').detach().numpy() + plt.imshow(cosine_similarity(a)) + plt.title("enc hs, t x t cos_sim") + plt.subplot(1, 3, 2) + b = rearrange(enc_hs_last, '1 t k d -> k (t d)').detach().numpy() + plt.imshow(cosine_similarity(b)) + plt.title("enc hs, k x k cos_sim") + plt.subplot(1, 3, 3) + c = rearrange(enc_hs_last, '1 t k d -> d (k t)').detach().numpy() + plt.imshow(cosine_similarity(c)) + plt.title("cross att, d x d cos_sim") + plt.tight_layout() + plt.show() + + #!! enc latent + plt.imshow(model.encoder.latent_array.latents.detach().numpy()) + plt.title('latent array') + plt.xlabel('d') + plt.ylabel('latent k') + plt.show() + + #!! enc Spectral Cross Attention: (T x head x K x D). How latent K attends to conv channel C? + plt.subplot(311) + plt.imshow( + torch.sum(torch.sum(catt[0][0], axis=0), axis=0).detach().numpy()) + plt.title('block=0') + plt.ylabel('latent k') + plt.xlabel('conv channel') + plt.subplot(312) + plt.imshow( + torch.sum(torch.sum(catt[1][0], axis=0), axis=0).detach().numpy()) + plt.title('block=1') + plt.ylabel('latent k') + plt.xlabel('conv channel') + plt.subplot(313) + plt.imshow( + torch.sum(torch.sum(catt[2][0], axis=0), axis=0).detach().numpy()) + plt.title('block=2') + plt.ylabel('latent k') + plt.xlabel('conv channel') + # f'spectral cross attention. T-C-F Model', + # y=0, + # fontsize=12) + plt.tight_layout() + plt.show() + + #!! Animation of SCA for varying time, head in last block + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 6)) # Adjusted figsize for better layout + + # Function to update the plots for each frame in the animation + def update(t): + # Clear previous images + ax1.clear() + ax2.clear() + + # Update subplot for h=3 + ax1.imshow(catt[2][0][t, 3, :, :].detach().numpy()) + ax1.set_title(f'block=2, t={t}, head=3') + ax1.set_ylabel('latent k'); ax1.set_xlabel('conv channel') + + # Update subplot for h=5 + ax2.imshow(catt[2][0][t, 5, :, :].detach().numpy()) + ax2.set_title(f'block=2, t={t}, head=5') + ax2.set_ylabel('latent k'); ax2.set_xlabel('conv channel') + + # Adjust layout + fig.tight_layout() + + # Create the animation + anim = FuncAnimation(fig, update, frames=range(0, 110), interval=200) + anim.save('animation.gif', writer='pillow', fps=5) + + + + fig, axs = plt.subplots(3, 1, figsize=(12, 18), gridspec_kw={'height_ratios': [1, 1, 0.5]}) # Adjusted for different subplot sizes + + # Subplots for catt visualization (h=3 and h=5) + ax_catt3, ax_catt5, ax_att_row = axs + + # Creating 8 subplots for att visualization within the third row + for i in range(8): + ax_att_row = fig.add_subplot(3, 8, 17 + i) # Adding subplots in the third row + + # Update function for the combined animation + def combined_update_smaller_att(t): + # Update subplot for catt with h=3 + ax_catt3.clear() + ax_catt3.imshow(catt[2][0][t, 3, :, :].detach().numpy()) + ax_catt3.set_title(f'block=2, t={t}, head=3') + ax_catt3.set_ylabel('latent k'); ax_catt3.set_xlabel('conv channel') + + # Update subplot for catt with h=5 + ax_catt5.clear() + ax_catt5.imshow(catt[2][0][t, 5, :, :].detach().numpy()) + ax_catt5.set_title(f'block=2, t={t}, head=5') + ax_catt5.set_ylabel('latent k'); ax_catt5.set_xlabel('conv channel') + + # Update subplots for att (8 heads in one row) + for i in range(8): + ax = fig.add_subplot(3, 8, 17 + i) + ax.clear() + ax.imshow(att[0][1][t, i, :, :].detach().numpy(), cmap='viridis') + ax.set_title(f't={t}, head={i}') + ax.set_xlabel('k') + ax.set_ylabel('k') + ax.axis('square') # Make each subplot square-shaped + + # Adjust layout + fig.tight_layout() + combined_anim_smaller_att = FuncAnimation(fig, combined_update_smaller_att, frames=range(0, 110), interval=200) + combined_anim_smaller_att.save('combined_animation_smaller_att.gif', writer='pillow', fps=5) + + + + + + # enc Latent Self-attention: How latent K attends to K? + plt.subplot(231) + plt.imshow(torch.sum(torch.sum(att[0][0], axis=1), + axis=0).detach().numpy(), + origin='upper') + plt.title('B0L0') + plt.xlabel('latent k') + plt.ylabel('latent k') + plt.subplot(234) + plt.imshow(torch.sum(torch.sum(att[0][1], axis=1), + axis=0).detach().numpy(), + origin='upper') + plt.title('B0L1') + plt.xlabel('latent k') + plt.ylabel('latent k') + plt.subplot(232) + plt.imshow(torch.sum(torch.sum(att[1][0], axis=1), + axis=0).detach().numpy(), + origin='upper') + plt.title('B1L0') + plt.xlabel('latent k') + plt.ylabel('latent k') + plt.subplot(235) + plt.imshow(torch.sum(torch.sum(att[1][1], axis=1), + axis=0).detach().numpy(), + origin='upper') + plt.title('B1L1') + plt.xlabel('latent k') + plt.ylabel('latent k') + plt.subplot(233) + plt.imshow(torch.sum(torch.sum(att[2][0], axis=1), + axis=0).detach().numpy(), + origin='upper') + plt.title('B2L0') + plt.xlabel('latent k') + plt.ylabel('latent k') + plt.subplot(236) + plt.imshow(torch.sum(torch.sum(att[2][1], axis=1), + axis=0).detach().numpy(), + origin='upper') + plt.title('B2L1') + plt.xlabel('latent k') + plt.ylabel('latent k') + plt.tight_layout() + plt.show() + # Time varying, different head for latent self-attention + #!!! Display latent self-attention for each head + bl = 0 # first latent transformer block, last layer att + data = att[bl][1].detach().numpy() + time_steps = [30, 50, 100] + fig, axs = plt.subplots( + len(time_steps), 8, + figsize=(16, 6)) # Subplots for each time step and head + for i, t in enumerate(time_steps): + for head in range(8): + axs[i, head].imshow(data[t, head, :, :], cmap='viridis') + axs[i, head].set_title(f't={t}, head={head}') + axs[i, head].set_xlabel('k') + axs[i, head].set_ylabel('k') + plt.suptitle( + f'latent transformer block={bl}, last layer self-attention over time', + y=0, + fontsize=12) + plt.tight_layout() + plt.show() + + bl = 1 # second latent transformer block, last layer att + data = att[bl][1].detach().numpy() + time_steps = [30, 50, 100] + fig, axs = plt.subplots( + len(time_steps), 8, + figsize=(16, 6)) # Subplots for each time step and head + for i, t in enumerate(time_steps): + for head in range(8): + axs[i, head].imshow(data[t, head, :, :], cmap='viridis') + axs[i, head].set_title(f't={t}, head={head}') + axs[i, head].set_xlabel('k') + axs[i, head].set_ylabel('k') + plt.suptitle( + f'latent transformer block={bl}, last layer self-attention over time', + y=0, + fontsize=12) + plt.tight_layout() + plt.show() + + bl = 2 # last latent transformer block, last layer att + data = att[bl][1].detach().numpy() + time_steps = [30, 50, 100] + fig, axs = plt.subplots( + len(time_steps), 8, + figsize=(16, 6)) # Subplots for each time step and head + for i, t in enumerate(time_steps): + for head in range(8): + axs[i, head].imshow(data[t, head, :, :], cmap='viridis') + axs[i, head].set_title(f't={t}, head={head}') + axs[i, head].set_xlabel('k') + axs[i, head].set_ylabel('k') + plt.suptitle( + f'latent transformer block={bl}, last layer self-attention over time', + y=0, + fontsize=12) + plt.tight_layout() + plt.show() + + # Temporal Self-attention: (K x H x T x T) How time t attends to time t? + plt.subplot(231) + plt.imshow(torch.sum(torch.sum(att[0][2], axis=1), + axis=0).detach().numpy(), + origin='upper') + plt.title('B0L2') + plt.xlabel('t') + plt.ylabel('t') + plt.subplot(234) + plt.imshow(torch.sum(torch.sum(att[0][3], axis=1), + axis=0).detach().numpy(), + origin='upper') + plt.title('B0L3') + plt.xlabel('t') + plt.ylabel('t') + plt.subplot(232) + plt.imshow(torch.sum(torch.sum(att[1][2], axis=1), + axis=0).detach().numpy(), + origin='upper') + plt.title('B1L2') + plt.xlabel('t') + plt.ylabel('t') + plt.subplot(235) + plt.imshow(torch.sum(torch.sum(att[1][3], axis=1), + axis=0).detach().numpy(), + origin='upper') + plt.title('B1L3') + plt.xlabel('t') + plt.ylabel('t') + plt.subplot(233) + plt.imshow(torch.sum(torch.sum(att[2][2], axis=1), + axis=0).detach().numpy(), + origin='upper') + plt.title('B2L2') + plt.xlabel('t') + plt.ylabel('t') + plt.subplot(236) + plt.imshow(torch.sum(torch.sum(att[2][3], axis=1), + axis=0).detach().numpy(), + origin='upper') + plt.title('B2L3') + plt.xlabel('t') + plt.ylabel('t') + plt.tight_layout() + plt.show() + + # decoding + dec_input_ids = model.shift_right_fn(label) + dec_inputs_embeds = model.embed_tokens(dec_input_ids) + dec_output = model.decoder(inputs_embeds=dec_inputs_embeds, + encoder_hidden_states=enc_hs_proj, + output_attentions=True, + output_hidden_states=True, + return_dict=True) + dec_att, dec_catt = dec_output.attentions, dec_output.cross_attentions + dec_hs_all = dec_output.hidden_states + dec_last_hs = dec_output.last_hidden_state + + # lm head + logits = model.lm_head(dec_last_hs) + + # pred ids + pred_ids = torch.argmax(logits, dim=3) + + # dec att + plt.subplot(1, 2, 1) + plt.imshow(torch.sum(dec_att[5][0], axis=0).detach().numpy()) + plt.title('decoder attention, layer0') + plt.xlabel('decoder time step') + plt.ylabel('decoder time step') + plt.subplot(1, 2, 2) + plt.imshow(torch.sum(dec_att[7][0], axis=0).detach().numpy()) + plt.title('decoder attention, final layer') + plt.xlabel('decoder step') + plt.show() + + + # dec catt + def remove_values_after_eos(catt_np, pred_ids, max_k): + # catt_np: (k, head, t, t) + # pred_ids: (1, k, t)) + max_length = pred_ids.shape[-1] + seq_lengths = np.zeros((max_k), dtype=np.int32) + for k in range(max_k): + for t in range(max_length): + if pred_ids[0, k, t] == 1: + break + catt_np[k, :, t+1:, :] = 0 + # catt_np[k, :, :, t+1:] = 0 + seq_lengths[k] = t+1 + return catt_np, seq_lengths + + # data = dec_catt[1].detach().numpy() # last layer's cross attention + l = 4 + data = dec_catt[l].detach().numpy() + data, seq_lengths = remove_values_after_eos(data, pred_ids, max_k=13) + seq_lengths[:]= 256 + + fig, axs = plt.subplots(13, 6, figsize=(21, 39)) # 13 rows (for k=0:12) and 7 columns (for head=0:6) + for k in range(13): + s = seq_lengths[k] + for head in range(6): + axs[k, head].imshow(data[k, head, :s, :].T, aspect='auto', cmap='viridis') + axs[k, head].set_title(f'Layer {l}, k={k}, head={head}') + axs[k, head].set_xlabel('Decoder step') + axs[k, head].set_ylabel('Encoder frame') + plt.tight_layout() + plt.show() + + + # # dec catt by head with xxx + # dec_att_z = z_normalize_tensors(shorten_att(dec_att)) + # plt.imshow(dec_att_z[0][0, 0, :, :].detach().numpy()) + # from bertviz import head_view + # token = [] + # for i in label[0, :30]: + # token.append(str(i)) + # head_view(dec_att_z, tokens) + + # dec_hs + plt.subplot(1, 2, 1) + k=2 + plt.imshow(dec_last_hs[0][k].detach().numpy(), origin='upper') + plt.colorbar(orientation='horizontal') + plt.title('decoder last hidden state, k=0') + plt.xlabel('hidden dim') + plt.ylabel('time step') + plt.subplot(1, 2, 2) + k=12 + plt.imshow(dec_last_hs[0][k].detach().numpy(), origin='upper') + plt.colorbar(orientation='horizontal') + plt.title('decoder last hidden state, k=12') + plt.xlabel('hidden dim') + plt.show() + + # lm head + logits = model.lm_head(dec_last_hs) + k=6 + plt.imshow(logits[0][k][0:200, :].detach().numpy().T, origin='upper') + plt.title('lm head output') + plt.xlabel('vocab dim') + plt.ylabel('time step') + plt.show() + softmax = torch.nn.Softmax(dim=3) + logits_sm = softmax(logits) # B, K, T, V + k=6 + plt.imshow(logits_sm[0][k][:255, :].detach().numpy().T, origin='upper') + plt.title('lm head softmax') + plt.xlabel('vocab dim') + plt.ylabel('time step') + # plt.xlim([1000, 1350]) + plt.show() + + k = 10 + print(torch.argmax(logits, dim=3)[0,k,:]) + + + + diff --git a/amt/src/extras/pitch_shift_benchmark.py b/amt/src/extras/pitch_shift_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..a333e13bf189c4f52dc055ed466ce32a0e57b663 --- /dev/null +++ b/amt/src/extras/pitch_shift_benchmark.py @@ -0,0 +1,167 @@ +""" Test the speed of the augmentation """ +import torch +import torchaudio + +# Device +device = torch.device("cuda") +# device = torch.device("cpu") + +# Music +# x, _ = torchaudio.load("music.wav") +# slice_length = 32767 +# n_slices = 80 +# slices = [x[0, i * slice_length:(i + 1) * slice_length] for i in range(n_slices)] +# x = torch.stack(slices) # (80, 32767) +# Sine wave +t = torch.arange(0, 2.0479, 1 / 16000) # 2.05 seconds at 16kHz +x = torch.sin(2 * torch.pi * 440 * t) * 0.5 +x = x.reshape(1, 1, 32767).tile(80, 1, 1) +x = x.to(device) + +############################################################################################ +# torch-audiomentation: https://github.com/asteroid-team/torch-audiomentation +# +# process time : 1.18 s ± 5.35 ms +# process time : 58 ms +# GPU memory usage: 3.8 GB per 1 semitone +############################################################################################ +import torch +from torch_audiomentations import Compose, PitchShift, Gain, PolarityInversion + +apply_augmentation = Compose(transforms=[ + # Gain( + # min_gain_in_db=-15.0, + # max_gain_in_db=5.0, + # p=0.5, + # ), + # PolarityInversion(p=0.5) + PitchShift( + min_transpose_semitones=0, + max_transpose_semitones=2.2, + mode="per_batch", #"per_example", + p=1.0, + p_mode="per_batch", + sample_rate=16000, + target_rate=16000) +]) +x_am = apply_augmentation(x, sample_rate=16000) + +############################################################################################ +# torchaudio: +# +# process time : 4.01 s ± 19.6 ms per loop +# process time : 25.1 ms ± 161 µs per loop +# memory usage : 1.2 (growth to 5.49) GB per 1 semitone +############################################################################################ +from torchaudio import transforms + +ta_transform = transforms.PitchShift(16000, n_steps=2).to(device) +x_ta = ta_transform(x) + +############################################################################################ +# YourMT3 pitch_shift_layer: +# +# process time : 389ms ± 22ms, (stretch=143 ms, resampler=245 ms) +# process time : 7.18 ms ± 17.3 µs (stretch=6.47 ms, resampler=0.71 ms) +# memory usage: 16 MB per 1 semitone (average) +############################################################################################ +from model.pitchshift_layer import PitchShiftLayer + +ps_ymt3 = PitchShiftLayer(pshift_range=[2, 2], fs=16000, min_gcd=16, n_fft=2048).to(device) +x_ymt3 = ps_ymt3(x, 2) + +############################################################################################ +# Plot 1: Comparison of Process Time and GPU Memory Usage for 3 Pitch Shifting Methods +############################################################################################ +import matplotlib.pyplot as plt + +# Model names +models = ['torch-audiomentation', 'torchaudio', 'YourMT3:PitchShiftLayer'] + +# Process time (CPU) in seconds +cpu_time = [1.18, 4.01, 0.389] + +# Process time (GPU) in milliseconds +gpu_time = [58, 25.1, 7.18] + +# GPU memory usage in GB +gpu_memory = [3.8, 5.49, 0.016] + +# Creating subplots +fig, axs = plt.subplots(1, 3, figsize=(15, 5)) + +# Creating bar charts +bar1 = axs[0].bar(models, cpu_time, color=['#FFB6C1', '#ADD8E6', '#98FB98']) +bar2 = axs[1].bar(models, gpu_time, color=['#FFB6C1', '#ADD8E6', '#98FB98']) +bar3 = axs[2].bar(models, gpu_memory, color=['#FFB6C1', '#ADD8E6', '#98FB98']) + +# Adding labels and titles +axs[0].set_ylabel('Time (s)') +axs[0].set_title('Process Time (CPU) bsz=80') +axs[1].set_ylabel('Time (ms)') +axs[1].set_title('Process Time (GPU) bsz=80') +axs[2].set_ylabel('Memory (GB)') +axs[2].set_title('GPU Memory Usage per semitone') + +# Adding grid for better readability of the plots +for ax in axs: + ax.grid(axis='y') + ax.set_yscale('log') + ax.set_xticklabels(models, rotation=45, ha="right") + +# Adding text labels above the bars +for i, rect in enumerate(bar1): + axs[0].text( + rect.get_x() + rect.get_width() / 2, + rect.get_height(), + f'{cpu_time[i]:.2f} s', + ha='center', + va='bottom') +for i, rect in enumerate(bar2): + axs[1].text( + rect.get_x() + rect.get_width() / 2, + rect.get_height(), + f'{gpu_time[i]:.2f} ms', + ha='center', + va='bottom') +for i, rect in enumerate(bar3): + axs[2].text( + rect.get_x() + rect.get_width() / 2, + rect.get_height(), + f'{gpu_memory[i]:.3f} GB', + ha='center', + va='bottom') +plt.tight_layout() +plt.show() + +############################################################################################ +# Plot 2: Stretch and Resampler Processing Time Contribution +############################################################################################ +# Data +processing_type = ['Stretch (Phase Vocoder)', 'Resampler (Conv1D)'] +cpu_times = [143, 245] # [Stretch, Resampler] times for CPU in milliseconds +gpu_times = [6.47, 0.71] # [Stretch, Resampler] times for GPU in milliseconds + +# Creating subplots +fig, axs = plt.subplots(1, 2, figsize=(12, 6)) + +# Plotting bar charts +axs[0].bar(processing_type, cpu_times, color=['#ADD8E6', '#98FB98']) +axs[1].bar(processing_type, gpu_times, color=['#ADD8E6', '#98FB98']) + +# Adding labels and titles +axs[0].set_ylabel('Time (ms)') +axs[0].set_title('Contribution of CPU Processing Time: YMT3-PS (BSZ=80)') +axs[1].set_title('Contribution of GPU Processing Time: YMT3-PS (BSZ=80)') + +# Adding grid for better readability of the plots +for ax in axs: + ax.grid(axis='y') + ax.set_yscale('log') # Log scale to better visualize the smaller values + +# Adding values on top of the bars +for ax, times in zip(axs, [cpu_times, gpu_times]): + for idx, time in enumerate(times): + ax.text(idx, time, f"{time:.2f} ms", ha='center', va='bottom', fontsize=8) +plt.tight_layout() +plt.show() diff --git a/amt/src/extras/remove_silence_musicnet_midi.py b/amt/src/extras/remove_silence_musicnet_midi.py new file mode 100644 index 0000000000000000000000000000000000000000..096816d80eed31b3745506f6a560b4743312ceeb --- /dev/null +++ b/amt/src/extras/remove_silence_musicnet_midi.py @@ -0,0 +1,32 @@ +import os +import glob + +from utils.midi import midi2note +from utils.note2event import note2note_event +from utils.note_event_dataclasses import Note +from utils.note_event_dataclasses import NoteEvent +from utils.midi import note_event2midi + +data_home = '../../data' +dataset_name = 'musicnet' +base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') +mid_pattern = os.path.join(base_dir, '*_midi', '*.mid') +mid_files = glob.glob(mid_pattern, recursive=True) + +for mid_file in mid_files: + notes, _ = midi2note(mid_file) + first_onset_time = notes[0].onset + fixed_notes = [] + for note in notes: + fixed_notes.append( + Note( + is_drum=note.is_drum, + program=note.program, + onset=note.onset - first_onset_time, + offset=note.offset - first_onset_time, + pitch=note.pitch, + velocity=note.velocity)) + assert len(notes) == len(fixed_notes) + fixed_note_events = note2note_event(fixed_notes, return_activity=False) + note_event2midi(fixed_note_events, mid_file) + print(f'Overwriting {mid_file}') diff --git a/amt/src/extras/rotary_positional_embedding.py b/amt/src/extras/rotary_positional_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..30afd292c21b9e214876bfdc20f568c5fa90dcd4 --- /dev/null +++ b/amt/src/extras/rotary_positional_embedding.py @@ -0,0 +1,191 @@ +"""rotary_positional_embedding.py - Rotary Positional Embedding + +code from github.com/lucidrains/rotary-embedding-torch + +MIT License +""" + +from math import pi, log +import torch +from torch import nn, einsum +from einops import rearrange, repeat + + +def exists(val): + return val is not None + + +def broadcat(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = set(list(map(lambda t: len(t.shape), tensors))) + assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions' + shape_len = list(shape_lens)[0] + + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*map(lambda t: list(t.shape), tensors))) + + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims) + ]), 'invalid dimensions for broadcastable concatentation' + max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) + expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) + tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) + return torch.cat(tensors, dim=dim) + + +# rotary embedding helper functions +def rotate_half(x): + x = rearrange(x, '... (d r) -> ... d r', r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, '... d r -> ... (d r)') + + +def apply_rotary_emb(freqs, t, start_index=0, scale=1.): + rot_dim, seq_len = freqs.shape[-1], t.shape[-2] + freqs = freqs[-seq_len:, :] + + freqs = freqs.to(t) + end_index = start_index + rot_dim + assert rot_dim <= t.shape[ + -1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' + t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] + t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) + return torch.cat((t_left, t, t_right), dim=-1) + + +# learned rotation helpers +def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None): + if exists(freq_ranges): + rotations = einsum('..., f -> ... f', rotations, freq_ranges) + rotations = rearrange(rotations, '... r f -> ... (r f)') + + rotations = repeat(rotations, '... n -> ... (n r)', r=2) + return apply_rotary_emb(rotations, t, start_index=start_index) + + +# classes +class RotaryEmbedding(nn.Module): + + def __init__(self, + dim, + custom_freqs=None, + freqs_for='lang', + theta=10000, + max_freq=10, + num_freqs=1, + learned_freq=False, + use_xpos=False, + xpos_scale_base=512, + interpolate_factor=1., + theta_rescale_factor=1.): + super().__init__() + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + theta *= theta_rescale_factor**(dim / (dim - 2)) + + if exists(custom_freqs): + freqs = custom_freqs + elif freqs_for == 'lang': + freqs = 1. / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) + elif freqs_for == 'pixel': + freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi + elif freqs_for == 'constant': + freqs = torch.ones(num_freqs).float() + else: + raise ValueError(f'unknown modality {freqs_for}') + + self.cache = dict() + self.cache_scale = dict() + self.freqs = nn.Parameter(freqs, requires_grad=learned_freq) + + # interpolation factors + + assert interpolate_factor >= 1. + self.interpolate_factor = interpolate_factor + + # xpos + + self.use_xpos = use_xpos + if not use_xpos: + self.register_buffer('scale', None) + return + + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + self.scale_base = xpos_scale_base + self.register_buffer('scale', scale) + + def get_seq_pos(self, seq_len, device, dtype, offset=0): + return (torch.arange(seq_len, device=device, dtype=dtype) + + offset) / self.interpolate_factor + + def rotate_queries_or_keys(self, t, seq_dim=-2, offset=0, freq_seq_len=None): + assert not self.use_xpos, 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings' + + device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim] + + if exists(freq_seq_len): + assert freq_seq_len >= seq_len + seq_len = freq_seq_len + + freqs = self.forward( + lambda: self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset), + cache_key=f'freqs:{seq_len}|offset:{offset}') + return apply_rotary_emb(freqs, t) + + def rotate_queries_with_cached_keys(self, q, k, seq_dim=-2): + q_len, k_len = q.shape[seq_dim], k.shape[seq_dim] + assert q_len <= k_len + q = self.rotate_queries_or_keys(q, seq_dim=seq_dim, freq_seq_len=k_len) + k = self.rotate_queries_or_keys(k, seq_dim=seq_dim) + return q, k + + def rotate_queries_and_keys(self, q, k, seq_dim=-2): + assert self.use_xpos + device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim] + seq = self.get_seq_pos(seq_len, dtype=dtype, device=device) + freqs = self.forward(lambda: seq, cache_key=f'freqs:{seq_len}') + scale = self.get_scale(lambda: seq, cache_key=f'scale:{seq_len}').to(dtype) + rotated_q = apply_rotary_emb(freqs, q, scale=scale) + rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1) + return rotated_q, rotated_k + + def get_scale(self, t, cache_key=None): + assert self.use_xpos + + if exists(cache_key) and cache_key in self.cache: + return self.cache[cache_key] + + if callable(t): + t = t() + + scale = 1. + if self.use_xpos: + power = (t - len(t) // 2) / self.scale_base + scale = self.scale**rearrange(power, 'n -> n 1') + scale = torch.cat((scale, scale), dim=-1) + + if exists(cache_key): + self.cache[cache_key] = scale + + return scale + + def forward(self, t, cache_key=None): + if exists(cache_key) and cache_key in self.cache: + return self.cache[cache_key] + + if callable(t): + t = t() + + freqs = self.freqs + + freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs) + freqs = repeat(freqs, '... n -> ... (n r)', r=2) + + if exists(cache_key): + self.cache[cache_key] = freqs + + return freqs \ No newline at end of file diff --git a/amt/src/extras/run_spleeter_mir1k.sh b/amt/src/extras/run_spleeter_mir1k.sh new file mode 100644 index 0000000000000000000000000000000000000000..8352d5cffdcfe1e7e45c7d772790026d86557420 --- /dev/null +++ b/amt/src/extras/run_spleeter_mir1k.sh @@ -0,0 +1,17 @@ +#!/bin/bash +shopt -s globstar +for file in "$1"/**/*.wav; do + echo $file + output_dir="tmp" + spleeter separate -b 256k -B tensorflow -p spleeter:2stems -o $output_dir $file -f {instrument}.{codec} + sox --ignore-length tmp/accompaniment.wav -r 16000 -c 1 -b 16 tmp/accompaniment_16k.wav + sox --ignore-length tmp/vocals.wav -r 16000 -c 1 -b 16 tmp/vocals_16k.wav + acc_file="${file//.wav/_accompaniment.wav}" + voc_file="${file//.wav/_vocals.wav}" + mv -f "tmp/accompaniment_16k.wav" $acc_file + mv -f "tmp/vocals_16k.wav" $voc_file + echo $acc_file + echo $voc_file + rm -rf tmp +done +rm -rf pretrained_models \ No newline at end of file diff --git a/amt/src/extras/run_spleeter_mirst500.sh b/amt/src/extras/run_spleeter_mirst500.sh new file mode 100644 index 0000000000000000000000000000000000000000..bf155454a2837733f7ac8d3c6690eafdf18270f2 --- /dev/null +++ b/amt/src/extras/run_spleeter_mirst500.sh @@ -0,0 +1,13 @@ +#!/bin/bash +shopt -s globstar +for file in "$1"/**/*.wav; do + output_dir="${file%/*}" + input_file="$output_dir/converted_Mixture.wav" + spleeter separate -p spleeter:2stems -o $output_dir $input_file -f {instrument}.{codec} + ffmpeg -i "$output_dir/vocals.wav" -acodec pcm_s16le -ac 1 -ar 16000 -y "$output_dir/vocals_16k.wav" + ffmpeg -i "$output_dir/accompaniment.wav" -acodec pcm_s16le -ac 1 -ar 16000 -y "$output_dir/accompaniment_16k.wav" + rm "$output_dir/vocals.wav" + rm "$output_dir/accompaniment.wav" + mv "$output_dir/vocals_16k.wav" "$output_dir/vocals.wav" + mv "$output_dir/accompaniment_16k.wav" "$output_dir/accompaniment.wav" +done diff --git a/amt/src/extras/run_spleeter_mirst500_cmedia.sh b/amt/src/extras/run_spleeter_mirst500_cmedia.sh new file mode 100644 index 0000000000000000000000000000000000000000..bf155454a2837733f7ac8d3c6690eafdf18270f2 --- /dev/null +++ b/amt/src/extras/run_spleeter_mirst500_cmedia.sh @@ -0,0 +1,13 @@ +#!/bin/bash +shopt -s globstar +for file in "$1"/**/*.wav; do + output_dir="${file%/*}" + input_file="$output_dir/converted_Mixture.wav" + spleeter separate -p spleeter:2stems -o $output_dir $input_file -f {instrument}.{codec} + ffmpeg -i "$output_dir/vocals.wav" -acodec pcm_s16le -ac 1 -ar 16000 -y "$output_dir/vocals_16k.wav" + ffmpeg -i "$output_dir/accompaniment.wav" -acodec pcm_s16le -ac 1 -ar 16000 -y "$output_dir/accompaniment_16k.wav" + rm "$output_dir/vocals.wav" + rm "$output_dir/accompaniment.wav" + mv "$output_dir/vocals_16k.wav" "$output_dir/vocals.wav" + mv "$output_dir/accompaniment_16k.wav" "$output_dir/accompaniment.wav" +done diff --git a/amt/src/extras/swap_channel.py b/amt/src/extras/swap_channel.py new file mode 100644 index 0000000000000000000000000000000000000000..39c71a6765c91c633ffaf423ebb3217f5334a946 --- /dev/null +++ b/amt/src/extras/swap_channel.py @@ -0,0 +1,122 @@ +import numpy as np + +a = np.arange(12).reshape(2, 3, 2) # (batch, channel, dim) +print(a) +array([[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]]) + +swap_mat = create_swap_channel_mat(input_shape, swap_channel=(1, 2)) + +# will swap channel 1 and 2 of batch 0 with channel 1 and 2 of batch 1 +b = a @ swap_mat +print(b) +# expected output +array([[[0, 1], [8, 9], [10, 11]], [[6, 7], [2, 3], [4, 5]]]) + +import torch + + +def swap_channels_between_batches(a_tensor, swap_channels): + # Copy the tensor to avoid modifying the original tensor + result_tensor = a_tensor.clone() + + # Unpack the channels to be swapped + ch1, ch2 = swap_channels + + # Swap the specified channels between batches + result_tensor[0, ch1, :], result_tensor[1, ch1, :] = a_tensor[1, ch1, :].clone(), a_tensor[0, ch1, :].clone() + result_tensor[0, ch2, :], result_tensor[1, ch2, :] = a_tensor[1, ch2, :].clone(), a_tensor[0, ch2, :].clone() + + return result_tensor + + +# Define a sample tensor 'a_tensor' +a_tensor = torch.tensor([[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]], dtype=torch.float32) + +# Define channels to swap +swap_channels = (1, 2) # Channels to swap between batches + +# Swap the channels between batches +swapped_tensor = swap_channels_between_batches(a_tensor, swap_channels) + +# Print the original tensor and the tensor after swapping channels between batches +print("Original Tensor 'a_tensor':") +print(a_tensor) +print("\nTensor after swapping channels between batches:") +print(swapped_tensor) + +#------------------------------------------------- + +import torch +from einops import rearrange + + +def shift(arr, num, fill_value=np.nan): + result = np.empty_like(arr) + if num > 0: + result[:num] = fill_value + result[num:] = arr[:-num] + elif num < 0: + result[num:] = fill_value + result[:num] = arr[-num:] + else: + result[:] = arr + return result + + +def create_batch_swap_matrix(batch_size, channels, swap_channels): + swap_mat = np.eye(batch_size * channels) + + for c in swap_channels: + idx1 = c # 첫 번째 배치의 교환할 채널 인덱스 + idx2 = c + channels # 두 번째 배치의 교환할 채널 인덱스 + + swap_mat[idx1, idx1], swap_mat[idx2, idx2] = 0, 0 # 대각선 값을 0으로 설정 + swap_mat[idx1, idx2], swap_mat[idx2, idx1] = 1, 1 # 해당 채널을 교환 + return swap_mat + + +def create_batch_swap_matrix(batch_size, channels, swap_channels): + swap_mat = np.eye(batch_size * channels) + + # 모든 채널에 대해 교환 수행 + for c in swap_channels: + idx1 = np.arange(c, batch_size * channels, channels) # 현재 채널의 모든 배치 인덱스 + idx2 = (idx1 + channels) % (batch_size * channels) # 순환을 위해 modulo 사용 + + swap_mat[idx1, idx1] = 0 + swap_mat[idx2, idx2] = 0 + swap_mat[idx1, idx2] = 1 + swap_mat[idx2, idx1] = 1 + + return swap_mat + + +def swap_channels_between_batches(input_tensor, swap_matrix): + reshaped_tensor = rearrange(input_tensor, 'b c d -> (b c) d') + swapped_tensor = swap_matrix @ reshaped_tensor + return rearrange(swapped_tensor, '(b c) d -> b c d', b=input_tensor.shape[0]) + + +# 예제 파라미터 +batch_size = 2 +channels = 3 +# swap_info = { +# : [1, 2] # batch_index: [channel_indices] +# } +swap_channels = [1, 2] # 교환할 채널 + +# 예제 텐서 생성 +input_tensor = torch.tensor([[[0, 1], [2, 3], [4, 5]], [[6, 7], [8, 9], [10, 11]]], dtype=torch.float32) + +# swap matrix 생성 +swap_matrix = create_batch_swap_matrix(batch_size, channels, swap_channels) +swap_matrix = torch.Tensor(swap_matrix) + +# 채널 교환 수행 +swapped_tensor = swap_channels_between_batches(input_tensor, swap_matrix) + +# 결과 출력 +print("Original Tensor:") +print(input_tensor) +print("\nSwapped Tensor:") +print(swapped_tensor) diff --git a/amt/src/extras/t5_dev.py b/amt/src/extras/t5_dev.py new file mode 100644 index 0000000000000000000000000000000000000000..d6d9a842ab729853bd370f75a3934172bb47e2b8 --- /dev/null +++ b/amt/src/extras/t5_dev.py @@ -0,0 +1,41 @@ +import torch +from transformers import T5Config +from model.t5mod import T5ForConditionalGeneration + +a = { + "architectures": ["T5ForConditionalGeneration"], + "d_ff": 1024, # size of the intermediate feed forward layer in each T5Block + "d_kv": 64, # d_kv has to be equal to d_model // num_heads. + # "d_model": 512, # encoder hiddnen size, defined by model_cfg + "decoder_start_token_id": 0, + "dense_act_fn": "gelu_new", + # "dropout_rate": 0.05, # can be overwritten by args in ymt3 + "eos_token_id": 1, + "feed_forward_proj": "gated-gelu", + "initializer_factor": 1.0, + # "is_encoder_decoder": True, + "is_gated_act": True, + "layer_norm_epsilon": 1e-06, + "model_type": "t5", + # "num_decoder_layers": 8, + "num_heads": 6, + "num_layers": 8, + "output_past": True, + "pad_token_id": 0, + "relative_attention_num_buckets": 32, + "use_cache": True, + "vocab_size": 1391 # vocab_size is automatically set by the task manager... +} +cfg = T5Config(**a) +cfg.num_decoder_layers = 4 +cfg.num_layers = 0 + +model = T5ForConditionalGeneration(cfg) +print(model) + +x = torch.rand(((2, 256, 512))) +out = model.encoder.forward(inputs_embeds=x) + +enc_hs = torch.rand((2, 256, 512)) +labels = torch.randint(0, 1391, (2, 256)) +pred = model(encoder_outputs=(enc_hs,), labels=labels) # important (enc_hs,) comma! diff --git a/amt/src/extras/t5perceiver.py b/amt/src/extras/t5perceiver.py new file mode 100644 index 0000000000000000000000000000000000000000..998cb320bf1016c0933ee78e45b17c420a434138 --- /dev/null +++ b/amt/src/extras/t5perceiver.py @@ -0,0 +1,443 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +""" Bare wrapper of HF PyTorch T5 and Perceiver with the following modifications: +- PerceiverTF encoder +- ResConv pre-encoder +- Projection layers for dynamic dimension matching +- Sinusoidal absolute positional embeddings +- Positional embeddings from Perceiver implementation +- Task conditioning on encoder and decoder by input tokens +""" +import copy +import warnings +from typing import Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from torch.utils.checkpoint import checkpoint + +from transformers.utils import logging +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map +from transformers.modeling_utils import PreTrainedModel +from transformers.models.t5.modeling_t5 import (T5LayerNorm, T5Block, PARALLELIZE_DOCSTRING, DEPARALLELIZE_DOCSTRING, + T5_START_DOCSTRING, T5_INPUTS_DOCSTRING, _CONFIG_FOR_DOC, + __HEAD_MASK_WARNING_MSG) +from transformers.modeling_outputs import (Seq2SeqLMOutput, BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions) +from transformers import T5Config #, T5PreTrainedModel +from model.ops import FixedSinusoidalPositionalEmbedding + +# additional imports +from model.t5mod import T5Stack +from transformers.models.t5.modeling_t5 import (T5Model, T5ForConditionalGeneration, T5EncoderModel, T5DenseActDense, + T5DenseGatedActDense, T5Attention, load_tf_weights_in_t5, + is_torch_fx_proxy) + +from transformers.utils import (DUMMY_INPUTS, DUMMY_MASK) + +logger = logging.get_logger(__name__) + + +class T5PerceiverPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = None + load_tf_weights = load_tf_weights_in_t5 + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + _no_split_modules = ["T5Block"] + _keep_in_fp32_modules = ["wo"] + + @property + def dummy_inputs(self): + input_ids = torch.tensor(DUMMY_INPUTS) + input_mask = torch.tensor(DUMMY_MASK) + dummy_inputs = { + "decoder_input_ids": input_ids, + "input_ids": input_ids, + "decoder_attention_mask": input_mask, + } + return dummy_inputs + + def _init_weights(self, module): + """Initialize the weights""" + factor = self.config.initializer_factor # Used for testing weights initialization + if isinstance(module, T5LayerNorm): + module.weight.data.fill_(factor * 1.0) + elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)): + # Mesh TensorFlow embeddings initialization + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624 + module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0) + if hasattr(module, "lm_head") and not self.config.tie_word_embeddings: + module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0) + elif isinstance(module, T5DenseActDense): + # Mesh TensorFlow FF initialization + # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56 + # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89 + module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model)**-0.5)) + if hasattr(module.wi, "bias") and module.wi.bias is not None: + module.wi.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff)**-0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, T5DenseGatedActDense): + module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model)**-0.5)) + if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None: + module.wi_0.bias.data.zero_() + module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model)**-0.5)) + if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None: + module.wi_1.bias.data.zero_() + module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff)**-0.5)) + if hasattr(module.wo, "bias") and module.wo.bias is not None: + module.wo.bias.data.zero_() + elif isinstance(module, T5Attention): + # Mesh TensorFlow attention initialization to avoid scaling before softmax + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 + d_model = self.config.d_model + key_value_proj_dim = self.config.d_kv + n_heads = self.config.num_heads + module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim)**-0.5)) + module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5)) + module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim)**-0.5)) + if module.has_relative_attention_bias: + module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model)**-0.5)) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (T5Attention, T5Stack)): + module.gradient_checkpointing = value + + def _shift_right(self, input_ids): + decoder_start_token_id = self.config.decoder_start_token_id + pad_token_id = self.config.pad_token_id + + assert decoder_start_token_id is not None, ( + "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id." + " See T5 docs for more information") + + # shift inputs to the right + if is_torch_fx_proxy(input_ids): + # Item assignment is not supported natively for proxies. + shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id) + shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1) + else: + shifted_input_ids = input_ids.new_zeros(input_ids.shape) + shifted_input_ids[..., 1:] = input_ids[..., :-1].clone() + shifted_input_ids[..., 0] = decoder_start_token_id + + assert pad_token_id is not None, "self.model.config.pad_token_id has to be defined." + # replace possible -100 values in labels by `pad_token_id` + shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) + + return shifted_input_ids + + +class T5PerceiverForConditionalGeneration(T5PerceiverPreTrainedModel): + config_class = None + load_tf_weights = load_tf_weights_in_t5 + base_model_prefix = "transformer" + is_parallelizable = True + supports_gradient_checkpointing = True + _no_split_modules = ["T5Block"] + _keep_in_fp32_modules = ["wo"] + + @property + def dummy_inputs(self): + input_ids = torch.tensor(DUMMY_INPUTS) + input_mask = torch.tensor(DUMMY_MASK) + dummy_inputs = { + "decoder_input_ids": input_ids, + "input_ids": input_ids, + "decoder_attention_mask": input_mask, + } + return dummy_inputs + + def __init__( + self, + model_cfg: dict, + # config: T5Config, + # use_fixed_absolute_pe: bool = True, + # num_max_positions: int = 1025 + ): + super().__init__(config) + self.model_dim = config.d_model + """ mod: absolute position embedding """ + self.use_fixed_absolute_pe = use_fixed_absolute_pe + + self.shared = nn.Embedding(config.vocab_size, config.d_model) + + encoder_config = copy.deepcopy(config) + encoder_config.is_decoder = False + encoder_config.use_cache = False + encoder_config.is_encoder_decoder = False + self.encoder = T5Stack(encoder_config, + self.shared, + use_fixed_absolute_pe=use_fixed_absolute_pe, + num_max_positions=num_max_positions) + + decoder_config = copy.deepcopy(config) + decoder_config.is_decoder = True + decoder_config.is_encoder_decoder = False + decoder_config.num_layers = config.num_decoder_layers + self.decoder = T5Stack(decoder_config, + self.shared, + use_fixed_absolute_pe=use_fixed_absolute_pe, + num_max_positions=num_max_positions) + + self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + # Model parallel + self.model_parallel = False + self.device_map = None + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, new_embeddings): + self.shared = new_embeddings + self.encoder.set_input_embeddings(new_embeddings) + self.decoder.set_input_embeddings(new_embeddings) + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_output_embeddings(self): + return self.lm_head + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.BoolTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + decoder_head_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., + config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for + labels in `[0, ..., config.vocab_size]` + + Returns: + + Examples: + + ```python + >>> from transformers import AutoTokenizer, T5ForConditionalGeneration + + >>> tokenizer = AutoTokenizer.from_pretrained("t5-small") + >>> model = T5ForConditionalGeneration.from_pretrained("t5-small") + + >>> # training + >>> input_ids = tokenizer("The walks in park", return_tensors="pt").input_ids + >>> labels = tokenizer(" cute dog the ", return_tensors="pt").input_ids + >>> outputs = model(input_ids=input_ids, labels=labels) + >>> loss = outputs.loss + >>> logits = outputs.logits + + >>> # inference + >>> input_ids = tokenizer( + ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt" + ... ).input_ids # Batch size 1 + >>> outputs = model.generate(input_ids) + >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) + >>> # studies have shown that owning a dog is good for you. + ```""" + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask + if head_mask is not None and decoder_head_mask is None: + if self.config.num_layers == self.config.num_decoder_layers: + warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) + decoder_head_mask = head_mask + + # Encode if needed (training, first prediction pass) + if encoder_outputs is None: + # Convert encoder inputs in embeddings if needed + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + hidden_states = encoder_outputs[0] + + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + + if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None: + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.decoder.first_device) + hidden_states = hidden_states.to(self.decoder.first_device) + if decoder_input_ids is not None: + decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) + if attention_mask is not None: + attention_mask = attention_mask.to(self.decoder.first_device) + if decoder_attention_mask is not None: + decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device) + + # Decode + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + inputs_embeds=decoder_inputs_embeds, + past_key_values=past_key_values, + encoder_hidden_states=hidden_states, + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + sequence_output = decoder_outputs[0] + + # Set device for model parallelism + if self.model_parallel: + torch.cuda.set_device(self.encoder.first_device) + self.lm_head = self.lm_head.to(self.encoder.first_device) + sequence_output = sequence_output.to(self.lm_head.weight.device) + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + sequence_output = sequence_output * (self.model_dim**-0.5) + + lm_logits = self.lm_head(sequence_output) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss(ignore_index=-100) + # move labels to correct device to enable PP + labels = labels.to(lm_logits.device) + loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1)) + # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666 + + if not return_dict: + output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs + return ((loss,) + output) if loss is not None else output + + return Seq2SeqLMOutput( + loss=loss, + logits=lm_logits, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + # cut decoder_input_ids if past is used + if past_key_values is not None: + input_ids = input_ids[:, -1:] + + return { + "decoder_input_ids": input_ids, + "past_key_values": past_key_values, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + } + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + def _reorder_cache(self, past_key_values, beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if past_key_values is None: + logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") + return past_key_values + + reordered_decoder_past = () + for layer_past_states in past_key_values: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + (layer_past_state.index_select( + 0, beam_idx.to(layer_past_state.device)),) + + assert reordered_layer_past_states[0].shape == layer_past_states[0].shape + assert len(reordered_layer_past_states) == len(layer_past_states) + + reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) + return reordered_decoder_past + + +from transformers import PreTrainedModel, PretrainedConfig +from transformers import AutoModel, AutoConfig + + +class MyConfig(T5Config, PerceiverConfig): + model_type = 'mymodel' + + def __init__(self, important_param=42, **kwargs): + super().__init__(**kwargs) + self.important_param = important_param diff --git a/amt/src/extras/unimax_sampler/README.md b/amt/src/extras/unimax_sampler/README.md new file mode 100644 index 0000000000000000000000000000000000000000..02397c64c85839a8bb4f3edc3bc1deb0f186c029 --- /dev/null +++ b/amt/src/extras/unimax_sampler/README.md @@ -0,0 +1,45 @@ +# UniMax Language Dataset Sampler with DDP support + +This repository contains an unofficial implementation of the UNIMAX sampling algorithm using PyTorch. The UNIMAX algorithm ["UniMax: Fairer and more Effective Language Sampling for Large-Scale Multilingual Pretraining" by HW Chung et al. (ICLR 2023)](https://arxiv.org/abs/2304.09151) is used to generate a sampling distribution of languages based on their character counts, a total character budget, and a specified number of epochs per language. This can be useful for training language models on datasets with imbalanced language distribution. + +## Contents + +1. `unimax_sampler.py`: This Python file contains the `UnimaxSampler` class, a PyTorch `Sampler` that uses the UNIMAX algorithm. + +2. `test_unimax_sampler.py`: This Python file contains a unit test for the `UnimaxSampler` class to ensure its correct functionality. + +## Usage + +```python +from torch.utils.data import Dataset, DataLoader +from unimax_sampler import UnimaxSampler + +# Define your parameters +language_character_counts = [100, 200, 300, 400, 500] +total_character_budget = 1000 +num_epochs = 2 + +# Create the UnimaxSampler +unimax_sampler = UnimaxSampler(language_character_counts, total_character_budget, num_epochs) +``` + +Then, use the sampler as the sampler argument when creating a DataLoader. + +```python +# Disable shuffle when using custom sampler... +data_loader = DataLoader(my_dataset, batch_size=2, shuffle=None, sampler=unimax_sampler) +``` + +For DDP, +```python +if torch.distributed.is_initialized(): + sampler = DistributedUnimaxSampler(...) +else: + return unimax_sampler(...) +``` + +## Note +The initial version of this code was created by [Chat GPT-4](https://chat.openai.com/), based on the pseudocode provided in the [UNIMAX](https://arxiv.org/abs/2304.09151) paper. Subsequently, the code was manually revised for `PyTorch` Distributed Data Parallel ([DDP](https://pytorch.org/docs/stable/notes/ddp.html)) framework. The DistributedSamplerWrapper implementation is derived from an earlier version found in the [Catalyst](https://github.com/catalyst-team/catalyst) project. + +## License +This project is licensed under the MIT License. \ No newline at end of file diff --git a/amt/src/extras/unimax_sampler/demo.py b/amt/src/extras/unimax_sampler/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..6480ac8d702e2c7b950dd10b906d4f2c9d76beac --- /dev/null +++ b/amt/src/extras/unimax_sampler/demo.py @@ -0,0 +1,15 @@ +from utils.unimax_sampler.unimax_sampler import UnimaxSampler + +language_character_counts = [100, 200, 300, 400, 500] +total_character_budget = 1000 +num_epochs = 2 + +# Create the UnimaxSampler. +sampler = UnimaxSampler(language_character_counts, total_character_budget, num_epochs) + +# Define the expected output. This will depend on your specific implementation of Unimax. +expected_output = torch.tensor([0.1, 0.2, 0.3, 0.2, 0.2]) + +# Use PyTorch's allclose function to compare the computed and expected outputs. +# The absolute tolerance parameter atol specifies the maximum difference allowed for the test to pass. +self.assertTrue(torch.allclose(sampler.p, expected_output, atol=1e-6)) \ No newline at end of file diff --git a/amt/src/extras/unimax_sampler/unimax_sampler.py b/amt/src/extras/unimax_sampler/unimax_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..fd06ba1fd314f90e5770b1a359b8509080665c4c --- /dev/null +++ b/amt/src/extras/unimax_sampler/unimax_sampler.py @@ -0,0 +1,168 @@ +import torch +from torch.utils.data import DistributedSampler +from torch.utils.data import Dataset, Sampler +from torch.utils.data import RandomSampler +from operator import itemgetter +from typing import List, Union, Iterator, Optional + + +class DatasetFromSampler(Dataset): + """Dataset to create indexes from `Sampler`. From catalyst library. + + Args: + sampler: PyTorch sampler + """ + + def __init__(self, sampler: Sampler): + """Initialisation for DatasetFromSampler.""" + self.sampler = sampler + self.sampler_list = None + + def __getitem__(self, index: int): + """Gets element of the dataset. + + Args: + index: index of the element in the dataset + + Returns: + Single element by index + """ + if self.sampler_list is None: + self.sampler_list = list(self.sampler) + return self.sampler_list[index] + + def __len__(self) -> int: + """ + Returns: + int: length of the dataset + """ + return len(self.sampler) + + +class DistributedSamplerWrapper(DistributedSampler): + """ + Wrapper over `Sampler` for distributed training. + Allows you to use any sampler in distributed mode. + From https://github.com/catalyst-team/catalyst/blob/master/catalyst/data/sampler.py + + It is especially useful in conjunction with + `torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSamplerWrapper instance as a DataLoader + sampler, and load a subset of subsampled data of the original dataset + that is exclusive to it. + + .. note:: + Sampler is assumed to be of constant size. + """ + + def __init__( + self, + sampler, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + ): + """ + + Args: + sampler: Sampler used for subsampling + num_replicas (int, optional): Number of processes participating in + distributed training + rank (int, optional): Rank of the current process + within ``num_replicas`` + shuffle (bool, optional): If true (default), + sampler will shuffle the indices + """ + super(DistributedSamplerWrapper, self).__init__( + DatasetFromSampler(sampler), + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle, + ) + self.sampler = sampler + + def __iter__(self) -> Iterator[int]: + """Iterate over sampler. + + Returns: + python iterator + """ + self.dataset = DatasetFromSampler(self.sampler) + indexes_of_indexes = super().__iter__() + subsampler_indexes = self.dataset + return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes)) + + +class UnimaxSampler(Sampler): + # Initialize the sampler with the character counts for each language, + # the total character budget, and the number of epochs per language. + def __init__(self, language_character_counts: List[int], total_character_budget: int, + num_epochs: int) -> None: + self.language_character_counts = torch.tensor(language_character_counts) + self.total_character_budget = total_character_budget + self.num_epochs = num_epochs + # Compute the sampling distribution p. + self.p = self._unimax() + + # Define how to iterate over the data. We'll use PyTorch's multinomial + # function to generate indices according to the distribution p. + def __iter__(self) -> iter: + return iter(torch.multinomial(self.p, len(self.p), replacement=True).tolist()) + + # Define the length of the sampler as the number of languages. + def __len__(self) -> int: + return len(self.p) + + # Implement the UNIMAX algorithm to compute the sampling distribution p. + def _unimax(self) -> torch.Tensor: + # Sort languages by character count. + L, indices = torch.sort(self.language_character_counts) + # Initialize the remaining budget to the total character budget. + B = float(self.total_character_budget) + i = 0 + # Initialize the budget per language. + U = torch.zeros_like(L) + # For each language... + for idx in indices: + # Compute the remaining budget per-language. + bl = B / (len(L) - i) + cl = L[idx] + # If per-language budget exceeds N epochs of the language, use N epochs. + if bl > cl * self.num_epochs: + Ul = cl * self.num_epochs + # Otherwise use uniform per-language budget. + else: + Ul = bl + # Store the computed budget. + U[idx] = Ul + # Update the remaining budget. + B -= Ul + # Move to the next language. + i += 1 + # Normalize the budget to create a distribution. + p = U / U.sum() + # Return the computed distribution. + return p + + +class DistributedUnimaxSampler(UnimaxSampler): + + def __init__(self, + language_character_counts: List[int], + total_character_budget: int, + num_epochs: int, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True) -> None: + + super().__init__(language_character_counts, total_character_budget, num_epochs) + self.distributed_sampler = DistributedSamplerWrapper(self, num_replicas, rank, shuffle) + + def __iter__(self): + return iter(self.distributed_sampler) + + def __len__(self): + return len(self.distributed_sampler) + + def set_epoch(self, epoch): + self.distributed_sampler.set_epoch(epoch) \ No newline at end of file diff --git a/amt/src/install_dataset.py b/amt/src/install_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a1f8f25ef342dde659c9c358bb8493a15592474b --- /dev/null +++ b/amt/src/install_dataset.py @@ -0,0 +1,285 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +""" install_dataset.py """ +import os +import argparse +import mirdata +from typing import Optional, Tuple, Union +from utils.preprocess.generate_dataset_stats import generate_dataset_stats_for_all_datasets, update_dataset_stats_for_new_dataset +from utils.mirdata_dev.datasets import slakh16k +from utils.preprocess.preprocess_slakh import preprocess_slakh16k, add_program_and_is_drum_info_to_file_list +from utils.preprocess.preprocess_musicnet import preprocess_musicnet16k +from utils.preprocess.preprocess_maps import preprocess_maps16k +from utils.preprocess.preprocess_maestro import preprocess_maestro16k +from utils.preprocess.preprocess_guitarset import preprocess_guitarset16k, create_filelist_by_style_guitarset16k +from utils.preprocess.preprocess_enstdrums import preprocess_enstdrums16k, create_filelist_dtm_random_enstdrums16k +from utils.preprocess.preprocess_mir_st500 import preprocess_mir_st500_16k +from utils.preprocess.preprocess_cmedia import preprocess_cmedia_16k +from utils.preprocess.preprocess_rwc_pop_full import preprocess_rwc_pop_full16k +from utils.preprocess.preprocess_rwc_pop import preprocess_rwc_pop16k +from utils.preprocess.preprocess_egmd import preprocess_egmd16k +from utils.preprocess.preprocess_mir1k import preprocess_mir1k_16k +from utils.preprocess.preprocess_urmp import preprocess_urmp16k +from utils.preprocess.preprocess_idmt_smt_bass import preprocess_idmt_smt_bass_16k +from utils.preprocess.preprocess_geerdes import preprocess_geerdes16k +from utils.utils import download_and_extract #, download_and_extract_zenodo_restricted + +# zenodo_token = "eyJhbGciOiJIUzUxMiIsImlhdCI6MTcxMDE1MDYzNywiZXhwIjoxNzEyNzA3MTk5fQ.eyJpZCI6ImRmODA5NzZlLTBjM2QtNDk5NS05YjM0LWFiNGM4NzJhMmZhMSIsImRhdGEiOnt9LCJyYW5kb20iOiIwMzY5ZDcxZjc2NTMyN2UyYmVmN2ExYjJkMmMyYTRhNSJ9.0aHnNC-7ivWQO6l8twjLR0NDH4boC0uOolAAmogVt7XRi2PHU5MEKBQoK7-wgDdnmWEIqEIvoLO6p8KTnsY9dg" + + +def install_slakh(data_home=os.PathLike, no_down=False) -> None: + if not no_down: + ds = slakh16k.Dataset(data_home, version='2100-yourmt3-16k') + ds.download(partial_download=['2100-yourmt3-16k', 'index']) + del (ds) + preprocess_slakh16k(data_home, delete_source_files=False, fix_bass_octave=True) + add_program_and_is_drum_info_to_file_list(data_home) + + +def install_musicnet(data_home=os.PathLike, no_down=False) -> None: + if not no_down: + url = "https://zenodo.org/record/7811639/files/musicnet_yourmt3_16k.tar.gz?download=1" + checksum = "a2da7c169e26d452a4e8b9bef498b3d7" + download_and_extract(data_home, url, remove_tar_file=True, check_sum=checksum) + preprocess_musicnet16k(data_home, dataset_name='musicnet') + + +def install_maps(data_home=os.PathLike, no_down=False, sanity_check=False) -> None: + if not no_down: + url = "https://zenodo.org/record/7812075/files/maps_yourmt3_16k.tar.gz?download=1" + checksum = "6b070d162c931cd5e69c16ef2398a649" + download_and_extract(data_home, url, remove_tar_file=True, check_sum=checksum) + preprocess_maps16k(data_home, dataset_name='maps', ignore_pedal=False, sanity_check=sanity_check) + + +def install_maestro(data_home=os.PathLike, no_down=False, sanity_check=False) -> None: + if not no_down: + url = "https://zenodo.org/record/7852176/files/maestro_yourmt3_16k.tar.gz?download=1" + checksum = "c17c6a188d936e5ff3870ef27144d397" + download_and_extract(data_home, url, remove_tar_file=True, check_sum=checksum) + preprocess_maestro16k(data_home, dataset_name='maestro', ignore_pedal=False, sanity_check=sanity_check) + + +def install_guitarset(data_home=os.PathLike, no_down=False) -> None: + if not no_down: + url = "https://zenodo.org/record/7831843/files/guitarset_yourmt3_16k.tar.gz?download=1" + checksum = "e3cfe0cc9394d91d9c290ce888821360" + download_and_extract(data_home, url, remove_tar_file=True, check_sum=checksum) + preprocess_guitarset16k(data_home, dataset_name='guitarset') + create_filelist_by_style_guitarset16k(data_home, dataset_name='guitarset') + + +def install_enstdrums(data_home, no_down=False) -> None: + if not no_down: + url = "https://zenodo.org/record/7831843/files/enstdrums_yourmt3_16k.tar.gz?download=1" + checksum = "7e28c2a923e4f4162b3d83877cedb5eb" + download_and_extract(data_home, url, remove_tar_file=True, check_sum=checksum) + preprocess_enstdrums16k(data_home, dataset_name='enstdrums') + create_filelist_dtm_random_enstdrums16k(data_home, dataset_name='enstdrums') + + +def install_egmd(data_home, no_down=False) -> None: + if not no_down: + url = "https://zenodo.org/record/7831072/files/egmc_yourmt3_16k.tar.gz?download=1" + checksum = "4f615157ea4c52a64c6c9dcf68bf2bde" + download_and_extract(data_home, url, remove_tar_file=True, check_sum=checksum) + preprocess_egmd16k(data_home, dataset_name='egmd') + + +def install_mirst500(data_home, zenodo_token, no_down=False, sanity_check=True, apply_correction=False) -> None: + """ Update Oct 2023: MIR-ST500 with FULL audio files""" + if not no_down: + url = "https://zenodo.org/records/10016397/files/mir_st500_yourmt3_16k.tar.gz?download=1" + checksum = "98eb52eb2456ce4034e21750f309da13" + download_and_extract(data_home, url, check_sum=checksum, zenodo_token=zenodo_token) + preprocess_mir_st500_16k(data_home, dataset_name='mir_st500', sanity_check=sanity_check) + + +def install_cmedia(data_home, zenodo_token, no_down=False, sanity_check=True) -> None: + if not no_down: + url = "https://zenodo.org/records/10016397/files/cmedia_yourmt3_16k.tar.gz?download=1" + checksum = "e6cca23577ba7588e9ed9711a398f7cf" + download_and_extract(data_home, url, check_sum=checksum, zenodo_token=zenodo_token) + preprocess_cmedia_16k(data_home, dataset_name='cmedia', sanity_check=sanity_check, apply_correction=True) + + +def install_rwc_pop(data_home, zenodo_token, no_down=False) -> None: + if not no_down: + url = "https://zenodo.org/records/10016397/files/rwc_pop_yourmt3_16k.tar.gz?download=1" + checksum = "ad459f9fa1b6b87676b2fb37c0ba5dfc" + download_and_extract(data_home, url, check_sum=checksum, zenodo_token=zenodo_token) + preprocess_rwc_pop16k(data_home, dataset_name='rwc_pop') # bass transcriptions + preprocess_rwc_pop_full16k(data_home, dataset_name='rwc_pop') # full transcriptions + + +def install_mir1k(data_home, no_down=False) -> None: + if not no_down: + url = "https://zenodo.org/record/7955481/files/mir1k_yourmt3_16k.tar.gz?download=1" + checksum = "4cbac56a4e971432ca807efd5cb76d67" + download_and_extract(data_home, url, remove_tar_file=True, check_sum=checksum) + # preprocess_mir1k_16k(data_home, dataset_name='mir1k') + + +def install_urmp(data_home, no_down=False) -> None: + if not no_down: + url = "https://zenodo.org/record/8021437/files/urmp_yourmt3_16k.tar.gz?download=1" + checksum = "4f539c71678a77ba34f6dfca41072102" + download_and_extract(data_home, url, remove_tar_file=True, check_sum=checksum) + preprocess_urmp16k(data_home, dataset_name='urmp') + + +def install_idmt_smt_bass(data_home, no_down=False) -> None: + if not no_down: + url = "https://zenodo.org/records/10009959/files/idmt_smt_bass_yourmt3_16k.tar.gz?download=1" + checksum = "0c95f91926a1e95b1f5d075c05b7eb76" + download_and_extract(data_home, url, remove_tar_file=True, check_sum=checksum) + preprocess_idmt_smt_bass_16k(data_home, dataset_name='idmt_smt_bass', sanity_check=True, + edit_audio=False) # the donwloaded audio has already been edited + + +def install_random_nsynth(data_home, no_down=False) -> None: + return + + +def install_geerdes(data_home) -> None: + try: + preprocess_geerdes16k(data_home, dataset_name='geerdes', sanity_check=False) + except Exception as e: + print(e) + print("Geerdes dataset is not available for download. Please contact the dataset provider.") + + +def regenerate_dataset_stats(data_home) -> None: + generate_dataset_stats_for_all_datasets(data_home) + + +def get_cached_zenodo_token() -> str: + # check if cached token exists + if not os.path.exists('.cached_zenodo_token'): + raise Exception("Cached Zenodo token not found. Please enter your Zenodo token.") + # read cached token + with open('.cached_zenodo_token', 'r') as f: + zenodo_token = f.read().strip() + print(f"Using cached Zenodo token: {zenodo_token}") + return zenodo_token + + +def cache_zenodo_token(zenodo_token: str) -> None: + with open('.cached_zenodo_token', 'w') as f: + f.write(zenodo_token) + print("Your Zenodo token is cached.") + + +def option_prompt(data_home: os.PathLike, no_download: bool = False) -> None: + print("Select the dataset(s) to install (enter comma-separated numbers):") + print("1. Slakh") + print("2. MusicNet") + print("3. MAPS") + print("4. Maestro") + print("5. GuitarSet") + print("6. ENST-drums") + print("7. EGMD") + print("8. MIR-ST500 ** Restricted Access **") + print("9. CMedia ** Restricted Access **") + print("10. RWC-Pop (Bass and Full) ** Restricted Access **") + print("11. MIR-1K (NOT SUPPORTED)") + print("12. URMP") + print("13. IDMT-SMT-Bass") + print("14. Random-NSynth") + print("15. Geerdes") + print("16. Regenerate Dataset Stats (experimental)") + print("17. Request Token for ** Restricted Access **") + print("18. Exit") + + choice = input("Enter your choices (multiple choices with comma): ") + choices = [c.strip() for c in choice.split(',')] + + if "18" in choices: + print("Exiting.") + else: + # ask for Zenodo token + for c in choices: + if int(c) in [8, 9, 10]: + if no_download is True: + zenodo_token = None + else: + zenodo_token = input("Enter Zenodo token, or press enter to use the cached token:") + if zenodo_token == "": + zenodo_token = get_cached_zenodo_token() + else: + cache_zenodo_token(zenodo_token) + break + + if "1" in choices: + install_slakh(data_home, no_down=no_download) + if "2" in choices: + install_musicnet(data_home, no_down=no_download) + if "3" in choices: + install_maps(data_home, no_down=no_download) + if "4" in choices: + install_maestro(data_home, no_down=no_download) + if "5" in choices: + install_guitarset(data_home, no_down=no_download) + if "6" in choices: + install_enstdrums(data_home, no_down=no_download) + if "7" in choices: + install_egmd(data_home, no_down=no_download) + if "8" in choices: + install_mirst500(data_home, zenodo_token, no_down=no_download) + if "9" in choices: + install_cmedia(data_home, zenodo_token, no_down=no_download) + if "10" in choices: + install_rwc_pop(data_home, zenodo_token, no_down=no_download) + if "11" in choices: + install_mir1k(data_home, no_down=no_download) + if "12" in choices: + install_urmp(data_home, no_down=no_download) + if "13" in choices: + install_idmt_smt_bass(data_home, no_down=no_download) + if "14" in choices: + install_random_nsynth(data_home, no_down=no_download) + if "15" in choices: + install_geerdes(data_home) # not available for download + if "16" in choices: + regenerate_dataset_stats(data_home, no_down=no_download) + if "17" in choices: + print("\nPlease visit https://zenodo.org/records/10016397 to request a Zenodo token.") + print("Upon submitting your request, you will receive an email with a link labeled 'Access the record'.") + print("Copy the token that follows 'token=' in that link.") + if not any(int(c) in range(16) for c in choices): + print("Invalid choice(s). Please enter valid numbers separated by commas.") + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description='Dataset installer script.') + # data home dir + parser.add_argument( + 'data_home', + type=str, + nargs='?', + default=None, + help='Path to data home directory. If None, use the default path defined in src/config/config.py') + # `no_download` option + parser.add_argument('--nodown', + '-nd', + action='store_true', + help='Flag to control downloading. If set, no downloading will occur.') + args = parser.parse_args() + + if args.data_home is None: + from config.config import shared_cfg + data_home = shared_cfg["PATH"]["data_home"] + else: + data_home = args.data_home + os.makedirs(data_home, exist_ok=True) + no_download = args.nodown + + option_prompt(data_home, no_download) diff --git a/amt/src/model/RoPE/RoPE.py b/amt/src/model/RoPE/RoPE.py new file mode 100644 index 0000000000000000000000000000000000000000..0884d027c25393e76ac9ff292c162abc9e4fcbbc --- /dev/null +++ b/amt/src/model/RoPE/RoPE.py @@ -0,0 +1,306 @@ +"""rotary_embedding.py - Rotary Embedding based on https://github.com/lucidrains/rotary-embedding-torch""" +from typing import Literal, Union, Optional +from math import pi, log +from einops import rearrange, repeat + +import torch +from torch.nn import Module, ModuleList +from torch.cuda.amp import autocast +from torch import nn, einsum, broadcast_tensors, Tensor + + +# helper functions +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +# broadcat, as tortoise-tts was using it +def broadcat(tensors, dim=-1): + broadcasted_tensors = broadcast_tensors(*tensors) + return torch.cat(broadcasted_tensors, dim=dim) + + +# rotary embedding helper functions +def rotate_half(x): + x = rearrange(x, '... (d r) -> ... d r', r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, '... d r -> ... (d r)') + + +@autocast(enabled=False) +def apply_rotary_emb(freqs, t, start_index=0, scale=1., seq_dim=-2): + """Applies rotary embedding for pixels.""" + if t.ndim == 3: + seq_len = t.shape[seq_dim] + freqs = freqs[-seq_len:].to(t) + + rot_dim = freqs.shape[-1] + end_index = start_index + rot_dim + + assert rot_dim <= t.shape[ + -1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' + + t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] + t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) + return torch.cat((t_left, t, t_right), dim=-1) + + +# learned rotation helpers +def apply_learned_rotations(rotations, t, start_index=0, freq_ranges=None): + if exists(freq_ranges): + rotations = einsum('..., f -> ... f', rotations, freq_ranges) + rotations = rearrange(rotations, '... r f -> ... (r f)') + + rotations = repeat(rotations, '... n -> ... (n r)', r=2) + return apply_rotary_emb(rotations, t, start_index=start_index) + + +# classes +class RotaryEmbedding(Module): + + def __init__(self, + dim, + custom_freqs: Optional[Tensor] = None, + freqs_for: Union[Literal['lang'], Literal['pixel'], Literal['constant']] = 'lang', + theta=10000, + max_freq=10, + num_freqs=1, + learned_freq=False, + use_xpos=False, + xpos_scale_base=512, + interpolate_factor=1., + theta_rescale_factor=1., + seq_before_head_dim=False, + cache_if_possible=True): + super().__init__() + # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning + # has some connection to NTK literature + # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ + + theta *= theta_rescale_factor**(dim / (dim - 2)) + + self.freqs_for = freqs_for + + if exists(custom_freqs): + freqs = custom_freqs + elif freqs_for == 'lang': + freqs = 1. / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) + elif freqs_for == 'pixel': + freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi + elif freqs_for == 'constant': + freqs = torch.ones(num_freqs).float() + + self.cache_if_possible = cache_if_possible + + self.tmp_store('cached_freqs', None) + self.tmp_store('cached_scales', None) + + self.freqs = nn.Parameter(freqs, requires_grad=learned_freq) + + self.learned_freq = learned_freq + + # dummy for device + + self.tmp_store('dummy', torch.tensor(0)) + + # default sequence dimension + + self.seq_before_head_dim = seq_before_head_dim + self.default_seq_dim = -3 if seq_before_head_dim else -2 + + # interpolation factors + + assert interpolate_factor >= 1. + self.interpolate_factor = interpolate_factor + + # xpos + + self.use_xpos = use_xpos + if not use_xpos: + self.tmp_store('scale', None) + return + + scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) + self.scale_base = xpos_scale_base + self.tmp_store('scale', scale) + + @property + def device(self): + return self.dummy.device + + def tmp_store(self, key, value): + self.register_buffer(key, value, persistent=False) + + def get_seq_pos(self, seq_len, device, dtype, offset=0): + return (torch.arange(seq_len, device=device, dtype=dtype) + offset) / self.interpolate_factor + + def rotate_queries_or_keys(self, t, seq_dim=None, offset=0, freq_seq_len=None): + seq_dim = default(seq_dim, self.default_seq_dim) + + assert not self.use_xpos, 'you must use `.rotate_queries_and_keys` method instead and pass in both queries and keys, for length extrapolatable rotary embeddings' + + device, dtype, seq_len = t.device, t.dtype, t.shape[seq_dim] + + if exists(freq_seq_len): + assert freq_seq_len >= seq_len + seq_len = freq_seq_len + + freqs = self.forward(self.get_seq_pos(seq_len, device=device, dtype=dtype, offset=offset), + seq_len=seq_len, + offset=offset) + + if seq_dim == -3: + freqs = rearrange(freqs, 'n d -> n 1 d') + + return apply_rotary_emb(freqs, t, seq_dim=seq_dim) + + def rotate_queries_with_cached_keys(self, q, k, seq_dim=None, offset=0): + seq_dim = default(seq_dim, self.default_seq_dim) + + q_len, k_len = q.shape[seq_dim], k.shape[seq_dim] + assert q_len <= k_len + rotated_q = self.rotate_queries_or_keys(q, seq_dim=seq_dim, freq_seq_len=k_len) + rotated_k = self.rotate_queries_or_keys(k, seq_dim=seq_dim) + + rotated_q = rotated_q.type(q.dtype) + rotated_k = rotated_k.type(k.dtype) + + return rotated_q, rotated_k + + def rotate_queries_and_keys(self, q, k, seq_dim=None): + seq_dim = default(seq_dim, self.default_seq_dim) + + assert self.use_xpos + device, dtype, seq_len = q.device, q.dtype, q.shape[seq_dim] + + seq = self.get_seq_pos(seq_len, dtype=dtype, device=device) + + freqs = self.forward(seq, seq_len=seq_len) + scale = self.get_scale(seq, seq_len=seq_len).to(dtype) + + if seq_dim == -3: + freqs = rearrange(freqs, 'n d -> n 1 d') + scale = rearrange(scale, 'n d -> n 1 d') + + rotated_q = apply_rotary_emb(freqs, q, scale=scale, seq_dim=seq_dim) + rotated_k = apply_rotary_emb(freqs, k, scale=scale**-1, seq_dim=seq_dim) + + rotated_q = rotated_q.type(q.dtype) + rotated_k = rotated_k.type(k.dtype) + + return rotated_q, rotated_k + + def get_scale(self, t: Tensor, seq_len: Optional[int] = None, offset=0): + assert self.use_xpos + + should_cache = (self.cache_if_possible and exists(seq_len)) + + if ( + should_cache and \ + exists(self.cached_scales) and \ + (seq_len + offset) <= self.cached_scales.shape[0] + ): + return self.cached_scales[offset:(offset + seq_len)] + + scale = 1. + if self.use_xpos: + power = (t - len(t) // 2) / self.scale_base + scale = self.scale**rearrange(power, 'n -> n 1') + scale = torch.cat((scale, scale), dim=-1) + + if should_cache: + self.tmp_store('cached_scales', scale) + + return scale + + def get_axial_freqs(self, *dims): + Colon = slice(None) + all_freqs = [] + + for ind, dim in enumerate(dims): + if self.freqs_for == 'pixel': + pos = torch.linspace(-1, 1, steps=dim, device=self.device) + else: + pos = torch.arange(dim, device=self.device) + + freqs = self.forward(pos, seq_len=dim) + + all_axis = [None] * len(dims) + all_axis[ind] = Colon + + new_axis_slice = (Ellipsis, *all_axis, Colon) + all_freqs.append(freqs[new_axis_slice]) + + all_freqs = broadcast_tensors(*all_freqs) + return torch.cat(all_freqs, dim=-1) + + @autocast(enabled=False) + def forward(self, t: Tensor, seq_len=None, offset=0): + should_cache = ( + self.cache_if_possible and \ + not self.learned_freq and \ + exists(seq_len) and \ + self.freqs_for != 'pixel' + ) + + if ( + should_cache and \ + exists(self.cached_freqs) and \ + (offset + seq_len) <= self.cached_freqs.shape[0] + ): + return self.cached_freqs[offset:(offset + seq_len)].detach() + + freqs = self.freqs + + freqs = einsum('..., f -> ... f', t.type(freqs.dtype), freqs) + freqs = repeat(freqs, '... n -> ... (n r)', r=2) + + if should_cache: + self.tmp_store('cached_freqs', freqs.detach()) + + return freqs + + # custom method for applying rotary embeddings + @torch.compiler.disable + def apply_rotary_custom(self, t: torch.Tensor): + """Apply rotary embeddings to queries and keys, if k is None, only q is rotated. + Depending on the freqs type, the rotation will be different.""" + if self.freqs_for == 'lang': + return self.rotate_queries_or_keys(t, seq_dim=-2) + elif self.freqs_for == 'pixel': + return apply_rotary_emb(self.get_axial_freqs(t.shape[-2]), t) + else: + raise ValueError(f"freqs_for must be 'lang' or 'pixel', but got {self.freqs_for}") + + +def test_rotary_embedding_lang(): + d = 32 # d by head + q = torch.ones(1, 4, 110, 32) # (B, H, T, D) for multi-head attention + rdim = d // 2 # will do a partial rotation on half, or d + + rotary = RotaryEmbedding(dim=rdim, freqs_for="lang") + q = rotary.rotate_queries_or_keys(q, seq_dim=-2) + + # visualize + import matplotlib.pyplot as plt + plt.imshow(q[0, 0, :, :].numpy().T, origin='lower') + + +def test_rotary_embedding_pixel(): + d = 32 # d by head + q = torch.ones(1, 4, 128, 32) # (B*T, H, F, C/H) for multi-head attention + rdim = d // 2 # will do a partial rotation on half + + rotary = RotaryEmbedding(dim=rdim, freqs_for="pixel", max_freq=10) + freqs = rotary.get_axial_freqs(128) + + q = apply_rotary_emb(freqs, q) # also k, if needed + + # visualize + import matplotlib.pyplot as plt + plt.imshow(q[0, 0, :, :].numpy().T, origin='lower') diff --git a/amt/src/model/conformer_helper.py b/amt/src/model/conformer_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..736b82c663ae793b6e61ce99da98ac669cd8d980 --- /dev/null +++ b/amt/src/model/conformer_helper.py @@ -0,0 +1,169 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +import math +from typing import Optional, Union + +from torch import nn +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_utils import PreTrainedModel + + +class ConformerYMT3Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`ConformerYMT3Encoder`]. It is used to + instantiate an ConformerYMT3Encoder according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar configuration to that of the Wav2Vec2Conformer + [facebook/wav2vec2-conformer-rel-pos-large](https://huggingface.co/facebook/wav2vec2-conformer-rel-pos-large) + architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + + Args: + d_model (`int`, *optional*, defaults to 512): + Dimensionality of the encoder layers and the pooler layer. + num_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + intermediate_size (`int`, *optional*, defaults to 2048): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + dropout_rate (`float`, *optional*, defaults to 0.05): + The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. + layerdrop (`float`, *optional*, defaults to 0.1): + The LayerDrop probability. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556) for more + details. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_eps (`float`, *optional*, defaults to 1e-12): + The epsilon used by the layer normalization layers. + conv_dim (`Tuple[int]` or `List[int]`, *optional*, defaults to `(512, 512, 512, 512, 512, 512, 512)`): + A tuple of integers defining the number of input and output channels of each 1D convolutional layer in the + feature encoder. The length of *conv_dim* defines the number of 1D convolutional layers. + conv_stride (`Tuple[int]` or `List[int]`, *optional*, defaults to `(5, 2, 2, 2, 2, 2, 2)`): + A tuple of integers defining the stride of each 1D convolutional layer in the feature encoder. The length + of *conv_stride* defines the number of convolutional layers and has to match the length of *conv_dim*. + conv_kernel (`Tuple[int]` or `List[int]`, *optional*, defaults to `(10, 3, 3, 3, 3, 3, 3)`): + A tuple of integers defining the kernel size of each 1D convolutional layer in the feature encoder. The + length of *conv_kernel* defines the number of convolutional layers and has to match the length of + *conv_dim*. + conv_bias (`bool`, *optional*, defaults to `False`): + Whether the 1D convolutional layers have a bias. + output_hidden_size (`int`, *optional*): + Dimensionality of the encoder output layer. If not defined, this defaults to *hidden-size*. Only relevant + if `add_adapter is True`. + position_encoding_type (`str`, *optional*, defaults to `"relative"`): + Can be specified to `relative` or `rotary` for relative or rotary position embeddings respectively. If left + `None` no relative position embedding is applied. + rotary_embedding_base (`int`, *optional*, defaults to 10000): + If `"rotary"` position embeddings are used, defines the size of the embedding base. + num_max_positions (`int`, *optional*, defaults to 5000): + if `"relative"` position embeddings are used, defines the maximum source input positions. + conv_depthwise_kernel_size (`int`, defaults to 31): + Kernel size of convolutional depthwise 1D layer in Conformer blocks. + + Example: + + ```python + >>> from transformers import ConformerYMT3Config, ConformerYMT3Encoder + + >>> # Initializing a ConformerYMT3Encoder configuration + >>> configuration = ConformerYMT3Config() + + >>> # Initializing a model (with random weights) from the facebook/wav2vec2-conformer-rel-pos-large style configuration + >>> model = ConformerYMT3Encoder(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "conformer-ymt3" + + def __init__( + self, + d_model=512, # 768 + num_layers=8, # ConformerYMT3Encoder + num_heads=8, # ConformerYMT3SelfAttention + intermediate_size=2048, # 3072,# used in intermediate_dense of ConformerYMT3FeedForward + hidden_act="gelu", # used in intermediate_act_fn of ConformerYMT3FeedForward + dropout_rate=0.1, + layerdrop=0.1, + initializer_range=0.02, + layer_norm_eps=1e-5, + conv_dim=(512, 512, 512, 512, 512, 512, 512), + conv_stride=(5, 2, 2, 2, 2, 2, 2), + conv_kernel=(10, 3, 3, 3, 3, 3, 3), + conv_bias=False, + position_encoding_type="rotary", + rotary_embedding_base=10000, + num_max_positions=1024, + conv_depthwise_kernel_size=31, + **kwargs, + ): + super().__init__(**kwargs) + self.d_model = d_model + self.conv_dim = list(conv_dim) + self.conv_stride = list(conv_stride) + self.conv_kernel = list(conv_kernel) + self.conv_bias = conv_bias + self.num_layers = num_layers + self.intermediate_size = intermediate_size + self.hidden_act = hidden_act + self.num_heads = num_heads + self.dropout_rate = dropout_rate + + self.layerdrop = layerdrop + self.layer_norm_eps = layer_norm_eps + self.initializer_range = initializer_range + self.num_max_positions = num_max_positions + self.position_encoding_type = position_encoding_type + self.rotary_embedding_base = rotary_embedding_base + + # Conformer-block related + self.conv_depthwise_kernel_size = conv_depthwise_kernel_size + + +class ConformerYMT3PreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ConformerYMT3Config + base_model_prefix = "wav2vec2_conformer" + main_input_name = "input_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights""" + if module.__class__.__name__ == "ConformerYMT3SelfAttention": + if hasattr(module, "pos_bias_u"): + nn.init.xavier_uniform_(module.pos_bias_u) + if hasattr(module, "pos_bias_v"): + nn.init.xavier_uniform_(module.pos_bias_v) + elif isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + nn.init.kaiming_normal_(module.weight) + if module.bias is not None: + k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) + nn.init.uniform_(module.bias, a=-k, b=k) + + def _set_gradient_checkpointing(self, module, value=False): + if module.__class__.__name__ == "ConformerYMT3Encoder": + module.gradient_checkpointing = value diff --git a/amt/src/model/conformer_mod.py b/amt/src/model/conformer_mod.py new file mode 100644 index 0000000000000000000000000000000000000000..7c76131317546e7e19b9eb1f3c76087a546c0e3d --- /dev/null +++ b/amt/src/model/conformer_mod.py @@ -0,0 +1,439 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +from typing import Tuple, Literal, Any, Optional +import math + +import torch +from torch import nn +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutput + +from model.conformer_helper import ConformerYMT3Config, ConformerYMT3PreTrainedModel +from model.positional_encoding import (Wav2Vec2ConformerRelPositionalEmbedding, + Wav2Vec2ConformerRotaryPositionalEmbedding) + + +class ConformerYMT3FeedForward(nn.Module): + + def __init__(self, config): + super().__init__() + self.intermediate_dropout = nn.Dropout(config.dropout_rate) + + self.intermediate_dense = nn.Linear(config.d_model, config.intermediate_size) + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + self.output_dense = nn.Linear(config.intermediate_size, config.d_model) + self.output_dropout = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + hidden_states = self.intermediate_dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.intermediate_dropout(hidden_states) + + hidden_states = self.output_dense(hidden_states) + hidden_states = self.output_dropout(hidden_states) + return hidden_states + + +class ConformerYMT3ConvolutionModule(nn.Module): + """Convolution block used in the conformer block""" + + def __init__(self, config): + super().__init__() + if (config.conv_depthwise_kernel_size - 1) % 2 == 1: + raise ValueError("`config.conv_depthwise_kernel_size` should be a odd number for 'SAME' padding") + self.layer_norm = nn.LayerNorm(config.d_model) + self.pointwise_conv1 = torch.nn.Conv1d( + config.d_model, + 2 * config.d_model, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.glu = torch.nn.GLU(dim=1) + self.depthwise_conv = torch.nn.Conv1d( + config.d_model, + config.d_model, + config.conv_depthwise_kernel_size, + stride=1, + padding=(config.conv_depthwise_kernel_size - 1) // 2, + groups=config.d_model, + bias=False, + ) + self.batch_norm = torch.nn.BatchNorm1d(config.d_model) + self.activation = ACT2FN[config.hidden_act] + self.pointwise_conv2 = torch.nn.Conv1d( + config.d_model, + config.d_model, + kernel_size=1, + stride=1, + padding=0, + bias=False, + ) + self.dropout = torch.nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states): + hidden_states = self.layer_norm(hidden_states) + # exchange the temporal dimension and the feature dimension + hidden_states = hidden_states.transpose(1, 2) + + # GLU mechanism + # => (batch, 2*channel, dim) + hidden_states = self.pointwise_conv1(hidden_states) + # => (batch, channel, dim) + hidden_states = self.glu(hidden_states) + + # 1D Depthwise Conv + hidden_states = self.depthwise_conv(hidden_states) + hidden_states = self.batch_norm(hidden_states) + hidden_states = self.activation(hidden_states) + + hidden_states = self.pointwise_conv2(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +class ConformerYMT3SelfAttention(nn.Module): + """Construct a ConformerSelfAttention object. + Can be enhanced with rotary or relative position embeddings. + """ + + def __init__(self, config): + super().__init__() + + self.head_size = config.d_model // config.num_heads + self.num_heads = config.num_heads + self.position_encoding_type = config.position_encoding_type + + self.linear_q = nn.Linear(config.d_model, config.d_model) + self.linear_k = nn.Linear(config.d_model, config.d_model) + self.linear_v = nn.Linear(config.d_model, config.d_model) + self.linear_out = nn.Linear(config.d_model, config.d_model) + + self.dropout = nn.Dropout(p=config.dropout_rate) + + if self.position_encoding_type == "relative": + # linear transformation for positional encoding + self.linear_pos = nn.Linear(config.d_model, config.d_model, bias=False) + # these two learnable bias are used in matrix c and matrix d + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + self.pos_bias_u = nn.Parameter(torch.zeros(self.num_heads, self.head_size)) + self.pos_bias_v = nn.Parameter(torch.zeros(self.num_heads, self.head_size)) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + relative_position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # self-attention mechanism + batch_size, sequence_length, d_model = hidden_states.size() + + # make sure query/key states can be != value states + query_key_states = hidden_states + value_states = hidden_states + + if self.position_encoding_type == "rotary": + if relative_position_embeddings is None: + raise ValueError( + "`relative_position_embeddings` has to be defined when `self.position_encoding_type == 'rotary'") + query_key_states = self._apply_rotary_embedding(query_key_states, relative_position_embeddings) + + # project query_key_states and value_states + query = self.linear_q(query_key_states).view(batch_size, -1, self.num_heads, self.head_size) + key = self.linear_k(query_key_states).view(batch_size, -1, self.num_heads, self.head_size) + value = self.linear_v(value_states).view(batch_size, -1, self.num_heads, self.head_size) + + # => (batch, head, time1, d_k) + query = query.transpose(1, 2) + key = key.transpose(1, 2) + value = value.transpose(1, 2) + + if self.position_encoding_type == "relative": + if relative_position_embeddings is None: + raise ValueError("`relative_position_embeddings` has to be defined when `self.position_encoding_type ==" + " 'relative'") + # apply relative_position_embeddings to qk scores + # as proposed in Transformer_XL: https://arxiv.org/abs/1901.02860 + scores = self._apply_relative_embeddings(query=query, + key=key, + relative_position_embeddings=relative_position_embeddings) + else: + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_size) + + # apply attention_mask if necessary + if attention_mask is not None: + scores = scores + attention_mask + + # => (batch, head, time1, time2) + probs = torch.softmax(scores, dim=-1) + probs = self.dropout(probs) + + # => (batch, head, time1, d_k) + hidden_states = torch.matmul(probs, value) + + # => (batch, time1, d_model) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_size) + hidden_states = self.linear_out(hidden_states) + + return hidden_states, probs + + def _apply_rotary_embedding(self, hidden_states, relative_position_embeddings): + batch_size, sequence_length, d_model = hidden_states.size() + hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads, self.head_size) + + cos = relative_position_embeddings[0, :sequence_length, ...] + sin = relative_position_embeddings[1, :sequence_length, ...] + + # rotate hidden_states with rotary embeddings + hidden_states = hidden_states.transpose(0, 1) + rotated_states_begin = hidden_states[..., :self.head_size // 2] + rotated_states_end = hidden_states[..., self.head_size // 2:] + rotated_states = torch.cat((-rotated_states_end, rotated_states_begin), dim=rotated_states_begin.ndim - 1) + hidden_states = (hidden_states * cos) + (rotated_states * sin) + hidden_states = hidden_states.transpose(0, 1) + + hidden_states = hidden_states.view(batch_size, sequence_length, self.num_heads * self.head_size) + + return hidden_states + + def _apply_relative_embeddings(self, query, key, relative_position_embeddings): + # 1. project positional embeddings + # => (batch, head, 2*time1-1, d_k) + proj_relative_position_embeddings = self.linear_pos(relative_position_embeddings) + proj_relative_position_embeddings = proj_relative_position_embeddings.view(relative_position_embeddings.size(0), + -1, self.num_heads, self.head_size) + proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(1, 2) + proj_relative_position_embeddings = proj_relative_position_embeddings.transpose(2, 3) + + # 2. Add bias to query + # => (batch, head, time1, d_k) + query = query.transpose(1, 2) + q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2) + q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2) + + # 3. attention score: first compute matrix a and matrix c + # as described in https://arxiv.org/abs/1901.02860 Section 3.3 + # => (batch, head, time1, time2) + scores_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1)) + + # 4. then compute matrix b and matrix d + # => (batch, head, time1, 2*time1-1) + scores_bd = torch.matmul(q_with_bias_v, proj_relative_position_embeddings) + + # 5. shift matrix b and matrix d + zero_pad = torch.zeros((*scores_bd.size()[:3], 1), device=scores_bd.device, dtype=scores_bd.dtype) + scores_bd_padded = torch.cat([zero_pad, scores_bd], dim=-1) + scores_bd_padded_shape = scores_bd.size()[:2] + (scores_bd.shape[3] + 1, scores_bd.shape[2]) + scores_bd_padded = scores_bd_padded.view(*scores_bd_padded_shape) + scores_bd = scores_bd_padded[:, :, 1:].view_as(scores_bd) + scores_bd = scores_bd[:, :, :, :scores_bd.size(-1) // 2 + 1] + + # 6. sum matrices + # => (batch, head, time1, time2) + scores = (scores_ac + scores_bd) / math.sqrt(self.head_size) + + return scores + + +class ConformerYMT3EncoderLayer(nn.Module): + """Conformer block based on https://arxiv.org/abs/2005.08100.""" + + def __init__(self, config): + super().__init__() + embed_dim = config.d_model + dropout = config.dropout_rate + + # Feed-forward 1 + self.ffn1_layer_norm = nn.LayerNorm(embed_dim) + self.ffn1 = ConformerYMT3FeedForward(config) + + # Self-Attention + self.self_attn_layer_norm = nn.LayerNorm(embed_dim) + self.self_attn_dropout = torch.nn.Dropout(dropout) + self.self_attn = ConformerYMT3SelfAttention(config) + + # Conformer Convolution + self.conv_module = ConformerYMT3ConvolutionModule(config) + + # Feed-forward 2 + self.ffn2_layer_norm = nn.LayerNorm(embed_dim) + self.ffn2 = ConformerYMT3FeedForward(config) + self.final_layer_norm = nn.LayerNorm(embed_dim) + + def forward( + self, + hidden_states, + attention_mask: Optional[torch.Tensor] = None, + relative_position_embeddings: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ): + hidden_states = hidden_states + + # 1. Feed-Forward 1 layer + residual = hidden_states + hidden_states = self.ffn1_layer_norm(hidden_states) + hidden_states = self.ffn1(hidden_states) + hidden_states = hidden_states * 0.5 + residual + residual = hidden_states + + # 2. Self-Attention layer + hidden_states = self.self_attn_layer_norm(hidden_states) + hidden_states, attn_weigts = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + relative_position_embeddings=relative_position_embeddings, + output_attentions=output_attentions, + ) + hidden_states = self.self_attn_dropout(hidden_states) + hidden_states = hidden_states + residual + + # 3. Convolutional Layer + residual = hidden_states + hidden_states = self.conv_module(hidden_states) + hidden_states = residual + hidden_states + + # 4. Feed-Forward 2 Layer + residual = hidden_states + hidden_states = self.ffn2_layer_norm(hidden_states) + hidden_states = self.ffn2(hidden_states) + hidden_states = hidden_states * 0.5 + residual + hidden_states = self.final_layer_norm(hidden_states) + + return hidden_states, attn_weigts + + +class ConformerYMT3Encoder(nn.Module): + + def __init__(self, config): + super().__init__() + self.config = config + + if config.position_encoding_type == "relative": + self.embed_positions = Wav2Vec2ConformerRelPositionalEmbedding(config) + elif config.position_encoding_type == "rotary": + self.embed_positions = Wav2Vec2ConformerRotaryPositionalEmbedding(config) + else: + self.embed_positions = None + + # self.pos_conv_embed = Wav2Vec2ConformerPositionalConvEmbedding(config) + self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.dropout_rate) + self.layers = nn.ModuleList([ConformerYMT3EncoderLayer(config) for _ in range(config.num_layers)]) + self.gradient_checkpointing = False + + def forward( + self, + inputs_embeds: torch.FloatTensor, # (B, T, D) + attention_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + return_dict: Optional[bool] = True, + ): + if output_attentions is None: + output_attentions = self.config.output_attentions + if output_hidden_states is None: + output_hidden_states = self.config.output_hidden_states + if return_dict is None: + return_dict = self.config.use_return_dict + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + # inputs_embeds as hidden_states + hidden_states = inputs_embeds + + if attention_mask is not None: + # make sure padded tokens output 0 + hidden_states[~attention_mask] = 0.0 + + # extend attention_mask + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min + attention_mask = attention_mask.expand(attention_mask.shape[0], 1, attention_mask.shape[-1], + attention_mask.shape[-1]) + + hidden_states = self.dropout(hidden_states) + + if self.embed_positions is not None: + relative_position_embeddings = self.embed_positions(hidden_states) + else: + relative_position_embeddings = None + + for i, layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = torch.rand([]) + + skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False + if not skip_the_layer: + # under deepspeed zero3 all gpus must run in sync + if self.gradient_checkpointing and self.training: + # create gradient checkpointing function + def create_custom_forward(module): + + def custom_forward(*inputs): + return module(*inputs, output_attentions) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(layer), + hidden_states, + attention_mask, + relative_position_embeddings, + ) + else: + layer_outputs = layer( + hidden_states, + attention_mask=attention_mask, + relative_position_embeddings=relative_position_embeddings, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + hidden_states = self.layer_norm(hidden_states) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +def test(): + import torch + from model.conformer_mod import ConformerYMT3Encoder + from model.conformer_helper import ConformerYMT3Config + from model.ops import count_parameters + config = ConformerYMT3Config() + encoder = ConformerYMT3Encoder(config) + encoder.eval() + # num params: 48,468,992 w/ intermediate_size=2048 + # num params: 23,278,592 w/ intermediate_size=512 + x = torch.randn(2, 256, 512) # (B, T, D) + enc_hs = encoder.forward(inputs_embeds=x)['last_hidden_state'] # (B, T, D) diff --git a/amt/src/model/conv_block.py b/amt/src/model/conv_block.py new file mode 100644 index 0000000000000000000000000000000000000000..383e93e46951a859139b0276d12fc69ff61b6a58 --- /dev/null +++ b/amt/src/model/conv_block.py @@ -0,0 +1,217 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +from typing import Literal +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +def init_layer(layer: nn.Module) -> None: + """Initialize a Linear or Convolutional layer.""" + nn.init.xavier_uniform_(layer.weight) + if hasattr(layer, "bias") and layer.bias is not None: + layer.bias.data.zero_() + + +def init_bn(bn: nn.Module) -> None: + """Initialize a Batchnorm layer.""" + bn.bias.data.zero_() + bn.weight.data.fill_(1.0) + bn.running_mean.data.zero_() + bn.running_var.data.fill_(1.0) + + +def act(x: torch.Tensor, activation: str) -> torch.Tensor: + """Activation function.""" + funcs = {"relu": F.relu_, "leaky_relu": lambda x: F.leaky_relu_(x, 0.01), "swish": lambda x: x * torch.sigmoid(x)} + return funcs.get(activation, lambda x: Exception("Incorrect activation!"))(x) + + +class Res2DAVPBlock(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, avp_kernel_size, activation): + """Convolutional residual block modified fromr bytedance/music_source_separation.""" + super().__init__() + + padding = kernel_size[0] // 2, kernel_size[1] // 2 + + self.activation = activation + self.bn1, self.bn2 = nn.BatchNorm2d(out_channels), nn.BatchNorm2d(out_channels) + + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=False) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, padding=padding, bias=False) + + self.is_shortcut = in_channels != out_channels + if self.is_shortcut: + self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1)) + + self.avp = nn.AvgPool2d(avp_kernel_size) + self.init_weights() + + def init_weights(self): + for m in [self.conv1, self.conv2] + ([self.shortcut] if self.is_shortcut else []): + init_layer(m) + for m in [self.bn1, self.bn2]: + init_bn(m) + + def forward(self, x): + origin = x + x = act(self.bn1(self.conv1(x)), self.activation) + x = self.bn2(self.conv2(x)) + x += self.shortcut(origin) if self.is_shortcut else origin + x = act(x, self.activation) + return self.avp(x) + + +class PreEncoderBlockRes3B(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size=(3, 3), avp_kernerl_size=(1, 2), activation='relu'): + """Pre-Encoder with 3 Res2DAVPBlocks.""" + super().__init__() + + self.blocks = nn.ModuleList([ + Res2DAVPBlock(in_channels if i == 0 else out_channels, out_channels, kernel_size, avp_kernerl_size, + activation) for i in range(3) + ]) + + def forward(self, x): # (B, T, F) + x = rearrange(x, 'b t f -> b 1 t f') + for block in self.blocks: + x = block(x) + return rearrange(x, 'b c t f -> b t f c') + + +def test_res3b(): + # mel-spec input + x = torch.randn(2, 256, 512) # (B, T, F) + pre = PreEncoderBlockRes3B(in_channels=1, out_channels=128) + x = pre(x) # (2, 256, 64, 128): B T,F,C + + x = torch.randn(2, 110, 1024) # (B, T, F) + pre = PreEncoderBlockRes3B(in_channels=1, out_channels=128) + x = pre(x) # (2, 110, 128, 128): B,T,F,C + + +# ==================================================================================================================== +# PreEncoderBlockHFTT: hFT-Transformer-like Pre-encoder +# ==================================================================================================================== +class PreEncoderBlockHFTT(nn.Module): + + def __init__(self, margin_pre=15, margin_post=16) -> None: + """Pre-Encoder with hFT-Transformer-like convolutions.""" + super().__init__() + + self.margin_pre, self.margin_post = margin_pre, margin_post + self.conv = nn.Conv2d(1, 4, kernel_size=(1, 5), padding='same', padding_mode='zeros') + self.emb_freq = nn.Linear(128, 128) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: (B, T, F) + x = rearrange(x, 'b t f -> b 1 f t') # (B, 1, F, T) or (2, 1, 128, 110) + x = F.pad(x, (self.margin_pre, self.margin_post), value=1e-7) # (B, 1, F, T+margin) or (2,1,128,141) + x = self.conv(x) # (B, C, F, T+margin) or (2, 4, 128, 141) + x = x.unfold(dimension=3, size=32, step=1) # (B, c1, T, F, c2) or (2, 4, 128, 110, 32) + x = rearrange(x, 'b c1 f t c2 -> b t f (c1 c2)') # (B, T, F, C) or (2, 110, 128, 128) + return self.emb_freq(x) # (B, T, F, C) or (2, 110, 128, 128) + + +def test_hftt(): + # from model.spectrogram import get_spectrogram_layer_from_audio_cfg + # from config.config import audio_cfg as default_audio_cfg + # audio_cfg = default_audio_cfg + # audio_cfg['codec'] = 'melspec' + # audio_cfg['hop_length'] = 300 + # audio_cfg['n_mels'] = 128 + # x = torch.randn(2, 1, 32767) + # mspec, _ = get_spectrogram_layer_from_audio_cfg(audio_cfg) + # x = mspec(x) + x = torch.randn(2, 110, 128) # (B, T, F) + pre_enc_hftt = PreEncoderBlockHFTT() + y = pre_enc_hftt(x) # (2, 110, 128, 128): B, T, F, C + + +# ==================================================================================================================== +# PreEncoderBlockRes3BHFTT: hFT-Transformer-like Pre-encoder with Res2DAVPBlock and spec input +# ==================================================================================================================== +class PreEncoderBlockRes3BHFTT(nn.Module): + + def __init__(self, margin_pre: int = 15, margin_post: int = 16) -> None: + """Pre-Encoder with hFT-Transformer-like convolutions. + + Args: + margin_pre (int): padding before the input + margin_post (int): padding after the input + stack_dim (Literal['c', 'f']): stack dimension. channel or frequency + + """ + super().__init__() + self.margin_pre, self.margin_post = margin_pre, margin_post + self.res3b = PreEncoderBlockRes3B(in_channels=1, out_channels=4) + self.emb_freq = nn.Linear(128, 128) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # x: (B, T, F) or (2, 110, 1024), input spectrogram + x = rearrange(x, 'b t f -> b f t') # (2, 1024, 110): B,F,T + x = F.pad(x, (self.margin_pre, self.margin_post), value=1e-7) # (2, 1024, 141): B,F,T+margin + x = rearrange(x, 'b f t -> b t f') # (2, 141, 1024): B,T+margin,F + x = self.res3b(x) # (2, 141, 128, 4): B,T+margin,F,C + x = x.unfold(dimension=1, size=32, step=1) # (B, T, F, C1, C2) or (2, 110, 128, 4, 32) + x = rearrange(x, 'b t f c1 c2 -> b t f (c1 c2)') # (B, T, F, C) or (2, 110, 128, 128) + return self.emb_freq(x) # (B, T, F, C) or (2, 110, 128, 128) + + +def test_res3b_hftt(): + # from model.spectrogram import get_spectrogram_layer_from_audio_cfg + # from config.config import audio_cfg as default_audio_cfg + # audio_cfg = default_audio_cfg + # audio_cfg['codec'] = 'spec' + # audio_cfg['hop_length'] = 300 + # x = torch.randn(2, 1, 32767) + # spec, _ = get_spectrogram_layer_from_audio_cfg(audio_cfg) + # x = spec(x) # (2, 110, 1024): B,T,F + x = torch.randn(2, 110, 1024) # (B, T, F) + pre_enc_res3b_hftt = PreEncoderBlockRes3BHFTT() + y = pre_enc_res3b_hftt(x) # (2, 110, 128, 128): B, T, F, C + + +# # ==================================================================================================================== +# # PreEncoderBlockConv1D: Pre-encoder without activation, with Melspec input +# # ==================================================================================================================== +# class PreEncoderBlockConv1D(nn.Module): + +# def __init__(self, +# in_channels, +# out_channels, +# kernel_size=3) -> None: +# """Pre-Encoder with 1D convolution.""" +# super().__init__() +# self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, stride=1) +# self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, stride=1) + +# def forward(self, x: torch.Tensor) -> torch.Tensor: +# # x: (B, T, F) or (2, 128, 256), input melspec +# x = rearrange(x, 'b t f -> b f t') # (2, 256, 128): B,F,T +# x = self.conv1(x) # (2, 128, 128): B,F,T +# return rearrange(x, 'b f t -> b t f') # (2, 110, 128): B,T,F + +# def test_conv1d(): +# # from model.spectrogram import get_spectrogram_layer_from_audio_cfg +# # from config.config import audio_cfg as default_audio_cfg +# # audio_cfg = default_audio_cfg +# # audio_cfg['codec'] = 'melspec' +# # audio_cfg['hop_length'] = 256 +# # audio_cfg['n_mels'] = 512 +# # x = torch.randn(2, 1, 32767) +# # mspec, _ = get_spectrogram_layer_from_audio_cfg(audio_cfg) +# # x = mspec(x) +# x = torch.randn(2, 128, 128) # (B, T, F) +# pre_enc_conv1d = PreEncoderBlockConv1D(in_channels=1, out_channels=128) +# y = pre_enc_conv1d(x) # (2, 110, 128, 128): B, T, F, C diff --git a/amt/src/model/ff_layer.py b/amt/src/model/ff_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..8595eb09535864f7e77470bed59b5be9e6d2d4da --- /dev/null +++ b/amt/src/model/ff_layer.py @@ -0,0 +1,238 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +"""ff_layer.py + +This module contains the implementation of the feedforward layers. + + Supported ff_layer_type: + 'mlp': Multi-Layer Perceptron + 'gmlp': Gated Multi-Layer Perceptron, simplified version of Mixtral Expert with num_experts=1 and top_k=1. + This is not the spatial gating MLP (https://arxiv.org/abs/2105.08050). + 'moe': Mixtral of Experts, modified from the original source code: + https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/models/mixtral/modeling_mixtral.py + + Usage: + from model.ff_layer import get_ff_layer + + config = PerceiverTFConfig() # or any type of PretrainedConfig() + config.ff_layer_type = 'moe' # or 'mlp' + config.moe_num_experts = 4 + config.moe_topk = 2 + config.hidden_act = 'gelu' # or any type of activation function, e.g., 'silu' + + ff_layer = get_ff_layer(config, input_size, widening_factor) + + What ff_layer returns: + - It returns (hidden_states, router_logits) for MoE and (hidden_states, None) for MLP. + - router_logits has the shape of (batch_size * sequence_length, n_experts) for MoE. + + +""" +from typing import Any, Tuple +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.configuration_utils import PretrainedConfig +from transformers.activations import ACT2FN +from model.ops import get_layer_norm +from model.ops import optional_compiler_disable, optional_compiler_dynamic + + +class MixtralBlockSparseTop2MLP(nn.Module): + """ + The Gated Multilayer Perceptron (GMLP) used in Mixtral of Experts (MoE). + + """ + + def __init__(self, config: PretrainedConfig, input_size: int, widening_factor: int): + super().__init__() + self.hidden_dim = input_size + self.ffn_dim = int(input_size * widening_factor) + + self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + self.gate = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.gate(hidden_states) + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states + + +class MixtralSparseMoeBlock(nn.Module): + """ + This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accomodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. + """ + + def __init__(self, config, input_size: int, widening_factor: int): + super().__init__() + self.hidden_dim = input_size + self.widening_factor = widening_factor + self.num_experts = config.moe_num_experts + self.top_k = config.moe_topk + + # gating + self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False) + self.experts = nn.ModuleList( + [MixtralBlockSparseTop2MLP(config, self.hidden_dim, self.widening_factor) for _ in range(self.num_experts)]) + + @optional_compiler_disable + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros((batch_size * sequence_length, hidden_dim), + dtype=hidden_states.dtype, + device=hidden_states.device) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + # in torch it is faster to index using lists than torch tensors + top_x_list = top_x.tolist() + idx_list = idx.tolist() + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) + current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + + +class MLP(nn.Module): + """A Standard Transformer-style dense module to follow attention.""" + + def __init__(self, config: PretrainedConfig, input_size: int, widening_factor: int): + super().__init__() + self.dense1 = nn.Linear(input_size, widening_factor * input_size) + self.dense2 = nn.Linear(widening_factor * input_size, input_size) + + if isinstance(config.hidden_act, str): + self.intermediate_act_fn = ACT2FN[config.hidden_act] + else: + self.intermediate_act_fn = config.hidden_act + + def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, Any]: + hidden_states = self.dense1(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + hidden_states = self.dense2(hidden_states) + return hidden_states, None + + +class SimpleGMLP(nn.Module): + """A Simple Gated Multilayer Perceptron (aka. 'gmlp'), without the spatial gating mechanism. + + Note that this is not the spatial gating MLP (https://arxiv.org/abs/2105.08050). + - A simplified MLP w/ gating mechanism adapted from Mixtral Expert, as when + the number of experts and top_k are both set to 1.) + - Added a dropout layer. + - This was also used in T5 v1.1. + """ + + def __init__(self, config: PretrainedConfig, input_size: int, widening_factor: int): + super().__init__() + self.hidden_dim = input_size + self.ffn_dim = int(input_size * widening_factor) + + self.w1 = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.w2 = nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + self.gate = nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.act_fn = ACT2FN[config.hidden_act] + self.dropout1 = nn.Dropout(config.dropout_rate) + self.dropout2 = nn.Dropout(config.dropout_rate) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.gate(hidden_states) + current_hidden_states = self.dropout1(current_hidden_states) + current_hidden_states = self.w2(current_hidden_states) + current_hidden_states = self.dropout2( + current_hidden_states) # Residual connection is applied outside of this module. + return current_hidden_states, None + + +def get_ff_layer(config: PretrainedConfig, input_size: int, widening_factor: int): + if config.ff_layer_type == 'moe': + assert hasattr(config, 'moe_num_experts') and hasattr(config, 'moe_topk') and hasattr(config, 'hidden_act') + return MixtralSparseMoeBlock(config, input_size, widening_factor) + elif config.ff_layer_type == 'mlp': + assert hasattr(config, 'hidden_act') + return MLP(config, input_size, widening_factor) + elif config.ff_layer_type == 'gmlp': + assert hasattr(config, 'hidden_act') + return SimpleGMLP(config, input_size, widening_factor) + else: + raise ValueError( + f"Unsupported ff_layer_type: {config.ff_layer_type}. Supported types are 'moe', 'mlp' and 'gmlp'.") + + +def test_get_ff_layer(): + from model.ff_layer import get_ff_layer + from model.perceiver_helper import PerceiverTFConfig + input_size = 32 + widening_factor = 1 + + # Test for MoE + config = PerceiverTFConfig() # or any type of PretrainedConfig() + config.ff_layer_type = 'moe' + config.moe_num_experts = 4 + config.moe_topk = 2 + config.hidden_act = 'silu' + + ff_layer = get_ff_layer(config, input_size, widening_factor) + x = torch.rand(2, 8, input_size) + hidden_states, router_logits = ff_layer(x) + print(hidden_states.shape, router_logits.shape) # (2, 8, 32), (2*8, 4) + + # Test for MLP + config.ff_layer_type = 'mlp' + config.hidden_act = 'gelu' + + ff_layer = get_ff_layer(config, input_size, widening_factor) + hidden_states, _ = ff_layer(x) + print(hidden_states.shape) # (2, 8, 32) + + # Test for (simple)gMLP + config.ff_layer_type = 'gmlp' + config.hidden_act = 'silu' + ff_layer = get_ff_layer(config, input_size, widening_factor) + hidden_states, _ = ff_layer(x) + print(hidden_states.shape) # (2, 8, 32) diff --git a/amt/src/model/init_train.py b/amt/src/model/init_train.py new file mode 100644 index 0000000000000000000000000000000000000000..83d46d1ebb98f742744c0e523b58bc2fe6f72f0e --- /dev/null +++ b/amt/src/model/init_train.py @@ -0,0 +1,281 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +"""init_train.py""" +from typing import Tuple, Literal, Any +from copy import deepcopy +import os +import argparse +import pytorch_lightning as pl +from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.callbacks import LearningRateMonitor +from pytorch_lightning.utilities import rank_zero_only +from config.config import shared_cfg as default_shared_cfg +from config.config import audio_cfg as default_audio_cfg +from config.config import model_cfg as default_model_cfg +from config.config import DEEPSPEED_CFG + + +def initialize_trainer(args: argparse.Namespace, + stage: Literal['train', 'test'] = 'train') -> Tuple[pl.Trainer, WandbLogger, dict]: + """Initialize trainer and logger""" + shared_cfg = deepcopy(default_shared_cfg) + + # create save dir + os.makedirs(shared_cfg["WANDB"]["save_dir"], exist_ok=True) + + # collecting specific checkpoint from exp_id with extension (@xxx where xxx is checkpoint name) + if "@" in args.exp_id: + args.exp_id, checkpoint_name = args.exp_id.split("@") + else: + checkpoint_name = "last.ckpt" + + # checkpoint dir + lightning_dir = os.path.join(shared_cfg["WANDB"]["save_dir"], args.project, args.exp_id) + + # create logger + if args.wandb_mode is not None: + shared_cfg["WANDB"]["mode"] = str(args.wandb_mode) + if shared_cfg["WANDB"].get("cache_dir", None) is not None: + os.environ["WANDB_CACHE_DIR"] = shared_cfg["WANDB"].get("cache_dir") + del shared_cfg["WANDB"]["cache_dir"] # remove cache_dir from shared_cfg + wandb_logger = WandbLogger(log_model="all", + project=args.project, + id=args.exp_id, + allow_val_change=True, + **shared_cfg['WANDB']) + + # check if any checkpoint exists + last_ckpt_path = os.path.join(lightning_dir, "checkpoints", checkpoint_name) + if os.path.exists(os.path.join(last_ckpt_path)): + print(f'Resuming from {last_ckpt_path}') + elif stage == 'train': + print(f'No checkpoint found in {last_ckpt_path}. Starting from scratch') + last_ckpt_path = None + else: + raise ValueError(f'No checkpoint found in {last_ckpt_path}. Quit...') + + # add info + dir_info = dict(lightning_dir=lightning_dir, last_ckpt_path=last_ckpt_path) + + # define checkpoint callback + checkpoint_callback = ModelCheckpoint(**shared_cfg["CHECKPOINT"],) + + # define lr scheduler monitor callback + lr_monitor = LearningRateMonitor(logging_interval='step') + + # deepspeed strategy + if args.strategy == 'deepspeed': + strategy = pl.strategies.DeepSpeedStrategy(config=DEEPSPEED_CFG) + + # validation interval + if stage == 'train' and args.val_interval is not None: + shared_cfg["TRAINER"]["check_val_every_n_epoch"] = None + shared_cfg["TRAINER"]["val_check_interval"] = int(args.val_interval) + + # define trainer + sync_batchnorm = False + if stage == 'train': + # train batch size + if args.train_batch_size is not None: + train_sub_bsz = int(args.train_batch_size[0]) + train_local_bsz = int(args.train_batch_size[1]) + if train_local_bsz % train_sub_bsz == 0: + shared_cfg["BSZ"]["train_sub"] = train_sub_bsz + shared_cfg["BSZ"]["train_local"] = train_local_bsz + else: + raise ValueError( + f'Local batch size {train_local_bsz} must be divisible by sub batch size {train_sub_bsz}') + + # ddp strategy + if args.strategy == 'ddp': + args.strategy = 'ddp_find_unused_parameters_true' # fix for conformer or pitchshifter having unused parameter issue + + # sync-batchnorm + if args.sync_batchnorm is True: + sync_batchnorm = True + + train_params = dict(**shared_cfg["TRAINER"], + devices=args.num_gpus if args.num_gpus == 'auto' else int(args.num_gpus), + num_nodes=int(args.num_nodes), + strategy=strategy if args.strategy == 'deepspeed' else args.strategy, + precision=args.precision, + max_epochs=args.max_epochs if stage == 'train' else None, + max_steps=args.max_steps if stage == 'train' else -1, + logger=wandb_logger, + callbacks=[checkpoint_callback, lr_monitor], + sync_batchnorm=sync_batchnorm) + trainer = pl.trainer.trainer.Trainer(**train_params) + + # Update wandb logger (for DDP) + if trainer.global_rank == 0: + wandb_logger.experiment.config.update(args, allow_val_change=True) + + return trainer, wandb_logger, dir_info, shared_cfg + + +def update_config(args, shared_cfg, stage: Literal['train', 'test'] = 'train'): + """Update audio/model/shared configurations with args""" + audio_cfg = default_audio_cfg + model_cfg = default_model_cfg + + # Only update config when training + if stage == 'train': + # Augmentation parameters + if args.random_amp_range is not None: + shared_cfg["AUGMENTATION"]["train_random_amp_range"] = list( + (float(args.random_amp_range[0]), float(args.random_amp_range[1]))) + if args.stem_iaug_prob is not None: + shared_cfg["AUGMENTATION"]["train_stem_iaug_prob"] = float(args.stem_iaug_prob) + + if args.xaug_max_k is not None: + shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["max_k"] = int(args.xaug_max_k) + if args.xaug_tau is not None: + shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["tau"] = float(args.xaug_tau) + if args.xaug_alpha is not None: + shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["alpha"] = float(args.xaug_alpha) + if args.xaug_no_instr_overlap is not None: + shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["no_instr_overlap"] = bool(args.xaug_no_instr_overlap) + if args.xaug_no_drum_overlap is not None: + shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["no_drum_overlap"] = bool(args.xaug_no_drum_overlap) + if args.uhat_intra_stem_augment is not None: + shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["uhat_intra_stem_augment"] = bool( + args.uhat_intra_stem_augment) + + if args.pitch_shift_range is not None: + if args.pitch_shift_range in [["0", "0"], [0, 0]]: + shared_cfg["AUGMENTATION"]["train_pitch_shift_range"] = None + else: + shared_cfg["AUGMENTATION"]["train_pitch_shift_range"] = list( + (int(args.pitch_shift_range[0]), int(args.pitch_shift_range[1]))) + + train_stem_iaug_prob = shared_cfg["AUGMENTATION"]["train_stem_iaug_prob"] + random_amp_range = shared_cfg["AUGMENTATION"]["train_random_amp_range"] + train_stem_xaug_policy = shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"] + print(f'Random amp range: {random_amp_range}\n' + + f'Intra-stem augmentation probability: {train_stem_iaug_prob}\n' + + f'Stem augmentation policy: {train_stem_xaug_policy}\n' + + f'Pitch shift range: {shared_cfg["AUGMENTATION"]["train_pitch_shift_range"]}\n') + + # Update audio config + if args.audio_codec != None: + assert args.audio_codec in ['spec', 'melspec'] + audio_cfg["codec"] = str(args.audio_codec) + if args.hop_length != None: + audio_cfg["hop_length"] = int(args.hop_length) + if args.n_mels != None: + audio_cfg["n_mels"] = int(args.n_mels) + if args.input_frames != None: + audio_cfg["input_frames"] = int(args.input_frames) + + # Update shared config + if shared_cfg["TOKENIZER"]["max_shift_steps"] == "auto": + shift_steps_ms = shared_cfg["TOKENIZER"]["shift_step_ms"] + input_frames = audio_cfg["input_frames"] + fs = audio_cfg["sample_rate"] + max_shift_steps = (input_frames / fs) // (shift_steps_ms / 1000) + 2 # 206 by default + shared_cfg["TOKENIZER"]["max_shift_steps"] = int(max_shift_steps) + + # Update model config + if args.encoder_type != None: + model_cfg["encoder_type"] = str(args.encoder_type) + if args.decoder_type != None: + model_cfg["decoder_type"] = str(args.decoder_type) + if args.pre_encoder_type != "default": + model_cfg["pre_encoder_type"] = str(args.pre_encoder_type) + if args.pre_decoder_type != 'default': + model_cfg["pre_decoder_type"] = str(args.pre_decoder_type) + if args.conv_out_channels != None: + model_cfg["conv_out_channels"] = int(args.conv_out_channels) + assert isinstance(args.task_cond_decoder, bool) and isinstance(args.task_cond_encoder, bool) + model_cfg["use_task_conditional_encoder"] = args.task_cond_encoder + model_cfg["use_task_conditional_decoder"] = args.task_cond_decoder + + if args.encoder_position_encoding_type != 'default': + if args.encoder_position_encoding_type in ['None', 'none', '0']: + model_cfg["encoder"][model_cfg["encoder_type"]]["position_encoding_type"] = None + elif args.encoder_position_encoding_type in [ + 'sinusoidal', 'rope', 'trainable', 'alibi', 'alibit', 'tkd', 'td', 'tk', 'kdt' + ]: + model_cfg["encoder"][model_cfg["encoder_type"]]["position_encoding_type"] = str( + args.encoder_position_encoding_type) + else: + raise ValueError(f'Encoder PE type {args.encoder_position_encoding_type} not supported') + if args.decoder_position_encoding_type != 'default': + if args.decoder_position_encoding_type in ['None', 'none', '0']: + raise ValueError('Decoder PE type cannot be None') + elif args.decoder_position_encoding_type in ['sinusoidal', 'trainable']: + model_cfg["decoder"][model_cfg["decoder_type"]]["position_encoding_type"] = str( + args.decoder_position_encoding_type) + else: + raise ValueError(f'Decoder PE {args.decoder_position_encoding_type} not supported') + + if args.tie_word_embedding is not None: + model_cfg["tie_word_embedding"] = bool(args.tie_word_embedding) + + if args.d_feat != None: + model_cfg["d_feat"] = int(args.d_feat) + if args.d_latent != None: + model_cfg['encoder']['perceiver-tf']["d_latent"] = int(args.d_latent) + if args.num_latents != None: + model_cfg['encoder']['perceiver-tf']['num_latents'] = int(args.num_latents) + if args.perceiver_tf_d_model != None: + model_cfg['encoder']['perceiver-tf']['d_model'] = int(args.perceiver_tf_d_model) + if args.num_perceiver_tf_blocks != None: + model_cfg["encoder"]["perceiver-tf"]["num_blocks"] = int(args.num_perceiver_tf_blocks) + if args.num_perceiver_tf_local_transformers_per_block != None: + model_cfg["encoder"]["perceiver-tf"]["num_local_transformers_per_block"] = int( + args.num_perceiver_tf_local_transformers_per_block) + if args.num_perceiver_tf_temporal_transformers_per_block != None: + model_cfg["encoder"]["perceiver-tf"]["num_temporal_transformers_per_block"] = int( + args.num_perceiver_tf_temporal_transformers_per_block) + if args.attention_to_channel != None: + model_cfg["encoder"]["perceiver-tf"]["attention_to_channel"] = bool(args.attention_to_channel) + if args.sca_use_query_residual != None: + model_cfg["encoder"]["perceiver-tf"]["sca_use_query_residual"] = bool(args.sca_use_query_residual) + if args.layer_norm_type != None: + model_cfg["encoder"]["perceiver-tf"]["layer_norm"] = str(args.layer_norm_type) + if args.ff_layer_type != None: + model_cfg["encoder"]["perceiver-tf"]["ff_layer_type"] = str(args.ff_layer_type) + if args.ff_widening_factor != None: + model_cfg["encoder"]["perceiver-tf"]["ff_widening_factor"] = int(args.ff_widening_factor) + if args.moe_num_experts != None: + model_cfg["encoder"]["perceiver-tf"]["moe_num_experts"] = int(args.moe_num_experts) + if args.moe_topk != None: + model_cfg["encoder"]["perceiver-tf"]["moe_topk"] = int(args.moe_topk) + if args.hidden_act != None: + model_cfg["encoder"]["perceiver-tf"]["hidden_act"] = str(args.hidden_act) + if args.rotary_type != None: + assert len( + args.rotary_type + ) == 3, "rotary_type must be a 3-letter string (e.g. 'ppl': 'pixel' for SCA, 'pixel' for latent, 'lang' for temporal transformer)" + model_cfg["encoder"]["perceiver-tf"]["rotary_type_sca"] = str(args.rotary_type)[0] + model_cfg["encoder"]["perceiver-tf"]["rotary_type_latent"] = str(args.rotary_type)[1] + model_cfg["encoder"]["perceiver-tf"]["rotary_type_temporal"] = str(args.rotary_type)[2] + if args.rope_apply_to_keys != None: + model_cfg["encoder"]["perceiver-tf"]["rope_apply_to_keys"] = bool(args.rope_apply_to_keys) + if args.rope_partial_pe != None: + model_cfg["encoder"]["perceiver-tf"]["rope_partial_pe"] = bool(args.rope_partial_pe) + + if args.decoder_ff_layer_type != None: + model_cfg["decoder"][model_cfg["decoder_type"]]["ff_layer_type"] = str(args.decoder_ff_layer_type) + if args.decoder_ff_widening_factor != None: + model_cfg["decoder"][model_cfg["decoder_type"]]["ff_widening_factor"] = int(args.decoder_ff_widening_factor) + + if args.event_length != None: + model_cfg["event_length"] = int(args.event_length) + + if stage == 'train': + if args.encoder_dropout_rate != None: + model_cfg["encoder"][model_cfg["encoder_type"]]["dropout_rate"] = float(args.encoder_dropout_rate) + if args.decoder_dropout_rate != None: + model_cfg["decoder"][model_cfg["decoder_type"]]["dropout_rate"] = float(args.decoder_dropout_rate) + + return shared_cfg, audio_cfg, model_cfg # return updated configs diff --git a/amt/src/model/lm_head.py b/amt/src/model/lm_head.py new file mode 100644 index 0000000000000000000000000000000000000000..7a12b17c0f225dad545d3c9c03cd81565094924f --- /dev/null +++ b/amt/src/model/lm_head.py @@ -0,0 +1,40 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +"""lm_head.py""" +import torch +from torch import nn +from typing import Optional, Dict + + +class LMHead(nn.Module): + """Language Model Head with tied weights.""" + + def __init__(self, decoder_config: Dict, init_factor: float = 1.0, tie_word_embeddings: bool = True): + + super().__init__() + self.d_model = decoder_config["d_model"] + self.init_factor = init_factor + self.tie_word_embeddings = tie_word_embeddings + + self.lm_head = nn.Linear(decoder_config["d_model"], decoder_config["vocab_size"], bias=False) + self._init_weights() + + def _init_weights(self): + if self.tie_word_embeddings is False: + self.lm_head.weight.data.normal_(mean=0.0, std=self.init_factor * 1.0) + + def forward(self, decoder_hs: torch.FloatTensor) -> torch.FloatTensor: + if self.tie_word_embeddings is True: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + decoder_hs = decoder_hs * (self.d_model**-0.5) + + lm_logits = self.lm_head(decoder_hs) + return lm_logits diff --git a/amt/src/model/lr_scheduler.py b/amt/src/model/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..db680c581f7a5eeb617e8a623ae66c6692acb8d7 --- /dev/null +++ b/amt/src/model/lr_scheduler.py @@ -0,0 +1,91 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +"""lr_schedule.py""" +import torch +from typing import Dict, Optional + + +def get_lr_scheduler(optimizer: torch.optim.Optimizer, scheduler_name: str, base_lr: float, scheduler_cfg: Dict): + + if scheduler_name.lower() == 'cosine': + from torch.optim.lr_scheduler import ( + SequentialLR, + LinearLR, + CosineAnnealingLR, + ) + + scheduler1 = LinearLR( + optimizer, + start_factor=0.5, + end_factor=1, + total_iters=scheduler_cfg["warmup_steps"], + last_epoch=-1, + ) + + scheduler2 = CosineAnnealingLR( + optimizer, + T_max=scheduler_cfg["total_steps"] - scheduler_cfg["warmup_steps"], + eta_min=scheduler_cfg["final_cosine"], + ) + + lr_scheduler = SequentialLR(optimizer, + schedulers=[scheduler1, scheduler2], + milestones=[scheduler_cfg["warmup_steps"]]) + elif scheduler_name.lower() == 'legacy': + import math + from torch.optim.lr_scheduler import ( + SequentialLR, + LinearLR, + LambdaLR, + ) + + msg = "You are using T5 legacy LR Schedule, it's independent from the optim.base_lr" + print(msg) + + num_steps_optimizer1 = math.ceil(scheduler_cfg["total_steps"] * 0.9) + iters_left_for_optimizer2 = scheduler_cfg["total_steps"] - num_steps_optimizer1 + + scheduler1 = LambdaLR(optimizer, lambda step: min(base_lr, 1.0 / math.sqrt(step)) / base_lr + if step else base_lr / base_lr) + + scheduler2 = LinearLR(optimizer, + start_factor=(min(base_lr, 1.0 / math.sqrt(num_steps_optimizer1)) / base_lr), + end_factor=0, + total_iters=iters_left_for_optimizer2, + last_epoch=-1) + + lr_scheduler = SequentialLR( + optimizer, + schedulers=[scheduler1, scheduler2], + milestones=[num_steps_optimizer1], + ) + elif scheduler_name.lower() == 'constant': + from transformers import get_scheduler + lr_scheduler = get_scheduler( + name=scheduler_name.lower(), + optimizer=optimizer, + ) + else: + raise NotImplementedError + + return lr_scheduler + + +def extra_stats(args, model, optimizer): + stats = {} + + if args.logging.weights_l2: + weights_l2 = sum(p.detach().norm(2).item()**2 for p in model.parameters())**0.5 + stats['weights_l2'] = weights_l2 + + cur_lr = optimizer.param_groups[0]['lr'] + stats['lr'] = cur_lr + + return stats diff --git a/amt/src/model/ops.py b/amt/src/model/ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d11fb7d93b17318ea30b60e8f7012e1faa90f8de --- /dev/null +++ b/amt/src/model/ops.py @@ -0,0 +1,111 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +""" op.py """ +import math +from packaging.version import parse as VersionParse + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from transformers.models.t5.modeling_t5 import T5LayerNorm as RMSNorm + + +def get_layer_norm(dim: int, layer_norm_type: str = "layer_norm", layer_norm_eps: float = 1e-5): + """Get layer normalization layer. + Args: + dim (int): Feature dimension + layer_norm_type (str): "layer_norm" or "rms_norm" + layer_norm_eps (float): Epsilon value for numerical stability + + Returns: + nn.Module: Layer normalization layer + """ + if layer_norm_type == "rms_norm": + # T5LayerNorm is equivalent to RMSNorm. https://arxiv.org/abs/1910.07467 + return RMSNorm(hidden_size=dim, eps=layer_norm_eps) + else: + return nn.LayerNorm(normalized_shape=dim, eps=layer_norm_eps) + + +def check_all_elements_equal(x: torch.Tensor) -> bool: + return x.eq(x[0]).all().item() + + +def minmax_normalize(x: torch.Tensor, eps: float = 0.008) -> torch.FloatTensor: + """Min-max normalization: + + x_norm = (x - x_min) / (x_max - x_min + eps) + + Args: + x (torch.Tensor): (B, T, F) + Returns: + torch.Tensor: (B, T, F) with output range of [0, 1] + """ + x_max = rearrange(x, "b t f -> b (t f)").max(1, keepdim=True)[0] + x_min = rearrange(x, "b t f -> b (f t)").min(1, keepdim=True)[0] + x_max = x_max[:, None, :] # (B,1,1) + x_min = x_min[:, None, :] # (B,1,1) + return (x - x_min) / (x_max - x_min + eps) + + +def count_parameters(model): + num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + num_params = sum(p.numel() for p in model.parameters()) + return num_trainable_params, num_params + + +def adjust_b_to_gcd(a, b, min_gcd=16): + """ + Adjust the value of b to ensure the GCD(a, b) is at least min_gcd with minimum change to b. + + Parameters: + - a (int): A positive integer + - b (int): A positive integer + - min_gcd (int): The minimum desired GCD + + Returns: + - int: The adjusted value of b + """ + current_gcd = math.gcd(a, b) + + # If current GCD is already greater than or equal to min_gcd, return b as it is. + if current_gcd >= min_gcd: + return b + + # If a is less than min_gcd, then it's impossible to get a GCD of at least min_gcd. + if a < min_gcd: + raise ValueError("a must be at least as large as min_gcd.") + + # Adjust b by trying increments and decrements, preferring the smallest absolute change. + adjusted_b_up = b + adjusted_b_down = b + + while True: + adjusted_b_up += 1 + adjusted_b_down -= 1 + + if math.gcd(a, adjusted_b_up) >= min_gcd: + return adjusted_b_up + elif math.gcd(a, adjusted_b_down) >= min_gcd: + return adjusted_b_down + + +def optional_compiler_disable(func): + if VersionParse(torch.__version__) >= VersionParse("2.1"): + # If the version is 2.1 or higher, apply the torch.compiler.disable decorator. + return torch.compiler.disable(func) + else: + # If the version is below 2.1, return the original function. + return func + + +def optional_compiler_dynamic(func): + return torch.compile(func, dynamic=True) diff --git a/amt/src/model/optimizers.py b/amt/src/model/optimizers.py new file mode 100644 index 0000000000000000000000000000000000000000..10234d68d207d2098fbd38e3ccabefe2e6a45796 --- /dev/null +++ b/amt/src/model/optimizers.py @@ -0,0 +1,218 @@ +""" optimizers.py + +Code based on nanoT5 project: + https://github.com/PiotrNawrot/nanoT5/blob/main/nanoT5/utils/copied_utils.py + ++ D-adapt Adam from https://github.com/facebookresearch/dadaptation +""" +import importlib +import math +import torch + +from typing import Iterable, Tuple +from torch import nn +from torch.optim import Optimizer +from transformers import Adafactor +from torch.optim import AdamW + + +class AdamWScale(Optimizer): + """ + This AdamW implementation is copied from Huggingface. + We modified it with Adagrad scaling by rms of a weight tensor + + Implements Adam algorithm with weight decay fix as introduced in [Decoupled Weight Decay + Regularization](https://arxiv.org/abs/1711.05101). + + Parameters: + params (`Iterable[nn.parameter.Parameter]`): + Iterable of parameters to optimize or dictionaries defining parameter groups. + lr (`float`, *optional*, defaults to 1e-3): + The learning rate to use. + betas (`Tuple[float,float]`, *optional*, defaults to (0.9, 0.999)): + Adam's betas parameters (b1, b2). + eps (`float`, *optional*, defaults to 1e-6): + Adam's epsilon for numerical stability. + weight_decay (`float`, *optional*, defaults to 0): + Decoupled weight decay to apply. + correct_bias (`bool`, *optional*, defaults to `True`): + Whether or not to correct bias in Adam (for instance, in Bert TF repository they use `False`). + no_deprecation_warning (`bool`, *optional*, defaults to `False`): + A flag used to disable the deprecation warning (set to `True` to disable the warning). + """ + + def __init__( + self, + params: Iterable[nn.parameter.Parameter], + lr: float = 1e-3, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-6, + weight_decay: float = 0.0, + correct_bias: bool = True, + ): + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr} - should be >= 0.0") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter: {betas[0]} - should be in [0.0, 1.0)") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter: {betas[1]} - should be in [0.0, 1.0)") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps} - should be >= 0.0") + defaults = dict( + lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias) + super().__init__(params, defaults) + + @staticmethod + def _rms(tensor): + return tensor.norm(2) / (tensor.numel()**0.5) + + def step(self, closure=None): + """ + Performs a single optimization step. + + Arguments: + closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss. + """ + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + if grad.is_sparse: + raise RuntimeError( + "Adam does not support sparse gradients, please consider SparseAdam instead" + ) + + state = self.state[p] + beta1, beta2 = group["betas"] + + # State initialization + if len(state) == 0: + state["step"] = 0 + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like(p.data) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like(p.data) + + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + + state["step"] += 1 + + # Decay the first and second moment running average coefficient + # In-place operations to update the averages at the same time + exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1)) + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2) + denom = exp_avg_sq.sqrt().add_(group["eps"]) + + step_size = group["lr"] + if group["correct_bias"]: # No bias correction for Bert + bias_correction1 = 1.0 - beta1**state["step"] + bias_correction2 = 1.0 - beta2**state["step"] + step_size = step_size * math.sqrt(bias_correction2) / bias_correction1 + + # /Adapt Step from Adagrad + step_size = step_size * max(1e-3, self._rms(p.data)) + # /Adapt Step from Adagrad + + p.data.addcdiv_(exp_avg, denom, value=-step_size) + + # Just adding the square of the weights to the loss function is *not* + # the correct way of using L2 regularization/weight decay with Adam, + # since that will interact with the m and v parameters in strange ways. + # + # Instead we want to decay the weights in a manner that doesn't interact + # with the m/v parameters. This is equivalent to adding the square + # of the weights to the loss with plain (non-momentum) SGD. + # Add weight decay at the end (fixed version) + if group["weight_decay"] > 0.0: + p.data.add_(p.data, alpha=(-group["lr"] * group["weight_decay"])) + + return loss + + +# def get_optimizer(models_dict: nn.ModuleDict, +# optimizer_name: str, +# base_lr: float, +# weight_decay: float = 0.): + +# no_decay = [ +# "bias", "LayerNorm", "layernorm", "layer_norm", "ln", "BatchNorm", "bn", "batch_norm", +# "batchnorm" +# ] + + +# optimizer_grouped_parameters = [] +# for name, current_model in models_dict.items(): +# if current_model is None: +# continue +# optimizer_grouped_parameters += [ +# { +# "params": [ +# p for n, p in current_model.named_parameters() +# if not any(nd in n for nd in no_decay) +# ], +# "weight_decay": weight_decay, +# }, +# { +# "params": [ +# p for n, p in current_model.named_parameters() +# if any(nd in n for nd in no_decay) +# ], +# "weight_decay": 0.0, +# }, +# ] +def get_optimizer(models_dict: nn.ModuleDict, + optimizer_name: str, + base_lr: float, + weight_decay: float = 0.): + + no_decay = [ + "bias", "LayerNorm", "layernorm", "layer_norm", "ln", "BatchNorm", "bn", "batch_norm", + "batchnorm" + ] + optimizer_grouped_parameters = [] + for n, p in models_dict: + # drop pitch shifter + if 'pshifters' in n: + continue + # no decay + if n in no_decay: + optimizer_grouped_parameters.append({"params": [p], "weight_decay": 0.0}) + else: + optimizer_grouped_parameters.append({"params": [p], "weight_decay": weight_decay}) + + if optimizer_name.lower() == 'adamw': + base_lr = 1e-03 if base_lr == None else float(base_lr) + opt = AdamW(optimizer_grouped_parameters, lr=base_lr) + elif optimizer_name.lower() == 'adafactor': + if base_lr == None: + opt = Adafactor( + optimizer_grouped_parameters, + lr=None, + scale_parameter=True, + relative_step=True, + warmup_init=True) + else: + opt = Adafactor(optimizer_grouped_parameters, lr=base_lr, relative_step=False) + elif optimizer_name.lower() == 'adamwscale': + base_lr = 1e-02 if base_lr == None else float(base_lr) + opt = AdamWScale( + optimizer_grouped_parameters, + lr=base_lr, + ) + elif optimizer_name.lower() == 'cpuadam': + dspd = importlib.import_module('deepspeed') + base_lr = 1e-03 if base_lr == None else float(base_lr) + opt = dspd.ops.adam.cpu_adam.DeepSpeedCPUAdam(optimizer_grouped_parameters, lr=base_lr) + elif optimizer_name.lower() == 'dadaptadam': + dadaptation = importlib.import_module('dadaptation') + base_lr = 1.0 if base_lr == None else float(base_lr) + opt = dadaptation.DAdaptAdam(optimizer_grouped_parameters, lr=base_lr) + else: + raise NotImplementedError(optimizer_name) + + return opt, base_lr diff --git a/amt/src/model/perceiver_helper.py b/amt/src/model/perceiver_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..fb52f1f9e71514aca2df338b5cf209fdbd6cc115 --- /dev/null +++ b/amt/src/model/perceiver_helper.py @@ -0,0 +1,290 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +from dataclasses import dataclass +from typing import Optional, Tuple +import torch +from torch import nn +from transformers.utils import ModelOutput +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_utils import PreTrainedModel +# from transformers.models.perceiver.modeling_perceiver import (PerceiverAbstractPositionEncoding, +# PerceiverTrainablePositionEncoding, +# PerceiverFourierPositionEncoding) + + +class PerceiverTFConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`PerceiverTF`]. It is used to instantiate an + Perceiver model according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the Perceiver + [deepmind/language-perceiver](https://huggingface.co/deepmind/language-perceiver) architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + num_latents (`int`, *optional*, defaults to 256): + The number of latents. + d_latents (`int`, *optional*, defaults to 1280): + Dimension of the latent embeddings. + d_model (`int`, *optional*, defaults to 768): + Dimension of the inputs. Should only be provided in case [*PerceiverTextPreprocessor*] is used or no + preprocessor is provided. + kv_dim (`int`, *optional*, defaults to 128): + num_blocks (`int`, *optional*, defaults to 1): + Number of blocks in the Transformer encoder. + num_self_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each self-attention layer in the Transformer encoder. + num_cross_attention_heads (`int`, *optional*, defaults to 8): + Number of attention heads for each cross-attention layer in the Transformer encoder. + num_local_transformers_per_block (`int`, *optional*, defaults to 2): + Number of local Transformer layers per Transformer block in the Transformer encoder. + num_temporal_transformers_per_block (`int`, *optional*, defaults to 2): + Number of temporal Transformer layers per Transformer block in the Transformer encoder. + shared_parallel_temporal_transformers (`bool`, *optional*, defaults to `False`): + Whether to share the parameters across the K parallel temporal Transformers in each block. + qk_channels (`int`, *optional*): + Dimension to project the queries + keys before applying attention in the cross-attention and self-attention + layers of the encoder. Will default to preserving the dimension of the queries if not specified. + v_channels (`int`, *optional*): + Dimension to project the values before applying attention in the cross-attention and self-attention layers + of the encoder. Will default to preserving the dimension of the queries if not specified. + ** DEPRECATED ** cross_attention_shape_for_attention (`str`, *optional*, defaults to `'kv'`): + Dimension to use when downsampling the queries and keys in the cross-attention layer of the encoder. + ** DEPRECATED ** self_attention_widening_factor (`int`, *optional*, defaults to 1): + Dimension of the feed-forward layer in the cross-attention layer of the Transformer encoder. + cross_attention_widening_factor (`int`, *optional*, defaults to 1): + Dimension of the feed-forward layer in the self-attention layers of the Transformer encoder. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` are supported. + dropout_rate (`float`, *optional*, defaults to 0.1): + The dropout ratio for the attention probabilities. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + layer_norm_type (`str`, *optional*, defaults to `'layer_norm'`): + The type of layer normalization to use. Can be one of {'layer_norm', 'rms_norm'}. + layer_norm_eps (`float`, *optional*, defaults to 1e-5): + The epsilon used by the layer normalization layers. + sca_use_query_residual (`bool`, *optional*, defaults to `True`): + Whether to add a query residual in the spectral cross attention (SCA) layer of the encoder. + use_query_residual (`float`, *optional*, defaults to `True`): + Whether to add a query residual in the cross-attention layer of the encoder. + position_encoding_type (`str`, *optional*, defaults to `'trainable'`): + Type of position encoding to use. Can be one of {'trainable', 'alibi', 'alibit', 'rope', None}. + num_max_positions (`int`, *optional*, defaults to 331): + Maximum number of positions to use for the position encoding. + vocab_size (`int`, *optional*, defaults to 262): + Vocabulary size for the masked language modeling model. + attention_to_channel (`bool`, defaults to `False`): + Whether SCA should attend to the channel dimension. If False, attention to frequency bin dimension. + ff_layer_type (`str`, *optional*, defaults to `'mlp'`): + Type of feed-forward layer to use. Can be one of {'mlp', 'moe'}. + ff_widening_factor (`int`, *optional*, defaults to 1): + Widening factor for the feed-forward layers in the MLP/MoE. + moe_num_experts (`int`, *optional*, defaults to 4): + Number of experts to use in the mixture of experts (MoE) feed-forward layer. + Only used if `ff_layer_type` is set to `'moe'`. + moe_topk (`int`, *optional*, defaults to 2): + Number of top experts to use in the mixture of experts (MoE) feed-forward layer. + Only used if `ff_layer_type` is set to `'moe'`. + rope_type_sca (`str`, *optional*, defaults to `pixel`): Can be one of {'l'|lang', 'p'|'pixel', None}. + RoPE index type for SCA. Only used if `position_encoding_type` is set to `rope`. + rope_type_latent (`str`, *optional*, defaults to `pixel`): Can be one of {'l'|'lang', 'p'|'pixel', None}. + RoPE index type for Latent Transformer. Only used if `position_encoding_type` is set to `'rope'`. + rope_type_temporal (`str`, *optional*, defaults to `lang`): Can be one of {'l'|'lang', 'p'|'pixel', None}. + RoPE index type for Temporal Transformer. Only used if `position_encoding_type` is set to `'rope'`. + rope_apply_to_keys (`bool`, *optional*, defaults to `False`): Whether to apply RoPE to the keys in the + self/cross-attention layers. Only used if `position_encoding_type` is set to `'rope'`. + rope_partial_pe (`bool`, *optional*, defaults to `False`): Whether to use partial RoPE in the self/cross-attention. + Only used if `position_encoding_type` is set to `'rope'`. + rope_trainable (`bool`, *optional*, defaults to `False`): Whether to make the RoPE trainable. Only used if + + Example: + + ```python + >>> from model.perceiver_mod import PerceiverTFEncodel, PerceiverTFConfig + + >>> # Initializing a Perceiver deepmind/language-perceiver style configuration + >>> configuration = PerceiverTFConfig() + + >>> # Initializing a model from the deepmind/language-perceiver style configuration + >>> model = PerceiverTFEncoder(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + model_type = "perceivertf" + + def __init__( + self, + num_latents=24, + d_latents=128, + d_model=128, + kv_dim=128, + num_blocks=3, + num_self_attention_heads=8, + num_cross_attention_heads=8, + num_local_transformers_per_block=2, + num_temporal_transformers_per_block=2, + qk_channels=128, + v_channels=128, + cross_attention_shape_for_attention="q", + # self_attention_widening_factor=1, ** DEPRECATED ** + # cross_attention_widening_factor=1, ** DEPRECATED ** + hidden_act="gelu", + dropout_rate=0.1, + initializer_range=0.02, + layer_norm_type="layer_norm", + layer_norm_eps=1e-5, + sca_use_query_residual=True, + use_query_residual=True, + position_encoding_type="trainable", + num_max_positions=330, + vocab_size=1391, + attention_to_channel=False, + ff_layer_type="mlp", + ff_widening_factor=1, + moe_num_experts=4, + moe_topk=2, + rope_type_sca="pixel", + rope_type_latent="pixel", + rope_type_temporal="lang", + rope_apply_to_keys=False, + rope_partial_pe=False, + rope_trainable=False, + **kwargs, + ): + super().__init__(**kwargs) + + self.num_latents = num_latents + self.d_latents = d_latents + self.d_model = d_model + self.kv_dim = kv_dim + self.qk_channels = qk_channels + self.v_channels = v_channels + + self.num_blocks = num_blocks + self.num_self_attention_heads = num_self_attention_heads + self.num_cross_attention_heads = num_cross_attention_heads + self.num_local_transformers_per_block = num_local_transformers_per_block + self.num_temporal_transformers_per_block = num_temporal_transformers_per_block + self.sca_use_query_residual = sca_use_query_residual + self.use_query_residual = use_query_residual + self.position_encoding_type = position_encoding_type + self.num_max_positions = num_max_positions + # self.self_attention_widening_factor = self_attention_widening_factor + # self.cross_attention_widening_factor = cross_attention_widening_factor + self.cross_attention_shape_for_attention = cross_attention_shape_for_attention + self.attention_to_channel = attention_to_channel + self.ff_layer_type = ff_layer_type + self.ff_widening_factor = ff_widening_factor + self.moe_num_experts = moe_num_experts + self.moe_topk = moe_topk + self.rope_type_sca = rope_type_sca + self.rope_type_latent = rope_type_latent + self.rope_type_temporal = rope_type_temporal + self.rope_apply_to_keys = rope_apply_to_keys + self.rope_partial_pe = rope_partial_pe + self.rope_trainable = rope_trainable + + self.hidden_act = hidden_act + self.dropout_rate = dropout_rate + self.initializer_range = initializer_range + self.layer_norm_type = layer_norm_type + self.layer_norm_eps = layer_norm_eps + + # masked language modeling attributes + self.vocab_size = vocab_size + + +class PerceiverTFPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = PerceiverTFConfig + base_model_prefix = "perceivertf" + main_input_name = "inputs" + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, (nn.Linear, nn.Conv2d)): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif hasattr(module, "latents"): + module.latents.data.normal_(mean=0.0, std=self.config.initializer_range) + elif hasattr(module, "_pos_emb") and isinstance(module._pos_emb, nn.Parameter): + # initialize PerceiverTFTrainablePE + module._pos_emb.data.normal_(mean=0.0, std=self.config.initializer_range) + elif hasattr(module, "_pos_emb_temporal"): + # initialize PerceiverTFTrainablePE + module._pos_emb_temporal.data.normal_(mean=0.0, std=self.config.initializer_range) + elif hasattr(module, "slopes") and isinstance(module.slopes, nn.Parameter): + # initialize AlibiPositionalBias + module.reset_parameters() + elif isinstance(module, nn.ParameterDict): + for modality in module.keys(): + module[modality].data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + # elif hasattr(module, "position_embeddings") and isinstance( + # module, PerceiverTrainablePositionEncoding): + # module.position_embeddings.data.normal_(mean=0.0, std=self.config.initializer_range) + + +# Replace the 'ModelOutputWithCrossAttentions' with 'MoEModelOutputWithCrossAttentions' for MoE +@dataclass +class MoEModelOutputWithCrossAttentions(ModelOutput): + """ + Base class for model's outputs, with potential hidden states and attentions. + Plus, router_probs for Mixture of Experts models. + + Args: + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + + Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the + weighted average in the cross-attention heads. + router_probs (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_probs=True` and `config.add_router_probs=True` is passed or when `config.output_router_probs=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`. + + Raw router probabilities that are computed by MoE routers, these terms are used to compute the auxiliary + loss and the z_loss for Mixture of Experts models. + """ + + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + cross_attentions: Optional[Tuple[torch.FloatTensor]] = None + router_logits: Optional[Tuple[torch.FloatTensor]] = None diff --git a/amt/src/model/perceiver_mod.py b/amt/src/model/perceiver_mod.py new file mode 100644 index 0000000000000000000000000000000000000000..749d9cd5e5b62c76671dfcdbf1cb078c372c7a2a --- /dev/null +++ b/amt/src/model/perceiver_mod.py @@ -0,0 +1,912 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +"""perceiver_mod.py + + Implementation of the PerceiverTF encoder with: + - AliBi positional bias + - Mixtral of Experts (MoE) feedforward layer + +""" +import math +from einops import rearrange +from typing import Optional, Tuple, Union, List, Dict, Literal + +import torch +from torch import nn +from transformers.models.perceiver.modeling_perceiver import PerceiverSelfOutput +from transformers.pytorch_utils import (apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer) +from model.perceiver_helper import MoEModelOutputWithCrossAttentions +from model.perceiver_helper import PerceiverTFPreTrainedModel, PerceiverTFConfig +from model.positional_encoding import AlibiPositionalBias, get_rotary_emb +from model.ops import get_layer_norm +from model.ff_layer import get_ff_layer + + +class PerceiverEmbeddings(nn.Module): + """Construct the latent embeddings sharable with token embeddings in the decoder.""" + + def __init__(self, config, shared_emb: Optional[nn.Parameter] = None): + super().__init__() + if shared_emb is not None: + self.latents = shared_emb + assert self.latents.shape == (config.num_latents, config.d_latents) + else: + self.latents = nn.Parameter(torch.randn(config.num_latents, config.d_latents)) + + def forward(self, batch_size: int): + return self.latents.expand(batch_size, -1, -1) + + +class PerceiverTFTrainablePE(nn.Module): + """Construct the trainable absolute positional embeddings.""" + + def __init__(self, position_encoding_type: Literal['trainable', 'tkd', 'td', 'tk', 'kdt'], max_t: int, k: int, + d: int) -> None: + super().__init__() + self.position_encoding_type = position_encoding_type + self.max_t = max_t + self.k = k + self.d = d + + if position_encoding_type in ['trainable', 'tkd']: + self._pos_emb = nn.Parameter(torch.randn(max_t, k, d)) + elif position_encoding_type == 'td': + self._pos_emb = nn.Parameter(torch.randn(max_t, d)) + elif position_encoding_type == 'tk': + self._pos_emb = nn.Parameter(torch.randn(max_t, k)) + elif position_encoding_type == 'kdt': + self._pos_emb = nn.Parameter(torch.randn(k, d)) + self._pos_emb_temporal = nn.Parameter(torch.randn(max_t, d)) + else: + raise ValueError(f'unknown position encoding type {position_encoding_type}') + + def forward(self): + pos_emb_temporal = None + + if self.position_encoding_type in ['trainable', 'tkd']: + pos_emb = self._pos_emb + elif self.position_encoding_type == 'td': + pos_emb = self._pos_emb.unsqueeze(1).expand(-1, self.k, -1) + elif self.position_encoding_type == 'tk': + pos_emb = self._pos_emb.unsqueeze(-1).expand(-1, -1, self.d) + elif self.position_encoding_type == 'kdt': + pos_emb = self._pos_emb.unsqueeze(0).expand(self.max_t, -1, -1) + pos_emb_temporal = self._pos_emb_temporal + + return pos_emb, pos_emb_temporal + + +class PerceiverAlibiSelfAttention(nn.Module): + """ + Multi-headed {cross, self}-attention + Alibi/Rotary positional bias/emb: + - Can be used both in the encoder as well as in the decoder. + - Modified from PerceiverSelfAttention in modeling_perceiver.py to support Alibi positional bias + + """ + + def __init__( + self, + config, + is_cross_attention=False, + qk_channels=None, + v_channels=None, + num_heads=1, + q_dim=None, + kv_dim=None, + rotary_emb=None, + ): + super().__init__() + self.num_heads = num_heads + # Q and K must have the same number of channels. + # Default to preserving Q's input's shape. + if qk_channels is None: + qk_channels = q_dim + # V's num_channels determines the shape of the output of QKV-attention. + # Default to the same number of channels used in the key-query operation. + if v_channels is None: + v_channels = qk_channels + if qk_channels % num_heads != 0: + raise ValueError(f"qk_channels ({qk_channels}) must be divisible by num_heads ({num_heads}).") + if v_channels % num_heads != 0: + raise ValueError(f"v_channels ({v_channels}) must be divisible by num_heads ({num_heads}).") + + self.qk_channels = qk_channels + self.v_channels = v_channels + self.qk_channels_per_head = self.qk_channels // num_heads + self.v_channels_per_head = self.v_channels // num_heads + + # Layer normalization + self.layernorm1 = get_layer_norm(q_dim, config.layer_norm_type, config.layer_norm_eps) + if is_cross_attention: + self.layernorm2 = get_layer_norm(kv_dim, config.layer_norm_type, config.layer_norm_eps) + else: + self.layernorm2 = nn.Identity() + # self.layernorm1 = nn.LayerNorm(q_dim) + # self.layernorm2 = nn.LayerNorm(kv_dim) if is_cross_attention else nn.Identity() + + # Projection matrices + self.query = nn.Linear(q_dim, qk_channels) + self.key = nn.Linear(kv_dim, qk_channels) + self.value = nn.Linear(kv_dim, v_channels) + + self.dropout = nn.Dropout(config.dropout_rate) + + # (Modified) Alibi positional bias + if config.position_encoding_type == 'alibi': + self.alibi_bias = AlibiPositionalBias(heads=num_heads, total_heads=num_heads, trainable_slope=False) + elif config.position_encoding_type == 'alibit': + self.alibi_bias = AlibiPositionalBias(heads=num_heads, total_heads=num_heads, trainable_slope=True) + else: + self.alibi_bias = None + # (Modified) RoPE + if config.position_encoding_type == 'rope': + assert rotary_emb is not None, "rotary_emb must be provided for RoPE." + self.rotary_emb = rotary_emb + else: + self.rotary_emb = None + self.rope_apply_to_keys = config.rope_apply_to_keys # False by default + + def transpose_for_scores(self, x, channels_per_head): + new_x_shape = x.size()[:-1] + (self.num_heads, channels_per_head) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs: Optional[torch.FloatTensor] = None, + inputs_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + hidden_states = self.layernorm1(hidden_states) + inputs = self.layernorm2(inputs) + + # Project queries, keys and values to a common feature dimension. If this is instantiated as a cross-attention module, + # the keys and values come from the inputs; the attention mask needs to be such that the inputs's non-relevant tokens are not attended to. + is_cross_attention = inputs is not None + queries = self.query(hidden_states) + + if is_cross_attention: + keys = self.key(inputs) + values = self.value(inputs) + attention_mask = inputs_mask + else: + keys = self.key(hidden_states) + values = self.value(hidden_states) + + # Reshape channels for multi-head attention. + # We reshape from (batch_size, time, channels) to (batch_size, num_heads, time, channels per head) + queries = self.transpose_for_scores(queries, self.qk_channels_per_head) + keys = self.transpose_for_scores(keys, self.qk_channels_per_head) + values = self.transpose_for_scores(values, self.v_channels_per_head) + + # (Modified) RoPE + if self.rotary_emb is not None: + queries = self.rotary_emb.apply_rotary_custom(queries) + if self.rope_apply_to_keys is True: + keys = self.rotary_emb.apply_rotary_custom(keys) + + # Take the dot product between the queries and keys to get the raw attention scores. + attention_scores = torch.matmul(queries, keys.transpose(-1, -2)) + + # (Modified) Alibi positional bias + if self.alibi_bias is not None: + batch_size, num_heads, q_seq_len, k_seq_len = attention_scores.shape + attention_scores += self.alibi_bias(q_seq_len, + k_seq_len) # auto-broadcasting to (b, num_heads, q_seq_len, k_seq_len) + + _, _, _, q_head_dim = queries.shape + _, _, _, v_head_dim = values.shape + hiddens = self.num_heads * v_head_dim + + attention_scores = attention_scores / math.sqrt(q_head_dim) + + if attention_mask is not None: + # Apply the attention mask (precomputed for all layers in PerceiverModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, values) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (hiddens,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) + + return outputs + + +class PerceiverAlibiAttention(nn.Module): + """ + Attention module, including a dense block + Alibi + : modified from PerceiverAttention in modeling_perceiver.py to support Alibi positional bias + """ + + def __init__( + self, + config, + is_cross_attention=False, + qk_channels=None, + v_channels=None, + num_heads=1, + q_dim=None, + kv_dim=None, + use_query_residual=True, + rotary_emb=None, + ): + super().__init__() + # MultiHead attention + if is_cross_attention and qk_channels is None: + if config.cross_attention_shape_for_attention == "q": + qk_channels = q_dim + elif config.cross_attention_shape_for_attention == "kv": + qk_channels = kv_dim + else: + raise ValueError(f"Unknown value {config.cross_attention_shape_for_attention} for " + "cross_attention_shape_for_attention.") + else: + if qk_channels is None: + qk_channels = q_dim + if v_channels is None: + v_channels = qk_channels + self.self = PerceiverAlibiSelfAttention(config, + is_cross_attention=is_cross_attention, + qk_channels=qk_channels, + v_channels=v_channels, + num_heads=num_heads, + q_dim=q_dim, + kv_dim=kv_dim, + rotary_emb=rotary_emb) + # dense block + output_channels = None + if is_cross_attention: + output_channels = q_dim + else: + if output_channels is None: + output_channels = v_channels + self.output = PerceiverSelfOutput(config, input_channels=self.self.v_channels, output_channels=output_channels) + self.use_query_residual = use_query_residual + self.pruned_heads = set() + + def prune_heads(self, heads): + if len(heads) == 0: + return + heads, index = find_pruneable_heads_and_indices(heads, self.self.num_attention_heads, + self.self.attention_head_size, self.pruned_heads) + + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + + # Update hyper params and store pruned heads + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + self.pruned_heads = self.pruned_heads.union(heads) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs: Optional[torch.FloatTensor] = None, + inputs_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + self_outputs = self.self( + hidden_states, + attention_mask, + head_mask, + inputs, + inputs_mask, + output_attentions, + ) + + # Output projection + attention_output = self.output(self_outputs[0]) + + # Optionally include a residual to the original queries. + # Consider omitting the residual if the semantics of query and output + # are different, e.g. if queries are positions and outputs are pixels. + if self.use_query_residual: + attention_output = attention_output + hidden_states + + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class PerceiverAlibiLayer(nn.Module): + """Construct a single PerceiverTF layer with: + - Alibi positional bias + - RoPE + - Mixtral of Experts (MoE) feedforward layer + + """ + + def __init__( + self, + config, + is_cross_attention=False, + qk_channels=None, + v_channels=None, + num_heads=1, + q_dim=None, + kv_dim=None, + widening_factor=1, + use_query_residual=True, + rotary_emb=None, + ): + super().__init__() + self.chunk_size_feed_forward = config.chunk_size_feed_forward + self.seq_len_dim = 1 + self.attention = PerceiverAlibiAttention(config, + is_cross_attention=is_cross_attention, + qk_channels=qk_channels, + v_channels=v_channels, + num_heads=num_heads, + q_dim=q_dim, + kv_dim=kv_dim, + use_query_residual=use_query_residual, + rotary_emb=rotary_emb) + self.layernorm = get_layer_norm(q_dim, config.layer_norm_type, config.layer_norm_eps) + # self.layernorm = nn.LayerNorm(q_dim) + self.mlp = get_ff_layer(config, input_size=q_dim, widening_factor=widening_factor) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs: Optional[torch.FloatTensor] = None, + inputs_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor]: + attention_outputs = self.attention( + hidden_states, + attention_mask, + head_mask, + inputs, + inputs_mask, + output_attentions, + ) + attention_output = attention_outputs[0] + + outputs = attention_outputs[1:] # add attentions if we output attention weights + """apply_chunking_to_forward: + This function chunks the input_tensors into smaller input tensor parts of size + chunk_size over the dimension chunk_dim. It then applies a layer forward_fn to + each chunk independently to save memory.If the forward_fn is independent across + the chunk_dim this function will yield the same result as not applying it. + """ + layer_output, router_logits = apply_chunking_to_forward(self.feed_forward_chunk, self.chunk_size_feed_forward, + self.seq_len_dim, attention_output) + + layer_output = layer_output + attention_output # residual connection + outputs = (layer_output,) + outputs + (router_logits,) # add router_logits to outputs + return outputs + + def feed_forward_chunk(self, attention_output): + layer_output = self.layernorm(attention_output) + layer_output, router_logits = self.mlp(layer_output) # router_logits is returned only when using MoE. + return layer_output, router_logits + + +class PerceiverTFEncoderBlock(nn.Module): + """Construct a single block of PerceiverTF encoder: + - Spectral Cross Attention (SCA) + - Local latent transformer layers + - Temporal transformer layers + - added Alibi positional bias, RoPE, gMLP and MoE feedforward layer + """ + + def __init__(self, + config: PerceiverTFConfig, + kv_dim: Optional[int] = None, + sca_use_query_residual: bool = True, + rotary_emb_sca: Optional[nn.Module] = None, + rotary_emb_latent: Optional[nn.Module] = None, + rotary_emb_temporal: Optional[nn.Module] = None): + super().__init__() + self.config = config + + # Check that we can use multihead-attention with these shapes. + if config.d_latents % config.num_self_attention_heads != 0: + raise ValueError(f"num_z_channels ({config.d_latents}) must be divisible by" + f" num_self_attend_heads ({config.num_self_attention_heads}).") + if config.d_latents % config.num_cross_attention_heads != 0: + raise ValueError(f"num_z_channels ({config.d_latents}) must be divisible by" + f" num_cross_attend_heads ({config.num_cross_attention_heads}).") + + if kv_dim is None: + kv_dim = config.kv_dim + if sca_use_query_residual is None: + sca_use_query_residual = config.sca_use_query_residual + + # Spectral Cross Attention (SCA) layer. + self.sca_attention_to_channel = config.attention_to_channel + self.spectral_cross_attention = PerceiverAlibiAttention(config, + is_cross_attention=True, + qk_channels=config.qk_channels, + v_channels=config.v_channels, + num_heads=config.num_cross_attention_heads, + q_dim=config.d_latents, + kv_dim=kv_dim, + use_query_residual=sca_use_query_residual, + rotary_emb=rotary_emb_sca) # (Modified) RoPE + + # Local latent trasformer layers. + local_transformer_layers = [] + for _ in range(config.num_local_transformers_per_block): + layer = PerceiverAlibiLayer( + config, + is_cross_attention=False, + qk_channels=config.qk_channels, # projection dim for q and k. + v_channels=config.v_channels, # projection dim for v. + num_heads=config.num_self_attention_heads, + q_dim=config.d_model, + kv_dim=config.d_model, + widening_factor=config.ff_widening_factor, + use_query_residual=config.use_query_residual, + rotary_emb=rotary_emb_latent # (Modified) RoPE + ) + local_transformer_layers.append(layer) + self.local_transformer = nn.ModuleList(local_transformer_layers) + + # Temporal transformer layers. + temporal_transformer_layers = [] + for _ in range(config.num_temporal_transformers_per_block): + layer = PerceiverAlibiLayer( + config, + is_cross_attention=False, + qk_channels=config.qk_channels, # projection dim for q and k. + v_channels=config.v_channels, # projection dim for v. + num_heads=config.num_self_attention_heads, + q_dim=config.d_model, + kv_dim=config.d_model, + widening_factor=config.ff_widening_factor, + use_query_residual=config.use_query_residual, + rotary_emb=rotary_emb_temporal # (Modified) RoPE + ) + temporal_transformer_layers.append(layer) + self.temporal_transformer = nn.ModuleList(temporal_transformer_layers) + + def forward( + self, + hidden_states: torch.Tensor, + inputs: Optional[torch.FloatTensor] = None, + inputs_mask: Optional[torch.FloatTensor] = None, + local_attention_mask: Optional[torch.FloatTensor] = None, + temporal_attention_mask: Optional[torch.FloatTensor] = None, + local_head_mask: Optional[torch.FloatTensor] = None, + temporal_head_mask: Optional[torch.FloatTensor] = None, + pos_emb_temporal: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = False, + output_hidden_states: Optional[bool] = False, + output_router_logits: Optional[bool] = False, # Only used for MoE. + return_dict: Optional[bool] = True, + ) -> Union[Tuple, MoEModelOutputWithCrossAttentions]: + """ + Inputs: + hidden_states: (B, T, K, D) + inputs: (B, T, F, C) + Returns: + hidden_states: (B, T, K, D) + + Args: + hidden_states: + latent_array (B, T, num_latents, d_latents) for SCA. The latent array + with shape (B, K, D) is expanded by t, and positional embeddings are + added to it. + inputs: torch.FloatTensor + The input sequence of shape (B, T, F, C). + inputs_mask: torch.FloatTensor + Only used for SCA. By default, None. + local_attention_mask: + Used for local self-attention. By default, None. + temporal_attention_mask: + Used for temporal self-attention. By default, None. + local_head_mask: + By default, None. + temporal_head_mask: + By default, None. + pos_emb_temporal: + Optioanl. Used for temporal self-attention. By default, None. (max_t, num_latents, d_latents) + output_attentions: bool + Whether to return attentions weights. + output_hidden_states: bool + Whether to return all hidden states. If False, only last hidden + state is returned. + output_router_logits: bool + Whether to return router logits for MoE. If False, only last hidden + state is returned. + return_dict: bool + Whether to return a MoEModelOutputWithCrossAttentions instead of a tuple. + """ + + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + all_cross_attentions = () if output_attentions else None + all_router_logits = () if output_router_logits else None + + # Collect dimension info + batch_size, t, num_latents, d_latents = hidden_states.size() # (B, T, K, D) + + # if self.sca_attention_to_channel: + # _, _, _, f = inputs.size() # (B, T, C, F) + # assert d_latents == f, "d_latents must be equal to kv_dim, which is input frequency dim." + # else: + # _, _, _, c = inputs.size() # (B, T, F, C) + # assert d_latents == c, "d_latents must be equal to kv_dim, which is input channels." + + # Reshape (B, T, _, _) to (B*T, _, _) for SCA and local transformer. + hidden_states = rearrange(hidden_states, "b t k d -> (b t) k d") + inputs = rearrange(inputs, "b t f c -> (b t) f c") + + # Apply the SCA between the latents (hidden_states) and inputs: + layer_outputs = self.spectral_cross_attention( + hidden_states, + attention_mask=None, # Input_mask is used instead for cross-attention + inputs=inputs, + inputs_mask=inputs_mask, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] # (B*T, K, D) + + if output_attentions: + all_cross_attentions = all_cross_attentions + (layer_outputs[1],) + + # Apply the block of local latent transformer layers. + for i, layer_module in enumerate(self.local_transformer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = local_head_mask[i] if local_head_mask is not None else None + layer_outputs = layer_module( + hidden_states, + attention_mask=local_attention_mask, + head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] # (B*T, K, D) + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if output_router_logits: + all_router_logits = all_router_logits + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # Reshape (B*T, K, D) to (B*K, T, D) for the temporal transformer. + hidden_states = rearrange(hidden_states, "(b t) k d -> (b k) t d", b=batch_size) + + # Apply the block of temporal transformer layers. + for i, layer_module in enumerate(self.temporal_transformer): + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + layer_head_mask = temporal_head_mask[i] if temporal_head_mask is not None else None + + if i == 0 and pos_emb_temporal is not None: + # Add temporal positional embeddings to the hidden_states. + hidden_states = hidden_states + pos_emb_temporal[:t] # pos_emb_temporal: (T, D) + + layer_outputs = layer_module( + hidden_states, + attention_mask=temporal_attention_mask, + head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = layer_outputs[0] + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + if output_router_logits: + all_router_logits = all_router_logits + (layer_outputs[2],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + last_hideen_state = hidden_states + # Reshape (B*K, T, D) to (B, T, K, D) for the next block. + last_hideen_state = rearrange(last_hideen_state, "(b k) t d -> b t k d", b=batch_size) + + # Prepare the outputs. + if not return_dict: + return tuple( + v for v in + [last_hideen_state, all_hidden_states, all_self_attentions, all_cross_attentions, all_router_logits] + if v is not None) + return MoEModelOutputWithCrossAttentions( + last_hidden_state=last_hideen_state, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + cross_attentions=all_cross_attentions, + router_logits=all_router_logits, + ) + + +class PerceiverTFEncoder(PerceiverTFPreTrainedModel): + """PerceiverTFEncoder is an encoder model based on the Perceiver and Spectral Cross Attention (SCA). + + position_encoding_type: str + The type of positional encoding to use. One of the following: + - 'trainable': trainable positional embeddings + - 'alibi': AlibiNet positional embeddings + - 'alibit': AlibiNet positional embeddings with trainable slopes for each head + - 'rope': RoPE (Rotary Positional Encoding) + (experimental w/ 'trainable') + - 'tkd': trainable PE (T,K,D) on latent (default for 'trainable') + - 'td': trainable PE (T,D) on latent + - 'tk': trainable PE (T,K) on latent + - 'kdt': trainable PE (K,D) on latent, and (T,) on temporal transformer + + """ + + def __init__(self, + config: PerceiverTFConfig, + sca_use_query_residual: Optional[bool] = None, + shared_emb: Optional[nn.Embedding] = None): + super().__init__(config) + self.config = config + + if sca_use_query_residual is None: + self.sca_use_query_residual = config.sca_use_query_residual # True by default + self.position_encoding_type = config.position_encoding_type + self.sca_attention_to_channel = config.attention_to_channel + + # Construct a latent array. + self.latent_array = PerceiverEmbeddings(config) # (num_latents, d_latents) + + # Positional embeddings for the latent array. + if self.position_encoding_type == 'rope': + # (Modified) RoPE + self.rotary_emb_sca = get_rotary_emb(config.num_cross_attention_heads, config.rope_type_sca, + config.rope_partial_pe, config.rope_trainable) + self.rotary_emb_latent = get_rotary_emb(config.num_cross_attention_heads, config.rope_type_latent, + config.rope_partial_pe, config.rope_trainable) + self.rotary_emb_temporal = get_rotary_emb(config.num_cross_attention_heads, config.rope_type_temporal, + config.rope_partial_pe, config.rope_trainable) + else: + self.rotary_emb_sca = None + self.rotary_emb_latent = None + self.rotary_emb_temporal = None + + if self.position_encoding_type in ['alibi', 'alibit', 'rope', None]: + # alibi is imeplemented within PerceiverAlibiSelfAttention, and activated by config. + # RoPE is implemented without using self.pos_emb. + self.pos_emb = None + else: + k, d = self.latent_array.latents.size() + max_t = int(config.num_max_positions) + 10 # 10 is headroom for future task tokens... + self.pos_emb = PerceiverTFTrainablePE(self.position_encoding_type, max_t, k, d) + """ + self.pos_emb() returns: + pos_emb: (max_t, K, D) + pos_emb_temporal: (max_t, K, D) + """ + + # Construct the encoder blocks. + blocks = [] + for _ in range(config.num_blocks): + block = PerceiverTFEncoderBlock( + config, + kv_dim=config.kv_dim, + sca_use_query_residual=sca_use_query_residual, + rotary_emb_sca=self.rotary_emb_sca, # (Modified) RoPE + rotary_emb_latent=self.rotary_emb_latent, + rotary_emb_temporal=self.rotary_emb_temporal) + blocks.append(block) + self.blocks = nn.ModuleList(blocks) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.latent_array.latents + + def set_input_embeddings(self, value): + self.latent_array.latents = value + + """temporary fix for torch.compile issue""" + + def forward(self, **kwargs): + if self.training is True: + return self._forward_compile(**kwargs) + else: + return self._forward_no_compile(**kwargs) + + def _forward_no_compile(self, **kwargs): + return self._forward(**kwargs) + + @torch.compile + def _forward_compile(self, **kwargs): + return self._forward(**kwargs) + + def _forward( + self, + inputs: Optional[torch.FloatTensor] = None, # (B, T, F, kv_dim) + inputs_embeds: Optional[torch.FloatTensor] = None, # (B, T, F, kv_dim) + inputs_mask: Optional[torch.FloatTensor] = None, # (B, F) Mask freq. of inputs in SCA. + local_attention_mask: Optional[torch.FloatTensor] = None, # (B, K) + temporal_attention_mask: Optional[torch.FloatTensor] = None, # (B, T) + local_head_mask: Optional[torch.FloatTensor] = None, + temporal_head_mask: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, MoEModelOutputWithCrossAttentions]: + # Inputs and inputs_embeds are tied, and actually the same. (following T5 convention) + # Inputs are from convoulutional features from audio. + # Don't be confused with latent embeddings, which is `self.latent_array.latents`, and + # used as hidden_state of block. + if inputs is None and inputs_embeds is not None: + inputs = inputs_embeds + elif inputs is None and inputs_embeds is None: + raise ValueError("You must provide 'inputs' or 'inputs_embeds' argument.") + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size, t, _f, _c = inputs.size() + device = inputs.device + + # SCA attention to channels of inputs, instead of frequency bins. + if self.sca_attention_to_channel is True: + inputs = rearrange(inputs, "b t f c -> b t c f") + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_blocks x num_heads] + # and head_mask is converted to shape [num_blocks x batch x num_heads x N x N] + local_head_mask = self.get_head_mask(local_head_mask, + self.config.num_blocks * self.config.num_local_transformers_per_block) + temporal_head_mask = self.get_head_mask( + temporal_head_mask, self.config.num_blocks * self.config.num_temporal_transformers_per_block) + + # Prepare attention mask: not implemented + + # Expand the latent embeddings by t: (B, K, D) --> (B, T, K, D) + latent_embeddings = self.latent_array(batch_size=batch_size) # (B, num_latents, d_latents) + expanded_latent_embeddings = latent_embeddings.unsqueeze(1).expand(-1, t, -1, -1) + + # Add positional embeddings to the expanded latent embeddings: (B, T, K, D) + if self.pos_emb is not None: + pos_emb_latent, pos_emb_temporal = self.pos_emb.forward() + expanded_latent_embeddings = expanded_latent_embeddings + pos_emb_latent[:t] + # (max_t, K, D) -> (T, K, D) -> (B, T, K, D) auto-broadcasting + else: + pos_emb_temporal = None + + # Lists to store intermediate outputs if required + all_hidden_states = [] + all_attentions = [] + all_cross_attentions = [] + all_router_logits = [] + + hidden_states = expanded_latent_embeddings + + # Forward-pass + for i, block in enumerate(self.blocks): + block_output = block(hidden_states=hidden_states, + inputs=inputs, + inputs_mask=inputs_mask, + local_attention_mask=local_attention_mask, + temporal_attention_mask=temporal_attention_mask, + local_head_mask=local_head_mask, + temporal_head_mask=temporal_head_mask, + pos_emb_temporal=pos_emb_temporal if i == 0 else None, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=True) + + # Update the hidden_states for the next block + hidden_states = block_output.last_hidden_state + + # Append to lists if required + if output_hidden_states: + all_hidden_states.append(hidden_states) + if output_attentions: + all_attentions.append(block_output.attentions) + all_cross_attentions.append(block_output.cross_attentions) + if output_router_logits: + all_router_logits.append(block_output.router_logits) + last_hidden_states = hidden_states + + # Prepare outputs + if not return_dict: + # Convert lists to tuples + return (last_hidden_states, tuple(all_hidden_states) if all_hidden_states else None, + tuple(all_attentions) if all_attentions else None, + tuple(all_cross_attentions) if all_cross_attentions else None, + tuple(all_router_logits) if all_router_logits else None) + + return MoEModelOutputWithCrossAttentions( + last_hidden_state=last_hidden_states, + hidden_states=tuple(all_hidden_states) if all_hidden_states else None, + attentions=tuple(all_attentions) if all_attentions else None, + cross_attentions=tuple(all_cross_attentions) if all_cross_attentions else None, + router_logits=tuple(all_router_logits) if all_router_logits else None) + + +def test(): + # In HuggingFace's Perceiver implementation: + # `q_dim` is the latent array dimension d_latents of ((B), num_latents, d_latents). + # `kv_dim`os the actual input dimension D of (B, T, D) + # `qk_channels`, `v_channels`: are projection dimensions for attention, (B, T, C) + # (B, T, D) --> projection --> (B, T, C) + # However, PerceiverTF does not require projection: + # It takes as input a latent tensor (B, num_latents, d_latents) and a conv_feat tensor (T, B, F, C) + # The `spectral-cross-attention` and `local-self-attention-transformer` takes as input (B*T, F, C), + # and C=D=d_latents. + from model.ops import count_parameters + + # Test input + b = 2 # batch + t = 10 # time steps (330 for 6s in paper) + f = 128 # freq of conv_feat + c = 128 # channels of conv_feat + k = 24 # num_latents + d = 128 # d_latents + conv_feat = torch.randn(b, t, f, c) + + # construct PerceiverTFEncoder + config = PerceiverTFConfig() + pe_types = ['alibi', 'alibit', 'trainable', 'tkd', 'td', 'tk', 'kdt', None] + config.ff_layer_type = 'moe' + config.moe_num_experts = 4 + config.moe_topk = 2 + + for pe_type in pe_types: + config.position_encoding_type = pe_type # 'alibi', 'alibit', 'trainable', 'tkd', 'td', 'tk', 'kdt', None + config.num_latents = k + config.d_latents = d + config.kv_dim = c + config.qk_channels = d + config.v_channels = d + encoder = PerceiverTFEncoder(config) + encoder.eval() + assert encoder.latent_array.latents.size() == (k, d) + # forward + enc_hidden_state = encoder.forward(inputs_embeds=conv_feat).last_hidden_state + # print(enc_hidden_state.shape) # [2, 10, 24, 128] = [B, T, K, D] + n_param = count_parameters(encoder)[1] // 1000 + print(config.position_encoding_type, f'num_param: {n_param}K') + """ + PE type | num. param. + None | 1397K + alibi | 1397K + alibit (train slope) | 1397K + tkd | 2442K + td | 1441K + tk | 1405K + kdt | 1444K + + MLP | 2637K + MoE (4 experts) | 4411K + MoE (6 experts) | 5594K + """ diff --git a/amt/src/model/pitchshift_layer.py b/amt/src/model/pitchshift_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..27d85d26e7be165cfecab1c2cbacbecc99b94189 --- /dev/null +++ b/amt/src/model/pitchshift_layer.py @@ -0,0 +1,550 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +"""pitchshift.py""" +# import math +import numpy as np +# from scipy import special +from einops import rearrange +from typing import Optional, Literal, Dict, List, Tuple, Callable + +import torch +from torch import nn +import torchaudio +from torchaudio import transforms +# from torchaudio import functional as F +# from torchaudio.functional.functional import ( +# _fix_waveform_shape, +# _stretch_waveform, +# ) +# from model.ops import adjust_b_to_gcd, check_all_elements_equal + + +class PitchShiftLayer(nn.Module): + """Applying batch-wise pitch-shift to time-domain audio signals. + + Args: + pshift_range (List[int]): Range of pitch shift in semitones. Default: ``[-2, 2]``. + resample_source_fs (int): Default is 4000. + stretch_n_fft (int): Default is 2048. + window: (Optional[Literal['kaiser']]) Default is None. + beta: (Optional[float]): Parameter for 'kaiser' filter. Default: None. + """ + + def __init__( + self, + pshift_range: List[int] = [-2, 2], + resample_source_fs: int = 4000, + strecth_n_fft: int = 512, + win_length: Optional[int] = None, + hop_length: Optional[int] = None, + window: Optional[Literal['kaiser']] = None, + beta: Optional[float] = None, + expected_input_shape: Optional[Tuple[int]] = None, + device: Optional[torch.device] = None, + **kwargs, + ) -> None: + super().__init__() + self.pshift_range = pshift_range + self.resample_source_fs = resample_source_fs + self.strecth_n_fft = strecth_n_fft + self.win_length = win_length + self.hop_length = hop_length + + if window is None: + self.window_fn = torch.hann_window + self.window_kwargs = None + elif 'kaiser' in window: + + def custom_kaiser_window(window_length, beta, **kwargs): + return torch.kaiser_window(window_length, periodic=True, beta=beta, **kwargs) + + self.window_fn = custom_kaiser_window + self.window_kwargs = {'beta': beta} + + # Initialize pitch shifters for every semitone + self.pshifters = None + self.frame_gaps = None + self._initialize_pshifters(expected_input_shape, device=device) + self.requires_grad_(False) + + def _initialize_pshifters(self, + expected_input_shape: Optional[Tuple[int]] = None, + device: Optional[torch.device] = None) -> None: + # DDP requires initializing parameters with a dummy input + if expected_input_shape is not None: + if device is not None: + dummy_input = torch.randn(expected_input_shape, requires_grad=False).to(device) + else: + dummy_input = torch.randn(expected_input_shape, requires_grad=False) + else: + dummy_input = None + + pshifters = nn.ModuleDict() + for semitone in range(self.pshift_range[0], self.pshift_range[1] + 1): + if semitone == 0: + # No need to shift and resample + pshifters[str(semitone)] = None + else: + pshifter = transforms.PitchShift(self.resample_source_fs, + n_steps=semitone, + n_fft=self.strecth_n_fft, + win_length=self.win_length, + hop_length=self.hop_length, + window_fn=self.window_fn, + wkwargs=self.window_kwargs) + pshifters[str(semitone)] = pshifter + # Pass dummy input to initialize parameters + with torch.no_grad(): + if dummy_input is not None: + _ = pshifter.initialize_parameters(dummy_input) + self.pshifters = pshifters + + def calculate_frame_gaps(self) -> Dict[int, float]: + """Calculate the expected gap between the original and the stretched audio.""" + frame_gaps = {} # for debugging + for semitone in range(self.pshift_range[0], self.pshift_range[1] + 1): + if semitone == 0: + # No need to shift and resample + frame_gaps[semitone] = 0. + else: + pshifter = self.pshifters[str(semitone)] + gap_in_ms = 1000. * (pshifter.kernel.shape[2] - + pshifter.kernel.shape[0] / 2.0**(-float(semitone) / 12)) / self.resample_source_fs + frame_gaps[semitone] = gap_in_ms + return frame_gaps + + @torch.no_grad() + def forward(self, x: torch.Tensor, semitone: int) -> torch.Tensor: + """ + Args: + x (torch.Tensor): (B, 1, T) or (B, T) + Returns: + torch.Tensor: (B, 1, T) or (B, T) + """ + if semitone == 0: + return x + elif semitone >= min(self.pshift_range) and semitone <= max(self.pshift_range): + return self.pshifters[str(semitone)](x) + else: + raise ValueError(f"semitone must be in range {self.pshift_range}") + + +def test_resampler_sinewave(): + # x: {440Hz, 220Hz} sine wave at 16kHz + t = torch.arange(0, 2, 1 / 16000) # 2 seconds at 16kHz + x0 = torch.sin(2 * torch.pi * 440 * t) * 0.5 + x1 = torch.sin(2 * torch.pi * 220 * t) * 0.5 + x = torch.stack((x0, x1), dim=0) # (2, 32000) + + # Resample + psl = PitchShiftLayer(pshift_range=[-2, 2], resample_source_fs=4000) + y = psl(x, 2) # (2, 24000) + + # Export to wav + torchaudio.save("x.wav", x, 16000, bits_per_sample=16) + torchaudio.save("y.wav", y, 12000, bits_per_sample=16) + + +# class Resampler(nn.Module): +# """ +# Resampling using conv1d operations, more memory-efficient than torchaudio's resampler. + +# Based on Dan Povey's resampler.py: +# https://github.com/danpovey/filtering/blob/master/lilfilter/resampler.py +# """ + +# def __init__(self, +# input_sr: int, +# output_sr: int, +# dtype: torch.dtype = torch.float32, +# filter_width: int = 16, +# cutoff_ratio: float = 0.85, +# filter: Literal['kaiser', 'kaiser_best', 'kaiser_fast', 'hann'] = 'kaiser_fast', +# beta: float = 8.555504641634386) -> None: +# super().__init__() # init the base class +# """ +# Initialize the Resampler. + +# Args: +# - input_sr (int): Input sampling rate. +# - output_sr (int): Output sampling rate. +# - dtype (torch.dtype): Computation data type. Default: torch.float32. +# - filter_width (int): Number of zeros per side in the sinc function. Default: 16. +# - cutoff_ratio (float): Filter rolloff point as a fraction of Nyquist freq. Default: 0.95. +# - filter (str): Filter type. One of ['kaiser', 'kaiser_best', 'kaiser_fast', 'hann']. Default: 'kaiser_fast'. +# - beta (float): Parameter for 'kaiser' filter. Default: 8.555504641634386. + +# Note: Ratio between input_sr and output_sr should be reduced to simplest form. +# """ +# assert isinstance(input_sr, int) and isinstance(output_sr, int) +# if input_sr == output_sr: +# self.resample_type = 'trivial' +# return + +# d = math.gcd(input_sr, output_sr) +# input_sr, output_sr = input_sr // d, output_sr // d + +# assert dtype in [torch.float32, torch.float64] +# assert filter_width > 3 # a reasonable bare minimum +# np_dtype = np.float32 if dtype == torch.float32 else np.float64 + +# assert filter in ['hann', 'kaiser', 'kaiser_best', 'kaiser_fast'] + +# if filter == 'kaiser_best': +# filter_width = 64 +# beta = 14.769656459379492 +# cutoff_ratio = 0.9475937167399596 +# filter = 'kaiser' +# elif filter == 'kaiser_fast': +# filter_width = 16 +# beta = 8.555504641634386 +# cutoff_ratio = 0.85 +# filter = 'kaiser' +# """ +# - Define a sample 'block' correlating `input_sr` input samples to `output_sr` output samples. +# - Dividing samples into these blocks allows corresponding block alignment. +# - On average, `zeros_per_block` zeros per block are present in the sinc function. +# """ +# zeros_per_block = min(input_sr, output_sr) * cutoff_ratio +# """ +# - Define conv kernel size n = (blocks_per_side*2 + 1), adding blocks to each side of the center. +# - `blocks_per_side` blocks as window radius ensures each central block sample accesses its window. +# - `blocks_per_side` is determined, rounding up if needed, as 1 + int(filter_width / zeros_per_block). +# """ +# blocks_per_side = int(np.ceil(filter_width / zeros_per_block)) + +# kernel_width = 2 * blocks_per_side + 1 + +# # Shape of conv1d weights: (out_channels, in_channels, kernel_width) +# """ Time computations are in units of 1 block, aligning with the `canonical` time axis, +# since each block has input_sr input samples, adhering to our time unit.""" + +# window_radius_in_blocks = blocks_per_side +# """`times` will be sinc function arguments, expanding to shape (output_sr, input_sr, kernel_width) +# via broadcasting. Ensuring t == 0 along the central block diagonal (when input_sr == output_sr)""" +# times = ( +# np.arange(output_sr, dtype=np_dtype).reshape( +# (output_sr, 1, 1)) / output_sr - np.arange(input_sr, dtype=np_dtype).reshape( +# (1, input_sr, 1)) / input_sr - (np.arange(kernel_width, dtype=np_dtype).reshape( +# (1, 1, kernel_width)) - blocks_per_side)) + +# def hann_window(a): +# """ +# returning 0.5 + 0.5 cos(a*pi) on [-1,1] and 0 outside. +# """ +# return np.heaviside(1 - np.abs(a), 0.0) * (0.5 + 0.5 * np.cos(a * np.pi)) + +# def kaiser_window(a, beta): +# w = special.i0(beta * np.sqrt(np.clip(1 - ( +# (a - 0.0) / 1.0)**2.0, 0.0, 1.0))) / special.i0(beta) +# return np.heaviside(1 - np.abs(a), 0.0) * w + +# """The weights are computed as a sinc function times a Hann-window function, normalized by +# `zeros_per_block` (sinc) and `input_sr` (input function) to maintain integral and magnitude.""" +# if filter == 'hann': +# weights = ( +# np.sinc(times * zeros_per_block) * hann_window(times / window_radius_in_blocks) * +# zeros_per_block / input_sr) +# else: +# weights = ( +# np.sinc(times * zeros_per_block) * +# kaiser_window(times / window_radius_in_blocks, beta) * zeros_per_block / input_sr) + +# self.input_sr = input_sr +# self.output_sr = output_sr +# """If output_sr == 1, merge input_sr into kernel_width for weights (shape: output_sr, input_sr, +# kernel_width) to optimize convolution speed and avoid extra reshaping.""" + +# assert weights.shape == (output_sr, input_sr, kernel_width) +# if output_sr == 1: +# self.resample_type = 'integer_downsample' +# self.padding = input_sr * blocks_per_side +# weights = torch.tensor(weights, dtype=dtype, requires_grad=False) +# weights = weights.transpose(1, 2).contiguous().view(1, 1, input_sr * kernel_width) + +# elif input_sr == 1: +# # For conv_transpose, use weights as if input_sr and output_sr were swapped, simulating downsampling. +# self.resample_type = 'integer_upsample' +# self.padding = output_sr * blocks_per_side +# weights = torch.tensor(weights, dtype=dtype, requires_grad=False) +# weights = weights.flip(2).transpose(0, +# 2).contiguous().view(1, 1, output_sr * kernel_width) +# else: +# self.resample_type = 'general' +# self.reshaped = False +# self.padding = blocks_per_side +# weights = torch.tensor(weights, dtype=dtype, requires_grad=False) + +# self.weights = torch.nn.Parameter(weights, requires_grad=False) + +# @torch.no_grad() +# def forward(self, x: torch.Tensor) -> torch.Tensor: +# """ +# Parameters: +# - x: torch.Tensor, with shape (minibatch_size, sequence_length), dtype should match the instance's dtype. + +# Returns: +# - A torch.Tensor with shape (minibatch_size, (sequence_length//input_sr)*output_sr), dtype matching the input, +# and content resampled. +# """ +# if self.resample_type == 'trivial': +# return x +# elif self.resample_type == 'integer_downsample': +# (minibatch_size, seq_len) = x.shape # (B, in_C, L) with in_C == 1 +# x = x.unsqueeze(1) +# x = torch.nn.functional.conv1d( +# x, self.weights, stride=self.input_sr, padding=self.padding) # (B, out_C, L) +# return x.squeeze(1) # (B, L) + +# elif self.resample_type == 'integer_upsample': +# x = x.unsqueeze(1) +# x = torch.nn.functional.conv_transpose1d( +# x, self.weights, stride=self.output_sr, padding=self.padding) + +# return x.squeeze(1) +# else: +# assert self.resample_type == 'general' +# (minibatch_size, seq_len) = x.shape +# num_blocks = seq_len // self.input_sr +# if num_blocks == 0: +# # TODO: pad with zeros. +# raise RuntimeError("Signal is too short to resample") +# # Truncate input +# x = x[:, 0:(num_blocks * self.input_sr)].view(minibatch_size, num_blocks, self.input_sr) +# x = x.transpose(1, 2) # (B, in_C, L) +# x = torch.nn.functional.conv1d( +# x, self.weights, padding=self.padding) # (B, out_C, num_blocks) +# return x.transpose(1, 2).contiguous().view(minibatch_size, num_blocks * self.output_sr) + +# def test_resampler_sinewave(): +# import torchaudio +# # x: {440Hz, 220Hz} sine wave at 16kHz +# t = torch.arange(0, 2, 1 / 16000) # 2 seconds at 16kHz +# x0 = torch.sin(2 * torch.pi * 440 * t) * 0.5 +# x1 = torch.sin(2 * torch.pi * 220 * t) * 0.5 +# x = torch.stack((x0, x1), dim=0) # (2, 32000) + +# # Resample +# resampler = Resampler(input_sr=16000, output_sr=12000) +# y = resampler(x) # (2, 24000) + +# # Export to wav +# torchaudio.save("x.wav", x, 16000, bits_per_sample=16) +# torchaudio.save("y.wav", y, 12000, bits_per_sample=16) + +# def test_resampler_music(): +# import torchaudio +# # x: music at 16kHz +# x, _ = torchaudio.load("music.wav") +# slice_length = 32000 +# n_slices = 80 +# slices = [x[0, i * slice_length:(i + 1) * slice_length] for i in range(n_slices)] +# x = torch.stack(slices) # (80, 32000) + +# # Resample +# filter_width = 32 +# resampler = Resampler(16000, 12000, filter_width=filter_width) +# y = resampler(x) # (80, 24000) +# y = y.reshape(1, -1) # (1, 1920000) +# torchaudio.save(f"y_filter_width{filter_width}.wav", y, 12000, bits_per_sample=16) + +# class PitchShiftLayer(nn.Module): +# """Applying batch-wise pitch-shift to time-domain audio signals. + +# Args: +# expected_input_length (int): Expected input length. Default: ``32767``. +# pshift_range (List[int]): Range of pitch shift in semitones. Default: ``[-2, 2]``. +# min_gcd (int): Minimum GCD of input and output sampling rates for resampling. Setting high value can save GPU memory. Default: ``16``. +# max_timing_error (float): Maximum allowed timing error in seconds. Default: ``0.002``. +# fs (int): Sample rate of input waveform, x. Default: 16000. +# bins_per_octave (int, optional): The number of steps per octave (Default : ``12``). +# n_fft (int, optional): Size of FFT, creates ``n_fft // 2 + 1`` bins (Default: ``512``). +# win_length (int or None, optional): Window size. If None, then ``n_fft`` is used. (Default: ``None``). +# hop_length (int or None, optional): Length of hop between STFT windows. If None, then ``win_length // 4`` +# is used (Default: ``None``). +# window (Tensor or None, optional): Window tensor that is applied/multiplied to each frame/window. +# If None, then ``torch.hann_window(win_length)`` is used (Default: ``None``). + +# """ + +# def __init__( +# self, +# expected_input_length: int = 32767, +# pshift_range: List[int] = [-2, 2], +# min_gcd: int = 16, +# max_timing_error: float = 0.002, +# fs: int = 16000, +# bins_per_octave: int = 12, +# n_fft: int = 2048, +# win_length: Optional[int] = None, +# hop_length: Optional[int] = None, +# window: Optional[torch.Tensor] = None, +# filter_width: int = 16, +# filter: Literal['kaiser', 'kaiser_best', 'kaiser_fast', 'hann'] = 'kaiser_fast', +# cutoff_ratio: float = 0.85, +# beta: float = 8.555504641634386, +# **kwargs, +# ): +# super().__init__() +# self.expected_input_length = expected_input_length +# self.pshift_range = pshift_range +# self.min_gcd = min_gcd +# self.max_timing_error = max_timing_error +# self.fs = fs +# self.bins_per_octave = bins_per_octave +# self.n_fft = n_fft +# self.win_length = win_length +# self.hop_length = hop_length +# self.window = window +# self.resample_args = { +# "filter_width": filter_width, +# "filter": filter, +# "cutoff_ratio": cutoff_ratio, +# "beta": beta, +# } + +# # Initialize Resamplers +# self._initialize_resamplers() + +# def _initialize_resamplers(self): +# resamplers = nn.ModuleDict() +# self.frame_gaps = {} # for debugging +# for i in range(self.pshift_range[0], self.pshift_range[1] + 1): +# if i == 0: +# # No need to shift and resample +# resamplers[str(i)] = None +# else: +# # Find optimal reconversion frames meeting the min_gcd +# stretched_frames, recon_frames, gap = self._find_optimal_reconversion_frames(i) +# self.frame_gaps[i] = gap +# resamplers[str(i)] = Resampler(stretched_frames, recon_frames, **self.resample_args) +# self.resamplers = resamplers + +# def _find_optimal_reconversion_frames(self, semitone: int): +# """ +# Find the optimal reconversion frames for a given source sample rate, input length, and semitone for strech. + +# Parameters: +# - sr (int): Input audio sample rate, which should be power of 2 +# - n_step (int): The number of pitch-shift steps in semi-tone. +# - min_gcd (int): The minimum desired GCD, power of 2. Defaults to 16. 16 or 32 are good choices. +# - max_timing_error (float): The maximum allowed timing error, in seconds. Defaults to 5 ms + +# Returns: +# - int: The optimal target sample rate +# """ +# stretch_rate = 1 / 2.0**(-float(semitone) / self.bins_per_octave) +# stretched_frames = round(self.expected_input_length * stretch_rate) + +# gcd = math.gcd(self.expected_input_length, stretched_frames) +# if gcd >= self.min_gcd: +# return stretched_frames, self.expected_input_length, 0 +# else: +# reconversion_frames = adjust_b_to_gcd(stretched_frames, self.expected_input_length, +# self.min_gcd) +# gap = reconversion_frames - self.expected_input_length +# gap_sec = gap / self.fs +# if gap_sec > self.max_timing_error: +# # TODO: modifying vocoder of stretch_waveform to adjust pitch-shift rate in cents +# raise ValueError( +# gap_sec < self.max_timing_error, +# f"gap_sec={gap_sec} > max_timing_error={self.max_timing_error} with semitone={semitone}, stretched_frames={stretched_frames}, recon_frames={reconversion_frames}. Try adjusting input lenght or decreasing min_gcd." +# ) +# else: +# return stretched_frames, reconversion_frames, gap_sec + +# @torch.no_grad() +# def forward(self, +# x: torch.Tensor, +# semitone: int, +# resample: bool = True, +# fix_shape: bool = True) -> torch.Tensor: +# """ +# Args: +# x (torch.Tensor): (B, 1, T) +# Returns: +# torch.Tensor: (B, 1, T) +# """ +# if semitone == 0: +# return x +# elif semitone >= min(self.pshift_range) and semitone <= max(self.pshift_range): +# x = x.squeeze(1) # (B, T) +# original_x_size = x.size() +# x = _stretch_waveform( +# x, +# semitone, +# self.bins_per_octave, +# self.n_fft, +# self.win_length, +# self.hop_length, +# self.window, +# ) +# if resample: +# x = self.resamplers[str(semitone)].forward(x) +# # Fix waveform shape +# if fix_shape: +# if x.size(1) != original_x_size[1]: +# # print(f"Warning: {x.size(1)} != {original_x_length}") +# x = _fix_waveform_shape(x, original_x_size) +# return x.unsqueeze(1) # (B, 1, T) +# else: +# raise ValueError(f"semitone must be in range {self.pshift_range}") + +# def test_pitchshift_layer(): +# import torchaudio +# # music +# # x, _ = torchaudio.load("music.wav") +# # slice_length = 32767 +# # n_slices = 80 +# # slices = [x[0, i * slice_length:(i + 1) * slice_length] for i in range(n_slices)] +# # x = torch.stack(slices).unsqueeze(1) # (80, 1, 32767) + +# # sine wave +# t = torch.arange(0, 2.0479, 1 / 16000) # 2.05 seconds at 16kHz +# x = torch.sin(2 * torch.pi * 440 * t) * 0.5 +# x = x.reshape(1, 1, 32767).tile(80, 1, 1) + +# # Resample +# pos = 0 +# ps = PitchShiftLayer( +# pshift_range=[-3, 4], +# expected_input_length=32767, +# fs=16000, +# min_gcd=16, +# max_timing_error=0.002, +# # filter_width=64, +# filter='kaiser_fast', +# n_fft=2048) +# y = [] +# for i in range(-3, 4): +# y.append(ps(x[[pos], :, :], i, resample=False, fix_shape=False)[0, 0, :]) +# y = torch.cat(y).unsqueeze(0) # (1, 32767 * 7) +# torchaudio.save("y_2048_kaiser_fast.wav", y, 16000, bits_per_sample=16) + +# # TorchAudio PitchShifter fopr comparision +# y_ta = [] +# for i in range(-3, 4): +# ta_transform = torchaudio.transforms.PitchShift(16000, n_steps=i) +# y_ta.append(ta_transform(x[[pos], :, :])[0, 0, :]) +# y_ta = torch.cat(y_ta).unsqueeze(0) # (1, 32767 * 7) +# torchaudio.save("y_ta.wav", y_ta, 16000, bits_per_sample=16) + +# def test_min_gcd_mem_usage(): +# min_gcd = 16 +# for i in range(-3, 4): +# stretched_frames = _stretch_waveform(x, i).shape[1] +# adjusted = adjust_b_to_gcd(stretched_frames, 32767, min_gcd) +# gcd_val = math.gcd(adjusted, stretched_frames) +# gap = adjusted - 32767 +# gap_ms = (gap / 16000) * 1000 +# mem_mb = (stretched_frames / gcd_val) * (adjusted / gcd_val) * 3 * 4 / 1000 / 1000 +# print(f'\033[92mmin_gcd={min_gcd}\033[0m', f'ps={i}', f'frames={stretched_frames}', +# f'adjusted_frames={adjusted}', f'gap={gap}', f'\033[91mgap_ms={gap_ms}\033[0m', +# f'gcd={gcd_val}', f'mem_MB={mem_mb}') diff --git a/amt/src/model/positional_encoding.py b/amt/src/model/positional_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..bb6723f226882a666ce5c2385b28ec0801ccde8f --- /dev/null +++ b/amt/src/model/positional_encoding.py @@ -0,0 +1,288 @@ +"""positional_encoding.py """ +from typing import Optional, Literal +from inspect import isfunction +from math import log, log2, pi, floor + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn, einsum +from einops import rearrange, repeat + +from model.RoPE.RoPE import RotaryEmbedding + + +class AlibiPositionalBias(nn.Module): + """ + Alibi Positional Bias for Transformer Attention + : modified to support trainalbe slope similar to "little bird" paper, based on + https://github.com/lucidrains/x-transformers/ + https://github.com/ofirpress/attention_with_linear_biases/issues/5 + + This is Alibi positional bias extension for: + - bi-directional self/cross attention + - supporting extrapolation. + + References: + Ofir, Noah A. Smith, and Mike Lewis. "Train short, test long: Attention with linear + biases enables input length extrapolation." arXiv preprint arXiv:2108.12409 (2021). + + Lee, Minchul, Kijong Han, and Myeong Cheol Shin. "LittleBird: Efficient Faster & Longer + Transformer for Question Answering." arXiv preprint arXiv:2210.11870 (2022). + """ + + def __init__(self, + heads: int = 8, + total_heads: int = 8, + trainable_slope: bool = False, + trainable_slope_init: Literal['random', 'log'] = 'random', + **kwargs) -> None: + super().__init__() + self.heads = heads # number of heads to be activated + self.total_heads = total_heads # number of heads in attention module + self.trainable_slope = trainable_slope + self.trainable_slope_init = trainable_slope_init + + if trainable_slope: + self.slopes = nn.Parameter(torch.Tensor(heads, 1, 1), requires_grad=True) + else: + slopes = torch.Tensor(self._get_slopes(heads)) + slopes = rearrange(slopes, 'h -> h 1 1') + self.register_buffer('slopes', slopes, persistent=False) + + self.register_buffer('bias', None, persistent=False) + + def reset_parameters(self) -> None: + if self.trainable_slope: + if self.trainable_slope_init == 'random': + nn.init.normal_(self.slopes, -2, 1) + else: + raise NotImplementedError(f'Unknown trainable_slope_init: {self.trainable_slope_init}') + + def get_bias(self, i, j, device): + i_arange = torch.arange(j - i, j, device=device) + j_arange = torch.arange(j, device=device) + bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1')) + return bias + + @staticmethod + def _get_slopes(heads): + + def get_slopes_power_of_2(n): + start = (2**(-2**-(log2(n) - 3))) + ratio = start + return [start * ratio**i for i in range(n)] + + if log2(heads).is_integer(): + return get_slopes_power_of_2(heads) + + closest_power_of_2 = 2**floor(log2(heads)) + return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2( + 2 * closest_power_of_2)[0::2][:heads - closest_power_of_2] + + @staticmethod + def pad_at_dim(t, pad, dim=-1, value=0.): + dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1) + zeros = ((0, 0) * dims_from_right) + return F.pad(t, (*zeros, *pad), value=value) + + @property + def device(self): + if self.trainable_slope: + return self.slopes.device + else: + return next(self.buffers()).device + + def forward(self, i, j): + """ + Args: + i (int): end index of query + j (int): end index of key + + Returns: + torch.Tensor: (num_total_heads, i, j) positional bias for each head + + Usage: + >>> alibi_bias = AlibiPositionalBias(heads=8, total_heads=8, trainable_slope=False) + >>> pos_bias = alibi_bias(len(q), len(k)) + >>> q_dot_k = ... + >>> q_dot_k += pos_bias + >>> q_dot_k = q_dot_k.softmax(dim=-1) + + """ + h, device = self.total_heads, self.device + if self.trainable_slope: + if self.bias is not None and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i: + bias = self.bias[..., :i, :j] + else: + bias = self.get_bias(i, j, device) + num_heads_unalibied = h - bias.shape[0] + bias = self.pad_at_dim(bias, (0, num_heads_unalibied), dim=0) + self.register_buffer('bias', bias, persistent=False) + + return self.bias * torch.sigmoid(self.slopes) + + else: + if self.bias is not None and self.bias.shape[-1] >= j and self.bias.shape[-2] >= i: + return self.bias[..., :i, :j] + + bias = self.get_bias(i, j, device) + bias = bias * self.slopes + + num_heads_unalibied = h - bias.shape[0] + bias = self.pad_at_dim(bias, (0, num_heads_unalibied), dim=0) + self.register_buffer('bias', bias, persistent=False) + + return self.bias + + +class FixedSinusoidalPositionalEmbedding(nn.Embedding): + """ + Sinusoidal Absolute Positional Embeddings (APE) of any length. + + Adapted from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding + + """ + + def __init__(self, num_positions: int, embedding_dim: int, padding_idx: Optional[int] = None): + super().__init__(num_positions, embedding_dim) + self.weight = self._init_weight(self.weight) + + @staticmethod + def _init_weight(out: nn.Parameter): + """ + Identical to the XLM create_sinusoidal_embeddings except features are not interleaved. The cos features are in + the 2nd half of the vector. [dim // 2:] + """ + n_pos, dim = out.shape + position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos) + ]) + out.requires_grad = False # set early to avoid an error in pytorch-1.8+ + sentinel = dim // 2 if dim % 2 == 0 else (dim // 2) + 1 + out[:, 0:sentinel] = torch.FloatTensor(np.sin(position_enc[:, 0::2])) + out[:, sentinel:] = torch.FloatTensor(np.cos(position_enc[:, 1::2])) + out.detach_() + return out + + @torch.no_grad() + def forward(self, seq_len: int, past_key_values_length: int = 0): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + positions = torch.arange(past_key_values_length, + past_key_values_length + seq_len, + dtype=torch.long, + device=self.weight.device) + return super().forward(positions) + + +class Wav2Vec2ConformerRotaryPositionalEmbedding(nn.Module): + """Rotary positional embedding + Reference : https://blog.eleuther.ai/rotary-embeddings/ Paper: https://arxiv.org/pdf/2104.09864.pdf + """ + + def __init__(self, config): + super().__init__() + dim = config.d_model // config.num_heads + base = config.rotary_embedding_base + + inv_freq = 1.0 / (base**(torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + self.cached_sequence_length = None + self.cached_rotary_positional_embedding = None + + def forward(self, hidden_states): + sequence_length = hidden_states.shape[1] + + if sequence_length == self.cached_sequence_length and self.cached_rotary_positional_embedding is not None: + return self.cached_rotary_positional_embedding + + self.cached_sequence_length = sequence_length + time_stamps = torch.arange(sequence_length).type_as(self.inv_freq) + freqs = torch.einsum("i,j->ij", time_stamps, self.inv_freq) + embeddings = torch.cat((freqs, freqs), dim=-1) + + cos_embeddings = embeddings.cos()[:, None, None, :] + sin_embeddings = embeddings.sin()[:, None, None, :] + self.cached_rotary_positional_embedding = torch.stack([cos_embeddings, sin_embeddings]) + return self.cached_rotary_positional_embedding + + +class Wav2Vec2ConformerRelPositionalEmbedding(nn.Module): + """Relative positional encoding module.""" + + def __init__(self, config): + super().__init__() + self.max_len = config.num_max_positions + self.d_model = config.d_model + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, self.max_len)) + + def extend_pe(self, x): + # Reset the positional encodings + if self.pe is not None: + # self.pe contains both positive and negative parts + # the length of self.pe is 2 * input_len - 1 + if self.pe.size(1) >= x.size(1) * 2 - 1: + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + # Suppose `i` is the position of query vector and `j` is the + # position of key vector. We use positive relative positions when keys + # are to the left (i>j) and negative relative positions otherwise (i>> m = GroupLinearFlatten(128, 512, 24) # + >>> input = torch.randn(16, 10, 24, 128) # (B, T, C, F) + >>> output = m(input) + >>> output.size() + torch.Size([16, 10, 512]) # (B, T, D) + """ + + def __init__(self, in_features, flatten_out_features, num_groups, use_bmm=True): + super().__init__() + self.in_features = in_features + self.flatten_out_features = flatten_out_features + self.num_groups = num_groups + self.use_bmm = use_bmm + + # Assuming flatten_out_features is divisible by num_groups + self.out_features_per_group = self.flatten_out_features // self.num_groups + + # Each group gets its own weights + self.weight = nn.Parameter(torch.Tensor(num_groups, self.out_features_per_group, in_features)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + + def forward(self, input): + # input shape: (batch, seq_length, groups, in_features) + # weight shape: (groups, out_features_per_group, in_features) + + batch_size, t, k, source_d = input.size() + + if self.use_bmm: + # Reshape input for bmm operation + input_reshaped = rearrange(input, 'b t k d -> k d (b t)') + + # Matrix multiplication: dot((k, out_features_per_group, d), (k, d, b*t)) -> (k, out_features_per_group, b*t) + output_bmm = torch.bmm(self.weight, input_reshaped) + + # Reshape back to original shape and flatten the group dimension + output = rearrange(output_bmm, 'k d_out (b t) -> b t (k d_out)', b=batch_size, t=t, k=k) + else: + output = torch.einsum('bsgi,goi->bsgo', input, self.weight) + output = rearrange(output, 'b t k d_out -> b t (k d_out)') + + return output + + +# class MultiChannelGroupLinear(nn.Module): +# """ Not Implemented Yet """ +# def __init__(self, in_ch=26, in_dim=128, out_ch=13, out_dim=512): +# super().__init__() + +# self.in_ch = in_ch +# self.in_dim = in_dim +# self.out_ch = out_ch +# self.out_dim = out_dim +# self.in_ch_per_group = in_ch // out_ch + +# self.layer = GroupLinearFlatten(in_features=) + + +class MultiChannelLinearProjection(nn.Module): + + def __init__(self, in_ch=26, in_dim=128, out_ch=13, out_dim=512): + super().__init__() + self.in_ch = in_ch + self.in_dim = in_dim + self.out_ch = out_ch + self.out_dim = out_dim + + self.in_ch_per_group = in_ch // out_ch + self.linear_in_ch = in_ch // self.in_ch_per_group + self.linear_in_dim = in_dim * self.in_ch_per_group + + # Reshaped Input shape: (b, t, in_dim//in_ch_per_group, in_dim*in_ch_per_group) + # Output shape: (b, t, out_ch, out_dim) + if in_dim * self.in_ch_per_group == out_dim: + self.linear = nn.Identity() + else: + self.linear = nn.Linear(in_features=self.linear_in_dim, out_features=out_dim, bias=False) + + def forward(self, x): + """ + Args: + x: (B, T, C, D) + + Returns: + x: (B, C_target, T, D_target) + """ + x = rearrange(x, 'b t (c1 c2) d -> b c1 t (c2 d)', c1=self.linear_in_ch, c2=self.in_ch_per_group) + return self.linear(x) + + +def get_multi_channel_projection_layer(input_shape: Tuple[int], output_shape: Tuple[int], proj_type: str) -> nn.Module: + """ This function returns one of the projection layers for multi-channel models.""" + in_ch = input_shape[-2] + in_dim = input_shape[-1] + out_ch = output_shape[-2] + out_dim = output_shape[-1] + + if proj_type == 'mc_shared_linear': + return MultiChannelLinearProjection(in_ch, in_dim, out_ch, out_dim) + + +def test_multi_channel_linear_projection(): + x = torch.randn(2, 10, 26, 128) # (b, t, c, d) + mclp = MultiChannelLinearProjection(in_ch=26, in_dim=128, out_ch=13, out_dim=256) # actually nn.Identity() + assert type(nn.Identity()) == type(mclp.linear) + assert mclp(x).shape == (2, 13, 10, 256) # (b, _c, t, _d) + + x = torch.randn(2, 10, 26, 128) # (b, t, c, d) + mclp = MultiChannelLinearProjection(in_ch=26, in_dim=128, out_ch=13, out_dim=512) # actually nn.Identity() + assert torch.nn.modules.linear.Linear == type(mclp.linear) + assert mclp(x).shape == (2, 13, 10, 512) # (b, _c, t, _d) + + +class FlattenMLP(nn.Module): + + def __init__(self, in_features, flatten_out_features, num_groups, hidden_dim=None, activation=None): + super().__init__() + + self.in_features = in_features + self.num_groups = num_groups + + # Calculate flattened input dimension + self.flat_in_dim = in_features * num_groups + if hidden_dim is None: + hidden_dim = self.flat_in_dim // 2 + self.hidden_dim = hidden_dim + + # Check if flatten_out_features is divisible by in_features + assert flatten_out_features % in_features == 0, "flatten_out_features should be divisible by in_features." + + # Define layers + self.layers = nn.Sequential(nn.Flatten(2, 3), nn.Linear(self.flat_in_dim, hidden_dim), nn.LayerNorm(hidden_dim), + activation() if activation else nn.Identity(), nn.Linear(hidden_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + activation() if activation else nn.Identity(), + nn.Linear(hidden_dim, flatten_out_features)) + + def forward(self, x): + # x shape: (batch, seq, num_groups, in_features) + return self.layers(x) + + +class LinearProjection(nn.Module): + + def __init__(self, in_features, flatten_out_features, num_groups): + super().__init__() + + # Calculate flattened input dimension + self.flat_in_dim = in_features * num_groups + self.projection_layer = nn.Linear(in_features=self.flat_in_dim, out_features=flatten_out_features, bias=False) + + def forward(self, x): + # x shape: (batch, seq, num_groups, in_features) + batch_size, t, _, _ = x.size() + x_flattened = x.reshape(batch_size, t, -1) # Flattening num_groups and in_features + return self.projection_layer(x_flattened) + + +class DepthwiseConvProjection(nn.Module): + + def __init__(self, in_features, flatten_out_features, num_groups, depth): + super().__init__() + d_out = flatten_out_features // in_features + + self.conv = nn.Conv2d(in_channels=num_groups, + out_channels=num_groups * d_out, + kernel_size=(1, depth), + groups=num_groups) + + self.fc = nn.Linear(num_groups * d_out * (in_features - depth + 1), flatten_out_features) + + def forward(self, x): + # Swap the dimensions of k and t to match expected input for depthwise convolution + x = x.permute(0, 2, 1, 3) # shape: (b, k, t, d) + + # Convolutional layer + x = self.conv(x) # shape: (b, k*d_out, t, d-depth+1) + + # Reshape the tensor for the Linear layer + batch_size, _, t, _ = x.size() + x = x.reshape(batch_size, t, -1) + return self.fc(x) + + +def get_projection_layer(input_shape: Tuple[int], output_shape: Tuple[int], proj_type: str) -> nn.Module: + """ This function returns one of the projection layers defined below. """ + if len(input_shape) == 2: + _, d_source = input_shape + elif len(input_shape) == 3: + _, k_source, d_source = input_shape + if len(output_shape) == 2: + _, d_target = output_shape + elif len(output_shape) == 3: + _, k_target, d_target = output_shape + + if 'linear' == proj_type: + return LinearProjection(in_features=d_source, flatten_out_features=d_target, num_groups=k_source) + elif 'mlp' in proj_type: + if 'gelu' in proj_type: + return FlattenMLP(in_features=d_source, + flatten_out_features=d_target, + num_groups=k_source, + activation=nn.GELU) + elif 'relu' in proj_type: + return FlattenMLP(in_features=d_source, + flatten_out_features=d_target, + num_groups=k_source, + activation=nn.ReLU) + else: + return FlattenMLP(in_features=d_source, flatten_out_features=d_target, num_groups=k_source, activation=None) + elif 'conv' in proj_type: + if 'conv4' == proj_type: + return DepthwiseConvProjection(in_features=d_source, + flatten_out_features=d_target, + num_groups=k_source, + depth=4) + elif 'conv16' == proj_type: + return DepthwiseConvProjection(in_features=d_source, + flatten_out_features=d_target, + num_groups=k_source, + depth=16) + elif 'conv32' == proj_type: + return DepthwiseConvProjection(in_features=d_source, + flatten_out_features=d_target, + num_groups=k_source, + depth=32) + elif 'conv64' == proj_type: + return DepthwiseConvProjection(in_features=d_source, + flatten_out_features=d_target, + num_groups=k_source, + depth=64) + else: # conv depth 1 + return DepthwiseConvProjection(in_features=d_source, + flatten_out_features=d_target, + num_groups=k_source, + depth=1) + elif 'group_linear' == proj_type: + assert d_source % k_source == 0, "d_source and k_source must be divisible for group_linear projection." + return GroupLinearFlatten(in_features=d_source, + flatten_out_features=d_target, + num_groups=k_source, + use_bmm=True) + else: + raise ValueError(f"Invalid projection type: {proj_type}") + + +def test_projection_layers(): + # encoder hidden states: (B, T, K, D) + b = 2 + t = 110 #10 + k = 24 #16 + d = 128 + enc_hs = torch.randn(b, t, k, d) + + # target shape: (B, T, K, D//4) + target_flatten_d = 512 + + # GroupLinear + gl = GroupLinearFlatten(in_features=d, flatten_out_features=target_flatten_d, num_groups=k, use_bmm=True) + enc_hs_hat = gl(enc_hs) + assert enc_hs_hat.shape == (b, t, target_flatten_d) + print('GroupLinear: ', f'{count_parameters(gl)//1000}k') # 65k + + # FlattenMLP + fm = FlattenMLP(in_features=d, + flatten_out_features=target_flatten_d, + num_groups=k, + hidden_dim=None, + activation=nn.GELU) + enc_hs_hat = fm(enc_hs) + assert enc_hs_hat.shape == (b, t, target_flatten_d) + print('FlattenMLP: ', f'{count_parameters(fm)//1000}k') # 3.6M + + # LinearProjection + lp = LinearProjection(in_features=d, flatten_out_features=target_flatten_d, num_groups=k) + enc_hs_hat = lp(enc_hs) + assert enc_hs_hat.shape == (b, t, target_flatten_d) + print('LinearProjection: ', f'{count_parameters(lp)//1000}k') # 1M + + # DepthwiseConvProjection + dc = DepthwiseConvProjection(in_features=d, flatten_out_features=target_flatten_d, num_groups=k, depth=16) + enc_hs_hat = dc(enc_hs) + assert enc_hs_hat.shape == (b, t, target_flatten_d) + print('DepthwiseConvProjection: ', f'{count_parameters(dc)//1000}k') # 4M diff --git a/amt/src/model/spectrogram.py b/amt/src/model/spectrogram.py new file mode 100644 index 0000000000000000000000000000000000000000..fb4b38373f9d73c63535df407007c956c3a49e7c --- /dev/null +++ b/amt/src/model/spectrogram.py @@ -0,0 +1,225 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +"""spectrogram.py""" +import importlib +from typing import Optional, Literal, Dict, Tuple +from packaging.version import parse as VersionParse + +import torch +import torch.nn as nn +from einops import rearrange +from model.ops import minmax_normalize +from config.config import audio_cfg as default_audio_cfg +""" +Example usage: + +# MT3 setup +>>> hop = 8 ms or 128 samples +>>> melspec = Melspectrogram(sample_rate=16000, n_fft=2048, hop_length=128, + f_min=50, f_max=8000, n_mels=512) +>>> x = torch.randn(2, 1, 32767) # (B, C=1, T): 2.048 s +>>> y = melspec(x) # (2, 256, 512) (B, T, F) + +# PerceiverTF-like setup +>>> hop = 18.75 ms or 300 samples +>>> spec = Spectrogram(n_fft=2048, hop_length=300) + ) +>>> x = torch.randn(2, 1, 95999) # (B, C=1, T): 6.000 s +>>> y = spec(x) # (2, 320, 1024) (B, T, F) + +# Hybrid setup (2.048 seconds segment and spectrogram with hop=300) +>>> hop = 18.75 ms or 300 samples +>>> spec = Spectrogram(n_fft=2048, hop_length=300) +>>> x = torch.randn(2, 1, 32767) # (B, C=1, T): 2.048 s +>>> y = spec(x) # (2, 110, 1024) (B, T, F) + +# PerceiverTF-like setup, hop=256 +>>> hop = 16 ms or 256 samples +>>> spec256 = Spectrogram(sample_rate=16000, n_fft=2048, hop_length=256, + f_min=20, f_max=8000, n_mels=256) +>>> x = torch.randn(2, 1, 32767) # (B, C=1, T): 2.048 s +>>> y = spec256(x) # (2, 128, 1024) (B, T, F) +""" + + +def optional_compiler_disable(func): + if VersionParse(torch.__version__) >= VersionParse("2.1"): + # If the version is 2.1 or higher, apply the torch.compiler.disable decorator. + return torch.compiler.disable(func) + else: + # If the version is below 2.1, return the original function. + return func + + +# ------------------------------------------------------------------------------------- +# Log-Mel spectrogram +# ------------------------------------------------------------------------------------- +class Melspectrogram(nn.Module): + + def __init__( + self, + audio_backend: Literal['torchaudio', 'nnaudio'] = 'torchaudio', + sample_rate: int = 16000, + n_fft: int = 2048, + hop_length: int = 128, + f_min: int = 50, # 20 Hz in the MT3 paper, but we can only use 20 Hz with nnAudio + f_max: Optional[int] = 8000, + n_mels: int = 512, + eps: float = 1e-5, + **kwargs, + ): + """ + Log-Melspectrogram + + Args: + audio_backend (str): 'torchaudio' or 'nnaudio' + sample_rate (int): sample rate in Hz + n_fft (int): FFT window size + hop_length (int): hop length in samples + f_min (int): minimum frequency in Hz + f_max (int): maximum frequency in Hz + n_mels (int): number of mel frequency bins + eps (float): epsilon for numerical stability + + """ + super(Melspectrogram, self).__init__() + self.audio_backend = audio_backend.lower() + + if audio_backend.lower() == 'torchaudio': + torchaudio = importlib.import_module('torchaudio') + self.mel_stft = torchaudio.transforms.MelSpectrogram( + sample_rate=sample_rate, + n_fft=n_fft, + hop_length=hop_length, + f_min=f_min, + f_max=f_max, + n_mels=n_mels, + ) + elif audio_backend.lower() == 'nnaudio': + nnaudio = importlib.import_module('nnAudio.features') + self.mel_stft_nnaudio = nnaudio.mel.MelSpectrogram( + sr=sample_rate, + win_length=n_fft, + n_mels=n_mels, + hop_length=hop_length, + fmin=20, #f_min, + fmax=f_max) + else: + raise NotImplementedError(audio_backend) + self.eps = eps + + @optional_compiler_disable + def forward(self, x: torch.Tensor) -> torch.Tensor: # (B, 1, T) + """ + Args: + x (torch.Tensor): (B, 1, T) + + Returns: + torch.Tensor: (B, T, F) + + """ + if self.audio_backend == 'torchaudio': + x = self.mel_stft(x) # (B, 1, F, T) + x = rearrange(x, 'b 1 f t -> b t f') + x = minmax_normalize(torch.log(x + self.eps)) + # some versions of torchaudio returns nan when input is all-zeros + return torch.nan_to_num(x) + + elif self.audio_backend == 'nnaudio': + x = self.mel_stft_nnaudio(x) # (B, F, T) + x = rearrange(x, 'b f t -> b t f') + x = minmax_normalize(torch.log(x + self.eps)) + return x + + +# ------------------------------------------------------------------------------------- +# Log-spectrogram +# ------------------------------------------------------------------------------------- +class Spectrogram(nn.Module): + + def __init__( + self, + audio_backend: Literal['torchaudio', 'nnaudio'] = 'torchaudio', + n_fft: int = 2048, + hop_length: int = 128, + eps: float = 1e-5, + **kwargs, + ): + """ + Log-Magnitude Spectrogram + + Args: + audio_backend (str): 'torchaudio' or 'nnaudio' + n_fft (int): FFT window size, creates n_fft // 2 + 1 freq-bins + hop_length (int): hop length in samples + eps (float): epsilon for numerical stability + + """ + super(Spectrogram, self).__init__() + self.audio_backend = audio_backend.lower() + + if audio_backend.lower() == 'torchaudio': + torchaudio = importlib.import_module('torchaudio') + self.stft = torchaudio.transforms.Spectrogram(n_fft=n_fft, + hop_length=hop_length, + window_fn=torch.hann_window, + power=1.) # (B, 1, F, T), remove DC component + elif audio_backend.lower() == 'nnaudio': + # TODO: nnAudio spectrogram + raise NotImplementedError(audio_backend) + else: + raise NotImplementedError(audio_backend) + self.eps = eps + + @optional_compiler_disable + def forward(self, x: torch.Tensor) -> torch.Tensor: # (B, 1, T) + """ + Args: + x (torch.Tensor): (B, 1, T) + + Returns: + torch.Tensor: (B, T, F) + + """ + if self.audio_backend == 'torchaudio': + x = self.stft(x)[:, :, 1:, :] # (B, 1, F, T) remove DC component + x = rearrange(x, 'b 1 f t -> b t f') + x = minmax_normalize(torch.log(x + self.eps)) + return torch.nan_to_num(x) # some versions of torchaudio returns nan when input is all-zeros + elif self.audio_backend == 'nnaudio': + raise NotImplementedError(self.audio_backend) + + +def get_spectrogram_layer_from_audio_cfg(audio_cfg: Optional[Dict] = None) -> Tuple[nn.Module, Tuple[int]]: + """Get mel-/spectrogram layer from config. + - Used by 'ymt3' to create a spectrogram layer. + - Returns output shape of spectrogram layer, which is used to determine input shape of model. + + Args: + audio_cfg (dict): see config/config.py + + Returns: + layer (nn.Module): mel-/spectrogram layer + output_shape (tuple): inferred output shape of layer excluding batch dim. (T, F) + """ + if audio_cfg is None: + audio_cfg = default_audio_cfg + + if audio_cfg['codec'] == 'melspec': + layer = Melspectrogram(**audio_cfg) + elif audio_cfg['codec'] == 'spec': + layer = Spectrogram(**audio_cfg) + else: + raise NotImplementedError(audio_cfg['codec']) + + # Infer output shape of the spectrogram layer + with torch.no_grad(): + output_shape = layer(torch.randn(1, 1, audio_cfg['input_frames'])).shape[1:] + return layer, output_shape diff --git a/amt/src/model/t5mod.py b/amt/src/model/t5mod.py new file mode 100644 index 0000000000000000000000000000000000000000..523a26019c0af5716428b1df9abd9449a6ffddb4 --- /dev/null +++ b/amt/src/model/t5mod.py @@ -0,0 +1,687 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +# ============================================================================== +# Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +from typing import Optional, Tuple, Union, Dict +from einops import rearrange +from model.ops import count_parameters + +import torch +from torch import nn +from torch.utils.checkpoint import checkpoint +from transformers.utils import logging +from transformers.utils.model_parallel_utils import assert_device_map, get_device_map +from transformers.models.t5.modeling_t5 import (T5LayerNorm, T5LayerSelfAttention, T5LayerCrossAttention, T5LayerFF) +from transformers.modeling_outputs import (BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions) +from transformers import T5Config, T5PreTrainedModel +from model.positional_encoding import FixedSinusoidalPositionalEmbedding +from model.ff_layer import get_ff_layer + +logger = logging.get_logger(__name__) + + +class T5BlockYMT3(nn.Module): + """T5 Block, modified to allow using different types of FF layers.""" + + def __init__(self, config, has_relative_attention_bias=False): + super().__init__() + self.is_decoder = config.is_decoder + self.layer = nn.ModuleList() + self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) + if self.is_decoder: + self.layer.append(T5LayerCrossAttention(config)) + + # FF layer + if config.ff_layer_type == 't5_gmlp': + self.layer.append(T5LayerFF(config)) + elif config.ff_layer_type == 'moe': + config.moe_num_experts = 8 + config.moe_topk = 2 + config.hidden_act = 'silu' + moe = get_ff_layer(config, input_size=config.d_model, widening_factor=config.ff_widening_factor) + self.layer.append(moe) + else: + raise ValueError(f"Unknown FF layer type: {config.ff_layer_type}.") + self.ff_layer_type = config.ff_layer_type + + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + encoder_decoder_position_bias=None, + layer_head_mask=None, + cross_attn_layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + return_dict=True, + ): + if past_key_value is not None: + if not self.is_decoder: + logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.") + expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 + + if len(past_key_value) != expected_num_past_key_values: + raise ValueError( + f"There should be {expected_num_past_key_values} past states. " + f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" + f"Got {len(past_key_value)} past key / value states") + + self_attn_past_key_value = past_key_value[:2] + cross_attn_past_key_value = past_key_value[2:] + else: + self_attn_past_key_value, cross_attn_past_key_value = None, None + + self_attention_outputs = self.layer[0]( + hidden_states, + attention_mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=self_attn_past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states, present_key_value_state = self_attention_outputs[:2] + attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + do_cross_attention = self.is_decoder and encoder_hidden_states is not None + if do_cross_attention: + # the actual query length is unknown for cross attention + # if using past key value states. Need to inject it here + if present_key_value_state is not None: + query_length = present_key_value_state[0].shape[2] + else: + query_length = None + + cross_attention_outputs = self.layer[1]( + hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + position_bias=encoder_decoder_position_bias, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + query_length=query_length, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = cross_attention_outputs[0] + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + # Combine self attn and cross attn key value states + if present_key_value_state is not None: + present_key_value_state = present_key_value_state + cross_attention_outputs[1] + + # Keep cross-attention outputs and relative position weights + attention_outputs = attention_outputs + cross_attention_outputs[2:] + + # Apply Feed Forward layer - Modified for MoE + if self.ff_layer_type == 't5_gmlp': + hidden_states = self.layer[-1](hidden_states) + elif self.ff_layer_type == 'moe': + hidden_states = hidden_states + self.layer[-1](hidden_states)[0] # residual connection outside the MoE + else: + raise ValueError(f"Unknown FF layer type: {self.ff_layer_type}.") + + # clamp inf values to enable fp16 training + if hidden_states.dtype == torch.float16: + clamp_value = torch.where( + torch.isinf(hidden_states).any(), + torch.finfo(hidden_states.dtype).max - 1000, + torch.finfo(hidden_states.dtype).max, + ) + hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) + + outputs = (hidden_states,) + + if use_cache: + outputs = outputs + (present_key_value_state,) + attention_outputs + else: + outputs = outputs + attention_outputs + + return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + + +class T5StackYMT3(T5PreTrainedModel): + """ + T5Stack, modified for YMT3 with: + - absolute sinusoidal absolute positional encoding + """ + + def __init__( + self, + config, + ): + super().__init__(config) + self.is_decoder = config.is_decoder + + # Positional encoding (modified) + self.use_t5_trainable_pe = False + self.additive_pe = None + + pos_enc_type = getattr(config, 'position_encoding_type', 'sinusoidal') + if pos_enc_type in ['sinusoidal']: + self.additive_pe = FixedSinusoidalPositionalEmbedding(config.num_max_positions, + embedding_dim=config.d_model) + self.block = nn.ModuleList( + [T5BlockYMT3(config, has_relative_attention_bias=False) for i in range(config.num_layers)]) + elif pos_enc_type == 'trainable': + self.use_t5_trainable_pe = True + # Stack blocks + self.block = nn.ModuleList( + [T5BlockYMT3(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]) + + self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dropout = nn.Dropout(config.dropout_rate) + + # Initialize weights and apply final processing + self.post_init() + # Model parallel + self.gradient_checkpointing = False + + def forward( + self, + # input_ids=None, + inputs_embeds=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = (output_hidden_states + if output_hidden_states is not None else self.config.output_hidden_states) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError(f"You have to specify {err_msg_prefix}inputs_embeds") + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + + # mod: required for additive PE + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if use_cache is True: + assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder" + + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: + encoder_seq_length = encoder_hidden_states.shape[1] + encoder_attention_mask = torch.ones(batch_size, + encoder_seq_length, + device=inputs_embeds.device, + dtype=torch.long) + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...") + use_cache = False + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + # mod: additive absolute PE (sinusoidal) + if self.additive_pe is not None: + inputs_embeds = inputs_embeds + self.additive_pe(inputs_embeds.shape[1], past_key_values_length) + else: + pass # trinable PE is implemented in T5Block + + hidden_states = self.dropout(inputs_embeds) + + for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + + def custom_forward(*inputs): + return tuple(module(*inputs, use_cache, output_attentions)) + + return custom_forward + + layer_outputs = checkpoint( + create_custom_forward(layer_module), + hidden_states, + extended_attention_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + (present_key_value_state,) + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] if v is not None) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) + + +class T5EncoderYMT3(T5PreTrainedModel): + # _keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"] + + def __init__(self, encoder_config: Optional[Dict] = None, config: Optional[T5Config] = None): + if config is None: + config = T5Config() + if encoder_config is not None: + config = copy.deepcopy(config) + config.update(encoder_config) + + if hasattr(config, "ff_widening_factor"): + config.d_ff = int(config.d_model) * int(config.ff_widening_factor) + + config.is_decoder = False + config.use_cache = False + config.is_encoder_decoder = False + + super().__init__(config) + self.model_dim = config.d_model + + self.encoder = T5StackYMT3(config) + + # Initialize weights and apply final processing + self.post_init() + + """temporary fix for torch.compile issue""" + + def forward(self, **kwargs): + if self.training is True: + return self._forward_compile(**kwargs) + else: + return self._forward_no_compile(**kwargs) + + def _forward_no_compile(self, **kwargs): + return self._forward(**kwargs) + + @torch.compile + def _forward_compile(self, **kwargs): + return self._forward(**kwargs) + + def _forward( + self, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]: + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # Encode + encoder_outputs = self.encoder( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + head_mask=head_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return encoder_outputs + else: + return BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + +class T5DecoderYMT3(T5PreTrainedModel): + + def __init__(self, decoder_config: Optional[Dict] = None, config: Optional[T5Config] = None): + if config is None: + config = T5Config() + if decoder_config is not None: + config = copy.deepcopy(config) + config.update(decoder_config) + + if hasattr(config, "ff_widening_factor"): + config.d_ff = int(config.d_model) * int(config.ff_widening_factor) + + config.is_decoder = True + config.is_encoder_decoder = False + + super().__init__(config) + self.model_dim = config.d_model + + self.decoder = T5StackYMT3(config) + + # Initialize weights and apply final processing + self.post_init() + + """temporary fix for torch.compile issue""" + + def forward(self, **kwargs): + if self.training is True: + return self._forward_compile(**kwargs) + else: + return self._forward_no_compile(**kwargs) + + def _forward_no_compile(self, **kwargs): + return self._forward(**kwargs) + + @torch.compile + def _forward_compile(self, **kwargs): + return self._forward(**kwargs) + + def _forward( + self, + # input_ids: torch.LongTensor, # removed since embed_tokens is outside the decoder + inputs_embeds: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, # decoder_attention_mask + encoder_attention_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], BaseModelOutputWithPastAndCrossAttentions]: + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if isinstance(encoder_hidden_states, BaseModelOutput): + encoder_hidden_states = encoder_hidden_states.last_hidden_state + + # Decode + decoder_outputs = self.decoder( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + else: + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=decoder_outputs[0], + past_key_values=decoder_outputs[1], + hidden_states=decoder_outputs[2] if len(decoder_outputs) > 2 else None, + attentions=decoder_outputs[3] if len(decoder_outputs) > 3 else None, + cross_attentions=decoder_outputs[4] if len(decoder_outputs) > 4 else None, + ) + + +class MultiChannelT5Decoder(T5PreTrainedModel): + + def __init__(self, decoder_config: Optional[Dict] = None, config: Optional[T5Config] = None): + if config is None: + config = T5Config() + if decoder_config is not None: + config = copy.deepcopy(config) + config.update(decoder_config) + + if hasattr(config, "ff_widening_factor"): + config.d_ff = int(config.d_model) * int(config.ff_widening_factor) + + config.is_decoder = True + config.is_encoder_decoder = False + + super().__init__(config) + self.model_dim = config.d_model + self.decoder = T5StackYMT3(config) + + # Multi-channel parameters + self.num_channels = config.num_channels + + # Initialize weights and apply final processing + self.post_init() + + """temporary fix for torch.compile issue""" + + def forward(self, **kwargs): + if self.training is True: + return self._forward_compile(**kwargs) + else: + return self._forward_no_compile(**kwargs) + + def _forward_no_compile(self, **kwargs): + return self._forward(**kwargs) + + @torch.compile + def _forward_compile(self, **kwargs): + return self._forward(**kwargs) + + def _forward( + self, + # input_ids: torch.LongTensor, # removed since embed_tokens is outside the decoder + inputs_embeds: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, # decoder_attention_mask + encoder_attention_mask: Optional[torch.FloatTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.FloatTensor], BaseModelOutputWithPastAndCrossAttentions]: + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + """ + Args: + inputs_embeds: torch.FloatTensor (B, K, T, D), where K is the number of channels + encoder_hidden_states: torch.FloatTensor (B, K, T, D), where K is the number of channels + + Returns: + decoder_outputs: BaseModelOutputWithPastAndCrossAttentions + last_hidden_state: torch.FloatTensor (B, K, T, D), where K is the number of channels + past_key_values: Tuple[Tuple[torch.Tensor]] + hidden_states: Tuple[torch.FloatTensor] + attentions: Tuple[torch.FloatTensor] + cross_attentions: Tuple[torch.FloatTensor] + + """ + if isinstance(encoder_hidden_states, BaseModelOutput): + encoder_hidden_states = encoder_hidden_states.last_hidden_state + + # Reshape input_embeds and encoder_hidden_states + b, k, t, d = inputs_embeds.size() + inputs_embeds = rearrange(inputs_embeds, 'b k t d -> (b k) t d') + encoder_hidden_states = rearrange(encoder_hidden_states, 'b k t d -> (b k) t d') + + # K-channel Decoding + decoder_outputs = self.decoder( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + head_mask=head_mask, + cross_attn_head_mask=cross_attn_head_mask, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=True, + ) + + # Reshape decoder_outputs + decoder_outputs['last_hidden_state'] = rearrange(decoder_outputs['last_hidden_state'], + '(b k) t d -> b k t d', + b=b, + k=k) + + if not return_dict: + # Collecting values from decoder_outputs in a specific order + outputs = ( + decoder_outputs['last_hidden_state'], + decoder_outputs.get('past_key_values', None), + decoder_outputs.get('hidden_states', None), + decoder_outputs.get('attentions', None), + decoder_outputs.get('cross_attentions', None), + ) + return tuple(v for v in outputs if v is not None) + else: + return decoder_outputs # ['last_hidden_state']: (B, K, T, D) + + +def test_multi_channel_t5_decoder(): + # Test multi-channel decoder + config = T5Config() + config.num_channels = 4 + config.d_model = 32 + config.num_layers = 2 + config.num_heads = 2 + config.num_max_positions = 64 # for positional encoding + + decoder = MultiChannelT5Decoder(decoder_config=None, config=config) + decoder.eval() + + input_emb = torch.rand(2, 4, 64, 32) # (B, K, T, D) + enc_hs = torch.rand(2, 4, 64, 32) # (B, K, T, D) + out = decoder(inputs_embeds=input_emb, encoder_hidden_states=enc_hs, return_dict=True) + # out['last_hidden_state']: (B, K, T, D) + # out['past_key_values']: Tuple[Tuple[torch.Tensor]] diff --git a/amt/src/model/t5mod_helper.py b/amt/src/model/t5mod_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..65468f316efcc8910f7cd1deb3b44db974990c3c --- /dev/null +++ b/amt/src/model/t5mod_helper.py @@ -0,0 +1,133 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +"""t5mod_helper.py""" +import torch +from torch import nn +from model.t5mod import T5DecoderYMT3, MultiChannelT5Decoder +from typing import Optional, Callable, Union, Literal + + +@torch.no_grad() +def task_cond_dec_generate(decoder: Union[T5DecoderYMT3, MultiChannelT5Decoder], + decoder_type: Literal["t5", "multi-t5"], + embed_tokens: nn.Embedding, + lm_head: nn.Module, + encoder_hidden_states: torch.FloatTensor, + shift_right_fn: Callable, + prefix_ids: Optional[torch.LongTensor] = None, + max_length: int = 1024, + stop_at_eos: bool = True, + eos_id: Optional[int] = 1, + pad_id: Optional[int] = 0, + decoder_start_token_id: Optional[int] = 0, + debug: bool = False) -> torch.LongTensor: + """ + Generate sequence by task conditioning on the decoder side + :An extension of transofrmers.generate() function for the model with + conditioning only on the decoder side. + + Args: + decoder: T5DecoderYMT3 or MultiChannelT5Decoder, any decoder model with T5Stack architecture + decoder_type: Literal["t5", "multi-t5"], type of decoder + embed_tokens: nn.Embedding, embedding layer for the decoder + lm_head: nn.Module, language model head + encoder_hidden_states: torch.FloatTensor, (B, T, D) or (B, K, T, D) last hidden states + shift_right_fn: Callable, shift_right function of the decoder + prefix_ids: torch.LongTensor, (B, prefix_len) prefix ids typically used as task conditioning to decoder. + max_length: int, max token length to generate (default is 1024) + stop_at_eos: bool, whether to early-stop when all predictions in the batch are the token. + eos_id: int, the id of the token (default is 1) + pad_id: int, the id of the token (default is 0) + decoder_start_token_id: int, the id of the token (default is 0) + debug: bool, whether to print debug information + + Returns: + pred_ids: torch.LongTensor, (B, task_len + N) or (B, C, task_len + N) predicted token ids + """ + bsz = encoder_hidden_states.shape[0] + device = encoder_hidden_states.device + + # Prepare dec_input_shape: (B, 1) or (B, C, 1) + if decoder_type == "t5": + dec_input_shape = (bsz, 1) + elif decoder_type == "multi-t5": + dec_input_shape = (bsz, decoder.num_channels, 1) + else: + raise ValueError(f"decoder_type {decoder_type} is not supported.") + + # Prepare dec_input_ids: + task_prefix_token (B, prefix_len + 1) or (B, C, prefix_len + 1) + if prefix_ids is not None and prefix_ids.numel() > 0: + dec_input_ids = shift_right_fn(prefix_ids) + prefix_length = prefix_ids.shape[-1] + else: + # if prefix_ids is None, use as initial inSput + dec_input_ids = torch.tile(torch.LongTensor([decoder_start_token_id]).to(device), dec_input_shape) + prefix_length = 0 + dec_inputs_embeds = embed_tokens(dec_input_ids) # (B, L, D) or (B, C, L, D) + + # Generate decoder hidden state and past_key_values using prefix: + """ + - initial inputs_embeds can be a sequence, without using past_key_values + - dec_hs: (B, 1, D) + - past_key_values: Tuple of length M for M layers of decoder + - pred_ids: (B, prefix_len) where N is the length of prefix_ids + """ + dec_hs, past_key_values = decoder(inputs_embeds=dec_inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + return_dict=False) + logits = lm_head(dec_hs) # (b, T=1, vocab_size) or (b, C, T=1, vocab_size) + pred_ids = logits.argmax(-1) # (B, prefix_len + 1) or (B, C, prefix_len + 1) + + # keep track of which sequences are already finished + unfinished_sequences = torch.ones(dec_input_shape, dtype=torch.long, device=device) + + # Fast generation with past_key_values for the rest of the sequence + if decoder_type == "t5": + dec_input_ids = pred_ids[:, -1].unsqueeze(-1) # (B, 1) + elif decoder_type == "multi-t5": + dec_input_ids = pred_ids[:, :, -1].unsqueeze(-1) # (B, C, 1) + for i in range(max_length - prefix_length - 1): # -1 for token + if debug: + past_key_values_length = past_key_values[0][0].shape[ + 2] # past_key_values_length determines the positional embedding + print(f'i = {i}, past_key_values_length = {past_key_values_length}, pred_ids.shape = {pred_ids.shape}') + + # when past_key_values is provided, we use only the last token as input_ids + dec_inputs_embeds = embed_tokens(dec_input_ids) # (B, 1, D) or (B, C, 1, D) + dec_hs, _past_key_values = decoder(inputs_embeds=dec_inputs_embeds, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + return_dict=False) + logits = lm_head(dec_hs) # (b, 1, vocab_size) or (b, K, 1, vocab_size) + _pred_ids = logits.argmax(-1) # (B, 1) or (B, K, 1) + + # update input_ids and past_key_values for next iteration + dec_input_ids = _pred_ids.clone( + ) # (B, 1) or (B, C, 1), deepcopy of _pred_ids because _pred_ids will be modified for finished sentences + past_key_values = _past_key_values + + # finished sentences should have their next token be a padding token + if eos_id is not None: + if pad_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + _pred_ids = _pred_ids * unfinished_sequences + pad_id * (1 - unfinished_sequences) + + # update pred_ids + pred_ids = torch.cat((pred_ids, _pred_ids), dim=-1) # (B, T') or (B, C, T') with increasing T' + + # update state of unfinished_sequences + if eos_id is not None: + unfinished_sequences = unfinished_sequences * _pred_ids.ne(eos_id).long() + + # early-stop when each sentence is finished + if stop_at_eos is True and unfinished_sequences.max() == 0: + break + + return pred_ids # (B, L) or (B, C, L) diff --git a/amt/src/model/ymt3.py b/amt/src/model/ymt3.py new file mode 100644 index 0000000000000000000000000000000000000000..89d27654191f5c0de6d9599a1fb9a2626bbd258a --- /dev/null +++ b/amt/src/model/ymt3.py @@ -0,0 +1,967 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +"""ymt3.py""" +import os +from typing import Union, Optional, Tuple, Dict, List, Any +from collections import Counter + +import torch +import torch.nn as nn +from torch.nn import CrossEntropyLoss +import torchaudio # for debugging audio +import pytorch_lightning as pl +import numpy as np +import wandb +from einops import rearrange + +from transformers import T5Config +from model.t5mod import T5EncoderYMT3, T5DecoderYMT3, MultiChannelT5Decoder +from model.t5mod_helper import task_cond_dec_generate +from model.perceiver_mod import PerceiverTFEncoder +from model.perceiver_helper import PerceiverTFConfig +from model.conformer_mod import ConformerYMT3Encoder +from model.conformer_helper import ConformerYMT3Config +from model.lm_head import LMHead +from model.pitchshift_layer import PitchShiftLayer +from model.spectrogram import get_spectrogram_layer_from_audio_cfg +from model.conv_block import PreEncoderBlockRes3B +from model.conv_block import PreEncoderBlockHFTT, PreEncoderBlockRes3BHFTT # added for hFTT-like pre-encoder +from model.projection_layer import get_projection_layer, get_multi_channel_projection_layer +from model.optimizers import get_optimizer +from model.lr_scheduler import get_lr_scheduler + +from utils.note_event_dataclasses import Note +from utils.note2event import mix_notes +from utils.event2note import merge_zipped_note_events_and_ties_to_notes, DECODING_ERR_TYPES +from utils.metrics import compute_track_metrics +from utils.metrics import AMTMetrics +# from utils.utils import write_model_output_as_npy +from utils.utils import write_model_output_as_midi, create_inverse_vocab, write_err_cnt_as_json +from utils.utils import Timer +from utils.task_manager import TaskManager + +from config.config import audio_cfg as default_audio_cfg +from config.config import model_cfg as default_model_cfg +from config.config import shared_cfg as default_shared_cfg +from config.config import T5_BASE_CFG + + +class YourMT3(pl.LightningModule): + """YourMT3: + + Lightning wrapper for multi-task music transcription Transformer. + + """ + + def __init__( + self, + audio_cfg: Optional[Dict] = None, + model_cfg: Optional[Dict] = None, + shared_cfg: Optional[Dict] = None, + pretrained: bool = False, + optimizer_name: str = 'adamwscale', + scheduler_name: str = 'cosine', + base_lr: float = None, # None: 'auto' for AdaFactor, 1e-3 for constant, 1e-2 for cosine + max_steps: Optional[int] = None, + weight_decay: float = 0.0, + init_factor: Optional[Union[str, float]] = None, + task_manager: TaskManager = TaskManager(), + eval_subtask_key: Optional[str] = "default", + eval_vocab: Optional[Dict] = None, + eval_drum_vocab: Optional[Dict] = None, + write_output_dir: Optional[str] = None, + write_output_vocab: Optional[Dict] = None, + onset_tolerance: float = 0.05, + add_pitch_class_metric: Optional[List[str]] = None, + add_melody_metric_to_singing: bool = True, + test_optimal_octave_shift: bool = False, + test_pitch_shift_layer: Optional[str] = None, + **kwargs: Any) -> None: + super().__init__() + if pretrained is True: + raise NotImplementedError("Pretrained model is not supported in this version.") + self.test_pitch_shift_layer = test_pitch_shift_layer # debug only + + # Config + if model_cfg is None: + model_cfg = default_model_cfg # default config, not overwritten by args of trainer + if audio_cfg is None: + audio_cfg = default_audio_cfg # default config, not overwritten by args of trainer + if shared_cfg is None: + shared_cfg = default_shared_cfg # default config, not overwritten by args of trainer + + # Spec Layer (need to define here to infer max token length) + self.spectrogram, spec_output_shape = get_spectrogram_layer_from_audio_cfg( + audio_cfg) # can be spec or melspec; output_shape is (T, F) + model_cfg["feat_length"] = spec_output_shape[0] # T of (T, F) + + # Task manger and Tokens + self.task_manager = task_manager + self.max_total_token_length = self.task_manager.max_total_token_length + + # Task Conditioning + self.use_task_cond_encoder = bool(model_cfg["use_task_conditional_encoder"]) + self.use_task_cond_decoder = bool(model_cfg["use_task_conditional_decoder"]) + + # Select Encoder type, Model-specific Config + assert model_cfg["encoder_type"] in ["t5", "perceiver-tf", "conformer"] + assert model_cfg["decoder_type"] in ["t5", "multi-t5"] + self.encoder_type = model_cfg["encoder_type"] # {"t5", "perceiver-tf", "conformer"} + self.decoder_type = model_cfg["decoder_type"] # {"t5", "multi-t5"} + encoder_config = model_cfg["encoder"][self.encoder_type] # mutable + decoder_config = model_cfg["decoder"][self.decoder_type] # mutable + + # Positional Encoding + if isinstance(model_cfg["num_max_positions"], str) and model_cfg["num_max_positions"] == 'auto': + encoder_config["num_max_positions"] = int(model_cfg["feat_length"] + + self.task_manager.max_task_token_length + 10) + decoder_config["num_max_positions"] = int(self.max_total_token_length + 10) + else: + assert isinstance(model_cfg["num_max_positions"], int) + encoder_config["num_max_positions"] = model_cfg["num_max_positions"] + decoder_config["num_max_positions"] = model_cfg["num_max_positions"] + + # Select Pre-Encoder and Pre-Decoder type + if model_cfg["pre_encoder_type"] == "default": + model_cfg["pre_encoder_type"] = model_cfg["pre_encoder_type_default"].get(model_cfg["encoder_type"], None) + elif model_cfg["pre_encoder_type"] in [None, "none", "None", "0"]: + model_cfg["pre_encoder_type"] = None + if model_cfg["pre_decoder_type"] == "default": + model_cfg["pre_decoder_type"] = model_cfg["pre_decoder_type_default"].get(model_cfg["encoder_type"]).get( + model_cfg["decoder_type"], None) + elif model_cfg["pre_decoder_type"] in [None, "none", "None", "0"]: + model_cfg["pre_decoder_type"] = None + self.pre_encoder_type = model_cfg["pre_encoder_type"] + self.pre_decoder_type = model_cfg["pre_decoder_type"] + + # Pre-encoder + self.pre_encoder = nn.Sequential() + if self.pre_encoder_type in ["conv", "conv1d_t", "conv1d_f"]: + kernel_size = (3, 3) + avp_kernel_size = (1, 2) + if self.pre_encoder_type == "conv1d_t": + kernel_size = (3, 1) + elif self.pre_encoder_type == "conv1d_f": + kernel_size = (1, 3) + self.pre_encoder.append( + PreEncoderBlockRes3B(1, + model_cfg["conv_out_channels"], + kernel_size=kernel_size, + avp_kernerl_size=avp_kernel_size, + activation="relu")) + pre_enc_output_shape = (spec_output_shape[0], spec_output_shape[1] // 2**3, model_cfg["conv_out_channels"] + ) # (T, F, C) excluding batch dim + elif self.pre_encoder_type == "hftt": + self.pre_encoder.append(PreEncoderBlockHFTT()) + pre_enc_output_shape = (spec_output_shape[0], spec_output_shape[1], 128) # (T, F, C) excluding batch dim + elif self.pre_encoder_type == "res3b_hftt": + self.pre_encoder.append(PreEncoderBlockRes3BHFTT()) + pre_enc_output_shape = (spec_output_shape[0], spec_output_shape[1] // 2**3, 128) + else: + pre_enc_output_shape = spec_output_shape # (T, F) excluding batch dim + + # Auto-infer `d_feat` and `d_model`, `vocab_size`, and `num_max_positions` + if isinstance(model_cfg["d_feat"], str) and model_cfg["d_feat"] == 'auto': + if self.encoder_type == "perceiver-tf" and encoder_config["attention_to_channel"] is True: + model_cfg["d_feat"] = pre_enc_output_shape[-2] # TODO: better readablity + else: + model_cfg["d_feat"] = pre_enc_output_shape[-1] # C of (T, F, C) or F or (T, F) + + if self.encoder_type == "perceiver-tf" and isinstance(encoder_config["d_model"], str): + if encoder_config["d_model"] == 'q': + encoder_config["d_model"] = encoder_config["d_latent"] + elif encoder_config["d_model"] == 'kv': + encoder_config["d_model"] = model_cfg["d_feat"] + else: + raise ValueError(f"Unknown d_model: {encoder_config['d_model']}") + + # # required for PerceiverTF with attention_to_channel option + # if self.encoder_type == "perceiver-tf": + # if encoder_config["attention_to_channel"] is True: + # encoder_config["kv_dim"] = model_cfg["d_feat"] # TODO: better readablity + # else: + # encoder_config["kv_dim"] = model_cfg["conv_out_channels"] + + if isinstance(model_cfg["vocab_size"], str) and model_cfg["vocab_size"] == 'auto': + model_cfg["vocab_size"] = task_manager.num_tokens + + if isinstance(model_cfg["num_max_positions"], str) and model_cfg["num_max_positions"] == 'auto': + model_cfg["num_max_positions"] = int( + max(model_cfg["feat_length"], model_cfg["event_length"]) + self.task_manager.max_task_token_length + 10) + + # Pre-decoder + self.pre_decoder = nn.Sequential() + if self.encoder_type == "perceiver-tf" and self.decoder_type == "t5": + t, f, c = pre_enc_output_shape # perceiver-tf: (110, 128, 128) for 2s + encoder_output_shape = (t, encoder_config["num_latents"], encoder_config["d_latent"]) # (T, K, D_source) + decoder_input_shape = (t, decoder_config["d_model"]) # (T, D_target) + proj_layer = get_projection_layer(input_shape=encoder_output_shape, + output_shape=decoder_input_shape, + proj_type=self.pre_decoder_type) + self.pre_encoder_output_shape = pre_enc_output_shape + self.encoder_output_shape = encoder_output_shape + self.decoder_input_shape = decoder_input_shape + self.pre_decoder.append(proj_layer) + elif self.encoder_type in ["t5", "conformer"] and self.decoder_type == "t5": + pass + elif self.encoder_type == "perceiver-tf" and self.decoder_type == "multi-t5": + # NOTE: this is experiemental, only for multi-channel decoding with 13 classes + assert encoder_config["num_latents"] % decoder_config["num_channels"] == 0 + encoder_output_shape = (encoder_config["num_latents"], encoder_config["d_model"]) + decoder_input_shape = (decoder_config["num_channels"], decoder_config["d_model"]) + proj_layer = get_multi_channel_projection_layer(input_shape=encoder_output_shape, + output_shape=decoder_input_shape, + proj_type=self.pre_decoder_type) + self.pre_decoder.append(proj_layer) + else: + raise NotImplementedError( + f"Encoder type {self.encoder_type} and decoder type {self.decoder_type} is not implemented yet.") + + # Positional Encoding, Vocab, etc. + if self.encoder_type in ["t5", "conformer"]: + encoder_config["num_max_positions"] = decoder_config["num_max_positions"] = model_cfg["num_max_positions"] + else: # perceiver-tf uses separate positional encoding + encoder_config["num_max_positions"] = model_cfg["feat_length"] + decoder_config["num_max_positions"] = model_cfg["num_max_positions"] + encoder_config["vocab_size"] = decoder_config["vocab_size"] = model_cfg["vocab_size"] + + # Print and save updated configs + self.audio_cfg = audio_cfg + self.model_cfg = model_cfg + self.shared_cfg = shared_cfg + self.save_hyperparameters() + if self.global_rank == 0: + print(self.hparams) + + # Encoder and Decoder and LM-head + self.encoder = None + self.decoder = None + self.lm_head = LMHead(decoder_config, 1.0, model_cfg["tie_word_embeddings"]) + self.embed_tokens = nn.Embedding(decoder_config["vocab_size"], decoder_config["d_model"]) + self.embed_tokens.weight.data.normal_(mean=0.0, std=1.0) + self.shift_right_fn = None + self.set_encoder_decoder() # shift_right_fn is also set here + + # Model as ModuleDict + # self.model = nn.ModuleDict({ + # "pitchshift": self.pitchshift, # no grad; created in setup() only for training, + # and called by training_step() + # "spectrogram": self.spectrogram, # no grad + # "pre_encoder": self.pre_encoder, + # "encoder": self.encoder, + # "pre_decoder": self.pre_decoder, + # "decoder": self.decoder, + # "embed_tokens": self.embed_tokens, + # "lm_head": self.lm_head, + # }) + + # Tables (for logging) + columns = ['Ep', 'Track ID', 'Pred Events', 'Actual Events', 'Pred Notes', 'Actual Notes'] + self.sample_table = wandb.Table(columns=columns) + + # Output MIDI + if write_output_dir is not None: + if write_output_vocab is None: + from config.vocabulary import program_vocab_presets + self.midi_output_vocab = program_vocab_presets["gm_ext_plus"] + else: + self.midi_output_vocab = write_output_vocab + self.midi_output_inverse_vocab = create_inverse_vocab(self.midi_output_vocab) + + def set_encoder_decoder(self) -> None: + """Set encoder, decoder, lm_head and emb_tokens from self.model_cfg""" + + # Generate and update T5Config + t5_basename = self.model_cfg["t5_basename"] + if t5_basename in T5_BASE_CFG.keys(): + # Load from pre-defined config in config.py + t5_config = T5Config(**T5_BASE_CFG[t5_basename]) + else: + # Load from HuggingFace hub + t5_config = T5Config.from_pretrained(t5_basename) + + # Create encoder, decoder, lm_head and embed_tokens + if self.encoder_type == "t5": + self.encoder = T5EncoderYMT3(self.model_cfg["encoder"]["t5"], t5_config) + elif self.encoder_type == "perceiver-tf": + perceivertf_config = PerceiverTFConfig() + perceivertf_config.update(self.model_cfg["encoder"]["perceiver-tf"]) + self.encoder = PerceiverTFEncoder(perceivertf_config) + elif self.encoder_type == "conformer": + conformer_config = ConformerYMT3Config() + conformer_config.update(self.model_cfg["encoder"]["conformer"]) + self.encoder = ConformerYMT3Encoder(conformer_config) + + if self.decoder_type == "t5": + self.decoder = T5DecoderYMT3(self.model_cfg["decoder"]["t5"], t5_config) + elif self.decoder_type == "multi-t5": + self.decoder = MultiChannelT5Decoder(self.model_cfg["decoder"]["multi-t5"], t5_config) + + # `shift_right` function for decoding + self.shift_right_fn = self.decoder._shift_right + + def setup(self, stage: str) -> None: + # Defining metrics + if self.hparams.eval_vocab is None: + extra_classes_per_dataset = [None] + else: + extra_classes_per_dataset = [ + list(v.keys()) if v is not None else None for v in self.hparams.eval_vocab + ] # e.g. [['Piano'], ['Guitar'], ['Piano'], ['Piano', 'Strings', 'Winds'], None] + + # For direct addition of extra metrics using full metric name + extra_metrics = None + if self.hparams.add_melody_metric_to_singing is True: + extra_metrics = ["melody_rpa_Singing Voice", "melody_rca_Singing Voice", "melody_oa_Singing Voice"] + + # Add pitch class metric + if self.hparams.add_pitch_class_metric is not None: + for sublist in extra_classes_per_dataset: + for name in self.hparams.add_pitch_class_metric: + if sublist is not None and name in sublist: + sublist += [name + "_pc"] + + extra_classes_unique = list( + set(item for sublist in extra_classes_per_dataset if sublist is not None + for item in sublist)) # e.g. ['Strings', 'Winds', 'Guitar', 'Piano'] + dm = self.trainer.datamodule + + # Train/Vaidation-only + if stage == "fit": + self.val_metrics_macro = AMTMetrics(prefix=f'validation/macro_', extra_classes=extra_classes_unique) + self.val_metrics = nn.ModuleList() # val_metric is a list of AMTMetrics objects + for i in range(dm.num_val_dataloaders): + self.val_metrics.append( + AMTMetrics(prefix=f'validation/({dm.get_val_dataset_name(i)})', + extra_classes=extra_classes_per_dataset[i], + error_types=DECODING_ERR_TYPES)) + + # Add pitchshift layer + if self.shared_cfg["AUGMENTATION"]["train_pitch_shift_range"] in [None, [0, 0]]: + self.pitchshift = None + else: + # torchaudio pitchshifter requires a dummy input for initialization in DDP + input_shape = (self.shared_cfg["BSZ"]["train_local"], 1, self.audio_cfg["input_frames"]) + self.pitchshift = PitchShiftLayer( + pshift_range=self.shared_cfg["AUGMENTATION"]["train_pitch_shift_range"], + expected_input_shape=input_shape, + device=self.device) + + # Test-only + elif stage == "test": + # self.test_metrics_macro = AMTMetrics( + # prefix=f'test/macro_', extra_classes=extra_classes_unique) + self.test_metrics = nn.ModuleList() + for i in range(dm.num_test_dataloaders): + self.test_metrics.append( + AMTMetrics(prefix=f'test/({dm.get_test_dataset_name(i)})', + extra_classes=extra_classes_per_dataset[i], + extra_metrics=extra_metrics, + error_types=DECODING_ERR_TYPES)) + + # Test pitch shift layer: debug only + if self.test_pitch_shift_layer is not None: + self.test_pitch_shift_semitone = int(self.test_pitch_shift_layer) + self.pitchshift = PitchShiftLayer( + pshift_range=[self.test_pitch_shift_semitone, self.test_pitch_shift_semitone]) + + def configure_optimizers(self) -> None: + """Configure optimizer and scheduler""" + optimizer, base_lr = get_optimizer(models_dict=self.named_parameters(), + optimizer_name=self.hparams.optimizer_name, + base_lr=self.hparams.base_lr, + weight_decay=self.hparams.weight_decay) + + if self.hparams.optimizer_name.lower() == 'adafactor' and self.hparams.base_lr == None: + print("Using AdaFactor with auto learning rate and no scheduler") + return [optimizer] + if self.hparams.optimizer_name.lower() == 'dadaptadam': + print("Using dAdaptAdam with auto learning rate and no scheduler") + return [optimizer] + elif self.hparams.base_lr == None: + print(f"Using default learning rate {base_lr} of {self.hparams.optimizer_name} as base learning rate.") + self.hparams.base_lr = base_lr + + scheduler_cfg = self.shared_cfg["LR_SCHEDULE"] + if self.hparams.max_steps != -1: + # overwrite total_steps + scheduler_cfg["total_steps"] = self.hparams.max_steps + _lr_scheduler = get_lr_scheduler(optimizer, + scheduler_name=self.hparams.scheduler_name, + base_lr=base_lr, + scheduler_cfg=scheduler_cfg) + + lr_scheduler = {'scheduler': _lr_scheduler, 'interval': 'step', 'frequency': 1} + return [optimizer], [lr_scheduler] + + def forward( + self, + x: torch.FloatTensor, + target_tokens: torch.LongTensor, + # task_tokens: Optional[torch.LongTensor] = None, + **kwargs) -> Dict: + """ Forward pass with teacher-forcing for training and validation. + Args: + x: (B, 1, T) waveform with default T=32767 + target_tokens: (B, C, N) tokenized sequence of length N=event_length + task_tokens: (B, C, task_len) tokenized task + + Returns: + { + 'logits': (B, N + task_len + 1, vocab_size) + 'loss': (1, ) + } + + NOTE: all the commented shapes are in the case of original MT3 setup. + """ + x = self.spectrogram(x) # mel-/spectrogram: (b, 256, 512) or (B, T, F) + x = self.pre_encoder(x) # projection to d_model: (B, 256, 512) + + # TODO: task_cond_encoder would not work properly because of 3-d task_tokens + # if task_tokens is not None and task_tokens.numel() > 0 and self.use_task_cond_encoder is True: + # # append task embedding to encoder input + # task_embed = self.embed_tokens(task_tokens) # (B, task_len, 512) + # x = torch.cat([task_embed, x], dim=1) # (B, task_len + 256, 512) + enc_hs = self.encoder(inputs_embeds=x)["last_hidden_state"] # (B, T', D) + enc_hs = self.pre_decoder(enc_hs) # (B, T', D) or (B, K, T, D) + + # if task_tokens is not None and task_tokens.numel() > 0 and self.use_task_cond_decoder is True: + # # append task token to decoder input and output label + # labels = torch.cat([task_tokens, target_tokens], dim=2) # (B, C, task_len + N) + # else: + # labels = target_tokens # (B, C, N) + labels = target_tokens # (B, C, N) + if labels.shape[1] == 1: # for single-channel decoders, e.g. t5. + labels = labels.squeeze(1) # (B, N) + + dec_input_ids = self.shift_right_fn(labels) # t5:(B, N), multi-t5:(B, C, N) + dec_inputs_embeds = self.embed_tokens(dec_input_ids) # t5:(B, N, D), multi-t5:(B, C, N, D) + dec_hs, _ = self.decoder(inputs_embeds=dec_inputs_embeds, encoder_hidden_states=enc_hs, return_dict=False) + + if self.model_cfg["tie_word_embeddings"] is True: + dec_hs = dec_hs * (self.model_cfg["decoder"][self.decoder_type]["d_model"]**-0.5) + + logits = self.lm_head(dec_hs) + + loss = None + labels = labels.masked_fill(labels == 0, value=-100) # ignore pad tokens for loss + loss_fct = CrossEntropyLoss(ignore_index=-100) + loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1)) + return {"logits": logits, "loss": loss} + + def inference(self, + x: torch.FloatTensor, + task_tokens: Optional[torch.LongTensor] = None, + max_token_length: Optional[int] = None, + **kwargs: Any) -> torch.Tensor: + """ Inference from audio batch by cached autoregressive decoding. + Args: + x: (b, 1, t) waveform with t=32767 + task_token: (b, c, task_len) tokenized task. If None, will not append task embeddings (from task_tokens) to input. + max_length: Maximum length of generated sequence. If None, self.max_total_token_length. + **kwargs: https://huggingface.co/docs/transformers/v4.27.2/en/main_classes/text_generation#transformers.GenerationMixin.generate + + Returns: + res_tokens: (b, n) resulting tokenized sequence of variable length < max_length + """ + if self.test_pitch_shift_layer is not None: + x_ps = self.pitchshift(x, self.test_pitch_shift_semitone) + x = x_ps + + # From spectrogram to pre-decoder is the same pipeline as in forward() + x = self.spectrogram(x) # mel-/spectrogram: (b, 256, 512) or (B, T, F) + x = self.pre_encoder(x) # projection to d_model: (B, 256, 512) + if task_tokens is not None and task_tokens.numel() > 0 and self.use_task_cond_encoder is True: + # append task embedding to encoder input + task_embed = self.embed_tokens(task_tokens) # (B, task_len, 512) + x = torch.cat([task_embed, x], dim=1) # (B, task_len + 256, 512) + enc_hs = self.encoder(inputs_embeds=x)["last_hidden_state"] # (B, task_len + 256, 512) + enc_hs = self.pre_decoder(enc_hs) # (B, task_len + 256, 512) + + # Cached-autoregressive decoding with task token (can be None) as prefix + if max_token_length is None: + max_token_length = self.max_total_token_length + + pred_ids = task_cond_dec_generate(decoder=self.decoder, + decoder_type=self.decoder_type, + embed_tokens=self.embed_tokens, + lm_head=self.lm_head, + encoder_hidden_states=enc_hs, + shift_right_fn=self.shift_right_fn, + prefix_ids=task_tokens, + max_length=max_token_length) # (B, task_len + N) or (B, C, task_len + N) + if pred_ids.dim() == 2: + pred_ids = pred_ids.unsqueeze(1) # (B, 1, task_len + N) + + if self.test_pitch_shift_layer is None: + return pred_ids + else: + return pred_ids, x_ps + + def inference_file( + self, + bsz: int, + audio_segments: torch.FloatTensor, # (n_items, 1, segment_len): from a single file + note_token_array: Optional[torch.LongTensor] = None, + task_token_array: Optional[torch.LongTensor] = None, + # subtask_key: Optional[str] = "default" + ) -> Tuple[List[np.ndarray], Optional[torch.Tensor]]: + """ Inference from audio batch by autoregressive decoding: + Args: + bsz: batch size + audio_segments: (n_items, 1, segment_len): segmented audio from a single file + note_token_array: (n_items, max_token_len): Optional. If token_array is None, will not return loss. + subtask_key: (str): If None, not using subtask prefix. By default, using "default" defined in task manager. + """ + # if subtask_key is not None: + # _subtask_token = torch.LongTensor( + # self.task_manager.get_eval_subtask_prefix_dict()[subtask_key]).to(self.device) + + n_items = audio_segments.shape[0] + loss = 0. + pred_token_array_file = [] # each element is (B, C, L) np.ndarray + x_ps_concat = [] + + for i in range(0, n_items, bsz): + if i + bsz > n_items: # last batch can be smaller + x = audio_segments[i:n_items].to(self.device) + # if subtask_key is not None: + # b = n_items - i # bsz for the last batch + # task_tokens = _subtask_token.expand((b, -1)) # (b, task_len) + if note_token_array is not None: + target_tokens = note_token_array[i:n_items].to(self.device) + if task_token_array is not None and task_token_array.numel() > 0: + task_tokens = task_token_array[i:n_items].to(self.device) + else: + task_tokens = None + else: + x = audio_segments[i:i + bsz].to(self.device) # (bsz, 1, segment_len) + # if subtask_key is not None: + # task_tokens = _subtask_token.expand((bsz, -1)) # (bsz, task_len) + if note_token_array is not None: + target_tokens = note_token_array[i:i + bsz].to(self.device) # (bsz, token_len) + if task_token_array is not None and task_token_array.numel() > 0: + task_tokens = task_token_array[i:i + bsz].to(self.device) + else: + task_tokens = None + + # token prediction (fast-autoregressive decoding) + # if subtask_key is not None: + # preds = self.inference(x, task_tokens).detach().cpu().numpy() + # else: + # preds = self.inference(x).detach().cpu().numpy() + + if self.test_pitch_shift_layer is not None: # debug only + preds, x_ps = self.inference(x, task_tokens) + preds = preds.detach().cpu().numpy() + x_ps_concat.append(x_ps.detach().cpu()) + else: + preds = self.inference(x, task_tokens).detach().cpu().numpy() + if len(preds) != len(x): + raise ValueError(f'preds: {len(preds)}, x: {len(x)}') + pred_token_array_file.append(preds) + + # validation loss (by teacher forcing) + if note_token_array is not None: + loss_weight = x.shape[0] / n_items + loss += self(x, target_tokens)['loss'] * loss_weight + # loss += self(x, target_tokens, task_tokens)['loss'] * loss_weight + else: + loss = None + + if self.test_pitch_shift_layer is not None: # debug only + if self.hparams.write_output_dir is not None: + x_ps_concat = torch.cat(x_ps_concat, dim=0) + return pred_token_array_file, loss, x_ps_concat.flatten().unsqueeze(0) + else: + return pred_token_array_file, loss + + def training_step(self, batch, batch_idx) -> torch.Tensor: + # batch: { + # 'dataset1': [Tuple[audio_segments(b, 1, t), tokens(b, max_token_len), ...]] + # 'dataset2': [Tuple[audio_segments(b, 1, t), tokens(b, max_token_len), ...]] + # 'dataset3': ... + # } + audio_segments, note_tokens, pshift_steps = [torch.cat(t, dim=0) for t in zip(*batch.values())] + + if self.pitchshift is not None: + # Pitch shift + n_groups = len(batch) + audio_segments = torch.chunk(audio_segments, n_groups, dim=0) + pshift_steps = torch.chunk(pshift_steps, n_groups, dim=0) + for p in pshift_steps: + assert p.eq(p[0]).all().item() + + audio_segments = torch.cat([self.pitchshift(a, p[0].item()) for a, p in zip(audio_segments, pshift_steps)], + dim=0) + + loss = self(audio_segments, note_tokens)['loss'] + self.log('train_loss', + loss, + on_step=True, + on_epoch=True, + prog_bar=True, + batch_size=note_tokens.shape[0], + sync_dist=True) + # print('lr', self.trainer.optimizers[0].param_groups[0]['lr']) + return loss + + def validation_step(self, batch, batch_idx, dataloader_idx=0) -> Dict: + # File-wise validation + if self.task_manager.num_decoding_channels == 1: + bsz = self.shared_cfg["BSZ"]["validation"] + else: + bsz = self.shared_cfg["BSZ"]["validation"] // self.task_manager.num_decoding_channels * 3 + # audio_segments, notes_dict, note_token_array, task_token_array = batch + audio_segments, notes_dict, note_token_array = batch + task_token_array = None + + # Loop through the tensor in chunks of bsz (=subbsz actually) + n_items = audio_segments.shape[0] + start_secs_file = [32767 * i / 16000 for i in range(n_items)] + with Timer() as t: + pred_token_array_file, loss = self.inference_file(bsz, audio_segments, note_token_array, task_token_array) + """ + notes_dict: # Ground truth notes + { + 'mtrack_id': int, + 'program': List[int], + 'is_drum': bool, + 'duration_sec': float, + 'notes': List[Note], + } + """ + # Process a list of channel-wise token arrays for a file + num_channels = self.task_manager.num_decoding_channels + pred_notes_in_file = [] + n_err_cnt = Counter() + for ch in range(num_channels): + pred_token_array_ch = [arr[:, ch, :] for arr in pred_token_array_file] # (B, L) + zipped_note_events_and_tie, list_events, ne_err_cnt = self.task_manager.detokenize_list_batches( + pred_token_array_ch, start_secs_file, return_events=True) + pred_notes_ch, n_err_cnt_ch = merge_zipped_note_events_and_ties_to_notes(zipped_note_events_and_tie) + pred_notes_in_file.append(pred_notes_ch) + n_err_cnt += n_err_cnt_ch + pred_notes = mix_notes(pred_notes_in_file) # This is the mixed notes from all channels + + if self.hparams.write_output_dir is not None: + track_info = [notes_dict[k] for k in notes_dict.keys() if k.endswith("_id")][0] + dataset_info = [k for k in notes_dict.keys() if k.endswith('_id')][0][:-3] + # write_model_output_as_npy(zipped_note_events_and_tie, self.hparams.write_output_dir, + # track_info) + write_model_output_as_midi(pred_notes, + self.hparams.write_output_dir, + track_info, + self.midi_output_inverse_vocab, + output_dir_suffix=str(dataset_info) + '_' + + str(self.hparams.eval_subtask_key)) + # generate sample text to display in log table + # pred_events_text = [str([list_events[0][:200]])] + # pred_notes_text = [str([pred_notes[:200]])] + + # this is local GPU metric per file, not global metric in DDP + drum_metric, non_drum_metric, instr_metric = compute_track_metrics( + pred_notes, + notes_dict['notes'], + eval_vocab=self.hparams.eval_vocab[dataloader_idx], + eval_drum_vocab=self.hparams.eval_drum_vocab, + onset_tolerance=self.hparams.onset_tolerance, + add_pitch_class_metric=self.hparams.add_pitch_class_metric) + self.val_metrics[dataloader_idx].bulk_update(drum_metric) + self.val_metrics[dataloader_idx].bulk_update(non_drum_metric) + self.val_metrics[dataloader_idx].bulk_update(instr_metric) + self.val_metrics_macro.bulk_update(drum_metric) + self.val_metrics_macro.bulk_update(non_drum_metric) + self.val_metrics_macro.bulk_update(instr_metric) + + # Log sample table: predicted notes and ground truth notes + # if batch_idx in (0, 1) and self.global_rank == 0: + # actual_notes_text = [str([notes_dict['notes'][:200]])] + # actual_tokens = token_array[0, :200].detach().cpu().numpy().tolist() + # actual_events_text = [str(self.tokenizer._decode(actual_tokens))] + # track_info = [notes_dict[k] for k in notes_dict.keys() if k.endswith("_id")] + # self.sample_table.add_data(self.current_epoch, track_info, pred_events_text, + # actual_events_text, pred_notes_text, actual_notes_text) + # self.logger.log_table('Samples', self.sample_table.columns, self.sample_table.data) + + decoding_time_sec = t.elapsed_time() + self.log('val_loss', loss, prog_bar=True, batch_size=n_items, sync_dist=True) + # self.val_metrics[dataloader_idx].bulk_update_errors({'decoding_time': decoding_time_sec}) + + def on_validation_epoch_end(self) -> None: + for val_metrics in self.val_metrics: + self.log_dict(val_metrics.bulk_compute(), sync_dist=True) + val_metrics.bulk_reset() + self.log_dict(self.val_metrics_macro.bulk_compute(), sync_dist=True) + self.val_metrics_macro.bulk_reset() + + def test_step(self, batch, batch_idx, dataloader_idx=0) -> Dict: + # File-wise evaluation + if self.task_manager.num_decoding_channels == 1: + bsz = self.shared_cfg["BSZ"]["validation"] + else: + bsz = self.shared_cfg["BSZ"]["validation"] // self.task_manager.num_decoding_channels * 3 + # audio_segments, notes_dict, note_token_array, task_token_array = batch + audio_segments, notes_dict, note_token_array = batch + task_token_array = None + + # Test pitch shift layer: debug only + if self.test_pitch_shift_layer is not None and self.test_pitch_shift_semitone != 0: + for n in notes_dict['notes']: + if n.is_drum == False: + n.pitch = n.pitch + self.test_pitch_shift_semitone + + # Loop through the tensor in chunks of bsz (=subbsz actually) + n_items = audio_segments.shape[0] + start_secs_file = [32767 * i / 16000 for i in range(n_items)] + + if self.test_pitch_shift_layer is not None and self.hparams.write_output_dir is not None: + pred_token_array_file, loss, x_ps = self.inference_file(bsz, audio_segments, None, None) + else: + pred_token_array_file, loss = self.inference_file(bsz, audio_segments, None, None) + if len(pred_token_array_file) > 0: + + # Process a list of channel-wise token arrays for a file + num_channels = self.task_manager.num_decoding_channels + pred_notes_in_file = [] + n_err_cnt = Counter() + for ch in range(num_channels): + pred_token_array_ch = [arr[:, ch, :] for arr in pred_token_array_file] # (B, L) + zipped_note_events_and_tie, list_events, ne_err_cnt = self.task_manager.detokenize_list_batches( + pred_token_array_ch, start_secs_file, return_events=True) + pred_notes_ch, n_err_cnt_ch = merge_zipped_note_events_and_ties_to_notes(zipped_note_events_and_tie) + pred_notes_in_file.append(pred_notes_ch) + n_err_cnt += n_err_cnt_ch + pred_notes = mix_notes(pred_notes_in_file) # This is the mixed notes from all channels + + if self.test_pitch_shift_layer is not None and self.hparams.write_output_dir is not None: + # debug only + wav_output_dir = os.path.join(self.hparams.write_output_dir, f"model_output_{dataset_info}") + os.makedirs(wav_output_dir, exist_ok=True) + wav_output_file = os.path.join(wav_output_dir, f"{track_info}_ps_{self.test_pitch_shift_semitone}.wav") + torchaudio.save(wav_output_file, x_ps.squeeze(1), 16000, bits_per_sample=16) + + drum_metric, non_drum_metric, instr_metric = compute_track_metrics( + pred_notes, + notes_dict['notes'], + eval_vocab=self.hparams.eval_vocab[dataloader_idx], + eval_drum_vocab=self.hparams.eval_drum_vocab, + onset_tolerance=self.hparams.onset_tolerance, + add_pitch_class_metric=self.hparams.add_pitch_class_metric, + add_melody_metric=['Singing Voice'] if self.hparams.add_melody_metric_to_singing else None, + add_frame_metric=True, + add_micro_metric=True, + add_multi_f_metric=True) + + if self.hparams.write_output_dir is not None and self.global_rank == 0: + # write model output to file + track_info = [notes_dict[k] for k in notes_dict.keys() if k.endswith("_id")][0] + dataset_info = [k for k in notes_dict.keys() if k.endswith('_id')][0][:-3] + f_score = f"OnF{non_drum_metric['onset_f']:.2f}_MulF{instr_metric['multi_f']:.2f}" + write_model_output_as_midi(pred_notes, + self.hparams.write_output_dir, + track_info, + self.midi_output_inverse_vocab, + output_dir_suffix=str(dataset_info) + '_' + + str(self.hparams.eval_subtask_key) + '_' + f_score) + write_err_cnt_as_json(track_info, self.hparams.write_output_dir, + str(dataset_info) + '_' + str(self.hparams.eval_subtask_key) + '_' + f_score, + n_err_cnt, ne_err_cnt) + + # Test with optimal octave shift + if self.hparams.test_optimal_octave_shift: + track_info = [notes_dict[k] for k in notes_dict.keys() if k.endswith("_id")][0] + dataset_info = [k for k in notes_dict.keys() if k.endswith('_id')][0][:-3] + score = [instr_metric['onset_f_Bass']] + ref_notes_plus = [] + ref_notes_minus = [] + for note in notes_dict['notes']: + if note.is_drum == True: + ref_notes_plus.append(note) + ref_notes_minus.append(note) + else: + ref_notes_plus.append( + Note(is_drum=note.is_drum, + program=note.program, + onset=note.onset, + offset=note.offset, + pitch=note.pitch + 12, + velocity=note.velocity)) + ref_notes_minus.append( + Note(is_drum=note.is_drum, + program=note.program, + onset=note.onset, + offset=note.offset, + pitch=note.pitch - 12, + velocity=note.velocity)) + + drum_metric_plus, non_drum_metric_plus, instr_metric_plus = compute_track_metrics( + pred_notes, + ref_notes_plus, + eval_vocab=self.hparams.eval_vocab[dataloader_idx], + eval_drum_vocab=self.hparams.eval_drum_vocab, + onset_tolerance=self.hparams.onset_tolerance, + add_pitch_class_metric=self.hparams.add_pitch_class_metric) + drum_metric_minus, non_drum_metric_minus, instr_metric_minus = compute_track_metrics( + ref_notes_minus, + notes_dict['notes'], + eval_vocab=self.hparams.eval_vocab[dataloader_idx], + eval_drum_vocab=self.hparams.eval_drum_vocab, + onset_tolerance=self.hparams.onset_tolerance, + add_pitch_class_metric=self.hparams.add_pitch_class_metric) + + score.append(instr_metric_plus['onset_f_Bass']) + score.append(instr_metric_minus['onset_f_Bass']) + max_index = score.index(max(score)) + if max_index == 0: + print(f"ZERO: {track_info}, z/p/m: {score[0]:.2f}/{score[1]:.2f}/{score[2]:.2f}") + elif max_index == 1: + # plus + instr_metric['onset_f_Bass'] = instr_metric_plus['onset_f_Bass'] + print(f"PLUS: {track_info}, z/p/m: {score[0]:.2f}/{score[1]:.2f}/{score[2]:.2f}") + write_model_output_as_midi(ref_notes_plus, + self.hparams.write_output_dir, + track_info + '_ref_octave_plus', + self.midi_output_inverse_vocab, + output_dir_suffix=str(dataset_info) + '_' + + str(self.hparams.eval_subtask_key)) + else: + # minus + instr_metric['onset_f_Bass'] = instr_metric_minus['onset_f_Bass'] + print(f"MINUS: {track_info}, z/p/m: {score[0]:.2f}/{score[1]:.2f}/{score[2]:.2f}") + write_model_output_as_midi(ref_notes_minus, + self.hparams.write_output_dir, + track_info + '_ref_octave_minus', + self.midi_output_, + output_dir_suffix=str(dataset_info) + '_' + + str(self.hparams.eval_subtask_key)) + + self.test_metrics[dataloader_idx].bulk_update(drum_metric) + self.test_metrics[dataloader_idx].bulk_update(non_drum_metric) + self.test_metrics[dataloader_idx].bulk_update(instr_metric) + # self.test_metrics_macro.bulk_update(drum_metric) + # self.test_metrics_macro.bulk_update(non_drum_metric) + # self.test_metrics_macro.bulk_update(instr_metric) + + def on_test_epoch_end(self) -> None: + # all_gather is done seeminglesly by torchmetrics + for test_metrics in self.test_metrics: + self.log_dict(test_metrics.bulk_compute(), sync_dist=True) + test_metrics.bulk_reset() + # self.log_dict(self.test_metrics_macro.bulk_compute(), sync_dist=True) + # self.test_metrics_macro.bulk_reset() + + +def test_case_forward_mt3(): + import torch + from config.config import audio_cfg, model_cfg, shared_cfg + from model.ymt3 import YourMT3 + model = YourMT3() + model.eval() + x = torch.randn(2, 1, 32767) + labels = torch.randint(0, 596, (2, 1, 1024), requires_grad=False) # (B, C=1, T) + task_tokens = torch.LongTensor([]) + output = model.forward(x, labels, task_tokens) + logits, loss = output['logits'], output['loss'] + assert logits.shape == (2, 1024, 596) # (B, N, vocab_size) + + +def test_case_inference_mt3(): + import torch + from config.config import audio_cfg, model_cfg, shared_cfg + from model.ymt3 import YourMT3 + model_cfg["num_max_positions"] = 1024 + 3 + 1 + model = YourMT3(model_cfg=model_cfg) + model.eval() + x = torch.randn(2, 1, 32767) + task_tokens = torch.randint(0, 596, (2, 3), requires_grad=False) + pred_ids = model.inference(x, task_tokens, max_token_length=10) # (2, 3, 9) (B, C, L-task_len) + # TODO: need to check the length of pred_ids when task_tokens is not None + + +def test_case_forward_enc_perceiver_tf_dec_t5(): + import torch + from model.ymt3 import YourMT3 + from config.config import audio_cfg, model_cfg, shared_cfg + model_cfg["encoder_type"] = "perceiver-tf" + audio_cfg["codec"] = "spec" + audio_cfg["hop_length"] = 300 + + model = YourMT3(audio_cfg=audio_cfg, model_cfg=model_cfg) + model.eval() + + x = torch.randn(2, 1, 32767) + labels = torch.randint(0, 596, (2, 1, 1024), requires_grad=False) + + # forward + output = model.forward(x, labels) + logits, loss = output['logits'], output['loss'] # logits: (2, 1024, 596) (B, N, vocab_size) + + # inference + pred_ids = model.inference(x, None, max_token_length=3) # (2, 1, 3) (B, C, L) + + +def test_case_forward_enc_conformer_dec_t5(): + import torch + from model.ymt3 import YourMT3 + from config.config import audio_cfg, model_cfg, shared_cfg + model_cfg["encoder_type"] = "conformer" + audio_cfg["codec"] = "melspec" + audio_cfg["hop_length"] = 128 + model = YourMT3(audio_cfg=audio_cfg, model_cfg=model_cfg) + model.eval() + + x = torch.randn(2, 1, 32767) + labels = torch.randint(0, 596, (2, 1024), requires_grad=False) + + # forward + output = model.forward(x, labels) + logits, loss = output['logits'], output['loss'] # logits: (2, 1024, 596) (B, N, vocab_size) + + # inference + pred_ids = model.inference(x, None, 20) # (2, 1, 20) (B, C, L) + + +def test_case_enc_perceiver_tf_dec_multi_t5(): + import torch + from model.ymt3 import YourMT3 + from config.config import audio_cfg, model_cfg, shared_cfg + model_cfg["encoder_type"] = "perceiver-tf" + model_cfg["decoder_type"] = "multi-t5" + model_cfg["encoder"]["perceiver-tf"]["attention_to_channel"] = True + model_cfg["encoder"]["perceiver-tf"]["num_latents"] = 26 + audio_cfg["codec"] = "spec" + audio_cfg["hop_length"] = 300 + model = YourMT3(audio_cfg=audio_cfg, model_cfg=model_cfg) + model.eval() + + x = torch.randn(2, 1, 32767) + labels = torch.randint(0, 596, (2, 13, 200), requires_grad=False) # (B, C, T) + + # x = model.spectrogram(x) + # x = model.pre_encoder(x) # (2, 110, 128, 128) (B, T, C, D) + # enc_hs = model.encoder(inputs_embeds=x)["last_hidden_state"] # (2, 110, 128, 128) (B, T, C, D) + # enc_hs = model.pre_decoder(enc_hs) # (2, 13, 110, 512) (B, C, T, D) + + # dec_input_ids = model.shift_right_fn(labels) # (2, 13, 200) (B, C, T) + # dec_inputs_embeds = model.embed_tokens(dec_input_ids) # (2, 13, 200, 512) (B, C, T, D) + # dec_hs, _ = model.decoder( + # inputs_embeds=dec_inputs_embeds, encoder_hidden_states=enc_hs, return_dict=False) + # logits = model.lm_head(dec_hs) # (2, 13, 200, 596) (B, C, T, vocab_size) + + # forward + x = torch.randn(2, 1, 32767) + labels = torch.randint(0, 596, (2, 13, 200), requires_grad=False) # (B, C, T) + output = model.forward(x, labels) + logits, loss = output['logits'], output['loss'] # (2, 13, 200, 596) (B, C, T, vocab_size) + + # inference + model.max_total_token_length = 123 # to save time.. + pred_ids = model.inference(x, None) # (2, 13, 123) (B, C, L) diff --git a/amt/src/pytest.ini b/amt/src/pytest.ini new file mode 100644 index 0000000000000000000000000000000000000000..03f586d41650da9718d8c6bdc6ada4a068cb96a8 --- /dev/null +++ b/amt/src/pytest.ini @@ -0,0 +1,2 @@ +[pytest] +pythonpath = . \ No newline at end of file diff --git a/amt/src/requirements.txt b/amt/src/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..039e7116c9888e9d31a17c255582b772e7ce8103 --- /dev/null +++ b/amt/src/requirements.txt @@ -0,0 +1,14 @@ +mirdata +mido +git+https://github.com/craffel/mir_eval.git +git+https://github.com/katsura-jp/pytorch-cosine-annealing-with-warmup.git +matplotlib +lightning>=2.2.1 +pytest-timeout +pytest +deprecated +librosa +einops +transformers +wandb +smart-open diff --git a/amt/src/test.py b/amt/src/test.py new file mode 100644 index 0000000000000000000000000000000000000000..869456ab5be65a8834651683edec4d18551b3af3 --- /dev/null +++ b/amt/src/test.py @@ -0,0 +1,183 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +""" test.py """ +import os +import pprint +import argparse +import torch + +from utils.data_modules import AMTDataModule +from utils.task_manager import TaskManager +from model.init_train import initialize_trainer, update_config +from model.ymt3 import YourMT3 +from config.data_presets import data_preset_single_cfg, data_preset_multi_cfg +from config.vocabulary import drum_vocab_presets +from utils.utils import str2bool + +# yapf: disable +parser = argparse.ArgumentParser(description="YourMT3") +# General +parser.add_argument('exp_id', type=str, help='A unique identifier for the experiment is used to resume training. The "@" symbol can be used to load a specific checkpoint.') +parser.add_argument('-p', '--project', type=str, default='ymt3', help='project name') +parser.add_argument('-d', '--data-preset', type=str, default='musicnet_thickstun_ext_em', help='dataset preset (default=musicnet_thickstun_ext_em). See config/data.py for more options.') +# Audio configurations +parser.add_argument('-ac', '--audio-codec', type=str, default=None, help='audio codec (default=None). {"spec", "melspec"}. If None, default value defined in config.py will be used.') +parser.add_argument('-hop', '--hop-length', type=int, default=None, help='hop length in frames (default=None). {128, 300} 128 for MT3, 300 for PerceiverTFIf None, default value defined in config.py will be used.') +parser.add_argument('-nmel', '--n-mels', type=int, default=None, help='number of mel bins (default=None). If None, default value defined in config.py will be used.') +parser.add_argument('-if', '--input-frames', type=int, default=None, help='number of audio frames for input segment (default=None). If None, default value defined in config.py will be used.') +# Model configurations +parser.add_argument('-sqr', '--sca-use-query-residual', type=str2bool, default=None, help='sca use query residual flag. Default follows config.py') +parser.add_argument('-enc', '--encoder-type', type=str, default=None, help="Encoder type. 't5' or 'perceiver-tf' or 'conformer'. Default is 't5', following config.py.") +parser.add_argument('-dec', '--decoder-type', type=str, default=None, help="Decoder type. 't5' or 'multi-t5'. Default is 't5', following config.py.") +parser.add_argument('-preenc', '--pre-encoder-type', type=str, default='default', help="Pre-encoder type. None or 'conv' or 'default'. By default, t5_enc:None, perceiver_tf_enc:conv, conformer:None") +parser.add_argument('-predec', '--pre-decoder-type', type=str, default='default', help="Pre-decoder type. {None, 'linear', 'conv1', 'mlp', 'group_linear'} or 'default'. Default is {'t5': None, 'perceiver-tf': 'linear', 'conformer': None}.") +parser.add_argument('-cout', '--conv-out-channels', type=int, default=None, help='Number of filters for pre-encoder conv layer. Default follows "model_cfg" of config.py.') +parser.add_argument('-tenc', '--task-cond-encoder', type=str2bool, default=True, help='task conditional encoder (default=True). True or False') +parser.add_argument('-tdec', '--task-cond-decoder', type=str2bool, default=True, help='task conditional decoder (default=True). True or False') +parser.add_argument('-df', '--d-feat', type=int, default=None, help='Audio feature will be projected to this dimension for Q,K,V of T5 or K,V of Perceiver (default=None). If None, default value defined in config.py will be used.') +parser.add_argument('-pt', '--pretrained', type=str2bool, default=False, help='pretrained T5(default=False). True or False') +parser.add_argument('-b', '--base-name', type=str, default="google/t5-v1_1-small", help='base model name (default="google/t5-v1_1-small")') +parser.add_argument('-epe', '--encoder-position-encoding-type', type=str, default='default', help="Positional encoding type of encoder. By default, pre-defined PE for T5 or Perceiver-TF encoder in config.py. For T5: {'sinusoidal', 'trainable'}, conformer: {'rotary', 'trainable'}, Perceiver-TF: {'trainable', 'rope', 'alibi', 'alibit', 'None', '0', 'none', 'tkd', 'td', 'tk', 'kdt'}.") +parser.add_argument('-dpe', '--decoder-position-encoding-type', type=str, default='default', help="Positional encoding type of decoder. By default, pre-defined PE for T5 in config.py. {'sinusoidal', 'trainable'}.") +parser.add_argument('-twe', '--tie-word-embedding', type=str2bool, default=None, help='tie word embedding (default=None). If None, default value defined in config.py will be used.') +parser.add_argument('-el', '--event-length', type=int, default=None, help='event length (default=None). If None, default value defined in model cfg of config.py will be used.') +# Perceiver-TF configurations +parser.add_argument('-dl', '--d-latent', type=int, default=None, help='Latent dimension of Perceiver. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.') +parser.add_argument('-nl', '--num-latents', type=int, default=None, help='Number of latents of Perceiver. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.') +parser.add_argument('-dpm', '--perceiver-tf-d-model', type=int, default=None, help='Perceiver-TF d_model (default=None). If None, default value defined in config.py will be used.') +parser.add_argument('-npb', '--num-perceiver-tf-blocks', type=int, default=None, help='Number of blocks of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py.') +parser.add_argument('-npl', '--num-perceiver-tf-local-transformers-per-block', type=int, default=None, help='Number of local layers per block of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.') +parser.add_argument('-npt', '--num-perceiver-tf-temporal-transformers-per-block', type=int, default=None, help='Number of temporal layers per block of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.') +parser.add_argument('-atc', '--attention-to-channel', type=str2bool, default=None, help='Attention to channel flag of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.') +parser.add_argument('-ln', '--layer-norm-type', type=str, default=None, help='Layer normalization type (default=None). {"layer_norm", "rms_norm"}. If None, default value defined in config.py will be used.') +parser.add_argument('-ff', '--ff-layer-type', type=str, default=None, help='Feed forward layer type (default=None). {"mlp", "moe", "gmlp"}. If None, default value defined in config.py will be used.') +parser.add_argument('-wf', '--ff-widening-factor', type=int, default=None, help='Feed forward layer widening factor for MLP/MoE/gMLP (default=None). If None, default value defined in config.py will be used.') +parser.add_argument('-nmoe', '--moe-num-experts', type=int, default=None, help='Number of experts for MoE (default=None). If None, default value defined in config.py will be used.') +parser.add_argument('-kmoe', '--moe-topk', type=int, default=None, help='Top-k for MoE (default=None). If None, default value defined in config.py will be used.') +parser.add_argument('-act', '--hidden-act', type=str, default=None, help='Hidden activation function (default=None). {"gelu", "silu", "relu", "tanh"}. If None, default value defined in config.py will be used.') +parser.add_argument('-rt', '--rotary-type', type=str, default=None, help='Rotary embedding type expressed in three letters. e.g. ppl: "pixel" for SCA and latents, "lang" for temporal transformer. If None, use config.') +parser.add_argument('-rk', '--rope-apply-to-keys', type=str2bool, default=None, help='Apply rope to keys (default=None). If None, use config.') +parser.add_argument('-rp', '--rope-partial-pe', type=str2bool, default=None, help='Whether to apply RoPE to partial positions (default=None). If None, use config.') +# Decoder configurations +parser.add_argument('-dff', '--decoder-ff-layer-type', type=str, default=None, help='Feed forward layer type of decoder (default=None). {"mlp", "moe", "gmlp"}. If None, default value defined in config.py will be used.') +parser.add_argument('-dwf', '--decoder-ff-widening-factor', type=int, default=None, help='Feed forward layer widening factor for decoder MLP/MoE/gMLP (default=None). If None, default value defined in config.py will be used.') +# Task and Evaluation configurations +parser.add_argument('-tk', '--task', type=str, default='mt3_full_plus', help='tokenizer type (default=mt3_full_plus). See config/task.py for more options.') +parser.add_argument('-epv', '--eval-program-vocab', type=str, default=None, help='evaluation vocabulary (default=None). If None, default vocabulary of the data preset will be used.') +parser.add_argument('-edv', '--eval-drum-vocab', type=str, default=None, help='evaluation vocabulary for drum (default=None). If None, default vocabulary of the data preset will be used.') +parser.add_argument('-etk', '--eval-subtask-key', type=str, default='default', help='evaluation subtask key (default=default). See config/task.py for more options.') +parser.add_argument('-t', '--onset-tolerance', type=float, default=0.05, help='onset tolerance (default=0.05).') +parser.add_argument('-os', '--test-octave-shift', type=str2bool, default=False, help='test optimal octave shift (default=False). True or False') +parser.add_argument('-w', '--write-model-output', type=str2bool, default=False, help='write model test output to file (default=False). True or False') +# Trainer configurations +parser.add_argument('-pr','--precision', type=str, default="bf16-mixed", help='precision (default="bf16-mixed") {32, 16, bf16, bf16-mixed}') +parser.add_argument('-st', '--strategy', type=str, default='auto', help='strategy (default=auto). auto or deepspeed or ddp') +parser.add_argument('-n', '--num-nodes', type=int, default=1, help='number of nodes (default=1)') +parser.add_argument('-g', '--num-gpus', type=str, default='auto', help='number of gpus (default="auto")') +parser.add_argument('-wb', '--wandb-mode', type=str, default=None, help='wandb mode for logging (default=None). "disabled" or "online" or "offline". If None, default value defined in config.py will be used.') +# Debug +parser.add_argument('-debug', '--debug-mode', type=str2bool, default=False, help='debug mode (default=False). True or False') +parser.add_argument('-tps', '--test-pitch-shift', type=int, default=None, help='use pitch shift when testing. debug-purpose only. (default=None). semitone in int.') +args = parser.parse_args() +# yapf: enable +if torch.__version__ >= "1.13": + torch.set_float32_matmul_precision("high") +args.epochs = None + +# Initialize trainer +trainer, wandb_logger, dir_info, shared_cfg = initialize_trainer(args, stage='test') + +# Update config with args, including augmentation settings +shared_cfg, audio_cfg, model_cfg = update_config(args, shared_cfg, stage='test') + + +def main(): + # Data preset + if args.data_preset in data_preset_single_cfg: + # convert single preset into multi preset format + data_preset = { + "presets": [args.data_preset], + "eval_vocab": data_preset_single_cfg[args.data_preset]["eval_vocab"], + } + for k in data_preset_single_cfg[args.data_preset].keys(): + if k in ["eval_drum_vocab", "add_pitch_class_metric"]: + data_preset[k] = data_preset_single_cfg[args.data_preset][k] + elif args.data_preset in data_preset_multi_cfg: + data_preset = data_preset_multi_cfg[args.data_preset] + else: + raise ValueError("Invalid data preset") + eval_drum_vocab = data_preset.get("eval_drum_vocab", None) + + if args.eval_drum_vocab != None: # override eval_drum_vocab + eval_drum_vocab = drum_vocab_presets[args.eval_drum_vocab] + + # Task manager + tm = TaskManager(task_name=args.task, + max_shift_steps=int(shared_cfg["TOKENIZER"]["max_shift_steps"]), + debug_mode=args.debug_mode) + print(f"Task: {tm.task_name}, Max Shift Steps: {tm.max_shift_steps}") + + results = [] + for i, preset in enumerate(data_preset["presets"]): + # sdp: unpacking multi preset as a list of single presets + sdp = { + "presets": [preset], + "eval_vocab": [data_preset["eval_vocab"][i]], + "eval_drum_vocab": eval_drum_vocab, + } + for k in data_preset.keys(): + if k not in ["presets", "eval_vocab"]: + sdp[k] = data_preset[k] + + dm = AMTDataModule(data_preset_multi=sdp, task_manager=tm, audio_cfg=audio_cfg) + + model = YourMT3( + audio_cfg=audio_cfg, + model_cfg=model_cfg, + shared_cfg=shared_cfg, + optimizer=None, + task_manager=tm, # tokenizer is a member of task_manager + eval_subtask_key=args.eval_subtask_key, + eval_vocab=args.eval_program_vocab if args.eval_program_vocab != None else sdp["eval_vocab"], + eval_drum_vocab=sdp["eval_drum_vocab"], + write_output_dir=dir_info["lightning_dir"] if args.write_model_output or args.test_octave_shift else None, + onset_tolerance=float(args.onset_tolerance), + add_pitch_class_metric=sdp.get("add_pitch_class_metric", None), + test_optimal_octave_shift=args.test_octave_shift, + test_pitch_shift_layer=args.test_pitch_shift) + + # load checkpoint & drop pitchshift from state_dict + checkpoint = torch.load(dir_info["last_ckpt_path"]) + state_dict = checkpoint['state_dict'] + new_state_dict = {k: v for k, v in state_dict.items() if 'pitchshift' not in k} + model.load_state_dict(new_state_dict, strict=False) + # if args.test_pitch_shift is None: + # new_state_dict = {k: v for k, v in state_dict.items() if 'pitchshift' not in k} + # model.load_state_dict(new_state_dict, strict=False) + # else: + # model.load_state_dict(state_dict, strict=False) + + results.append("-----------------------------------------------------------------") + results.append(sdp) + results.append(trainer.test(model, datamodule=dm)) + # TODO: directly load checkpoint including hyperparmeters https://lightning.ai/docs/pytorch/1.6.2/common/hyperparameters.html + + # save result + pp = pprint.PrettyPrinter(indent=4) + results_str = pp.pformat(results) + result_file = os.path.join(dir_info["lightning_dir"], + f"result_{args.task}_{args.eval_subtask_key}_{args.data_preset}.json") + with open(result_file, 'w') as f: + f.write(results_str) + print(f"Result is saved to {result_file}") + + +if __name__ == "__main__": + main() diff --git a/amt/src/tests/.DS_Store b/amt/src/tests/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..b4183d2deded642a50af6919af606d2574b28c88 Binary files /dev/null and b/amt/src/tests/.DS_Store differ diff --git a/amt/src/tests/assert_fns.py b/amt/src/tests/assert_fns.py new file mode 100644 index 0000000000000000000000000000000000000000..9e374bcd9ccddcb9d0cb279a21fefeca73390cad --- /dev/null +++ b/amt/src/tests/assert_fns.py @@ -0,0 +1,55 @@ +""" assert_fn.py """ +import numpy as np + + +def assert_notes_almost_equal(actual_notes, predicted_notes, delta=5e-3): + """ + Asserts that the given lists of Note instances are equal up to a small + floating-point tolerance, similar to `assertAlmostEqual` of `unittest`. + Tolerance is 5e-3 by default, which is 5 ms for 100 ticks-per-second. + """ + assert len(actual_notes) == len(predicted_notes) + for actual_note, predicted_note in zip(actual_notes, predicted_notes): + assert abs(actual_note.onset - predicted_note.onset) < delta + assert abs(actual_note.offset - predicted_note.offset) < delta + assert actual_note.pitch == predicted_note.pitch + if actual_note.is_drum is False and predicted_note.is_drum is False: + assert actual_note.program == predicted_note.program + assert actual_note.is_drum == predicted_note.is_drum + assert actual_note.velocity == predicted_note.velocity + + +def assert_note_events_almost_equal(actual_note_events, + predicted_note_events, + ignore_time=False, + ignore_activity=True, + delta=5.1e-3): + """ + Asserts that the given lists of Note instances are equal up to a small + floating-point tolerance, similar to `assertAlmostEqual` of `unittest`. + Tolerance is 5e-3 by default, which is 5 ms for 100 ticks-per-second. + + If `ignore_time` is True, then the time field is ignored. (useful for + comparing tie note events, default is False) + + If `ignore_activity` is True, then the activity field is ignored (default + is True). + """ + assert len(actual_note_events) == len(predicted_note_events) + for j, (actual_note_event, + predicted_note_event) in enumerate(zip(actual_note_events, predicted_note_events)): + if ignore_time is False: + assert abs(actual_note_event.time - predicted_note_event.time) <= delta + assert actual_note_event.is_drum == predicted_note_event.is_drum + if actual_note_event.is_drum is False and predicted_note_event.is_drum is False: + assert actual_note_event.program == predicted_note_event.program + assert actual_note_event.pitch == predicted_note_event.pitch + assert actual_note_event.velocity == predicted_note_event.velocity + if ignore_activity is False: + assert actual_note_event.activity == predicted_note_event.activity + + +def assert_track_metrics_score1(metrics) -> None: + for k, v in metrics.items(): + if np.isnan(v) is False: + assert v == 1.0 diff --git a/amt/src/tests/audio_test.py b/amt/src/tests/audio_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1d54319436aaa3727b657cdca930dca31f348ba5 --- /dev/null +++ b/amt/src/tests/audio_test.py @@ -0,0 +1,144 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +"""audio_test.py""" +import unittest +import os +import numpy as np +import wave +import tempfile +from utils.audio import load_audio_file +from utils.audio import get_audio_file_info +from utils.audio import slice_padded_array +from utils.audio import slice_padded_array_for_subbatch +from utils.audio import write_wav_file + + +class TestLoadAudioFile(unittest.TestCase): + + def create_temp_wav_file(self, duration: float, fs: int = 16000) -> str: + n_samples = int(duration * fs) + temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) + temp_filename = temp_file.name + + data = np.random.randint(-2**15, 2**15, n_samples, dtype=np.int16) + + with wave.open(temp_filename, 'wb') as f: + f.setnchannels(1) + f.setsampwidth(2) + f.setframerate(fs) + f.writeframes(data.tobytes()) + + return temp_filename + + def test_load_audio_file(self): + duration = 3.0 + fs = 16000 + temp_filename = self.create_temp_wav_file(duration, fs) + + # Test load entire file + audio_data = load_audio_file(temp_filename, dtype=np.int16) + file_fs, n_frames, n_channels = get_audio_file_info(temp_filename) + + self.assertEqual(len(audio_data), n_frames) + self.assertEqual(file_fs, fs) + self.assertEqual(n_channels, 1) + + # Test load specific segment + seg_start_sec = 1.0 + seg_length_sec = 1.0 + audio_data = load_audio_file(temp_filename, seg_start_sec, seg_length_sec, dtype=np.int16) + + self.assertEqual(len(audio_data), int(seg_length_sec * fs)) + + # Test unsupported file extension + with self.assertRaises(NotImplementedError): + load_audio_file("unsupported.xyz") + + +class TestSliceArray(unittest.TestCase): + + def setUp(self): + self.x = np.random.randint(0, 10, size=(1, 10000)) + + def test_without_padding(self): + sliced_x = slice_padded_array(self.x, slice_length=100, slice_hop=50, pad=False) + self.assertEqual(sliced_x.shape, (199, 100)) + + def test_with_padding(self): + sliced_x = slice_padded_array(self.x, slice_length=100, slice_hop=50, pad=True) + self.assertEqual(sliced_x.shape, (199, 100)) + + def test_content(self): + sliced_x = slice_padded_array(self.x, slice_length=100, slice_hop=50, pad=True) + for i in range(sliced_x.shape[0] - 1): + np.testing.assert_array_equal(sliced_x[i, :], self.x[:, i * 50:i * 50 + 100].flatten()) + # Test the last slice separately to account for potential padding + last_slice = sliced_x[-1, :] + last_slice_no_padding = self.x[:, -100:].flatten() + np.testing.assert_array_equal(last_slice[:len(last_slice_no_padding)], last_slice_no_padding) + + +class TestSlicePadForSubbatch(unittest.TestCase): + + def test_slice_padded_array_for_subbatch(self): + input_array = np.random.randn(6, 10) + slice_length = 4 + slice_hop = 2 + pad = True + sub_batch_size = 4 + + expected_output_shape = (4, 4) + + # Call the slice_pad_for_subbatch function + result = slice_padded_array_for_subbatch(input_array, slice_length, slice_hop, pad, sub_batch_size) + + # Check if the output shape is correct + self.assertEqual(result.shape, expected_output_shape) + + # Check if the number of slices is divisible by sub_batch_size + self.assertEqual(result.shape[0] % sub_batch_size, 0) + + +class TestWriteWavFile(unittest.TestCase): + + def test_write_wav_file_z(self): + # Generate some test audio data + samplerate = 16000 + duration = 1 # 1 second + t = np.linspace(0, duration, int(samplerate * duration), endpoint=False) + x = np.sin(2 * np.pi * 440 * t) + + # Write the test audio data to a WAV file + filename = "extras/test.wav" + write_wav_file(filename, x, samplerate) + + # Read the written WAV file and check its contents + with wave.open(filename, "rb") as wav_file: + # Check the WAV file parameters + self.assertEqual(wav_file.getnchannels(), 1) + self.assertEqual(wav_file.getsampwidth(), 2) + self.assertEqual(wav_file.getframerate(), samplerate) + self.assertEqual(wav_file.getnframes(), len(x)) + + # Read the audio samples from the WAV file + data = wav_file.readframes(len(x)) + + # Convert the audio sample byte string to a NumPy array and normalize it to the range [-1, 1] + x_read = np.frombuffer(data, dtype=np.int16) / 32767.0 + + # Check that the audio samples read from the WAV file are equal to the original audio samples + np.testing.assert_allclose(x_read, x, atol=1e-4) + + # Delete the written WAV file + os.remove(filename) + + +if __name__ == '__main__': + unittest.main() diff --git a/amt/src/tests/event2note_test.py b/amt/src/tests/event2note_test.py new file mode 100644 index 0000000000000000000000000000000000000000..306d3de80743dab6f5ba2dd81149a62a60f34dca --- /dev/null +++ b/amt/src/tests/event2note_test.py @@ -0,0 +1,187 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +"""event2midi_test.py: + +This file contains tests for the following classes: +• event2note_event +• note_event2note + +""" +import unittest +import pytest +from numpy import random +from assert_fns import assert_notes_almost_equal +from utils.note_event_dataclasses import Event, Note, NoteEvent +from utils.event2note import event2note_event +from utils.event2note import note_event2note + + +# yapf: disable +class TestEvent2NoteEvent(unittest.TestCase): + def test_event2note_event(self): + events = [ + Event('program', 33), Event('pitch', 60), + Event('program', 52), Event('pitch', 40), + Event('tie', 0), + Event('shift', 20), Event('velocity', 1), Event('drum', 36), + Event('shift', 150), Event('program', 33), Event('velocity', 0), Event('pitch', 60), + Event('shift', 160), Event('velocity', 1), Event('pitch', 62), Event('program', 100), + Event('pitch', 77), + Event('shift', 200), Event('velocity', 0), Event('pitch', 77), + Event('shift', 250), Event('velocity', 1), Event('drum', 38), + Event('shift', 300), Event('velocity', 0), Event('program', 33), Event('pitch', 62) + ] + + note_events, tie_note_events, last_activity, err_cnt = event2note_event(events, start_time=0, sort=False, tps=100) + self.assertEqual(len(err_cnt), 0) + expected_note_events = [NoteEvent(True, 128, 0.2, 1, 36), + NoteEvent(False, 33, 1.5, 0, 60), + NoteEvent(False, 33, 1.6, 1, 62), + NoteEvent(False, 100, 1.6, 1, 77), + NoteEvent(False, 100, 2.0, 0, 77), + NoteEvent(True, 128, 2.5, 1, 38), + NoteEvent(False, 33, 3.0, 0, 62)] + expected_tie_note_events = [NoteEvent(False, 33, None, 1, 60), + NoteEvent(False, 52, None, 1, 40)] + expected_last_activity = [(52, 40)] + self.assertSequenceEqual(note_events, expected_note_events) + self.assertSequenceEqual(tie_note_events, expected_tie_note_events) + self.assertSequenceEqual(last_activity, expected_last_activity) + +class TestEvent2NoteEventInvalidInputWarn(unittest.TestCase): + def test_event2note_event_with_invalid_shift_value(self): + events = [Event('tie', 0), Event('shift', 0), Event('shift', 1050)] # shift: 0 <= value <= 1000 + _, _, _, err_cnt = event2note_event(events, start_time=0, sort=True, tps=100) + self.assertEqual(err_cnt['Err/Shift out of range'], 2) + + def test_event2note_event_with_invalid_pitch_event(self): + events = [Event('pitch', 60), Event('tie', 0)] # pitch event must follow a program event + _, _, _, err_cnt = event2note_event(events, start_time=0, sort=True, tps=100) + self.assertEqual(err_cnt['Err/Missing prg in tie'], 1) + + def test_event2note_event_with_invalid_tie_event(self): + events = [Event('shift', 10)] + _, _, _, err_cnt = event2note_event(events, start_time=0, sort=True, tps=100) + self.assertEqual(err_cnt['Err/Missing tie'], 1) + +class TestEvent2NoteEventSpecialEvent(unittest.TestCase): + def test_note_event2note_special_events(self): + events = [Event('program', 33), Event('pitch', 60), + Event('tie', 0), + Event('shift', 10), Event('program', 33), Event('velocity', 0), Event('pitch', 60), + Event('EOS', 0), Event('PAD', 0), # <- will stop decoding at this point... + Event('shift', 20), Event('velocity', 1), Event('pitch', 20), + Event('shift', 30), Event('velocity', 1), Event('pitch', 20),] + note_events, tie_note_events, _, err_cnt = event2note_event(events, start_time=0) + print(note_events) + self.assertEqual(len(note_events), 1) + self.assertEqual(len(tie_note_events), 1) + self.assertEqual(len(err_cnt), 0) + + +class TestNoteEvent2Note(unittest.TestCase): + + def test_note_event2note(self): + + note_events = [NoteEvent(is_drum=False, program=33, time=0, velocity=1, pitch=60), + NoteEvent(is_drum=False, program=33, time=1.5, velocity=0, pitch=60), + NoteEvent(is_drum=False, program=33, time=1.6, velocity=1, pitch=62), + NoteEvent(is_drum=False, program=33, time=3.0, velocity=0, pitch=62), + NoteEvent(is_drum=False, program=100, time=1.6, velocity=1, pitch=77), + NoteEvent(is_drum=False, program=100, time=2.0, velocity=0, pitch=77), + NoteEvent(is_drum=True, program=128, time=0.2, velocity=1, pitch=36), + NoteEvent(is_drum=True, program=128, time=2.5, velocity=1, pitch=38) + ] + notes, err_cnt = note_event2note(note_events, sort=True) + + expected_notes = [ + Note(is_drum=False, program=33, onset=0, offset=1.5, pitch=60, velocity=1), + Note(is_drum=True, program=128, onset=0.2, offset=0.21, pitch=36, velocity=1), + Note(is_drum=False, program=33, onset=1.6, offset=3.0, pitch=62, velocity=1), + Note(is_drum=False, program=100, onset=1.6, offset=2.0, pitch=77, velocity=1), + Note(is_drum=True, program=128, onset=2.5, offset=2.51, pitch=38, velocity=1) + ] + self.assertEqual(len(err_cnt), 0) + assert_notes_almost_equal(notes, expected_notes, delta=5e-3) + + + def test_note_event2note_simple_cases(self): + # Case 1: Basic test case with two notes + note_events = [ + NoteEvent(is_drum=False, program=0, time=0.1, velocity=1, pitch=60), + NoteEvent(is_drum=False, program=0, time=0.5, velocity=0, pitch=60), + NoteEvent(is_drum=False, program=0, time=0.7, velocity=1, pitch=62), + NoteEvent(is_drum=False, program=0, time=1.5, velocity=0, pitch=62), + ] + + expected_notes = [ + Note(is_drum=False, program=0, onset=0.1, offset=0.5, pitch=60, velocity=1), + Note(is_drum=False, program=0, onset=0.7, offset=1.5, pitch=62, velocity=1), + ] + notes, err_cnt = note_event2note(note_events) + self.assertEqual(len(err_cnt), 0) + self.assertSequenceEqual(notes, expected_notes) + + # Case 2: Test with drum notes + note_events = [ + NoteEvent(is_drum=True, program=128, time=0.2, velocity=1, pitch=36), + NoteEvent(is_drum=True, program=128, time=0.3, velocity=1, pitch=38), + NoteEvent(is_drum=True, program=128, time=0.4, velocity=0, pitch=36), + NoteEvent(is_drum=True, program=128, time=0.5, velocity=0, pitch=38), + ] + + expected_notes = [ + Note(is_drum=True, program=128, onset=0.2, offset=0.21, pitch=36, velocity=1), + Note(is_drum=True, program=128, onset=0.3, offset=0.31, pitch=38, velocity=1), + ] + notes, err_cnt = note_event2note(note_events) + self.assertEqual(len(err_cnt), 0) + assert_notes_almost_equal(notes, expected_notes, delta=5.1e-3) + + + def test_note_event2note_multiple_overlapping_notes(self): + + note_events = [ + NoteEvent(is_drum=False, program=1, time=0.0, velocity=1, pitch=60), + NoteEvent(is_drum=False, program=1, time=0.5, velocity=0, pitch=60), + NoteEvent(is_drum=False, program=1, time=1.0, velocity=1, pitch=62), + NoteEvent(is_drum=False, program=1, time=1.5, velocity=0, pitch=62), + NoteEvent(is_drum=False, program=2, time=0.25, velocity=1, pitch=60), + NoteEvent(is_drum=False, program=2, time=0.75, velocity=0, pitch=60), + NoteEvent(is_drum=False, program=2, time=1.25, velocity=1, pitch=62), + NoteEvent(is_drum=False, program=2, time=1.75, velocity=0, pitch=62), + NoteEvent(is_drum=False, program=3, time=0.0, velocity=1, pitch=64), + NoteEvent(is_drum=False, program=3, time=1.0, velocity=0, pitch=64), + NoteEvent(is_drum=False, program=4, time=0.5, velocity=1, pitch=66), + NoteEvent(is_drum=False, program=4, time=1.5, velocity=0, pitch=66), + NoteEvent(is_drum=False, program=4, time=0.75, velocity=1, pitch=67), + NoteEvent(is_drum=False, program=4, time=1.75, velocity=0, pitch=67), + NoteEvent(is_drum=False, program=4, time=1.0, velocity=1, pitch=69), + NoteEvent(is_drum=False, program=4, time=2.0, velocity=0, pitch=69) + ] + + expected_notes = [ + Note(is_drum=False, program=1, onset=0.0, offset=0.5, pitch=60, velocity=1), + Note(is_drum=False, program=3, onset=0.0, offset=1.0, pitch=64, velocity=1), + Note(is_drum=False, program=2, onset=0.25, offset=0.75, pitch=60, velocity=1), + Note(is_drum=False, program=4, onset=0.5, offset=1.5, pitch=66, velocity=1), + Note(is_drum=False, program=4, onset=0.75, offset=1.75, pitch=67, velocity=1), + Note(is_drum=False, program=1, onset=1.0, offset=1.5, pitch=62, velocity=1), + Note(is_drum=False, program=4, onset=1.0, offset=2.0, pitch=69, velocity=1), + Note(is_drum=False, program=2, onset=1.25, offset=1.75, pitch=62, velocity=1) + ] + + notes, err_cnt = note_event2note(note_events) + self.assertEqual(len(err_cnt), 0) + assert_notes_almost_equal(notes, expected_notes, delta=5e-3) + +# yapf: enable +if __name__ == '__main__': + unittest.main() diff --git a/amt/src/tests/event_codec_test.py b/amt/src/tests/event_codec_test.py new file mode 100644 index 0000000000000000000000000000000000000000..68fbe94fb25d1918958299562f684b6b6443a043 --- /dev/null +++ b/amt/src/tests/event_codec_test.py @@ -0,0 +1,158 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +"""event_codec_test.py: + +This file contains tests for the following classes: +• Event +• EventRange +• FastCodec equivalent to MT3 author's Codec + +See tokenizer_test.py for the FastCodec performance benchmark + +""" +import unittest +from utils.note_event_dataclasses import Event, EventRange +from utils.event_codec import FastCodec as Codec +# from utils.event_codec import Codec + + +class TestEvent(unittest.TestCase): + + def test_Event(self): + e = Event(type='shift', value=0) + self.assertEqual(e.type, 'shift') + self.assertEqual(e.value, 0) + + +class TestEventRange(unittest.TestCase): + + def test_EventRange(self): + er = EventRange('abc', min_value=0, max_value=500) + self.assertEqual(er.type, 'abc') + self.assertEqual(er.min_value, 0) + self.assertEqual(er.max_value, 500) + + +class TestEventCodec(unittest.TestCase): + + def test_event_codec(self): + ec = Codec( + special_tokens=['asd'], + max_shift_steps=1001, + event_ranges=[ + EventRange('pitch', min_value=0, max_value=127), + EventRange('velocity', min_value=0, max_value=1), + EventRange('tie', min_value=0, max_value=0), + EventRange('program', min_value=0, max_value=127), + EventRange('drum', min_value=0, max_value=127), + ], + ) + + events = [ + Event(type='shift', value=0), # actually not needed + Event(type='shift', value=1), # 10 ms shift + Event(type='shift', value=1000), # 10 s shift + Event(type='pitch', value=0), # lowest pitch 8.18 Hz + Event(type='pitch', value=60), # C4 or 261.63 Hz + Event(type='pitch', value=127), # highest pitch G9 or 12543.85 Hz + Event(type='velocity', value=0), # lowest velocity) + Event(type='velocity', value=1), # lowest velocity) + Event(type='tie', value=0), # tie + Event(type='program', value=0), # program + Event(type='program', value=127), # program + Event(type='drum', value=0), # drum + Event(type='drum', value=127), # drum + ] + + encoded = [ec.encode_event(e) for e in events] + decoded = [ec.decode_event_index(idx) for idx in encoded] + self.assertSequenceEqual(events, decoded) + + +class TestEventCodecErrorCases(unittest.TestCase): + + def setUp(self): + self.event_ranges = [ + EventRange("program", 0, 127), + EventRange("pitch", 0, 127), + EventRange("velocity", 0, 3), + EventRange("drum", 0, 127), + EventRange("tie", 0, 1), + ] + self.ec = Codec([], 1000, self.event_ranges) + + def test_encode_event_with_invalid_event_type(self): + with self.assertRaises(ValueError): + self.ec.encode_event(Event("unknown_event_type", 50)) + + def test_encode_event_with_invalid_event_value(self): + with self.assertRaises(ValueError): + self.ec.encode_event(Event("program", 200)) + + def test_event_type_range_with_invalid_event_type(self): + with self.assertRaises(ValueError): + self.ec.event_type_range("unknown_event_type") + + def test_decode_event_index_with_invalid_index(self): + with self.assertRaises(ValueError): + self.ec.decode_event_index(1000000) + + +class TestEventCodecVocabulary(unittest.TestCase): + + def test_encode_event_using_program_vocabulary(self): + prog_vocab = {"Piano": [0, 1, 2, 3, 4, 5, 6, 7], "xxx": [50, 30, 120]} + ec = Codec(special_tokens=['asd'], + max_shift_steps=1001, + event_ranges=[ + EventRange('pitch', min_value=0, max_value=127), + EventRange('velocity', min_value=0, max_value=1), + EventRange('tie', min_value=0, max_value=0), + EventRange('program', min_value=0, max_value=127), + EventRange('drum', min_value=0, max_value=127), + ], + program_vocabulary=prog_vocab) + + events = [ + Event(type='program', value=0), # 0 --> 0 + Event(type='program', value=7), # 7 --> 0 + Event(type='program', value=111), # 111 --> 111 + Event(type='program', value=30), # 30 --> 50 + ] + encoded = [ec.encode_event(e) for e in events] + expected = [1133, 1133, 1244, 1183] + self.assertSequenceEqual(encoded, expected) + + def test_encode_event_using_drum_vocabulary(self): + drum_vocab = {"Kick": [50, 51, 52], "Snare": [53, 54]} + ec = Codec(special_tokens=['asd'], + max_shift_steps=1001, + event_ranges=[ + EventRange('pitch', min_value=0, max_value=127), + EventRange('velocity', min_value=0, max_value=1), + EventRange('tie', min_value=0, max_value=0), + EventRange('program', min_value=0, max_value=127), + EventRange('drum', min_value=0, max_value=127), + ], + drum_vocabulary=drum_vocab) + + events = [ + Event(type='drum', value=50), + Event(type='drum', value=51), + Event(type='drum', value=53), + Event(type='drum', value=54), + ] + encoded = [ec.encode_event(e) for e in events] + self.assertEqual(encoded[0], encoded[1]) + self.assertEqual(encoded[2], encoded[3]) + + +if __name__ == '__main__': + unittest.main() diff --git a/amt/src/tests/metrics_test.py b/amt/src/tests/metrics_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2ce80ad36151c50d85253c6e256608f9b8d8d883 --- /dev/null +++ b/amt/src/tests/metrics_test.py @@ -0,0 +1,118 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +"""metrics_test.py: + +This file contains tests for the following classes: +• AMTMetrics + +""" +import unittest +import warnings +import torch +import numpy as np +from utils.metrics import AMTMetrics +from utils.metrics import compute_track_metrics + + +class TestAMTMetrics(unittest.TestCase): + + def test_individual_attributes(self): + metric = AMTMetrics() + + # Test updating the metric using .update() method + metric.onset_f.update(0.5) + + # Test updating the metric using __call__() method + metric.onset_f(0.5) + + # Test updating the metric with a weight + metric.onset_f(0, weight=1.0) + + # Test computing the average value of the metric + computed_value = metric.onset_f.compute() + self.assertAlmostEqual(computed_value, 0.3333333333333333) + + # Test resetting the metric + metric.onset_f.reset() + with self.assertWarns(UserWarning): + torch._assert(metric.onset_f.compute(), torch.nan) + + # Test bulk_compute + with self.assertWarns(UserWarning): + computed_metrics = metric.bulk_compute() + + def test_bulk_update_and_compute(self): + metric = AMTMetrics() + + # Test bulk_update with values only + d1 = {'onset_f': 0.5, 'offset_f': 0.5} + metric.bulk_update(d1) + + # Test bulk_update with values and weights + d2 = {'onset_f': {'value': 0.5, 'weight': 1.0}, 'offset_f': {'value': 0.5, 'weight': 1.0}} + metric.bulk_update(d2) + + # Test bulk_compute + computed_metrics = metric.bulk_compute() + + # Ensure the 'onset_f' and 'offset_f' keys exist in the computed_metrics dictionary + self.assertIn('onset_f', computed_metrics) + self.assertIn('offset_f', computed_metrics) + + # Check the computed values + self.assertAlmostEqual(computed_metrics['onset_f'], 0.5) + self.assertAlmostEqual(computed_metrics['offset_f'], 0.5) + + def test_compute_track_metrics_singing(self): + from config.vocabulary import SINGING_SOLO_CLASS, GM_INSTR_CLASS_PLUS + from utils.event2note import note_event2note + + ref_notes_dict = np.load('extras/examples/singing_notes.npy', allow_pickle=True).tolist() + ref_note_events_dict = np.load('extras/examples/singing_note_events.npy', allow_pickle=True).tolist() + est_notes, _ = note_event2note(ref_note_events_dict['note_events']) + ref_notes = ref_notes_dict['notes'] + + metric = AMTMetrics(prefix=f'test/', extra_classes=[k for k in SINGING_SOLO_CLASS.keys()]) + drum_metric, non_drum_metric, instr_metric = compute_track_metrics(est_notes, + ref_notes, + eval_vocab=SINGING_SOLO_CLASS, + eval_drum_vocab=None, + onset_tolerance=0.05) + metric.bulk_update(drum_metric) + metric.bulk_update(non_drum_metric) + metric.bulk_update(instr_metric) + computed_metrics = metric.bulk_compute() + cnt = 0 + for k, v in computed_metrics.items(): + if 'Singing Voice' in k: + self.assertEqual(v, 1.0) + cnt += 1 + self.assertEqual(cnt, 6) + + metric = AMTMetrics(prefix=f'test/', extra_classes=[k for k in GM_INSTR_CLASS_PLUS.keys()]) + drum_metric, non_drum_metric, instr_metric = compute_track_metrics(est_notes, + ref_notes, + eval_vocab=GM_INSTR_CLASS_PLUS, + eval_drum_vocab=None, + onset_tolerance=0.05) + metric.bulk_update(drum_metric) + metric.bulk_update(non_drum_metric) + metric.bulk_update(instr_metric) + computed_metrics = metric.bulk_compute() + cnt = 0 + for k, v in computed_metrics.items(): + if 'Singing Voice' in k: + self.assertEqual(v, 1.0) + cnt += 1 + self.assertEqual(cnt, 6) + + +if __name__ == '__main__': + unittest.main() diff --git a/amt/src/tests/midi_test.py b/amt/src/tests/midi_test.py new file mode 100644 index 0000000000000000000000000000000000000000..877f99e9a2575a4b02a6a4fed5df476d838a99c6 --- /dev/null +++ b/amt/src/tests/midi_test.py @@ -0,0 +1,65 @@ +import unittest +from typing import List +from tempfile import NamedTemporaryFile +from assert_fns import assert_notes_almost_equal +from utils.note_event_dataclasses import Note + +from utils.midi import note_event2midi +from utils.midi import midi2note +from utils.note2event import note2note_event +# yapf: disable + +class TestNoteMidiConversion(unittest.TestCase): + + def test_note2midi2note_z(self): + original_notes = [ + Note(is_drum=False, program=3, onset=0., offset=1., pitch=60, velocity=1), + Note(is_drum=False, program=3, onset=1., offset=2., pitch=64, velocity=1), + ] + + with NamedTemporaryFile(suffix=".mid", delete=True) as temp_file: + # Convert original_notes to MIDI and save it to the temporary file + note_events = note2note_event(notes=original_notes, sort=True) + note_event2midi(note_events, temp_file.name, velocity=100) + + # Convert the MIDI back to notes + converted_notes, _ = midi2note(temp_file.name) + + # Compare original notes and converted notes + assert_notes_almost_equal(original_notes, converted_notes) + + def test_midi2note2midi2note_piano_z(self): + file = 'extras/examples/piano.mid' + # This MIDI file is missing the program change event, so we force it to be 0 + notes, _ = midi2note(file, quantize=False, force_all_program_to=0)[:1000] + note_events = note2note_event(notes=notes, sort=True) + note_event2midi(note_events, 'extras/examples/piano_converted.mid', velocity=100) + reconverted_notes, _ = midi2note('extras/examples/piano_converted.mid', quantize=False) + assert_notes_almost_equal(notes, reconverted_notes, delta=0.01) + + def test_midi2note2midi2note_force_drum_z(self): + file = 'extras/examples/drum.mid' + conv_file = 'extras/examples/drum_converted.mid' + # This MIDI file is missing the program change event, so we force it to be 0 + notes, _ = midi2note(file, quantize=True, force_all_drum=True)[:100] + note_events = note2note_event(notes=notes, sort=True) + note_event2midi(note_events, conv_file, velocity=100, ticks_per_beat=960) + reconverted_notes, _ = midi2note(conv_file, quantize=True, force_all_drum=True) + assert_notes_almost_equal(notes, reconverted_notes, delta=0.005) + + # In drum, this is very inaccurate. We should fix this in the future. + # Even for the first 100 notes, the timing is off by 170 ms. + + def test_midi2note_ignore_pedal_true_z(self): + file = 'extras/examples/piano.mid' + notes, _ = midi2note(file, quantize=False, ignore_pedal=True, force_all_program_to=0) + note_events = note2note_event(notes=notes, sort=True) + note_event2midi(note_events, 'extras/examples/piano_converted.mid', velocity=100) + reconverted_notes, _ = midi2note('extras/examples/piano_converted.mid', quantize=False) + assert_notes_almost_equal(notes, reconverted_notes, delta=0.01) + + +# yapf: enable + +if __name__ == '__main__': + unittest.main() diff --git a/amt/src/tests/model/ops_test.py b/amt/src/tests/model/ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..130285f212522b5d4092b672d0b018ffc1bc1227 --- /dev/null +++ b/amt/src/tests/model/ops_test.py @@ -0,0 +1,18 @@ +import unittest +import torch +import numpy as np +from model.ops import minmax_normalize + + +class TestMinMaxNormalize(unittest.TestCase): + + def test_minmax_normalize(self): + x = torch.tensor([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]) + x_norm = minmax_normalize(x) + x_norm_expected = torch.tensor([[[0.0, 0.2, 0.4], [0.6, 0.8, 1.0]]]) + + np.testing.assert_almost_equal(x_norm.numpy(), x_norm_expected.numpy(), decimal=2) + + +if __name__ == '__main__': + unittest.main() diff --git a/amt/src/tests/model/spectrogram_test.py b/amt/src/tests/model/spectrogram_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c639aa2fce8716a094eb6822265db5f86018c7ac --- /dev/null +++ b/amt/src/tests/model/spectrogram_test.py @@ -0,0 +1,29 @@ +import torch +import unittest +from model.spectrogram import Melspectrogram + + +class TestMelspectrogram(unittest.TestCase): + + def test_melspectrogram(self): + # Create a Melspectrogram instance with default parameters + melspec = Melspectrogram() + + # Create a random input tensor (B, C, T) with T = 32767 samples for 2048 ms + x = torch.randn(2, 1, 32767) + + # Compute the Melspectrogram + y = melspec(x) + + # Check the output shape + self.assertEqual(y.shape, (2, 256, 512)) + + # Check if the output contains NaN values + self.assertFalse(torch.isnan(y).any()) + + # Check if the output contains infinite values + self.assertFalse(torch.isinf(y).any()) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/amt/src/tests/note2event_test.py b/amt/src/tests/note2event_test.py new file mode 100644 index 0000000000000000000000000000000000000000..634089325f7b910b81c6db84648be7e4a33358cf --- /dev/null +++ b/amt/src/tests/note2event_test.py @@ -0,0 +1,581 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +import unittest +import pytest +import warnings +from numpy import random + +from utils.note2event import note2note_event +from utils.note2event import slice_note_events_and_ties +from utils.note2event import slice_multiple_note_events_and_ties_to_bundle +from utils.note2event import trim_overlapping_notes +from utils.note2event import note_event2event +from utils.note2event import mix_notes +from utils.note2event import validate_notes +from utils.note_event_dataclasses import Note, NoteEvent, NoteEventListsBundle +from utils.note_event_dataclasses import Event + + +# yapf: disable +class TestNoteTools(unittest.TestCase): + + def test_trim_overlapping_notes(self): + notes = [ + Note(is_drum=False, program=1, onset=0.0, offset=1.0, pitch=60, velocity=100), + Note(is_drum=False, program=1, onset=0.5, offset=1.5, pitch=60, velocity=100), + Note(is_drum=False, program=1, onset=2.0, offset=3.0, pitch=60, velocity=100) + ] + expected_notes = [ + Note(is_drum=False, program=1, onset=0.0, offset=0.5, pitch=60, velocity=100), + Note(is_drum=False, program=1, onset=0.5, offset=1.5, pitch=60, velocity=100), + Note(is_drum=False, program=1, onset=2.0, offset=3.0, pitch=60, velocity=100) + ] + + trimmed_notes = trim_overlapping_notes(notes) + + self.assertEqual(len(expected_notes), len(trimmed_notes), "Number of notes should be equal") + for e_note, t_note in zip(expected_notes, trimmed_notes): + self.assertEqual(e_note, t_note, "Trimmed note should match the expected note") + + def test_mix_notes(self): + notes1 = [ + Note(is_drum=False, program=33, onset=0.0, offset=0.5, pitch=60, velocity=1), + Note(is_drum=False, program=33, onset=1.0, offset=1.5, pitch=62, velocity=1), + Note(is_drum=True, program=128, onset=2.0, offset=2.1, pitch=36, velocity=1) + ] + notes2 = [ + Note(is_drum=False, program=52, onset=0.5, offset=1.0, pitch=40, velocity=1), + Note(is_drum=False, program=100, onset=1.5, offset=2.0, pitch=77, velocity=1), + Note(is_drum=True, program=128, onset=2.5, offset=2.6, pitch=38, velocity=1) + ] + mixed_notes = mix_notes((notes1, notes2), sort=True, trim_overlap=True, fix_offset=True) + + expected_mixed_notes = [ + Note(is_drum=False, program=33, onset=0.0, offset=0.5, pitch=60, velocity=1), + Note(is_drum=False, program=52, onset=0.5, offset=1.0, pitch=40, velocity=1), + Note(is_drum=False, program=33, onset=1.0, offset=1.5, pitch=62, velocity=1), + Note(is_drum=False, program=100, onset=1.5, offset=2.0, pitch=77, velocity=1), + Note(is_drum=True, program=128, onset=2.0, offset=2.1, pitch=36, velocity=1), + Note(is_drum=True, program=128, onset=2.5, offset=2.6, pitch=38, velocity=1) + ] + self.assertSequenceEqual(mixed_notes, expected_mixed_notes) + + def test_validate_notes(self): + DRUM_OFFSET_TIME = 0.01 # in seconds + MINIMUM_OFFSET_TIME = 0.01 # this is used to avoid zero-length notes + + notes = [ + Note(is_drum=False, program=33, onset=0.0, offset=0.5, pitch=60, velocity=1), + Note(is_drum=False, program=33, onset=1.0, offset=0.9, pitch=62, velocity=1), + Note(is_drum=True, program=128, onset=2.0, offset=2.1, pitch=36, velocity=1), + Note(is_drum=False, program=100, onset=1.5, offset=1.4, pitch=77, velocity=1) + ] + with self.assertWarns(UserWarning): + validated_notes = validate_notes(notes, fix=True) + + expected_validated_notes = [ + Note(is_drum=False, program=33, onset=0.0, offset=0.5, pitch=60, velocity=1), + Note(is_drum=False, program=33, onset=1.0, offset=1.0 + MINIMUM_OFFSET_TIME, pitch=62, velocity=1), + Note(is_drum=True, program=128, onset=2.0, offset=2.1, pitch=36, velocity=1), + Note(is_drum=False, program=100, onset=1.5, offset=1.5 + MINIMUM_OFFSET_TIME, pitch=77, velocity=1) + ] + + self.assertSequenceEqual(validated_notes, expected_validated_notes) + + + +class TestNoteEvent(unittest.TestCase): + + def test_NoteEvent(self): + note_event = NoteEvent(is_drum=False, program=33, time=0, velocity=1, pitch=60) + self.assertEqual(note_event.is_drum, False) + self.assertEqual(note_event.program, 33) + self.assertEqual(note_event.time, 0) + self.assertEqual(note_event.velocity, 1) + self.assertEqual(note_event.pitch, 60) + + ne1 = NoteEvent(True, 64, 0.5, 0, 60) + ne2 = NoteEvent(True, 64, 0.5, 0, 61) + self.assertEqual(ne1.equals_except(ne2, "pitch"), True) + self.assertEqual(ne1.equals_except(ne2, "program"), False) + self.assertEqual(ne1.equals_except(ne2, "time", "pitch"), True) + + ne1 = NoteEvent(True, 64, 0.5, 1, 60) + ne2 = NoteEvent(True, 11, 0.5, 1, 61) + self.assertEqual(ne1.equals_only(ne2, "velocity"), True) + self.assertEqual(ne1.equals_only(ne2, "time", "velocity"), True) + self.assertEqual(ne1.equals_only(ne2, "program", "velocity"), False) + + +class TestNote2NoteEvent(unittest.TestCase): + def test_note2note_event(self): + notes = [ + Note(is_drum=False, program=33, onset=0, offset=1.5, pitch=60, velocity=1), + Note(is_drum=False, program=33, onset=1.6, offset=3.0, pitch=62, velocity=1), + Note(is_drum=False, program=100, onset=1.6, offset=2.0, pitch=77, velocity=1), + Note(is_drum=True, program=128, onset=0.2, offset=0.21, pitch=36, velocity=1), + Note(is_drum=True, program=128, onset=2.5, offset=2.51, pitch=38, velocity=1) + ] + + note_events = note2note_event(notes, sort=False, return_activity=False) + self.assertSequenceEqual(note_events, + [NoteEvent(is_drum=False, program=33, time=0, velocity=1, pitch=60), + NoteEvent(is_drum=False, program=33, time=1.5, velocity=0, pitch=60), + NoteEvent(is_drum=False, program=33, time=1.6, velocity=1, pitch=62), + NoteEvent(is_drum=False, program=33, time=3.0, velocity=0, pitch=62), + NoteEvent(is_drum=False, program=100, time=1.6, velocity=1, pitch=77), + NoteEvent(is_drum=False, program=100, time=2.0, velocity=0, pitch=77), + NoteEvent(is_drum=True, program=128, time=0.2, velocity=1, pitch=36), + NoteEvent(is_drum=True, program=128, time=2.5, velocity=1, pitch=38) + ]) + + note_events = note2note_event(notes, sort=True, return_activity=True) + self.assertSequenceEqual(note_events, + [NoteEvent(is_drum=False, program=33, time=0, velocity=1, pitch=60, activity=set()), + NoteEvent(is_drum=True, program=128, time=0.2, velocity=1, pitch=36, activity={0}), + NoteEvent(is_drum=False, program=33, time=1.5, velocity=0, pitch=60, activity={0}), + NoteEvent(is_drum=False, program=33, time=1.6, velocity=1, pitch=62, activity=set()), + NoteEvent(is_drum=False, program=100, time=1.6, velocity=1, pitch=77, activity={3}), + NoteEvent(is_drum=False, program=100, time=2.0, velocity=0, pitch=77, activity={3, 4}), + NoteEvent(is_drum=True, program=128, time=2.5, velocity=1, pitch=38, activity={3}), + NoteEvent(is_drum=False, program=33, time=3.0, velocity=0, pitch=62, activity={3}) + ]) + + def test_note2note_event_invalid_velocity_value(self): + notes = [Note(is_drum=0, program=1, onset=0, offset=127, pitch=60, velocity=100)] + with self.assertRaises(ValueError): + note2note_event(notes) + + def test_note2note_event_non_empty_notes_list(self): + notes = [ + Note(is_drum=0, program=1, onset=0, offset=127, pitch=60, velocity=1), + Note(is_drum=0, program=1, onset=20, offset=127, pitch=62, velocity=1), + Note(is_drum=0, program=1, onset=40, offset=127, pitch=64, velocity=1) + ] + note_events = note2note_event(notes) + assert len(note_events) == 6 + + def test_note2note_event_sort_parameter(self): + notes = [ + Note(is_drum=0, program=10, onset=0, offset=127, pitch=64, velocity=1), + Note(is_drum=0, program=10, onset=20, offset=127, pitch=60, velocity=1), + Note(is_drum=0, program=10, onset=0, offset=127, pitch=62, velocity=1) + ] + note_events = note2note_event(notes, sort=True) + sorted_note_events = sorted( + note_events, key=lambda n_ev: (n_ev.time, n_ev.is_drum, n_ev.program, \ + n_ev.velocity, n_ev.pitch)) + assert note_events == sorted_note_events + +class TestNoteEventTools(unittest.TestCase): + + def test_slice_note_events_and_ties(self): + note_events = [ + NoteEvent(is_drum=False, program=33, time=0, velocity=1, pitch=60, activity=set()), + NoteEvent(is_drum=True, program=128, time=0.2, velocity=1, pitch=36, activity={0}), + NoteEvent(is_drum=False, program=33, time=1.5, velocity=0, pitch=60, activity={0}), + NoteEvent(is_drum=False, program=33, time=1.6, velocity=1, pitch=62, activity=set()), + NoteEvent(is_drum=False, program=100, time=1.6, velocity=1, pitch=77, activity={3}), + NoteEvent(is_drum=False, program=100, time=2.0, velocity=0, pitch=77, activity={3, 4}), + NoteEvent(is_drum=True, program=128, time=2.5, velocity=1, pitch=38, activity={3}), + NoteEvent(is_drum=False, program=33, time=3.5, velocity=0, pitch=62, activity={3}) + ] + start_time = 1.5 + end_time = 3.5 + + sliced_note_events, tie_note_events, _ = slice_note_events_and_ties(note_events, start_time, end_time) + assert len(sliced_note_events) == 5 + assert len(tie_note_events) == 1 + + # Check if the tie_note_events are as expected + expected_tie_note_events = [ + NoteEvent(is_drum=False, program=33, time=0, velocity=1, pitch=60, activity=set()), + ] + self.assertSequenceEqual(tie_note_events, expected_tie_note_events) + + # Check if the note_events are as expected + expected_sliced_note_events = [ + NoteEvent(is_drum=False, program=33, time=1.5, velocity=0, pitch=60, activity={0}), + NoteEvent(is_drum=False, program=33, time=1.6, velocity=1, pitch=62, activity=set()), + NoteEvent(is_drum=False, program=100, time=1.6, velocity=1, pitch=77, activity={3}), + NoteEvent(is_drum=False, program=100, time=2.0, velocity=0, pitch=77, activity={3, 4}), + NoteEvent(is_drum=True, program=128, time=2.5, velocity=1, pitch=38, activity={3}) + ] + self.assertSequenceEqual(sliced_note_events, expected_sliced_note_events) + + def test_slice_note_events_and_ties_tidyup(self): + note_events = [ + NoteEvent(is_drum=False, program=33, time=0, velocity=1, pitch=60, activity=set()), + NoteEvent(is_drum=True, program=128, time=0.2, velocity=1, pitch=36, activity={0}), + NoteEvent(is_drum=False, program=33, time=1.5, velocity=0, pitch=60, activity={0}), + NoteEvent(is_drum=False, program=33, time=1.6, velocity=1, pitch=62, activity=set()), + NoteEvent(is_drum=False, program=100, time=1.6, velocity=1, pitch=77, activity={3}), + NoteEvent(is_drum=False, program=100, time=2.0, velocity=0, pitch=77, activity={3, 4}), + NoteEvent(is_drum=True, program=128, time=2.5, velocity=1, pitch=38, activity={3}), + NoteEvent(is_drum=False, program=33, time=3.5, velocity=0, pitch=62, activity={3}) + ] + start_time = 1.5 + end_time = 3.5 + + sliced_note_events, tie_note_events, _ = slice_note_events_and_ties( + note_events, start_time, end_time, tidyup=True) + assert len(sliced_note_events) == 5 + assert len(tie_note_events) == 1 + + # Check if the tie_note_events are as expected + expected_tie_note_events = [ + NoteEvent(is_drum=False, program=33, time=None, velocity=1, pitch=60, activity=None), + ] + self.assertSequenceEqual(tie_note_events, expected_tie_note_events) + + # Check if the note_events are as expected + expected_sliced_note_events = [ + NoteEvent(is_drum=False, program=33, time=1.5, velocity=0, pitch=60, activity=None), + NoteEvent(is_drum=False, program=33, time=1.6, velocity=1, pitch=62, activity=None), + NoteEvent(is_drum=False, program=100, time=1.6, velocity=1, pitch=77, activity=None), + NoteEvent(is_drum=False, program=100, time=2.0, velocity=0, pitch=77, activity=None), + NoteEvent(is_drum=True, program=128, time=2.5, velocity=1, pitch=38, activity=None) + ] + self.assertSequenceEqual(sliced_note_events, expected_sliced_note_events) + + def test_slice_note_events_and_ties_empty_input(self): + note_events = [] + start_time = 1.0 + end_time = 2.5 + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + sliced_note_events, tie_note_events, _ = slice_note_events_and_ties( + note_events, start_time, end_time) + assert len(w) == 1 + assert issubclass(w[-1].category, UserWarning) + assert "empty note_events as input" in str(w[-1].message) + + assert sliced_note_events == [] + assert tie_note_events == [] + + def test_slice_note_events_and_ties_index_out_of_range(self): + note_events = [ + NoteEvent(is_drum=False, program=33, time=0.1, velocity=1, pitch=60, activity=set()), + NoteEvent(is_drum=True, program=128, time=0.2, velocity=1, pitch=36, activity={0}), + NoteEvent(is_drum=False, program=33, time=1.5, velocity=0, pitch=60, activity={0}), + NoteEvent(is_drum=True, program=128, time=6, velocity=1, pitch=36, activity=set()) + ] + + start_time = 0 + end_time = 0.1 + sliced_note_events, tie_note_events, _ = slice_note_events_and_ties( + note_events, start_time, end_time) + self.assertEqual(len(sliced_note_events), 0) + self.assertEqual(len(tie_note_events), 0) + + start_time = 0.3 + end_time = 2 + sliced_note_events, tie_note_events, _ = slice_note_events_and_ties( + note_events, start_time, end_time) + self.assertEqual(len(sliced_note_events), 1) + self.assertEqual(len(tie_note_events), 1) # drum has no offset, and activity is not counted + + start_time = 0.3 + end_time = 1 + sliced_note_events, tie_note_events, _ = slice_note_events_and_ties( + note_events, start_time, end_time) + self.assertEqual(len(sliced_note_events), 0) + self.assertEqual(len(tie_note_events), 1) # drum has no offset, and activity is not counted + + start_time = 2 + end_time = 4 + sliced_note_events, tie_note_events, _ = slice_note_events_and_ties( + note_events, start_time, end_time) + self.assertEqual(len(sliced_note_events), 0) + self.assertEqual(len(tie_note_events), 0) + + start_time = 3 + end_time = 4 + sliced_note_events, tie_note_events, _ = slice_note_events_and_ties( + note_events, start_time, end_time) + self.assertEqual(len(sliced_note_events), 0) + self.assertEqual(len(tie_note_events), 0) + + start_time = 3 + end_time = 7 + sliced_note_events, tie_note_events, _ = slice_note_events_and_ties( + note_events, start_time, end_time) + self.assertEqual(len(sliced_note_events), 1) + self.assertEqual(len(tie_note_events), 0) + + start_time = 7 + end_time = 8 + sliced_note_events, tie_note_events, _ = slice_note_events_and_ties( + note_events, start_time, end_time) + self.assertEqual(len(sliced_note_events), 0) + self.assertEqual(len(tie_note_events), 0) + + +class TestNoteEventToolsMultiSlice(unittest.TestCase): + + def setUp(self): + self.note_events = [ + NoteEvent(is_drum=False, program=33, time=0, velocity=1, pitch=60, activity=set()), + NoteEvent(is_drum=True, program=128, time=0.2, velocity=1, pitch=36, activity={0}), + NoteEvent(is_drum=False, program=33, time=1.5, velocity=0, pitch=60, activity={0}), + NoteEvent(is_drum=False, program=33, time=1.6, velocity=1, pitch=62, activity=set()), + NoteEvent(is_drum=False, program=100, time=1.6, velocity=1, pitch=77, activity={3}), + NoteEvent(is_drum=False, program=100, time=2.0, velocity=0, pitch=77, activity={3, 4}), + NoteEvent(is_drum=True, program=128, time=2.5, velocity=1, pitch=38, activity={3}), + NoteEvent(is_drum=False, program=33, time=3.5, velocity=0, pitch=62, activity={3}), + NoteEvent(is_drum=False, program=33, time=4.0, velocity=0, pitch=62, activity={3}), + NoteEvent(is_drum=False, program=50, time=5.5, velocity=1, pitch=55, activity=set()), + NoteEvent(is_drum=False, program=33, time=6.1, velocity=1, pitch=64, activity={9}), + NoteEvent(is_drum=True, program=128, time=6.5, velocity=1, pitch=36, activity={9, 10}), + NoteEvent(is_drum=False, program=33, time=7.5, velocity=0, pitch=64, activity={9, 10}) + ] + + def test_slice_note_events_and_ties_continuous_slices(self): + start_times = [0., 2, 4, 6, 8] + end_times = [2., 4, 6, 8, 10] + sliced_note_events_list = [] + sliced_tie_note_events_list = [] + for start_time, end_time in zip(start_times, end_times): + sliced_note_events, tie_note_events, t = slice_note_events_and_ties( + self.note_events, start_time, end_time) + sliced_note_events_list.extend(sliced_note_events) # merge... + sliced_tie_note_events_list.append(tie_note_events) + self.assertSequenceEqual(sliced_note_events_list, self.note_events) + self.assertEqual(len(sliced_tie_note_events_list), 5) + self.assertEqual(sliced_tie_note_events_list[0], []) # first slice always empty + self.assertEqual(sliced_tie_note_events_list[4], []) # last slice is empty in this example + + def test_slice_multiple_note_events_and_ties_to_bundle(self): + start_times = [0., 1] + duration_sec = 2. + # Create a bundle from the sliced note events + ne_bundle = slice_multiple_note_events_and_ties_to_bundle( + self.note_events, start_times, duration_sec) + # ne_bundle = NoteEventListsBundle({'note_events': sliced_note_events_list, + # 'tie_note_events': sliced_tie_note_events_list, + # 'start_times': start_times}) + expected_ne_0 = [ + NoteEvent(is_drum=False, program=33, time=0, velocity=1, pitch=60, activity=set()), + NoteEvent(is_drum=True, program=128, time=0.2, velocity=1, pitch=36, activity={0}), + NoteEvent(is_drum=False, program=33, time=1.5, velocity=0, pitch=60, activity={0}), + NoteEvent(is_drum=False, program=33, time=1.6, velocity=1, pitch=62, activity=set()), + NoteEvent(is_drum=False, program=100, time=1.6, velocity=1, pitch=77, activity={3})] + self.assertSequenceEqual(ne_bundle['note_events'][0], expected_ne_0) + expected_ne_1 = [ + NoteEvent(is_drum=False, program=33, time=1.5, velocity=0, pitch=60, activity={0}), + NoteEvent(is_drum=False, program=33, time=1.6, velocity=1, pitch=62, activity=set()), + NoteEvent(is_drum=False, program=100, time=1.6, velocity=1, pitch=77, activity={3}), + NoteEvent(is_drum=False, program=100, time=2.0, velocity=0, pitch=77, activity={3, 4}), + NoteEvent(is_drum=True, program=128, time=2.5, velocity=1, pitch=38, activity={3})] + self.assertSequenceEqual(ne_bundle['note_events'][1], expected_ne_1) + expected_tne_0 = [] + self.assertSequenceEqual(ne_bundle['tie_note_events'][0], expected_tne_0) + expected_tne_1 = [NoteEvent(is_drum=False, program=33, time=0, velocity=1, pitch=60, activity=set())] + self.assertSequenceEqual(ne_bundle['tie_note_events'][1], expected_tne_1) + self.assertEqual(ne_bundle['start_times'], start_times) + + def test_slice_multiple_note_events_and_ties_to_bundle_overlength_case(self): + # This is a case where the last slices are intended to be empty as in datasets_eval.py + start_times = [10., 11.] + duration_sec = 2. + # Create a bundle from the sliced note events + ne_bundle = slice_multiple_note_events_and_ties_to_bundle( + self.note_events, start_times, duration_sec) + expected_ne_0 = [] + expected_ne_1 = [] + self.assertSequenceEqual(ne_bundle['note_events'][0], expected_ne_0) + self.assertSequenceEqual(ne_bundle['note_events'][1], expected_ne_1) + + + + +class TestNoteEvent2Event(unittest.TestCase): + def test_note_event2event(self): + note_events = [NoteEvent(True, 128, 0.2, 1, 36), + NoteEvent(False, 33, 1.5, 0, 60), + NoteEvent(False, 33, 1.6, 1, 62), + NoteEvent(False, 100, 1.6, 1, 77), + NoteEvent(False, 100, 2.0, 0, 77), + NoteEvent(True, 128, 2.5, 1, 38), + NoteEvent(False, 33, 3.0, 0, 62)] + tie_note_events = [NoteEvent(False, 33, None, 1, 60), + NoteEvent(False, 52, None, 1, 40)] + start_time = 0.0 + tps = 100 + sort = False + events = note_event2event(note_events, tie_note_events, start_time, tps, sort) + + expected_events = [ + Event('program', 33), Event('pitch', 60), + Event('program', 52), Event('pitch', 40), + Event('tie', 0), + Event('shift', 20), Event('velocity', 1), Event('drum', 36), + Event('shift', 150), Event('program', 33), Event('velocity', 0), Event('pitch', 60), + Event('shift', 160), Event('velocity', 1), Event('pitch', 62), Event('program', 100), + Event('pitch', 77), + Event('shift', 200), Event('velocity', 0), Event('pitch', 77), + Event('shift', 250), Event('velocity', 1), Event('drum', 38), + Event('shift', 300), Event('program', 33), Event('velocity', 0), Event('pitch', 62) + ] + self.assertSequenceEqual(events, expected_events) + + def test_empty_input(self): + events = note_event2event([]) + expected_events = [Event('tie', 0)] + self.assertEqual(events, expected_events) + + events = note_event2event([], []) + self.assertEqual(events, expected_events) + + events = note_event2event([], [], 0) + self.assertEqual(events, expected_events) + + def test_single_note_event(self): + note_events = [NoteEvent(False, 33, 1.0, 1, 60)] + events = note_event2event(note_events) + expected_events = [ + Event('tie', 0), + Event('shift', 100), Event('program', 33), Event('velocity', 1), Event('pitch', 60) + ] + self.assertSequenceEqual(events, expected_events) + + def test_single_drum_event(self): + note_events = [NoteEvent(True, 128, 1.0, 1, 36)] + events = note_event2event(note_events) + expected_events = [ + Event('tie', 0), + Event('shift', 100), Event('velocity', 1), Event('drum', 36) + ] + self.assertSequenceEqual(events, expected_events) + + def test_multiple_drum_event(self): + note_events = [NoteEvent(is_drum=True, program=128, time=0.105, velocity=1, pitch=38, activity=set()), + NoteEvent(is_drum=True, program=128, time=0.11499999999999999, velocity=0, pitch=38, activity=set()), + NoteEvent(is_drum=True, program=128, time=1.5, velocity=1, pitch=38, activity=set()), + NoteEvent(is_drum=True, program=128, time=1.51, velocity=0, pitch=38, activity=set()), + NoteEvent(is_drum=True, program=128, time=2.886, velocity=1, pitch=38, activity=set()), + NoteEvent(is_drum=True, program=128, time=2.896, velocity=0, pitch=38, activity=set()), + NoteEvent(is_drum=True, program=128, time=5.528, velocity=1, pitch=38, activity=set()), + NoteEvent(is_drum=True, program=128, time=5.537999999999999, velocity=0, pitch=38, activity=set()), + NoteEvent(is_drum=True, program=128, time=7.641, velocity=1, pitch=38, activity=set()), + NoteEvent(is_drum=True, program=128, time=7.651, velocity=0, pitch=38, activity=set()), + NoteEvent(is_drum=True, program=128, time=10.413, velocity=1, pitch=38, activity=set()), + NoteEvent(is_drum=True, program=128, time=10.423, velocity=0, pitch=38, activity=set())] + + events = note_event2event(note_events, []) + expected_events = [Event(type='tie', value=0), + Event(type='shift', value=10), + Event(type='velocity', value=1), + Event(type='drum', value=38), + Event(type='shift', value=150), + Event(type='drum', value=38), + Event(type='shift', value=289), + Event(type='drum', value=38), + Event(type='shift', value=553), + Event(type='drum', value=38), + Event(type='shift', value=764), + Event(type='drum', value=38), + Event(type='shift', value=1041), + Event(type='drum', value=38)] + print(events) + self.assertSequenceEqual(events, expected_events) + + + def test_tie_note_events(self): + note_events = [NoteEvent(False, 33, 1.0, 1, 60), + NoteEvent(False, 33, 2.0, 0, 60)] + tie_note_events = [NoteEvent(False, 33, None, 1, 60)] + events = note_event2event(note_events, tie_note_events) + expected_events = [ + Event('program', 33), + Event('pitch', 60), + Event('tie', 0), + Event('shift', 100), + Event('velocity', 1), + Event('pitch', 60), + Event('shift', 200), + Event('velocity', 0), + Event('pitch', 60) + ] + self.assertSequenceEqual(events, expected_events) + + def test_rounding_behavior_in_shift(self): + note_events = [NoteEvent(False, 33, 1.000001, 1, 60), + NoteEvent(False, 12, 1.98900, 1, 60), # less than 10ms is ignored + NoteEvent(False, 11, 1.99000, 1, 60), # smaller program number first! + NoteEvent(False, 10, 1.99100, 1, 60), + NoteEvent(False, 33, 1.99999, 1, 60), # less than 10ms is ignored + NoteEvent(False, 33, 2.00001, 0, 60)] # offset first! + + events = note_event2event(note_events, sort=True) + expected_events = [ + Event('tie', 0), + Event('shift', 100), Event('program', 33), Event('velocity', 1), Event('pitch', 60), + Event('shift', 199), Event('program', 10), Event('pitch', 60), # smaller program number first! + Event('program', 11), Event('pitch', 60), + Event('program', 12), Event('pitch', 60), + Event('shift', 200), Event('program', 33), Event('velocity', 0), Event('pitch', 60), + Event('velocity', 1), Event('pitch', 60)] # offset first! + self.assertSequenceEqual(events, expected_events) + + def test_rounding_behavior_in_shift_without_sort(self): + note_events = [NoteEvent(False, 12, 1.98900, 1, 60), # less than 10ms is ignored + NoteEvent(False, 11, 1.99000, 1, 60)] + + # If sort=False, the order of events in quantized timing is not guaranteed. + # To avoid sort(), midi2note(..., quantize=True) is default. + events = note_event2event(note_events, sort=False) + expected_events = [ + Event('tie', 0), + Event('shift', 199), Event('program', 12), Event('velocity', 1), Event('pitch', 60), + Event('program', 11), Event('pitch', 60)] + self.assertSequenceEqual(events, expected_events) + + def test_rounding_behavior_in_shift_without_sort_quantized_note_event(self): + note_events = [NoteEvent(False, 11, 1.99000, 1, 60), + NoteEvent(False, 12, 1.99000, 1, 60)] + events = note_event2event(note_events, sort=False) + expected_events = [ + Event('tie', 0), + Event('shift', 199), Event('program', 11), Event('velocity', 1), Event('pitch', 60), + Event('program', 12), Event('pitch', 60)] + self.assertSequenceEqual(events, expected_events) + + +class TestNoteEvent2EventProcessTime(unittest.TestCase): + + def setUp(self): + self.note_events = [ + NoteEvent(is_drum=False, program=i % 128, time=i / 10.0, velocity=64, pitch=i % 128) + for i in range(333) + ] + + self.rand_note_events = [ + NoteEvent(is_drum=bool(random.randint(2)), program=random.randint(128), + time=random.randint(500) / 10.0, velocity=random.randint(2), + pitch=random.randint(128)) + for i in range(333) + ] + + @pytest.mark.timeout(0.1) # Set a timeout of 30 ms + def test_large_note_event_list(self): + # B = 64, Sequence_length = 333. 64 * 333 = 21248 with single cpu process + for i in range(64): + events = note_event2event(self.note_events, sort=False) + + @pytest.mark.timeout(0.1) # Set a timeout of 35 ms + def test_large_random_note_event_list_with_sort(self): + # B = 64, Sequence_length = 333. 64 * 333 = 21248 with single cpu process + for i in range(64): + events = note_event2event(self.rand_note_events, sort=True) + +# yapf: enable +if __name__ == '__main__': + unittest.main() diff --git a/amt/src/tests/note_event_roundtrip_test.py b/amt/src/tests/note_event_roundtrip_test.py new file mode 100644 index 0000000000000000000000000000000000000000..393388da4ab5094bbbcc51e3295a2246f9daa7a9 --- /dev/null +++ b/amt/src/tests/note_event_roundtrip_test.py @@ -0,0 +1,249 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +""" note_event_roundtrip_test.py: +This file contains tests for the round trip conversion between Note and +NoteEvent and Event. + +Itinerary 1: + NoteEvent → Event → Token → Event → NoteEvent + +Itinerary 2: + Note → NoteEvent → Event → Token → Event → NoteEvent → Note + +Training: + (Dataloader) NoteEvent → (augmentation) → Event → Token + +Evaluation : + (Model side) Token → Event → NoteEvent → Note → (mir_eval) + (Ground Truth) Note → (mir_eval) + + • This conversion may fail for unsorted and unquantized timing events. + • Acitivity attribute of NoteEvent is often ignorable. + +""" +import unittest +import numpy as np +from assert_fns import assert_notes_almost_equal +from assert_fns import assert_note_events_almost_equal +from assert_fns import assert_track_metrics_score1 + +from utils.note_event_dataclasses import Note, NoteEvent, Event +from utils.note2event import note2note_event, note_event2event +from utils.note2event import validate_notes, trim_overlapping_notes +from utils.event2note import event2note_event, note_event2note +from utils.tokenizer import EventTokenizer, NoteEventTokenizer +from utils.midi import note_event2midi +from utils.midi import midi2note +from utils.note2event import slice_multiple_note_events_and_ties_to_bundle +from utils.event2note import merge_zipped_note_events_and_ties_to_notes +from utils.metrics import compute_track_metrics +from config.vocabulary import GM_INSTR_FULL, SINGING_SOLO_CLASS +# yapf: disable + +class TestNoteEventRoundTrip1(unittest.TestCase): + + def setUp(self) -> None: + self.note_events = [ + NoteEvent(is_drum=False, program=33, time=0, velocity=1, pitch=60, activity=set()), + NoteEvent(is_drum=True, program=128, time=0.2, velocity=1, pitch=36, activity=set()), + NoteEvent(is_drum=False, program=33, time=1.5, velocity=0, pitch=60, activity=set()), + NoteEvent(is_drum=False, program=33, time=1.6, velocity=1, pitch=62, activity=set()), + NoteEvent(is_drum=False, program=100, time=1.6, velocity=1, pitch=77, activity=set()), + NoteEvent(is_drum=False, program=100, time=2.0, velocity=0, pitch=77, activity=set()), + NoteEvent(is_drum=True, program=128, time=2.0, velocity=1, pitch=38, activity=set()), + NoteEvent(is_drum=False, program=33, time=2.0, velocity=0, pitch=62, activity=set()) + ] + self.tokenizer = EventTokenizer() + + def test_note_event_rt_ne2e2ne(self): + """ NoteEvent → Event → NoteEvent """ + note_events = self.note_events.copy() + events = note_event2event(note_events=note_events, + tie_note_events=None, + start_time=0, sort=True) + recon_note_events, unused_tie_note_events, unsued_last_activity, err_cnt = event2note_event( + events, start_time=0, sort=True, tps=100) + + self.assertSequenceEqual(note_events, recon_note_events) + self.assertEqual(len(err_cnt), 0) + + def test_note_event_rt_ne2e2t2e2ne(self): + """ NoteEvent → Event → Token → Event → NoteEvent """ + note_events = self.note_events.copy() + events = note_event2event( + note_events=note_events, tie_note_events=None, start_time=0, sort=True) + tokens = self.tokenizer.encode(events) + events = self.tokenizer.decode(tokens) + recon_note_events, unused_tie_note_events, unsued_last_activity, err_cnt = event2note_event( + events, start_time=0, sort=True, tps=100) + + self.assertSequenceEqual(note_events, recon_note_events) + self.assertEqual(len(err_cnt), 0) + +class TestNoteEvent2(unittest.TestCase): + + def setUp(self) -> None: + notes = [ + Note(is_drum=False, program=33, onset=0, offset=1.5, pitch=60, velocity=1), + Note(is_drum=True, program=128, onset=0.2, offset=0.21, pitch=36, velocity=1), + Note(is_drum=False, program=25, onset=0.4, offset=1.1, pitch=55, velocity=1), + Note(is_drum=True, program=128, onset=1, offset=1.01, pitch=42, velocity=1), + Note(is_drum=False, program=33, onset=1.2, offset=1.8, pitch=80, velocity=1), + Note(is_drum=False, program=33, onset=1.6, offset=2.0, pitch=62, velocity=1), + Note(is_drum=False, program=100, onset=1.6, offset=2.0, pitch=77, velocity=1), + Note(is_drum=False, program=98, onset=1.7, offset=2.0, pitch=77, velocity=1), + Note(is_drum=True, program=128, onset=1.9, offset=1.91, pitch=38, velocity=1) + ] + + # Validate and trim notes to make sure they are valid. + _notes = validate_notes(notes, fix=True) + self.assertSequenceEqual(notes, _notes) + _notes = trim_overlapping_notes(notes, sort=True) + self.assertSequenceEqual(notes, _notes) + + self.notes = notes + self.tokenizer = EventTokenizer() + + + def test_note_event_rt_n2ne2e2t2e2ne2n(self): + """ Note → NoteEvent → Event → Token → Event → NoteEvent → Note """ + notes = self.notes.copy() + note_events = note2note_event(notes=notes, sort=True) + events = note_event2event(note_events=note_events, + tie_note_events=None, + start_time=0, + tps=100, + sort=True) + tokens = self.tokenizer.encode(events) + events = self.tokenizer.decode(tokens) + recon_note_events, unused_tie_note_events, unsued_last_activity, err_cnt = event2note_event( + events, start_time=0, sort=True, tps=100) + self.assertEqual(len(err_cnt), 0) + + recon_notes, err_cnt = note_event2note(note_events=recon_note_events, sort=True) + self.assertEqual(len(err_cnt), 0) + assert_notes_almost_equal(notes, recon_notes, delta=5e-3) # 5 ms on/offset tolerance + + # def test_encoding_from_midi_without_slicing_zz(self): + # """ MIDI → Note → NoteEvent → Event → Token → Event → NoteEvent → Note → MIDI """ + # src_midi_file = 'extras/examples/1727.mid' + # notes, _ = midi2note(src_midi_file, quantize=False) + # note_events = note2note_event(notes=notes, sort=True) + # events = note_event2event(note_events=note_events, + # tie_note_events=None, + # start_time=0, + # tps=100, + # sort=True) + # # check acculuated time by all the shift events + # last_shift = 0 + # for ev in events: + # if ev.type == "shift": + # last_shift = ev.value + # last_shift_in_sec = last_shift / 100 # 447.04 + # assert last_shift_in_sec == 447.04 + # # compare with the last offset time) + # last_offset_time = 0. + # for n in notes: + # if last_offset_time < n.offset: + # last_offset_time = n.offset # 447.0395833... + # self.assertAlmostEqual(last_shift_in_sec, last_offset_time, delta=1e-3) + + # tokens = self.tokenizer.encode(events) + # # reconustrction ----------------------------------------------------------- + # recon_events = self.tokenizer.decode(tokens) + # self.assertSequenceEqual(events, recon_events) + # recon_note_events, unused_tie_note_events, err_cnt = event2note_event(recon_events) + # self.assertEqual(len(err_cnt), 0) + # assert_note_events_almost_equal(note_events, recon_note_events) + # recon_notes, err_cnt = note_event2note(note_events=recon_note_events, sort=True, fix_offset=False) + # self.assertEqual(len(err_cnt), 0) + # assert_notes_almost_equal(notes, recon_notes, delta=5e-3) + # # evaluation without MIDI + # drum_metric, non_drum_metric, instr_metric = compute_track_metrics(recon_notes, notes, eval_vocab=GM_INSTR_FULL, onset_threshold=0.5) + # assert_track_metrics_score1(drum_metric) + # assert_track_metrics_score1(non_drum_metric) + # assert_track_metrics_score1(instr_metric) + + # # evaluation thourgh MIDI + # note_event2midi(recon_note_events, output_file='extras/examples/recon_1727.mid') + # re_recon_notes, _ = midi2note('extras/examples/recon_1727.mid', quantize=False) + # drum_metric, non_drum_metric, instr_metric = compute_track_metrics(re_recon_notes, notes, eval_vocab=GM_INSTR_FULL, onset_threshold=0.5) + # assert_track_metrics_score1(drum_metric) + # assert_track_metrics_score1(non_drum_metric) + # assert_track_metrics_score1(instr_metric) + + def test_encoding_from_midi_with_slicing_zz(self): + src_midi_file = 'extras/examples/2106.mid' # 'extras/examples/1727.mid'# 'extras/examples/1733.mid' # these are from musicnet_em + notes, max_time = midi2note(src_midi_file, quantize=False) + note_events = note2note_event(notes=notes, sort=True) + + # slice note events + num_segs = int(max_time * 16000 // 32757 + 1) + seg_len_sec = 32767 / 16000 + start_times = [i * seg_len_sec for i in range(num_segs)] + note_event_segments = slice_multiple_note_events_and_ties_to_bundle( + note_events, + start_times, + seg_len_sec, + ) + + # encode + tokenizer = NoteEventTokenizer() + token_array = np.zeros((num_segs, 1024), dtype=np.int32) + for i, tup in enumerate(list(zip(*note_event_segments.values()))): + padded_tokens = tokenizer.encode_plus(*tup) + token_array[i, :] = padded_tokens + + # decode: warning: Invalid pitch event without program or velocity --> solved + zipped_note_events_and_tie, list_events, err_cnt = tokenizer.decode_list_batches( + [token_array], start_times, return_events=True) + self.assertEqual(len(err_cnt), 0) + + # First check, the number of empty note_events and tie_note_events + cnt_org_empty = 0 + cnt_recon_empty = 0 + for i, (recon_note_events, recon_tie_note_events, recon_last_activity, recon_start_times) in enumerate(zipped_note_events_and_tie): + org_note_events = note_event_segments['note_events'][i] + org_tie_note_events = note_event_segments['tie_note_events'][i] + if org_note_events == []: + cnt_org_empty += 1 + if recon_note_events == []: + cnt_recon_empty += 1 + + assert len(org_note_events) == len(recon_note_events) # passed after bug fix + # self.assertEqual(len(org_tie_note_events), len(recon_tie_note_events)) + + + # Check the reconstruction of note_events + for i, (recon_note_events, recon_tie_note_events, recon_last_activity, recon_start_times) in enumerate(zipped_note_events_and_tie): + org_note_events = note_event_segments['note_events'][i] + org_tie_note_events = note_event_segments['tie_note_events'][i] + + org_note_events.sort(key=lambda n_ev: (n_ev.time, n_ev.is_drum, n_ev.program, n_ev.velocity, n_ev.pitch)) + org_tie_note_events.sort(key=lambda n_ev: (n_ev.program, n_ev.pitch)) + recon_note_events.sort(key=lambda n_ev: (n_ev.time, n_ev.is_drum, n_ev.program, n_ev.velocity, n_ev.pitch)) + recon_tie_note_events.sort(key=lambda n_ev: (n_ev.program, n_ev.pitch)) + + assert_note_events_almost_equal(org_note_events, recon_note_events) + assert_note_events_almost_equal(org_tie_note_events, recon_tie_note_events, ignore_time=True) + + # Check notes + recon_notes, err_cnt = merge_zipped_note_events_and_ties_to_notes(zipped_note_events_and_tie, fix_offset=False) + self.assertEqual(len(err_cnt), 0) + assert_notes_almost_equal(notes, recon_notes, delta=5.1e-3) + + # Check metric + drum_metric, non_drum_metric, instr_metric = compute_track_metrics( + recon_notes, notes, eval_vocab=GM_INSTR_FULL, onset_tolerance=0.005) # 5ms + self.assertEqual(non_drum_metric['onset_f'], 1.0) +# yapf: enable + +if __name__ == '__main__': + unittest.main() diff --git a/amt/src/tests/tokenizer_test.py b/amt/src/tests/tokenizer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..441a3f826253154ee704c6c1290fae8dad9234f0 --- /dev/null +++ b/amt/src/tests/tokenizer_test.py @@ -0,0 +1,126 @@ +import unittest +import pytest +from numpy import random + +from utils.note_event_dataclasses import NoteEvent, Event, EventRange +from utils.event_codec import FastCodec as Codec +from utils.tokenizer import EventTokenizer +from utils.tokenizer import NoteEventTokenizer + + +class TestEventTokenizerBase(unittest.TestCase): + + def test_encode_and_decode(self): + tokenizer = EventTokenizer() + events = [ + Event('pitch', 64), + Event('velocity', 1), + Event('tie', 0), + Event('program', 10), + Event('drum', 0) + ] + tokens = tokenizer.encode(events) + decoded_events = tokenizer.decode(tokens) + self.assertEqual(events, decoded_events) + + def test_unknown_codec_name(self): + with self.assertRaises(ValueError): + EventTokenizer(base_codec='unknown') + + def test_unknown_codec_type(self): + with self.assertRaises(TypeError): + EventTokenizer(base_codec=123) + + def test_encode_and_decode_with_custom_codec(self): + + special_tokens = ['PAD', 'EOS', 'SOS', 'T'] + max_shift_steps = 100 + event_ranges = [ + EventRange('eat', min_value=0, max_value=9), + EventRange('sleep', min_value=0, max_value=9), + EventRange('play', min_value=0, max_value=1) + ] + + my_codec = Codec(special_tokens, max_shift_steps, event_ranges) + tokenizer = EventTokenizer(my_codec) + events = [ + Event('eat', 3), + Event('shift', 9), + Event('sleep', 9), + Event('shift', 20), + Event('play', 1) + ] + tokens = tokenizer.encode(events) + + # 0~3: special tokens + # 4~103: shift tokens + # 104~112: eat tokens + # 113~121: sleep tokens + # 122~123: play tokens + expected_tokens = [107, 13, 123, 24, 125] + self.assertEqual(tokens, expected_tokens) + decoded_events = tokenizer.decode(tokens) + self.assertEqual(events, decoded_events) + + +class TestEventTokenizerBaseProcessTime(unittest.TestCase): + + def setUp(self) -> None: + self.tokenizer = EventTokenizer('mt3') + self.random_tokens = random.randint(0, 500, size=333) + self.events = [ + Event(type='pitch', value=60), + Event(type='velocity', value=1), + Event(type='program', value=0), + Event(type='shift', value=10), + Event(type='tie', value=0), + Event(type='drum', value=0), + ] * 55 + + @pytest.mark.timeout(0.008) # 32 ms --> 8 ms + def test_event_tokenizer_encode(self): + for i in range(64): + encoded = self.tokenizer.encode(self.events) + + @pytest.mark.timeout(0.01) # 40 ms --> 10 ms + def test_event_tokenizer_decode(self): + for i in range(64): + decoded = self.tokenizer.decode(self.random_tokens) + + +# yapf: disable +class NoteEventTokenizerTest(unittest.TestCase): + + def test_note_event_tokenizer_encode(self): + tokenizer = NoteEventTokenizer() + note_events = [ + NoteEvent(is_drum=False, program=33, time=0, velocity=1, pitch=60, activity=set()), + NoteEvent(is_drum=True, program=128, time=0.2, velocity=1, pitch=36, activity=set()), + NoteEvent(is_drum=False, program=33, time=1.5, velocity=0, pitch=60, activity=set()) + ] + tokens = tokenizer.encode(note_events) + decoded_events, decoded_tie_events, last_activity, err_cnt = tokenizer.decode(tokens) + self.assertSequenceEqual(note_events, decoded_events) + self.assertSequenceEqual([], decoded_tie_events) + self.assertEqual(len(last_activity), 0) + self.assertEqual(len(err_cnt), 0) + + def test_note_event_tokenizer_encode_plus(self): + tokenizer = NoteEventTokenizer() + note_events = [ + NoteEvent(is_drum=False, program=33, time=0, velocity=1, pitch=60, activity=set()), + NoteEvent(is_drum=True, program=128, time=0.2, velocity=1, pitch=36, activity=set()), + NoteEvent(is_drum=False, program=33, time=1.5, velocity=0, pitch=60, activity=set()) + ] + tokens = tokenizer.encode_plus(note_events, max_length=30) + decoded_events, decoded_tie_events, last_activity, err_cnt = tokenizer.decode(tokens) + self.assertSequenceEqual(note_events, decoded_events) + self.assertSequenceEqual([], decoded_tie_events) + self.assertEqual(len(last_activity), 0) + self.assertEqual(len(err_cnt), 0) + + + +# yapf: enable +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/amt/src/tests/utils_test.py b/amt/src/tests/utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ca5c79d12e63f1712e79c7202d75f944c8dad8bf --- /dev/null +++ b/amt/src/tests/utils_test.py @@ -0,0 +1,124 @@ +import os +import unittest +from unittest.mock import patch +from io import BytesIO +from utils.utils import download_and_extract +from utils.utils import get_checksum +from utils.utils import merge_file_lists +from utils.utils import reindex_file_list_keys +from utils.utils import remove_ids_from_file_list +from utils.utils import deduplicate_splits + + +class TestMergeFileListFunctions(unittest.TestCase): + + def test_merge_file_lists(self): + # Define some example input dictionaries + file_list_1 = {1: 'file1.txt', 2: 'file2.txt'} + file_list_2 = {3: 'file3.txt', 4: 'file4.txt'} + file_list_3 = {5: 'file5.txt', 6: 'file6.txt'} + + # Call the merge_file_lists function with the example input + merged_file_list = merge_file_lists([file_list_1, file_list_2, file_list_3]) + + # Check that the merged dictionary has the correct length and keys/values + self.assertEqual(len(merged_file_list), 6) + self.assertEqual(merged_file_list[0], 'file1.txt') + self.assertEqual(merged_file_list[1], 'file2.txt') + self.assertEqual(merged_file_list[2], 'file3.txt') + self.assertEqual(merged_file_list[3], 'file4.txt') + self.assertEqual(merged_file_list[4], 'file5.txt') + self.assertEqual(merged_file_list[5], 'file6.txt') + + def test_reindex_file_list_keys(self): + file_list = {'a': {'id': 1, 'name': 'file1'}, 'b': {'id': 2, 'name': 'file2'}} + expected_reindexed = {0: {'id': 1, 'name': 'file1'}, 1: {'id': 2, 'name': 'file2'}} + reindexed = reindex_file_list_keys(file_list) + self.assertEqual(reindexed, expected_reindexed) + + def test_remove_ids_from_file_list(self): + file_list = { + 'a': { + 'music_id': 123, + 'name': 'file1' + }, + 'b': { + 'music_id': 222, + 'name': 'file2' + } + } + selected_ids = [123] + expected_filtered = {0: {'music_id': 222, 'name': 'file2'}} + filtered = remove_ids_from_file_list(file_list, selected_ids, reindex=True) + self.assertEqual(filtered, expected_filtered) + + +class TestGetChecksum(unittest.TestCase): + + def test_get_checksum_z(self): + # Create a temporary file with some content + file_name = "temp_test_file.txt" + with open(file_name, "w") as f: + f.write("This is a test file") + + # Calculate the expected checksum using an online md5 calculator or a known md5 value + expected_checksum = "0b26e313ed4a7ca6904b0e9369e5b957" + + # Call the get_checksum function + calculated_checksum = get_checksum(file_name) + + # Compare the expected and calculated checksums + self.assertEqual(expected_checksum, calculated_checksum) + + # Clean up the temporary file + os.remove(file_name) + + +class TestDeduplicateSplits(unittest.TestCase): + + def test_deduplicate_splits(self): + # Create sample file lists for splits A and B + file_list_a = { + 'split1': { + 'some_id': 1, + 'file_name': 'a.jpg' + }, + 'split2': { + 'some_id': 2, + 'file_name': 'b.jpg' + }, + 'split3': { + 'some_id': 3, + 'file_name': 'c.jpg' + } + } + file_list_b = { + 'split4': { + 'some_id': 2, + 'file_name': 'd.jpg' + }, + 'split5': { + 'some_id': 3, + 'file_name': 'e.jpg' + }, + 'split6': { + 'some_id': 6, + 'file_name': 'f.jpg' + } + } + + # Remove duplicates between split A and split B + filtered_file_list_a = deduplicate_splits(file_list_a, file_list_b, reindex=False) + + # Check that the correct IDs have been removed from split A + expected_file_list_a = { + 'split1': { + 'some_id': 1, + 'file_name': 'a.jpg' + }, + } + self.assertDictEqual(filtered_file_list_a, expected_file_list_a) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/amt/src/train.py b/amt/src/train.py new file mode 100644 index 0000000000000000000000000000000000000000..9981753ec1437a3b387a6823cd2e74205e34a527 --- /dev/null +++ b/amt/src/train.py @@ -0,0 +1,184 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +import argparse +# from packaging.version import parse as VersionParse + +import torch +from utils.data_modules import AMTDataModule +from utils.task_manager import TaskManager +from model.init_train import initialize_trainer, update_config +from model.ymt3 import YourMT3 +from config.data_presets import data_preset_single_cfg, data_preset_multi_cfg +from config.vocabulary import program_vocab_presets +from utils.utils import str2bool + +# yapf: disable +parser = argparse.ArgumentParser(description="YourMT3") +# General +parser.add_argument('exp_id', type=str, help='A unique identifier for the experiment is used to resume training. The "@" symbol can be used to load a specific checkpoint.') +parser.add_argument('-p', '--project', type=str, default='ymt3', help='project name') +parser.add_argument('-d', '--data-preset', type=str, default='musicnet_thickstun_ext_em', help='dataset preset (default=musicnet_thickstun_ext_em). See config/data.py for more options.') +# Intra-stem augmentation +parser.add_argument('-amp', '--random-amp-range', nargs=2, type=float, default=None, help='random amp range for audio augmentation (default=None, using default value defined in config.py). In command line, use -amp 0.6 1.2') +parser.add_argument('-iaug', '--stem-iaug-prob', type=float, default=None, help='intra-stem augmentation probability (default follow config.py). p=1.0 means no intra-stem augmentation (no stems are dropped)') +# Cross-stem augmentation policy +parser.add_argument('-xk', '--xaug-max-k', type=int, default=None, help='max number of external sources used for cross-stem augmentations. Default follows config.py.') +parser.add_argument('-xtau', '--xaug-tau', type=float, default=None, help='exponential decay rate for cross-stem augmentation. Default follows config.py') +parser.add_argument('-xalpha', '--xaug-alpha', type=float, default=None, help='shape parameter for Weibull distribution. set 1.0 for exponential. Default follows config.py') +parser.add_argument('-xnio', '--xaug-no-instr-overlap', type=str2bool, default=None, help='No instrument overlap flag. Default follows config.py') +parser.add_argument('-xndo', '--xaug-no-drum-overlap', type=str2bool, default=None, help='No drum overlap flag. Default follows config.py') +parser.add_argument('-xiaug', '--uhat-intra-stem_augment', type=str2bool, default=None, help='uhat intra-stem augmentation flag. Default follows config.py') +# Post-mix augmentation (post-mixing) +parser.add_argument('-ps', '--pitch-shift-range', nargs=2, type=int, default=None, help='pitch shift range in semitones (default=None). If None, default value defined in config.py. [0, 0] disables pitch shift. In command line, use -ps -2 2') +# Audio configurations +parser.add_argument('-ac', '--audio-codec', type=str, default=None, help='audio codec (default=None). {"spec", "melspec"}. If None, default value defined in config.py will be used.') +parser.add_argument('-hop', '--hop-length', type=int, default=None, help='hop length in frames (default=None). {128, 300} 128 for MT3, 300 for PerceiverTFIf None, default value defined in config.py will be used.') +parser.add_argument('-nmel', '--n-mels', type=int, default=None, help='number of mel bins (default=None). If None, default value defined in config.py will be used.') +parser.add_argument('-if', '--input-frames', type=int, default=None, help='number of audio frames for input segment (default=None). If None, default value defined in config.py will be used.') +# Model configurations +parser.add_argument('-sqr', '--sca-use-query-residual', type=str2bool, default=None, help='sca use query residual flag. Default follows config.py') +parser.add_argument('-enc', '--encoder-type', type=str, default=None, help="Encoder type. 't5' or 'perceiver-tf' or 'conformer'. Default is 't5', following config.py.") +parser.add_argument('-dec', '--decoder-type', type=str, default=None, help="Decoder type. 't5' or 'multi-t5'. Default is 't5', following config.py.") +parser.add_argument('-preenc', '--pre-encoder-type', type=str, default='default', help="Pre-encoder type. None or 'conv' or 'default' or 'conv1d_t' or 'conv1d_f' or 'hftt' or 'res3b_hftt'. By default, t5_enc:None, perceiver_tf_enc:conv, conformer:None") +parser.add_argument('-predec', '--pre-decoder-type', type=str, default='default', help="Pre-decoder type. {None, 'linear', 'conv1', 'mlp', 'group_linear'} or 'default'. Default is {'t5': None, 'perceiver-tf': 'linear', 'conformer': None}.") +parser.add_argument('-cout', '--conv-out-channels', type=int, default=None, help='Number of filters for pre-encoder conv layer. Default follows "model_cfg" of config.py.') +parser.add_argument('-tenc', '--task-cond-encoder', type=str2bool, default=True, help='task conditional encoder (default=True). True or False') +parser.add_argument('-tdec', '--task-cond-decoder', type=str2bool, default=True, help='task conditional decoder (default=True). True or False') +parser.add_argument('-df', '--d-feat', type=int, default=None, help='Audio feature will be projected to this dimension for Q,K,V of T5 or K,V of Perceiver (default=None). If None, default value defined in config.py will be used.') +parser.add_argument('-pt', '--pretrained', type=str2bool, default=False, help='pretrained T5(default=False). True or False') +parser.add_argument('-b', '--base-name', type=str, default="google/t5-v1_1-small", help='base model name (default="google/t5-v1_1-small")') +parser.add_argument('-edr', '--encoder-dropout-rate', type=float, default=None, help='encoder dropout rate (default=None). If None, default rate defined in config will be used.') +parser.add_argument('-ddr', '--decoder-dropout-rate', type=float, default=None, help='decoder dropout rate (default=None). If None, default rate defined in config will be used.') +parser.add_argument('-epe', '--encoder-position-encoding-type', type=str, default='default', help="Positional encoding type of encoder. By default, pre-defined PE for T5 or Perceiver-TF encoder in config.py. For T5: {'sinusoidal', 'trainable'}, conformer: {'rotary', 'trainable'}, Perceiver-TF: {'trainable', 'rope', 'alibi', 'alibit', 'None', '0', 'none', 'tkd', 'td', 'tk', 'kdt'}.") +parser.add_argument('-dpe', '--decoder-position-encoding-type', type=str, default='default', help="Positional encoding type of decoder. By default, pre-defined PE for T5 in config.py. {'sinusoidal', 'trainable'}.") +parser.add_argument('-twe', '--tie-word-embedding', type=str2bool, default=None, help='tie word embedding (default=None). If None, default value defined in config.py will be used.') +parser.add_argument('-el', '--event-length', type=int, default=None, help='event length (default=None). If None, default value defined in model cfg of config.py will be used.') +# Perceiver-TF configurations +parser.add_argument('-dl', '--d-latent', type=int, default=None, help='Latent dimension of Perceiver. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.') +parser.add_argument('-nl', '--num-latents', type=int, default=None, help='Number of latents of Perceiver. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.') +parser.add_argument('-dpm', '--perceiver-tf-d-model', type=int, default=None, help='Perceiver-TF d_model (default=None). If None, default value defined in config.py will be used.') +parser.add_argument('-npb', '--num-perceiver-tf-blocks', type=int, default=None, help='Number of blocks of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py.') +parser.add_argument('-npl', '--num-perceiver-tf-local-transformers-per-block', type=int, default=None, help='Number of local layers per block of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.') +parser.add_argument('-npt', '--num-perceiver-tf-temporal-transformers-per-block', type=int, default=None, help='Number of temporal layers per block of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.') +parser.add_argument('-atc', '--attention-to-channel', type=str2bool, default=None, help='Attention to channel flag of Perceiver-TF. On T5, this will be ignored (default=None). If None, default value defined in config.py will be used.') +parser.add_argument('-ln', '--layer-norm-type', type=str, default=None, help='Layer normalization type (default=None). {"layer_norm", "rms_norm"}. If None, default value defined in config.py will be used.') +parser.add_argument('-ff', '--ff-layer-type', type=str, default=None, help='Feed forward layer type (default=None). {"mlp", "moe", "gmlp"}. If None, default value defined in config.py will be used.') +parser.add_argument('-wf', '--ff-widening-factor', type=int, default=None, help='Feed forward layer widening factor for MLP/MoE/gMLP (default=None). If None, default value defined in config.py will be used.') +parser.add_argument('-nmoe', '--moe-num-experts', type=int, default=None, help='Number of experts for MoE (default=None). If None, default value defined in config.py will be used.') +parser.add_argument('-kmoe', '--moe-topk', type=int, default=None, help='Top-k for MoE (default=None). If None, default value defined in config.py will be used.') +parser.add_argument('-act', '--hidden-act', type=str, default=None, help='Hidden activation function (default=None). {"gelu", "silu", "relu", "tanh"}. If None, default value defined in config.py will be used.') +parser.add_argument('-rt', '--rotary-type', type=str, default=None, help='Rotary embedding type expressed in three letters. e.g. ppl: "pixel" for SCA and latents, "lang" for temporal transformer. If None, use config.') +parser.add_argument('-rk', '--rope-apply-to-keys', type=str2bool, default=None, help='Apply rope to keys (default=None). If None, use config.') +parser.add_argument('-rp', '--rope-partial-pe', type=str2bool, default=None, help='Whether to apply RoPE to partial positions (default=None). If None, use config.') +# Decoder configurations +parser.add_argument('-dff', '--decoder-ff-layer-type', type=str, default=None, help='Feed forward layer type of decoder (default=None). {"mlp", "moe", "gmlp"}. If None, default value defined in config.py will be used.') +parser.add_argument('-dwf', '--decoder-ff-widening-factor', type=int, default=None, help='Feed forward layer widening factor for decoder MLP/MoE/gMLP (default=None). If None, default value defined in config.py will be used.') +# Task and Evaluation configurations +parser.add_argument('-tk', '--task', type=str, default='mt3_full_plus', help='tokenizer type (default=gm_ext_plus). See config/task.py for more options.') +parser.add_argument('-epv', '--eval-program-vocab', type=str, default=None, help='evaluation program vocabulary (default=None). If None, default vocabulary of the data preset will be used.') +parser.add_argument('-w', '--write-model-output', type=str2bool, default=False, help='write model test output to file (default=False). True or False') +# Trainer configurations +parser.add_argument('-bsz', '--train-batch-size', nargs=2, type=int, default=None, help='train batch size for sub and local (default=None) per GPU. e.g. "-bsz 6 12". If None, default value defined in config.py will be used.') +parser.add_argument('-pr','--precision', type=str, default="bf16-mixed", help='precision (default="bf16-mixed") {32, 16, bf16, bf16-mixed}') +parser.add_argument('-st', '--strategy', type=str, default='auto', help='strategy (default=auto). auto or deepspeed or ddp') +parser.add_argument('-sb', '--sync-batchnorm', type=str2bool, default=False, help='sync batchnorm (default=True). True or False') +parser.add_argument('-se', '--train-num-samples-per-epoch', type=int, default=90000, help='number of samples per epoch (default=96000). If None, use the total number of files in multi datasets.') +parser.add_argument('-e', '--max-epochs', type=int, default=None, help='number of max epochs (default is None, which is 1000).') +parser.add_argument('-it', '--max-steps', type=int, default=-1, help='number of max steps (default is -1, disabled). This overrides the number of total steps defined in config.') +parser.add_argument('-vit', '--val-interval', type=int, default=None, help='validation interval (default=None). If None, use the check_val_every_n_epoch defined in config.py') +parser.add_argument('-lr', '--base-learning-rate', type=float, default=None, help='base learning rate (default is 1e-03 for AdamW, and auto for AdaFactor)') +parser.add_argument('-o', '--optimizer', type=str, default='AdamWScale', help='optimizer (default=AdamWScale) or AdaFactor or AdamW or CPUAdam or DAdaptAdam. Only check lowercase.') +parser.add_argument('-s', '--scheduler', type=str, default='cosine', help='scheduler name (default=legacy), constant or legacy or cosine') +parser.add_argument('-n', '--num-nodes', type=int, default=1, help='number of nodes (default=1)') +parser.add_argument('-g', '--num-gpus', type=str, default='auto', help='number of gpus (default="auto")') +parser.add_argument('-wb', '--wandb-mode', type=str, default=None, help='wandb mode for logging (default=None). "disabled" or "online" or "offline". If None, default value defined in config.py will be used.') +args = parser.parse_args() +# yapf: enable +if torch.__version__ >= "1.13": + torch.set_float32_matmul_precision("high") + +# Initialize trainer +trainer, wandb_logger, dir_info, shared_cfg = initialize_trainer(args, stage='train') + +# Update config with args, including augmentation settings +shared_cfg, audio_cfg, model_cfg = update_config(args, shared_cfg, stage='train') + + +def main(): + # Data preset + if args.data_preset in data_preset_single_cfg: + # convert single preset into multi preset format + data_preset = { + "presets": [args.data_preset], + "eval_vocab": data_preset_single_cfg[args.data_preset]["eval_vocab"], + } + for k in data_preset_single_cfg[args.data_preset].keys(): + if k in ["eval_drum_vocab", "add_pitch_class_metric"]: + data_preset[k] = data_preset_single_cfg[args.data_preset][k] + elif args.data_preset in data_preset_multi_cfg: + data_preset = data_preset_multi_cfg[args.data_preset] + else: + raise ValueError(f"Invalid data preset: {args.data_preset}") + + # Task manager + tm = TaskManager(task_name=args.task, max_shift_steps=int(shared_cfg["TOKENIZER"]["max_shift_steps"])) + print(f"Task: {tm.task_name}, Max Shift Steps: {tm.max_shift_steps}") + + # Vocabulary for validation + if args.eval_program_vocab != None: + eval_program_vocab = program_vocab_presets[args.eval_program_vocab] + else: + eval_program_vocab = data_preset["eval_vocab"] + eval_drum_vocab = data_preset.get("eval_drum_vocab", None) + + dm = AMTDataModule(data_preset_multi=data_preset, + task_manager=tm, + train_num_samples_per_epoch=args.train_num_samples_per_epoch, + audio_cfg=audio_cfg, + **shared_cfg["AUGMENTATION"]) + + model = YourMT3( + audio_cfg=audio_cfg, + model_cfg=model_cfg, + shared_cfg=shared_cfg, + pretrained=args.pretrained, + optimizer_name="CPUAdam" if "offload" in args.strategy.lower() else args.optimizer, + scheduler_name=args.scheduler.lower(), + base_lr=float(args.base_learning_rate) if args.base_learning_rate != None else None, + max_steps=int(args.max_steps), + task_manager=tm, # tokenizer is a member of task_manager + eval_vocab=eval_program_vocab, + eval_drum_vocab=eval_drum_vocab, + write_output_dir=dir_info["lightning_dir"] if args.write_model_output else None, + add_pitch_class_metric=data_preset.get("add_pitch_class_metric", None)) + + # if VersionParse(torch.__version__) >= VersionParse("2.1"): + # model = torch.compile(model, mode="reduce-overhead") + + # Logging config updated by args + if trainer.global_rank == 0: + wandb_logger.experiment.config.update({"audio_cfg": model.audio_cfg}, allow_val_change=True) + wandb_logger.experiment.config.update({"model_cfg": model.model_cfg}, allow_val_change=True) + wandb_logger.experiment.config.update(model.shared_cfg, allow_val_change=True) + + wandb_logger.watch(model, log='gradients', log_freq=5000) + + # last_ckpt_path can be None + if dir_info["last_ckpt_path"] is not None: + checkpoint = torch.load(dir_info["last_ckpt_path"]) + state_dict = checkpoint['state_dict'] + model.load_state_dict(state_dict, strict=False) + trainer.fit(model, datamodule=dm) + else: + trainer.fit(model, ckpt_path=dir_info["last_ckpt_path"], datamodule=dm) + + +if __name__ == "__main__": + main() diff --git a/amt/src/utils/README.md b/amt/src/utils/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1879d8153c72960ec86687fa8202589274ccf8c9 --- /dev/null +++ b/amt/src/utils/README.md @@ -0,0 +1,22 @@ +# YourMT3: Utils + + +## CachedAudioDataset + +```mermaid +graph TB + A[Call __getitem__]:::main --> B1(Update cache):::process + A --> B2(Get segments from cache):::process + B1 --> C1[Load & cut audio]:::subprocess + C1 --> C2[Load & cut note events]:::subprocess + C2 --> C3[Augment data]:::subprocess + C3 --> C4[Tokenize & pad events]:::subprocess + C4 --> C5[Save to cache]:::subprocess + B2 --> D1[Return audio segments]:::output + B2 --> D2[Return tokens]:::output + + classDef main fill:#FED7E2,stroke:#000000; + classDef process fill:#FEE2E2,stroke:#000000; + classDef subprocess fill:#E0F0F4,stroke:#000000; + classDef output fill:#F0E6EF,stroke:#000000; +``` diff --git a/amt/src/utils/audio.py b/amt/src/utils/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..02f0ac6edc6df50fd08734d72f7d281625364df3 --- /dev/null +++ b/amt/src/utils/audio.py @@ -0,0 +1,309 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +"""audio.py""" +import os +import subprocess +import numpy as np +import wave +import math +from typing import Tuple, List +from numpy.lib.stride_tricks import as_strided + + +def load_audio_file(filename: str, + seg_start_sec: float = 0., + seg_length_sec: float = 0., + fs: int = 16000, + dtype: np.dtype = np.float64) -> np.ndarray: + """Load audio file and return the segment of audio.""" + start_frame_idx = int(np.floor(seg_start_sec * fs)) + seg_length_frame = int(np.floor(seg_length_sec * fs)) + end_frame_idx = start_frame_idx + seg_length_frame + + file_ext = filename[-3:] + + if file_ext == 'wav': + with wave.open(filename, 'r') as f: + f.setpos(start_frame_idx) + if seg_length_sec == 0: + x = f.readframes(f.getnframes()) + else: + x = f.readframes(end_frame_idx - start_frame_idx) + + if dtype == np.float64: + x = np.frombuffer(x, dtype=np.int16) / 2**15 + elif dtype == np.float32: + x = np.frombuffer(x, dtype=np.int16) / 2**15 + x = x.astype(np.float32) + elif dtype == np.int16: + x = np.frombuffer(x, dtype=np.int16) + elif dtype is None: + pass + else: + raise NotImplementedError(f"Unsupported dtype: {dtype}") + else: + raise NotImplementedError(f"Unsupported file extension: {file_ext}") + + return x + + +def get_audio_file_info(filename: str) -> Tuple[int, int, int]: + """Get audio file info. + + Args: + filename: path to the audio file + Returns: + fs: sampling rate + n_frames: number of frames + n_channels: number of channels + + """ + file_ext = filename[-3:] + + if file_ext == 'wav': + with wave.open(filename, 'r') as f: + fs = f.getframerate() + n_frames = f.getnframes() + n_channels = f.getnchannels() + else: + raise NotImplementedError(f"Unsupported file extension: {file_ext}") + + return fs, n_frames, n_channels + + +def get_segments_from_numpy_array(arr: np.ndarray, + slice_length: int, + start_frame_indices: List[int], + dtype: np.dtype = np.float32) -> np.ndarray: + """Get random audio slices from numpy array. + + Args: + arr: numpy array of shape (c, n_frames) + slice_length: length of the slice + start_frame_indices: list of m start frames + Returns: + slices: numpy array of shape (m, c, slice_length) + """ + c, max_length = arr.shape + max_length = arr.shape[1] + m = len(start_frame_indices) + + slices = np.zeros((m, c, slice_length), dtype=dtype) + for i, start_frame in enumerate(start_frame_indices): + end_frame = start_frame + slice_length + assert (end_frame <= max_length - 1) + slices[i, :, :] = arr[:, start_frame:end_frame].astype(dtype) + return slices + + +def slice_padded_array(x: np.ndarray, slice_length: int, slice_hop: int, pad: bool = True) -> np.ndarray: + """ + Slices the input array into overlapping windows based on the given slice length and slice hop. + + Args: + x: The input array to be sliced. + slice_length: The length of each slice. + slice_hop: The number of elements between the start of each slice. + pad: If True, the last slice will be padded with zeros if necessary. + + Returns: + A numpy array with shape (n_slices, slice_length) containing the slices. + """ + num_slices = (x.shape[1] - slice_length) // slice_hop + 1 + remaining = (x.shape[1] - slice_length) % slice_hop + + if pad and remaining > 0: + padding = np.zeros((x.shape[0], slice_length - remaining)) + x = np.hstack((x, padding)) + num_slices += 1 + + shape: Tuple[int, int] = (num_slices, slice_length) + strides: Tuple[int, int] = (slice_hop * x.strides[1], x.strides[1]) + sliced_x = as_strided(x, shape=shape, strides=strides) + + return sliced_x + + +def slice_padded_array_for_subbatch(x: np.ndarray, + slice_length: int, + slice_hop: int, + pad: bool = True, + sub_batch_size: int = 1, + dtype: np.dtype = np.float32) -> np.ndarray: + """ + Slices the input array into overlapping windows based on the given slice length and slice hop, + and pads it to make the output divisible by the sub_batch_size. + + NOTE: This method is currently not used. + + Args: + x: The input array to be sliced, such as (1, n_frames). + slice_length: The length of each slice. + slice_hop: The number of elements between the start of each slice. + pad: If True, the last slice will be padded with zeros if necessary. + sub_batch_size: The desired number of slices to be divisible by. + + Returns: + A numpy array with shape (n_slices, slice_length) containing the slices. + """ + num_slices = (x.shape[1] - slice_length) // slice_hop + 1 + remaining = (x.shape[1] - slice_length) % slice_hop + + if pad and remaining > 0: + padding = np.zeros((x.shape[0], slice_length - remaining), dtype=dtype) + x = np.hstack((x, padding)) + num_slices += 1 + + # Adjust the padding to make n_slices divisible by sub_batch_size + if pad and num_slices % sub_batch_size != 0: + additional_padding_needed = (sub_batch_size - (num_slices % sub_batch_size)) * slice_hop + additional_padding = np.zeros((x.shape[0], additional_padding_needed), dtype=dtype) + x = np.hstack((x, additional_padding)) + num_slices += (sub_batch_size - (num_slices % sub_batch_size)) + + shape: Tuple[int, int] = (num_slices, slice_length) + strides: Tuple[int, int] = (slice_hop * x.strides[1], x.strides[1]) + sliced_x = as_strided(x, shape=shape, strides=strides) + + return sliced_x + + +def pitch_shift_audio(src_audio_file: os.PathLike, + min_pitch_shift: int = -5, + max_pitch_shift: int = 6, + random_microshift_range: tuple[int, int] = (-10, 11)): + """ + Pitch shift audio file using the Sox command-line tool. + + NOTE: This method is currently not used. Previously, we used this for + offline augmentation for GuitarSet. + + Args: + src_audio_file: Path to the input audio file. + min_pitch_shift: Minimum pitch shift in semitones. + max_pitch_shift: Maximum pitch shift in semitones. + random_microshift_range: Range of random microshifts to apply in tenths of a semitone. + + Returns: + None + + Raises: + CalledProcessError: If the Sox command fails to execute. + + """ + + # files + src_audio_dir = os.path.dirname(src_audio_file) + src_audio_filename = os.path.basename(src_audio_file).split('.')[0] + + # load source audio + try: + audio = load_audio_file(src_audio_file, dtype=np.int16) + audio = audio / 2**15 + audio = audio.astype(np.float16) + except Exception as e: + print(f"Failed to load audio file: {src_audio_file}. {e}") + return + + # pitch shift audio for each semitone in the range + for pitch_shift in range(min_pitch_shift, max_pitch_shift): + if pitch_shift == 0: + continue + + # pitch shift audio by sox + dst_audio_file = os.path.join(src_audio_dir, f'{src_audio_filename}_pshift{pitch_shift}.wav') + shift_semitone = 100 * pitch_shift + np.random.randint(*random_microshift_range) + + # build Sox command + command = ['sox', src_audio_file, '-r', '16000', dst_audio_file, 'pitch', str(shift_semitone)] + + try: + # execute Sox command and check for errors + subprocess.run(command, check=True) + print(f"Created {dst_audio_file}") + except subprocess.CalledProcessError as e: + print(f"Failed to pitch shift audio file: {src_audio_file}, pitch_shift: {pitch_shift}. {e}") + + +def write_wav_file(filename: str, x: np.ndarray, samplerate: int = 16000) -> None: + """ + Write a mono PCM WAV file from a NumPy array of audio samples. + + Args: + filename (str): The name of the WAV file to be created. + x (np.ndarray): A 1D NumPy array containing the audio samples to be written to the WAV file. + The audio samples should be in the range [-1, 1]. + samplerate (int): The sample rate (in Hz) of the audio samples. + + Returns: + None + """ + # Set the WAV file parameters + nchannels = 1 # Mono + sampwidth = 2 # 16-bit + framerate = samplerate + nframes = len(x) + + # Scale the audio samples to the range [-32767, 32767] + x_scaled = np.array(x * 32767, dtype=np.int16) + + # Set the buffer size for writing the WAV file + BUFFER_SIZE = 1024 + + # Open the WAV file for writing + with wave.open(filename, "wb") as wav_file: + # Set the WAV file parameters + wav_file.setparams((nchannels, sampwidth, framerate, nframes, "NONE", "NONE")) + + # Write the audio samples to the file in chunks + for i in range(0, len(x_scaled), BUFFER_SIZE): + # Get the next chunk of audio samples + chunk = x_scaled[i:i + BUFFER_SIZE] + + # Convert the chunk of audio samples to a byte string and write it to the WAV file + wav_file.writeframes(chunk.tobytes()) + + # Close the WAV file + wav_file.close() + + +def guess_onset_offset_by_amp_envelope(x, fs=16000, onset_threshold=0.05, offset_threshold=0.02, frame_size=256): + """ Guess onset/offset from audio signal x """ + amp_env = [] + num_frames = math.floor(len(x) / frame_size) + for t in range(num_frames): + lower = t * frame_size + upper = (t + 1) * frame_size - 1 + # Find maximum of each frame and add it to our array + amp_env.append(np.max(x[lower:upper])) + amp_env = np.array(amp_env) + # Find the first index where the amplitude envelope is greater than the threshold + onset = np.where(amp_env > onset_threshold)[0][0] * frame_size + offset = (len(amp_env) - 1 - np.where(amp_env[::-1] > offset_threshold)[0][0]) * frame_size + return onset, offset, amp_env + + +# from pydub import AudioSegment +# def convert_flac_to_wav(input_path, output_path): +# # Load FLAC file using Pydub +# sound = AudioSegment.from_file(input_path, format="flac") + +# # Set the parameters for the output WAV file +# channels = 1 # mono +# sample_width = 2 # 16-bit +# frame_rate = 16000 + +# # Convert the input sound to the specified format +# sound = sound.set_frame_rate(frame_rate) +# sound = sound.set_channels(channels) +# sound = sound.set_sample_width(sample_width) + +# # Save the output WAV file to the specified path +# sound.export(output_path, format="wav") diff --git a/amt/src/utils/augment.py b/amt/src/utils/augment.py new file mode 100644 index 0000000000000000000000000000000000000000..52774016c48d48525f4d4df1c259c52356ecb24a --- /dev/null +++ b/amt/src/utils/augment.py @@ -0,0 +1,743 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +"""augment.py""" +import numpy as np +import random +from collections import defaultdict +from typing import Optional, Tuple, Union, Callable, Literal, DefaultDict, Set, Any, Dict, List +from utils.note_event_dataclasses import NoteEvent, NoteEventListsBundle +from utils.note2event import check_event_len_from_bundle, mix_note_event_lists_bundle, separate_by_subunit_programs_from_note_event_lists_bundle +from utils.utils import dict_iterator, extend_dict +from copy import deepcopy + +EPS = 1e-7 +DRUM_PROGRAM = 128 +UNANNOTATED_PROGRAM = 129 + +# ------------------------------------------------------------------------------------- +# shared augmentation helper functions +# ------------------------------------------------------------------------------------- + + +def audio_random_submix_fn(x: np.ndarray, + random_amp_range: Optional[List[float]] = None, + mask: Optional[np.ndarray] = None, + normalize: bool = True, + dtype: np.dtype = np.float32) -> Tuple[np.ndarray, np.ndarray]: + """ + Randomly submix audio. This function supports batch-wise matrix processing. + + Parameters: + - x (np.ndarray): Input audio tensor with shape (b, c, t). + - random_amp_range (List[float], optional): A list containing [min_amp, max_amp]. + Defaults to [0.6, 1.2]. + - mask (np.ndarray, optional): Mask tensor with shape (b, c). Defaults to None. + - dtype (np.dtype): Data type for computations. Defaults to np.float32. + + Returns: + - Tuple[np.ndarray, np.ndarray]: Processed audio (stems, mix). + """ + b, c, t = x.shape + + if random_amp_range is None: + random_amp_range = [0.6, 1.2] + + if len(random_amp_range) == 2: + min_w, max_w = random_amp_range + ws = np.random.uniform(min_w, max_w, size=(b, c)).astype(dtype) + else: + raise ValueError( + f"random_amp_range should be a list of two floats, [min_amp, max_amp] or None, but got {random_amp_range}") + + if mask is not None: + ws *= mask # (b, c) + + processed_audio_stems = x * ws[:, :, np.newaxis] # (b, c, t) + processed_audio_mix = np.sum(processed_audio_stems, axis=1, keepdims=True) # (b, 1, t) + + # Normalize + if normalize is True: + norm_factors = np.max(np.abs(processed_audio_mix), axis=2, keepdims=True) + EPS # (b, 1, 1) + processed_audio_stems /= norm_factors # (b, c, t) + processed_audio_mix /= norm_factors # (b, 1, t) + else: + pass + return processed_audio_stems, processed_audio_mix + + +def audio_random_submix_processor(sampled_data: Dict[str, Any], + random_amp_range: List[float] = [0.6, 1.2], + audio_masks: Optional[List[Optional[np.ndarray]]] = None, + update_audio_segments: bool = True, + create_processed_audio_array: bool = True) -> None: + """Randomly submix audio from sampled data + + Args: + sampled_data: a dictionary containing sampled data. + ['audio_segments']: a list of audio segments with length B, each element with shape (1, num_stems, T) + random_amp_range: a list of two floats, [min_amp, max_amp] + audio_masks: a list of masks. Each mask is binary vector with shape (num_stems,). + update_audio_segments: if True (default), update sampled_data["audio_segments"] in-place. + create_processed_audio_array: if True (default), create a new key "processed_audio_array" in sampled_data for mix audio. + + Returns: + None (processed audio is stored in sampled_data["processed_audio_array"]) + + NOTE: + - This function creates a new key "processed_audio_array" in sampled_data, in-place of `sampled_data`. + - Input audio should exist in sampled_data["audio_segments"]. + - The created sampled_data["processed_audio_array"] has shape of (B, 1, T) + """ + if update_audio_segments is False and create_processed_audio_array is False: + raise ValueError("At least one of update_audio_segments and create_processed_audio_mix should be True.") + + # create a new key "processed_audio" in sampled_data + b = len(sampled_data["audio_segments"]) # sub-batch size + t = sampled_data["audio_segments"][0].shape[2] # audio length + + if create_processed_audio_array is True: + sampled_data["processed_audio_array"] = np.zeros((b, 1, t), dtype=np.float32) + + # loop over each audio segment + if audio_masks is None: + # no audio mask is provided, randomly submix all audio segments + for i, audio_segment in enumerate(sampled_data["audio_segments"]): + processed_audio_stems, processed_audio_mix = audio_random_submix_fn(x=audio_segment, + random_amp_range=random_amp_range, + mask=None) + if create_processed_audio_array is True: + sampled_data["processed_audio_array"][i, :, :] = processed_audio_mix + if update_audio_segments is True: + sampled_data["audio_segments"][i] = processed_audio_stems + + else: + # audio mask is provided, randomly submix audio segments based on the audio mask + for i, (audio_segment, mask) in enumerate(zip(sampled_data["audio_segments"], audio_masks)): + processed_audio_stems, processed_audio_mix = audio_random_submix_fn(x=audio_segment, + random_amp_range=random_amp_range, + mask=mask) + if create_processed_audio_array is True: + sampled_data["processed_audio_array"][i, :, :] = processed_audio_mix + if update_audio_segments is True: + sampled_data["audio_segments"][i] = processed_audio_stems + + +def drop_random_stems_from_bundle(sampled_data: Dict[str, Any], prob: float = 0.7) -> None: + """ + Drop stems with a probability of `prob` from a bundle containing `note_event_segments` and + `audio_segments`. It also update `programs`, and add `has_unannotated` info. This function + serves as a utility for stem-based data augmentation used by `intra_stem_augment_processor` + and `cross_stem_augment_processor`. + + Args: + sampled_data: A dict of sampled data. + prob: The probability of dropping stems from the data. + + Returns: + None. The processed data is stored in-place within the `sampled_data` dictionary. + + Update keys in sampled_data (in-place): + sampled_data["note_event_segments"]: NoteEventListsBundle + sampled_data["audio_segments"]: NoteEventListsBundle + sampled_data["programs_segments"]: a list of list, drum program is 128. updated. + sampled_data["has_unannotated_segments"]: a list of bool, True if unannotated program 129 is in use. Newly added. + + + Removed kyes in sampled_data (in-place): + all other keys except for the above are removed. + + Function execution time: 16ms for bsz=36 with single worker + """ + # Create a deep copy to avoid modifying the original data. + note_event_segments = deepcopy(sampled_data["note_event_segments"]) + has_unannotated = [] # List of bool, True if unannotated program 129 is in use + + for i, (has_stems, note_events, tie_note_events, audio_segment, programs, is_drum) in enumerate( + zip(sampled_data["has_stems_segments"], note_event_segments['note_events'], + note_event_segments['tie_note_events'], sampled_data["audio_segments"], + sampled_data["programs_segments"], sampled_data["is_drum_segments"])): + + # Make sure that programs is np.ndarray + if not isinstance(programs, np.ndarray): + programs = np.array(programs) + + if has_stems is True and UNANNOTATED_PROGRAM not in programs: + # Get unique and actual presence of instruments. 128 means drums, 129 means unannotated. + uniq_programs = np.unique([ne.program if not ne.is_drum else 128 for ne in (tie_note_events + note_events)]) + + # Debug + if DRUM_PROGRAM in uniq_programs: + assert DRUM_PROGRAM in programs, "Drum program 128 not in programs" + if is_drum.any(): + assert DRUM_PROGRAM in programs, "Drum program 128 not in programs" + + # Vectorized random choice for each unique_program + rand_sel_prgs = uniq_programs[np.random.rand(len(uniq_programs)) < prob] + if len(rand_sel_prgs) == 0 and len(uniq_programs) != 0: # Make sure at least one program is active + rand_sel_prgs = np.random.choice(uniq_programs, size=1) + programs_mask = np.isin(programs, rand_sel_prgs).astype(np.int32) + drums_mask = programs_mask * is_drum # NOTE: if drums are not annotated as program 128, this would not work properly + _programs_in_use = programs[programs_mask == 1] + _drum_in_use = np.any(drums_mask == 1) # True if any drum is in use + + # Drop note_events and tie_note_events in-place + note_events[:] = [ + ne for ne in note_events + if (not ne.is_drum and ne.program in _programs_in_use) or (ne.is_drum and _drum_in_use) + ] + tie_note_events[:] = [ne for ne in tie_note_events if ne.program in _programs_in_use] + + # Drop stems from audio_segments, update programs_segments + sampled_data["audio_segments"][i] = audio_segment[:, programs_mask == 1, :] + sampled_data["programs_segments"][i] = programs[programs_mask == 1] + + # Create has_unannotated + has_unannotated.append(False) + + elif has_stems is True and UNANNOTATED_PROGRAM in programs: + # If unannotated program is included in programs, we only drop 129 with a probability of `prob`. + # `note_event_segments` remains the same. + # TODO: Actually, we can drop any annoated programs, but current datasets are not the case. + uniq_programs = np.unique([ne.program if not ne.is_drum else 128 for ne in (tie_note_events + note_events)]) + if np.random.rand() > prob: + # keep unannotated program, and this will not allow further cross-stem augmentation. + has_unannotated.append(True) + else: + # drop unannotated program + assert UNANNOTATED_PROGRAM not in uniq_programs # 129 is not included here... + sampled_data["audio_segments"][i] = audio_segment[:, programs != 129, :] + sampled_data["programs_segments"][i] = programs[programs != 129] + has_unannotated.append(False) + + elif has_stems is False and UNANNOTATED_PROGRAM in programs: + # No stems, but has unannoted program: cannot be used for cross-stem augmentation. + has_unannotated.append(True) + + else: + # No stems, no unannotated program: nothing to do. + has_unannotated.append(False) + + # Update sampled_data in-place + sampled_data["note_event_segments"] = note_event_segments + sampled_data["has_unannotated_segments"] = has_unannotated + + # Remove all other keys except for the above, because they are not used in the downstream pipeline. + keys_to_remove = ['is_drum_segments', 'has_stems_segments'] + for key in keys_to_remove: + del sampled_data[key] + + +# ------------------------------------------------------------------------------------- +# intra stem augmentation processor +# ------------------------------------------------------------------------------------- +def intra_stem_augment_processor(sampled_data: Dict[str, Any], + random_amp_range: List[float] = [0.6, 1.2], + prob: float = 0.7, + update_audio_segments: bool = True, + submix_audio: bool = True) -> None: + """ + Intra_stem_augmentation + + Shape of input: + sampled_data: + ['note_event_segments']['note_events']: + List[List[NoteEvent]] with length B, each element is a list of NoteEvent + with length num_notes + ['note_event_segments']['tie_note_events']: + List[List[NoteEvent]] with length B, each element is a list of NoteEvent + with length num_tie_notes + ['note_event_segments']['start_times']: + List[float] with length B + + ['audio_segments']: + np.ndarray with shape(B, num_stems, T) + ['programs_segments']: + np.ndarray with shape(num_stems,) + ['is_drum_segments']: + np.ndarray with shape(num_stems,) + ['has_stems_segments']: + List[bool] with length B + + Output (modified in-place): + sampled_data: + ['note_event_segments']: + ['note_events']: + ['tie_note_events']: + ['start_times']: (not modified) + ['audio_segments']: + np.ndarray with shape(1, num_stems, T) + ['processed_audio_array']: # if submix_audio is True + np.ndarray with shape(B, 1, T) + ['programs_segments']: + List[np.ndarray] with length B, each element is a np.ndarray with shape(num_stems,) + ['has_unannotated_segments']: + List[bool] with length B + Execution time: 27 ms for bsz=36 with single worker, including submix audio + """ + + # Randomly drop stems: + # - p (0. < p <= 1.) chances to keep each stem, at least one non-drum is guaranteed to be kept. + # - This method modifies the input 'note_event_segments' in-place. + drop_random_stems_from_bundle(sampled_data, prob=prob) + + # Audio processing + if submix_audio is True: + # Randomly submix audio, and update audio_segments in-place with random amplitude applied. + audio_random_submix_processor(sampled_data=sampled_data, + random_amp_range=random_amp_range, + audio_masks=None, + update_audio_segments=True, + create_processed_audio_array=True) # mix + # assert "processed_audio_array" in sampled_data.keys() + else: + # NOTE: This is used within the cross-stem augmentation pipeline. + pass + + +# ------------------------------------------------------------------------------------- +# cross-stem augmentation helper functions +# ------------------------------------------------------------------------------------- +def combined_survival_and_stop(max_k: int = 5, tau: float = 0.3, alpha: float = 1.0) -> Tuple[np.ndarray, np.ndarray]: + """ + Compute the survival function and prob_stop for exponential or Weibull distributions based on the value of alpha. + - S(k) represents the probability of "surviving" up to k-th trial. + - P_stop(k), the stopping probability at trial k is the difference between the survival probabilities at + k-1 and k. + + Parameters: + - max_k (int) : Maximum number of trials. k=0, 1, ..., max_k. k=0 means no cross-stem augmentation. + - tau (float) : Scale parameter. Represents average time to the first failure for exponential distribution. + For Weibull distribution, it influences the spread and shape of the distribution. + - alpha (float) : Shape parameter. If alpha=1, the function reduces to exponential distribution. + Otherwise, it represents the Weibull distribution. + + Returns: + - survival (array-like) : Computed survival function values. + - prob_stop (array-like) : Computed stop probabilities. + + Example 1: + >>> survival_exp, stop_exp = combined_survival_and_stop(max_k=5, tau=0.3, alpha=1.0) + Exponential Survival: [1. 0.74081822 0.54881164 0.40656966 0.30119421 0.22313016] + Exponential Stop Prob: [0.22313016 0.25918178 0.19200658 0.14224198 0.10537545 0.07806405] + + Example 2: + max_k = 5 + survival_exp, stop_exp_03 = combined_survival_and_stop(max_k, 0.3, 1) + survival_weibull, stop_weibull = combined_survival_and_stop(max_k, 0.3, 1.5) + + import matplotlib.pyplot as plt + plt.plot(range(max_k+1), list(stop_exp_03), 'o-', label='Exponential (tau=0.3)') + plt.plot(range(max_k+1), list(stop_weibull), 's-', label='Weibull (tau=0.3, alpha=1.5)') + plt.title("Stop Probabilities"); plt.xlabel("k"); plt.ylabel("Probability") + plt.legend(); plt.grid(True); plt.show() + + References: + - Weibull, Waloddi. "A statistical distribution function of wide applicability." Journal of applied mechanics (1951). + + """ + + # Generate k values based on max_k + k_values = np.arange(max_k + 1) + + # Calculate survival function + if alpha == 1: + survival = np.exp(-k_values * tau) + else: + survival = np.exp(-np.power(k_values * tau, alpha)) + + # Calculate prob_stop and normalize + prob_stop_at_k = -np.diff(np.append(survival, 0.)) + return survival, prob_stop_at_k # (max_k+1,), (max_k+1,) + + +def deterministic_random_ux_sampler(prob_stop_at_k, bsz) -> np.ndarray: + """ + Deterministic random sampler for sampling U\X for cross-stem augmentation. + + Args: + prob_stop_at_k (array-like): Probabilities of stopping at k-th trial. + bsz (int) : Batch size. Usually local batch size. + + Returns: + ux_count_per_item (array-like): Number of U\X to sample for each item in the batch. + + Example: + >>> max_k = 5; tau = 0.3; alpha = 1.0; bsz = 20 + >>> _, prob_stop_at_k = combined_survival_and_stop(max_k, tau, alpha) + prob_stop_at_k: [0.22313016 0.25918178 0.19200658 0.14224198 0.10537545 0.07806405] + >>> np.random.choice(np.arange(max_k+1), size=bsz, p=prob_stop_at_k) + array([1, 4, 1, 3, 0, 3, 0, 2, 5, 0]) + + """ + ux_count_per_item = np.random.choice(np.arange(len(prob_stop_at_k)), size=bsz, p=prob_stop_at_k) + return ux_count_per_item + + +def check_programs_overlap(list_programs: List[np.ndarray], programs: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """ + Check if there is any instrument overlap between two lists of programs. + + Example: + >>> list_programs = np.array([np.array([1,2,3]), np.array([5,6])], dtype=object) + >>> print(check_programs_overlap(list_programs, np.array([np.array([1,7])], dtype=object))) # Expected [1] + >>> print(check_programs_overlap(list_programs, np.array([np.array([])], dtype=object))) # Expected [] + """ + list_programs_set = set(item for sublist in list_programs for item in sublist) + overlaps = [p for p in programs if p in list_programs_set] + uniq_prg_mask = np.array([p not in list_programs_set for p in programs]) + return np.array(overlaps), uniq_prg_mask + + +def regroup_program_and_audio_by_minimal_shared_subunits( + gathered_programs: List[np.ndarray], + gathered_audio_array: List[np.ndarray], + max_num_groups: Optional[int] = None +) -> Tuple[List[List[int]], DefaultDict[Tuple[int, ...], List[Tuple[int, int]]]]: + # Check if each audio has stems + gathered_has_stem = [ + audio_array.shape[1] > 1 for programs, audio_array in zip(gathered_programs, gathered_audio_array) + ] + + # Create a dictionary for mapping audio to programs + audio2prg = defaultdict(list) + for i, programs in enumerate(gathered_programs): + for j, value in enumerate(programs): + if gathered_has_stem[i] is True: + audio2prg[(i, j)].append(value) + else: + audio2prg[(i, 0)].append(value) + grouped_prg2audio = defaultdict(list) + for k_tuple, v_list in audio2prg.items(): + grouped_prg2audio[tuple(sorted(v_list))].append(k_tuple) + # defaultdict(list, + # {(61, 69, 71, 72): [(0, 0)], + # (128,): [(1, 0)], ...} + + # Limit the number of groups + if max_num_groups is not None: + # randomly merge groups + while len(grouped_prg2audio) > max_num_groups: + # randomly select two groups to merge + k1, k2 = random.sample(list(grouped_prg2audio.keys()), 2) + grouped_prg2audio[k1].extend(grouped_prg2audio[k2]) + del grouped_prg2audio[k2] + + grouped_programs = list(grouped_prg2audio.keys()) + return grouped_programs, grouped_prg2audio # (List[Tuple[int]], DefaultDict[Tuple[int], List[int]]) + + +def audio_random_submix_by_regroup_program_processor(gathered_programs: List[np.ndarray], + gathered_audio_array: np.ndarray, + submix_random_amp_range: List[float] = [0.9, 1.0], + max_num_stems: int = 12) -> Tuple[List[Tuple[int]], np.ndarray]: + """Regroup programs into subunit programs, and submix regrouped audio arrays + Return: + grouped_programs: List[Tuple[int]] + submix_audio_array: np.ndarray with shape (1, num_grouped_submix_audio, T) + """ + + # Regroup programs into subunit programs + grouped_programs, grouped_prg2audio = regroup_program_and_audio_by_minimal_shared_subunits( + gathered_programs, gathered_audio_array, max_num_groups=max_num_stems) + + # Submix subunit audio arrays, based on the regrouped programs + n_frames = gathered_audio_array[0].shape[2] + submix_audio_array = np.zeros((1, max_num_stems, n_frames), dtype=np.float32) + for i, prgs in enumerate(grouped_programs): + audio_ids = grouped_prg2audio[prgs] # id of gathered_audio_array, e.g.:[(i,j),...] + if len(audio_ids) == 1: + # no need to submix, already subunits + src_idx, stem_idx = audio_ids[0] + submix_audio_array[:, i, :] = gathered_audio_array[src_idx][:, [stem_idx], :] + else: + # submix audio from elements of subunit programs + _submix_audio_list = [gathered_audio_array[src_idx][:, [stem_idx], :] for (src_idx, stem_idx) in audio_ids] + _submix_audio_arr = np.concatenate(_submix_audio_list, axis=1, dtype=np.float32) # (1, C, T) + _, _submix_audio_arr = audio_random_submix_fn(_submix_audio_arr, + random_amp_range=submix_random_amp_range, + normalize=False) + submix_audio_array[:, i, :] = _submix_audio_arr + return [list(prgs) for prgs in grouped_programs], submix_audio_array + + +# ------------------------------------------------------------------------------------- +# cross stem augmentation processor +# ------------------------------------------------------------------------------------- +def cross_stem_augment_processor( + sampled_data: Dict[str, Any], + sampled_ids: np.ndarray, + get_rand_segments_from_cache_fn: Callable, + random_amp_range: List[float] = [0.6, 1.2], + stem_iaug_prob: float = 0.7, + stem_xaug_policy: Dict = { + "max_k": 3, # max number of external sources used for cross-stem augmentations + "tau": 0.3, # exponential decay rate for cross-stem augmentation + "alpha": 1.0, # shape parameter for Weibull distribution. set 1.0 for exponential. + "max_subunit_stems": 12, # the number of subunit stems to be reduced to + "p_include_singing": + 0.8, # probability of including singing for cross augmented examples. if None, use base probaility. + "no_instr_overlap": True, + "no_drum_overlap": True, + "uhat_intra_stem_augment": True, + }, + max_l: int = 1024, + precomputed_prob_stop_at_k: Optional[np.array] = None, + mix_audio: bool = True, + create_subunit_note_events: bool = False) -> None: + """ + Cross-stem augmentation + + Args: + sampled_data: a dictionary containing sampled data. + ['note_event_segments']: a list of NoteEventListsBundle with length B + ['audio_segments']: a list of audio segments with length B, each element with shape (1, num_stems, T) + ['programs_segments']: a list of programs with length B, each element with shape (num_stems,) + ['has_unannotated_segments']: a list of bool with length B + sampled_ids: a numpy array of sampled ids used in sampled_data. (B,) + get_rand_segments_from_cache_fn: a function for getting random segments from cache. + random_amp_range: a list of two floats, [min_amp, max_amp] + stem_iaug_prob: a float, probability of intra-stem augmentation + stem_xaug_policy: a dictionary of cross-stem augmentation policy + - max_k (int) : Maximum number of trials. k=0, 1, ..., max_k. k=0 means no cross-stem augmentation. + - tau (float) : Scale parameter. Represents average time to the first failure for exponential distribution. + For Weibull distribution, it influences the spread and shape of the distribution. + - alpha (float) : Shape parameter. If alpha=1, the function reduces to exponential distribution. + Otherwise, it represents the Weibull distribution. + - max_subunit_stems (int): Maximum number of subunit stems. If larger, they are reduced to this number + by submix. Default: 12 + - p_include_singing (float): Probability of including singing for cross augmented examples. If None, use + base probaility. + - no_instr_overlap (bool): If True, do not allow instrument overlap between X and U\X. + - no_drum_overlap (bool): If True, do not allow drum overlap between X and U\X. + - uhat_intra_stem_augment (bool): If True, apply intra-stem augmentation to U\X. + max_l: a int, maximum number of note events in a note event list. Default: 1024 + precomputed_prob_stop_at_k: a numpy array of precomputed prob_stop_at_k. If None, it will be computed every time. + mix_audio: a bool, if True, mix audio from X and U\X. Default: True + create_subunit_note_events: a bool, if True, create subunit note events. This is necessary for multi channel + decoder training. Default is False. + + Returns: + None (processed data is stored in-place within the `sampled_data` dictionary) + + Update keys in sampled_data (in-place): + sampled_data["subunit_programs_segments"]: List[List[np.ndarray]], with length B + sampled_data["subunit_note_event_segments"]: List[NoteEventListsBundle], with length B + sampled_data["subunit_audio_array"]: np.ndarray with shape (B, max_subunit_stems, T) + sampled_data["programs_segments"]: List[np.ndarray], with length B + sampled_data["note_event_segments"]: NoteEventListsBundle + sampled_data["has_unannotated_segments"]: List[bool], with length B + sampled_data["processed_audio_array"]: np.ndarray with shape (B, 1, T) + + Removed kyes in sampled_data (in-place): + all other keys except for the above are removed. + """ + # Setup parameters + max_k = stem_xaug_policy["max_k"] + tau = stem_xaug_policy["tau"] + alpha = stem_xaug_policy.get("alpha", 1.0) + max_subunit_stems = stem_xaug_policy.get("max_subunit_stems", 12) + p_include_singing = stem_xaug_policy.get("p_include_singing", None) + no_instr_overlap = stem_xaug_policy["no_instr_overlap"] + no_drum_overlap = stem_xaug_policy["no_drum_overlap"] + uhat_intra_stem_augment = stem_xaug_policy["uhat_intra_stem_augment"] + bsz = len(sampled_ids) # local batch size + n_frames = sampled_data["audio_segments"][0].shape[2] + + if precomputed_prob_stop_at_k is None: + _, prob_stop_at_k = combined_survival_and_stop(max_k, tau, alpha) + else: + prob_stop_at_k = precomputed_prob_stop_at_k + + ux_count_per_item = deterministic_random_ux_sampler(prob_stop_at_k, bsz) + ux_count_sum = int(np.sum(ux_count_per_item)) + + # X_in: sampled_data, which we have already applied intra-stem augmentation + + # U\X: ux_sampled_data, complement of X in U + ux_sampled_data, _ = get_rand_segments_from_cache_fn( + num_segments=ux_count_sum, + use_ordered_read_pos=False, # fully random sampling segments from cache + sample_excluding_ids=sampled_ids) + + # Randomly drop stems from U\X, and update audio stems without submixing audio. + if uhat_intra_stem_augment is True: + intra_stem_augment_processor(sampled_data=ux_sampled_data, + random_amp_range=random_amp_range, + prob=stem_iaug_prob, + update_audio_segments=True, + submix_audio=False) + + # Loop for creating X_hat + iter_ux = iter( + zip( + ux_sampled_data['audio_segments'], + dict_iterator(ux_sampled_data['note_event_segments']), + ux_sampled_data['programs_segments'], + ux_sampled_data['has_unannotated_segments'], + )) + iter_x_in = iter( + zip( + sampled_data['audio_segments'], + dict_iterator(sampled_data['note_event_segments']), + sampled_data['programs_segments'], + sampled_data['has_unannotated_segments'], + )) + x_hat = { + "subunit_programs_segments": [], # List[List[np.ndarray]], with length B + "subunit_note_event_segments": [], # List[NoteEventListsBundle], with length B + "subunit_audio_array": np.zeros((bsz, max_subunit_stems, n_frames), + dtype=np.float32), # (B, max_submix_stems, T) + "programs_segments": [], # List[np.ndarray], with length B + "note_event_segments": { + "note_events": [], + "tie_note_events": [], + "start_times": [] + }, # NoteEventListsBundle + "has_unannotated_segments": [], # List[bool], with length B + "processed_audio_array": np.zeros((bsz, 1, n_frames), dtype=np.float32), # mixed audio array, B, 1, T) + } + + for i, (audio_array, ne_bundle, programs, has_unannotated) in enumerate(iter_x_in): + num_ux_samples = ux_count_per_item[i] + if num_ux_samples > 0 and has_unannotated is False: + # gather the main source and k external sources + gathered_programs = [programs] + gathered_ne_bundle = ne_bundle # mutable, but ok because `dict_iterator` yields new dict + gathered_audio_array = [audio_array] + + for k in range(num_ux_samples): + # Get next external source + ex_audio_array, ex_ne_bundle, ex_programs, ex_has_unannotated = next(iter_ux) + ex_prg_mask = None # None: no need to mask external programs + ex_has_stem = bool(ex_audio_array.shape[1] > 1) + """Criteria for skipping sources""" + if ex_has_unannotated is True: + continue + """Criteria for instrument overlap and drum overlap """ + instr_overlap, uniq_ex_prg_mask = check_programs_overlap(gathered_programs, ex_programs) + if no_instr_overlap is True and len(instr_overlap) > 0: + if np.any(uniq_ex_prg_mask) and ex_has_stem is True: + # mask out non-unique external programs + ex_prg_mask = uniq_ex_prg_mask + else: + # print(i, k, num_ux_samples, ex_programs, + # 'Warning: no unique external programs, skip this source') + continue # no unique external programs, skip this source + else: + # programs is already unique or don't care about overlap + pass + + if no_drum_overlap is True and no_instr_overlap is False and DRUM_PROGRAM in instr_overlap: + non_drum_ex_prg_mask = np.array([prg != DRUM_PROGRAM for prg in ex_programs]) + if np.any(non_drum_ex_prg_mask): + # mask only drum external programs + ex_prg_mask = non_drum_ex_prg_mask + else: + # print(i, k, num_ux_samples, ex_programs, + # 'Warning: no non-drum external programs, skip this source') + continue # drum overlapped, but no non-drum programs, skip this source + else: + pass + """Criteria for stopping iteration with respect to max length""" + if check_event_len_from_bundle(gathered_ne_bundle, ex_ne_bundle, max_len=max_l) is False: + # print(i, k, num_ux_samples, 'Warning: max length reached, stop iteration') + break + + # Apply mask and gather + if ex_prg_mask is None: + gathered_programs.append(ex_programs) + extend_dict(gathered_ne_bundle, ex_ne_bundle) + gathered_audio_array.append(ex_audio_array) + else: + # apply mask to external programs, and add to list + ex_programs = ex_programs[ex_prg_mask] + gathered_programs.append(ex_programs) + + # drop note_events with masked programs, and extend dictionary + _ex_has_drum = np.any(ex_programs == DRUM_PROGRAM) + ex_ne_bundle["note_events"][0] = [ + ne for ne in ex_ne_bundle["note_events"][0] + if (not ne.is_drum and ne.program in ex_programs) or (ne.is_drum and _ex_has_drum) + ] + ex_ne_bundle["tie_note_events"][0] = [ + ne for ne in ex_ne_bundle["tie_note_events"][0] if ne.program in ex_programs + ] + extend_dict(gathered_ne_bundle, ex_ne_bundle) + + # apply mask to external audio_array, and add to list + gathered_audio_array.append(ex_audio_array[:, ex_prg_mask, :]) + + # print(gathered_programs) + # Regroup gathered programs, and cresate submix by subunits programs + subunit_programs, subunit_audio_array = audio_random_submix_by_regroup_program_processor( + gathered_programs, gathered_audio_array, max_num_stems=max_subunit_stems) + mixed_ne_bundle = mix_note_event_lists_bundle(gathered_ne_bundle, + sort=True, + start_time_to_zero=True, + use_deepcopy=True) #False) + + if create_subunit_note_events is True: + subunit_ne_bundle = separate_by_subunit_programs_from_note_event_lists_bundle(mixed_ne_bundle, + subunit_programs, + start_time_to_zero=False, + sort=True) + else: + subunit_ne_bundle = None + x_hat["subunit_note_event_segments"].append(subunit_ne_bundle) + + x_hat["subunit_programs_segments"].append(subunit_programs) + x_hat["subunit_audio_array"][i, :subunit_audio_array.shape[1], :] = subunit_audio_array # (B, C, T) + + x_hat["programs_segments"].append(np.concatenate(gathered_programs, axis=0)) + extend_dict(x_hat["note_event_segments"], mixed_ne_bundle) + x_hat["has_unannotated_segments"].append(has_unannotated) + else: + num_stems = audio_array.shape[1] + if num_stems > max_subunit_stems: + # If num_stems exceeds max_subunit_stems, randomly select max_subunit_stems stems + subunit_programs, subunit_audio_array = audio_random_submix_by_regroup_program_processor( + [programs], [audio_array], max_num_stems=max_subunit_stems) + else: + subunit_programs = [programs] + subunit_audio_array = audio_array + x_hat["subunit_programs_segments"].append(subunit_programs) + x_hat["subunit_audio_array"][i, :subunit_audio_array.shape[1], :] = subunit_audio_array + + if create_subunit_note_events is True: + subunit_ne_bundle = separate_by_subunit_programs_from_note_event_lists_bundle(ne_bundle, + subunit_programs, + start_time_to_zero=True, + sort=True) + else: + subunit_ne_bundle = None + x_hat["subunit_note_event_segments"].append(subunit_ne_bundle) + + x_hat["programs_segments"].append(programs) + extend_dict(x_hat["note_event_segments"], ne_bundle) + x_hat["has_unannotated_segments"].append(has_unannotated) + + # Mix subunit audio and update subunit audio arrays + if mix_audio is True: + amp_applied_stem_arr, mix_audio_arr = audio_random_submix_fn(x_hat["subunit_audio_array"], + random_amp_range=random_amp_range, + mask=None, + normalize=True) + x_hat["subunit_audio_array"] = amp_applied_stem_arr # (B, C, T) + x_hat["processed_audio_array"] = mix_audio_arr # (B, 1, T) + + # Update sampled_data in-place + sampled_data["subunit_programs_segments"] = x_hat["subunit_programs_segments"] + sampled_data["subunit_note_event_segments"] = x_hat["subunit_note_event_segments"] + sampled_data["subunit_audio_array"] = x_hat["subunit_audio_array"] + sampled_data["programs_segments"] = x_hat["programs_segments"] + sampled_data["note_event_segments"] = x_hat["note_event_segments"] + sampled_data["has_unannotated_segments"] = x_hat["has_unannotated_segments"] + sampled_data["processed_audio_array"] = x_hat["processed_audio_array"] + del sampled_data["audio_segments"] diff --git a/amt/src/utils/data_modules.py b/amt/src/utils/data_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..b5b482a8b5e7464272c3a0e9c77bdb1f9191dd39 --- /dev/null +++ b/amt/src/utils/data_modules.py @@ -0,0 +1,204 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +""" data_modules.py """ +from typing import Optional, Dict, List, Any +import os +import numpy as np +from pytorch_lightning import LightningDataModule +from pytorch_lightning.utilities import CombinedLoader +from utils.datasets_train import get_cache_data_loader +from utils.datasets_eval import get_eval_dataloader +from utils.datasets_helper import create_merged_train_dataset_info, get_list_of_weighted_random_samplers +from utils.task_manager import TaskManager +from config.config import shared_cfg +from config.config import audio_cfg as default_audio_cfg +from config.data_presets import data_preset_single_cfg, data_preset_multi_cfg + + +class AMTDataModule(LightningDataModule): + + def __init__( + self, + data_home: Optional[os.PathLike] = None, + data_preset_multi: Dict[str, Any] = { + "presets": ["musicnet_mt3_synth_only"], + }, # only allowing multi_preset_cfg. single_preset_cfg should be converted to multi_preset_cfg + task_manager: TaskManager = TaskManager(task_name="mt3_full_plus"), + train_num_samples_per_epoch: Optional[int] = None, + train_random_amp_range: List[float] = [0.6, 1.2], + train_stem_iaug_prob: Optional[float] = 0.7, + train_stem_xaug_policy: Optional[Dict] = { + "max_k": 3, + "tau": 0.3, + "alpha": 1.0, + "max_subunit_stems": 12, # the number of subunit stems to be reduced to this number of stems + "p_include_singing": + 0.8, # probability of including singing for cross augmented examples. if None, use base probaility. + "no_instr_overlap": True, + "no_drum_overlap": True, + "uhat_intra_stem_augment": True, + }, + train_pitch_shift_range: Optional[List[int]] = None, + audio_cfg: Optional[Dict] = None) -> None: + super().__init__() + + # check path existence + if data_home is None: + data_home = shared_cfg["PATH"]["data_home"] + if os.path.exists(data_home): + self.data_home = data_home + else: + raise ValueError(f"Invalid data_home: {data_home}") + self.preset_multi = data_preset_multi + self.preset_singles = [] + # e.g. [{"dataset_name": ..., "train_split": ..., "validation_split":...,}, {...}] + for dp in self.preset_multi["presets"]: + if dp not in data_preset_single_cfg.keys(): + raise ValueError("Invalid data_preset") + self.preset_singles.append(data_preset_single_cfg[dp]) + + # task manager + self.task_manager = task_manager + + # train num samples per epoch, passed to the sampler + self.train_num_samples_per_epoch = train_num_samples_per_epoch + assert shared_cfg["BSZ"]["train_local"] % shared_cfg["BSZ"]["train_sub"] == 0 + self.num_train_samplers = shared_cfg["BSZ"]["train_local"] // shared_cfg["BSZ"]["train_sub"] + + # train augmentation parameters + self.train_random_amp_range = train_random_amp_range + self.train_stem_iaug_prob = train_stem_iaug_prob + self.train_stem_xaug_policy = train_stem_xaug_policy + self.train_pitch_shift_range = train_pitch_shift_range + + # train data info + self.train_data_info = None # to be set in setup() + + # validation/test max num of files + self.val_max_num_files = data_preset_multi.get("val_max_num_files", None) + self.test_max_num_files = data_preset_multi.get("test_max_num_files", None) + + # audio config + self.audio_cfg = audio_cfg if audio_cfg is not None else default_audio_cfg + + def set_merged_train_data_info(self) -> None: + """Collect train datasets and create info... + + self.train_dataset_info = { + "n_datasets": 0, + "n_notes_per_dataset": [], + "n_files_per_dataset": [], + "dataset_names": [], # dataset names by order of merging file lists + "train_split_names": [], # train split names by order of merging file lists + "index_ranges": [], # index ranges of each dataset in the merged file list + "dataset_weights": [], # pre-defined list of dataset weights for sampling, if available + "merged_file_list": {}, + } + """ + self.train_data_info = create_merged_train_dataset_info(self.preset_multi) + print( + f"AMTDataModule: Added {len(self.train_data_info['merged_file_list'])} files from {self.train_data_info['n_datasets']} datasets to the training set." + ) + + def setup(self, stage: str): + """ + Prepare data args for the dataloaders to be used on each stage. + `stage` is automatically passed by pytorch lightning Trainer. + """ + if stage == "fit": + # Set up train data info + self.set_merged_train_data_info() + + # Distributed Weighted random sampler for training + actual_train_num_samples_per_epoch = self.train_num_samples_per_epoch // shared_cfg["BSZ"][ + "train_local"] if self.train_num_samples_per_epoch else None + samplers = get_list_of_weighted_random_samplers(num_samplers=self.num_train_samplers, + dataset_weights=self.train_data_info["dataset_weights"], + dataset_index_ranges=self.train_data_info["index_ranges"], + num_samples_per_epoch=actual_train_num_samples_per_epoch) + # Train dataloader arguments + self.train_data_args = [] + for sampler in samplers: + self.train_data_args.append({ + "dataset_name": None, + "split": None, + "file_list": self.train_data_info["merged_file_list"], + "sub_batch_size": shared_cfg["BSZ"]["train_sub"], + "task_manager": self.task_manager, + "random_amp_range": self.train_random_amp_range, # "0.1,0.5 + "stem_iaug_prob": self.train_stem_iaug_prob, + "stem_xaug_policy": self.train_stem_xaug_policy, + "pitch_shift_range": self.train_pitch_shift_range, + "shuffle": True, + "sampler": sampler, + "audio_cfg": self.audio_cfg, + }) + + # Validation dataloader arguments + self.val_data_args = [] + for preset_single in self.preset_singles: + if preset_single["validation_split"] != None: + self.val_data_args.append({ + "dataset_name": preset_single["dataset_name"], + "split": preset_single["validation_split"], + "task_manager": self.task_manager, + # "tokenizer": self.task_manager.get_tokenizer(), + "max_num_files": self.val_max_num_files, + "audio_cfg": self.audio_cfg, + }) + + if stage == "test": + self.test_data_args = [] + for preset_single in self.preset_singles: + if preset_single["test_split"] != None: + self.test_data_args.append({ + "dataset_name": preset_single["dataset_name"], + "split": preset_single["test_split"], + "task_manager": self.task_manager, + "max_num_files": self.test_max_num_files, + "audio_cfg": self.audio_cfg, + }) + + def train_dataloader(self) -> Any: + loaders = {} + for i, args_dict in enumerate(self.train_data_args): + loaders[f"data_loader_{i}"] = get_cache_data_loader(**args_dict, dataloader_config=shared_cfg["DATAIO"]) + return CombinedLoader(loaders, mode="min_size") # size is always identical + + def val_dataloader(self) -> Any: + loaders = {} + for args_dict in self.val_data_args: + dataset_name = args_dict["dataset_name"] + loaders[dataset_name] = get_eval_dataloader(**args_dict, dataloader_config=shared_cfg["DATAIO"]) + return loaders + + def test_dataloader(self) -> Any: + loaders = {} + for args_dict in self.test_data_args: + dataset_name = args_dict["dataset_name"] + loaders[dataset_name] = get_eval_dataloader(**args_dict, dataloader_config=shared_cfg["DATAIO"]) + return loaders + + """CombinedLoader in "sequential" mode returns dataloader_idx to the + trainer, which is used to get the dataset name in the logger. """ + + @property + def num_val_dataloaders(self) -> int: + return len(self.val_data_args) + + @property + def num_test_dataloaders(self) -> int: + return len(self.test_data_args) + + def get_val_dataset_name(self, dataloader_idx: int) -> str: + return self.val_data_args[dataloader_idx]["dataset_name"] + + def get_test_dataset_name(self, dataloader_idx: int) -> str: + return self.test_data_args[dataloader_idx]["dataset_name"] diff --git a/amt/src/utils/datasets_eval.py b/amt/src/utils/datasets_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..1d2d29053d5621e31a62ffa053292cd06145ed65 --- /dev/null +++ b/amt/src/utils/datasets_eval.py @@ -0,0 +1,214 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +import json +import os +from typing import Dict, Any, Union, Tuple, Optional + +import torch +import numpy as np +from einops import rearrange +from torch.utils.data import DataLoader, Dataset +from utils.audio import load_audio_file, slice_padded_array +from utils.tokenizer import EventTokenizerBase, NoteEventTokenizer +from utils.note2event import slice_multiple_note_events_and_ties_to_bundle +from utils.note_event_dataclasses import Note, NoteEvent, NoteEventListsBundle +from utils.task_manager import TaskManager +from config.config import shared_cfg +from config.config import audio_cfg as default_audio_cfg + +UNANNOTATED_PROGRAM = 129 + + +class AudioFileDataset(Dataset): + """ + 🎧 AudioFileDataset for validation/test: + + This dataset class is designed to be used ONLY with `batch_size=None` and + returns sliced audio segments and unsliced notes and sliced note events for + a single song when `__getitem__` is called. + + Args: + file_list (Union[str, bytes, os.PathLike], optional): + Path to the file list. e.g. "../../data/yourmt3_indexes/slakh_validation_file_list.json" + task_manager (TaskManager, optional): TaskManager instance. Defaults to TaskManager(). + fs (int, optional): Sampling rate. Defaults to 16000. + seg_len_frame (int, optional): Segment length in frames. Defaults to 32767. + seg_hop_frame (int, optional): Segment hop in frames. Defaults to 32767. + sub_batch_size (int, optional): Sub-batch size that will be used in + generation of tokens. Defaults to 32. + max_num_files (int, optional): Maximum number of files to be loaded. Defaults to None. + + + Variables: + file_list: + '{dataset_name}_{split}_file_list.json' has the following keys: + { + 'index': + { + 'mtrack_id': mtrack_id, + 'n_frames': n of audio frames + 'stem_file': Dict of stem audio file info + 'mix_audio_file': mtrack.mix_path, + 'notes_file': available only for 'validation' and 'test' + 'note_events_file': available only for 'train' and 'validation' + 'midi_file': mtrack.midi_path + } + } + + __getitem__(index) returns: + + audio_segment: + torch.FloatTensor: (nearest_N_divisable_by_sub_batch_size, 1, seg_len_frame) + + notes_dict: + { + 'mtrack_id': int, + 'program': List[int], + 'is_drum': bool, + 'duration_sec': float, + 'notes': List[Note], + } + + token_array: + torch.LongTensor: (n_segments, seg_len_frame) + + """ + + def __init__( + self, + file_list: Union[str, bytes, os.PathLike], + task_manager: TaskManager = TaskManager(), + # tokenizer: Optional[EventTokenizerBase] = None, + fs: int = 16000, + seg_len_frame: int = 32767, + seg_hop_frame: int = 32767, + max_num_files: Optional[int] = None) -> None: + + # load the file list + with open(file_list, 'r') as f: + fl = json.load(f) + file_list = {int(key): value for key, value in fl.items()} + if max_num_files: # reduce the number of files + self.file_list = dict(list(file_list.items())[:max_num_files]) + else: + self.file_list = file_list + + self.fs = fs + self.seg_len_frame = seg_len_frame + self.seg_len_sec = seg_len_frame / fs + self.seg_hop_frame = seg_hop_frame + self.task_manager = task_manager + + def __getitem__(self, index: int) -> Tuple[np.ndarray, Dict, NoteEventListsBundle]: + # get metadata + metadata = self.file_list[index] + audio_file = metadata['mix_audio_file'] + notes_file = metadata['notes_file'] + note_events_file = metadata['note_events_file'] + + # load the audio + audio = load_audio_file(audio_file, dtype=np.int16) # returns bytes + audio = audio / 2**15 + audio = audio.astype(np.float32) + audio = audio.reshape(1, -1) + audio_segments = slice_padded_array( + audio, + self.seg_len_frame, + self.seg_hop_frame, + pad=True, + ) # (n_segs, seg_len_frame) + audio_segments = rearrange(audio_segments, 'n t -> n 1 t').astype(np.float32) + num_segs = audio_segments.shape[0] + + # load all notes and from a file (of a single song) + notes_dict = np.load(notes_file, allow_pickle=True, fix_imports=False).tolist() + + # TODO: add midi_file path in preprocessing instead of here + notes_dict['midi_file'] = metadata['midi_file'] + + # tokenize note_events + note_events_dict = np.load(note_events_file, allow_pickle=True, fix_imports=False).tolist() + + if self.task_manager.tokenizer is not None: + # not using seg_len_sec to avoid accumulated rounding errors + start_times = [i * self.seg_hop_frame / self.fs for i in range(num_segs)] + note_event_segments = slice_multiple_note_events_and_ties_to_bundle( + note_events_dict['note_events'], + start_times, + self.seg_len_sec, + ) + + # Support for multi-channel decoding + if UNANNOTATED_PROGRAM in notes_dict['program']: + has_unannotated_segments = [True] * num_segs + else: + has_unannotated_segments = [False] * num_segs + + token_array = self.task_manager.tokenize_note_events_batch(note_event_segments, + start_time_to_zero=False, + sort=True) + # note_token_array = self.task_manager.tokenize_note_events_batch(note_event_segments, + # start_time_to_zero=False, + # sort=True) + # task_token_array = self.task_manager.tokenize_task_events_batch(note_event_segments, + # has_unannotated_segments) + + # Shape: + # processed_audio_array: (num_segs, 1, nframe) + # notes_dict: Dict + # note_token_array: (num_segs, decoding_ch, max_note_token_len) + # task_token_array: (num_segs, decoding_ch, max_task_token_len) + # return torch.from_numpy(audio_segments), notes_dict, torch.from_numpy( + # note_token_array).long(), torch.from_numpy(task_token_array).long() + return torch.from_numpy(audio_segments), notes_dict, torch.from_numpy(token_array).long() + + # # Tokenize/pad note_event_segments -> array of token and mask + # max_len = self.tokenizer.max_length + # token_array = np.zeros((num_segs, max_len), dtype=np.int32) + + # for i, tup in enumerate(list(zip(*note_event_segments.values()))): + # padded_tokens = self.tokenizer.encode_plus(*tup) + # token_array[i, :] = padded_tokens + # return torch.from_numpy(audio_segments), notes_dict, torch.from_numpy(token_array).long() + + def __len__(self) -> int: + return len(self.file_list) + + +def get_eval_dataloader( + dataset_name: str, + split: str = 'validation', + dataloader_config: Dict = {"num_workers": 0}, + task_manager: TaskManager = TaskManager(), + # tokenizer: Optional[EventTokenizerBase] = NoteEventTokenizer('mt3'), + max_num_files: Optional[int] = None, + audio_cfg: Optional[Dict] = None, +) -> DataLoader: + """ + 🎧 get_audio_file_dataloader: + + This function returns a dataloader for AudioFileDataset that returns padded slices + of audio samples with the divisable number of sub-batch size. + """ + data_home = shared_cfg["PATH"]["data_home"] + file_list = f"{data_home}/yourmt3_indexes/{dataset_name}_{split}_file_list.json" + + if audio_cfg is None: + audio_cfg = default_audio_cfg + + ds = AudioFileDataset( + file_list, + task_manager=task_manager, + # tokenizer=tokenizer, + seg_len_frame=int(audio_cfg["input_frames"]), # Default: 32767 + seg_hop_frame=int(audio_cfg["input_frames"]), # Default: 32767 + max_num_files=max_num_files) + dl = DataLoader(ds, batch_size=None, collate_fn=lambda k: k, **dataloader_config) + return dl diff --git a/amt/src/utils/datasets_helper.py b/amt/src/utils/datasets_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..1cfc866d76d34cc34154ec5a7c9048826109161c --- /dev/null +++ b/amt/src/utils/datasets_helper.py @@ -0,0 +1,287 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +import os +import json +import torch +import numpy as np +from torch.utils.data import DistributedSampler +from torch.utils.data import Dataset, Sampler +from torch.utils.data import RandomSampler, WeightedRandomSampler +from operator import itemgetter +from typing import List, Tuple, Union, Iterator, Optional +from config.data_presets import data_preset_single_cfg, data_preset_multi_cfg +from config.config import shared_cfg + + +class DatasetFromSampler(Dataset): + """Dataset to create indexes from `Sampler`. From catalyst library. + + Args: + sampler: PyTorch sampler + """ + + def __init__(self, sampler: Sampler): + """Initialisation for DatasetFromSampler.""" + self.sampler = sampler + self.sampler_list = None + + def __getitem__(self, index: int): + """Gets element of the dataset. + + Args: + index: index of the element in the dataset + + Returns: + Single element by index + """ + if self.sampler_list is None: + self.sampler_list = list(self.sampler) + return self.sampler_list[index] + + def __len__(self) -> int: + """ + Returns: + int: length of the dataset + """ + return len(self.sampler) + + +class DistributedSamplerWrapper(DistributedSampler): + """ + Wrapper over `Sampler` for distributed training. + Allows to use any sampler in distributed mode. + From https://github.com/catalyst-team/catalyst/blob/master/catalyst/data/sampler.py + + It is especially useful in conjunction with + `torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSamplerWrapper instance as a DataLoader + sampler, and load a subset of subsampled data of the original dataset + that is exclusive to it. + + .. note:: + Sampler is assumed to be of constant size. + """ + + def __init__( + self, + sampler, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + ): + """ + + Args: + sampler: Sampler used for subsampling + num_replicas (int, optional): Number of processes participating in + distributed training + rank (int, optional): Rank of the current process + within ``num_replicas`` + shuffle (bool, optional): If true (default), + sampler will shuffle the indices + """ + super(DistributedSamplerWrapper, self).__init__( + DatasetFromSampler(sampler), + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle, + ) + self.sampler = sampler + + def __iter__(self) -> Iterator[int]: + """Iterate over sampler. + + Returns: + python iterator + """ + self.dataset = DatasetFromSampler(self.sampler) + indexes_of_indexes = super().__iter__() + subsampler_indexes = self.dataset + return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes)) + + +def discount_to_target(samples: np.ndarray, target_sum: int) -> np.ndarray: + """Discounts samples to target sum. + + NOTE: this function is deprecated. + + This function adjusts an array of sample values so that their sum equals a target sum, while ensuring + that each element remains greater than or equal to 1 and attempting to maintain a distribution similar + to the original. + + Example 1: + samples = np.array([3, 1, 1, 1, 1, 1]) + target_sum = 7 + discounted_samples = discount_to_target(samples, target_sum) + # [2, 1, 1, 1, 1, 1] + + Example 2: + samples = np.array([3,1, 10, 1, 1, 1]) + target_sum = 7 + # [1, 1, 2, 1, 1, 1] + + Parameters: + samples (np.ndarray): Original array of sample values. + target_sum (int): The desired sum of the sample array. + + Returns: + np.ndarray: Adjusted array of sample values whose sum should equal the target sum, + and where each element is greater than or equal to 1. + """ + samples = samples.copy().astype(int) + if samples.sum() <= target_sum: + samples[0] += 1 + return samples + + while samples.sum() > target_sum: + # indices of all elements larger than 1 + indices_to_discount = np.where(samples > 1)[0] + if indices_to_discount.size == 0: + # No elements left to discount, we cannot reach target_sum without going below 1 + print("Cannot reach target sum without going below 1 for some elements.") + return samples + discount_count = int(min(len(indices_to_discount), samples.sum() - target_sum)) + indices_to_discount = indices_to_discount[:discount_count] + samples[indices_to_discount] -= 1 + return samples + + +def create_merged_train_dataset_info(data_preset_multi: dict, data_home: Optional[os.PathLike] = None): + """Create merged dataset info from data preset multi. + Args: + data_preset_multi (dict): data preset multi + data_home (os.PathLike, optional): path to data home. If None, used the path defined + in config/config.py. + + Returns: + dict: merged dataset info + """ + train_dataset_info = { + "n_datasets": 0, + "n_notes_per_dataset": None, # TODO: not implemented yet... + "n_files_per_dataset": [], + "dataset_names": [], # dataset names by order of merging file lists + "data_split_names": [], # dataset names by order of merging file lists + "index_ranges": [], # index ranges of each dataset in the merged file list + "dataset_weights": None, # pre-defined list of dataset weights for sampling, if available + "merged_file_list": {}, + } + + if data_home is None: + data_home = shared_cfg["PATH"]["data_home"] + assert os.path.exists(data_home) + + for dp in data_preset_multi["presets"]: + train_dataset_info["n_datasets"] += 1 + + dataset_name = data_preset_single_cfg[dp]["dataset_name"] + train_dataset_info["dataset_names"].append(dataset_name) + train_dataset_info["data_split_names"].append(dp) + + # load file list for train split + if isinstance(data_preset_single_cfg[dp]["train_split"], str): + train_split_name = data_preset_single_cfg[dp]["train_split"] + file_list_path = os.path.join(data_home, 'yourmt3_indexes', + f'{dataset_name}_{train_split_name}_file_list.json') + # check if file list exists + if not os.path.exists(file_list_path): + raise ValueError(f"File list {file_list_path} does not exist.") + _file_list = json.load(open(file_list_path, 'r')) + elif isinstance(data_preset_single_cfg[dp]["train_split"], dict): + _file_list = data_preset_single_cfg[dp]["train_split"] + else: + raise ValueError("Invalid train split.") + + # merge file list + start_idx = len(train_dataset_info["merged_file_list"]) + for i, v in enumerate(_file_list.values()): + train_dataset_info["merged_file_list"][start_idx + i] = v + train_dataset_info["n_files_per_dataset"].append(len(_file_list)) + train_dataset_info["index_ranges"].append((start_idx, start_idx + len(_file_list))) + + # set dataset weights + if "weights" in data_preset_multi.keys() and data_preset_multi["weights"] is not None: + train_dataset_info["dataset_weights"] = data_preset_multi["weights"] + assert len(train_dataset_info["dataset_weights"]) == train_dataset_info["n_datasets"] + else: + train_dataset_info["dataset_weights"] = np.ones(train_dataset_info["n_datasets"]) + print("No dataset weights specified, using equal weights for all datasets.") + return train_dataset_info + + +def get_random_sampler(dataset, num_samples): + if torch.distributed.is_initialized(): + return DistributedSamplerWrapper(sampler=RandomSampler(dataset, num_samples=num_samples)) + else: + return RandomSampler(dataset, num_samples=num_samples) + + +def get_weighted_random_sampler(dataset_weights: List[float], + dataset_index_ranges: List[Tuple[int]], + num_samples_per_epoch: Optional[int] = None, + replacement: bool = True) -> torch.utils.data.sampler.Sampler: + """Get distributed weighted random sampler. + Args: + dataset_weights (List[float]): list of dataset weights of n length for n_datasets + dataset_index_ranges (List[Tuple[int]]): list of dataset index ranges + n_samples_per_epoch (Optional[int]): number of samples per epoch, typically length of + entire dataset. Defaults to None. If None, the total number of samples is calculated. + replacement (bool, optional): replacement. Defaults to True. + Returns: + (distributed) weighted random sampler + """ + assert len(dataset_weights) == len(dataset_index_ranges) + + sample_weights = [] + n_total_samples_in_datasets = dataset_index_ranges[-1][1] + if len(dataset_weights) > 1 and len(dataset_index_ranges) > 1: + for dataset_weight, index_range in zip(dataset_weights, dataset_index_ranges): + assert dataset_weight >= 0 + n_samples_in_dataset = index_range[1] - index_range[0] + sample_weight = dataset_weight * (1 - n_samples_in_dataset / n_total_samples_in_datasets) + # repeat the same weight for the number of samples in the dataset + sample_weights += [sample_weight] * (index_range[1] - index_range[0]) + elif len(dataset_weights) == 1 and len(dataset_index_ranges) == 1: + # Single dataset + sample_weights = [1] * n_total_samples_in_datasets + + if num_samples_per_epoch is None: + num_samples_per_epoch = n_total_samples_in_datasets + + sampler = WeightedRandomSampler(weights=sample_weights, num_samples=num_samples_per_epoch, replacement=replacement) + + if torch.distributed.is_initialized(): + return DistributedSamplerWrapper(sampler=sampler) + else: + return sampler + + +def get_list_of_weighted_random_samplers(num_samplers: int, + dataset_weights: List[float], + dataset_index_ranges: List[Tuple[int]], + num_samples_per_epoch: Optional[int] = None, + replacement: bool = True) -> List[torch.utils.data.sampler.Sampler]: + """Get list of distributed weighted random samplers. + Args: + dataset_weights (List[float]): list of dataset weights of n length for n_datasets + dataset_index_ranges (List[Tuple[int]]): list of dataset index ranges + n_samples_per_epoch (Optional[int]): number of samples per epoch, typically length of + entire dataset. Defaults to None. If None, the total number of samples is calculated. + replacement (bool, optional): replacement. Defaults to True. + + Returns: + List[(distributed) weighted random sampler] + """ + assert num_samplers > 0 + samplers = [] + for i in range(num_samplers): + samplers.append( + get_weighted_random_sampler(dataset_weights, dataset_index_ranges, num_samples_per_epoch, replacement)) + return samplers diff --git a/amt/src/utils/datasets_train.py b/amt/src/utils/datasets_train.py new file mode 100644 index 0000000000000000000000000000000000000000..9de23e5f54db01c8c8a9454c44ff446ed5c9a2fa --- /dev/null +++ b/amt/src/utils/datasets_train.py @@ -0,0 +1,660 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +import json +import os +import warnings +import random +from collections import OrderedDict +from itertools import cycle +from typing import Any, Dict, List, Literal, Optional, Tuple, Union + +import numpy as np +import torch +from einops import rearrange +import scipy.stats as stats +from torch.utils.data import DataLoader, Dataset, Sampler + +from config.config import shared_cfg +from config.config import audio_cfg as default_audio_cfg +from utils.audio import get_segments_from_numpy_array, load_audio_file +from utils.augment import (audio_random_submix_processor, combined_survival_and_stop, cross_stem_augment_processor, + intra_stem_augment_processor) +from utils.note2event import slice_multiple_note_events_and_ties_to_bundle, slice_note_events_and_ties +from utils.note2event import pitch_shift_note_events +from utils.note_event_dataclasses import NoteEventListsBundle +from utils.task_manager import TaskManager +from utils.utils import Timer + +UNANNOTATED_PROGRAM = 129 + + +class FixedSizeOrderedDict(OrderedDict): + """ + Dequeue-dict: If the dictionary reaches its maximum size, it will + automatically remove the oldest key-value pair. + """ + + def __init__(self, max_size: int): + super().__init__() + self.max_size: int = max_size + self._id_set: set = set() + self._id_counter: int = 0 + + def __setitem__(self, key: Any, value: Any) -> None: + if key not in self: + if len(self) >= self.max_size: + oldest_key, _ = self.popitem(last=False) + self._id_set.remove(oldest_key) + super().__setitem__(key, value) + self._id_set.add(key) + + def generate_unique_id(self) -> int: + while self._id_counter in self._id_set: + self._id_counter = (self._id_counter + 1) % (self.max_size * 100) + # max_size * 100 is arbitrary, but to ensure that there are enough + # unique ids available when the dictionary is full. + unique_id: int = self._id_counter + return unique_id + + +class CachedAudioDataset(Dataset): + """ + 🎧 CachedAudioDataset: + + This dataset subsamples from a temporal cache of audio data to improve efficiency + during training. + + - The dataset uses a fixed size cache and subsamples from the N most recent batches + stored in the cache. + - This design can help alleviate the disk I/O bottleneck that can occur during + random access of audio multi-track data for augmentation. + + Tips: + - The '__getitem__()' method returns a sub-batch of samples from the cache with a + size specified by the 'subbatch_size' parameter. + - Use 'collate_fn' in your dataloader to get the final batch size + (num_workers * subbatch_size). + - Larger 'subbatch_size' will result in more efficient parallelization. + + 👀 See '_update_cache()' for customized data processing. + + """ + + def __init__( + self, + file_list: Union[str, os.PathLike, Dict], + task_manager: TaskManager = TaskManager(), + num_samples_per_epoch: Optional[int] = None, + fs: int = 16000, + seg_len_frame: int = 32767, + sub_batch_size: int = 16, + num_files_cache: Optional[int] = None, + sample_index_for_init_cache: Optional[List[int]] = None, + random_amp_range: Optional[List[float]] = [0.6, 1.2], + pitch_shift_range: Optional[List[int]] = None, + stem_iaug_prob: Optional[float] = 0.7, + stem_xaug_policy: Optional[Dict] = { + "max_k": 3, # max number of external sources used for cross-stem augmentations + "tau": 0.3, # exponential decay rate for cross-stem augmentation + "alpha": 1.0, # shape parameter for Weibull distribution. Set to 1.0 for exponential distribution. + "max_subunit_stems": 12, # the number of subunit stems to be reduced to this number of stems + "p_include_singing": + 0.8, # probability of including singing for cross augmented examples. if None, use base probaility. + "no_instr_overlap": True, + "no_drum_overlap": True, + "uhat_intra_stem_augment": True, + } + ) -> None: + """ + Args: + file_list: Path to the file list, or a dictionary of file list. e.g. "../../data/yourmt3_indexes/slakh_train_file_list.json", + task_manager: Task manager. + fs: Sampling frequency. + seg_len_frame: Length of the audio segment in frames. + sub_batch_size: Number of segments per sub-batch. + num_files_cache: Number of files to cache. + - If None, max(4, cross_stem_aug_max_k) * sub_batch_size files will be cached. + - When manually setting, it is recommended to use a number larger than the sub_batch_size. + - When using `cross_stem_aug`, it is recommended to set num_files_cache to a + multiple of sub_batch_size for diversity of cross-batch samples. + random_amp_range: Random amplitude range. Default: [0.6, 1.2]. + pitch_shift_range: Pitch shift range. Default: [-2, 2]. If None or [0, 0], pitch shift is disabled. + stem_iaug_prob: Probability of intra-stem augmentation. Bernoulli(p). Default: 0.7. + If None or 1, intra-stem augmentation is disabled. If 0, only one stem is randomly + selected. + stem_xaug_policy: Policy for cross-stem augmentation. If None, cross-stem augmentation + is disabled. Default: { + "max_k": 5, (Max number of external sources used for cross-stem augmentations. If 0, no cross-stem augmentation) + "no_instr_overlap": True, + "no_drum_overlap": True, + "uhat_intra_stem_augment": False, + } + """ + + # load the file list + if isinstance(file_list, dict): + self.file_list = file_list + elif isinstance(file_list, str) or isinstance(file_list, os.PathLike): + with open(file_list, 'r') as f: + fl = json.load(f) + self.file_list = {int(key): value for key, value in fl.items()} + else: + raise ValueError(f'📕 file_list must be a dictionary or a path to \ + a json file.') + + self.num_samples_per_epoch = num_samples_per_epoch + self.fs = fs + self.seg_len_frame = seg_len_frame + self.seg_len_sec = seg_len_frame / fs + + # Task manager + self.task_manager = task_manager # task_manager includes the tokenizer + self.num_decoding_channels = task_manager.num_decoding_channels # By default 1, but can be > 1 for multi-channel decoding + + # Augmentation + self.random_amp_range = random_amp_range + self.stem_iaug_prob = stem_iaug_prob + self.stem_xaug_policy = stem_xaug_policy + if stem_xaug_policy is not None: + # precompute the probability distribution of stopping at each k + self.precomputed_prob_stop_at_k = combined_survival_and_stop(max_k=stem_xaug_policy["max_k"], + tau=stem_xaug_policy["tau"], + alpha=stem_xaug_policy["alpha"])[1] + if pitch_shift_range is not None or pitch_shift_range != [0, 0]: + self.pitch_shift_range = pitch_shift_range + else: + self.pitch_shift_range = None + + # determine the number of samples per file & the number of files to cache + self.sub_batch_size = sub_batch_size + if num_files_cache is None: + if stem_xaug_policy is None: + self.num_files_cache = 4 * sub_batch_size + else: + self.num_files_cache = max(4, stem_xaug_policy["max_k"] + 1) * sub_batch_size + elif isinstance(num_files_cache, int): + if sub_batch_size > num_files_cache: + raise ValueError( + f'📙 num_files_cache {num_files_cache} must be equal or larger than sub_batch_size {sub_batch_size}.' + ) # currently, we do not support sub_batch_size > num_files_cache + if stem_xaug_policy is not None and (sub_batch_size * 2 > num_files_cache): + warnings.warn( + f'📙 When cross_stem_aug_k is not None, sub_batch_size {sub_batch_size} * 2 > num_files_cache {num_files_cache} will decrease diversity in training examples.' + ) + self.num_files_cache = num_files_cache + else: + raise ValueError(f'📙 num_files_cache must be an integer or None. Got {num_files_cache}.') + + self.seg_read_size = 1 # np.ceil(sub_batch_size / num_files_cache).astype(int) + self.num_cached_seg_per_file = sub_batch_size + print(f'📘 caching {self.num_cached_seg_per_file} segments per file.') + + # initialize cache + self._init_cache(index_for_init_cache=sample_index_for_init_cache) + + def __getitem__(self, index: int) -> Tuple[torch.FloatTensor, torch.LongTensor, torch.LongTensor]: + # update cache with new stem and token segments + self._update_cache(index) + + # get sub-batch note_events and audio for segments from the cache + sampled_data, sampled_ids = self._get_rand_segments_from_cache( + num_segments=self.sub_batch_size) # sampled_data is deepcopy of sampled cached instances + + # Stem augmentation and audio submix: processing sampled_data in-place + self._augment_stems_and_audio_submix_in_place(sampled_data, sampled_ids) + # assert "processed_audio_array" in sampled_data.keys() + + # Post-mix augmentation: pitch shift (per-batch) + self._post_mix_augment(sampled_data) + # assert "pitch_shift_steps" in sampled_data.keys() + + # Prepare sub-batch + processed_audio_array = sampled_data['processed_audio_array'] + token_array = self.task_manager.tokenize_task_and_note_events_batch( + programs_segments=sampled_data['programs_segments'], + has_unannotated_segments=sampled_data['has_unannotated_segments'], + note_event_segments=sampled_data['note_event_segments'], + subunit_programs_segments=None, # using subunit is TODO + subunit_note_event_segments=None # using subunit is TODO + ) + + # note_token_array = self.task_manager.tokenize_note_events_batch(sampled_data['note_event_segments']) + # task_token_array = self.task_manager.tokenize_task_events_batch(sampled_data['programs_segments'], + # sampled_data['has_unannotated_segments']) + pitch_shift_steps = sampled_data['pitch_shift_steps'] + + # Shape: + # processed_audio_array: (sub_b, 1, nframe) + # note_token_array: (sub_b, decoding_ch, max_note_token_len) + # task_token_array: (sub_b, decoding_ch, max_task_token_len) + # pitch_shift_steps: (sub_b,) + return torch.FloatTensor(processed_audio_array), torch.LongTensor(token_array), torch.LongTensor( + pitch_shift_steps) + + # Shape: + # processed_audio_array: (sub_b, 1, nframe) + # note_token_array: (sub_b, decoding_ch, max_note_token_len) + # task_token_array: (sub_b, decoding_ch, max_task_token_len) + # pitch_shift_steps: (sub_b,) + # return torch.FloatTensor(processed_audio_array), torch.LongTensor(note_token_array), torch.LongTensor( + # task_token_array), torch.LongTensor(pitch_shift_steps) + + def _post_mix_augment(self, sampled_data: Dict[str, Any]) -> None: + """Post-mix augmentation""" + + if self.pitch_shift_range is None: + sampled_data['pitch_shift_steps'] = [0] * self.sub_batch_size + return + else: + """random pitch shift on note events only. audio will be transformer in the model's layer""" + # random pitch shift steps + sampled_data['pitch_shift_steps'] = np.random.randint( + self.pitch_shift_range[0], self.pitch_shift_range[1] + 1) * np.ones(self.sub_batch_size) + # n_choices = self.pitch_shift_range[1] - self.pitch_shift_range[ + # 0] + 1 + # zero_index = np.argmax( + # np.arange(self.pitch_shift_range[0], + # self.pitch_shift_range[1] + 1) == 0) + # p = np.ones(n_choices) + # p[zero_index] = n_choices * 2 + # sampled_data['pitch_shift_steps'] = np.full( + # self.sub_batch_size, + # np.random.choice(n_choices, 1, p=p / p.sum())[0] + + # self.pitch_shift_range[0], + # dtype=np.int32 + # ) # p = [0.07142857 0.07142857 0.71428571 0.07142857 0.07142857] + + # apply pitch shift to note events and tie note events (in-place) + note_event_segments = sampled_data['note_event_segments'] + for i, (note_events, tie_note_events, start_time) in enumerate(list(zip(*note_event_segments.values()))): + note_events = pitch_shift_note_events(note_events, + sampled_data['pitch_shift_steps'][i], + use_deepcopy=True) + tie_note_events = pitch_shift_note_events(tie_note_events, + sampled_data['pitch_shift_steps'][i], + use_deepcopy=True) + + def _augment_stems_and_audio_submix_in_place(self, sampled_data: Dict[str, Any], sampled_ids: np.ndarray) -> None: + """Augment stems and submix audio""" + + if self.stem_iaug_prob is None or self.stem_iaug_prob == 1.: + # no augmentation at all + audio_random_submix_processor(sampled_data=sampled_data, random_amp_range=self.random_amp_range) + return + elif self.stem_xaug_policy is None or self.stem_xaug_policy["max_k"] == 0: + # intra-stem augmentation only + intra_stem_augment_processor(sampled_data=sampled_data, + random_amp_range=self.random_amp_range, + prob=self.stem_iaug_prob, + submix_audio=True) + return + elif self.stem_xaug_policy is not None and self.stem_xaug_policy["max_k"] > 0: + intra_stem_augment_processor( + sampled_data=sampled_data, + random_amp_range=self.random_amp_range, + prob=self.stem_iaug_prob, + submix_audio=False) # submix_audio=False to postpone audio mixing until cross-stem augmentation + cross_stem_augment_processor( + sampled_data=sampled_data, # X_hat + sampled_ids=sampled_ids, # indices of X, to exclude X from U + get_rand_segments_from_cache_fn=self._get_rand_segments_from_cache, + random_amp_range=self.random_amp_range, + stem_iaug_prob=self.stem_iaug_prob, + stem_xaug_policy=self.stem_xaug_policy, + max_l=self.task_manager.max_note_token_length, + precomputed_prob_stop_at_k=self.precomputed_prob_stop_at_k, + mix_audio=True, + create_subunit_note_events=False) + # assert "subunit_programs_segments" in sampled_data.keys() + # assert "subunit_audio_array" in sampled_data.keys() + # assert "subunit_note_event_segments" in sampled_data.keys() + # assert "programs_segments" in sampled_data.keys() + # assert "note_event_segments" in sampled_data.keys() + # assert "has_unannotated_segments" in sampled_data.keys() + # assert "processed_audio_array" in sampled_data.keys() + else: + raise ValueError(f"Invalid stem_xaug_policy: {self.stem_xaug_policy}") + + def __len__(self): + return len(self.file_list) + + def _get_rand_segments_from_cache( + self, + num_segments: Union[int, Literal["max"]], + use_ordered_read_pos: bool = True, + sample_excluding_ids: Optional[np.ndarray] = None) -> Tuple[NoteEventListsBundle, np.ndarray]: + """Get sampled segments from the cache, accessed by file ids and read positions. + Args: + use_ordered_read_pos: Whether to use the oredered read position generator. Default: True. + If False, the read position is randomly selected. This is used for cross-stem augmentation + source samples. + sample_excluding_ids: IDs to exclude files from sampling. + num_segments: Number of segments to sample. If None, sub_batch_size * cross_stem_aug_max_k. + Returns: + sampled_data: Dict + + Function execution time: 60 µs for sub_bsz=36 with single worker + + NOTE: This function returns mutable instances of the cached data. If you want to modify the + data, make sure to deepcopy the returned data, as in the augment.py/drop_random_stems_from_bundle() + """ + # construct output dict + sampled_data = { + 'audio_segments': [], # list of (1, n_stems, n_frame) with len = sub_batch_size + 'note_event_segments': { + 'note_events': [], # list of List[NoteEvent] + 'tie_note_events': [], # list of List[NoteEvent] + 'start_times': [], # [float, float, ...] + }, # NoteEventBundle dataclass + 'programs_segments': [], # list of List[int] + 'is_drum_segments': [], # list of List[bool] + 'has_stems_segments': [], # List[bool] + 'has_unannotated_segments': [], # List[bool] + } + + # random choice of files from cache + if num_segments == "max": + n = self.sub_batch_size * self.stem_xaug_policy["max_k"] + elif isinstance(num_segments, int): + n = num_segments + else: + raise ValueError(f"num_segments must be int or 'max', but got {num_segments}") + cache_values = np.array(list(self.cache.values())) + if sample_excluding_ids is None: + sampled_ids = np.random.choice( + self.num_files_cache, n, replace=False + ) # The ids are not exactly the keys() of cache, since we reindexed them in the range(0,N) by np.array(dict.values()) + else: + sampled_ids = np.random.permutation(list(set(np.arange(self.num_files_cache)) - + set(sample_excluding_ids)))[:n] + selected_files = cache_values[sampled_ids] + + if use_ordered_read_pos is True: + start = self._get_read_pos() + end = start + self.seg_read_size + for d in selected_files: + if use_ordered_read_pos is False: + start = np.random.randint(0, self.num_cached_seg_per_file - self.seg_read_size + 1) + end = start + self.seg_read_size + sampled_data['audio_segments'].append(d['audio_array'][start:end]) + sampled_data['note_event_segments']['note_events'].extend( + d['note_event_segments']['note_events'][start:end]) + sampled_data['note_event_segments']['tie_note_events'].extend( + d['note_event_segments']['tie_note_events'][start:end]) + sampled_data['note_event_segments']['start_times'].extend( + d['note_event_segments']['start_times'][start:end]) + sampled_data['programs_segments'].append(d['programs']) + sampled_data['is_drum_segments'].append(d['is_drum']) + sampled_data['has_stems_segments'].append(d['has_stems']) + sampled_data['has_unannotated_segments'].append(d['has_unannotated']) + return sampled_data, sampled_ids # Note that the data returned is mutable instance. + + def _update_cache(self, index) -> None: + data = { + 'programs': None, + 'is_drum': None, + 'has_stems': None, + 'has_unannotated': None, + 'audio_array': None, # (n_segs, n_stems, n_frames): non-stem dataset has n_stems=1 + 'note_event_segments': None, # NoteEventBundle dataclass + } + + # Load Audio stems -> slice -> (audio_segments, start_times) + if 'stem_file' in self.file_list[index].keys() and \ + self.file_list[index]['stem_file'] != None: + audio_data = np.load(self.file_list[index]['stem_file'], + allow_pickle=True).tolist() # dict with 'audio_array' having shape (n_stems, n_frames) + data['has_stems'] = True + elif 'mix_audio_file' in self.file_list[index].keys(): + wav_data = load_audio_file(self.file_list[index]['mix_audio_file'], fs=self.fs, dtype=np.float32) + audio_data = { + 'audio_array': wav_data[np.newaxis, :], # (1, n_frames) + 'n_frames': len(wav_data), + 'program': np.array(self.file_list[index]['program'], dtype=np.int32), + 'is_drum': np.array(self.file_list[index]['is_drum'], dtype=np.int32), + } + data['has_stems'] = False + else: + raise ValueError(f'📕 No stem_file or mix_audio_file found in the file list.') + + if UNANNOTATED_PROGRAM in audio_data['program']: + data['has_unannotated'] = True + + # Pad audio data shorter than the segment length + if audio_data['audio_array'].shape[1] < self.seg_len_frame + 2000: + audio_data['audio_array'] = np.pad(audio_data['audio_array'], + ((0, 0), + (0, self.seg_len_frame + 2000 - audio_data['audio_array'].shape[1])), + mode='constant') + audio_data['n_frames'] = audio_data['audio_array'].shape[1] + + data['programs'] = audio_data['program'] + data['is_drum'] = audio_data['is_drum'] + + # Randomly select start frame indices and filtering out empty note_event segments + note_event_data = np.load(self.file_list[index]['note_events_file'], allow_pickle=True).tolist() + note_event_segments = NoteEventListsBundle({'note_events': [], 'tie_note_events': [], 'start_times': []}) + start_frame_indices = [] + attempt = 0 + while len(start_frame_indices) < self.num_cached_seg_per_file and attempt < 5: + sampled_indices = random.sample(range(audio_data['n_frames'] - self.seg_len_frame), + self.num_cached_seg_per_file) + for idx in sampled_indices: + _start_time = idx / self.fs + _end_time = _start_time + self.seg_len_sec + sliced_note_events, sliced_tie_note_events, _ = slice_note_events_and_ties( + note_event_data['note_events'], _start_time, _end_time, False) + if len(sliced_note_events) + len(sliced_tie_note_events) > 0 or attempt == 4: + # non-empty segment or last attempt + start_frame_indices.append(idx) + note_event_segments['note_events'].append(sliced_note_events) + note_event_segments['tie_note_events'].append(sliced_tie_note_events) + note_event_segments['start_times'].append(_start_time) + if len(start_frame_indices) == self.num_cached_seg_per_file: + break + attempt += 1 + assert len(start_frame_indices) == self.num_cached_seg_per_file + + # start_frame_indices = np.random.choice(audio_data['n_frames'] - self.seg_len_frame, + # size=self.num_cached_seg_per_file, + # replace=False) + # start_times = start_frame_indices / self.fs + + # # Load Note events -> slice -> note_event_segments, tie_note_event_segments + # note_event_data = np.load(self.file_list[index]['note_events_file'], allow_pickle=True).tolist() + + # # Extract note event segments for the audio segments, returning a dictionary + # # with keys: 'note_events', 'tie_note_events', and 'start_times'. + # note_event_segments = slice_multiple_note_events_and_ties_to_bundle( + # note_event_data['note_events'], + # start_times, + # self.seg_len_sec, + # ) # note_event_segments: see NoteEventBundle dataclass... + + audio_segments = get_segments_from_numpy_array(audio_data['audio_array'], + self.seg_len_frame, + start_frame_indices=start_frame_indices, + dtype=np.float32) # audio_segments: (n_segs, n_stems, n_frames) + + # Add audio and note events of the sliced segments to data + data['audio_array'] = audio_segments # (n_segs, n_stems, n_frames) + data['note_event_segments'] = note_event_segments # NoteEventBundle dataclass + + # Update the cache + unique_id = self.cache.generate_unique_id() + self.cache[unique_id] = data # push + + def _init_cache(self, index_for_init_cache: Optional[List[int]] = None): + with Timer() as t: + self.cache = FixedSizeOrderedDict(max_size=self.num_files_cache) + print(f'💿 Initializing cache with max_size={self.cache.max_size}') + if index_for_init_cache is not None: + assert len(index_for_init_cache) >= self.num_files_cache + for i in index_for_init_cache[-self.num_files_cache:]: + self._update_cache(i) + else: + rand_ids = np.random.choice(np.arange(len(self)), size=self.num_files_cache, replace=False) + for i in rand_ids: + self._update_cache(i) + + # Initialize an infinite cache read position generator + self._cache_read_pos_generator = cycle(np.arange(0, self.num_cached_seg_per_file, self.seg_read_size)) + t.print_elapsed_time() + + def _get_read_pos(self): + return next(self._cache_read_pos_generator) + + +def collate_fn(batch: Tuple[torch.FloatTensor, torch.LongTensor], + local_batch_size: int) -> Tuple[torch.FloatTensor, torch.LongTensor]: + """ + This function is used to get the final batch size + batch: (np.ndarray of shape (B, b, 1, T), np.ndarray of shape (B, b, T)) + where b is the sub-batch size and B is the batch size. + """ + audio_segments = torch.vstack([b[0] for b in batch]) + note_tokens = torch.vstack([b[1] for b in batch]) + return (audio_segments, note_tokens) + + +# def collate_fn(batch: Tuple[torch.FloatTensor, torch.LongTensor, torch.LongTensor], +# local_batch_size: int) -> Tuple[torch.FloatTensor, torch.LongTensor, torch.LongTensor]: +# """ +# This function is used to get the final batch size +# batch: (np.ndarray of shape (B, b, 1, T), np.ndarray of shape (B, b, T)) +# where b is the sub-batch size and B is the batch size. +# """ +# audio_segments = torch.vstack([b[0] for b in batch]) +# note_tokens = torch.vstack([b[1] for b in batch]) +# task_tokens = torch.vstack([b[2] for b in batch]) +# return (audio_segments, note_tokens, task_tokens) + + +def get_cache_data_loader( + dataset_name: Optional[str] = None, + split: Optional[str] = None, + file_list: Optional[Dict] = None, + sub_batch_size: int = 32, + task_manager: TaskManager = TaskManager(), + stem_iaug_prob: Optional[float] = 0.7, + stem_xaug_policy: Optional[Dict] = { + "max_k": 3, + "tau": 0.3, + "alpha": 1.0, + "max_subunit_stems": 12, + "p_include_singing": 0.8, + "no_instr_overlap": True, + "no_drum_overlap": True, + "uhat_intra_stem_augment": True, + }, + random_amp_range: Optional[List[float]] = [0.6, 1.2], + pitch_shift_range: Optional[List[int]] = None, + shuffle: Optional[bool] = True, + sampler: Optional[Sampler] = None, + audio_cfg: Optional[Dict] = None, + dataloader_config: Dict = {"num_workers": 0}) -> DataLoader: + """ + This function returns a DataLoader object that can be used to iterate over the dataset. + Args: + dataset_name: str, name of the dataset. + split: str, name of the split. + - dataset_name and split are used to load the file list. + - if file_list is not None, and dataset_name and split should be None, it will be used to load the dataset. + file_list: dict, file list of the dataset. + sub_batch_size: int, number of segments per sub-batch. + task_manager: TaskManager, See `utils/task_manager.py`. + stem_iaug_prob: float, probability of intra-stem augmentation. Bernoulli(p). Default: 0.7. + If None or 1, intra-stem augmentation is disabled. If 0, only one stem is randomly selected. + stem_xaug_policy: dict, policy for cross-stem augmentation. If None, cross-stem augmentation + is disabled. + random_amp_range: list, random amplitude range. Default: [0.6, 1.2]. + pitch_shift_range: list, pitch shift range. Default: [-2, 2]. None or [0, 0] for no pitch shift. + shuffle (bool): whether to shuffle the dataset. Default: True. However, shuffle is ignored when sampler is specified. + sampler: Sampler, defines the strategy to draw samples from the dataset. If specified, shuffle must be False. + audio_cfg: dict, audio configuration. + dataloader_config: dict, other arguments for PyTorch native DataLoader class. + + Returns: + DataLoader object. + """ + if dataset_name is None and split is None and file_list is None: + raise ValueError("Error: all arguments cannot be None.") + elif (dataset_name is not None and split is not None and file_list is None) and isinstance( + split, str) and isinstance(dataset_name, str): + data_home = shared_cfg["PATH"]["data_home"] + file_list = f"{data_home}/yourmt3_indexes/{dataset_name}_{split}_file_list.json" + assert os.path.exists(file_list) + elif (dataset_name is None and split is None and file_list is not None) and isinstance(file_list, dict): + pass + else: + raise ValueError("Error: invalid combination of arguments.") + + # If sampler is specified, initialize cache using sampler, otherwise random initialization. + if sampler is not None: + sample_index_for_init_cache = list(sampler) + else: + sample_index_for_init_cache = None + + if audio_cfg is None: + audio_cfg = default_audio_cfg + + ds = CachedAudioDataset( + file_list, + task_manager=task_manager, + seg_len_frame=int(audio_cfg['input_frames']), + sub_batch_size=sub_batch_size, + num_files_cache=None, # auto + random_amp_range=random_amp_range, + pitch_shift_range=pitch_shift_range, + stem_iaug_prob=stem_iaug_prob, + stem_xaug_policy=stem_xaug_policy, + sample_index_for_init_cache=sample_index_for_init_cache, + ) + batch_size = None + _collate_fn = None + + return DataLoader(ds, + batch_size=batch_size, + collate_fn=_collate_fn, + sampler=sampler, + shuffle=None if sampler is not None else shuffle, + **dataloader_config) + + +# def speed_benchmark_cache_audio_dataset(): +# # ds = CachedAudioDataset(sub_batch_size=32, num_files_cache=8) +# # +# # Audio-only w/ single worker: +# # %timeit ds.__getitem__(0) # 61.9ms b16 c16; 76.1ms b16 c4; 77.2ms b16 c1 +# # %timeit ds.__getitem__(0) # 133ms b64 c64; 118ms b64 c32; 114ms b64 c16 +# # %timeit ds.__getitem__(0) # 371ms b128 c128; 205ms b128 c64; 200ms b128 c32; 165ms b128 c16 +# # +# ds = CachedAudioDataset(sub_batch_size=128, num_files_cache=16, tokenizer=NoteEventTokenizer()) +# # Audio + Tokenization w/ single worker: +# # %timeit ds.__getitem__(0) # 91.2ms b16 c16; 90.8ms b16 c4; 98.6ms b16 c1 +# # %timeit ds.__getitem__(0) # 172ms b64 c64; 158ms b64 c32; 158ms b64 c16 +# # %timeit ds.__getitem__(0) # 422ms b128 c128; 278ms b128 c64; 280ms b128 c32; 269ms b128 c16 + +# # dl = DataLoader( +# # ds, batch_size=None, shuffle=True, collate_fn=collate_fn, num_workers=0) + +# dl = get_cache_data_loader( +# 'slakh', +# tokenizer=NoteEventTokenizer('mt3'), +# sub_batch_size=32, +# global_batch_size=32, +# num_workers=0) + +# with Timer() as t: +# for i, data in enumerate(dl): +# if i > 4: +# break +# print(i) +# t.print_elapsed_time() diff --git a/amt/src/utils/event2note.py b/amt/src/utils/event2note.py new file mode 100644 index 0000000000000000000000000000000000000000..9df38be5093ea2c41aa2232a2b777a15173a278d --- /dev/null +++ b/amt/src/utils/event2note.py @@ -0,0 +1,300 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +"""event2note.py: + +Event to NoteEvent: +• event2note_event + +NoteEvent to Note: +• note_event2note +• merge_zipped_note_events_and_ties_to_notes + +""" +import warnings +from collections import Counter +from typing import List, Tuple, Optional, Dict, Counter + +from utils.note_event_dataclasses import Note, NoteEvent +from utils.note_event_dataclasses import Event +from utils.note2event import validate_notes, trim_overlapping_notes + +MINIMUM_OFFSET_SEC = 0.01 + +DECODING_ERR_TYPES = [ + 'decoding_time', 'Err/Missing prg in tie', 'Err/Missing tie', 'Err/Shift out of range', 'Err/Missing prg', + 'Err/Missing vel', 'Err/Multi-tie type 1', 'Err/Multi-tie type 2', 'Err/Unknown event', 'Err/onset not found', + 'Err/active ne incomplete', 'Err/merging segment tie', 'Err/long note > 10s' +] + + +def event2note_event(events: List[Event], + start_time: float = 0.0, + sort: bool = True, + tps: int = 100) -> Tuple[List[NoteEvent], List[NoteEvent], List[Tuple[int]], Counter[str]]: + """Convert events to note events. + + Args: + events: A list of events. + start_time: The start time of the segment. + sort: Whether to sort the note events. + tps: Ticks per second. + + Returns: + List[NoteEvent]: A list of note events. + List[NoteEvent]: A list of tie note events. + List[Tuple[int]]: A list of last activity of segment. [(program, pitch), ...]. This is useful + for validating notes within a batch of segments extracted from a file. + Counter[str]: A dictionary of error counters. + """ + assert (start_time >= 0.) + + # Collect tie events + tie_index = program_state = None + tie_note_events = [] + last_activity = [] # For activity check and last activity of segment. [(program, pitch), ...] + error_counter = {} # Add a dictionary to count the errors by their types + + for i, e in enumerate(events): + try: + if e.type == 'tie': + tie_index = i + break + if e.type == 'shift': + break + elif e.type == 'program': + program_state = e.value + elif e.type == 'pitch': + if program_state is None: + raise ValueError('Err/Missing prg in tie') + tie_note_events.append( + NoteEvent(is_drum=False, program=program_state, time=None, velocity=1, pitch=e.value)) + last_activity.append((program_state, e.value)) # (program, pitch) + except ValueError as ve: + error_type = str(ve) + error_counter[error_type] = error_counter.get(error_type, 0.) + 1 + + try: + if tie_index is None: + raise ValueError('Err/Missing tie') + else: + events = events[tie_index + 1:] + except ValueError as ve: + error_type = str(ve) + error_counter[error_type] = error_counter.get(error_type, 0.) + 1 + return [], [], [], error_counter + + # Collect main events: + note_events = [] + velocity_state = None + start_tick = round(start_time * tps) + tick_state = start_tick + # keep the program_state of last tie event... + + for e in events: + try: + if e.type == 'shift': + if e.value <= 0 or e.value > 1000: + raise ValueError('Err/Shift out of range') + # tick_state += e.value + tick_state = start_tick + e.value + elif e.type == 'drum': + note_events.append( + NoteEvent(is_drum=True, program=128, time=tick_state / tps, velocity=1, pitch=e.value)) + elif e.type == 'program': + program_state = e.value + elif e.type == 'velocity': + velocity_state = e.value + elif e.type == 'pitch': + if program_state is None: + raise ValueError('Err/Missing prg') + elif velocity_state is None: + raise ValueError('Err/Missing vel') + # Check activity + if velocity_state > 0: + last_activity.append((program_state, e.value)) # (program, pitch) + elif velocity_state == 0 and (program_state, e.value) in last_activity: + last_activity.remove((program_state, e.value)) + else: + # print(f'tick_state: {tick_state}') # <-- This displays unresolved offset errors!! + raise ValueError('Err/Note off without note on') + note_events.append( + NoteEvent(is_drum=False, + program=program_state, + time=tick_state / tps, + velocity=velocity_state, + pitch=e.value)) + elif e.type == 'EOS': + break + elif e.type == 'PAD': + continue + elif e.type == 'UNK': + continue + elif e.type == 'tie': + if tick_state == start_tick: + raise ValueError('Err/Multi-tie type 1') + else: + raise ValueError('Err/Multi-tie type 2') + else: + raise ValueError(f'Err/Unknown event') + except ValueError as ve: + error_type = str(ve) + error_counter[error_type] = error_counter.get(error_type, 0.) + 1 + + if sort: + note_events.sort(key=lambda n_ev: (n_ev.time, n_ev.is_drum, n_ev.program, n_ev.velocity, n_ev.pitch)) + tie_note_events.sort(key=lambda n_ev: (n_ev.is_drum, n_ev.program, n_ev.pitch)) + + return note_events, tie_note_events, last_activity, error_counter + + +def note_event2note( + note_events: List[NoteEvent], + tie_note_events: Optional[List[NoteEvent]] = None, + sort: bool = True, + fix_offset: bool = True, + trim_overlap: bool = True, +) -> Tuple[List[Note], Counter[str]]: + """Convert note events to notes. + + Returns: + List[Note]: A list of merged note events. + Counter[str]: A dictionary of error counters. + """ + + notes = [] + active_note_events = {} + + error_counter = {} # Add a dictionary to count the errors by their types + + if tie_note_events is not None: + for ne in tie_note_events: + active_note_events[(ne.pitch, ne.program)] = ne + + if sort: + note_events.sort(key=lambda ne: (ne.time, ne.is_drum, ne.pitch, ne.velocity, ne.program)) + + for ne in note_events: + try: + if ne.time == None: + continue + elif ne.is_drum: + if ne.velocity == 1: + notes.append( + Note(is_drum=True, + program=128, + onset=ne.time, + offset=ne.time + MINIMUM_OFFSET_SEC, + pitch=ne.pitch, + velocity=1)) + else: + continue + elif ne.velocity == 1: + active_ne = active_note_events.get((ne.pitch, ne.program)) + if active_ne is not None: + active_note_events.pop((ne.pitch, ne.program)) + notes.append( + Note(False, active_ne.program, active_ne.time, ne.time, active_ne.pitch, active_ne.velocity)) + active_note_events[(ne.pitch, ne.program)] = ne + + elif ne.velocity == 0: + active_ne = active_note_events.pop((ne.pitch, ne.program), None) + if active_ne is not None: + notes.append( + Note(False, active_ne.program, active_ne.time, ne.time, active_ne.pitch, active_ne.velocity)) + else: + raise ValueError('Err/onset not found') + except ValueError as ve: + error_type = str(ve) + error_counter[error_type] = error_counter.get(error_type, 0.) + 1 + + for ne in active_note_events.values(): + try: + if ne.velocity == 1: + if ne.program == None or ne.pitch == None: + raise ValueError('Err/active ne incomplete') + elif ne.time == None: + continue + else: + notes.append( + Note(is_drum=False, + program=ne.program, + onset=ne.time, + offset=ne.time + MINIMUM_OFFSET_SEC, + pitch=ne.pitch, + velocity=1)) + except ValueError as ve: + error_type = str(ve) + error_counter[error_type] = error_counter.get(error_type, 0.) + 1 + + if fix_offset: + for n in list(notes): + try: + if n.offset - n.onset > 10: + n.offset = n.onset + MINIMUM_OFFSET_SEC + raise ValueError('Err/long note > 10s') + except ValueError as ve: + error_type = str(ve) + error_counter[error_type] = error_counter.get(error_type, 0.) + 1 + + if sort: + notes.sort(key=lambda note: (note.onset, note.is_drum, note.program, note.velocity, note.pitch)) + + if fix_offset: + notes = validate_notes(notes, fix=True) + + if trim_overlap: + notes = trim_overlapping_notes(notes, sort=True) + + return notes, error_counter + + +def merge_zipped_note_events_and_ties_to_notes(zipped_note_events_and_ties, + force_note_off_missing_tie=True, + fix_offset=True) -> Tuple[List[Note], Counter[str]]: + """Merge zipped note events and ties. + + Args: + zipped_note_events_and_ties: A list of tuples of (note events, tie note events, last_activity, start time). + force_note_off_missing_tie: Whether to force note off for missing tie note events. + fix_offset: Whether to fix the offset of notes. + + Returns: + List[Note]: A list of merged note events. + Counter[str]: A dictionary of error counters. + """ + merged_note_events = [] + prev_last_activity = None + seg_merge_err_cnt = Counter() + for nes, tie_nes, last_activity, start_time in zipped_note_events_and_ties: + if prev_last_activity is not None and force_note_off_missing_tie: + # Check mismatch between prev_last_activity and current tie_note_events + prog_pitch_tie = set([(ne.program, ne.pitch) for ne in tie_nes]) + for prog_pitch_pla in prev_last_activity: # (program, pitch) of previous last active notes + if prog_pitch_pla not in prog_pitch_tie: + # last acitve notes of previous segment is missing in tie information. + # We create a note off event for these notes at the beginning of current note events. + merged_note_events.append( + NoteEvent(is_drum=False, + program=prog_pitch_pla[0], + time=start_time, + velocity=0, + pitch=prog_pitch_pla[1])) + seg_merge_err_cnt['Err/merging segment tie'] += 1 + else: + pass + merged_note_events += nes + prev_last_activity = last_activity + + # merged_note_events to notes + notes, err_cnt = note_event2note(merged_note_events, tie_note_events=None, fix_offset=fix_offset) + + # gather error counts + err_cnt.update(seg_merge_err_cnt) + return notes, err_cnt diff --git a/amt/src/utils/event_codec.py b/amt/src/utils/event_codec.py new file mode 100644 index 0000000000000000000000000000000000000000..a0915e39cbaf960b43851a6e400f93527b0a05e5 --- /dev/null +++ b/amt/src/utils/event_codec.py @@ -0,0 +1,297 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +""" event_codec.py: Encodes and decodes events to/from indices + +🚀 Improvements: + +• Encoding uses a precomputed dictionary in Python. This achieves a time + complexity of O(1). +• Decoding has time complexity of O(1), while the original code from MT3 + (Gardner et al.) has a time complexity of O(n). + +In practice, the performance of this optimized code was 4x faster for encoding + and decoding compared to the original code. + +""" +from typing import List, Dict, Tuple, Optional +from utils.note_event_dataclasses import Event, EventRange +# from bisect import bisect_right + + +class FastCodec: + """ Fast Encoding and decoding Event. """ + + def __init__(self, + special_tokens: List[str], + max_shift_steps: int, + event_ranges: List[EventRange], + program_vocabulary: Optional[Dict] = None, + drum_vocabulary: Optional[Dict] = None, + extra_tokens: List[str] = [], + name: Optional[str] = None): + """ Initializes the FastCodec object. + + :param special_tokens: List of special tokens to include in the vocabulary. + :param max_shift_steps: The maximum number of steps to shift. + :param event_ranges: List of EventRange objects. + :param instr_vocabulary: A dictionary of instrument groups. Please see config/vocabulary.py + We apply vocabulary only for encoding in training. + :param drum_vocabulary: A dictionary of drum mapping. Please see config/vocabulary.py + We apply vocabulary only for encoding in training. + + :param name: Name of the codec. + """ + # Store the special tokens and event ranges. + self.special_tokens = special_tokens + self._special_token_ranges = [] + self._extra_token_ranges = [] + + for token in special_tokens: + self._special_token_ranges.append(EventRange(token, 0, 0)) + for token in extra_tokens: + self._extra_token_ranges.append(EventRange(token, 0, 0)) + self._shift_range = EventRange(type='shift', min_value=0, max_value=max_shift_steps - 1) + self._event_ranges = self._special_token_ranges + [self._shift_range] + event_ranges + self._extra_token_ranges + # Ensure all event types have unique names. + assert len(self._event_ranges) == len(set([er.type for er in self._event_ranges])) + + # Store the name of the codec, so that we can identify it in tokenizer. + self._name = name + + # Create dictionary for decoding + self._decode_dict = {} + self._encode_dict = {} + self._event_type_range_dict = {} + idx = 0 + for er in self._event_ranges: + start_idx = idx + for value in range(er.min_value, er.max_value + 1): + self._decode_dict[idx] = Event(type=er.type, value=value) + self._encode_dict[(er.type, value)] = idx + idx += 1 + end_idx = idx - 1 + self._event_type_range_dict[er.type] = (start_idx, end_idx) + + self._num_classes = idx + + # Create inverse vocabulary for instrument groups + if program_vocabulary is not None: + self.inverse_vocab_program = {} + self._create_inverse_vocab_program(program_vocabulary) + else: + self.inverse_vocab_program = None + + # Create inverse vocabulary for drum mapping + if drum_vocabulary is not None: + self.inverse_vocab_drum = {} + self._create_inverse_vocab_drum(drum_vocabulary) + else: + self.inverse_vocab_drum = None + + @property + def num_classes(self) -> int: + return self._num_classes + + def _create_inverse_vocab_program(self, vocab): + for key, values in vocab.items(): + for value in values: + self.inverse_vocab_program[value] = values[0] + + def _create_inverse_vocab_drum(self, vocab): + for key, values in vocab.items(): + for value in values: + self.inverse_vocab_drum[value] = values[0] + + def encode_event(self, event: Event) -> int: + """Encode an event to an index.""" + if (event.type, event.value) not in self._encode_dict: + raise ValueError(f'Unknown event type: {event.type} or value: {event.value}') + + if event.type == 'program' and self.inverse_vocab_program is not None: + # If the event value is not in the vocabulary, use the original value + _event_value = self.inverse_vocab_program.get(event.value, event.value) + return self._encode_dict[(event.type, _event_value)] + elif event.type == 'drum' and self.inverse_vocab_drum is not None: + _event_value = self.inverse_vocab_drum.get(event.value, event.value) + return self._encode_dict[(event.type, _event_value)] + else: + return self._encode_dict[(event.type, event.value)] + + def event_type_range(self, event_type: str) -> Tuple[int, int]: + """Return [min_id, max_id] for an event type.""" + if event_type not in self._event_type_range_dict: + raise ValueError(f'Unknown event type: {event_type}') + + return self._event_type_range_dict[event_type] + + def decode_event_index(self, index: int) -> Event: + """Decode an event index to an Event.""" + if index < 0 or index >= self.num_classes: + raise ValueError(f'Unknown event index: {index}') + decoded_event = self._decode_dict[index] + + # Create a new event with the same type and value + return Event(type=decoded_event.type, value=decoded_event.value) + + +# class FastCodec: +# """ Fast Encoding and decoding Event. """ + +# def __init__(self, +# special_tokens: List[str], +# max_shift_steps: int, +# event_ranges: List[EventRange], +# name: Optional[str] = None): +# """ Initializes the FastCodec object. + +# :param special_tokens: List of special tokens to include in the vocabulary. +# :param max_shift_steps: The maximum number of steps to shift. +# :param event_ranges: List of EventRange objects. +# """ +# # Store the special tokens and event ranges. +# self.special_tokens = special_tokens +# self._special_token_ranges = [] + +# for token in special_tokens: +# self._special_token_ranges.append(EventRange(token, 0, 0)) +# self._shift_range = EventRange( +# type='shift', min_value=0, max_value=max_shift_steps - 1) +# self._event_ranges = self._special_token_ranges + [self._shift_range +# ] + event_ranges +# # Ensure all event types have unique names. +# assert len(self._event_ranges) == len( +# set([er.type for er in self._event_ranges])) + +# # Precompute cumulative offsets. +# self._cumulative_offsets = [0] +# for er in self._event_ranges: +# self._cumulative_offsets.append(self._cumulative_offsets[-1] + +# er.max_value - er.min_value + 1) + +# # Create event type to range and offset mapping. +# self._event_type_to_range_offset = {} +# for er, offset in zip(self._event_ranges, self._cumulative_offsets): +# self._event_type_to_range_offset[er.type] = (er, offset) + +# # Store the name of the codec, so that we can identify it in tokenizer. +# self._name = name + +# @property +# def num_classes(self) -> int: +# return self._cumulative_offsets[-1] + +# def encode_event(self, event: Event) -> int: +# """Encode an event to an index.""" +# if event.type not in self._event_type_to_range_offset: +# raise ValueError(f'Unknown event type: {event.type}') + +# er, offset = self._event_type_to_range_offset[event.type] + +# if not er.min_value <= event.value <= er.max_value: +# raise ValueError( +# f'Event value {event.value} is not within valid range ' +# f'[{er.min_value}, {er.max_value}] for type {event.type}') +# return offset + event.value - er.min_value + +# def event_type_range(self, event_type: str) -> Tuple[int, int]: +# """Return [min_id, max_id] for an event type.""" +# offset = 0 +# for er in self._event_ranges: +# if event_type == er.type: +# return offset, offset + (er.max_value - er.min_value) +# offset += er.max_value - er.min_value + 1 + +# raise ValueError(f'Unknown event type: {event_type}') + +# def decode_event_index(self, index: int) -> Event: +# """Decode an event index to an Event.""" +# if index < 0 or index >= self.num_classes: +# raise ValueError(f'Unknown event index: {index}') + +# # Find the event range using binary search. +# range_idx = bisect_right(self._cumulative_offsets, index) - 1 +# er = self._event_ranges[range_idx] +# offset = self._cumulative_offsets[range_idx] + +# return Event(type=er.type, value=er.min_value + index - offset) + +# Original code +# +# https://github.com/magenta/mt3/blob/main/mt3/event_codec.py +# Copyright 2022 The MT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# class Codec: +# """Encode and decode events.""" +# +# def __init__(self, special_tokens: List[str], max_shift_steps: int, +# event_ranges: List[EventRange]): +# """Define Codec. +# """ +# self._special_token_ranges = [] +# for token in special_tokens: +# self._special_token_ranges.append(EventRange(token, 0, 0)) +# self._shift_range = EventRange( +# type='shift', min_value=0, max_value=max_shift_steps - 1) +# self._event_ranges = self._special_token_ranges + [self._shift_range +# ] + event_ranges +# # Ensure all event types have unique names. +# assert len(self._event_ranges) == len( +# set([er.type for er in self._event_ranges])) + +# @property +# def num_classes(self) -> int: +# return sum(er.max_value - er.min_value + 1 for er in self._event_ranges) + +# def encode_event(self, event: Event) -> int: +# """Encode an event to an index.""" +# offset = 0 +# for er in self._event_ranges: +# if event.type == er.type: +# if not er.min_value <= event.value <= er.max_value: +# raise ValueError( +# f'Event value {event.value} is not within valid range ' +# f'[{er.min_value}, {er.max_value}] for type {event.type}' +# ) +# return offset + event.value - er.min_value +# offset += er.max_value - er.min_value + 1 + +# raise ValueError(f'Unknown event type: {event.type}') + +# def event_type_range(self, event_type: str) -> Tuple[int, int]: +# """Return [min_id, max_id] for an event type.""" +# offset = 0 +# for er in self._event_ranges: +# if event_type == er.type: +# return offset, offset + (er.max_value - er.min_value) +# offset += er.max_value - er.min_value + 1 + +# raise ValueError(f'Unknown event type: {event_type}') + +# def decode_event_index(self, index: int) -> Event: +# """Decode an event index to an Event.""" +# offset = 0 +# for er in self._event_ranges: +# if offset <= index <= offset + er.max_value - er.min_value: +# return Event(type=er.type, value=er.min_value + index - offset) +# offset += er.max_value - er.min_value + 1 + +# raise ValueError(f'Unknown event index: {index}') diff --git a/amt/src/utils/metrics.py b/amt/src/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..85d940ea8b945d38b89263809cf99344fa51f3b6 --- /dev/null +++ b/amt/src/utils/metrics.py @@ -0,0 +1,477 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +"""metrics.py""" + +from typing import List, Any, Dict, Optional, Tuple, Union +import numpy as np +import copy +from torch.nn import Module + +from utils.note_event_dataclasses import NoteEvent, Note +from utils.note2event import sort_notes, notes2pc_notes +from utils.event2note import note_event2note +from sklearn.metrics import average_precision_score +from utils.metrics_helper import (f1_measure, round_float, mir_eval_note_f1, mir_eval_frame_f1, mir_eval_melody_metric, + extract_pitches_intervals_from_notes, extract_frame_time_freq_from_notes) +from torchmetrics import MeanMetric, SumMetric + + +class UpdatedMeanMetric(MeanMetric): + """ + A wrapper of torchmetrics.MeanMetric to support reset and update separately. + """ + + def __init__(self, nan_strategy: str = 'ignore', **kwargs) -> None: + super().__init__(nan_strategy=nan_strategy, **kwargs) + self._updated = False + + def update(self, *args, **kwargs): + super().update(*args, **kwargs) + self._updated = True + + def is_updated(self): + return self._updated + + +class UpdatedSumMetric(SumMetric): + """ + A wrapper of torchmetrics.SumMetric to support reset and update separately. + """ + + def __init__(self, nan_strategy: str = 'ignore', **kwargs) -> None: + super().__init__(nan_strategy=nan_strategy, **kwargs) + self._updated = False + + def update(self, *args, **kwargs): + super().update(*args, **kwargs) + self._updated = True + + def is_updated(self): + return self._updated + + +class AMTMetrics(Module): + """ + Automatic music transcription (AMT) evaluation metrics for music transcription + tasks with DDP support, following the convention of AMT. The average of file-wise + metrics is calculated. + + Metrics: + -------- + + Instrument-agnostic note onset and note offset metrics: + (Drum notes are generally excluded) + + - onset_f: the most conventional, often called Note F1 + - offset_f: a pair of onset + offset matching metric + + Multi-instrument note on-offset Macro-micro F1 metric, multi-F1 (of MT3): + + - multi_f: counts for onset + offset + program (instrument class) matching. + For drum notes, we only count onset. macro-micro means that we + calculate weighted precision and recall by counting each note + instrument class per file, and calcualte micro F1. We then + calculate average F1 for all files with equal weights (Macro). + + Instrument-group note onset and offset metrics are defined by extra_classes: + + e.g. extra_classes = ['piano', 'guitar'] + - onset_f_piano: piano instrument + - onset_f_guitar: guitar instrument + - offset_f_piano: piano instrument + - offset_f_guitar: guitar instrument + also p, r metrics follow... + + + Usage: + ------ + + Each metric instance can be individually updated and reset for computation. + + ``` + my_metric = AMTMetrics() + my_metric.onset_f.update(0.5) + my_metric.onset_f(0.5) # same + my_metric.onset_f(0, weight=1.0) # same and weighted by 1.0 (default) + my_metric.onset_f.compute() # return 0.333.. + my_metric.onset_f.reset() # reset the metric + + ``` + • {attribute}.update(value: float, weight: Optional[float]): Here weight is an + optional argument for weighted average. + • {attribute}.(...): Same as update method. + • {attribute}.compute(): Return the average value of the metric. + • {attribute}.reset(): Reset the metric. + + Class methods: + --------------- + + ``` + d = {'onset_f': 0.5, 'offset_f': 0.5} + my_metric.bulk_update(d) + d = {'onset_f': {'value': 0.5, 'weight': 1.0}, 'offset_f': {'value': 0.5, 'weight': 1.0}} + my_metric.onset_f.update(d) + ``` + + • bulk_update(metrics: Dict[str, Union[float, Dict[str, float]]]): Update metrics with a + dictionary as an argument. + • bulk_compute(): Return a dictionary of any non-empty metrics with average values. + • bulk_reset(): Reset all metrics. + + """ + + def __init__(self, + prefix: str = '', + nan_strategy: str = 'ignore', + extra_classes: Optional[List[str]] = None, + extra_metrics: Optional[List[str]] = None, + error_types: Optional[List[str]] = None, + **kwargs) -> None: + """ + Args: + suffix: prefix for the metric name, e.g. 'val' or 'test'. '_' will be added automatically. + nan_strategy: 'warn' or 'raise' or 'ignore' + + """ + super().__init__(**kwargs) + self._prefix = prefix + self.nan_strategy = nan_strategy + + # Instrument-agnostic Note onsets and Note on-offset metrics for non-drum notes + self.onset_f = UpdatedMeanMetric(nan_strategy=nan_strategy) + self.offset_f = UpdatedMeanMetric(nan_strategy=nan_strategy) + + # Instrument-agnostic Frame F1 (skip in validation) + self.frame_f = UpdatedMeanMetric(nan_strategy=nan_strategy) + self.frame_f_pc = UpdatedMeanMetric(nan_strategy=nan_strategy) + + # Drum Onset metrics + self.onset_f_drum = UpdatedMeanMetric(nan_strategy=nan_strategy) + + # Multi F1 (Macro-micro F1 of MT3) + self.multi_f = UpdatedMeanMetric(nan_strategy=nan_strategy) + + # Initialize extra metrics for instrument macro F1 + self.extra_classes = extra_classes + if extra_classes is not None: + for class_name in extra_classes: + if not hasattr(self, class_name): + for onoff in ['onset', 'offset']: + for fpr in ['f']: + setattr(self, onoff + '_' + fpr + '_' + class_name, + UpdatedMeanMetric(nan_strategy=nan_strategy)) + # setattr(self, class_name, UpdatedMeanMetric(nan_strategy=nan_strategy)) + else: + raise ValueError(f"Metric '{class_name}' already exists.") + + # Initialize extra metrics for instruments(F is computed later) + self.extra_classes = extra_classes + if extra_classes is not None: + for class_name in extra_classes: + if not hasattr(self, class_name): + for onoff in ['micro_onset', 'micro_offset']: + for fpr in ['p', 'r']: + setattr(self, onoff + '_' + fpr + '_' + class_name, + UpdatedMeanMetric(nan_strategy=nan_strategy)) + # setattr( + # self, onoff + '_f_' + class_name, None + # ) # micro_onset_f and micro_offset_f for each instrument + else: + raise ValueError(f"Metric '{class_name}' already exists.") + + # Initialize drum micro P,R (F is computed later) + self.micro_onset_p_drum = UpdatedMeanMetric(nan_strategy=nan_strategy) + self.micro_onset_r_drum = UpdatedMeanMetric(nan_strategy=nan_strategy) + + # Initialize extra metrics directly + if extra_metrics is not None: + for metric_name in extra_metrics: + setattr(self, metric_name, UpdatedMeanMetric(nan_strategy=nan_strategy)) + + # Initialize error counters + self.error_types = error_types + if error_types is not None: + for error_type in error_types: + setattr(self, error_type, UpdatedMeanMetric(nan_strategy=nan_strategy)) + + def bulk_update(self, metrics: Dict[str, Union[float, Dict[str, float], Tuple[float, ...]]]) -> None: + """ Update metrics with a dictionary as an argument. + + metrics: + {'onset_f': 0.5, 'offset_f': 0.5} + or {'onset_f': {'value': 0.5, 'weight': 1.0}, 'offset_f': {'value': 0.5, 'weight': 1.0}} + or {'onset_p': (0.3, 5)} + + """ + for k, v in metrics.items(): + if isinstance(v, dict): + getattr(self, k).update(**v) + elif isinstance(v, tuple): + getattr(self, k).update(*v) + else: + getattr(self, k).update(v) + + def bulk_update_errors(self, errors: Dict[str, Union[int, float]]) -> None: + """ Update error counts with a dictionary as an argument. + + errors: + {'error_type_or_message_1': (int | float) count, + 'error_type_or_message_2': (int | float) count,} + + """ + for error_type, count in errors.items(): + # Update the error count + if isinstance(count, int) or isinstance(count, float): + getattr(self, error_type).update(count) + else: + raise ValueError(f"Count of error type '{error_type}' must be an integer or a float.") + + def bulk_compute(self) -> Dict[str, float]: + computed_metrics = {} + for k, v in self._modules.items(): + if isinstance(v, UpdatedMeanMetric) and v.is_updated(): + computed_metrics[self._prefix + k] = v.compute() + # Create micro onset F1 for each instrument. Only when micro metrics are updated. + extra_classes = self.extra_classes if self.extra_classes is not None else [] + for class_name in extra_classes + ['drum']: + # micro onset F1 for each instrument. + _micro_onset_p_instr = computed_metrics.get(self._prefix + 'micro_onset_p_' + class_name, None) + _micro_onset_r_instr = computed_metrics.get(self._prefix + 'micro_onset_r_' + class_name, None) + if _micro_onset_p_instr is not None and _micro_onset_r_instr is not None: + computed_metrics[self._prefix + 'micro_onset_f_' + class_name] = f1_measure( + _micro_onset_p_instr.item(), _micro_onset_r_instr.item()) + # micro offset F1 for each instrument. 'drum' is usually not included. + _micro_offset_p_instr = computed_metrics.get(self._prefix + 'micro_offset_p_' + class_name, None) + _micro_offset_r_instr = computed_metrics.get(self._prefix + 'micro_offset_r_' + class_name, None) + if _micro_offset_p_instr is not None and _micro_offset_r_instr is not None: + computed_metrics[self._prefix + 'micro_offset_f_' + class_name] = f1_measure( + _micro_offset_p_instr.item(), _micro_offset_r_instr.item()) + + # Remove micro onset and offset P,R (Now we have F1) + for class_name in extra_classes + ['drum']: + for onoff in ['micro_onset', 'micro_offset']: + for pr in ['p', 'r']: + computed_metrics.pop(self._prefix + onoff + '_' + pr + '_' + class_name, None) + + return computed_metrics + + def bulk_reset(self) -> None: + for k, v in self._modules.items(): + if isinstance(v, UpdatedMeanMetric): + v.reset() + v._updated = False + + +def compute_track_metrics(pred_notes: List[Note], + ref_notes: List[Note], + eval_vocab: Optional[Dict] = None, + eval_drum_vocab: Optional[Dict] = None, + onset_tolerance: float = 0.05, + add_pitch_class_metric: Optional[List[str]] = None, + add_melody_metric: Optional[List[str]] = None, + add_frame_metric: bool = False, + add_micro_metric: bool = False, + add_multi_f_metric: bool = False, + extra_info: Optional[Any] = None): + """ Track metrics + + Args: + pred_notes: (List[Note]) predicted sequence of notes for a track + ref_notes: (List[Note]) reference sequence of notes for a track + return_instr_metric: (bool) return instrument-specific metrics + eval_vocab: (Dict or None) program group for instrument-specific metrics + { + instrument_or_group_name: + [program_number_0, program_number_1 ...] + } + If None, use default GM instruments. + + ex) eval_vocab = {"piano": np.arange(0, 8), ...} + drum_vocab: (Dict or None) note (pitch) group for drum-specific metrics + { + instrument_or_group_name: + [note_number_0, note_number_1 ...] + } + add_pitch_class_metric: (List[str] or None) add pitch class metrics for the + given instruments. The instrument names are defined in config/vocabulrary.py. + ex) ['Bass', 'Guitar'] + add_singing_oa_metric: (bool) add melody overall accuracy for tje given instruments. + The instrument names are defined in config/vocabulrary.py. + ex) ['Singing Voice'] + (https://craffel.github.io/mir_eval/#mir_eval.melody.overall_accuracy + add_frame_metric: (bool) add frame-wise metrics + extra_info: (Any) extra information for debugging. Currently not implemented + + Returns: + metrics: (Dict) track metrics in the AMTMetric format with attribute names such as 'onset_f_{instrument_or_group_name}' + + + @dataclass + class Note: + is_drum: bool + program: int + onset: float + offset: float + pitch: int + velocity: int + + Caution: Note is mutable instance, even if we use copy(). + + """ + + # Extract drum and non-drum notes + def extract_drum_and_non_drum_notes(notes: List[Note]): + drum_notes, non_drum_notes = [], [] + for note in notes: + if note.is_drum: + drum_notes.append(note) + else: + non_drum_notes.append(note) + return drum_notes, non_drum_notes + + pns_drum, pns_non_drum = extract_drum_and_non_drum_notes(pred_notes) + rns_drum, rns_non_drum = extract_drum_and_non_drum_notes(ref_notes) + + # Reduce drum notes to drum vocab + def reduce_drum_notes_to_drum_vocab(notes: List[Note], drum_vocab: Dict): + reduced_notes = [] + for note in notes: + for drum_name, pitches in drum_vocab.items(): + if note.pitch in pitches: + new_note = copy.deepcopy(note) + new_note.pitch = pitches[0] + reduced_notes.append(new_note) + return sort_notes(reduced_notes) + + if eval_drum_vocab != None: + pns_drum = reduce_drum_notes_to_drum_vocab(pns_drum, eval_drum_vocab) + rns_drum = reduce_drum_notes_to_drum_vocab(rns_drum, eval_drum_vocab) + + # Extract Pitches (freq) and Intervals + pns_drum_pi = extract_pitches_intervals_from_notes(pns_drum, is_drum=True) + pns_non_drum_pi = extract_pitches_intervals_from_notes(pns_non_drum) + rns_drum_pi = extract_pitches_intervals_from_notes(rns_drum, is_drum=True) + rns_non_drum_pi = extract_pitches_intervals_from_notes(rns_non_drum) + + # Compute file-wise PRF for drums + drum_metric = mir_eval_note_f1(pns_drum_pi['pitches'], + pns_drum_pi['intervals'], + rns_drum_pi['pitches'], + rns_drum_pi['intervals'], + onset_tolerance=onset_tolerance, + is_drum=True, + add_micro_metric=add_micro_metric) + + # Compute file-wise PRF for non-drums + non_drum_metric = mir_eval_note_f1(pns_non_drum_pi['pitches'], + pns_non_drum_pi['intervals'], + rns_non_drum_pi['pitches'], + rns_non_drum_pi['intervals'], + onset_tolerance=onset_tolerance, + is_drum=False) + + # Compute file-wise frame PRF for non-drums + if add_frame_metric is True: + # Extract frame-level Pitches (freq) and Intervals + pns_non_drum_tf = extract_frame_time_freq_from_notes(pns_non_drum) + rns_non_drum_tf = extract_frame_time_freq_from_notes(rns_non_drum) + + res = mir_eval_frame_f1(pns_non_drum_tf, rns_non_drum_tf) + non_drum_metric = {**non_drum_metric, **res} # merge dicts + + ############## Compute instrument-wise PRF for non-drums ############## + + if eval_vocab is None: + return drum_metric, non_drum_metric, {} + else: + instr_metric = {} + for group_name, programs in eval_vocab.items(): + # Extract notes for each instrument + # bug fix for piano/drum overlap on slakh + pns_group = [note for note in pns_non_drum if note.program in programs] + rns_group = [note for note in rns_non_drum if note.program in programs] + + # Compute PC instrument-wise PRF using pitch class (currently for bass) + if add_pitch_class_metric is not None: + if group_name.lower() in [g.lower() for g in add_pitch_class_metric]: + # pc: pitch information is converted to pitch classe e.g. 0-11 + pns_pc_group = extract_pitches_intervals_from_notes(notes2pc_notes(pns_group)) + rns_pc_group = extract_pitches_intervals_from_notes(notes2pc_notes(rns_group)) + + _instr_pc_metric = mir_eval_note_f1(pns_pc_group['pitches'], + pns_pc_group['intervals'], + rns_pc_group['pitches'], + rns_pc_group['intervals'], + onset_tolerance=onset_tolerance, + is_drum=False, + add_micro_metric=add_micro_metric, + suffix=group_name + '_pc') + # Add to instrument-wise PRF + for k, v in _instr_pc_metric.items(): + instr_metric[k] = v + + # Extract Pitches (freq) and Intervals + pns_group = extract_pitches_intervals_from_notes(pns_group) + rns_group = extract_pitches_intervals_from_notes(rns_group) + + # Compute instrument-wise PRF + _instr_metric = mir_eval_note_f1(pns_group['pitches'], + pns_group['intervals'], + rns_group['pitches'], + rns_group['intervals'], + onset_tolerance=onset_tolerance, + is_drum=False, + add_micro_metric=add_micro_metric, + suffix=group_name) + + # Merge instrument-wise PRF + for k, v in _instr_metric.items(): + instr_metric[k] = v + + # Optionally compute melody metrics: RPA, RCA, OA + if add_melody_metric is not None: + if group_name.lower() in [g.lower() for g in add_melody_metric]: + _melody_metric = mir_eval_melody_metric(pns_group['pitches'], + pns_group['intervals'], + rns_group['pitches'], + rns_group['intervals'], + cent_tolerance=50, + suffix=group_name) + for k, v in _melody_metric.items(): + instr_metric[k] = v + + # Calculate multi_f metric for this track + if add_multi_f_metric is True: + drum_micro_onset_tp_sum, drum_micro_onset_tpfp_sum, drum_micro_onset_tpfn_sum = 0., 0., 0. + non_drum_micro_offset_tp_sum, non_drum_micro_offset_tpfp_sum, non_drum_micro_offset_tpfn_sum = 0., 0., 0. + # Collect offset metric for non-drum notes + for k, v in instr_metric.items(): + if 'micro_offset_p_' in k and not np.isnan(v['value']): + non_drum_micro_offset_tp_sum += v['value'] * v['weight'] + non_drum_micro_offset_tpfp_sum += v['weight'] + if 'micro_offset_r_' in k and not np.isnan(v['value']): + non_drum_micro_offset_tpfn_sum += v['weight'] + # Collect onset metric for drum notes + for k, v in drum_metric.items(): + if 'micro_onset_p_drum' in k and not np.isnan(v['value']): + drum_micro_onset_tp_sum += v['value'] * v['weight'] + drum_micro_onset_tpfp_sum += v['weight'] + if 'micro_onset_r_drum' in k and not np.isnan(v['value']): + drum_micro_onset_tpfn_sum += v['weight'] + + tp = non_drum_micro_offset_tp_sum + drum_micro_onset_tp_sum + tpfp = non_drum_micro_offset_tpfp_sum + drum_micro_onset_tpfp_sum + tpfn = non_drum_micro_offset_tpfn_sum + drum_micro_onset_tpfn_sum + multi_p_track = tp / tpfp if tpfp > 0 else np.nan + multi_r_track = tp / tpfn if tpfn > 0 else np.nan + multi_f_track = f1_measure(multi_p_track, multi_r_track) + instr_metric['multi_f'] = multi_f_track + + return drum_metric, non_drum_metric, instr_metric diff --git a/amt/src/utils/metrics_helper.py b/amt/src/utils/metrics_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..887b5098ac91c64d5f15ff5b84f27d6245a28373 --- /dev/null +++ b/amt/src/utils/metrics_helper.py @@ -0,0 +1,322 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +import sys +from typing import Tuple, Dict, List, Optional, Any +import numpy as np +from collections import Counter +from scipy.stats import hmean +from mir_eval.transcription import precision_recall_f1_overlap +from mir_eval.multipitch import evaluate +from mir_eval.melody import to_cent_voicing, raw_pitch_accuracy, raw_chroma_accuracy, overall_accuracy +from mir_eval.util import midi_to_hz +from utils.note_event_dataclasses import Note + +EPS = sys.float_info.epsilon + + +def f1_measure(p, r): + return hmean([p + EPS, r + EPS]) - EPS + + +def round_float(l=[], ndigits=4): + return [round(x, ndigits) for x in l] + + +# Convert Notes to pitches and intervals for mir_eval note-wise evaluation +def extract_pitches_intervals_from_notes(notes: List[Note], is_drum: bool = False) -> Dict[str, np.ndarray]: + # drum offsets will be ignored anyways... + pitches = [midi_to_hz(n.pitch) for n in notes] + if is_drum: + intervals = [[n.onset, n.onset + 0.008] for n in notes] + else: + intervals = [[n.onset, n.offset] for n in notes] + return { + "pitches": np.array(pitches), # (L,) + "intervals": np.array(intervals), # (L, 2) + } + + +# Convert Notes to time and freqs for mir_eval frame-wise evaluation +def extract_frame_time_freq_from_notes(notes: List[Note], + is_drum: bool = False, + hop_size_sec: float = 0.0625) -> Dict[str, np.ndarray]: + if len(notes) == 0: + return { + "time": np.array([]), + "freqs": [[]], + "roll": np.zeros((0, 128)), + } + + # drum offsets will be ignored anyways... + note_pitches = [n.pitch for n in notes] + last_offset = max([n.offset for n in notes[-20:]]) + shape = (int(last_offset / hop_size_sec), 128) + roll = np.zeros(shape) + if is_drum: + frame_intervals = [[int(n.onset / hop_size_sec), int(n.onset / hop_size_sec) + 1] for n in notes] + else: + frame_intervals = [[ + int(n.onset / hop_size_sec), + max(int(n.offset / hop_size_sec), + int(n.onset / hop_size_sec) + 1) + ] for n in notes] + # create frame-level piano-roll + for note_pitch, (frame_onset, frame_offset) in zip(note_pitches, frame_intervals): + roll[frame_onset:frame_offset, note_pitch] = 1 + + # take frequency in the range of [16, 110] due to the limitation of mir_eval + roll[:, :16] = 0 + roll[:, 110:] = 0 + + time = np.arange(shape[0]) + frame_pitches = [roll[t, :].nonzero()[0] for t in time] + return { + "time": time * hop_size_sec, + "freqs": [np.array([midi_to_hz(p) for p in pitches]) for pitches in frame_pitches], + "roll": roll, + } + + +# Evaluation: Single instrument Note Onset F1 & OnsetOffset F1 +def mir_eval_note_f1(est_pitches: np.ndarray, + est_intervals: np.ndarray, + ref_pitches: np.ndarray, + ref_intervals: np.ndarray, + is_drum: bool = False, + add_micro_metric: bool = False, + suffix: Optional[str] = None, + onset_tolerance: float = 0.05) -> Dict[str, Any]: + """ Instrument-agnostic Note F1 score + + Args: + est_pitches (np.ndarray): Estimated pitches (Hz) shape=(n,) + est_intervals (np.ndarray): Estimated intervals (seconds) shape=(n, 2) + ref_pitches (np.ndarray): Reference pitches (Hz) shape=(n,) + ref_intervals (np.ndarray): Reference intervals (seconds) shape=(n, 2) + is_drum (bool, optional): Whether the instrument is drum. Defaults to False. + suffix (Optional[str], optional): Suffix to add to the metric names. Defaults to None. + + Returns: + Dict[str, Any]: Instrument-agnostic Note F1 score. np.nan if empty. + + """ + if len(ref_pitches) == 0 and len(est_pitches) == 0: + metrics = { + 'onset_f': np.nan, + 'offset_f': np.nan, + } + onset_p, onset_r, offset_p, offset_r = np.nan, np.nan, np.nan, np.nan + elif len(ref_pitches) == 0 and len(est_pitches) != 0: + metrics = { + 'onset_f': np.nan, # No false negatives, recall and F1 will be NaN + 'offset_f': np.nan, # No false negatives, recall and F1 will be NaN + } + onset_p, onset_r, offset_p, offset_r = 0., np.nan, 0., np.nan + # Add the following elif case to handle the situation when there are reference pitches but no estimated pitches + elif len(ref_pitches) != 0 and len(est_pitches) == 0: + metrics = { + 'onset_f': 0., # No false positives, precision is NaN. recall and F1 are 0. + 'offset_f': 0., # No false positives, precision is NaN. recall and F1 are 0. + } + onset_p, onset_r, offset_p, offset_r = np.nan, 0., np.nan, 0. + else: + metrics = {} + onset_p, onset_r, metrics['onset_f'], _ = precision_recall_f1_overlap(ref_intervals, + ref_pitches, + est_intervals, + est_pitches, + onset_tolerance=onset_tolerance, + pitch_tolerance=50., + offset_ratio=None) + if is_drum is not True: + offset_p, offset_r, metrics['offset_f'], _ = precision_recall_f1_overlap(ref_intervals, + ref_pitches, + est_intervals, + est_pitches, + onset_tolerance=onset_tolerance, + pitch_tolerance=50., + offset_ratio=0.2) + + if add_micro_metric is True: + metrics['micro_onset_p'] = {'value': onset_p, 'weight': len(est_pitches)} + metrics['micro_onset_r'] = {'value': onset_r, 'weight': len(ref_pitches)} + if is_drum is not True: + metrics['micro_offset_p'] = {'value': offset_p, 'weight': len(est_pitches)} + metrics['micro_offset_r'] = {'value': offset_r, 'weight': len(ref_pitches)} + + if is_drum: + # remove offset metrics, and add suffix '_drum' for drum onset metrics + metrics = {k + '_drum' if 'onset' in k else k: v for k, v in metrics.items() if 'offset' not in k} + + if suffix: + metrics = {k + '_' + suffix: v for k, v in metrics.items()} + + return metrics + + +# Evaluation: Frame F1 +def mir_eval_frame_f1(est_time_freqs: Dict[str, List[np.ndarray]], + ref_time_freqs: Dict[str, List[np.ndarray]], + suffix: Optional[str] = None) -> Dict[str, float]: + """ Instrument-agnostic Note F1 score + + Args: + est_time_freqs Dict[str, List[np.ndarray]]: Estimated time, freqs and piano-roll + { + 'time': np.ndarray, Estimated time indices in seconds. + 'freqs': List[np.ndarray], Estimated frequencies in Hz. + 'roll': np.ndarray, Estimated piano-roll. + } + ref_time_freqs Dict[str, List[np.ndarray]]: Reference time, freqs and piano-roll + { + 'time': np.ndarray, Reference time indices in seconds. + 'freqs': List[np.ndarray], Reference frequencies in Hz. + 'roll': np.ndarray, Reference piano-roll. + } + suffix (Optional[str], optional): Suffix to add to the metric names. Defaults to None. + + Returns: + Tuple[Counter, Dict]: Instrument-agnostic Note F1 score + + """ + if np.sum(ref_time_freqs['roll']) == 0 and np.sum(est_time_freqs['roll']) == 0: + metrics = { + 'frame_f': np.nan, + 'frame_f_pc': np.nan, + } + elif np.sum(ref_time_freqs['roll']) == 0 and np.sum(est_time_freqs['roll']) != 0: + metrics = { + 'frame_f': np.nan, # F1-score will be NaN + 'frame_f_pc': np.nan, + } + # Add the following elif case to handle the situation when there are reference pitches but no estimated pitches + elif np.sum(ref_time_freqs['roll']) != 0 and np.sum(est_time_freqs['roll']) == 0: + metrics = { + 'frame_f': 0., # F1-score will be 0. + 'frame_f_pc': 0., + } + else: + # frame-wise evaluation + res = evaluate(ref_time=ref_time_freqs['time'], + ref_freqs=ref_time_freqs['freqs'], + est_time=est_time_freqs['time'], + est_freqs=est_time_freqs['freqs']) + frame_f = f1_measure(res['Precision'], res['Recall']) + frame_f_pc = f1_measure(res['Chroma Precision'], res['Chroma Recall']) + metrics = { + 'frame_f': frame_f, + 'frame_f_pc': frame_f_pc, + } + + if suffix: + metrics = {k + '_' + suffix: v for k, v in metrics.items()} + + return metrics + + +# Evaluation: Melody metrics +def mir_eval_melody_metric(est_pitches: np.ndarray, + est_intervals: np.ndarray, + ref_pitches: np.ndarray, + ref_intervals: np.ndarray, + cent_tolerance: float = 50, + suffix: Optional[str] = None) -> Dict[str, Any]: + """ Melody metrics: Raw Pitch Accuracy, Raw Chroma Accuracy, Overall Accuracy + + Args: + est_pitches (np.ndarray): Estimated pitches (Hz) shape=(n,) + est_intervals (np.ndarray): Estimated intervals (seconds) shape=(n, 2) + ref_pitches (np.ndarray): Reference pitches (Hz) shape=(n,) + ref_intervals (np.ndarray): Reference intervals (seconds) shape=(n, 2) + cent_tolerance (float, optional): Cent tolerance. Defaults to 50. + suffix (Optional[str], optional): Suffix to add to the metric names. Defaults to None. + + Returns: + Dict[str, Any]: RPA, RCA, OA + + """ + try: + (ref_v, ref_c, est_v, est_c) = to_cent_voicing(ref_intervals[:, 0:1], + ref_pitches, + est_intervals[:, 0:1], + est_pitches, + hop=0.01) + # Your code here to calculate rpa based on the outputs of to_cent_voicing + except Exception as e: + print(f"Error occurred: {e}") + return { + 'melody_rpa' + ('_' + suffix if suffix else ''): np.nan, + 'melody_rca' + ('_' + suffix if suffix else ''): np.nan, + 'melody_oa' + ('_' + suffix if suffix else ''): np.nan, + } + + rpa = raw_pitch_accuracy(ref_v, ref_c, est_v, est_c, cent_tolerance) + rca = raw_chroma_accuracy(ref_v, ref_c, est_v, est_c, cent_tolerance) + oa = overall_accuracy(ref_v, ref_c, est_v, est_c, cent_tolerance) + return { + 'melody_rpa' + ('_' + suffix if suffix else ''): rpa, + 'melody_rca' + ('_' + suffix if suffix else ''): rca, + 'melody_oa' + ('_' + suffix if suffix else ''): oa, + } + + +def test(): + ref_pitches = np.array([100, 100, 200, 300]) # in Hz + ref_intervals = np.array([ + [0, 1], # in seconds + [2, 3], + [5, 12], + [1, 10] + ]) + est_pitches = ref_pitches.copy() + est_intervals = ref_intervals.copy() + mir_eval_note_f1(ref_pitches, ref_intervals, ref_pitches, ref_intervals) + """ + result: + + (Counter({ + 'note_onset/precision': 1.0, + 'note_onset/recall': 1.0, + 'note_onset/f1': 1.0, + 'note_offset/precision': 1.0, + 'note_offset/recall': 1.0, + 'note_offset/f1': 1.0 + }) + """ + + est_pitches = np.array([101, 100, 200, 300]) # in Hz + est_intervals = np.array([ + [0.3, 1], # wrong onset, thus on-offset incorrect too. + [2, 3], + [5, 12], + [1, 10] + ]) + mir_eval_note_f1(est_pitches, est_intervals, ref_pitches, ref_intervals) + # note_onset/f1': 0.75, 'note_offset/f1': 0.75}), + + est_pitches = np.array([101, 100, 200, 300]) # in Hz + est_intervals = np.array([ + [0, 0.5], # correct onset, on-offset incorrect + [2, 3], + [5, 12], + [1, 10] + ]) + mir_eval_note_f1(est_pitches, est_intervals, ref_pitches, ref_intervals) + # 'note_onset/f1': 1.0, 'note_offset/f1': 0.75}), + """ Duplicated notes """ + est_pitches = ref_pitches.copy() + est_intervals = ref_intervals.copy() + np.append(est_pitches, 100) # ref has 4 notes, while est has correct 4 notes + another 1 note. + np.append(est_intervals, [1.5, 2.5]) + mir_eval_note_f1(est_pitches, est_intervals, ref_pitches, ref_intervals) + # 'note_onset/f1': 1.0, 'note_offset/f1': 1.0}), + # The duplicated note is not counted as a false positive + # and thus we do not need to post-process multi-instrument tokens + # to remove duplicated notes in instrument-agnostic metrics. diff --git a/amt/src/utils/midi.py b/amt/src/utils/midi.py new file mode 100644 index 0000000000000000000000000000000000000000..b397bddd55721ee60370d6af40103aa0d8616a03 --- /dev/null +++ b/amt/src/utils/midi.py @@ -0,0 +1,423 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +"""midi.py + +MIDI <-> Note +• midi2note: convert a MIDI file to a list of Note instances. +• note2midi: convert a list of Note instances to a MIDI file. + +""" +import os +import copy +import warnings +import numpy as np +from typing import List, Dict, Optional, Tuple, Union +from mido import MetaMessage, Message, MidiFile, MidiTrack, second2tick +from utils.note_event_dataclasses import Note, NoteEvent +from utils.note2event import validate_notes, trim_overlapping_notes +from utils.note2event import note2note_event +""" midi2note: +Convert a MIDI file to a list of Note instances. + +About new implementation: + + The widely used MIDI parsers (implementations from pretty_midi, +onset_and_frames, reconvat, and mir_data) implementations used a method of +applying the offset to the nearest previous note when note overlaps occurred. + + We often found issues with this lazy-processing approach, where the length of +the overlapped notes later in the sequence would become extremely short. + + This code has been re-implemented to address these issues by keeping note +activations in channel-specific buffers, similar to actual DAWs, +allowing for the application of the sustain pedal effect in multi-channel +tracks. + +Example from Slkah,'Track00805-S00' (bass stem): + +(onset, offset) + + +(8.83, 9.02*) * first note's offset is later than second note's onset, so overlap occurs. +(9.0, 9.55) + + +(8.83, 9.0) +(9.0, 9.02*) * second note is too short, because first note's offset is applied to second note. + + +(8.83, 8.84*) * due to reverse search, first note's offset is missing, so minimum offset is applied. +(9.0, 9.55) + + +(8.83, 9.0) +(9.0, 9.55) + +""" +DRUM_PROGRAM = 128 + + +def find_channel_of_track_name(midi_file: os.PathLike, track_name_keywords: List[str]) -> Optional[int]: + mid = MidiFile(midi_file) + found_channels = [] + + for track in mid.tracks: + track_name_found = False + for msg in track: + if msg.type == 'track_name': + for k in track_name_keywords: + if k.lower() == msg.name.lower(): # exact match only + track_name_found = True + break + + if track_name_found and msg.type in ['note_on', 'note_off']: + found_channels.append(msg.channel) + break + + return list(set(found_channels)) + + +def midi2note(file: Union[os.PathLike, str], + binary_velocity: bool = True, + ch_9_as_drum: bool = False, + force_all_drum: bool = False, + force_all_program_to: Optional[int] = None, + track_name_to_program: Optional[Dict] = None, + trim_overlap: bool = True, + fix_offset: bool = True, + quantize: bool = True, + verbose: int = 0, + minimum_offset_sec: float = 0.01, + drum_offset_sec: float = 0.01, + ignore_pedal: bool = False, + return_programs: bool = False) -> Tuple[List[Note], float]: + midi = MidiFile(file) + max_time = midi.length # in seconds + + finished_notes = [] + program_state = [None] * 16 # program_number = program_state[ch] + sustain_state = [None] * 16 # sustain_state[ch] = True if sustain is on + active_notes = [[] for i in range(16)] # active notes by channel(0~15). active_notes[ch] = [Note1, Note_2,..] + sustained_notes = [[] for i in range(16) + ] # offset is passed, but sustain is applied. sustained_notes[ch] = [Note1, Note_2,..] + + # Mapping track name to program (for geerdes data) + reserved_channels = [] + if track_name_to_program is not None: + for key in track_name_to_program.keys(): + found_channels = find_channel_of_track_name(file, [key]) + if len(found_channels) > 0: + for ch in found_channels: + program_state[ch] = track_name_to_program[key] + reserved_channels.append(ch) + if ch_9_as_drum is True: + program_state[9] = DRUM_PROGRAM + reserved_channels.append(9) + + current_time = 0. + for i, msg in enumerate(midi): + current_time += msg.time + if msg.type == 'program_change' and msg.channel not in reserved_channels: + program_state[msg.channel] = msg.program + elif msg.type == 'control_change' and msg.control == 64 and not ignore_pedal: + if msg.value >= 64: + sustain_state[msg.channel] = True + else: + sustain_state[msg.channel] = False + for note in sustained_notes[msg.channel]: + note.offset = current_time + finished_notes.append(note) + sustained_notes[msg.channel] = [] + elif msg.type == 'note_on' and msg.velocity > 0: + if program_state[msg.channel] == None: + if force_all_program_to == None: + raise ValueError( + '📕 midi2note: program_change message is missing. Use `force_all_program_to` option') + else: + program_state[msg.channel] = force_all_program_to + # if (ch_9_as_drum and msg.channel == 9) or force_all_drum: + if program_state[msg.channel] == DRUM_PROGRAM or force_all_drum: + # drum's offset, active_notes, sustained_notes are not tracked. + new_note = Note(is_drum=True, + program=program_state[msg.channel], + onset=current_time, + offset=current_time + drum_offset_sec, + pitch=msg.note, + velocity=msg.velocity) + finished_notes.append(new_note) + else: + new_note = Note(is_drum=False, + program=program_state[msg.channel], + onset=current_time, + offset=None, + pitch=msg.note, + velocity=msg.velocity) + active_notes[msg.channel].append(new_note) + elif msg.type == 'note_off' or (msg.type == 'note_on' and msg.velocity == 0): + temp_active_notes = active_notes.copy() + offset_done_flag = False + for note in active_notes[msg.channel]: + if note.pitch == msg.note: + if sustain_state[msg.channel]: + sustained_notes[msg.channel].append(note) + temp_active_notes[msg.channel].remove(note) + elif offset_done_flag == False: + note.offset = current_time + finished_notes.append(note) + temp_active_notes[msg.channel].remove(note) + offset_done_flag = True + # fix: note_off message is only for the oldest note_on message + else: + pass + active_notes = temp_active_notes + + # Handle any still-active notes (e.g., if the file ends without note_off messages) + for ch_notes in active_notes: + for note in ch_notes: + note.offset = min(current_time, note.onset + minimum_offset_sec) + finished_notes.append(note) + for ch_notes in sustained_notes: + for note in ch_notes: + note.offset = min(current_time, note.onset + minimum_offset_sec) + finished_notes.append(note) + + notes = finished_notes + + if binary_velocity: + for note in notes: + note.velocity = 1 if note.velocity > 0 else 0 + + notes.sort(key=lambda note: (note.onset, note.is_drum, note.program, note.velocity, note.pitch)) + + # Quantize notes to 10 ms + if quantize: + for note in notes: + note.onset = round(note.onset * 100) / 100. + note.offset = round(note.offset * 100) / 100. + + # Trim overlapping notes + if trim_overlap: + notes = trim_overlapping_notes(notes, sort=True) + + # fix offset >= onset the Note instances + if fix_offset: + notes = validate_notes(notes, fix=True) + + # Print some statistics + has_drum = False + for note in notes: + if note.is_drum: + has_drum = True + break + num_instr = sum([int(c is not None) for c in program_state]) + if verbose > 0: + print( + f'parsed {file}: midi_type={midi.type}, num_notes={len(notes)}, num_instr={num_instr}, has_drum={has_drum}') + if return_programs: + return notes, max_time, program_state + else: + return notes, max_time + + +def note_event2midi(note_events: List[NoteEvent], + output_file: Optional[os.PathLike] = None, + velocity: int = 100, + ticks_per_beat: int = 480, + tempo: int = 500000, + singing_program_mapping: int = 65, + singing_chorus_program_mapping: int = 53, + output_inverse_vocab: Optional[Dict] = None) -> None: + """Converts a list of Note instances to a MIDI file. + + List[NoteEvent]: + [NoteEvent(is_drum: bool, program: int, time: Optional[float], velocity: int, + pitch: int, activity: Optional[Set[int]] = {}) + + Example usage: + + note_event2midi(note_events, 'output.mid') + + """ + midi = MidiFile(ticks_per_beat=ticks_per_beat, type=0) + midi.type = 1 + track = MidiTrack() + midi.tracks.append(track) + + # Set tempo + # track.append(mido.MetaMessage('set_tempo', tempo=tempo)) + + # Assign channels to programs + programs = set() + for ne in note_events: + if ne.program == 128 or ne.is_drum == True: + programs.add(128) # 128 represents drum here... + ne.program = 128 # internally we use 128 for drum + else: + programs.add(ne.program) + programs = sorted(programs) + + program_to_channel = {} + available_channels = list(range(0, 9)) + list(range(10, 16)) + for prg in programs: + if prg == 128: + program_to_channel[prg] = 9 + else: + try: + program_to_channel[prg] = available_channels.pop(0) + except IndexError: + warnings.warn(f'not available channels for program {prg}, share channel 16') + program_to_channel[prg] = 15 + + # notes to note_events (this is simpler) + drum_offset_events = [] # for drum notes, we need to add an offset event + for ne in note_events: + if ne.is_drum: + drum_offset_events.append( + NoteEvent(is_drum=True, program=ne.program, time=ne.time + 0.01, pitch=ne.pitch, velocity=0)) + note_events += drum_offset_events + note_events.sort(key=lambda ne: (ne.time, ne.is_drum, ne.program, ne.velocity, ne.pitch)) + + # Add note events to multitrack + for program in programs: + # Create a track for each program + track = MidiTrack() + midi.tracks.append(track) + + # Add track name + if program == 128: + program_name = 'Drums' + elif output_inverse_vocab is not None: + program_name = output_inverse_vocab.get(program, (program, f'Prg. {str(program)}'))[1] + else: + program_name = f'Prg. {str(program)}' + track.append(MetaMessage('track_name', name=program_name, time=0)) + + # Channel is determined by the program + channel = program_to_channel[program] + + # Some special treatment for singing voice and drums + if program == 128: # drum + # set 0 but it is ignored in drum channel + track.append(Message('program_change', program=0, time=0, channel=channel)) + elif program == 100: # singing voice --> Alto Sax + track.append(Message('program_change', program=singing_program_mapping, time=0, channel=channel)) + elif program == 101: # singing voice (chrous) --> Voice Oohs + track.append(Message('program_change', program=singing_chorus_program_mapping, time=0, channel=channel)) + else: + track.append(Message('program_change', program=program, time=0, channel=channel)) + + current_tick = int(0) + for ne in note_events: + if ne.program == program: + absolute_tick = round(second2tick(ne.time, ticks_per_beat, tempo)) + if absolute_tick == current_tick: + delta_tick = int(0) + elif absolute_tick < current_tick: + # this should not happen after sorting + raise ValueError( + f'at ne.time {ne.time}, absolute_tick {absolute_tick} < current_tick {current_tick}') + else: + # Convert time shift value from seconds to ticks + delta_tick = absolute_tick - current_tick + current_tick += delta_tick + + # Create a note on or note off message + msg_note = 'note_on' if ne.velocity > 0 else 'note_off' + msg_velocity = velocity if ne.velocity > 0 else 0 + new_msg = Message(msg_note, note=ne.pitch, velocity=msg_velocity, time=delta_tick, channel=channel) + + track.append(new_msg) + + # Save MIDI file + if output_file != None: + midi.save(output_file) + + +def get_pitch_range_from_midi(midi_file: os.PathLike) -> Tuple[int, int]: + """Returns the pitch range of a MIDI file. + + Args: + midi_file (os.PathLike): Path to a MIDI file. + + Returns: + Tuple[int, int]: The lowest and highest notes in the MIDI file. + """ + notes = midi2note(midi_file, quantize=False, trim_overlap=False) + pitches = [n.pitch for n in notes] + return min(pitches), max(pitches) + + +def pitch_shift_midi(src_midi_file: os.PathLike, + min_pitch_shift: int = -5, + max_pitch_shift: int = 6, + write_midi_file: bool = True, + write_notes_file: bool = True, + write_note_events_file: bool = True) -> None: + """Pitch shifts a MIDI file and write it as MIDI. + + Args: + src_midi_file (os.PathLike): Path to a MIDI file. + min_pitch_shift (int): The number of semitones to shift. + max_pitch_shift (int): The number of semitones to shift. + + Writes: + dst_midi_file (os.PathLike): {src_midi_filename}_pshift_{i}.mid, where i can be [...,-1, 1, 2,...] + dst_notes : List[Note] + dst_note_events: List[NoteEvent] + """ + # source file + src_midi_dir = os.path.dirname(src_midi_file) + src_midi_filename = os.path.basename(src_midi_file).split('.')[0] + src_notes_file = os.path.join(src_midi_dir, f'{src_midi_filename}_notes.npy') + src_note_events_file = os.path.join(src_midi_dir, f'{src_midi_filename}_note_events.npy') + src_notes, _ = midi2note(src_midi_file) + # src_note_events = note2note_event(src_notes) + + for pitch_shift in range(min_pitch_shift, max_pitch_shift): + if pitch_shift == 0: + continue + + # destination file + dst_midi_file = os.path.join(src_midi_dir, f'{src_midi_filename}_pshift{pitch_shift}.mid') + dst_notes_file = os.path.join(src_midi_dir, f'{src_midi_filename}_pshift{pitch_shift}_notes.npy') + dst_note_events_file = os.path.join(src_midi_dir, f'{src_midi_filename}_pshift{pitch_shift}_note_events.npy') + + dst_notes = [] + for note in src_notes: + dst_note = copy.deepcopy(note) + dst_note.pitch += pitch_shift + dst_notes.append(dst_note) + + dst_note_events = note2note_event(dst_notes) + + # write midi file + if write_midi_file: + note_event2midi(dst_note_events, dst_midi_file) + print(f'Created {dst_midi_file}') + + # write notes file + if write_notes_file: + # get metadata for notes + src_notes_metadata = np.load(src_notes_file, allow_pickle=True).tolist() + dst_notes_metadata = src_notes_metadata + dst_notes_metadata['pitch_shift'] = pitch_shift + dst_notes_metadata['notes'] = dst_notes + np.save(dst_notes_file, dst_notes_metadata, allow_pickle=True, fix_imports=False) + print(f'Created {dst_notes_file}') + + # write note events file + if write_note_events_file: + # get metadata for note events + src_note_events_metadata = np.load(src_note_events_file, allow_pickle=True).tolist() + dst_note_events_metadata = src_note_events_metadata + dst_note_events_metadata['pitch_shift'] = pitch_shift + dst_note_events_metadata['note_events'] = dst_note_events + np.save(dst_note_events_file, dst_note_events_metadata, allow_pickle=True, fix_imports=False) + print(f'Created {dst_note_events_file}') diff --git a/amt/src/utils/mirdata_dev/.DS_Store b/amt/src/utils/mirdata_dev/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..fd7ae2684d6bd9ed3c3ade9b8bf608f42014f026 Binary files /dev/null and b/amt/src/utils/mirdata_dev/.DS_Store differ diff --git a/amt/src/utils/mirdata_dev/datasets/slakh16k.py b/amt/src/utils/mirdata_dev/datasets/slakh16k.py new file mode 100644 index 0000000000000000000000000000000000000000..465a8944f0a327f8d29ddb6df14035d00516c8e5 --- /dev/null +++ b/amt/src/utils/mirdata_dev/datasets/slakh16k.py @@ -0,0 +1,455 @@ +"""slakh Dataset Loader + +.. admonition:: Dataset Info + :class: dropdown + + • This code is modified to use the Slakh2100 dataset converted into 16k. + • Unlike slakh, this version treats drum tracks as pitched instruments (80 notes appears). + See Line 243, 356. + + The Synthesized Lakh (Slakh) Dataset is a dataset of multi-track audio and aligned + MIDI for music source separation and multi-instrument automatic transcription. + Individual MIDI tracks are synthesized from the Lakh MIDI Dataset v0.1 using + professional-grade sample-based virtual instruments, and the resulting audio is + mixed together to make musical mixtures. + + The original release of Slakh, called Slakh2100, + contains 2100 automatically mixed tracks and accompanying, aligned MIDI files, + synthesized from 187 instrument patches categorized into 34 classes, totaling + 145 hours of mixture data. + + This loader supports two versions of Slakh: + - Slakh2100-redux: a deduplicated version of slakh2100 containing 1710 multitracks + - baby-slakh: a mini version with 16k wav audio and only the first 20 tracks + + This dataset was created at Mitsubishi Electric Research Labl (MERL) and + Interactive Audio Lab at Northwestern University by Ethan Manilow, + Gordon Wichern, Prem Seetharaman, and Jonathan Le Roux. + + For more information see http://www.slakh.com/ + +""" +import os +from typing import BinaryIO, Optional, Tuple + +from deprecated.sphinx import deprecated +import librosa +import numpy as np +import pretty_midi +from smart_open import open +import yaml + +from mirdata import io, download_utils, jams_utils, core, annotations + +BIBTEX = """ +@inproceedings{manilow2019cutting, + title={Cutting Music Source Separation Some {Slakh}: A Dataset to Study the Impact of Training Data Quality and Quantity}, + author={Manilow, Ethan and Wichern, Gordon and Seetharaman, Prem and Le Roux, Jonathan}, + booktitle={Proc. IEEE Workshop on Applications of Signal Processing to Audio and Acoustics (WASPAA)}, + year={2019}, + organization={IEEE} +} +""" + +INDEXES = { + "default": + "2100-yourmt3-16k", + "test": + "baby", + "2100-yourmt3-16k": + core.Index( + filename="slakh_index_2100-yourmt3-16k.json", + url="https://zenodo.org/record/7717249/files/slakh_index_2100-yourmt3-16k.json?download=1", + checksum="fab898bd82827ddc4c3e4dbd7b7fcbd9", + partial_download=["2100-yourmt3-16k"]), + "2100-redux": + core.Index(filename="slakh_index_2100-redux.json", partial_download=["2100-redux"]), + "baby": + core.Index(filename="slakh_index_baby.json", partial_download=["baby"]), +} + +REMOTES = { + "2100-yourmt3-16k": + download_utils.RemoteFileMetadata( + filename="slakh2100_yourmt3_16k.tar.gz", + url="https://zenodo.org/record/7717249/files/slakh2100_yourmt3_16k.tar.gz?download=1", + checksum="c44f9bcba07b3c6ddeaf604f45dc61c5", + ), + "2100-redux": + download_utils.RemoteFileMetadata( + filename="slakh2100_flac_redux.tar.gz", + url="https://zenodo.org/record/4599666/files/slakh2100_flac_redux.tar.gz?download=1", + checksum="f4b71b6c45ac9b506f59788456b3f0c4", + ), + "baby": + download_utils.RemoteFileMetadata( + filename="babyslakh_16k.tar.gz", + url="https://zenodo.org/record/4603870/files/babyslakh_16k.tar.gz?download=1", + checksum="311096dc2bde7d61c97e930edbfc7f78", + ), +} + +LICENSE_INFO = """ +Creative Commons Attribution 4.0 International +""" + +SPLITS = ["train", "validation", "test", "omitted"] +SPLITS_16K = ["train", "validation", "test"] + +#: Mixing group to program number mapping +MIXING_GROUPS = { + "piano": [0, 1, 2, 3, 4, 5, 6, 7], + "guitar": [24, 25, 26, 27, 28, 29, 30, 31], + "bass": [32, 33, 34, 35, 36, 37, 38, 39], + "drums": [128], +} + + +class Track(core.Track): + """slakh Track class, for individual stems + + Attributes: + audio_path (str or None): path to the track's audio file. For some unusual tracks, + such as sound effects, there is no audio and this attribute is None. + split (str or None): one of 'train', 'validation', 'test', or 'omitted'. + 'omitted' tracks are part of slakh2100-redux which were found to be + duplicates in the original slakh2011. + In baby slakh there are no splits, so this attribute is None. + data_split (str or None): equivalent to split (deprecated in 0.3.6) + metadata_path (str): path to the multitrack's metadata file + midi_path (str or None): path to the track's midi file. For some unusual tracks, + such as sound effects, there is no midi and this attribute is None. + mtrack_id (str): the track's multitrack id + track_id (str): track id + instrument (str): MIDI instrument class, see link for details: + https://en.wikipedia.org/wiki/General_MIDI#Program_change_events + integrated_loudness (float): integrated loudness (dB) of this track + as calculated by the ITU-R BS.1770-4 spec + is_drum (bool): whether the "drum" flag is true for this MIDI track + midi_program_name (str): MIDI instrument program name + plugin_name (str): patch/plugin name that rendered the audio file + mixing_group (str): which mixing group the track belongs to. + One of MIXING_GROUPS. + program_number (int): MIDI instrument program number + + Cached Properties: + midi (PrettyMIDI): midi data used to generate the audio + notes (NoteData or None): note representation of the midi data. + If there are no notes in the midi file, returns None. + multif0 (MultiF0Data or None): multif0 representaation of the midi data. + If there are no notes in the midi file, returns None. + + """ + + def __init__(self, track_id, data_home, dataset_name, index, metadata): + + super().__init__( + track_id, + data_home, + dataset_name=dataset_name, + index=index, + metadata=metadata, + ) + + self.mtrack_id = self.track_id.split("-")[0] + self.audio_path = self.get_path("audio") + self.midi_path = self.get_path("midi") + self.metadata_path = self.get_path("metadata") + + # split (train/validation/test/omitted) is part of the relative filepath in the index + self.split = None # for baby_slakh, there are no data splits - set to None + # if index["version"] == "2100-redux": + if "2100-redux" in index["version"]: + self.split = self._track_paths["metadata"][0].split(os.sep)[1] + assert (self.split in SPLITS), "{} not a valid split - should be one of {}.".format( + self.split, SPLITS) + elif "2100-yourmt3" in index["version"]: + self.split = self._track_paths["metadata"][0].split(os.sep)[1] + assert (self.split in SPLITS_16K), "{} not a valid split - should be one of {}.".format( + self.split, SPLITS_16K) + + self.data_split = self.split # deprecated in 0.3.6 + + @core.cached_property + def _track_metadata(self) -> dict: + try: + with open(self.metadata_path, "r") as fhandle: + metadata = yaml.safe_load(fhandle) + except FileNotFoundError: + raise FileNotFoundError( + f"track metadata for {self.track_id} not found. Did you run .download()?") + return metadata["stems"][self.track_id.split("-")[1]] + + @property + def instrument(self) -> Optional[str]: + return self._track_metadata.get("inst_class") + + @property + def integrated_loudness(self) -> Optional[float]: + return self._track_metadata.get("integrated_loudness") + + @property + def is_drum(self) -> Optional[bool]: + return self._track_metadata.get("is_drum") + + @property + def midi_program_name(self) -> Optional[str]: + return self._track_metadata.get("midi_program_name") + + @property + def plugin_name(self) -> Optional[str]: + return self._track_metadata.get("plugin_name") + + @property + def program_number(self) -> Optional[int]: + return self._track_metadata.get("program_num") + + @property + def mixing_group(self) -> Optional[str]: + group = [k for k, v in MIXING_GROUPS.items() if self.program_number in v] + if len(group) == 0: + return None + return group[0] + + @core.cached_property + def midi(self) -> Optional[pretty_midi.PrettyMIDI]: + return io.load_midi(self.midi_path) + + @core.cached_property + def notes(self) -> Optional[annotations.NoteData]: + return io.load_notes_from_midi(self.midi_path, self.midi, skip_drums=False) + + @core.cached_property + def multif0(self) -> Optional[annotations.MultiF0Data]: + return io.load_multif0_from_midi( + self.midi_path, self.midi, skip_drums=True, pitch_bend=False) + + @property + def audio(self) -> Optional[Tuple[np.ndarray, float]]: + """The track's audio + + Returns: + * np.ndarray - audio signal + * float - sample rate + + """ + return load_audio(self.audio_path) + + def to_jams(self): + """Jams: the track's data in jams format""" + return jams_utils.jams_converter( + audio_path=self.audio_path, + note_data=[(self.notes, "Notes")], + ) + + +class MultiTrack(core.MultiTrack): + """slakh multitrack class, containing information about the mix and + the set of associated stems + + Attributes: + mtrack_id (str): track id + tracks (dict): {track_id: Track} + track_audio_property (str): the name of the attribute of Track which + returns the audio to be mixed + mix_path (str): path to the multitrack mix audio + midi_path (str): path to the full midi data used to generate the mixture + metadata_path (str): path to the multitrack metadata file + split (str or None): one of 'train', 'validation', 'test', or 'omitted'. + 'omitted' tracks are part of slakh2100-redux which were found to be + duplicates in the original slakh2011. + data_split (str or None): equivalent to split (deprecated in 0.3.6) + uuid (str): File name of the original MIDI file from Lakh, sans extension + lakh_midi_dir (str): Path to the original MIDI file from a fresh download of Lakh + normalized (bool): whether the mix and stems were normalized according to the ITU-R BS.1770-4 spec + overall_gain (float): gain applied to every stem to make sure mixture does not clip when stems are summed + + Cached Properties: + midi (PrettyMIDI): midi data used to generate the mixture audio + notes (NoteData): note representation of the midi data + multif0 (MultiF0Data): multif0 representation of the midi data + + """ + + def __init__(self, mtrack_id, data_home, dataset_name, index, track_class, metadata): + super().__init__( + mtrack_id=mtrack_id, + data_home=data_home, + dataset_name=dataset_name, + index=index, + track_class=track_class, + metadata=metadata, + ) + self.mix_path = self.get_path("mix") + self.midi_path = self.get_path("midi") + self.metadata_path = self.get_path("metadata") + + # split (train/validation/test) is determined by the relative filepath in the index + self.split = None # for baby_slakh, there are no data splits - set to None + # if index["version"] == "2100-redux": + if "2100-redux" in index["version"]: + self.split = self._multitrack_paths["mix"][0].split(os.sep)[1] + assert self.split in SPLITS, "{} not in SPLITS".format(self.split) + elif "2100-yourmt3" in index["version"]: + self.split = self._multitrack_paths["mix"][0].split(os.sep)[1] + assert self.split in SPLITS_16K, "{} not in SPLITS".format(self.split) + + self.data_split = self.split # deprecated in 0.3.6 + + @property + def track_audio_property(self) -> str: + return "audio" + + @core.cached_property + def _multitrack_metadata(self) -> dict: + try: + with open(self.metadata_path, "r") as fhandle: + metadata = yaml.safe_load(fhandle) + except FileNotFoundError: + raise FileNotFoundError("Metadata not found. Did you run .download()?") + return metadata + + @property + def uuid(self) -> Optional[str]: + return self._multitrack_metadata.get("UUID") + + @property + def lakh_midi_dir(self) -> Optional[str]: + return self._multitrack_metadata.get("lmd_midi_dir") + + @property + def normalized(self) -> Optional[bool]: + return self._multitrack_metadata.get("normalized") + + @property + def overall_gain(self) -> Optional[float]: + return self._multitrack_metadata.get("overall_gain") + + @core.cached_property + def midi(self) -> Optional[pretty_midi.PrettyMIDI]: + return io.load_midi(self.midi_path) + + @core.cached_property + def notes(self) -> Optional[annotations.NoteData]: + return io.load_notes_from_midi(self.midi_path, self.midi, skip_drums=False) + + @core.cached_property + def multif0(self) -> Optional[annotations.MultiF0Data]: + # TODO: setting pitch_bend to False by default, but there are some + # patches that render pitch bend in the audio. + return io.load_multif0_from_midi( + self.midi_path, self.midi, skip_drums=False, pitch_bend=False) + + @property + def audio(self) -> Optional[Tuple[np.ndarray, float]]: + """The track's audio + + Returns: + * np.ndarray - audio signal + * float - sample rate + + """ + return load_audio(self.mix_path) + + def to_jams(self): + """Jams: the track's data in jams format""" + return jams_utils.jams_converter( + audio_path=self.mix_path, + note_data=[(self.notes, "Notes")], + ) + + def get_submix_by_group(self, target_groups): + """Create submixes grouped by instrument type. Creates one submix + per target group, plus one additional "other" group for any remaining sources. + Only tracks with available audio are mixed. + + Args: + target_groups (list): List of target groups. Elements should be one of + MIXING_GROUPS, e.g. ["bass", "guitar"] + + Returns: + * submixes (dict): {group: audio_signal} of submixes + * groups (dict): {group: list of track ids} of submixes + + """ + groups = {} + submixes = {} + tracks_with_audio = [track for track in self.tracks.values() if track.audio_path] + in_group = [] + for group in target_groups: + groups[group] = [ + track.track_id for track in tracks_with_audio if track.mixing_group == group + ] + in_group.extend(groups[group]) + + submixes[group] = (None if len(groups[group]) == 0 else self.get_target(groups[group])) + + groups["other"] = [ + track.track_id for track in tracks_with_audio if track.track_id not in in_group + ] + submixes["other"] = (None + if len(groups["other"]) == 0 else self.get_target(groups["other"])) + return submixes, groups + + +@io.coerce_to_bytes_io +def load_audio(fhandle: BinaryIO) -> Tuple[np.ndarray, float]: + """Load a slakh audio file. + + Args: + fhandle (str or file-like): path or file-like object pointing to an audio file + + Returns: + * np.ndarray - the audio signal + * float - The sample rate of the audio file + + """ + return librosa.load(fhandle, sr=None, mono=False) + + +@core.docstring_inherit(core.Dataset) +class Dataset(core.Dataset): + """ + The slakh dataset + """ + + def __init__(self, data_home=None, version="default"): + super().__init__( + data_home, + version, + name="slakh", + track_class=Track, + multitrack_class=MultiTrack, + bibtex=BIBTEX, + indexes=INDEXES, + remotes=REMOTES, + license_info=LICENSE_INFO, + ) + + @deprecated( + reason="Use mirdata.datasets.slakh.load_audio", + version="0.3.4", + ) + def load_audio(self, *args, **kwargs): + return load_audio(*args, **kwargs) + + @deprecated( + reason="Use mirdata.datasets.slakh.load_midi", + version="0.3.4", + ) + def load_midi(self, *args, **kwargs): + return io.load_midi(*args, **kwargs) + + @deprecated( + reason="Use mirdata.io.load_notes_from_midi", + version="0.3.4", + ) + def load_notes_from_midi(self, *args, **kwargs): + return io.load_notes_from_midi(*args, **kwargs) + + @deprecated( + reason="Use mirdata.io.load_multif0_from_midi", + version="0.3.4", + ) + def load_multif0_from_midi(self, *args, **kwargs): + return io.load_multif0_from_midi(*args, **kwargs) \ No newline at end of file diff --git a/amt/src/utils/mirdata_dev/scripts/make_slakh_index.py b/amt/src/utils/mirdata_dev/scripts/make_slakh_index.py new file mode 100644 index 0000000000000000000000000000000000000000..751a914359d68f1344f6ba344deab1a41678e1f2 --- /dev/null +++ b/amt/src/utils/mirdata_dev/scripts/make_slakh_index.py @@ -0,0 +1,118 @@ +"""make_slakh16k_index.py +USAGE: + python tasks/utils/mirdata_dev/scripts/make_slakh_index.py '../data' '2100-yourmt3-16k' + +""" +import argparse +import glob +import json +import os +import yaml +from mirdata.validate import md5 + + +def get_file_info(path): + if os.path.exists(path): + return [path, md5(path)] + else: + print("warning: {} not found. check metadata for omitted files.".format( + path)) + return [None, None] + + +def make_dataset_index(dataset_data_path, version): + curr_dir = os.getcwd() + os.chdir(dataset_data_path) + + dataset_index_path = os.path.join(dataset_data_path, "mirdata_indexes", + f"slakh_index_{version}.json") + + if version == "baby": + splits = [""] + topdir = "babyslakh_16k" + fmt = "wav" + elif version == "2100-yourmt3-16k": + splits = ["train", "validation", "test"] + topdir = "slakh2100_yourmt3_16k" + fmt = "wav" + elif version == "2100-redux": + splits = ["train", "validation", "test", "omitted"] + topdir = "slakh2100_flac_redux" + fmt = "flac" + multitrack_index = {} + track_index = {} + + for split in splits: + mtrack_ids = sorted([ + os.path.basename(folder) + for folder in glob.glob(os.path.join(topdir, split, "Track*")) + ]) + for mtrack_id in mtrack_ids: + print(f'indexing multitrack: {mtrack_id}') + mtrack_path = os.path.join(topdir, split, mtrack_id) + metadata_path = os.path.join(mtrack_path, "metadata.yaml") + with open(metadata_path, "r") as fhandle: + metadata = yaml.safe_load(fhandle) + + mtrack_midi_path = os.path.join(mtrack_path, "all_src.mid") + mix_path = os.path.join(mtrack_path, "mix.{}".format(fmt)) + + track_ids = [] + for track_id in metadata["stems"].keys(): + if metadata["stems"][track_id]["audio_rendered"] is not True: + continue # <-- modified by @mimbres to avoid missing audio error + if metadata["stems"][track_id]["midi_saved"] is not True: + continue # <-- modified by @mimbres to avoid missing audio error + audio_path = os.path.join(mtrack_path, "stems", + "{}.{}".format(track_id, fmt)) + midi_path = os.path.join(mtrack_path, "MIDI", + "{}.mid".format(track_id)) + midi_file_info = get_file_info(midi_path) + # skip tracks where there is no midi information (and thus no audio) + if midi_file_info[0] is None: + continue + if get_file_info(audio_path)[0] is None: + continue # <-- modified by @mimbres to avoid missing audio error + track_id = "{}-{}".format(mtrack_id, track_id) + track_ids.append(track_id) + track_index[track_id] = { + "audio": get_file_info(audio_path), + "midi": [midi_file_info[0], midi_file_info[1]], + "metadata": get_file_info(metadata_path), + } + + multitrack_index[mtrack_id] = { + "tracks": track_ids, + "midi": get_file_info(mtrack_midi_path), + "mix": get_file_info(mix_path), + "metadata": get_file_info(metadata_path), + } + + # top-key level version + dataset_index = { + "version": version, + "tracks": track_index, + "multitracks": multitrack_index, + } + + os.chdir(curr_dir) + with open(dataset_index_path, "w") as fhandle: + json.dump(dataset_index, fhandle, indent=2) + + +def main(args): + make_dataset_index(args.dataset_data_path, args.version) + print( + f"A new index file is copied to {args.dataset_data_path}/mirdata_indexes/" + ) + + +if __name__ == "__main__": + PARSER = argparse.ArgumentParser(description="Make dataset index file.") + PARSER.add_argument( + "dataset_data_path", type=str, help="Path to dataset data folder.") + PARSER.add_argument( + "version", + type=str, + help="Dataset version. baby or 2100-redux or 2100-yourmt3-16k") + main(PARSER.parse_args()) \ No newline at end of file diff --git a/amt/src/utils/note2event.py b/amt/src/utils/note2event.py new file mode 100644 index 0000000000000000000000000000000000000000..6ed2ce770795c0dc3a3eaaa9788d6054da542b8c --- /dev/null +++ b/amt/src/utils/note2event.py @@ -0,0 +1,761 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +""" note2event.py + +Note tools: +• mix_notes(notes_to_mix, sort, trim_overlap, fix_offset) + -> List[Note] +• validate_notes(notes, fix) + -> List[Note] +• trim_overlapping_notes(notes, sort) + -> List[Note] +• sort_notes(notes) + -> List[Note] +• notes2pc_notes(notes, note_offs) + -> List[Note] +• extract_program_from_notes(notes) + -> Set[int] +• extract_notes_selected_by_programs(notes, programs, sort) + -> List[Note] + +Note to NoteEvent +• note2note_event(notes, sort, return_activity) + -> List[NoteEvent] + +NoteEvent tools: +• slice_note_events_and_ties(note_events, start_time, end_time, tidyup) + -> Tuple[List[NoteEvent], List[NoteEvent], int]) +• slice_multiple_note_events_and_ties_to_bundle(note_events, start_times, duration_sec, tidyup) + -> List[List[NoteEvent], List[NoteEvent], int]] # Note implmented yet.. +• mix_note_event_lists_bundle(note_events_to_mix, sort, start_time_to_zero) + -> NoteEventListsBundle +• pitch_shift_note_events(note_events, semitone, use_deepcopy) + -> List[NoteEvent] +• separate_by_subunit_programs_from_note_event_lists_bundle( + source_note_event_lists_bundle, + subunit_programs) + -> NoteEventListsBundle: +• separate_channel_by_program_group_from_note_event_lists_bundle( + source_note_event_lists_bundle, + num_program_groups, + program2channel_vocab) + -> List[NoteEventListsBundle]: + +NoteEvent to Event: +• note_event2event(note_events, tie_note_events, start_time, tps, sort) + -> List[Event] + +Event tools: +• check_event_len_from_bundle(note_events_dic_a, note_events_dic_b, max_len, fast_check) + -> bool +""" +import warnings +from copy import deepcopy +from itertools import chain +from typing import Optional, Tuple, Union, List, Set, Dict, Any + +import numpy as np +from utils.note_event_dataclasses import Note, NoteEvent, NoteEventListsBundle +from utils.note_event_dataclasses import Event + +DRUM_OFFSET_TIME = 0.01 # in seconds +MINIMUM_OFFSET_TIME = 0.01 # this is used to avoid zero-length notes +DRUM_PROGRAM = 128 + + +def mix_notes(notes_to_mix: Tuple[List[Note]], + sort: bool = True, + trim_overlap: bool = True, + fix_offset: bool = True) -> List[Note]: + """ + mix_notes: + Mixes a tuple of many lists of Note instances into a single list of Note + instances. This processes 'notes1 + notes2 + ... + notesN' faster. + Because Note instances use absolute timing, the Note instances in the + same timiming will be sorted by increasing order of program and pitch. + + Args: + - notes_to_mix (tuple[list[Note]]): A tuple of lists of Note instances. + - sort (bool): If True, sort the Note instances by increasing order of + onsets, and at the same timing, by increasing order of program and pitch. + Default is True. + + Returns: + - notes (list[Note]): A list of Note instances. + """ + mixed_notes = list(chain(*notes_to_mix)) + if sort and len(mixed_notes) > 0: + mixed_notes.sort( + key=lambda note: (note.onset, note.is_drum, note.program, note.velocity, note.pitch, note.offset)) + + # Trim overlapping notes + if trim_overlap: + mixed_notes = trim_overlapping_notes(mixed_notes, sort=sort) + + # fix offset >= onset the Note instances + if fix_offset: + mixed_notes = validate_notes(mixed_notes, fix=True) + return mixed_notes + + +def validate_notes(notes: Tuple[List[Note]], minimum_offset: Optional[bool] = 0.01, fix: bool = True) -> List[Note]: + """ validate and fix unrealistic notes """ + if len(notes) > 0: + for note in list(notes): + if note.onset == None: + if fix: + notes.remove(note) + continue + elif note.offset == None: + if fix: + note.offset = note.onset + MINIMUM_OFFSET_TIME + elif note.onset > note.offset: + warnings.warn(f'📙 Note at {note} has onset > offset.') + if fix: + note.offset = max(note.offset, note.onset + MINIMUM_OFFSET_TIME) + print(f'✅\033[92m Fixed! Setting offset to onset + {MINIMUM_OFFSET_TIME}.\033[0m') + elif note.is_drum is False and note.offset - note.onset < 0.01: + # fix 13 Oct: too short notes issue for the dataset with non-MIDI annotations + # warnings.warn(f'📙 Note at {note} has offset - onset < 0.01.') + if fix: + note.offset = note.onset + MINIMUM_OFFSET_TIME + # print(f'✅\033[92m Fixed! Setting offset to onset + {MINIMUM_OFFSET_TIME}.\033[0m') + + return notes + + +def trim_overlapping_notes(notes: List[Note], sort: bool = True) -> List[Note]: + """ Trim overlapping notes and dropping zero-length notes. + https://github.com/magenta/mt3/blob/3deffa260ba7de3cf03cda1ea513a4d7ba7144ca/mt3/note_sequences.py#L52 + + Trimming was only applied to train set, not test set in MT3. + """ + if len(notes) <= 1: + return notes + + trimmed_notes = [] + channels = set((note.pitch, note.program, note.is_drum) for note in notes) + + for pitch, program, is_drum in channels: + channel_notes = [ + note for note in notes if note.pitch == pitch and note.program == program and note.is_drum == is_drum + ] + sorted_notes = sorted(channel_notes, key=lambda note: note.onset) + + for i in range(1, len(sorted_notes)): + if sorted_notes[i - 1].offset > sorted_notes[i].onset: + sorted_notes[i - 1].offset = sorted_notes[i].onset + + # Filter out zero-length notes + valid_notes = [note for note in sorted_notes if note.onset < note.offset] + + trimmed_notes.extend(valid_notes) + + if sort: + trimmed_notes.sort(key=lambda note: (note.onset, note.is_drum, note.program, note.velocity, note.pitch)) + return trimmed_notes + + +def sort_notes(notes: List[Note]) -> List[Note]: + """ Sort notes by increasing order of onsets, and at the same timing, by increasing order of program and pitch. """ + if len(notes) > 0: + notes.sort(key=lambda note: (note.onset, note.is_drum, note.program, note.velocity, note.pitch, note.offset)) + return notes + + +def notes2pc_notes(notes: List[Note], note_offset: int = 64) -> List[Note]: + """ Convert a list of Note instances to a list of Pitch Class Set (PCS) instances. + This method is implemented for octave-ignore evaluation cases. """ + pc_notes = deepcopy(notes) + for note in pc_notes: + note.pitch = note.pitch % 12 + note_offset + return pc_notes + + +def extract_program_from_notes(notes: List[Note]) -> Set[int]: + """ Extract program numbers from a list of Note instances.""" + prg = set() + for note in notes: + if note.program not in prg: + prg.add(note.program) + return prg + + +def extract_notes_selected_by_programs(notes: List[Note], programs: Set[int], sort: bool = True) -> List[Note]: + """ Extract notes selected by program numbers from a list of Note instances.""" + selected_notes = [] + for note in notes: + if note.program in programs: + selected_notes.append(note) + if sort: + selected_notes.sort(key=lambda note: (note.onset, note.is_drum, note.program, note.velocity, note.pitch)) + return selected_notes + + +""" +NoteEvent data class: + +Combines NoteEvent and NoteActivity for onset and offset events during Note to Event conversion. + +Features: + +Trackable: follow note activity by index +Sliceable: extract time ranges; time is absolute +Mergeable: combine two NoteEvent instances (re-index needed) +Mutable: mute events by program number, pitch +Transferable: easily convert to Note or Event tokens +""" + + +def note2note_event(notes: List[Note], sort: bool = True, return_activity: bool = True) -> List[NoteEvent]: + """ + note2note_event: + Converts a list of Note instances to a list of NoteEvent instances. + + Args: + - notes (List[Note]): A list of Note instances. + - sort (bool): Sort the NoteEvent instances by increasing order of onsets, + and at the same timing, by increasing order of program and pitch. + Default is True. If return_activity is set to True, NoteEvent instances + are sorted regardless of this argument. + - return_activity (bool): If True, return a list of NoteActivity instances + + Returns: + - note_events (List[NoteEvent]): A list of NoteEvent instances. + + """ + note_events = [] + for note in notes: + # for each note, add onset and offset events + note_events.append(NoteEvent(note.is_drum, note.program, note.onset, note.velocity, note.pitch)) + if note.is_drum == 0: # (drum has no offset!) + note_events.append(NoteEvent(note.is_drum, note.program, note.offset, 0, note.pitch)) + + if sort or return_activity: + note_events.sort(key=lambda n_ev: (n_ev.time, n_ev.is_drum, n_ev.program, n_ev.velocity, n_ev.pitch)) + + if return_activity: + # activity stores the indices of previous notes that are still active + activity = set() # mutable class + for i, ne in enumerate(note_events): + # set a copy of the activity set ti the current note event + ne.activity = activity.copy() + + if ne.is_drum: + continue # drum's offset and activity are not tracked + elif ne.velocity == 1: + activity.add(i) + elif ne.velocity == 0: + # search for the index of matching onset event + matched_onset_event_index = None + for j in activity: + if note_events[j].equals_only(ne, 'is_drum', 'program', 'pitch'): + matched_onset_event_index = j + break + if matched_onset_event_index is not None: + activity.remove(matched_onset_event_index) + else: + raise ValueError(f'📕 note2note_event: no matching onset event for {ne}') + else: + raise ValueError(f'📕 Invalid velocity: {ne.velocity} expected 0 or 1') + if len(activity) > 0: + # if there are still active notes at the end of the sequence + warnings.warn(f'📙 note2note_event: {len(activity)} notes are still \ + active at the end of the sequence. Please validate \ + the input Note instances. ') + return note_events + + +def slice_note_events_and_ties(note_events: List[NoteEvent], + start_time: float, + end_time: float, + tidyup: bool = False) -> Tuple[List[NoteEvent], List[NoteEvent], int]: + """ + Extracts a specific subsequence of note events and tie note events for the + first note event in the subsequence. + + Args: + - note_events (List[NoteEvent]): List of NoteEvent instances. + - start_time (float): The start time of the subsequence in seconds. + - end_time (float): The end time of the subsequence in seconds. + - tidyup (Optional[bool]): If True, sort the resulting lists of NoteEvents, + and remove the activity attribute of sliced_note_event, and remove the + time and activity attributes of tie_note_events. Default is False. + Avoid using tidyup=True without deepcopying the original note_events. + + Note: + - The activity attribute of returned sliced_note_events, and the time and + activity attributes of tie_note_events are not valid after slicing. + Thus, they should be ignored in the downstream processing. + + Returns: + - sliced_note_events (List[NoteEvent]): List of NoteEvent instances in the + specified range. + - tie_note_events (List[NoteEvent]): List of NoteEvent instances that are + active (tie) at start_time. + - start_time (float): Just bypass the start time from the input argument. + """ + if start_time > end_time: + raise ValueError(f'📕 slice_note_events: start_time {start_time} \ + is greater than end_time {end_time}') + elif len(note_events) == 0: + warnings.warn('📙 slice_note_events: empty note_events as input') + return [], [], start_time + + # Get start_index and end_index + start_index, end_index = None, None + found_start = False + for i, ne in enumerate(note_events): + if not found_start and ne.time >= start_time and ne.time < end_time: + start_index = i + found_start = True + + if ne.time >= end_time: + end_index = i + break + + # Get tie_note_events + if start_index == None: + if end_index == 0: + tie_note_events = [] + elif end_index == None: + tie_note_events = [] + else: + tie_note_events = [note_events[i] for i in note_events[end_index].activity] + else: + tie_note_events = [note_events[i] for i in note_events[start_index].activity] + """ modifying note events here is dangerous, due to mutability of original note_events!! """ + if tidyup: + for tne in tie_note_events: + tne.time = None + tne.activity = None + + tie_note_events.sort(key=lambda n_ev: (n_ev.program, n_ev.pitch)) + + # Get sliced note_events + if start_index is None: + sliced_note_events = [] + else: + sliced_note_events = note_events[start_index:end_index] + + if tidyup: + for sne in sliced_note_events: + sne.activity = None + + sliced_note_events.sort(key=lambda n_ev: (n_ev.time, n_ev.is_drum, n_ev.program, n_ev.velocity, n_ev.pitch)) + return sliced_note_events, tie_note_events, start_time + + +""" +class NoteEventListsBundle(TypedDict): + note_events: List[List[NoteEvent]] + tie_note_events: List[List[NoteEvent]] + start_time: List[int] +""" + + +def slice_multiple_note_events_and_ties_to_bundle(note_events: List[NoteEvent], + start_times: List[float], + duration_sec: float, + tidyup: bool = False) -> NoteEventListsBundle: + """ + Extracts N subsequence of note events and tie-note events by taking + a list of N start_time and a list of N end_time. + """ + sliced_note_events_list = [] + sliced_tie_note_events_list = [] + for start_time in start_times: + end_time = start_time + duration_sec + sliced_note_events, tie_note_events, _ = slice_note_events_and_ties(note_events, start_time, end_time, tidyup) + sliced_note_events_list.append(sliced_note_events) + sliced_tie_note_events_list.append(tie_note_events) + return NoteEventListsBundle({ + 'note_events': sliced_note_events_list, + 'tie_note_events': sliced_tie_note_events_list, + 'start_times': start_times + }) + + +def mix_note_event_lists_bundle( + note_event_lists_bundle_to_mix: NoteEventListsBundle, + sort: bool = True, + start_time_to_zero: bool = True, + use_deepcopy: bool = False, +) -> NoteEventListsBundle: + """ + Mixes a tuple of many lists of NoteEvent instances into a single list of NoteEvent + instances. This processes 'note_events1 + note_events2 + ... + note_eventsN'. + Because each NoteEvent list instance may have different start time, it is recommended + to set start_time_to_zero to True. + + Known issue: + - Solution for overlapping note_events is not implemented yet. + - Currently, it is assumed that programs have no overlap among note_events_to_mix. + - For faster processing, use_deepcopy is set to False by default. + + Args: + - note_events_bundle_to_mix (NoteEventListsBundle): + A dictionary with keys ('note_events', 'tie_note_events', 'start_time'). + See NoteEventListsBundle in utils/note_event_dataclasses.py for more details. + - sort (bool): If True, sort the NoteEvent instances by increasing order of onsets, + and at the same timing, by increasing order of program and pitch. + Default is True. + - start_time_to_zero (bool): If True, set the start time of each list of NoteEvents to 0. + Default is True. + - use_deepcopy (bool): If True, use deepcopy() to avoid modifying the original NoteEvent + + Returns: + - mixed_note_events_dic (NoteEventListsBundle): A dictionary with keys ('note_events', 'tie_note_events', 'start_time'). + """ + if use_deepcopy is True: + note_events_to_mix = deepcopy(note_event_lists_bundle_to_mix["note_events"]) + tie_note_events_to_mix = deepcopy(note_event_lists_bundle_to_mix["tie_note_events"]) + else: + note_events_to_mix = note_event_lists_bundle_to_mix["note_events"] + tie_note_events_to_mix = note_event_lists_bundle_to_mix["tie_note_events"] + start_times = note_event_lists_bundle_to_mix["start_times"] + + # Reset start time to zero + if start_time_to_zero is True: + for note_events, tie_note_events, start_time in zip(note_events_to_mix, tie_note_events_to_mix, start_times): + for ne in note_events: + ne.time -= start_time + assert ne.time >= 0, f'📕 mix_note_events: negative time {ne.time}' + """modifying tie note events here is dangerous, due to mutability of linked note_events""" + # for tne in tie_note_events: + # tne.time = None + # tne.activity = None + + # Mix + mixed_note_events = list(chain(*note_events_to_mix)) + mixed_tie_note_events = list(chain(*tie_note_events_to_mix)) + + # Sort + if sort is True: + mixed_note_events.sort(key=lambda n_ev: (n_ev.time, n_ev.is_drum, n_ev.program, n_ev.velocity, n_ev.pitch)) + mixed_tie_note_events.sort(key=lambda n_ev: (n_ev.program, n_ev.pitch)) + + mixed_note_events_dic = NoteEventListsBundle({ + 'note_events': [mixed_note_events], + 'tie_note_events': [mixed_tie_note_events], + 'start_times': [0.] + }) + return mixed_note_events_dic + + +def pitch_shift_note_events(note_events: List[NoteEvent], semitone: int, use_deepcopy: bool = False) -> List[NoteEvent]: + """ + Apply pitch shift to NoteEvent instances: + + Args: + - note_events (List[NoteEvent]): A list of NoteEvent instances. Typically 'note_events' or + 'tie_note_events' can be an input. + - semitone (int): The number of semitones to shift. Positive value shifts up, negative value + - use_deepcopy (bool): If True, use deepcopy() to avoid modifying the original NoteEvent + + Returns: + - note_events (List[NoteEvent]): A list of NoteEvent instances with pitch shifted. Drums are + excluded from pitch shift processing. + """ + if semitone == 0: + return note_events + + if use_deepcopy is True: + note_events = deepcopy(note_events) + + for ne in note_events: + if ne.is_drum is False: + new_pitch = ne.pitch + semitone + if new_pitch >= 0 and new_pitch < 128: + ne.pitch = new_pitch + return note_events + + +def separate_by_subunit_programs_from_note_event_lists_bundle(source_note_event_lists_bundle: NoteEventListsBundle, + subunit_programs: List[List[int]], + start_time_to_zero: bool = True, + sort: bool = True) -> NoteEventListsBundle: + src_note_events = source_note_event_lists_bundle['note_events'] + src_tie_note_events = source_note_event_lists_bundle['tie_note_events'] + src_start_times = source_note_event_lists_bundle['start_times'] + + # Reset start time to zero + if start_time_to_zero is True and not all(t == 0. for t in src_start_times): + for nes, tnes, start_time in zip(src_note_events, src_tie_note_events, src_start_times): + for ne in nes: + ne.time -= start_time + assert ne.time >= 0, f'📕 mix_note_events: negative time {ne.time}' + for tne in tnes: + tne.time = None + tne.activity = None + src_start_times = [0. for i in range(len(src_start_times))] + + num_subunits = len(subunit_programs) + result_note_events = [[] for _ in range(num_subunits)] + result_tie_note_events = [[] for _ in range(num_subunits)] + result_start_times = [0. for _ in range(num_subunits)] + + # Convert subunit_programs to list of sets for faster lookups + subunit_program_sets = [set(sp) for sp in subunit_programs] + + for nes, tnes in zip(src_note_events, src_tie_note_events): + for ne in nes: + if ne.is_drum: + target_indices = [i for i, sp_set in enumerate(subunit_program_sets) if DRUM_PROGRAM in sp_set] + else: + target_indices = [i for i, sp_set in enumerate(subunit_program_sets) if ne.program in sp_set] + for i in target_indices: + result_note_events[i].append(ne) + + for tne in tnes: + target_indices = [i for i, sp_set in enumerate(subunit_program_sets) if tne.program in sp_set] + for i in target_indices: + result_tie_note_events[i].append(tne) + + # Sort + if sort is True: + for nes, tnes in zip(result_note_events, result_tie_note_events): + nes.sort(key=lambda n_ev: (n_ev.time, n_ev.is_drum, n_ev.program, n_ev.velocity, n_ev.pitch)) + tnes.sort(key=lambda n_ev: (n_ev.program, n_ev.pitch)) + + return { + 'note_events': result_note_events, # List[List[NoteEvent]] + 'tie_note_events': result_tie_note_events, # List[List[NoteEvent]] + 'start_times': result_start_times, # List[float] + } + + +def separate_channel_by_program_group_from_note_event_lists_bundle(source_note_event_lists_bundle: NoteEventListsBundle, + num_program_groups: int, + program2channel_vocab: Dict[int, Dict[str, Any]], + start_time_to_zero: bool = False, + sort: bool = True) -> List[NoteEventListsBundle]: + """ + Args: + - source_note_event_lists_bundle (NoteEventListsBundle): + A dictionary with keys ('note_events', 'tie_note_events', 'start_time'). + See NoteEventListsBundle in utils/note_event_dataclasses.py for more details. + - num_program_groups (int): The number of program groups to separate. Typically this is the length + of program_vocab + 1 (for drums). + - program2channel_vocab (Dict[int, Dict[str, Union[List[int], np.ndarray]]]): + A dictionary with keys (program, channel, instrument_group, primary_program). + See program2channel_vocab in utils/utils.py, create_program2channel_vocab() for more details. + example: + program2channel_vocab[program_int] = { + "channel": (int), + "instrument_group": (str), + "primary_program": (int), + } + - start_time_to_zero (bool): If True, set the start time of each list of NoteEvents to 0. + Default is False. + - sort (bool): If True, sort the NoteEvent instances by increasing order of onsets, + and at the same timing, by increasing order of program and pitch. + Default is True. + + Returns: + - result_list_bundle List[NoteEventListsBundle]: A list of NoteEventListsBundle instances with length + of batch_sz. + NoteEventListsBundle is a dictionary with keys ('note_events', 'tie_note_events', 'start_time'). + See NoteEventListsBundle in utils/note_event_dataclasses.py for more details. + + """ + src_note_events = source_note_event_lists_bundle['note_events'] + src_tie_note_events = source_note_event_lists_bundle['tie_note_events'] + src_start_times = source_note_event_lists_bundle['start_times'] + + # Reset start time to zero + if start_time_to_zero is True and not all(t == 0. for t in src_start_times): + for nes, tnes, start_time in zip(src_note_events, src_tie_note_events, src_start_times): + """modifying time of note events is only for mixing events within training. test set should keep the original time""" + for ne in nes: + ne.time -= start_time + assert ne.time >= 0, f'📕 mix_note_events: negative time {ne.time}' + """modifying tie note events here is dangerous, due to mutability of linked note_events""" + # for tne in tnes: + # tne.time = None + # tne.activity = None + src_start_times = [0. for i in range(len(src_start_times))] + + batch_sz = len(src_note_events) + result_list_bundle = [{ + "note_events": [[] for _ in range(num_program_groups)], + "tie_note_events": [[] for _ in range(num_program_groups)], + "start_times": [src_start_times[b] for _ in range(num_program_groups)], + } for b in range(batch_sz)] + """ Example of program2channel_vocab + { + 0: {'channel': 0, 'instrument_group': 'Piano', 'primary_program': 0}, + 1: {'channel': 1, 'instrument_group': 'Chromatic Percussion', 'primary_program': 8}, + ... + 100: {'channel': 11, 'instrument_group': 'Singing Voice', 'primary_program': 100}, + 128: {'channel': 12, 'instrument_group': 'Drums', 'primary_program': 128} + } + """ + # Separate by program_vocab + for b, (nes, tnes) in enumerate(zip(src_note_events, src_tie_note_events)): + for ne in nes: + program = DRUM_PROGRAM if ne.is_drum else ne.program + mapping_info = program2channel_vocab.get(program, None) + if mapping_info is not None: + ch = mapping_info["channel"] + result_list_bundle[b]["note_events"][ch].append(ne) + else: + # Temporary fix for program > 95, such as gunshot and FX. TODO: FX class + pass + + for tne in tnes: + mapping_info = program2channel_vocab.get(tne.program) + if mapping_info is not None: + ch = mapping_info["channel"] + result_list_bundle[b]["tie_note_events"][ch].append(tne) + else: + # Temporary fix for program > 95, such as gunshot and FX. TODO: FX class + pass + + # Sort + if sort: + for ch in range(num_program_groups): + result_list_bundle[b]["note_events"][ch].sort( + key=lambda n_ev: (n_ev.time, n_ev.is_drum, n_ev.program, n_ev.velocity, n_ev.pitch)) + result_list_bundle[b]["tie_note_events"][ch].sort(key=lambda n_ev: (n_ev.program, n_ev.pitch)) + + return result_list_bundle # List[NoteEventListsBundle] with length of batch_sz + + +def note_event2event(note_events: List[NoteEvent], + tie_note_events: Optional[List[NoteEvent]] = None, + start_time: float = 0., + tps: int = 100, + sort: bool = True) -> List[Event]: + """ note_event2event: + Converts a list of NoteEvent instances to a list of Event instances. + - NoteEvent instances have absolute time within a file, while Event instances + have 'shift' events of absolute time within a segment. + - Tie NoteEvent instances are prepended to output list of Event instances, + and closed by a 'tie' event. + - If start_time is not provided, start_time=0 in seconds by default. + - If there is non-tie note_event instances before the start_time, raises an error. + + Args: + - note_events (list[NoteEvent]): A list of NoteEvent instances. + - tie_note_events (Optional[list[NoteEvent]]): A list of tie NoteEvent instances. + See slice_note_events_and_ties() for more details. Default is None. + - start_time (float): Start time in seconds. Default is 0. Any non-tie NoteEvent + instances should have time >= start_time. + - tps (Optional[int]): Ticks per second. Default is 100. + - sort (bool): If True, sort the Event instances by increasing order of + onsets, and at the same timing, by increasing order of program and pitch. + Default is False. + + Returns: + - events (list[Event]): A list of Event instances. + """ + if sort: + if tie_note_events != None: + tie_note_events.sort(key=lambda n_ev: (n_ev.program, n_ev.pitch)) + note_events.sort( + key=lambda n_ev: (round(n_ev.time * tps), n_ev.is_drum, n_ev.program, n_ev.velocity, n_ev.pitch)) + + # Initialize event list and state variables + events = [] + start_tick = round(start_time * tps) + tick_state = start_tick + + program_state = None + + # Prepend tie events + if tie_note_events: + for tne in tie_note_events: + if tne.program != program_state: + events.append(Event(type='program', value=tne.program)) + program_state = tne.program + events.append(Event(type='pitch', value=tne.pitch)) + + # Any tie events (can be empty) are closed by a 'tie' event + events.append(Event(type='tie', value=0)) + + # Translate NoteEvent to Event in the list + velocity_state = None # reset state variables + for ne in note_events: + if ne.is_drum and ne.velocity == 0: # <-- bug fix + continue # drum's offset should be ignored, and should not cause shift + + # Process time shift and update tick_state + ne_tick = round(ne.time * tps) + if ne_tick > tick_state: + # shift_ticks = ne_tick - tick_state + shift_ticks = ne_tick - start_tick + events.append(Event(type='shift', value=shift_ticks)) + tick_state = ne_tick + elif ne_tick == tick_state: + pass + else: + raise ValueError( + f'NoteEvent tick_state {ne_tick} of time {ne.time} is smaller than tick_state {tick_state}.') + + # Process program change and update program_state + if ne.is_drum and ne.velocity == 1: + # drum events have no program and offset but velocity 1 + if velocity_state != 1 or velocity_state == None: + events.append(Event(type='velocity', value=1)) + velocity_state = 1 + events.append(Event(type='drum', value=ne.pitch)) + else: + if ne.program != program_state or program_state == None: + events.append(Event(type='program', value=ne.program)) + program_state = ne.program + + if ne.velocity != velocity_state or velocity_state == None: + events.append(Event(type='velocity', value=ne.velocity)) + velocity_state = ne.velocity + + events.append(Event(type='pitch', value=ne.pitch)) + + return events + + +def check_event_len_from_bundle(note_events_dic_a: Dict, + note_events_dic_b: Dict, + max_len: int, + fast_check: bool = True) -> bool: + """ + Check if the total length of events converted from note_events_dic exceeds the max length. + This is used in cross augmentation. See augment.py for more the usage. + + Args: + - note_events_dic_a (Dict): A dictionary with keys ('note_events', 'tie_note_events', 'start_time'). + - note_events_dic_b (Dict): A dictionary with keys ('note_events', 'tie_note_events', 'start_time'). + - max_len (int): Maximum length of events. + - fast_check (bool): If True, check the total length of note_events only. Default is True. + + Returns: + - bool: True (passed) or False (failed) + """ + if fast_check is True: + ne_len_a = sum([len(ne) for ne in note_events_dic_a['note_events']]) + ne_len_b = sum([len(ne) for ne in note_events_dic_b['note_events']]) + total_note_events_len = ne_len_a + ne_len_b + + if fast_check is False or total_note_events_len >= max_len // 3: + event_len_a = 0 + for ne, tne, start_time in zip(note_events_dic_a['note_events'], note_events_dic_a['tie_note_events'], + note_events_dic_a['start_times']): + event_len_a += len(note_event2event(ne, tne, start_time)) + + event_len_b = 0 + for ne, tne, start_time in zip(note_events_dic_b['note_events'], note_events_dic_b['tie_note_events'], + note_events_dic_b['start_times']): + event_len_b += len(note_event2event(ne, tne, start_time)) + + total_events_len = event_len_a + event_len_b + if total_events_len >= max_len: + return False # failed + else: + return True # passed diff --git a/amt/src/utils/note_event_dataclasses.py b/amt/src/utils/note_event_dataclasses.py new file mode 100644 index 0000000000000000000000000000000000000000..f98ebf41c4a634af48fff2ed9bce514804bcafeb --- /dev/null +++ b/amt/src/utils/note_event_dataclasses.py @@ -0,0 +1,85 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +import sys +import importlib +from dataclasses import dataclass, field +from typing import Set, List, Optional + +if sys.version_info >= (3, 8): + typing_module = importlib.import_module("typing") +else: + typing_module = importlib.import_module("typing_extensions") +TypedDict = typing_module.TypedDict + + +@dataclass +class Note: + is_drum: bool + program: int # MIDI program number (0-127) + onset: float # onset time in seconds + offset: float # offset time in seconds + pitch: int # MIDI note number (0-127) + velocity: int # (0-1) if ignore_velocity is True, otherwise (0-127) + + +@dataclass +class NoteEvent: + is_drum: bool + program: int # [0, 127], 128 for drum but ignored in tokenizer + time: Optional[float] # absolute time. allow None for tie note events + velocity: int # currently 1 for onset, 0 for offset, drum has no offset + pitch: int # MIDI pitch + activity: Optional[Set[int]] = field(default_factory=set) + + def equals_except(self, note_event, *excluded_attrs) -> bool: + """ Check if two NoteEvent instances are equal EXCEPT for the + specified attributes. """ + if not isinstance(note_event, NoteEvent): + return False + + for attr, value in self.__dict__.items(): + if attr not in excluded_attrs and value != note_event.__dict__.get(attr): + return False + return True + + def equals_only(self, note_event, *included_attrs) -> bool: + """ Check if two NoteEvent instances are equal for the + specified attributes. """ + if not isinstance(note_event, NoteEvent): + return False + + for attr in included_attrs: + if self.__dict__.get(attr) != note_event.__dict__.get(attr): + return False + return True + + +class NoteEventListsBundle(TypedDict): + """ NoteEventListsBundle: + + A TypedDict class instance that contains multiple lists of NoteEvents for multiple segments. + + """ + note_events: List[List[NoteEvent]] + tie_note_events: List[List[NoteEvent]] + start_times: List[float] + + +@dataclass +class EventRange: + type: str + min_value: int + max_value: int + + +@dataclass +class Event: + type: str + value: int diff --git a/amt/src/utils/preprocess/dataset_stats.py b/amt/src/utils/preprocess/dataset_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..f1625c33b96ae4db7fa886468230dd7635162e03 --- /dev/null +++ b/amt/src/utils/preprocess/dataset_stats.py @@ -0,0 +1,41 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +"""dataset_stats.py""" +import os +import json +import glob +import numpy as np +from typing import Optional, List + +STAT_FILE_NAME = "dataset_stats.json" + + +def generate_dataset_stats(data_home: os.PathLike, dataset_name: Optional[str] = None) -> None: + """Generate dataset stats for a given dataset. + + Args: + data_home: Path to the data directory. + dataset_name: Name of the dataset to (re)generate stats for. If None, generate MISSING stats for all + datasets. + """ + stat_file = os.path.join(data_home, 'yourmt3_indexes', STAT_FILE_NAME) + if os.path.exists(stat_file): + print(f"Loading existing dataset stats file: {stat_file}") + with open(stat_file, 'r') as f: + stats = json.load(f).items() + else: + print(f"Creating new dataset stats file: {stat_file}") + stats = {} + + # Collect all existing yourmt3 indexes + indexes = glob.glob(os.path.join(data_home, 'yourmt3_indexes', '*_file_list.json')) + for index_file in indexes: + dataset_name = os.path.basename(index_file).split('_')[0] + split_name = os.path.basename(index_file).split('_')[1] diff --git a/amt/src/utils/preprocess/generate_dataset_stats.py b/amt/src/utils/preprocess/generate_dataset_stats.py new file mode 100644 index 0000000000000000000000000000000000000000..09fca2632d2137220a685f8b8b22e4d89d08deb7 --- /dev/null +++ b/amt/src/utils/preprocess/generate_dataset_stats.py @@ -0,0 +1,63 @@ +""" +Generate Dataset Stats from yourmt3_indexes. + +Usage: python count_events_and_files_from_dataset.py + +""" + +import os +import sys +import glob +import json +import numpy as np +from utils.note2event import note_event2event + + +def generate_dataset_stats_for_all_datasets(dataset_path: os.PathLike): + """ Count the number of notes and files in the dataset. """ + yourmt3_index_dir = os.path.join(dataset_path, "yourmt3_indexes") + index_files = glob.glob(os.path.join(yourmt3_index_dir, '*.json')) + output_file = os.path.join(yourmt3_index_dir, 'dataset_stats.json') + if output_file in index_files: + index_files.remove(output_file) # remove output file from index files + + counts = {} + for index_file in index_files: + # Split-wise counts + split_file_name = os.path.basename(index_file) + + with open(index_file, 'r') as f: + index = json.load(f) + + file_cnt = len(index) + event_cnt = 0 + for file_dict in index.values(): + ne_file = file_dict['note_events_file'] + note_events_dict = np.load(ne_file, allow_pickle=True).tolist() + note_events = note_events_dict['note_events'] + events = note_event2event(note_events) + event_cnt += len(events) + + # Update counts + counts[split_file_name] = {'num_files': file_cnt, 'num_events': event_cnt} + print(split_file_name, f'num_files: {file_cnt}, num_events: {event_cnt}') + + # Save counts as json + with open(output_file, 'w') as f: + json.dump(counts, f, indent=4) + print(f'Saved data counts to {output_file}') + + +def update_dataset_stats_for_new_dataset(dataset_path: os.PathLike, split_file_name: str): + """ Update the number of notes and files of specific dataset. """ + raise NotImplementedError + + +if __name__ == '__main__': + if sys.argv == 2: + dataset_path = sys.argv[1] + generate_dataset_stats_for_all_datasets(dataset_path) + else: + print('Usage: generate_dataset_stats.py ') + print('Example: python count_events_and_files_from_dataset.py ../../data') + sys.exit(1) \ No newline at end of file diff --git a/amt/src/utils/preprocess/preprocess_cmedia.py b/amt/src/utils/preprocess/preprocess_cmedia.py new file mode 100644 index 0000000000000000000000000000000000000000..403a3fc4e2bd84a27572a92bf567fe625e2e61b0 --- /dev/null +++ b/amt/src/utils/preprocess/preprocess_cmedia.py @@ -0,0 +1,280 @@ +"""preprocess_cmedia.py""" +import os +import glob +import re +import json +import numpy as np +from copy import deepcopy +from typing import Dict +from collections import Counter + +from utils.audio import get_audio_file_info, load_audio_file +from utils.midi import midi2note, note_event2midi +from utils.note2event import note2note_event, sort_notes, validate_notes, trim_overlapping_notes +from utils.event2note import event2note_event +from utils.note_event_dataclasses import Note, NoteEvent +from utils.utils import note_event2token2note_event_sanity_check + +SINGING_WITH_UNANNOTATED_PROGRAM = [100, 129] # 100 for singing voice, 129 for unannotated +SINGING_ONLY_PROGRAM = [100] +# Corrected track 20: [165.368664, 165.831662, 62] to [165.368664, 165.831662, 62] +# Corrected track 20: [272.338528, 272.801526, 62] to [272.338528, 272.801526, 62] +# Corrected track 20: [287.092992, 287.55599, 63] to [287.092992, 287.55599, 63] +# Corrected track 20: [294.451973, 294.915932, 63] to [294.451973, 294.915932, 63] +# Corrected track 23: [185.887641, 186.133542, 62] to [185.887641, 186.133542, 62] +# Corrected track 25: [139.003042, 139.295517, 67] to [139.003042, 139.295517, 67] +# Corrected track 25: [180.361032, 180.433848, 52] to [180.361032, 180.433848, 52] +# Corrected track 41: [60.986724, 61.312811, 61] to [60.986724, 61.312811, 61] +# Corrected track 87: [96.360656, 96.519258, 67] to [96.360656, 96.519258, 67] +# Corrected track 87: [240.265161, 240.474838, 68] to [240.265161, 240.474838, 68] + + +def check_file_existence(file: str) -> bool: + """Checks if file exists.""" + res = True + if not os.path.exists(file): + res = False + elif get_audio_file_info(file)[1] < 10 * 16000: + print(f'File {file} is too short.') + res = False + return res + + +def create_spleeter_audio_stem(vocal_audio_file, accomp_audio_file, cmedia_id) -> Dict: + program = SINGING_WITH_UNANNOTATED_PROGRAM + is_drum = [0, 0] + + audio_tracks = [] # multi-channel audio array (C, T) + vocal_audio = load_audio_file(vocal_audio_file, dtype=np.int16) / 2**15 # returns bytes + audio_tracks.append(vocal_audio.astype(np.float16)) + accomp_audio = load_audio_file(accomp_audio_file, dtype=np.int16) / 2**15 # returns bytes + audio_tracks.append(accomp_audio.astype(np.float16)) + max_length = max(len(vocal_audio), len(accomp_audio)) + + # collate all the audio tracks into a single array + n_tracks = 2 + audio_array = np.zeros((n_tracks, max_length), dtype=np.float16) + for j, audio in enumerate(audio_tracks): + audio_array[j, :len(audio)] = audio + + stem_content = { + 'cmedia_id': cmedia_id, + 'program': np.array(program, dtype=np.int64), + 'is_drum': np.array(is_drum, dtype=np.int64), + 'n_frames': max_length, # int + 'audio_array': audio_array # (n_tracks, n_frames) + } + return stem_content + + +def create_note_note_event_midi_from_cmedia_annotation(ann, midi_file, cmedia_id): + """ + Args: + ann: List[List[float, float, float]] # [onset, offset, pitch] + cmedia_id: str + Returns: + notes: List[Note] + note_events: List[NoteEvent] + midi: List[List[int]] + """ + notes = [] + for onset, offset, pitch in ann: + # # fix 13 Oct: too short notes issue + # if offset - onset < 0.01: # < 10ms + # offset = onset + 0.01 + notes.append( + Note( + is_drum=False, + program=100, + onset=float(onset), + offset=float(offset), + pitch=int(pitch), + velocity=1)) + notes = sort_notes(notes) + notes = validate_notes(notes) # <-- # fix 13 Oct: too short notes issue + notes = trim_overlapping_notes(notes) + note_events = note2note_event(notes) + + # Write midi file + note_event2midi(note_events, midi_file) + print(f"Created {midi_file}") + + return { # notes + 'cmedia_id': cmedia_id, + 'program': SINGING_ONLY_PROGRAM, + 'is_drum': [0, 0], + 'duration_sec': note_events[-1].time, + 'notes': notes, + }, { # note_events + 'cmedia_id': cmedia_id, + 'program': SINGING_ONLY_PROGRAM, + 'is_drum': [0, 0], + 'duration_sec': note_events[-1].time, + 'note_events': note_events, + } + + +def correct_ann(ann_all: Dict, fix_offset: bool = False, max_dur: float = 0.5): + """ correct too short notes that are actully sung in legato """ + for i in range(1, 101): + for j, v in enumerate(ann_all[str(i)]): + dur = v[1] - v[0] + if dur < 0.01: + next_onset = ann_all[str(i)][j + 1][0] + dist_to_next_onset = next_onset - v[1] + if fix_offset is True: + if dist_to_next_onset < max_dur: + # correct the offset + v_old = deepcopy(v) + ann_all[str(i)][j][1] = next_onset + print(f'Corrected track {i}: {v_old} to {ann_all[str(i)][j]}') + + else: + print(v, ann_all[str(i)][j + 1], f'dist_to_next_onset: {dist_to_next_onset}') + + +def preprocess_cmedia_16k(data_home: os.PathLike, + dataset_name='cmedia', + apply_correction=True, + sanity_check=False) -> None: + """ + Splits: + - train: 100 files + - train_vocal + - train_stem + + Writes: + - {dataset_name}_{split}_file_list.json: a dictionary with the following keys: + { + index: + { + 'cmedia_id': cmedia_id, + 'n_frames': (int), + 'mix_audio_file': 'path/to/mix.wav', + 'notes_file': 'path/to/notes.npy', + 'note_events_file': 'path/to/note_events.npy', + 'midi_file': 'path/to/midi.mid', + 'program': List[int], 100 for singing voice, and 129 for unannotated + 'is_drum': List[int], # [0] or [1] + } + } + """ + + # Directory and file paths + base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') + output_index_dir = os.path.join(data_home, 'yourmt3_indexes') + os.makedirs(output_index_dir, exist_ok=True) + + # Load annotation json file as dictionary + ann_file = os.path.join(base_dir, 'Cmedia-train', 'Cmedia_train_gt.json') + with open(ann_file, 'r') as f: + ann_all = json.load(f) # index "1" to "100" + + # Correction for Cmedia-train + correct_ann(ann_all, fix_offset=apply_correction, max_dur=0.5) + + # write ann + ann_file = os.path.join(base_dir, 'Cmedia-train', 'Cmedia_train_gt_corrected.json') + with open(ann_file, 'w') as f: + json.dump(ann_all, f) + + # Check missing audio files and create a dictionary + audio_all = {} # except for missing files + audio_missing = {'train': []} + for i in range(1, 101): + split = 'train' # no split + audio_file = os.path.join(base_dir, f'{split}', f'{i}', 'converted_Mixture.wav') + audio_vocal_file = os.path.join(base_dir, f'{split}', f'{i}', 'vocals.wav') + audio_acc_file = os.path.join(base_dir, f'{split}', f'{i}', 'accompaniment.wav') + if check_file_existence(audio_file) and check_file_existence( + audio_vocal_file) and check_file_existence(audio_acc_file): + audio_all[str(i)] = audio_file + else: + audio_missing[split].append(i) + + assert len(audio_all.keys()) == 100 + + # Track ids + ids_all = audio_all.keys() + ids_train = audio_all.keys() + + # Create notes, note_events, and MIDI from annotation + total_err = Counter() + for id in ids_all: + ann = ann_all[id] + split = 'train' + midi_file = os.path.join(base_dir, f'{split}', id, 'singing.mid') + notes, note_events = create_note_note_event_midi_from_cmedia_annotation(ann, midi_file, id) + + notes_file = midi_file.replace('.mid', '_notes.npy') + note_events_file = midi_file.replace('.mid', '_note_events.npy') + np.save(notes_file, notes, allow_pickle=True, fix_imports=False) + print(f"Created {notes_file}") + np.save(note_events_file, note_events, allow_pickle=True, fix_imports=False) + print(f"Created {note_events_file}") + + if sanity_check: + # sanity check + print(f'Sanity check for {id}...') + err_cnt = note_event2token2note_event_sanity_check( + note_events['note_events'], notes['notes'], report_err_cnt=True) + total_err += err_cnt + if sanity_check: + print(total_err) + if sum(total_err.values()) > 0: + raise Exception("Sanity check failed. Please check the error messages above.") + else: + print("Sanity check passed.") + + # Process audio files + for id in ids_all: + split = 'train' + audio_vocal_file = os.path.join(base_dir, f'{split}', id, 'vocals.wav') + audio_acc_file = os.path.join(base_dir, f'{split}', id, 'accompaniment.wav') + stem_file = os.path.join(base_dir, f'{split}', id, 'stem.npy') + stem_content = create_spleeter_audio_stem(audio_vocal_file, audio_acc_file, id) + # write audio stem + np.save(stem_file, stem_content, allow_pickle=True, fix_imports=False) + print(f"Created {stem_file}") + + # Create file_list.json + ids_by_split = {'train': ids_train, 'train_vocal': ids_train, 'train_stem': ids_train} + + for split in ['train', 'train_vocal', 'train_stem']: + file_list = {} + for i, id in enumerate(ids_by_split[split]): + wav_file = audio_all[id] + n_frames = get_audio_file_info(wav_file)[1] + if 'vocal' in split: + stem_file = None + wav_file = wav_file.replace('converted_Mixture.wav', 'vocals.wav') + program = SINGING_ONLY_PROGRAM + is_drum = [0] + elif 'stem' in split: + stem_file = wav_file.replace('converted_Mixture.wav', 'stem.npy') + program = SINGING_WITH_UNANNOTATED_PROGRAM + is_drum = [0, 0] + else: + stem_file = None + program = SINGING_WITH_UNANNOTATED_PROGRAM + is_drum = [0, 0] + + mid_file = os.path.join(os.path.dirname(wav_file), 'singing.mid') + file_list[i] = { + 'cmedia_id': id, + 'n_frames': n_frames, + 'stem_file': stem_file, + 'mix_audio_file': wav_file, + 'notes_file': mid_file.replace('.mid', '_notes.npy'), + 'note_events_file': mid_file.replace('.mid', '_note_events.npy'), + 'midi_file': mid_file, + 'program': program, + 'is_drum': is_drum, + } + if stem_file is None: + del file_list[i]['stem_file'] + + output_file = os.path.join(output_index_dir, f'{dataset_name}_{split}_file_list.json') + with open(output_file, 'w') as f: + json.dump(file_list, f, indent=4) + print(f'Created {output_file}') \ No newline at end of file diff --git a/amt/src/utils/preprocess/preprocess_egmd.py b/amt/src/utils/preprocess/preprocess_egmd.py new file mode 100644 index 0000000000000000000000000000000000000000..29c47b285884029bb36655fee7723a1f04ab7dcb --- /dev/null +++ b/amt/src/utils/preprocess/preprocess_egmd.py @@ -0,0 +1,186 @@ +"""preprocess_egmd.py""" +import os +import csv +import glob +import re +import json +from typing import Dict, List, Tuple +import numpy as np +from utils.audio import get_audio_file_info +from utils.midi import midi2note, note_event2midi +from utils.note2event import note2note_event, note_event2event +from utils.event2note import event2note_event +from utils.note_event_dataclasses import Note, NoteEvent +from utils.utils import note_event2token2note_event_sanity_check +# from utils.utils import assert_note_events_almost_equal + + +def create_note_event_and_note_from_midi(mid_file: str, id: str) -> Tuple[Dict, Dict]: + """Extracts note or note_event and metadata from midi: + + Returns: + notes (dict): note events and metadata. + note_events (dict): note events and metadata. + """ + notes, dur_sec = midi2note( + mid_file, + binary_velocity=True, + ch_9_as_drum=True, + force_all_drum=True, + trim_overlap=True, + fix_offset=True, + quantize=True, + verbose=0, + minimum_offset_sec=0.01, + drum_offset_sec=0.01, + ignore_pedal=True) + return { # notes + 'egmd_id': id, + 'program': [128], + 'is_drum': [1], + 'duration_sec': dur_sec, + 'notes': notes, + }, { # note_events + 'maps_id': id, + 'program': [128], + 'is_drum': [1], + 'duration_sec': dur_sec, + 'note_events': note2note_event(notes), + } + + +def preprocess_egmd16k(data_home: os.PathLike, dataset_name='egmd') -> None: + """ + Splits: + - train: 35217 files + - validation: 5031 files + - test: 5289 files + - test_reduced: 246 files that contain '_5.midi' or '_10.midi' in the filename + + + Writes: + - {dataset_name}_{split}_file_list.json: a dictionary with the following keys: + { + index: + { + 'egmd_id': egmd_id, # filename wihout extension + 'n_frames': (int), + 'mix_audio_file': 'path/to/mix.wav', + 'notes_file': 'path/to/notes.npy', + 'note_events_file': 'path/to/note_events.npy', + 'midi_file': 'path/to/midi.mid', + 'program': List[int], + 'is_drum': List[int], # 0 or 1 + } + } + """ + + # Directory and file paths + base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') + output_index_dir = os.path.join(data_home, 'yourmt3_indexes') + os.makedirs(output_index_dir, exist_ok=True) + + # Load csv file and create a dictionary + csv_file = os.path.join(base_dir, 'e-gmd-v1.0.0.csv') + with open(csv_file, 'r') as f: + csv_dict_reader = csv.DictReader(f) + egmd_dict_list_all = list(csv_dict_reader) + assert len(egmd_dict_list_all) == 45537 + + # Process MIDI files + for d in egmd_dict_list_all: + emgd_id = d['midi_filename'].split('.')[0] + midi_file = os.path.join(base_dir, d['midi_filename']) + notes, note_events = create_note_event_and_note_from_midi(midi_file, emgd_id) + + # Write notes and note_events + notes_file = midi_file.replace('.midi', '_notes.npy') + note_events_file = midi_file.replace('.midi', '_note_events.npy') + np.save(notes_file, notes, allow_pickle=True, fix_imports=False) + print(f"Created {notes_file}") + np.save(note_events_file, note_events, allow_pickle=True, fix_imports=False) + print(f"Created {note_events_file}") + + # rewrite 120 bpm quantized midi file + quantized_midi_file = midi_file.replace('.midi', '_quantized_120bpm.mid') + note_event2midi(note_events['note_events'], quantized_midi_file) + print(f'Wrote {quantized_midi_file}') + + # Process audio files + pass + + # Create index files + for split in ['train', 'validation', 'test']: + file_list = {} + i = 0 + for d in egmd_dict_list_all: + if d['split'] == split: + egmd_id = d['midi_filename'].split('.')[0] + mix_audio_file = os.path.join(base_dir, d['audio_filename']) + n_frames = get_audio_file_info(mix_audio_file)[1] + midi_file = os.path.join(base_dir, d['midi_filename']) + notes_file = midi_file.replace('.midi', '_notes.npy') + note_events_file = midi_file.replace('.midi', '_note_events.npy') + + # check file existence + assert os.path.exists(mix_audio_file) + assert os.path.exists(midi_file) + assert os.path.exists(notes_file) + assert os.path.exists(note_events_file) + + # create file list + file_list[i] = { + 'egmd_id': egmd_id, + 'n_frames': n_frames, + 'mix_audio_file': mix_audio_file, + 'notes_file': notes_file, + 'note_events_file': note_events_file, + 'midi_file': midi_file, + 'program': [128], + 'is_drum': [1], + } + i += 1 + else: + pass + + # Write file list + output_file = os.path.join(output_index_dir, f'{dataset_name}_{split}_file_list.json') + with open(output_file, 'w') as f: + json.dump(file_list, f, indent=4) + print(f'Wrote {output_file}') + if split == 'train': + assert len(file_list) == 35217 + elif split == 'validation': + assert len(file_list) == 5031 + elif split == 'test': + assert len(file_list) == 5289 + + # Create reduced test index file + split = 'test_reduced' + file_list = {} + i = 0 + for d in egmd_dict_list_all: + if d['split'] == 'test': + midi_file = os.path.join(base_dir, d['midi_filename']) + if '_5.midi' in midi_file or '_10.midi' in midi_file: + egmd_id = d['midi_filename'].split('.')[0] + mix_audio_file = os.path.join(base_dir, d['audio_filename']) + n_frames = get_audio_file_info(mix_audio_file)[1] + notes_file = midi_file.replace('.midi', '_notes.npy') + note_events_file = midi_file.replace('.midi', '_note_events.npy') + file_list[i] = { + 'egmd_id': egmd_id, + 'n_frames': n_frames, + 'mix_audio_file': mix_audio_file, + 'notes_file': notes_file, + 'note_events_file': note_events_file, + 'midi_file': midi_file, + 'program': [128], + 'is_drum': [1], + } + i += 1 + output_file = os.path.join(output_index_dir, f'{dataset_name}_{split}_file_list.json') + with open(output_file, 'w') as f: + json.dump(file_list, f, indent=4) + print(f'Wrote {output_file}') + assert len(file_list) == 246 diff --git a/amt/src/utils/preprocess/preprocess_enstdrums.py b/amt/src/utils/preprocess/preprocess_enstdrums.py new file mode 100644 index 0000000000000000000000000000000000000000..225e0e5d748cd8edef69f7daa48387a48ed1f982 --- /dev/null +++ b/amt/src/utils/preprocess/preprocess_enstdrums.py @@ -0,0 +1,437 @@ +""" preprocess_enstdrums.py """ +import os +import re +import glob +import copy +import json +import numpy as np +from typing import Dict +from utils.note_event_dataclasses import Note +from utils.audio import get_audio_file_info, load_audio_file, write_wav_file +from utils.midi import note_event2midi +from utils.note2event import note2note_event, sort_notes, validate_notes, trim_overlapping_notes, mix_notes +from config.vocabulary import ENST_DRUM_NOTES + +DRUM_OFFSET = 0.01 + + +def create_enst_audio_stem(drum_audio_file, accomp_audio_file, enst_id) -> Dict: + program = [128, 129] + is_drum = [1, 0] + + audio_tracks = [] # multi-channel audio array (C, T) + drum_audio = load_audio_file(drum_audio_file, dtype=np.int16) / 2**15 # returns bytes + audio_tracks.append(drum_audio.astype(np.float16)) + accomp_audio = load_audio_file(accomp_audio_file, dtype=np.int16) / 2**15 # returns bytes + audio_tracks.append(accomp_audio.astype(np.float16)) + max_length = max(len(drum_audio), len(accomp_audio)) + + # collate all the audio tracks into a single array + n_tracks = 2 + audio_array = np.zeros((n_tracks, max_length), dtype=np.float16) + for j, audio in enumerate(audio_tracks): + audio_array[j, :len(audio)] = audio + + stem_content = { + 'enstdrums_id': enst_id, + 'program': program, + 'is_drum': is_drum, + 'n_frames': max_length, # int + 'audio_array': audio_array # (n_tracks, n_frames) + } + return stem_content + + +def create_note_note_event_midi_from_enst_annotation(ann_file, enst_id): + """ + Args: + ann_file: 'path/to/annotation.txt' + enst_id: str + Returns: + notes: List[Note] + note_events: List[NoteEvent] + midi: List[List[int]] + """ + # Read the text file and split each line into timestamp and drum instrument name + with open(ann_file, 'r') as f: + lines = f.readlines() # ignore additional annotations like rc2, sd- + anns = [(float(line.split()[0]), re.sub('[^a-zA-Z]', '', line.split()[1])) for line in lines] + + # Convert ann to notes by ENST_DRUM_NOTES vocabulary + notes = [] + for time, drum_name in anns: + if drum_name not in ENST_DRUM_NOTES.keys(): + raise ValueError(f"Drum name {drum_name} is not in ENST_DRUM_NOTES") + notes.append( + Note( + is_drum=True, + program=128, + onset=float(time), + offset=float(time) + DRUM_OFFSET, + pitch=ENST_DRUM_NOTES[drum_name][0], + velocity=1)) + + notes = sort_notes(notes) + notes = validate_notes(notes) + notes = trim_overlapping_notes(notes) + note_events = note2note_event(notes) + + # Write midi file + midi_file = ann_file.replace('.txt', '.mid') + note_event2midi(note_events, midi_file) + print(f"Created {midi_file}") + + program = [128] + is_drum = [1] + + return { # notes + 'enstdrums_id': enst_id, + 'program': program, + 'is_drum': is_drum, + 'duration_sec': note_events[-1].time, + 'notes': notes, + }, { # note_events + 'enstdrums_id': enst_id, + 'program': program, + 'is_drum': is_drum, + 'duration_sec': note_events[-1].time, + 'note_events': note_events, + } + + +def preprocess_enstdrums16k(data_home: os.PathLike, dataset_name='enstdrums') -> None: + """ + Some tracks ('minus-one' in the file name) of ENST-drums contain accompaniments. + 'stem file' will contain these accompaniments. 'mix_audio_file' will contain the + mix of the drums and accompaniments. + + Splits: + - drummer_1, drummer_2, drummer_3, all + - drummer3_dtd, drummer3_dtp, drummer3_dtm_r1, drummer3_dtm_r2 (for validation/test) + + DTD means drum track only, DTP means drum track plus percussions, DTM means + drum track plus music. + + DTM r1 and r2 are two different versions of the mixing tracks derived from listening + test in [Gillet 2008, Paulus 2009], and used in [Wu 2018]. + r1 uses 1:3 ratio of accompaniment to drums, and r2 uses 2:3 ratio. + + O. Gillet and G. Richard, “Transcription and separation of drum signals from polyphonic music,” + IEEE Trans. Audio, Speech, Lang. Process., vol. 16, no. 3, pp. 529–540, Mar. 2008. + J. Paulus and A. Klapuri, “Drum sound detection in polyphonic mu- sic with hidden Markov models,” + EURASIP J. Audio, Speech, Music Process., vol. 2009, no. 14, 2009, Art. no. 14. + C. -W. Wu et al., "A Review of Automatic Drum Transcription," + IEEE/ACM TASLP, vol. 26, no. 9, pp. 1457-1483, Sept. 2018, + + Writes: + - {dataset_name}_{split}_file_list.json: a dictionary with the following keys: + { + index: + { + 'enstdrums_id': {drummer_id}_{3-digit-track-id} + 'n_frames': (int), + 'stem_file': Dict of stem audio file with metadata, + 'mix_audio_file': 'path/to/mix.wav', + 'notes_file': 'path/to/notes.npy', + 'note_events_file': 'path/to/note_events.npy', + 'midi_file': 'path/to/midi.mid', + 'program': List[int], # 128 for drums, 129 for unannotated (accompaniment) + 'is_drum': List[int], # 0 or 1 + } + } + """ + # Directory and file path + base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') + output_index_dir = os.path.join(data_home, 'yourmt3_indexes') + os.makedirs(output_index_dir, exist_ok=True) + + # Gather info + enst_ids = [] + enst_info = {} + for i in [1, 2, 3]: + drummer_files = sorted( + glob.glob(os.path.join(base_dir, f'drummer_{i}', 'annotation/*.txt'))) + for file in drummer_files: + track_id = os.path.basename(file).split('_')[0] + enst_id = f'{i}_{track_id}' + enst_ids.append(enst_id) + + # Create notes, note_events, and MIDI from annotation + ann_file = file + assert os.path.exists(ann_file), f'{ann_file} does not exist' + notes, note_events = create_note_note_event_midi_from_enst_annotation(ann_file, enst_id) + notes_file = ann_file.replace('.txt', '_notes.npy') + note_events_file = ann_file.replace('.txt', '_note_events.npy') + np.save(notes_file, notes, allow_pickle=True, fix_imports=False) + print(f"Created {notes_file}") + np.save(note_events_file, note_events, allow_pickle=True, fix_imports=False) + print(f"Created {note_events_file}") + + # Create stem file from audio for accompaniment + drum_audio_file = os.path.join(base_dir, f'drummer_{i}', 'audio', 'wet_mix', + os.path.basename(file).replace('.txt', '.wav')) + assert os.path.exists(drum_audio_file), f'{drum_audio_file} does not exist' + + if 'minus-one' in file: # unannotated accompaniment exists + # 129: Unannotated accompaniment exists + accomp_audio_file = os.path.join(base_dir, f'drummer_{i}', 'audio', 'accompaniment', + os.path.basename(file).replace('.txt', '.wav')) + assert os.path.exists(accomp_audio_file), f'{accomp_audio_file} does not exist' + os.makedirs(os.path.join(base_dir, f'drummer_{i}', 'audio', 'stem'), exist_ok=True) + stem_file = os.path.join(base_dir, f'drummer_{i}', 'audio', 'stem', + os.path.basename(file).replace('.txt', '_stem.npy')) + stem_content = create_enst_audio_stem(drum_audio_file, accomp_audio_file, enst_id) + # write audio stem + np.save(stem_file, stem_content, allow_pickle=True, fix_imports=False) + print(f"Created {stem_file}") + + # create (drum + accompaniment) mix audio file. r1 + os.makedirs( + os.path.join(base_dir, f'drummer_{i}', 'audio', 'accompaniment_mix_r1'), + exist_ok=True) + accomp_mix_audio_file_r1 = os.path.join( + base_dir, f'drummer_{i}', 'audio', 'accompaniment_mix_r1', + os.path.basename(file).replace('.txt', '.wav')) + accomp_mix_audio_r1 = stem_content['audio_array'][0] / np.max( + np.abs(stem_content['audio_array'][0])) * 0.75 + stem_content['audio_array'][ + 1] / np.max(np.abs(stem_content['audio_array'][1])) * 0.25 + accomp_mix_audio_r1 = accomp_mix_audio_r1 / np.max(np.abs(accomp_mix_audio_r1)) + write_wav_file(accomp_mix_audio_file_r1, accomp_mix_audio_r1, 16000) + print(f"Created {accomp_mix_audio_file_r1}") + + # create (drum + accompaniment) mix audio file. r1 + os.makedirs( + os.path.join(base_dir, f'drummer_{i}', 'audio', 'accompaniment_mix_r2'), + exist_ok=True) + accomp_mix_audio_file_r2 = os.path.join( + base_dir, f'drummer_{i}', 'audio', 'accompaniment_mix_r2', + os.path.basename(file).replace('.txt', '.wav')) + accomp_mix_audio_r2 = stem_content['audio_array'][0] / np.max( + np.abs(stem_content['audio_array'][0])) * 0.6 + stem_content['audio_array'][ + 1] / np.max(np.abs(stem_content['audio_array'][1])) * 0.4 + accomp_mix_audio_r2 = accomp_mix_audio_r2 / np.max(np.abs(accomp_mix_audio_r2)) + write_wav_file(accomp_mix_audio_file_r2, accomp_mix_audio_r2, 16000) + print(f"Created {accomp_mix_audio_file_r2}") + n_frames = len(accomp_mix_audio_r2) + + # use r2 for training... + mix_audio_file = accomp_mix_audio_file_r2 + else: + # No unannotated accompaniment + stem_file = None + mix_audio_file = drum_audio_file + n_frames = get_audio_file_info(drum_audio_file)[1] + + # Create index, this is based on dtm setup + enst_info[enst_id] = { + 'enstdrums_id': enst_id, + 'n_frames': n_frames, + 'stem_file': stem_file, + 'mix_audio_file': mix_audio_file, + 'notes_file': notes_file, + 'note_events_file': note_events_file, + 'midi_file': ann_file.replace('.txt', '.mid'), + 'program': stem_content['program'] if 'minus-one' in file else notes['program'], + 'is_drum': stem_content['is_drum'] if 'minus-one' in file else notes['is_drum'], + } + + # Write index + for split in [ + 'drummer_1_dtm', 'drummer_2_dtm', 'all_dtm', 'drummer_1_dtp', 'drummer_2_dtp', + 'all_dtp', 'drummer_3_dtd', 'drummer_3_dtp', 'drummer_3_dtm_r1', 'drummer_3_dtm_r2' + ]: + # splits for training + file_list = {} + i = 0 + if split == 'drummer_1_dtm': + for enst_id in enst_ids: + if enst_id.startswith('1_'): + file_list[str(i)] = enst_info[enst_id] + i += 1 + assert len(file_list) == 97 + elif split == 'drummer_2_dtm': + for enst_id in enst_ids: + if enst_id.startswith('2_'): + file_list[str(i)] = enst_info[enst_id] + i += 1 + assert len(file_list) == 105 + elif split == 'all_dtm': + for enst_id in enst_ids: + file_list[str(i)] = enst_info[enst_id] + i += 1 + assert len(file_list) == 318 + elif split == 'drummer_1_dtp': + for enst_id in enst_ids: + if enst_id.startswith('1_'): + file_list[str(i)] = copy.deepcopy(enst_info[enst_id]) + file_list[str(i)]['stem_file'] = None + file_list[str(i)]['mix_audio_file'] = file_list[str( + i)]['mix_audio_file'].replace('accompaniment_mix_r2', 'wet_mix') + file_list[str(i)]['program'] = [128] + file_list[str(i)]['is_drum'] = [1] + i += 1 + assert len(file_list) == 97 + elif split == 'drummer_2_dtp': + for enst_id in enst_ids: + if enst_id.startswith('2_'): + file_list[str(i)] = copy.deepcopy(enst_info[enst_id]) + file_list[str(i)]['stem_file'] = None + file_list[str(i)]['mix_audio_file'] = file_list[str( + i)]['mix_audio_file'].replace('accompaniment_mix_r2', 'wet_mix') + file_list[str(i)]['program'] = [128] + file_list[str(i)]['is_drum'] = [1] + i += 1 + assert len(file_list) == 105 + elif split == 'all_dtp': + for enst_id in enst_ids: + file_list[str(i)] = copy.deepcopy(enst_info[enst_id]) + file_list[str(i)]['stem_file'] = None + file_list[str(i)]['mix_audio_file'] = file_list[str(i)]['mix_audio_file'].replace( + 'accompaniment_mix_r2', 'wet_mix') + file_list[str(i)]['program'] = [128] + file_list[str(i)]['is_drum'] = [1] + i += 1 + assert len(file_list) == 318 + elif split == 'drummer_3_dtd': + for enst_id in enst_ids: + if enst_id.startswith('3_') and len(enst_info[enst_id]['program']) == 1: + assert enst_info[enst_id]['stem_file'] == None + file_list[str(i)] = enst_info[enst_id] + i += 1 + assert len(file_list) == 95 + elif split == 'drummer_3_dtp': + for enst_id in enst_ids: + if enst_id.startswith('3_') and len(enst_info[enst_id]['program']) == 2: + file_list[str(i)] = copy.deepcopy(enst_info[enst_id]) + file_list[str(i)]['stem_file'] = None + # For DTP, we use the drum audio file as the mix audio file + file_list[str(i)]['mix_audio_file'] = file_list[str( + i)]['mix_audio_file'].replace('accompaniment_mix_r2', 'wet_mix') + file_list[str(i)]['program'] = [128] + file_list[str(i)]['is_drum'] = [1] + i += 1 + assert len(file_list) == 21 + elif split == 'drummer_3_dtm_r1': + for enst_id in enst_ids: + if enst_id.startswith('3_') and len(enst_info[enst_id]['program']) == 2: + file_list[str(i)] = copy.deepcopy(enst_info[enst_id]) + file_list[str(i)]['stem_file'] = None + file_list[str(i)]['mix_audio_file'] = file_list[str( + i)]['mix_audio_file'].replace('accompaniment_mix_r2', + 'accompaniment_mix_r1') + i += 1 + assert len(file_list) == 21 + elif split == 'drummer_3_dtm_r2': + for enst_id in enst_ids: + if enst_id.startswith('3_') and len(enst_info[enst_id]['program']) == 2: + file_list[str(i)] = copy.deepcopy(enst_info[enst_id]) + file_list[str(i)]['stem_file'] = None + i += 1 + assert len(file_list) == 21 + + # final check for file existence + for k, v in file_list.items(): + if v['stem_file'] is not None: + assert os.path.exists(v['stem_file']) + assert os.path.exists(v['mix_audio_file']) + assert os.path.exists(v['notes_file']) + assert os.path.exists(v['note_events_file']) + assert os.path.exists(v['midi_file']) + + # write json file + output_index_file = os.path.join(output_index_dir, f'{dataset_name}_{split}_file_list.json') + with open(output_index_file, 'w') as f: + json.dump(file_list, f, indent=4) + print(f"Created {output_index_file}") + + +def create_filelist_dtm_random_enstdrums16k(data_home: os.PathLike, + dataset_name: str = 'enstdrums') -> None: + # Directory and file paths + base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') + output_index_dir = os.path.join(data_home, 'yourmt3_indexes') + os.makedirs(output_index_dir, exist_ok=True) + + # Load all filelist + file_list_all_dtm_path = os.path.join(output_index_dir, + f'{dataset_name}_all_dtm_file_list.json') + file_list_all_dtp_path = os.path.join(output_index_dir, + f'{dataset_name}_all_dtp_file_list.json') + + # Collect dtm tracks + with open(file_list_all_dtm_path, 'r') as f: + fl = json.load(f) + fl_dtm = {} + i = 0 + for v in fl.values(): + if 129 in v['program']: + fl_dtm[i] = copy.deepcopy(v) + i += 1 + # Collect dtd tracks + fl_dtd = {} + i = 0 + for v in fl.values(): + if 129 not in v['program']: + fl_dtd[i] = copy.deepcopy(v) + i += 1 + + # Split: 70, 15, 15 + # rand_idx = np.random.permutation(len(fl_dtm)) + idx = {} + idx['train_dtm'] = [ + 47, 58, 14, 48, 60, 44, 34, 31, 5, 62, 46, 12, 9, 26, 57, 11, 16, 22, 33, 3, 6, 55, 50, 32, + 52, 53, 10, 28, 24, 41, 63, 51, 43, 49, 54, 15, 20, 1, 27, 2, 23, 45, 38, 37 + ] + idx['validation_dtm'] = [39, 4, 19, 59, 61, 17, 56, 36, 29, 0] + idx['test_dtm'] = [18, 7, 42, 25, 40, 8, 30, 21, 13, 35] + idx['train_dtp'] = idx['train_dtm'] + idx['validation_dtp'] = idx['validation_dtm'] + idx['test_dtp'] = idx['test_dtm'] + + for split in [ + 'train_dtm', + 'validation_dtm', + 'test_dtm', + 'train_dtp', + 'validation_dtp', + 'test_dtp', + 'all_dtd', + ]: + file_list = {} + i = 0 + if 'dtm' in split: + for k, v in fl_dtm.items(): + if int(k) in idx[split]: + file_list[i] = copy.deepcopy(v) + i += 1 + if split == 'test_dtm' or split == 'validation_dtm': + # add r1 mix tracks + for k, v in fl_dtm.items(): + if int(k) in idx[split]: + _v = copy.deepcopy(v) + _v['mix_audio_file'] = _v['mix_audio_file'].replace( + 'accompaniment_mix_r2', 'accompaniment_mix_r1') + file_list[i] = _v + i += 1 + elif 'dtp' in split: + for k, v in fl_dtm.items(): + if int(k) in idx[split]: + _v = copy.deepcopy(v) + _v['stem_file'] = None + _v['mix_audio_file'] = _v['mix_audio_file'].replace( + 'accompaniment_mix_r2', 'wet_mix') + _v['program'] = [128] # bug fixed.. + _v['is_drum'] = [1] # bug fixed.. + file_list[i] = _v + i += 1 + elif 'dtd' in split: + for k, v in fl_dtd.items(): + file_list[i] = copy.deepcopy(v) + i += 1 + else: + raise ValueError(f'Unknown split: {split}') + + output_file = os.path.join(output_index_dir, f'{dataset_name}_{split}_file_list.json') + with open(output_file, 'w') as f: + json.dump(file_list, f, indent=4) + print(f'Created {output_file}') diff --git a/amt/src/utils/preprocess/preprocess_geerdes.py b/amt/src/utils/preprocess/preprocess_geerdes.py new file mode 100644 index 0000000000000000000000000000000000000000..e57b401951ec75e7e7ea2b7b77440b05dd51fe8f --- /dev/null +++ b/amt/src/utils/preprocess/preprocess_geerdes.py @@ -0,0 +1,463 @@ +"""preprocess_geerdes.py""" +import os +import glob +import re +import json +import csv +import logging +import random +from typing import Dict, List, Tuple +from copy import deepcopy + +import numpy as np +from utils.audio import get_audio_file_info, load_audio_file +from utils.midi import midi2note, note_event2midi +from utils.note2event import (note2note_event, sort_notes, validate_notes, trim_overlapping_notes, + extract_program_from_notes, extract_notes_selected_by_programs) +from utils.event2note import event2note_event +from utils.note_event_dataclasses import Note, NoteEvent +from utils.utils import note_event2token2note_event_sanity_check, create_inverse_vocab +from config.vocabulary import MT3_FULL_PLUS + +GEERDES_DATA_CSV_FILENAME = 'geerdes_data_final.csv' +DRUM_CHANNEL = 9 # all drums are in channel 9 in geerdes dataset +DRUM_PROGRAM = 128 +SINGING_VOICE_PROGRAM = 100 +SINGING_VOICE_CHORUS_PROGRAM = 101 # representing backup vocals and choir +TRACK_NAME_TO_PROGRAM_MAP = { # compared by exact match of lowercase + "vocal": SINGING_VOICE_PROGRAM, + "vocalist": SINGING_VOICE_PROGRAM, + "2nd Vocals/backings/harmony": SINGING_VOICE_CHORUS_PROGRAM, + "backvocals": SINGING_VOICE_CHORUS_PROGRAM, +} + + +def format_number(n, width=5): + """ + Format a number to a fixed width string, padding with leading zeros if needed. + + Parameters: + - n (int): The number to be formatted. + - width (int, optional): The desired fixed width for the resulting string. Default is 5. + + Returns: + - str: The formatted string representation of the number. + + Example: + >>> format_number(123) + '00123' + >>> format_number(7, 3) + '007' + """ + return f"{int(n):0{width}}" + + +def find_index_with_key(lst, key): + # only checks alphanumeric characters, ignoring upper/lower case + def filter_string(s): + return re.sub(r'[^a-zA-Z0-9]', '', s) + + filtered_key = filter_string(key).lower() + indices = [ + index for index, value in enumerate(lst) if filtered_key in filter_string(value.lower()) + ] + + if len(indices) > 1: + raise ValueError(f"'{key}'has more than two matching song titles.") + elif len(indices) == 1: + return indices[0] + else: + return None + + +"""Code below was used to generate the "geerdes_data_final.csv" file for the Geerdes dataset split info.""" +# def split_and_generate_data_info_csv(data_home=os.PathLike, dataset_name='geerdes') -> None: +# """Preprocess Geerdes dataset.""" +# # Directory and file paths +# base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') +# output_index_dir = os.path.join(data_home, 'yourmt3_indexes') +# os.makedirs(output_index_dir, exist_ok=True) + +# # Setup logger +# log_file = os.path.join(base_dir, 'log.txt') +# logger = logging.getLogger('my_logger') +# logger.setLevel(logging.DEBUG) +# file_handler = logging.FileHandler(log_file) +# formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s') +# file_handler.setFormatter(formatter) +# if not logger.handlers: +# logger.addHandler(file_handler) +# console_handler = logging.StreamHandler() +# console_handler.setLevel(logging.DEBUG) +# console_formatter = logging.Formatter('%(levelname)s - %(message)s') +# console_handler.setFormatter(console_formatter) +# logger.addHandler(console_handler) + +# # Load CSV: construct id to midi/wav dictionary +# csv_file = os.path.join(base_dir, 'tracks_title_corrected.csv') +# tracks_all = {} +# with open(csv_file, 'r') as f: +# reader = csv.reader(f) +# next(reader) # skip header + +# for row in reader: +# geerdes_id = format_number(row[0]) +# title = row[1] +# artist = row[2] +# link = row[6] +# tracks_all[geerdes_id] = {'title': title} +# tracks_all[geerdes_id]['artist'] = artist +# tracks_all[geerdes_id]['link'] = link +# logger.info(f'Loaded {len(tracks_all)} tracks from {csv_file}.') + +# # Search existing audio files +# audio_dir = os.path.join(base_dir, 'audio_16k_final') +# _audio_files = glob.glob(os.path.join(audio_dir, '*.wav')) +# audio_files = [ +# file for file in _audio_files +# if not file.endswith('_vocals.wav') and not file.endswith('_accompaniment.wav') +# ] +# gid_no_audio = [] +# gid_has_audio = [] +# audio_matched = set() +# audio_no_match = set() + +# for geerdes_id in tracks_all.keys(): +# title = tracks_all[geerdes_id]['title'] +# artist = tracks_all[geerdes_id]['artist'] +# # Find matching audio file +# audio_file_id = find_index_with_key(audio_files, title) +# if audio_file_id is not None: +# # add audio file to tracks_all +# audio_file = audio_files[audio_file_id] +# tracks_all[geerdes_id]['audio_file'] = audio_file +# gid_has_audio.append(geerdes_id) +# audio_matched.add(audio_file) +# else: +# logger.info(f'No matching audio file found for {artist} - {title}.') +# gid_no_audio.append(geerdes_id) +# continue + +# audio_no_match = set(audio_files) - audio_matched +# logger.info( +# f'Found {len(audio_files)} audio files. {len(gid_no_audio)} geerdes_ids have no audio files. {gid_no_audio}' +# ) +# logging.warning( +# f'{len(audio_no_match)} audio files have no matching geerdes_id. {audio_no_match}') + +# # Search existing midi files +# midi_dir = os.path.join(base_dir, 'aligned_midifiles_corrected') +# midi_files = glob.glob(os.path.join(midi_dir, '*.mid')) + glob.glob( +# os.path.join(midi_dir, '*.MID')) +# logger.info(f'Found {len(midi_files)} midi files in {midi_dir}.') + +# # Construct id to midi/wav dictionary +# gid_no_midi = [] +# gid_has_midi = [] +# for geerdes_id in tracks_all.keys(): +# expected_midi_file = os.path.join(midi_dir, geerdes_id + 'T.MID') +# if os.path.exists(expected_midi_file): +# gid_has_midi.append(geerdes_id) +# tracks_all[geerdes_id]['midi_file'] = expected_midi_file +# else: +# artist = tracks_all[geerdes_id]['artist'] +# title = tracks_all[geerdes_id]['title'] +# logging.warning( +# f'No matching midi file found for {expected_midi_file}, {artist} - {title}') +# tracks_all[geerdes_id]['midi_file'] = expected_midi_file +# gid_no_midi.append(geerdes_id) + +# # Final dictionary where audio and midi files are matched +# gid_has_midi_and_audio = set(gid_has_midi) & set(gid_has_audio) +# gid_midi_or_audio_missing = set(gid_no_midi).union(set(gid_no_audio)) +# assert len(gid_has_midi_and_audio) + len(gid_midi_or_audio_missing) == len(tracks_all) +# logger.info(f'Found {len(gid_has_midi_and_audio)} tracks with both midi and audio files.') +# logging.warning( +# f'Found {len(gid_midi_or_audio_missing)} tracks with either midi or audio files missing.') + +# for gid in gid_midi_or_audio_missing: +# tracks_all.pop(gid) +# logger.info(f'Final number of tracks: {len(tracks_all)}.') + +# # Stratified split using artist name 5:5 +# artist_groups = {} +# for id, info in tracks_all.items(): +# artist = info['artist'] +# if artist not in artist_groups: +# artist_groups[artist] = [] +# artist_groups[artist].append((id, info)) + +# train_set = {} +# test_set = {} +# for artist, tracks in artist_groups.items(): +# if len(tracks) == 1: +# if random.random() < 0.5: +# train_set[tracks[0][0]] = tracks[0][1] +# else: +# test_set[tracks[0][0]] = tracks[0][1] +# else: +# split_index = len(tracks) // 2 +# for id, info in tracks[:split_index]: +# train_set[id] = info +# for id, info in tracks[split_index:]: +# test_set[id] = info +# logger.info("Train Set:", len(train_set)) +# logger.info("Test Set:", len(test_set)) +# gid_train = list(train_set.keys()) +# gid_validation = list(test_set.keys()) + +# # Create split information +# gid_all = np.random.permutation(list(tracks_all.keys())) +# gid_train = gid_all[:50] +# gid_validation = gid_all[50:] +# for k, v in tracks_all.items(): +# if k in gid_train: +# v['split_half'] = 'train' +# elif k in gid_validation: +# v['split_half'] = 'validation' +# else: +# raise ValueError(f'Invalid split for {k}.') +# logger.info( +# f'Split information created.\ngid_train: {gid_train}\n gid_validation: {gid_validation}.') + +# # Remove base_dir from audio_file and midi_file +# for v in tracks_all.values(): +# v['audio_file'] = v['audio_file'].replace(base_dir + '/', '') +# v['midi_file'] = v['midi_file'].replace(base_dir + '/', '') + +# Write a new csv file +# output_csv_file = os.path.join(base_dir, 'geerdes_data_final.csv') +# with open(output_csv_file, mode='w', newline='', encoding='utf-8') as file: +# writer = csv.writer(file) +# headers = ['id', 'split_half', 'title', 'artist', 'audio_file', 'midi_file', 'link'] +# writer.writerow(headers) + +# for id, info in tracks_all.items(): +# row = [ +# id, info['split_half'], info['title'], info['artist'], info['audio_file'], +# info['midi_file'], info['link'] +# ] +# writer.writerow(row) +# logger.info(f'Wrote {len(tracks_all)} rows to {output_csv_file}.') +# logger.info(f'Finished creating split and basic info file.') + + +def create_note_event_and_note_from_midi(mid_file: str, + id: str, + ch_9_as_drum: bool = True, + track_name_to_program: Dict = None, + ignore_pedal: bool = False) -> Tuple[Dict, Dict]: + """Create note_events and notes from midi file.""" + + # Load midi file + notes, dur_sec, program = midi2note( + mid_file, + ch_9_as_drum=ch_9_as_drum, + track_name_to_program=track_name_to_program, + binary_velocity=True, + ignore_pedal=ignore_pedal, + return_programs=True) + program = [x for x in set(program) if x is not None] # remove None and duplicates + return { # notes + 'geerdes_id': id, + 'program': program, + 'is_drum': [1 if p == DRUM_PROGRAM else 0 for p in program], + 'duration_sec': dur_sec, + 'notes': notes, + }, { # note_events + 'geerdes_id': id, + 'program': program, + 'is_drum': [1 if p == DRUM_PROGRAM else 0 for p in program], + 'duration_sec': dur_sec, + 'note_events': note2note_event(notes), + } + + +def preprocess_geerdes16k(data_home=os.PathLike, + dataset_name='geerdes', + sanity_check=False) -> None: + """Preprocess Geerdes dataset.""" + # Directory and file paths + base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') + output_index_dir = os.path.join(data_home, 'yourmt3_indexes') + os.makedirs(output_index_dir, exist_ok=True) + + # Setup logger + log_file = os.path.join(base_dir, 'log.txt') + logger = logging.getLogger('my_logger') + logger.setLevel(logging.DEBUG) + file_handler = logging.FileHandler(log_file) + formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s') + file_handler.setFormatter(formatter) + if not logger.handlers: + logger.addHandler(file_handler) + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.DEBUG) + console_formatter = logging.Formatter('%(levelname)s - %(message)s') + console_handler.setFormatter(console_formatter) + logger.addHandler(console_handler) + + # Load CSV: construct id to midi/wav dictionary + ymt3_geerdes_csv_file = os.path.join(base_dir, GEERDES_DATA_CSV_FILENAME) + tracks_all = {} + with open(ymt3_geerdes_csv_file, mode='r', encoding='utf-8') as file: + reader = csv.DictReader(file) + for row in reader: + geerdes_id = row['id'] + tracks_all[geerdes_id] = row + # append base_dir to audio_file and midi_file + for v in tracks_all.values(): + v['audio_file'] = os.path.join(base_dir, v['audio_file']) + v['midi_file'] = os.path.join(base_dir, v['midi_file']) + logger.info(f'Loaded {len(tracks_all)} tracks from {ymt3_geerdes_csv_file}.') + + # Process midi files + note_processed_dir = os.path.join(base_dir, 'note_processed') + os.makedirs(note_processed_dir, exist_ok=True) + + for geerdes_id, v in tracks_all.items(): + midi_file = v['midi_file'] + + # create notes and note_events + notes, note_events = create_note_event_and_note_from_midi( + mid_file=midi_file, + id=geerdes_id, + ch_9_as_drum=True, + track_name_to_program=TRACK_NAME_TO_PROGRAM_MAP, + ignore_pedal=False) + + # sanity check + if sanity_check is True: + err_cnt = note_event2token2note_event_sanity_check(note_events['note_events'], + notes['notes']) + if len(err_cnt) > 0: + logging.warning(f'Found {err_cnt} errors in {geerdes_id}.') + + # save notes and note_events + notes_file = os.path.join(note_processed_dir, geerdes_id + '_notes.npy') + np.save(notes_file, notes, allow_pickle=True, fix_imports=False) + logger.info(f'Created {notes_file}.') + + note_events_file = os.path.join(note_processed_dir, geerdes_id + '_note_events.npy') + np.save(note_events_file, note_events, allow_pickle=True, fix_imports=False) + logger.info(f'Created {note_events_file}.') + + # save reconstructed midi file + recon_midi_file = os.path.join(note_processed_dir, geerdes_id + '_recon.mid') + inverse_vocab = create_inverse_vocab(MT3_FULL_PLUS) + note_event2midi( + note_events['note_events'], recon_midi_file, output_inverse_vocab=inverse_vocab) + logger.info(f'Created {recon_midi_file}.') + + # add file paths and info to tracks_all + tracks_all[geerdes_id]['notes_file'] = notes_file + tracks_all[geerdes_id]['note_events_file'] = note_events_file + tracks_all[geerdes_id]['recon_midi_file'] = recon_midi_file + tracks_all[geerdes_id]['program'] = notes['program'] + tracks_all[geerdes_id]['is_drum'] = notes['is_drum'] + + # save extract main_vocal/vocal_and_chorus/accompaniment only notes and note_events + notes_voc = deepcopy(notes) + notes_voc['notes'] = extract_notes_selected_by_programs( + notes['notes'], [SINGING_VOICE_PROGRAM, SINGING_VOICE_CHORUS_PROGRAM]) + notes_voc['program'] = list(extract_program_from_notes(notes_voc['notes'])) + notes_voc['is_drum'] = [1 if p == DRUM_PROGRAM else 0 for p in notes_voc['program']] + notes_voc_file = os.path.join(note_processed_dir, geerdes_id + '_notes_voc.npy') + np.save(notes_voc_file, notes_voc, allow_pickle=True, fix_imports=False) + + note_events_voc = deepcopy(note_events) + note_events_voc['note_events'] = note2note_event(notes_voc['notes']) + note_events_voc['program'] = deepcopy(notes_voc['program']) + note_events_voc['is_drum'] = deepcopy(notes_voc['is_drum']) + note_events_voc_file = os.path.join(note_processed_dir, geerdes_id + '_note_events_voc.npy') + np.save(note_events_voc_file, note_events_voc, allow_pickle=True, fix_imports=False) + + notes_acc = deepcopy(notes) + notes_acc['notes'] = extract_notes_selected_by_programs(notes['notes'], [ + p for p in notes['program'] + if p not in [SINGING_VOICE_PROGRAM, SINGING_VOICE_CHORUS_PROGRAM] + ]) + notes_acc['program'] = list(extract_program_from_notes(notes_acc['notes'])) + notes_acc['is_drum'] = [1 if p == DRUM_PROGRAM else 0 for p in notes_acc['program']] + notes_acc_file = os.path.join(note_processed_dir, geerdes_id + '_notes_acc.npy') + np.save(notes_acc_file, notes_acc, allow_pickle=True, fix_imports=False) + + note_events_acc = deepcopy(note_events) + note_events_acc['note_events'] = note2note_event(notes_acc['notes']) + note_events_acc['program'] = deepcopy(notes_acc['program']) + note_events_acc['is_drum'] = deepcopy(notes_acc['is_drum']) + note_events_acc_file = os.path.join(note_processed_dir, geerdes_id + '_note_events_acc.npy') + np.save(note_events_acc_file, note_events_acc, allow_pickle=True, fix_imports=False) + + tracks_all[geerdes_id]['notes_file_voc'] = notes_voc_file + tracks_all[geerdes_id]['note_events_file_voc'] = note_events_voc_file + tracks_all[geerdes_id]['program_voc'] = notes_voc['program'] + tracks_all[geerdes_id]['is_drum_voc'] = notes_voc['is_drum'] + tracks_all[geerdes_id]['notes_file_acc'] = notes_acc_file + tracks_all[geerdes_id]['note_events_file_acc'] = note_events_acc_file + tracks_all[geerdes_id]['program_acc'] = notes_acc['program'] + tracks_all[geerdes_id]['is_drum_acc'] = notes_acc['is_drum'] + + # Process or check audio files + for geerdes_id, v in tracks_all.items(): + v['mix_audio_file'] = v['audio_file'] + v['mix_audio_file_voc'] = v['audio_file'].replace('.wav', '_vocals.wav') + v['mix_audio_file_acc'] = v['audio_file'].replace('.wav', '_accompaniment.wav') + assert os.path.exists(v['mix_audio_file']) + assert os.path.exists(v['mix_audio_file_voc']) + assert os.path.exists(v['mix_audio_file_acc']) + v['n_frames'] = get_audio_file_info(v['mix_audio_file'])[1] + logger.info(f'Checked audio files. All audio files exist.') + + # Create file_list.json + splits = ['train', 'validation', 'all'] + task_suffixes = ['', '_sep'] + + for task_suffix in task_suffixes: + for split in splits: + # NOTE: We use spleeter files as the mix audio files, since partial stems (for accomp.) are not implemented yet + file_list = {} + cur_idx = 0 + for geerdes_id, v in tracks_all.items(): + if v['split_half'] == split or split == 'all': + if task_suffix == '': + file_list[cur_idx] = { + 'geerdes_id': geerdes_id, + 'n_frames': v['n_frames'], + 'mix_audio_file': v['mix_audio_file'], + 'notes_file': v['notes_file'], + 'note_events_file': v['note_events_file'], + 'midi_file': v['midi_file'], + 'program': v['program'], + 'is_drum': v['is_drum'], + } + cur_idx += 1 + elif task_suffix == '_sep': + file_list[cur_idx] = { + 'geerdes_id': geerdes_id, + 'n_frames': v['n_frames'], + 'mix_audio_file': v['mix_audio_file_voc'], + 'notes_file': v['notes_file_voc'], + 'note_events_file': v['note_events_file_voc'], + 'midi_file': v['midi_file'], + 'program': v['program_voc'], + 'is_drum': v['is_drum_voc'], + } + cur_idx += 1 + file_list[cur_idx] = { + 'geerdes_id': geerdes_id, + 'n_frames': v['n_frames'], + 'mix_audio_file': v['mix_audio_file_acc'], + 'notes_file': v['notes_file_acc'], + 'note_events_file': v['note_events_file_acc'], + 'midi_file': v['midi_file'], + 'program': v['program_acc'], + 'is_drum': v['is_drum_acc'], + } + cur_idx += 1 + + file_list_file = os.path.join(output_index_dir, + f'{dataset_name}_{split}{task_suffix}_file_list.json') + with open(file_list_file, 'w') as f: + json.dump(file_list, f, indent=4) + logger.info(f'Created {file_list_file}.') \ No newline at end of file diff --git a/amt/src/utils/preprocess/preprocess_guitarset.py b/amt/src/utils/preprocess/preprocess_guitarset.py new file mode 100644 index 0000000000000000000000000000000000000000..ff3110f9c979bf473df33d933b2f6f895394fe3c --- /dev/null +++ b/amt/src/utils/preprocess/preprocess_guitarset.py @@ -0,0 +1,386 @@ +""" preprocess_guitarset.py """ +import os +import glob +import copy +import json +from typing import Dict, List, Tuple, Optional +import numpy as np +import jams +from utils.note_event_dataclasses import Note, NoteEvent +from utils.audio import get_audio_file_info, pitch_shift_audio +from utils.midi import note_event2midi, pitch_shift_midi +from utils.note2event import note2note_event, sort_notes, validate_notes, trim_overlapping_notes + + +def create_note_event_and_note_from_jam(jam_file: str, id: str) -> Tuple[Dict, Dict]: + jam = jams.load(jam_file) + notes = [] + for ann in jam.annotations: + for obs in ann.data: + if isinstance(obs.value, float): + if obs.confidence == None: + note = Note(is_drum=False, + program=24, + onset=obs.time, + offset=obs.time + obs.duration, + pitch=round(obs.value), + velocity=1) + notes.append(note) + # Sort, validate, and trim notes + notes = sort_notes(notes) + notes = validate_notes(notes) + notes = trim_overlapping_notes(notes) + + return { # notes + 'guitarset_id': id, + 'program': [24], + 'is_drum': [0], + 'duration_sec': jam.file_metadata.duration, + 'notes': notes, + }, { # note_events + 'guitarset_id': id, + 'program': [24], + 'is_drum': [0], + 'duration_sec': jam.file_metadata.duration, + 'note_events': note2note_event(notes), + } + + +def generate_pitch_shifted_wav_and_midi(file_list: Dict, min_pitch_shift: int = -5, max_pitch_shift: int = 6): + for key in file_list.keys(): + midi_file = file_list[key]['midi_file'] + audio_file = file_list[key]['mix_audio_file'] + + # Write midi, notes, and note_events with pitch shift + pitch_shift_midi(src_midi_file=midi_file, + min_pitch_shift=min_pitch_shift, + max_pitch_shift=max_pitch_shift, + write_midi_file=True, + write_notes_file=True, + write_note_events_file=True) + + # Write wav with pitch shift + pitch_shift_audio(src_audio_file=audio_file, + min_pitch_shift=min_pitch_shift, + max_pitch_shift=max_pitch_shift, + random_microshift_range=(-10, 11)) + + +def preprocess_guitarset16k(data_home: os.PathLike, + dataset_name: str = 'guitarset', + pitch_shift_range: Optional[Tuple[int, int]] = (-5, 6)) -> None: + """ + Splits: + - progression_1, progression_2, progression_3 + - train, validation, test (by random selection [4,1,1] player) + + Writes: + - {dataset_name}_{split}_file_list.json: a dictionary with the following keys: + { + index: + { + 'guitarset_id': guitarset_id, + 'n_frames': (int), + 'mix_audio_file': 'path/to/mix.wav', + 'notes_file': 'path/to/notes.npy', + 'note_events_file': 'path/to/note_events.npy', + 'midi_file': 'path/to/midi.mid', + 'program': List[int], + 'is_drum': List[int], # 0 or 1 + } + } + """ + # Directory and file paths + base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') + output_index_dir = os.path.join(data_home, 'yourmt3_indexes') + os.makedirs(output_index_dir, exist_ok=True) + + # Process annotations + all_ann_files = glob.glob(os.path.join(base_dir, 'annotation/*.jams'), recursive=True) + assert len(all_ann_files) == 360 + notes_files = {} + note_events_files = {} + midi_files = {} + for ann_file in all_ann_files: + # Convert all annotations to notes and note events + guitarset_id = os.path.basename(ann_file).split('.')[0] + notes, note_events = create_note_event_and_note_from_jam(ann_file, guitarset_id) + + notes_file = ann_file.replace('.jams', '_notes.npy') + np.save(notes_file, notes, allow_pickle=True, fix_imports=False) + print(f'Created {notes_file}') + + note_events_file = ann_file.replace('.jams', '_note_events.npy') + np.save(note_events_file, note_events, allow_pickle=True, fix_imports=False) + print(f'Created {note_events_file}') + + # Create a midi file from the note_events + midi_file = ann_file.replace('.jams', '.mid') + note_event2midi(note_events=note_events['note_events'], output_file=midi_file) + print(f'Created {midi_file}') + + notes_files[guitarset_id] = notes_file + note_events_files[guitarset_id] = note_events_file + midi_files[guitarset_id] = midi_file + + # Process audio files + pass + + # Create file_list.json + guitarset_ids_by_split = { + 'progression_1': [], + 'progression_2': [], + 'progression_3': [], + 'player_0': [], + 'player_1': [], + 'player_2': [], + 'player_3': [], + 'player_4': [], + 'player_5': [], + 'train': [], # random selection of 4 players for each style + 'validation': [], # random selection of 1 player for each style + 'test': [], # random selection of 1 player for each style + 'all': [], + } + # by progressions, players and all + for ann_file in all_ann_files: + guitarset_id = os.path.basename(ann_file).split('.')[0] + progression = int(guitarset_id.split('_')[1].split('-')[0][-1]) + player = int(guitarset_id.split('_')[0]) + + # all + guitarset_ids_by_split['all'].append(guitarset_id) + + # progression + if progression == 1: + guitarset_ids_by_split['progression_1'].append(guitarset_id) + elif progression == 2: + guitarset_ids_by_split['progression_2'].append(guitarset_id) + elif progression == 3: + guitarset_ids_by_split['progression_3'].append(guitarset_id) + else: + raise ValueError(f'Invalid progression: {guitarset_id}') + + # player + if player == 0: + guitarset_ids_by_split['player_0'].append(guitarset_id) + elif player == 1: + guitarset_ids_by_split['player_1'].append(guitarset_id) + elif player == 2: + guitarset_ids_by_split['player_2'].append(guitarset_id) + elif player == 3: + guitarset_ids_by_split['player_3'].append(guitarset_id) + elif player == 4: + guitarset_ids_by_split['player_4'].append(guitarset_id) + elif player == 5: + guitarset_ids_by_split['player_5'].append(guitarset_id) + else: + raise ValueError(f'Invalid player: {guitarset_id}') + + # sort + for key in guitarset_ids_by_split.keys(): + guitarset_ids_by_split[key] = sorted(guitarset_ids_by_split[key]) + for i in range(6): + assert len(guitarset_ids_by_split[f'player_{i}']) == 60 + + # train/valid/test by random player + for i in range(60): + rand_sel = np.random.choice(6, size=6, replace=False) + player_train = rand_sel[:4] + player_valid = rand_sel[4] + player_test = rand_sel[5] + for player in player_train: + guitarset_ids_by_split['train'].append(guitarset_ids_by_split[f'player_{player}'][i]) + guitarset_ids_by_split['validation'].append(guitarset_ids_by_split[f'player_{player_valid}'][i]) + guitarset_ids_by_split['test'].append(guitarset_ids_by_split[f'player_{player_test}'][i]) + + assert len(guitarset_ids_by_split['train']) == 240 + assert len(guitarset_ids_by_split['validation']) == 60 + assert len(guitarset_ids_by_split['test']) == 60 + + # Create file_list.json + for split in ['progression_1', 'progression_2', 'progression_3', 'train', 'validation', 'test', 'all']: + file_list = {} + for i, gid in enumerate(guitarset_ids_by_split[split]): + # Check if wav files exist for the 4 versions + wav_file = {} + wav_file['hex'] = os.path.join(base_dir, 'audio_hex-pickup_original', gid + '_' + 'hex' + '.wav') + wav_file['hex_cln'] = os.path.join(base_dir, 'audio_hex-pickup_debleeded', gid + '_' + 'hex_cln' + '.wav') + wav_file['mic'] = os.path.join(base_dir, 'audio_mono-mic', gid + '_' + 'mic' + '.wav') + wav_file['mix'] = os.path.join(base_dir, 'audio_mono-pickup_mix', gid + '_' + 'mix' + '.wav') + for ver in wav_file: + assert os.path.exists(wav_file[ver]) + + for ver in ['mic', 'mix']: #'hex', 'hex_cln', + file_list[i, ver] = { + 'guitarset_id': gid + '_' + ver, + 'n_frames': get_audio_file_info(wav_file[ver])[1], + 'mix_audio_file': wav_file[ver], + 'notes_file': notes_files[gid], + 'note_events_file': note_events_files[gid], + 'midi_file': midi_files[gid], + 'program': [24], + 'is_drum': [0], + } + + # Reindexing file_list + _file_list = {} + for i, v in enumerate(file_list.values()): + _file_list[i] = v + file_list = _file_list + + # Write json + output_file = os.path.join(output_index_dir, f'{dataset_name}_{split}_file_list.json') + with open(output_file, 'w') as f: + json.dump(file_list, f, indent=4) + print(f'Created {output_file}') + + if pitch_shift_range == None: + return + else: + min_pitch_shift, max_pitch_shift = pitch_shift_range + + # Generate pitch shifted wav and MIDI + file_list_all_path = os.path.join(output_index_dir, f'{dataset_name}_all_file_list.json') + with open(file_list_all_path, 'r') as f: + fl = json.load(f) + file_list_all = {int(key): value for key, value in fl.items()} + generate_pitch_shifted_wav_and_midi(file_list_all, min_pitch_shift=min_pitch_shift, max_pitch_shift=max_pitch_shift) + + # Create file_list.json for pitch shifted data + for split in ['progression_1', 'progression_2', 'progression_3', 'train', 'all']: + src_file_list_path = os.path.join(output_index_dir, f'{dataset_name}_{split}_file_list.json') + with open(src_file_list_path, 'r') as f: + fl = json.load(f) + src_file_list = {int(key): value for key, value in fl.items()} + + file_list = {} + for k, v in src_file_list.items(): + for pitch_shift in range(min_pitch_shift, max_pitch_shift): + if pitch_shift == 0: + file_list[k, 0] = copy.deepcopy(v) + else: + file_list[k, pitch_shift] = copy.deepcopy(v) + shifted_audio_file = v['mix_audio_file'].replace('.wav', f'_pshift{pitch_shift}.wav') + assert os.path.isfile(shifted_audio_file) == True + file_list[k, pitch_shift]['mix_audio_file'] = shifted_audio_file + file_list[k, pitch_shift]['n_frames'] = get_audio_file_info(shifted_audio_file)[1] + file_list[k, pitch_shift]['pitch_shift'] = pitch_shift + + shifted_midi_file = v['midi_file'].replace('.mid', f'_pshift{pitch_shift}.mid') + shifted_notes_file = v['notes_file'].replace('_notes', f'_pshift{pitch_shift}_notes') + shifted_note_events_file = v['note_events_file'].replace('_note_events', + f'_pshift{pitch_shift}_note_events') + assert os.path.isfile(shifted_midi_file) == True + assert os.path.isfile(shifted_notes_file) == True + assert os.path.isfile(shifted_note_events_file) == True + file_list[k, pitch_shift]['midi_file'] = shifted_midi_file + file_list[k, pitch_shift]['notes_file'] = shifted_notes_file + file_list[k, pitch_shift]['note_events_file'] = shifted_note_events_file + assert len(file_list) == len(src_file_list) * (max_pitch_shift - min_pitch_shift) + + # Reindexing file_list + _file_list = {} + for i, v in enumerate(file_list.values()): + _file_list[i] = v + file_list = _file_list + + # Write json + output_file = os.path.join(output_index_dir, f'{dataset_name}_{split}_pshift_file_list.json') + with open(output_file, 'w') as f: + json.dump(file_list, f, indent=4) + print(f'Created {output_file}') + + +def create_filelist_by_style_guitarset16k(data_home: os.PathLike, dataset_name: str = 'guitarset') -> None: + + # Directory and file paths + base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') + output_index_dir = os.path.join(data_home, 'yourmt3_indexes') + os.makedirs(output_index_dir, exist_ok=True) + + # Load filelist, pshift_all for train + file_list_pshift_all_path = os.path.join(output_index_dir, f'{dataset_name}_all_pshift_file_list.json') + with open(file_list_pshift_all_path, 'r') as f: + fl_pshift = json.load(f) + assert len(fl_pshift) == 7920 + + # Load filelist, all for test + file_list_all_path = os.path.join(output_index_dir, f'{dataset_name}_all_file_list.json') + with open(file_list_all_path, 'r') as f: + fl = json.load(f) + assert len(fl) == 720 + + # Create file_list.json for training each style using pitch shifted data + styles = ['BN', 'Funk', 'SS', 'Jazz', 'Rock'] + for style in styles: + # Create and write pshift file list + train_file_list = {} + i = 0 + for v in fl_pshift.values(): + if style in v['guitarset_id']: + train_file_list[i] = copy.deepcopy(v) + i += 1 + output_file = os.path.join(output_index_dir, f'{dataset_name}_{style}_pshift_file_list.json') + with open(output_file, 'w') as f: + json.dump(train_file_list, f, indent=4) + print(f'Created {output_file}') + + test_file_list = {} + i = 0 + for v in fl.values(): + if style in v['guitarset_id']: + test_file_list[i] = copy.deepcopy(v) + i += 1 + output_file = os.path.join(output_index_dir, f'{dataset_name}_{style}_file_list.json') + with open(output_file, 'w') as f: + json.dump(test_file_list, f, indent=4) + print(f'Created {output_file}') + + +# BASIC_PITCH_VALIDATION_IDS = [ +# "05_Funk2-108-Eb_comp", "04_BN2-131-B_comp", "04_Jazz3-150-C_solo", "05_Rock2-85-F_solo", +# "05_Funk3-98-A_comp", "05_BN3-119-G_comp", "02_SS2-107-Ab_solo", "01_BN2-131-B_solo", +# "00_BN2-166-Ab_comp", "04_SS1-100-C#_solo", "01_BN2-166-Ab_solo", "01_Rock1-130-A_solo", +# "04_Funk2-119-G_solo", "01_SS2-107-Ab_comp", "05_Funk3-98-A_solo", "05_Funk1-114-Ab_comp", +# "05_Jazz2-187-F#_solo", "05_SS1-100-C#_comp", "00_Rock3-148-C_solo", "02_Rock3-117-Bb_comp", +# "01_BN1-147-Gb_solo", "01_Rock1-90-C#_solo", "01_SS2-107-Ab_solo", "02_Jazz3-150-C_solo", +# "00_Funk1-97-C_solo", "05_SS3-98-C_solo", "03_Rock3-148-C_comp", "03_Rock3-117-Bb_solo", +# "04_Jazz2-187-F#_solo", "05_Jazz2-187-F#_comp", "02_SS1-68-E_solo", "04_SS2-88-F_solo", +# "04_BN2-131-B_solo", "04_Jazz3-137-Eb_comp", "00_SS2-107-Ab_comp", "01_Rock1-130-A_comp", +# "00_Jazz1-130-D_comp", "04_Funk2-108-Eb_comp", "05_BN2-166-Ab_comp" +# ] + +# BASIC_PITCH_TEST_IDS = [ +# "04_SS3-84-Bb_solo", "02_Funk1-114-Ab_solo", "05_Funk1-114-Ab_solo", "05_Funk1-97-C_solo", +# "00_Rock3-148-C_comp", "00_Jazz3-137-Eb_comp", "00_Jazz1-200-B_comp", "03_SS3-98-C_solo", +# "05_Jazz1-130-D_comp", "00_Jazz2-110-Bb_comp", "02_Funk3-98-A_comp", "04_Rock1-130-A_comp", +# "03_BN1-129-Eb_comp", "03_Funk2-119-G_comp", "05_BN1-147-Gb_comp", "02_Rock1-90-C#_comp", +# "00_Funk3-98-A_solo", "01_SS1-100-C#_comp", "00_Funk3-98-A_comp", "02_BN3-154-E_comp", +# "01_Jazz3-137-Eb_comp", "00_BN2-131-B_comp", "04_SS1-68-E_solo", "05_Funk1-97-C_comp", +# "04_Jazz3-137-Eb_solo", "05_Rock2-142-D_solo", "02_BN3-119-G_solo", "02_Rock2-142-D_solo", +# "01_BN1-129-Eb_solo", "00_Rock2-85-F_comp", "00_Rock1-130-A_solo" +# ] + +# def create_filelist_for_basic_pitch_benchmark_guitarset16k(data_home: os.PathLike, +# dataset_name: str = 'guitarset') -> None: + +# # Directory and file paths +# base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') +# output_index_dir = os.path.join(data_home, 'yourmt3_indexes') +# os.makedirs(output_index_dir, exist_ok=True) + +# # Load filelist, pshift_all for train +# file_list_pshift_all_path = os.path.join(output_index_dir, +# f'{dataset_name}_all_pshift_file_list.json') +# with open(file_list_pshift_all_path, 'r') as f: +# fl_pshift = json.load(f) +# assert len(fl_pshift) == 7920 + +# # Load filelist, all without pshift +# file_list_all_path = os.path.join(output_index_dir, f'{dataset_name}_all_file_list.json') +# with open(file_list_all_path, 'r') as f: +# fl = json.load(f) +# assert len(fl) == 720 + +# # This is abandoned, because the split is not official one. diff --git a/amt/src/utils/preprocess/preprocess_idmt_smt_bass.py b/amt/src/utils/preprocess/preprocess_idmt_smt_bass.py new file mode 100644 index 0000000000000000000000000000000000000000..aa9ac42f3338d1b929186b2e8f07a2708f4f4115 --- /dev/null +++ b/amt/src/utils/preprocess/preprocess_idmt_smt_bass.py @@ -0,0 +1,267 @@ +""" preprocess_idmt_smt_bass.py """ +import os +import glob +import json +import wave +import numpy as np +from typing import Dict, Tuple +from sklearn.model_selection import train_test_split +from utils.audio import get_audio_file_info, load_audio_file, write_wav_file, guess_onset_offset_by_amp_envelope +from utils.midi import midi2note, note_event2midi +from utils.note2event import note2note_event, sort_notes, validate_notes, trim_overlapping_notes +from utils.event2note import event2note_event +from utils.note_event_dataclasses import Note, NoteEvent +from utils.utils import assert_note_events_almost_equal + +SPLIT_INFO_FILE = 'stratified_split_crepe_smt.json' + +# Plucking style to GM program +PS2program = { + "FS": 33, # Fingered Elec Bass + "MU": 33, # Muted Elec Bass + "PK": 34, # Picked Elec Bass + "SP": 36, # Slap-Pluck Elec Bass + "ST": 37, # Salp-Thumb Elec Bass +} +PREPEND_SILENCE = 1.8 # seconds +APPEND_SILENCE = 1.8 # seconds + + +def bass_string_to_midi_pitch(string_number: int, fret: int, string_pitches=[28, 33, 38, 43, 48]): + """ sring_number: 1, 2, 3, 4, fret: 0, 1, 2, ...""" + return string_pitches[string_number - 1] + fret + + +def regenerate_stratified_split(audio_files_dict): + train_ids_dict = {} + val_ids_dict = {} + offset = 0 + + for key, files in audio_files_dict.items(): + ids = np.arange(len(files)) + offset + train_ids, val_ids = train_test_split( + ids, test_size=0.2, random_state=42, stratify=np.zeros_like(ids)) + train_ids_dict[key] = train_ids + val_ids_dict[key] = val_ids + offset += len(files) + + train_ids = np.concatenate(list(train_ids_dict.values())) + val_ids = np.concatenate(list(val_ids_dict.values())) + assert len(train_ids) == 1872 and len(val_ids) == 470 + return train_ids, val_ids + + +def create_note_event_and_note_from_midi(mid_file: str, + id: str, + program: int = 0, + ignore_pedal: bool = True) -> Tuple[Dict, Dict]: + """Extracts note or note_event and metadata from midi: + + Returns: + notes (dict): note events and metadata. + note_events (dict): note events and metadata. + """ + notes, dur_sec = midi2note( + mid_file, + binary_velocity=True, + force_all_program_to=program, + fix_offset=True, + quantize=True, + verbose=0, + minimum_offset_sec=0.01, + ignore_pedal=ignore_pedal) + return { # notes + 'idmt_smt_bass_id': str(id), + 'program': [program], + 'is_drum': [0], + 'duration_sec': dur_sec, + 'notes': notes, + }, { # note_events + 'idmt_smt_bass_id': str(id), + 'program': [0], + 'is_drum': [0], + 'duration_sec': dur_sec, + 'note_events': note2note_event(notes), + } + + +def preprocess_idmt_smt_bass_16k(data_home=os.PathLike, + dataset_name='idmt_smt_bass', + sanity_check=True, + edit_audio=True, + regenerate_split=False) -> None: + """ + Splits: stratified by plucking style + 'train': 1872 + 'validation': 470 + Total: 2342 + + Writes: + - {dataset_name}_{split}_file_list.json: a dictionary with the following keys: + { + index: + { + 'idmt_smt_bass_id': idmt_smt_bass_id, + 'n_frames': (int), + 'mix_audio_file': 'path/to/mix.wav', + 'notes_file': 'path/to/notes.npy', + 'note_events_file': 'path/to/note_events.npy', + 'midi_file': 'path/to/midi.mid', + 'program': List[int], see PS2program above + 'is_drum': List[int], # always [0] for this dataset + } + } + """ + + # Directory and file paths + base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') + output_index_dir = os.path.join(data_home, 'yourmt3_indexes') + os.makedirs(output_index_dir, exist_ok=True) + + # # audio file list + # FS_audio_pattern = os.path.join(base_dir, 'PS/FS/*.wav') + # MU_audio_pattern = os.path.join(base_dir, 'PS/MU/*.wav') + # PK_audio_pattern = os.path.join(base_dir, 'PS/PK/*.wav') + # SP_audio_pattern = os.path.join(base_dir, 'PS/SP/*.wav') + # ST_audio_pattern = os.path.join(base_dir, 'PS/ST/*.wav') + # FS_audio_files = sorted(glob.glob(FS_audio_pattern, recursive=False)) + # MU_audio_files = sorted(glob.glob(MU_audio_pattern, recursive=False)) + # PK_audio_files = sorted(glob.glob(PK_audio_pattern, recursive=False)) + # SP_audio_files = sorted(glob.glob(SP_audio_pattern, recursive=False)) + # ST_audio_files = sorted(glob.glob(ST_audio_pattern, recursive=False)) + # assert len(FS_audio_files) == 469 + # assert len(MU_audio_files) == 468 + # assert len(PK_audio_files) == 468 + # assert len(SP_audio_files) == 469 + # assert len(ST_audio_files) == 468 + # audio_files_dict = { + # 'FS': FS_audio_files, + # 'MU': MU_audio_files, + # 'PK': PK_audio_files, + # 'SP': SP_audio_files, + # 'ST': ST_audio_files + # } + + # splits: + split_info_file = os.path.join(base_dir, SPLIT_INFO_FILE) + with open(split_info_file, 'r') as f: + split_info = json.load(f) + + all_info_dict = {} + id = 0 + for split in ['train', 'validation']: + for file_path in split_info[split]: + audio_file = os.path.join(base_dir, file_path) + assert os.path.exists(audio_file) + all_info_dict[id] = { + 'idmt_smt_bass_id': id, + 'n_frames': None, + 'mix_audio_file': audio_file, + 'notes_file': None, + 'note_events_file': None, + 'midi_file': None, + 'program': None, + 'is_drum': [0] + } + id += 1 + train_ids = np.arange(len(split_info['train'])) + val_ids = np.arange(len(split_info['validation'])) + len(train_ids) + # if regenerate_split is True: + # train_ids, val_ids = regenerate_stratified_split(audio_files_dict) + # else: + # val_ids = VALIDATION_IDS + # train_ids = [i for i in range(len(all_info_dict)) if i not in val_ids] + + # Audio processing: prepend/append 1.8s silence + if edit_audio is True: + for v in all_info_dict.values(): + audio_file = v['mix_audio_file'] + fs, x_len, _ = get_audio_file_info(audio_file) + x = load_audio_file(audio_file) # (T,) + prefix_len = int(fs * PREPEND_SILENCE) + suffix_len = int(fs * APPEND_SILENCE) + x_new_len = prefix_len + x_len + suffix_len + x_new = np.zeros(x_new_len) + x_new[prefix_len:prefix_len + x_len] = x + + # overwrite audio file + print(f'Overwriting {audio_file} with silence prepended/appended') + write_wav_file(audio_file, x_new, fs) + + # Guess Program/Pitch/Onset/Offset and Generate Notes/NoteEvents/MIDI + for id in all_info_dict.keys(): + audio_file = all_info_dict[id]['mix_audio_file'] + + # Guess program/pitch from audio file name + _, _, _, _, pluck_style, _, string_num, fret_num = os.path.basename(audio_file).split( + '.')[0].split('_') + program = PS2program[pluck_style] + pitch = bass_string_to_midi_pitch(int(string_num), int(fret_num)) + + # Guess onset/offset from audio signal x + fs, n_frames, _ = get_audio_file_info(audio_file) + x = load_audio_file(audio_file, fs=fs) + onset, offset, _ = guess_onset_offset_by_amp_envelope( + x, fs=fs, onset_threshold=0.05, offset_threshold=0.02, frame_size=256) + onset = round((onset / fs) * 1000) / 1000 + offset = round((offset / fs) * 1000) / 1000 + + # Notes and NoteEvents + notes = [ + Note( + is_drum=False, + program=program, + onset=onset, + offset=offset, + pitch=pitch, + velocity=1, + ) + ] + note_events = note2note_event(notes) + + # Write MIDI + midi_file = audio_file.replace('.wav', '.mid') + note_event2midi(note_events, midi_file) + + # Reconvert MIDI to Notes/NoteEvents, and validate + notes_dict, note_events_dict = create_note_event_and_note_from_midi( + midi_file, id, program=program, ignore_pedal=True) + if sanity_check: + assert_note_events_almost_equal(note_events_dict['note_events'], note_events) + + # Write notes and note_events + notes_file = audio_file.replace('.wav', '_notes.npy') + note_events_file = audio_file.replace('.wav', '_note_events.npy') + np.save(notes_file, notes_dict, allow_pickle=True, fix_imports=False) + np.save(note_events_file, note_events_dict, allow_pickle=True, fix_imports=False) + print(f'Created {notes_file}') + print(f'Created {note_events_file}') + + # Update all_info_dict + all_info_dict[id]['n_frames'] = n_frames + all_info_dict[id]['notes_file'] = notes_file + all_info_dict[id]['note_events_file'] = note_events_file + all_info_dict[id]['midi_file'] = midi_file + all_info_dict[id]['program'] = [program] + + # Save index + ids = {'train': train_ids, 'validation': val_ids, 'all': list(all_info_dict.keys())} + for split in ['train', 'validation']: + fl = {} + for i, id in enumerate(ids[split]): + fl[i] = all_info_dict[id] + output_index_file = os.path.join(output_index_dir, f'{dataset_name}_{split}_file_list.json') + with open(output_index_file, 'w') as f: + json.dump(fl, f, indent=4) + print(f'Created {output_index_file}') + + +def test_guess_onset_offset_by_amp_envelope(all_info_dict): + import matplotlib.pyplot as plt + id = np.random.randint(0, 2300) + x = load_audio_file(all_info_dict[id]['mix_audio_file']) + onset, offset, amp_env = guess_onset_offset_by_amp_envelope(x) + plt.plot(x) + plt.axvline(x=onset, color='r', linestyle='--', label='onset') + plt.axvline(x=offset, color='g', linestyle='--', label='offset') + plt.show() \ No newline at end of file diff --git a/amt/src/utils/preprocess/preprocess_maestro.py b/amt/src/utils/preprocess/preprocess_maestro.py new file mode 100644 index 0000000000000000000000000000000000000000..4ea0d1b0c01c3afada00685e0bd985d293f8fdf8 --- /dev/null +++ b/amt/src/utils/preprocess/preprocess_maestro.py @@ -0,0 +1,176 @@ +"""preprocess_maestro.py""" +import os +import glob +import re +import json +from typing import Dict, List, Tuple +import numpy as np +from utils.audio import get_audio_file_info +from utils.midi import midi2note, note_event2midi +from utils.note2event import note2note_event, note_event2event +from utils.event2note import event2note_event +from utils.note_event_dataclasses import Note, NoteEvent +from utils.utils import note_event2token2note_event_sanity_check +from utils.utils import assert_note_events_almost_equal + + +def create_note_event_and_note_from_midi(mid_file: str, + id: str, + ignore_pedal: bool = False) -> Tuple[Dict, Dict]: + """Extracts note or note_event and metadata from midi: + + Returns: + notes (dict): note events and metadata. + note_events (dict): note events and metadata. + """ + notes, dur_sec = midi2note( + mid_file, + binary_velocity=True, + ch_9_as_drum=False, + force_all_drum=False, + force_all_program_to=0, # always piano + trim_overlap=True, + fix_offset=True, + quantize=True, + verbose=0, + minimum_offset_sec=0.01, + drum_offset_sec=0.01, + ignore_pedal=ignore_pedal) + return { # notes + 'maps_id': id, + 'program': [0], + 'is_drum': [0], + 'duration_sec': dur_sec + 0.01, + 'notes': notes, + }, { # note_events + 'maps_id': id, + 'program': [0], + 'is_drum': [0], + 'duration_sec': dur_sec + 0.01, + 'note_events': note2note_event(notes), + } + + +def note_event2event_sanity_check(note_events: List[NoteEvent]): + """Sanity check for note events.""" + events = note_event2event(note_events, None) + note_events2, _, _ = event2note_event(events) + assert_note_events_almost_equal(note_events, note_events2) + + +def preprocess_maestro16k(data_home=os.PathLike, + dataset_name='maestro', + ignore_pedal=False, + sanity_check=False) -> None: + """ + Splits: + - train: 962 files + - validation: 137 files + - test: 177 files + - all: 1276 file + + Writes: + - {dataset_name}_{split}_file_list.json: a dictionary with the following keys: + { + index: + { + 'maestro_id': maestro_id, + 'n_frames': (int), + 'mix_audio_file': 'path/to/mix.wav', + 'notes_file': 'path/to/notes.npy', + 'note_events_file': 'path/to/note_events.npy', + 'midi_file': 'path/to/midi.mid', + 'program': List[int], + 'is_drum': List[int], # 0 or 1 + } + } + """ + + # Directory and file paths + base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') + output_index_dir = os.path.join(data_home, 'yourmt3_indexes') + os.makedirs(output_index_dir, exist_ok=True) + + # Get metadata + metadata_file = os.path.join(base_dir, 'maestro-v3.0.0.json') + with open(metadata_file, 'r') as f: + _metadata = json.load(f) + metadata = {} + ids_all = list(range(len(_metadata['canonical_composer']))) + assert len(ids_all) == 1276 + for i in ids_all: + metadata[i] = {} + for key in ['split', 'midi_filename', 'audio_filename', 'duration']: + metadata[i][key] = _metadata[key][str(i)] + + # Collect ids and prepend base_dir to filenames + ids = {'all': ids_all, 'train': [], 'validation': [], 'test': []} + for i in ids_all: + m = metadata[i] + ids[m['split']].append(i) + # Prepend base_dir + m['midi_filename'] = os.path.join(base_dir, m['midi_filename']) + m['audio_filename'] = os.path.join(base_dir, m['audio_filename']) + + # Rename '.midi' to '.mid' + if '.midi' in m['midi_filename'] and not os.path.exists(m['midi_filename'].replace( + '.midi', '.mid')): + os.rename(m['midi_filename'], m['midi_filename'].replace('.midi', '.mid')) + m['midi_filename'] = m['midi_filename'].replace('.midi', '.mid') + + # File sanity check + assert os.path.exists(m['midi_filename']) and '.mid' == m['midi_filename'][-4:] + assert os.path.exists(m['audio_filename']) and '.wav' in m['audio_filename'] + + assert len(ids['train']) == 962 + assert len(ids['validation']) == 137 + assert len(ids['test']) == 177 + + # Create 'all' filelist, and process MIDI + file_list = {} + for i in ids['all']: + m = metadata[i] + mix_audio_file = m['audio_filename'] + fs, n_frames, n_channels = get_audio_file_info(mix_audio_file) + assert fs == 16000 and n_channels == 1 + n_frames = min(int(m['duration'] * 16000), n_frames) + assert n_frames > 32001 + + notes_file = m['midi_filename'].replace('.mid', '_notes.npy') + note_events_file = m['midi_filename'].replace('.mid', '_note_events.npy') + midi_file = m['midi_filename'] + + file_list[i] = { + 'maestro_id': i, + 'n_frames': n_frames, + 'mix_audio_file': mix_audio_file, + 'notes_file': notes_file, + 'note_events_file': note_events_file, + 'midi_file': midi_file, + 'program': [0], + 'is_drum': [0], + } + + # Process MIDI + notes, note_events = create_note_event_and_note_from_midi( + mid_file=midi_file, id=i, ignore_pedal=ignore_pedal) + + if sanity_check: + # sanity check + print(f'Sanity check for {i}: {midi_file}') + note_event2token2note_event_sanity_check(note_events['note_events'], notes['notes']) + + np.save(notes_file, notes, allow_pickle=True, fix_imports=False) + print(f'Created {notes_file}') + np.save(note_events_file, note_events, allow_pickle=True, fix_imports=False) + print(f'Created {note_events_file}') + + # Save index + for split in ['all', 'train', 'validation', 'test']: + fl = {} + for i, maestro_id in enumerate(ids[split]): + fl[i] = file_list[maestro_id] + output_index_file = os.path.join(output_index_dir, f'{dataset_name}_{split}_file_list.json') + with open(output_index_file, 'w') as f: + json.dump(fl, f, indent=4) + print(f'Created {output_index_file}') diff --git a/amt/src/utils/preprocess/preprocess_maps.py b/amt/src/utils/preprocess/preprocess_maps.py new file mode 100644 index 0000000000000000000000000000000000000000..2337de4585beac53d1f01d2773e9adeab0b92a9f --- /dev/null +++ b/amt/src/utils/preprocess/preprocess_maps.py @@ -0,0 +1,183 @@ +"""preprocess_maps.py""" +import os +import glob +import re +import json +from typing import Dict, List, Tuple +import numpy as np +from utils.audio import get_audio_file_info +from utils.midi import midi2note, note_event2midi +from utils.note2event import note2note_event, note_event2event +from utils.event2note import event2note_event +from utils.note_event_dataclasses import Note, NoteEvent +from utils.utils import note_event2token2note_event_sanity_check +# from utils.utils import assert_note_events_almost_equal + + +def create_note_event_and_note_from_midi(mid_file: str, + id: str, + ignore_pedal: bool = False) -> Tuple[Dict, Dict]: + """Extracts note or note_event and metadata from midi: + + Returns: + notes (dict): note events and metadata. + note_events (dict): note events and metadata. + """ + notes, dur_sec = midi2note( + mid_file, + binary_velocity=True, + ch_9_as_drum=False, + force_all_drum=False, + force_all_program_to=0, # always piano + trim_overlap=True, + fix_offset=True, + quantize=True, + verbose=0, + minimum_offset_sec=0.01, + drum_offset_sec=0.01, + ignore_pedal=ignore_pedal) + return { # notes + 'maps_id': id, + 'program': [0], + 'is_drum': [0], + 'duration_sec': dur_sec, + 'notes': notes, + }, { # note_events + 'maps_id': id, + 'program': [0], + 'is_drum': [0], + 'duration_sec': dur_sec, + 'note_events': note2note_event(notes), + } + + +def rewrite_midi_120bpm(file: os.PathLike, note_events: List[NoteEvent]): + """Rewrite midi file with 120 bpm.""" + note_event2midi(note_events, file) + return + + +# def note_event2event_sanity_check(note_events: List[NoteEvent]): +# """Sanity check for note events.""" +# events = note_event2event(note_events, None) +# note_events2, _, _ = event2note_event(events) +# assert_note_events_almost_equal(note_events, note_events2) + + +def preprocess_maps16k(data_home=os.PathLike, + dataset_name='maps', + ignore_pedal=False, + sanity_check=False) -> None: + """ + Splits: + - train: following the convention described in Cheuk et al. (2021), + we filter out the songs overlapping with the MAPS test set. + 139 pieces from MUS folder are left for training. + - test: 60 files (MUS) + - all: 270 files including (unfiltered) train and test. This is used + for the evaluation on the MusicNet test set. + + + Writes: + - {dataset_name}_{split}_file_list.json: a dictionary with the following keys: + { + index: + { + 'maps_id': maps_id, + 'n_frames': (int), + 'mix_audio_file': 'path/to/mix.wav', + 'notes_file': 'path/to/notes.npy', + 'note_events_file': 'path/to/note_events.npy', + 'midi_file': 'path/to/midi.mid', + 'program': List[int], + 'is_drum': List[int], # 0 or 1 + } + } + """ + + # Directory and file paths + base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') + output_index_dir = os.path.join(data_home, 'yourmt3_indexes') + os.makedirs(output_index_dir, exist_ok=True) + + # Search for files with .mid and .wav (synth / acoustic) extensions + train_mid_pattern = os.path.join(base_dir, 'train/**/MUS/*.mid') + test_mid_pattern = os.path.join(base_dir, 'test/**/MUS/*.mid') + all_mid_pattern = os.path.join(base_dir, '**/MUS/*.mid') + + train_mid_files = glob.glob(train_mid_pattern, recursive=True) + test_mid_files = glob.glob(test_mid_pattern, recursive=True) + all_mid_files = glob.glob(all_mid_pattern, recursive=True) + + # Discard duplicated songs from train and test sets (reduce train set) + songnames_in_test_files = [] + for file in test_mid_files: + filename = os.path.basename(file) + match = re.search(r"MAPS_MUS-([\w-]+)_", filename) + if match: + songnames_in_test_files.append(match.group(1)) + + filtered_train_mid_files = [] + filtered_train_wav_files = [] + for train_file in train_mid_files: + if not any( + songname in os.path.basename(train_file) for songname in songnames_in_test_files): + filtered_train_mid_files.append(train_file) + filtered_train_wav_files.append(train_file.replace('.mid', '.wav')) + assert len(filtered_train_mid_files) == len(filtered_train_wav_files) == 139 + + # Process MIDI files + for i, mid_file in enumerate(all_mid_files): + maps_id = os.path.basename(mid_file)[:-4] + notes, note_events = create_note_event_and_note_from_midi( + mid_file=mid_file, id=maps_id, ignore_pedal=ignore_pedal) + + if sanity_check: + # sanity check + print(f'Sanity check for {i}: {maps_id}...') + note_event2token2note_event_sanity_check(note_events['note_events'], notes['notes']) + + notes_file = mid_file.replace('.mid', '_notes.npy') + np.save(notes_file, notes, allow_pickle=True, fix_imports=False) + print(f'Created {notes_file}') + + note_events_file = mid_file.replace('.mid', '_note_events.npy') + np.save(note_events_file, note_events, allow_pickle=True, fix_imports=False) + print(f'Created {note_events_file}') + + # overwrite midi file with 120 bpm + rewrite_midi_120bpm(mid_file, note_events['note_events']) + print(f'Overwrote {mid_file} with 120 bpm') + + # Process audio files + pass + + # Create file_list.json + mid_files_by_split = { + 'train': filtered_train_mid_files, + 'test': test_mid_files, + 'all': all_mid_files, + } + + for split in ['train', 'test', 'all']: + file_list = {} + for i, mid_file in enumerate(mid_files_by_split[split]): + # check if wav file exists + wav_file = mid_file.replace('.mid', '.wav') + if not os.path.exists(wav_file): + raise FileNotFoundError(f'Wav file not found: {wav_file}') + + file_list[i] = { + 'maps_id': os.path.basename(mid_file)[:-4], + 'n_frames': get_audio_file_info(wav_file)[1], + 'mix_audio_file': wav_file, + 'notes_file': mid_file.replace('.mid', '_notes.npy'), + 'note_events_file': mid_file.replace('.mid', '_note_events.npy'), + 'midi_file': mid_file, + 'program': [0], + 'is_drum': [0], + } + output_file = os.path.join(output_index_dir, f'{dataset_name}_{split}_file_list.json') + with open(output_file, 'w') as f: + json.dump(file_list, f, indent=4) + print(f'Created {output_file}') diff --git a/amt/src/utils/preprocess/preprocess_mir1k.py b/amt/src/utils/preprocess/preprocess_mir1k.py new file mode 100644 index 0000000000000000000000000000000000000000..66db0ac8a480dea48991793a692a952f64ad7101 --- /dev/null +++ b/amt/src/utils/preprocess/preprocess_mir1k.py @@ -0,0 +1,112 @@ +"""preprocess_mir1k.py""" +import os +import glob +import re +import json +from typing import Dict, List, Tuple +import numpy as np +from utils.audio import get_audio_file_info, load_audio_file +from utils.midi import midi2note, note_event2midi +from utils.note2event import note2note_event, sort_notes, validate_notes, trim_overlapping_notes +from utils.event2note import event2note_event +from utils.note_event_dataclasses import Note, NoteEvent +from utils.utils import note_event2token2note_event_sanity_check + +# def create_spleeter_audio_stem(vocal_audio_file, accomp_audio_file, mir_st500_id) -> Dict: +# program = MIR_ST500_PROGRAM +# is_drum = [0, 0] + +# audio_tracks = [] # multi-channel audio array (C, T) +# vocal_audio = load_audio_file(vocal_audio_file, dtype=np.int16) / 2**15 # returns bytes +# audio_tracks.append(vocal_audio.astype(np.float16)) +# accomp_audio = load_audio_file(accomp_audio_file, dtype=np.int16) / 2**15 # returns bytes +# audio_tracks.append(accomp_audio.astype(np.float16)) +# max_length = max(len(vocal_audio), len(accomp_audio)) + +# # collate all the audio tracks into a single array +# n_tracks = 2 +# audio_array = np.zeros((n_tracks, max_length), dtype=np.float16) +# for j, audio in enumerate(audio_tracks): +# audio_array[j, :len(audio)] = audio + +# stem_content = { +# 'mir_st500_id': mir_st500_id, +# 'program': np.array(program, dtype=np.int64), +# 'is_drum': np.array(is_drum, dtype=np.int64), +# 'n_frames': max_length, # int +# 'audio_array': audio_array # (n_tracks, n_frames) +# } +# return stem_content + +# def create_note_note_event_midi_from_mir1k_annotation(ann, midi_file, mir_st500_id): +# """ +# Args: +# ann: List[List[float, float, float]] # [onset, offset, pitch] +# mir_st500_id: str +# Returns: +# notes: List[Note] +# note_events: List[NoteEvent] +# midi: List[List[int]] +# """ +# notes = [] +# for onset, offset, pitch in ann: +# notes.append( +# Note( +# is_drum=False, +# program=100, +# onset=float(onset), +# offset=float(offset), +# pitch=int(pitch), +# velocity=1)) +# notes = sort_notes(notes) +# notes = validate_notes(notes) +# notes = trim_overlapping_notes(notes) +# note_events = note2note_event(notes) + +# # Write midi file +# note_event2midi(note_events, midi_file) +# print(f"Created {midi_file}") + +# return { # notes +# 'mir_st500_id': mir_st500_id, +# 'program': MIR_ST500_PROGRAM, +# 'is_drum': [0, 0], +# 'duration_sec': note_events[-1].time, +# 'notes': notes, +# }, { # note_events +# 'mir_st500_id': mir_st500_id, +# 'program': MIR_ST500_PROGRAM, +# 'is_drum': [0, 0], +# 'duration_sec': note_events[-1].time, +# 'note_events': note_events, +# } + + +def preprocess_mir1k_16k(data_home=os.PathLike, dataset_name='mir1k', sanity_check=False) -> None: + """ + Splits: + - train: index 1 to 400, 346 files (54 files missing) + - test: index 401 to 500, 94 files (6 files missing) + - all: 440 files (60 files missing) + + Writes: + - {dataset_name}_{split}_file_list.json: a dictionary with the following keys: + { + index: + { + 'mir_st500_id': mir_st500_id, + 'n_frames': (int), + 'mix_audio_file': 'path/to/mix.wav', + 'notes_file': 'path/to/notes.npy', + 'note_events_file': 'path/to/note_events.npy', + 'midi_file': 'path/to/midi.mid', + 'program': List[int], # [100, 129], 100 for singing voice, and 129 for unannotated + 'is_drum': List[int], # [0] or [1] + } + } + """ + + # Directory and file paths + base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') + output_index_dir = os.path.join(data_home, 'yourmt3_indexes') + os.makedirs(output_index_dir, exist_ok=True) \ No newline at end of file diff --git a/amt/src/utils/preprocess/preprocess_mir_st500.py b/amt/src/utils/preprocess/preprocess_mir_st500.py new file mode 100644 index 0000000000000000000000000000000000000000..0d3741f8747cb38d013d42c60c5648e55da41868 --- /dev/null +++ b/amt/src/utils/preprocess/preprocess_mir_st500.py @@ -0,0 +1,274 @@ +"""preprocess_mir_st500.py""" +import os +import json +from typing import Dict +import numpy as np +from utils.audio import get_audio_file_info, load_audio_file +from utils.midi import midi2note, note_event2midi +from utils.note2event import note2note_event, sort_notes, validate_notes, trim_overlapping_notes +from utils.event2note import event2note_event +from utils.note_event_dataclasses import Note, NoteEvent +from utils.utils import note_event2token2note_event_sanity_check + +SINGING_WITH_UNANNOTATED_PROGRAM = [100, 129] # 100 for singing voice, 129 for unannotated +SINGING_ONLY_PROGRAM = [100] + + +def check_file_existence(file: str) -> bool: + """Checks if file exists.""" + res = True + if not os.path.exists(file): + res = False + elif get_audio_file_info(file)[1] < 10 * 16000: + print(f'File {file} is too short.') + res = False + return res + + +def create_spleeter_audio_stem(vocal_audio_file, accomp_audio_file, mir_st500_id) -> Dict: + program = SINGING_WITH_UNANNOTATED_PROGRAM + is_drum = [0, 0] + + audio_tracks = [] # multi-channel audio array (C, T) + vocal_audio = load_audio_file(vocal_audio_file, dtype=np.int16) / 2**15 # returns bytes + audio_tracks.append(vocal_audio.astype(np.float16)) + accomp_audio = load_audio_file(accomp_audio_file, dtype=np.int16) / 2**15 # returns bytes + audio_tracks.append(accomp_audio.astype(np.float16)) + max_length = max(len(vocal_audio), len(accomp_audio)) + + # collate all the audio tracks into a single array + n_tracks = 2 + audio_array = np.zeros((n_tracks, max_length), dtype=np.float16) + for j, audio in enumerate(audio_tracks): + audio_array[j, :len(audio)] = audio + + stem_content = { + 'mir_st500_id': mir_st500_id, + 'program': np.array(program, dtype=np.int64), + 'is_drum': np.array(is_drum, dtype=np.int64), + 'n_frames': max_length, # int + 'audio_array': audio_array # (n_tracks, n_frames) + } + return stem_content + + +def create_note_note_event_midi_from_mir_st500_annotation(ann, midi_file, mir_st500_id): + """ + Args: + ann: List[List[float, float, float]] # [onset, offset, pitch] + mir_st500_id: str + Returns: + notes: List[Note] + note_events: List[NoteEvent] + midi: List[List[int]] + """ + notes = [] + for onset, offset, pitch in ann: + notes.append( + Note( + is_drum=False, + program=100, + onset=float(onset), + offset=float(offset), + pitch=int(pitch), + velocity=1)) + notes = sort_notes(notes) + notes = validate_notes(notes) + notes = trim_overlapping_notes(notes) + note_events = note2note_event(notes) + + # Write midi file + note_event2midi(note_events, midi_file) + print(f"Created {midi_file}") + + return { # notes + 'mir_st500_id': mir_st500_id, + 'program': SINGING_ONLY_PROGRAM, + 'is_drum': [0, 0], + 'duration_sec': note_events[-1].time, + 'notes': notes, + }, { # note_events + 'mir_st500_id': mir_st500_id, + 'program': SINGING_ONLY_PROGRAM, + 'is_drum': [0, 0], + 'duration_sec': note_events[-1].time, + 'note_events': note_events, + } + + +def correct_ann(ann_all: Dict, fix_offset: bool = False, max_dur: float = 0.5): + """ correct too short notes that are actully sung in legato """ + for i in range(1, 101): + for j, v in enumerate(ann_all[str(i)]): + dur = v[1] - v[0] + if dur < 0.01: + next_onset = ann_all[str(i)][j + 1][0] + dist_to_next_onset = next_onset - v[1] + if fix_offset is True: + if dist_to_next_onset < max_dur: + # correct the offset + ann_all[str(i)][j][1] = next_onset + print(f'Corrected track {i}: {v} to {ann_all[str(i)][j]}') + else: + print(v, ann_all[str(i)][j + 1], f'dist_to_next_onset: {dist_to_next_onset}') + + +def preprocess_mir_st500_16k(data_home=os.PathLike, + dataset_name='mir_st500', + apply_correction=False, + sanity_check=False) -> None: + """ + Splits: + 'train', + 'train_vocal', + 'train_stem', + 'test', + 'test_vocal', + 'all', + 'all_vocal', + 'all_stem' + + Writes: + - {dataset_name}_{split}_file_list.json: a dictionary with the following keys: + { + index: + { + 'mir_st500_id': mir_st500_id, + 'n_frames': (int), + 'mix_audio_file': 'path/to/mix.wav', + 'notes_file': 'path/to/notes.npy', + 'note_events_file': 'path/to/note_events.npy', + 'midi_file': 'path/to/midi.mid', + 'program': List[int], 100 for singing voice, and 129 for unannotated + 'is_drum': List[int], # [0] or [1] + } + } + """ + + # Directory and file paths + base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') + output_index_dir = os.path.join(data_home, 'yourmt3_indexes') + os.makedirs(output_index_dir, exist_ok=True) + + # Load annotation json file as dictionary + ann_file = os.path.join(base_dir, 'MIR-ST500_20210206', 'MIR-ST500_corrected.json') + with open(ann_file, 'r') as f: + ann_all = json.load(f) # index "1" to "500" + + # Correction for annotation + correct_ann(ann_all, fix_offset=apply_correction, max_dur=0.5) + + # Check missing audio files and create a dictionary + audio_all = {} # except for missing files + audio_missing = {'train': [], 'test': []} + for i in range(1, 501): + split = 'train' if i < 401 else 'test' + audio_file = os.path.join(base_dir, f'{split}', f'{i}', 'converted_Mixture.wav') + audio_vocal_file = os.path.join(base_dir, f'{split}', f'{i}', 'vocals.wav') + audio_acc_file = os.path.join(base_dir, f'{split}', f'{i}', 'accompaniment.wav') + if check_file_existence(audio_file) and check_file_existence( + audio_vocal_file) and check_file_existence(audio_acc_file): + audio_all[str(i)] = audio_file + else: + audio_missing[split].append(i) + print( + f'Number of missing audio files: train = {len(audio_missing["train"])}, test = {len(audio_missing["test"])}' + ) + assert len(audio_all.keys()) == 500 + + # Track ids + ids_all = audio_all.keys() + ids_train = [] + ids_test = [] + for i in ids_all: + if int(i) < 401: + ids_train.append(i) + else: + ids_test.append(i) + # assert len(ids_train) == 346 and len(ids_test) == 94 + assert len(ids_train) == 400 and len(ids_test) == 100 + + # Create notes, note_events, and MIDI from annotation + for id in ids_all: + ann = ann_all[id] + split = 'train' if int(id) < 401 else 'test' + midi_file = os.path.join(base_dir, f'{split}', id, 'singing.mid') + notes, note_events = create_note_note_event_midi_from_mir_st500_annotation( + ann, midi_file, id) + + notes_file = midi_file.replace('.mid', '_notes.npy') + note_events_file = midi_file.replace('.mid', '_note_events.npy') + np.save(notes_file, notes, allow_pickle=True, fix_imports=False) + print(f"Created {notes_file}") + np.save(note_events_file, note_events, allow_pickle=True, fix_imports=False) + print(f"Created {note_events_file}") + + if sanity_check: + # sanity check + print(f'Sanity check for {id}...') + note_event2token2note_event_sanity_check(note_events['note_events'], notes['notes']) + + # Process audio files + for id in ids_all: + split = 'train' if int(id) < 401 else 'test' + audio_vocal_file = os.path.join(base_dir, f'{split}', id, 'vocals.wav') + audio_acc_file = os.path.join(base_dir, f'{split}', id, 'accompaniment.wav') + stem_file = os.path.join(base_dir, f'{split}', id, 'stem.npy') + stem_content = create_spleeter_audio_stem(audio_vocal_file, audio_acc_file, id) + # write audio stem + np.save(stem_file, stem_content, allow_pickle=True, fix_imports=False) + print(f"Created {stem_file}") + + # Create file_list.json + ids_by_split = { + 'train': ids_train, + 'train_vocal': ids_train, + 'train_stem': ids_train, + 'test': ids_test, + 'test_vocal': ids_test, + 'all': ids_all, + 'all_vocal': ids_all, + 'all_stem': ids_all + } + + for split in [ + 'train', 'train_vocal', 'train_stem', 'test', 'test_vocal', 'all', 'all_vocal', + 'all_stem' + ]: + file_list = {} + for i, id in enumerate(ids_by_split[split]): + wav_file = audio_all[id] + n_frames = get_audio_file_info(wav_file)[1] + if 'vocal' in split: + stem_file = None + wav_file = wav_file.replace('converted_Mixture.wav', 'vocals.wav') + program = SINGING_ONLY_PROGRAM + is_drum = [0] + elif 'stem' in split: + stem_file = wav_file.replace('converted_Mixture.wav', 'stem.npy') + program = SINGING_WITH_UNANNOTATED_PROGRAM + is_drum = [0, 0] + else: + stem_file = None + program = SINGING_WITH_UNANNOTATED_PROGRAM + is_drum = [0, 0] + + mid_file = os.path.join(os.path.dirname(wav_file), 'singing.mid') + file_list[i] = { + 'mir_st500_id': id, + 'n_frames': n_frames, + 'stem_file': stem_file, + 'mix_audio_file': wav_file, + 'notes_file': mid_file.replace('.mid', '_notes.npy'), + 'note_events_file': mid_file.replace('.mid', '_note_events.npy'), + 'midi_file': mid_file, + 'program': program, + 'is_drum': is_drum, + } + if stem_file is None: + del file_list[i]['stem_file'] + + output_file = os.path.join(output_index_dir, f'{dataset_name}_{split}_file_list.json') + with open(output_file, 'w') as f: + json.dump(file_list, f, indent=4) + print(f'Created {output_file}') \ No newline at end of file diff --git a/amt/src/utils/preprocess/preprocess_musicnet.py b/amt/src/utils/preprocess/preprocess_musicnet.py new file mode 100644 index 0000000000000000000000000000000000000000..3490fa2df31932b7df88fe50a06b2c6f0fc3e568 --- /dev/null +++ b/amt/src/utils/preprocess/preprocess_musicnet.py @@ -0,0 +1,486 @@ +"""preprocess_musicnet.py""" +import os +import glob +import csv +import json +from typing import Dict, List, Tuple +import numpy as np +from utils.audio import get_audio_file_info +from utils.midi import midi2note +from utils.note2event import note2note_event +from utils.note_event_dataclasses import Note + +# yapf: disable +MUSICNET_SPLIT_INFO = { + 'train_mt3': [], # the first 300 songs are synth dataset, while the remaining 300 songs are acoustic dataset. + 'train_mt3_synth' : [], # Note: this is not the synthetic dataset of EM (MIDI Pop 80K) nor pitch-augmented. Just recording of MusicNet MIDI, split by MT3 author's split. But not sure if they used this (maybe not). + 'train_mt3_acoustic': [], + 'validation_mt3': [1733, 1765, 1790, 1818, 2160, 2198, 2289, 2300, 2308, 2315, 2336, 2466, 2477, 2504, 2611], + 'validation_mt3_synth': [1733, 1765, 1790, 1818, 2160, 2198, 2289, 2300, 2308, 2315, 2336, 2466, 2477, 2504, 2611], + 'validation_mt3_acoustic': [1733, 1765, 1790, 1818, 2160, 2198, 2289, 2300, 2308, 2315, 2336, 2466, 2477, 2504, 2611], + 'test_mt3_acoustic': [1729, 1776, 1813, 1893, 2118, 2186, 2296, 2431, 2432, 2487, 2497, 2501, 2507, 2537, 2621], + 'train_thickstun': [], # the first 320 songs are synth dataset, while the remaining 320 songs are acoustic dataset. + 'test_thickstun': [1819, 2303, 2382], + 'test_thickstun_em': [1819, 2303, 2382], + 'test_thickstun_ext': [1759, 1819, 2106, 2191, 2298, 2303, 2382, 2416, 2556, 2628], + 'test_thickstun_ext_em': [1759, 1819, 2106, 2191, 2298, 2303, 2382, 2416, 2556, 2628], + 'train_mt3_em': [], # 300 synth + 293 tracks for MT3 acoustic train set - 7 EM tracks are missing: [2194, 2211, 2227, 2230, 2292, 2305, 2310]. + 'train_thickstun_em': [], # 320 synth + 313 tracks for Thickstun acoustic train set - 7 EM tracks are missing. + 'validation_mt3_em': [1733, 1765, 1790, 1818, 2160, 2198, 2289, 2300, 2308, 2315, 2336, 2466, 2477, 2504, 2611], # ours + 'test_mt3_em': [1729, 1776, 1813, 1893, 2118, 2186, 2296, 2431, 2432, 2487, 2497, 2501, 2507, 2537, 2621], # ours + 'test_em_table2' : [2191, 2628, 2106, 2298, 1819, 2416], # strings and winds from Cheuk's split, using EM annotations + 'test_cheuk_table2' : [2191, 2628, 2106, 2298, 1819, 2416], # strings and winds from Cheuk's split, using Thickstun's annotations + 'test_thickstun_ext_em': [1759, 1819, 2106, 2191, 2298, 2303, 2382, 2416, 2556, 2628], +} +# Table 4 of EM is not included here. + +# yapf: enable +MUSICNET_DISCARD_INFO = ['test_labels_midi/1759.mid', + 'test_labels_midi/1819.mid'] # duplicated midi files +MUSICNET_EM_MISSING_IDS = set(['2194', '2211', '2227', '2230', '2292', '2305', '2310']) + +MUSICNET_FS = 44100 + + +def create_note_event_and_note_from_label(label_file: str, id: str): + """Extracts note or note_event and metadata from a label file. + + Returns: + notes (dict): note events and metadata. + note_events (dict): note events and metadata. + """ + program_numbers = set() + notes = [] + with open(label_file, 'r', newline='', encoding='utf-8') as c: + csv_reader = csv.reader(c) + for i, row in enumerate(csv_reader): + if i == 0: + continue + start_frame, end_frame, program, pitch, _, _, _ = row + new_note = Note( + is_drum=False, + program=int(program), + onset=float(start_frame) / MUSICNET_FS, + offset=float(end_frame) / MUSICNET_FS, + pitch=int(pitch), + velocity=1) + notes.append(new_note) + program_numbers.add(int(program)) + program_numbers = list(program_numbers) + + return { # notes + 'musicnet_id': id, + 'program': program_numbers, + 'is_drum': [0]*len(program_numbers), + 'duration_sec': notes[0].offset, + 'notes': notes, + }, { # note_events + 'musicnet_id': id, + 'program': program_numbers, + 'is_drum': [0]*len(program_numbers), + 'duration_sec': notes[0].offset, + 'note_events': note2note_event(notes), + } + + +def create_note_event_and_note_from_midi(mid_file: str, id: str) -> Tuple[Dict, Dict]: + """Extracts note or note_event and metadata from midi: + + Returns: + notes (dict): note events and metadata. + note_events (dict): note events and metadata. + """ + notes, dur_sec = midi2note( + mid_file, + binary_velocity=True, + ch_9_as_drum=False, + force_all_drum=False, + force_all_program_to=None, + trim_overlap=True, + fix_offset=True, + quantize=True, + verbose=0, + minimum_offset_sec=0.01, + drum_offset_sec=0.01) + return { # notes + 'musicnet_id': id, + 'program': [], + 'is_drum': [], + 'duration_sec': dur_sec, + 'notes': notes, + }, { # note_events + 'musicnet_id': id, + 'program': [], + 'is_drum': [], + 'duration_sec': dur_sec, + 'note_events': note2note_event(notes), + } + + +def preprocess_musicnet16k(data_home=os.PathLike, dataset_name='musicnet') -> None: + """ + + Writes: + - {dataset_name}_{split}_file_list.json: a dictionary with the following keys: + { + index: + { + 'musicnet_id': musicnet_id, + 'n_frames': (int), + 'mix_audio_file': 'path/to/mix.wav', + 'notes_file': 'path/to/notes.npy', + 'note_events_file': 'path/to/note_events.npy', + 'midi_file': 'path/to/midi.mid', + 'program': List[int], + 'is_drum': List[int], # 0 or 1 + } + } + """ + + # Directory and file paths + base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') + output_index_dir = os.path.join(data_home, 'yourmt3_indexes') + os.makedirs(output_index_dir, exist_ok=True) + + # Search for files with .mid and .wav (synth / acoustic) extensions + label_pattern = os.path.join(base_dir, '*_labels', '*.csv') + mid_em_pattern = os.path.join(base_dir, '*_em', + '*.mid') # EM annotations for real performances (wav) + mid_pattern = os.path.join(base_dir, '*_midi', '*.mid') + wav_synth_pattern = os.path.join(base_dir, '*_synth', '*.wav') + wav_acoustic_pattern = os.path.join(base_dir, '*_data', '*.wav') + + label_files = glob.glob(label_pattern, recursive=True) + mid_em_files = glob.glob(mid_em_pattern, recursive=True) # 323 files, not 330! + mid_files = glob.glob(mid_pattern, recursive=True) + wav_synth_files = glob.glob(wav_synth_pattern, recursive=True) + wav_acoustic_files = glob.glob(wav_acoustic_pattern, recursive=True) + + # Discard duplicated files + for file in MUSICNET_DISCARD_INFO: + mid_files.remove(os.path.join(base_dir, file)) + assert (len(mid_files) == len(label_files) == len(wav_synth_files) == len(wav_acoustic_files) == + 330) + + # Sort files by id + musicnet_ids = [] + for label_file in label_files: + musicnet_ids.append(os.path.basename(label_file).split('.')[0]) + musicnet_ids.sort() + assert (len(musicnet_ids) == 330) + + musicnet_em_ids = [] + for mid_em_file in mid_em_files: + musicnet_em_ids.append(os.path.basename(mid_em_file).split('.')[0]) + assert (len(musicnet_em_ids) == 323) + + def search_file_by_musicnet_id(musicnet_id, files): + file_found = [f for f in files if musicnet_id in f + ] # this only works in 4-digits file names of MusicNet + assert (len(file_found) == 1) + return file_found[0] + + # yapf: disable + musicnet_dict = {} + for i in musicnet_ids: + musicnet_dict[i] = { + 'wav_acoustic_file': search_file_by_musicnet_id(i, wav_acoustic_files), + 'wav_synth_file': search_file_by_musicnet_id(i, wav_synth_files), + 'mid_file': search_file_by_musicnet_id(i, mid_files), + 'mid_em_file': search_file_by_musicnet_id(i, mid_em_files) if i in musicnet_em_ids else None, + 'label_file': search_file_by_musicnet_id(i, label_files), + 'program': [], + 'is_drum': [], + 'duration_sec': 0., + 'notes_file_acoustic': '', + 'note_events_file_acoustic': '', + 'notes_file_synth': '', + 'note_events_file_synth': '', + 'notes_file_em': '', + 'note_events_file_em': '', + } + # yapf: enable + + # Process label files + for i in musicnet_ids: + notes, note_events = create_note_event_and_note_from_label( + label_file=musicnet_dict[i]['label_file'], id=i) + + notes_file = os.path.join(musicnet_dict[i]['label_file'][:-4] + '_notes.npy') + np.save(notes_file, notes, allow_pickle=True, fix_imports=False) + print(f'Created {notes_file}') + + note_events_file = os.path.join(musicnet_dict[i]['label_file'][:-4] + '_note_events.npy') + np.save(note_events_file, note_events, allow_pickle=True, fix_imports=False) + print(f'Created {note_events_file}') + + # update musicnet_dict + musicnet_dict[i]['program'] = notes['program'] + musicnet_dict[i]['is_drum'] = notes['is_drum'] + musicnet_dict[i]['duration_sec'] = notes['duration_sec'] + musicnet_dict[i]['notes_file_acoustic'] = notes_file + musicnet_dict[i]['note_events_file_acoustic'] = note_events_file + + # Process MIDI files + for i in musicnet_ids: + # musicnet + notes, note_events = create_note_event_and_note_from_midi( + mid_file=musicnet_dict[i]['mid_file'], id=i) + notes['program'] = musicnet_dict[i]['program'].copy() + notes['is_drum'] = musicnet_dict[i]['is_drum'].copy() + notes_file = os.path.join(musicnet_dict[i]['mid_file'][:-4] + '_notes.npy') + np.save(notes_file, notes, allow_pickle=True, fix_imports=False) + print(f'Created {notes_file}') + + note_events['program'] = musicnet_dict[i]['program'].copy() + note_events['is_drum'] = musicnet_dict[i]['is_drum'].copy() + note_events_file = os.path.join(musicnet_dict[i]['mid_file'][:-4] + '_note_events.npy') + np.save(note_events_file, note_events, allow_pickle=True, fix_imports=False) + print(f'Created {note_events_file}') + + # update musicnet_dict + musicnet_dict[i]['duration_sec'] = max(notes['duration_sec'], + musicnet_dict[i]['duration_sec']) + musicnet_dict[i]['notes_file_synth'] = notes_file + musicnet_dict[i]['note_events_file_synth'] = note_events_file + + # musicnet_em + if i in musicnet_em_ids: + notes, note_events = create_note_event_and_note_from_midi( + mid_file=musicnet_dict[i]['mid_em_file'], id=i) + notes['program'] = musicnet_dict[i]['program'].copy() + notes['is_drum'] = musicnet_dict[i]['is_drum'].copy() + notes_file = os.path.join(musicnet_dict[i]['mid_em_file'][:-4] + '_notes.npy') + np.save(notes_file, notes, allow_pickle=True, fix_imports=False) + print(f'Created {notes_file}') + + note_events['program'] = musicnet_dict[i]['program'].copy() + note_events['is_drum'] = musicnet_dict[i]['is_drum'].copy() + note_events_file = os.path.join(musicnet_dict[i]['mid_em_file'][:-4] + + '_note_events.npy') + np.save(note_events_file, note_events, allow_pickle=True, fix_imports=False) + print(f'Created {note_events_file}') + + # update musicnet_dict: use the longest duration + musicnet_dict[i]['duration_sec'] = max(notes['duration_sec'], + musicnet_dict[i]['duration_sec']) + musicnet_dict[i]['notes_file_em'] = notes_file + musicnet_dict[i]['note_events_file_em'] = note_events_file + + # Process audio files + pass + + # Complete split dictionary + split_dict = MUSICNET_SPLIT_INFO.copy() + + # Convert each list in the dictionary to a list of strings + for key in split_dict: + split_dict[key] = [str(item) for item in split_dict[key]] + + # Convert each list to a sorted tuple of strings to preserve the original order + for key in split_dict: + split_dict[key] = tuple(sorted(split_dict[key])) + + # Create sets and subtract sets to create new sets + whole_set = set(musicnet_ids) + split_dict['train_mt3'] = whole_set - set(split_dict['validation_mt3']) - set( + split_dict['test_mt3_acoustic']) + split_dict['train_mt3_synth'] = split_dict['train_mt3'] + split_dict['train_mt3_acoustic'] = split_dict['train_mt3'] + split_dict['train_thickstun'] = whole_set - set(split_dict['test_thickstun_ext']) + split_dict['train_thickstun_synth'] = split_dict['train_thickstun'] + split_dict['train_mt3_em'] = whole_set - set(split_dict['validation_mt3']) - set( + split_dict['test_mt3_acoustic']) - MUSICNET_EM_MISSING_IDS + split_dict['train_thickstun_em'] = whole_set - set( + split_dict['test_thickstun_ext']) - MUSICNET_EM_MISSING_IDS + # Convert each tuple back to a list of strings + for key in split_dict: + split_dict[key] = [str(item) for item in split_dict[key]] + + # Write MT3 file_list + for split in ('train_mt3_synth', 'validation_mt3_synth'): + file_list = {} + for i, musicnet_id in enumerate(split_dict[split]): + file_list[i] = { + 'musicnet_id': musicnet_id, + 'n_frames': get_audio_file_info(musicnet_dict[musicnet_id]['wav_synth_file'])[1], + 'mix_audio_file': musicnet_dict[musicnet_id]['wav_synth_file'], + 'notes_file': musicnet_dict[musicnet_id]['notes_file_synth'], + 'note_events_file': musicnet_dict[musicnet_id]['note_events_file_synth'], + 'midi_file': musicnet_dict[musicnet_id]['mid_file'], + 'program': musicnet_dict[musicnet_id]['program'], + 'is_drum': musicnet_dict[musicnet_id]['is_drum'], + } + assert (len(file_list) == len(split_dict[split])) + output_index_file = os.path.join(output_index_dir, f'musicnet_{split}_file_list.json') + with open(output_index_file, 'w') as f: + json.dump(file_list, f, indent=4) + print(f'Created {output_index_file}') + + for split in ('train_mt3_acoustic', 'validation_mt3_acoustic', 'test_mt3_acoustic'): + file_list = {} + for i, musicnet_id in enumerate(split_dict[split]): + file_list[i] = { + 'musicnet_id': musicnet_id, + 'n_frames': get_audio_file_info(musicnet_dict[musicnet_id]['wav_acoustic_file'])[1], + 'mix_audio_file': musicnet_dict[musicnet_id]['wav_acoustic_file'], + 'notes_file': musicnet_dict[musicnet_id]['notes_file_acoustic'], + 'note_events_file': musicnet_dict[musicnet_id]['note_events_file_acoustic'], + 'midi_file': musicnet_dict[musicnet_id]['mid_file'], + 'program': musicnet_dict[musicnet_id]['program'], + 'is_drum': musicnet_dict[musicnet_id]['is_drum'], + } + assert (len(file_list) == len(split_dict[split])) + output_index_file = os.path.join(output_index_dir, f'musicnet_{split}_file_list.json') + with open(output_index_file, 'w') as f: + json.dump(file_list, f, indent=4) + print(f'Created {output_index_file}') + + split = 'train_mt3' + merged_file_list = {} + index = 0 + file_list_train_mt3_synth = json.load( + open(os.path.join(output_index_dir, 'musicnet_train_mt3_synth_file_list.json'))) + file_list_train_mt3_acoustic = json.load( + open(os.path.join(output_index_dir, 'musicnet_train_mt3_acoustic_file_list.json'))) + for d in [file_list_train_mt3_synth, file_list_train_mt3_acoustic]: + for key, value in d.items(): + new_key = f'{index}' + merged_file_list[new_key] = value + index += 1 + assert (len(merged_file_list) == 600) + output_index_file = os.path.join(output_index_dir, f'musicnet_{split}_file_list.json') + with open(output_index_file, 'w') as f: + json.dump(merged_file_list, f, indent=4) + print(f'Created {output_index_file}') + + # Write ThickStun file_list + split = 'train_thickstun' + file_list = {} + for i, musicnet_id in enumerate(split_dict[split]): + file_list[i] = { + 'musicnet_id': musicnet_id, + 'n_frames': get_audio_file_info(musicnet_dict[musicnet_id]['wav_synth_file'])[1], + 'mix_audio_file': musicnet_dict[musicnet_id]['wav_synth_file'], + 'notes_file': musicnet_dict[musicnet_id]['notes_file_synth'], + 'note_events_file': musicnet_dict[musicnet_id]['note_events_file_synth'], + 'midi_file': musicnet_dict[musicnet_id]['mid_file'], + 'program': musicnet_dict[musicnet_id]['program'], + 'is_drum': musicnet_dict[musicnet_id]['is_drum'], + } + file_list[i + 327] = { + 'musicnet_id': musicnet_id, + 'n_frames': get_audio_file_info(musicnet_dict[musicnet_id]['wav_acoustic_file'])[1], + 'mix_audio_file': musicnet_dict[musicnet_id]['wav_acoustic_file'], + 'notes_file': musicnet_dict[musicnet_id]['notes_file_acoustic'], + 'note_events_file': musicnet_dict[musicnet_id]['note_events_file_acoustic'], + 'midi_file': musicnet_dict[musicnet_id]['mid_file'], + 'program': musicnet_dict[musicnet_id]['program'], + 'is_drum': musicnet_dict[musicnet_id]['is_drum'], + } + assert (len(file_list) == len(split_dict[split]) * 2) + output_index_file = os.path.join(output_index_dir, f'musicnet_{split}_file_list.json') + with open(output_index_file, 'w') as f: + json.dump(file_list, f, indent=4) + print(f'Created {output_index_file}') + + for split in ('test_thickstun', 'test_thickstun_ext'): + file_list = {} + for i, musicnet_id in enumerate(split_dict[split]): + file_list[i] = { + 'musicnet_id': musicnet_id, + 'n_frames': get_audio_file_info(musicnet_dict[musicnet_id]['wav_acoustic_file'])[1], + 'mix_audio_file': musicnet_dict[musicnet_id]['wav_acoustic_file'], + 'notes_file': musicnet_dict[musicnet_id]['notes_file_acoustic'], + 'note_events_file': musicnet_dict[musicnet_id]['note_events_file_acoustic'], + 'midi_file': musicnet_dict[musicnet_id]['mid_file'], + 'program': musicnet_dict[musicnet_id]['program'], + 'is_drum': musicnet_dict[musicnet_id]['is_drum'], + } + assert (len(file_list) == len(split_dict[split])) + output_index_file = os.path.join(output_index_dir, f'musicnet_{split}_file_list.json') + with open(output_index_file, 'w') as f: + json.dump(file_list, f, indent=4) + print(f'Created {output_index_file}') + + # Write EM file_list + for split in ('train_thickstun_em', 'train_mt3_em'): + file_list = {} + for i, musicnet_id in enumerate(split_dict[split]): + file_list[i] = { + 'musicnet_id': musicnet_id, + 'n_frames': get_audio_file_info(musicnet_dict[musicnet_id]['wav_acoustic_file'])[1], + 'mix_audio_file': musicnet_dict[musicnet_id]['wav_acoustic_file'], + 'notes_file': musicnet_dict[musicnet_id]['notes_file_em'], + 'note_events_file': musicnet_dict[musicnet_id]['note_events_file_em'], + 'midi_file': musicnet_dict[musicnet_id]['mid_em_file'], + 'program': musicnet_dict[musicnet_id]['program'], + 'is_drum': musicnet_dict[musicnet_id]['is_drum'], + } + synth_ids = split_dict['train_mt3'] if split == 'train_mt3_em' else split_dict[ + 'train_thickstun'] + for i, musicnet_id in enumerate(synth_ids): + file_list[i + len(split_dict[split])] = { + 'musicnet_id': musicnet_id, + 'n_frames': get_audio_file_info(musicnet_dict[musicnet_id]['wav_synth_file'])[1], + 'mix_audio_file': musicnet_dict[musicnet_id]['wav_synth_file'], + 'notes_file': musicnet_dict[musicnet_id]['notes_file_synth'], + 'note_events_file': musicnet_dict[musicnet_id]['note_events_file_synth'], + 'midi_file': musicnet_dict[musicnet_id]['mid_file'], + 'program': musicnet_dict[musicnet_id]['program'], + 'is_drum': musicnet_dict[musicnet_id]['is_drum'], + } + if split == 'train_thickstun_em': + assert (len(file_list) == 320 + 313) + if split == 'train_mt3_em': + assert (len(file_list) == 300 + 293) + output_index_file = os.path.join(output_index_dir, f'musicnet_{split}_file_list.json') + with open(output_index_file, 'w') as f: + json.dump(file_list, f, indent=4) + print(f'Created {output_index_file}') + + for split in ('validation_mt3_em', 'test_mt3_em', 'test_em_table2', 'test_thickstun_em', + 'test_thickstun_ext_em'): + file_list = {} + for i, musicnet_id in enumerate(split_dict[split]): + file_list[i] = { + 'musicnet_id': musicnet_id, + 'n_frames': get_audio_file_info(musicnet_dict[musicnet_id]['wav_acoustic_file'])[1], + 'mix_audio_file': musicnet_dict[musicnet_id]['wav_acoustic_file'], + 'notes_file': musicnet_dict[musicnet_id]['notes_file_em'], + 'note_events_file': musicnet_dict[musicnet_id]['note_events_file_em'], + 'midi_file': musicnet_dict[musicnet_id]['mid_em_file'], + 'program': musicnet_dict[musicnet_id]['program'], + 'is_drum': musicnet_dict[musicnet_id]['is_drum'], + } + assert (len(file_list) == len(split_dict[split])) + output_index_file = os.path.join(output_index_dir, f'musicnet_{split}_file_list.json') + with open(output_index_file, 'w') as f: + json.dump(file_list, f, indent=4) + print(f'Created {output_index_file}') + + # Write Cheuk file_list + for split in ['test_cheuk_table2']: + file_list = {} + for i, musicnet_id in enumerate(split_dict[split]): + file_list[i] = { + 'musicnet_id': musicnet_id, + 'n_frames': get_audio_file_info(musicnet_dict[musicnet_id]['wav_acoustic_file'])[1], + 'mix_audio_file': musicnet_dict[musicnet_id]['wav_acoustic_file'], + 'notes_file': musicnet_dict[musicnet_id]['notes_file_acoustic'], + 'note_events_file': musicnet_dict[musicnet_id]['note_events_file_acoustic'], + 'midi_file': musicnet_dict[musicnet_id]['mid_file'], + 'program': musicnet_dict[musicnet_id]['program'], + 'is_drum': musicnet_dict[musicnet_id]['is_drum'], + } + assert (len(file_list) == len(split_dict[split])) + output_index_file = os.path.join(output_index_dir, f'musicnet_{split}_file_list.json') + with open(output_index_file, 'w') as f: + json.dump(file_list, f, indent=4) + print(f'Created {output_index_file}') + + +if __name__ == '__main__': + from config.config import shared_cfg + data_home = shared_cfg['PATH']['data_home'] + preprocess_musicnet16k(data_home=data_home, dataset_name='musicnet') \ No newline at end of file diff --git a/amt/src/utils/preprocess/preprocess_rnsynth.py b/amt/src/utils/preprocess/preprocess_rnsynth.py new file mode 100644 index 0000000000000000000000000000000000000000..641d03826217a23b48442b443555fdd6c2b8ac70 --- /dev/null +++ b/amt/src/utils/preprocess/preprocess_rnsynth.py @@ -0,0 +1,332 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +"""preprocess_rnsynth.py + +RNSynth: Randomly generated note sequences using the NSynth dataset. + +""" +import os +import random +import glob +import json +import logging +import numpy as np +from typing import Dict, Literal, Optional +from utils.note_event_dataclasses import Note +from utils.audio import get_audio_file_info, load_audio_file, write_wav_file, guess_onset_offset_by_amp_envelope +from utils.midi import note_event2midi +from utils.note2event import note2note_event, sort_notes, validate_notes, trim_overlapping_notes, mix_notes + +# yapf: disable +QUALITY_VOCAB = [ + 'bright', 'dark', 'distortion', 'fast_decay', 'long_release', 'multiphonic', 'nonlinear_env', + 'percussive', 'reverb', 'tempo-synced' +] + +INSTRUMENT_FAMILY_VOCAB = [ + 'bass', 'brass', 'flute', 'guitar', 'keyboard', 'mallet', 'organ', 'reed', 'string', 'vocal', + 'synth_lead' +] + +INSTRUMENT_SOURCE_VOCAB = ['acoustic', 'electronic', 'synthetic'] + +INSTRUMENT_MAPPING = { + # key: (instrument_family, instrument_source) + ('bass', 'acoustic'): {'program': 32, 'channel': 0, 'allow_poly': False,}, + ('bass', 'electronic'): {'program': 33, 'channel': 0, 'allow_poly': False,}, + ('bass', 'synthetic'): {'program': 38, 'channel': 0, 'allow_poly': False,}, + ('brass', 'acoustic'): {'program': 61, 'channel': 1, 'allow_poly': True,}, + ('brass', 'electronic'): {'program': 62, 'channel': 1, 'allow_poly': True,}, + ('brass', 'synthetic'): {'program': 62, 'channel': 1, 'allow_poly': True, }, + ('flute', 'acoustic'): {'program': 73, 'channel': 2, 'allow_poly': False,}, + ('flute', 'electronic'): {'program': 76, 'channel': 2, 'allow_poly': False,}, + ('flute', 'synthetic'): {'program': 76, 'channel': 2, 'allow_poly': False,}, + ('guitar', 'acoustic'): {'program': 24, 'channel': 3, 'allow_poly': True,}, + ('guitar', 'electronic'): {'program': 27, 'channel': 3, 'allow_poly': True,}, + ('guitar', 'synthetic'): {'program': 27, 'channel': 3, 'allow_poly': True,}, + ('keyboard', 'acoustic'): {'program': 0, 'channel': 4, 'allow_poly': True,}, + ('keyboard', 'electronic'): {'program': 4, 'channel': 4, 'allow_poly': True,}, + ('keyboard', 'synthetic'): {'program': 80, 'channel': 4, 'allow_poly': True,}, + ('mallet', 'acoustic'): {'program': 12, 'channel': 5, 'allow_poly': True,}, + ('mallet', 'electronic'): {'program': 12, 'channel': 5, 'allow_poly': True,}, + ('mallet', 'synthetic'): {'program': 12, 'channel': 5, 'allow_poly': True,}, + ('organ', 'acoustic'): {'program': 16, 'channel': 6, 'allow_poly': True,}, + ('organ', 'electronic'): {'program': 18, 'channel': 6, 'allow_poly': True,}, + ('organ', 'synthetic'): {'program': 18, 'channel': 6, 'allow_poly': True,}, + ('reed', 'acoustic'): {'program': 65, 'channel': 7, 'allow_poly': True,}, + ('reed', 'electronic'): {'program': 83, 'channel': 7, 'allow_poly': True,}, + ('reed', 'synthetic'): {'program': 83, 'channel': 7, 'allow_poly': True,}, + ('string', 'acoustic'): {'program': 48, 'channel': 8, 'allow_poly': True,}, + ('string', 'electronic'): {'program': 50, 'channel': 8, 'allow_poly': True,}, + ('string', 'synthetic'): {'program': 50, 'channel': 8, 'allow_poly': True,}, + # ('vocal', 'acoustic'): [56], + # ('vocal', 'electronic'): [56], + # ('vocal', 'synthetic'): [56], + ('synth_lead', 'acoustic'): {'program': 80, 'channel': 9, 'allow_poly': True,}, + ('synth_lead', 'electronic'): {'program': 80, 'channel': 9, 'allow_poly': True,}, + ('synth_lead', 'synthetic'): {'program': 80, 'channel': 9, 'allow_poly': True,}, +} + + +CHANNEL_INFO = { + 0: {'name': 'bass', 'max_poly': 1}, + 1: {'name': 'brass', 'max_poly': 4}, + 2: {'name': 'flute', 'max_poly': 1}, + 3: {'name': 'guitar', 'max_poly': 6}, + 4: {'name': 'keyboard', 'max_poly': 8}, + 5: {'name': 'mallet', 'max_poly': 4}, + 6: {'name': 'organ', 'max_poly': 8}, + 7: {'name': 'reed', 'max_poly': 2}, + 8: {'name': 'string', 'max_poly': 4}, + 9: {'name': 'synth_lead', 'max_poly': 2}, +} +# yapf: enable + + +class RandomNSynthGenerator(object): + + def __init__(self, channel_info: Dict=CHANNEL_INFO): + self.num_channels = len(channel_info) + self.channel_info = channel_info + self.channel_max_poly = [channel_info[ch]['max_poly'] for ch in range(self.num_channels)] + + # channel_space_left[ch]: current state of empty space for notes left in channel + self.channel_space_left = [0] * self.num_channels + for ch in range(self.num_channels): + self.reset_space_left(ch) + + def reset_space_left(self, ch: int): + max_poly = self.channel_max_poly[ch] + if max_poly == 1: + self.channel_space_left[ch] = 1 + else: + self.channel_space_left[ch] = np.random.randint(1, max_poly + 1 ) + + + + +def setup_logger(log_file: str) -> logging.Logger: + logger = logging.getLogger('my_logger') + logger.setLevel(logging.DEBUG) + file_handler = logging.FileHandler(log_file) + formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s') + file_handler.setFormatter(formatter) + if not logger.handlers: + logger.addHandler(file_handler) + console_handler = logging.StreamHandler() + console_handler.setLevel(logging.DEBUG) + console_formatter = logging.Formatter('%(levelname)s - %(message)s') + console_handler.setFormatter(console_formatter) + logger.addHandler(console_handler) + return logger + + +def get_duration_by_detecting_offset(audio_file: os.PathLike, + side_info: Optional[str] = None, + offset_threshold: float = 0.02) -> float: + fs, n_frames, _ = get_audio_file_info(audio_file) + x = load_audio_file(audio_file, fs=fs) + if side_info is not None and 'fast_decay' in side_info or 'percussive' in side_info: + x = x[:int(fs * 2.0)] # limit to 1.5 sec + + _, offset, _ = guess_onset_offset_by_amp_envelope( + x, fs=fs, onset_threshold=0., offset_threshold=offset_threshold, frame_size=128) + offset = min(offset, n_frames) + dur_sec = np.floor((offset / fs) * 1000) / 1000 + return dur_sec + + +def random_key_cycle(d: Dict): + keys = list(d.keys()) + while True: + random.shuffle(keys) + for i, key in enumerate(keys): + is_last_element = (i == len(keys) - 1) # Check if it's the last element in the cycle + yield (d[key], is_last_element) + + +def create_sound_info(base_dir: os.PathLike, logger: logging.Logger, + split: Literal['train', 'validation', 'test'], metadata_file: os.PathLike): + """Create a dictionary of sound info from the metadata file.""" + with open(metadata_file, 'r') as f: + metadata = json.load(f) + logger.info(f'Loaded {metadata_file}. Number of examples: {len(metadata)}') + + # Create a sound_info dictionary + sound_info = {} # key: nsynth_id, value: dictionary of sound info + count_skipped = 0 + skipped_instrument_family = set() + for i, (k, v) in enumerate(metadata.items()): + if i % 5000 == 0: + print(f'Creating sound info {i} / {len(metadata)}') + nsynth_id = v['note'] + instrument_family = v['instrument_family_str'] + instrument_source = v['instrument_source_str'] + audio_file = os.path.join(base_dir, split, 'audio', k + '.wav') + if not os.path.exists(audio_file): + raise FileNotFoundError(audio_file) + dur_sec = get_duration_by_detecting_offset( + audio_file, side_info=v['qualities_str'], offset_threshold=0.001) + + if INSTRUMENT_MAPPING.get((instrument_family, instrument_source), None) is not None: + sound_info[nsynth_id] = { + 'audio_file': + audio_file, + 'program': + INSTRUMENT_MAPPING[instrument_family, instrument_source]['program'], + 'pitch': + int(v['pitch']), + 'velocity': + int(v['velocity']), + 'channel_group': + INSTRUMENT_MAPPING[instrument_family, instrument_source]['channel'], + 'dur_sec': + dur_sec, + } + else: + count_skipped += 1 + skipped_instrument_family.add(instrument_family) + logger.info(f'Created sound info. Number of examples: {len(sound_info)}') + logger.info(f'Number of skipped examples: {count_skipped}, {skipped_instrument_family}') + del metadata + + # Regroup sound_info by channel_group + sound_info_by_channel_group = {} # key: channel_group, value: list of sound_info + num_channel_groups = 10 + for i in range(num_channel_groups): + sound_info_by_channel_group[i] = {} + for nsynth_id, info in sound_info.items(): + channel_group = info['channel_group'] + sound_info_by_channel_group[channel_group][nsynth_id] = info + del sound_info + channel_group_counts = [ + (CHANNEL_INFO[k]['name'], len(v)) for k, v in sound_info_by_channel_group.items() + ] + logger.info('Count of sound_info in each channel_group: {}'.format(channel_group_counts)) + return sound_info_by_channel_group, num_channel_groups + + + + + + +def random_nsynth_generator(data_home: os.PathLike, + dataset_name: str = 'random_nsynth', + generation_minutes_per_file: float = 4.0) -> None: + """ + Splits: + 'train' + 'validation' + 'test' + + Writes: + - {dataset_name}_{split}_file_list.json: a dictionary with the following keys: + { + index: + { + 'random_nsynth_id': random_nsynth_id, # = nsynth_id + 'n_frames': (int), + 'stem_file': 'path/to/stem.npy', + 'mix_audio_file': 'path/to/mix.wav', + 'notes_file': 'path/to/notes.npy', + 'note_events_file': 'path/to/note_events.npy', + 'midi_file': 'path/to/midi.mid', # this is 120bpm converted midi file from note_events + 'program': List[int], + 'is_drum': List[int], # [0] or [1] + } + } + """ + # Directory and file paths + base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') + output_index_dir = os.path.join(data_home, 'yourmt3_indexes') + os.makedirs(output_index_dir, exist_ok=True) + + # Setup logger + log_file = os.path.join(base_dir, 'sound_genetation_log.txt') + logger = setup_logger(log_file) + + # Load annotation json file as dictionary + split = 'validation' + metadata_file = os.path.join(base_dir, split, 'examples.json') + + # Create a sound_info dictionary + sound_info_by_channel_group, num_channel_groups = create_sound_info( + base_dir, logger, split, metadata_file) + + # Gnenerate random note sequences + max_frames_per_file = int(generation_minutes_per_file * 60 * 16000) + sound_gens = [ + random_key_cycle(sound_info_by_channel_group[key]) + for key in sorted(sound_info_by_channel_group.keys()) + ] + + # 5-minute audio generation + notes = [] + y = np.zeros((num_channel_groups, max_frames_per_file), dtype=np.float32) # (C, L) + bass_channel = 0 # loop for a cycle of bass channel generation + cur_frame = 0 + # is_last_element_bass = False + #while cur_frame < max_frames_per_file and is_last_element_bass == False: + + # x: source audio, y: target audio for each channel + x_info, is_last_element = next(sound_gens[ch]) + if ch == bass_channel: + is_last_element = is_last_element_bass + + # info about this channel + onset_in_frame = cur_frame + offset_in_frame = cur_frame + int(x_info['dur_sec'] * 16000) + + x = load_audio_file(x_info['audio_file'], fs=16000) + x = x[:int(x_info['dur_sec'] * 16000)] + y[ch, :] = 0 + + + + +def preprocess_random_nsynth_16k(data_home=os.PathLike, dataset_name='random_nsynth') -> None: + """ + Splits: + 'train' + 'validation' + 'test' + + Writes: + - {dataset_name}_{split}_file_list.json: a dictionary with the following keys: + { + index: + { + 'random_nsynth_id': random_nsynth_id, # = nsynth_id + 'n_frames': (int), + 'stem_file': 'path/to/stem.npy', + 'mix_audio_file': 'path/to/mix.wav', + 'notes_file': 'path/to/notes.npy', + 'note_events_file': 'path/to/note_events.npy', + 'midi_file': 'path/to/midi.mid', # this is 120bpm converted midi file from note_events + 'program': List[int], + 'is_drum': List[int], # [0] or [1] + } + } + """ + # Directory and file paths + base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') + output_index_dir = os.path.join(data_home, 'yourmt3_indexes') + os.makedirs(output_index_dir, exist_ok=True) + + # Setup logger + log_file = os.path.join(base_dir, 'log.txt') + logger = setup_logger(log_file) + + # Load annotation json file as dictionary + split = 'validation' + metadata_file = os.path.join(base_dir, split, 'examples.json') + with open(metadata_file, 'r') as f: + metadata = json.load(f) + logger.info(f'Loaded {metadata_file}. Number of examples: {len(metadata)}') \ No newline at end of file diff --git a/amt/src/utils/preprocess/preprocess_rwc_pop.py b/amt/src/utils/preprocess/preprocess_rwc_pop.py new file mode 100644 index 0000000000000000000000000000000000000000..0cbac1a86bf35cf45fe2bab1edbdeba8562c468e --- /dev/null +++ b/amt/src/utils/preprocess/preprocess_rwc_pop.py @@ -0,0 +1,155 @@ +"""preprocess_rwc_pop.py""" +import os +import json +import csv +from typing import Dict, List, Tuple +import numpy as np +from utils.audio import get_audio_file_info, load_audio_file +from utils.midi import midi2note, note_event2midi +from utils.note2event import note2note_event, sort_notes, validate_notes, trim_overlapping_notes, extract_program_from_notes +from utils.event2note import event2note_event +from utils.note_event_dataclasses import Note, NoteEvent +from utils.utils import note_event2token2note_event_sanity_check +from mido import Message, MidiFile + +ID_NO_BASS = ['071', '072', '073', '074', '075', '076', '077', '078', '079', '080'] # 10 files + + +def check_file_existence(file: str) -> bool: + """Checks if file exists.""" + res = True + if not os.path.exists(file): + res = False + elif get_audio_file_info(file)[1] < 10 * 16000: + print(f'File {file} is too short.') + res = False + return res + + +def create_note_event_and_note_from_midi(mid_file: str, + id: str, + ignore_pedal: bool = True) -> Tuple[Dict, Dict]: + """Extracts note or note_event and metadata from midi: + + Returns: + notes (dict): note events and metadata. + note_events (dict): note events and metadata. + """ + notes, dur_sec, programs = midi2note( + mid_file, + binary_velocity=True, + ch_9_as_drum=True, + trim_overlap=True, + fix_offset=True, + quantize=True, + verbose=0, + minimum_offset_sec=0.01, + drum_offset_sec=0.01, + ignore_pedal=ignore_pedal, + return_programs=True) + + # Check drum availability + has_drum = False + for note in notes: + if note.is_drum: + has_drum = True + is_drum = [0] * len(programs) + if has_drum: + is_drum[9] = 1 + + return { # notes + 'rwc_pop_id': id, + 'program': programs, + 'is_drum': is_drum, + 'duration_sec': dur_sec, + 'notes': notes, + }, { # note_events + 'rwc_pop_id': id, + 'program': programs, + 'is_drum': is_drum, + 'duration_sec': dur_sec, + 'note_events': note2note_event(notes), + } + + +def preprocess_rwc_pop16k(data_home=os.PathLike, dataset_name='rwc_pop') -> None: + # Directory and file paths + base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') + output_index_dir = os.path.join(data_home, 'yourmt3_indexes') + os.makedirs(output_index_dir, exist_ok=True) + + # Load CSV: construct id to midi/wav dictionary + csv_file = os.path.join(base_dir, 'wav_to_midi_filename_mapping.csv') + rwc_bass = {} + with open(csv_file, 'r') as f: + reader = csv.reader(f) + headers = next(reader) + + for row in reader: + id = row[2] + # Skip unused ids + # if id in UNUSED_IDS: + # continue + # if id in MULTI_BASS_IDS: + # continue + + mix_audio_file = os.path.join(base_dir, headers[0] + row[0], + row[1] + ' ' + headers[1] + '.wav') + assert check_file_existence(mix_audio_file) + # mid_file = os.path.join(base_dir, 'MIDI', id + '.mid') + mid_file = os.path.join(base_dir, 'MIDI-Bass-Octave-fixed-v2', id + '_bass.mid') + # assert os.path.exists(mid_file) + if not os.path.exists(mid_file): + print(mid_file, "does not exist") + continue + + notes_file = mid_file.replace('.mid', '_notes.npy') + note_events_file = mid_file.replace('.mid', '_note_events.npy') + + rwc_bass[id] = { + 'rwc_pop_id': id, + 'n_frames': get_audio_file_info(mix_audio_file)[1], + 'mix_audio_file': mix_audio_file, + 'notes_file': notes_file, + 'note_events_file': note_events_file, + 'midi_file': mid_file, + 'program': None, + 'is_drum': None, + } + assert len(rwc_bass) == 90 + + # Create note and note_event files + for id in rwc_bass.keys(): + midi_file = rwc_bass[id]['midi_file'] + notes_file = rwc_bass[id]['notes_file'] + note_events_file = rwc_bass[id]['note_events_file'] + + # Create note and note_event files + notes, note_events = create_note_event_and_note_from_midi(midi_file, id, ignore_pedal=True) + + # Update programs and is_drum + rwc_bass[id]['program'] = notes['program'] + rwc_bass[id]['is_drum'] = notes['is_drum'] + + # Save note and note_event files + np.save(notes_file, notes, allow_pickle=True, fix_imports=False) + print(f'Created {notes_file}') + np.save(note_events_file, note_events, allow_pickle=True, fix_imports=False) + print(f'Created {note_events_file}') + + # saving bpm 120 midi files + bpm120_midi_file = midi_file.replace('.mid', '_bpm120.mid') + note_event2midi(note_events['note_events'], bpm120_midi_file) + print(f'Created {bpm120_midi_file}') + + # Save index file + split = 'bass' + output_index_file = os.path.join(output_index_dir, f'rwc_pop_{split}_file_list.json') + + file_list = {} + for i, id in enumerate(rwc_bass.keys()): + file_list[i] = rwc_bass[id] + + with open(output_index_file, 'w') as f: + json.dump(file_list, f, indent=4) + print(f'Created {output_index_file}') diff --git a/amt/src/utils/preprocess/preprocess_rwc_pop_full.py b/amt/src/utils/preprocess/preprocess_rwc_pop_full.py new file mode 100644 index 0000000000000000000000000000000000000000..0cdf585f00c11ccb7de2f7daa856f3004bbfdbff --- /dev/null +++ b/amt/src/utils/preprocess/preprocess_rwc_pop_full.py @@ -0,0 +1,326 @@ +"""preprocess_rwc_pop.py""" +import os +import glob +import re +import json +import csv +from typing import Dict, List, Any, Tuple +import numpy as np +from utils.audio import get_audio_file_info, load_audio_file +from utils.midi import midi2note, note_event2midi +from utils.note2event import note2note_event, sort_notes, validate_notes, trim_overlapping_notes, extract_program_from_notes +from utils.event2note import event2note_event +from utils.note_event_dataclasses import Note, NoteEvent +from utils.utils import note_event2token2note_event_sanity_check +from mido import MetaMessage, Message, MidiFile, MidiTrack + +# UNUSED_IDS = ["010", "071", "099", "023", "034", "036", "038", "049", "060", "062"] +# UNUSED_IDS = ["071", "099", "049", "060", "062"] +UNUSED_IDS = [] + +DRUM_CHANNEL = 9 # all drums are in channel 9 in geerdes dataset +DRUM_PROGRAM = 128 +SINGING_VOICE_PROGRAM = 100 +SINGING_VOICE_CHORUS_PROGRAM = 101 +TRACK_NAME_TO_PROGRAM_MAP = { # compared by exact match of lowercase + "Singing Voice": SINGING_VOICE_PROGRAM, + "Singing Voice (Chorus)": SINGING_VOICE_CHORUS_PROGRAM, + "Drums": DRUM_PROGRAM, +} + +# yapf: disable +TRACK_NAME_FILTERS = { + SINGING_VOICE_PROGRAM: {"include": ["MELO", "VOCAL"], "exclude": ["SUB", "GT"]}, + SINGING_VOICE_CHORUS_PROGRAM: {"include": ["CHORUS", "SUB VOCAL", "SUB MELO"], + "exclude": ["/", "GT"]}, + DRUM_PROGRAM: {"include": ["DRUMS", "DR", "HIHAT", "BD&SD", "TOM", "KICK"], + "exclude": ["ATOMOS"], "exact": ["DS"]}, + 0: {"include": ["P.F.", "PF", "PIANO", "A.P", "CLAV", "CEMBAL", "HARPSI"], "exclude": ["E.PIANO", "MARIMBA"]}, + 2: {"include": ["E.P"], "exclude": []}, + 8: {"include": ["GLOCKEN", "VIBRA", "VIBE", "MARIMBA", "BELL", "CHIME", "CHAIM", "KALIMB", "CHIMRE", "MALLET"], + "exclude": []}, + 16: {"include": ["ORG", "HAMO", "HARMONICA", "ACCORD"], "exclude": []}, + 24: {"include": ["MANDORIN", "AG", "NYLON", "AC.G", "GUITAR", "A.G", "E.G", "GT", "G. SOLO", "CLEAN LEAD", "SITAR", "ATOMOS", "ATMOS", + "CLEAN"], + "exclude": ["DIST", "DIS.", "D.", "E.G SOLO", "E.G.SOLO"]}, + 30: {"include": ["OD L", "OD R", "DIS.", "DIST GT", "D.G", "DIST", "DIS.SOLO", "E.GUITAR (SOLO)", "E.G SOLO", "LEAD", "E.G.SOLO", "EG", "GT MELO"], + "exclude": ["PAD","SYN.LEAD"]}, + 33: {"include": ["BASS"], "exclude": []}, + 48: {"include": ["OR 2", "ST", "STR", "ORCH", "PIZZ", "HIT", "TIMPANI", "VIORA", "VIOLA", "VIOLIN", "VN", "VA", "VC", "HARP", "LO FI", "CHO", "VLN", "CELLO"], + "exclude": ["CHORUS", "HARPSI", "STEEL", "GUITAR", "PAD", "BRASS", "GT", "HORN"], + "exact": ["OR"]}, + 56: {"include": ["BRAS", "TRUMP", "TP", "TB", "TROM", "HORN", "FLUGEL"], "exclude": []}, + 64: {"include": ["SAX", "OBOE", "BASS"], "exclude": ["SYNSAX"]}, + 72: {"include": ["FLUTE", "PICO", "BOTTLE", "GAYA"], "exclude": []}, + 80: {"include": ["S SOLO", "SYN SOLO", "SOLO SYNTH", "SYNTH SOLO", "SYN.LEAD", "SYNTH(SEQ)", "PORTASYN", "SQ", "SEQ", "VOICE"], "exclude": []}, + 88: {"include": ["SYNTH", "SYN", "PAD", "FANTASIA", "BRIGHTNESS", "FANTASY"], "exclude": ["SYNBELL", "PORTA", "SOLO", "SEQ", "LEAD", "ORGAN", "BRAS", "BASS", "TROM"]}, + None: {"include": ["INTRO SE", "WOW", "PERC", "EXC", "REVERSE", "GONG", "PER.", "RAP", "REV", "S.E", "LASER", + "LESER", "TAMBOURINE", "KANE", "PER", "SHAKER", "RWC-MDB"], + "exclude": [], + "exact": ["SE", "EX", "808", "ICERAIN"]}, + "USE RWC PROGRAM MAP": {"include": ["KIRA", "KILA", "ETHNIC&GK"], "exclude": [], "exact": ["FUE", "OU-01A"]}, +} +# yapf: enable +RWC_PROGRAM_MAP = { + 9: 8, + 11: 8, + 74: 72, + 94: 80, + 98: 88, + 100: 88, +} + +PRG2CH = { + 0: (0, "Acoustic Piano"), + 2: (1, "Electric Piano"), + 8: (2, "Chromatic Percussion"), + 16: (3, "Organ"), + 24: (4, "Guitar (clean)"), + 30: (5, "Guitar (distortion)"), + 33: (6, "Bass"), + 48: (7, "Strings"), + 56: (8, "Brass"), + DRUM_PROGRAM: (9, "Drums"), + 64: (10, "Reed"), + 72: (11, "Pipe"), + 80: (12, "Synth Lead"), + 88: (13, "Synth Pad"), + SINGING_VOICE_PROGRAM: (14, "Singing Voice"), + SINGING_VOICE_CHORUS_PROGRAM: (15, "Singing Voice (Chorus)"), +} + + +def find_matching_filters(input_text, filters): + input_text = input_text.upper() + + def text_matches_filter(text, filter_dict): + matchness = False + if "exact" in filter_dict: + for keyword in filter_dict["exact"]: + if keyword == text: + matchness = True + break + for keyword in filter_dict["include"]: + if keyword in text: + matchness = True + break + for keyword in filter_dict["exclude"]: + if keyword in text: + matchness = False + break + return matchness + + matching_filters = [] + for filter_name, filter_dict in filters.items(): + if text_matches_filter(input_text, filter_dict): + matching_filters.append(filter_name) + return matching_filters + + +def generate_corrected_midi(org_mid_file: os.PathLike, + new_mid_file: os.PathLike, + filters: Dict[Any, Dict[str, List]], + prg2ch: Dict[int, Tuple[int, str]]): + # Load original MIDI file + org_mid = MidiFile(org_mid_file) + + # Create a new MIDI file + new_mid = MidiFile(ticks_per_beat=org_mid.ticks_per_beat) + + # Extract global messages from the first track (usually the master track) + global_messages = [msg for msg in org_mid.tracks[0] if msg.is_meta] + global_track = MidiTrack(global_messages) + new_mid.tracks.append(global_track) + + # Loop over all tracks + for track in org_mid.tracks[1:]: + # Get track name + track_name = None + for msg in track: + if msg.type == 'track_name': + track_name = msg.name + break + if track_name is None: + raise ValueError('track name not found in midi file') + + # Get program number from track name + matching_filters = find_matching_filters(track_name, filters) + assert (len(matching_filters) != 0) + if isinstance(matching_filters[0], int): + program = matching_filters[0] + elif matching_filters[0] == "USE RWC PROGRAM MAP": + for msg in track: + if msg.type == 'program_change': + program = RWC_PROGRAM_MAP.get(msg.program, msg.program) + break + elif matching_filters[0] == None: + continue + + # Get channel and new track name + ch, new_track_name = prg2ch[program] + + # Copy messages to new track with new program, channel, and track_name + new_track = MidiTrack() + new_track.append(MetaMessage('track_name', name=new_track_name, + time=0)) + if program == DRUM_PROGRAM: + new_track.append( + Message('program_change', program=0, time=0, channel=9)) + else: + new_track.append( + Message('program_change', program=program, time=0, channel=ch)) + new_mid.tracks.append(new_track) + + for msg in track: + if msg.type in ['track_name', 'instrument_name', 'program_change']: + continue + else: + new_msg = msg.copy() + if hasattr(msg, 'channel'): + new_msg.channel = ch + new_track.append(new_msg) + + # Save new MIDI file + new_mid.save(new_mid_file) + print(f'Created {new_mid_file}') + + +def check_file_existence(file: str) -> bool: + """Checks if file exists.""" + res = True + if not os.path.exists(file): + res = False + elif get_audio_file_info(file)[1] < 10 * 16000: + print(f'File {file} is too short.') + res = False + return res + + +def create_note_event_and_note_from_midi( + mid_file: str, + id: str, + ch_9_as_drum: bool = False, + track_name_to_program: Dict = None, + ignore_pedal: bool = False) -> Tuple[Dict, Dict]: + """Create note_events and notes from midi file.""" + + # Load midi file + notes, dur_sec, program = midi2note( + mid_file, + ch_9_as_drum=ch_9_as_drum, + track_name_to_program=track_name_to_program, + binary_velocity=True, + ignore_pedal=ignore_pedal, + return_programs=True) + program = [x for x in set(program) + if x is not None] # remove None and duplicates + return { # notes + 'rwc_pop_id': id, + 'program': program, + 'is_drum': [1 if p == DRUM_PROGRAM else 0 for p in program], + 'duration_sec': dur_sec, + 'notes': notes, + }, { # note_events + 'rwc_pop_id': id, + 'program': program, + 'is_drum': [1 if p == DRUM_PROGRAM else 0 for p in program], + 'duration_sec': dur_sec, + 'note_events': note2note_event(notes), + } + + +def preprocess_rwc_pop_full16k(data_home='../../data', + dataset_name='rwc_pop') -> None: + # Directory and file paths + base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') + output_index_dir = os.path.join(data_home, 'yourmt3_indexes') + os.makedirs(output_index_dir, exist_ok=True) + + # Load CSV: construct id to midi/wav dictionary + csv_file = os.path.join(base_dir, 'wav_to_midi_filename_mapping.csv') + rwc_all = {} + with open(csv_file, 'r') as f: + reader = csv.reader(f) + headers = next(reader) + + for row in reader: + id = row[2] + mix_audio_file = os.path.join(base_dir, headers[0] + row[0], + row[1] + ' ' + headers[1] + '.wav') + assert check_file_existence(mix_audio_file) + mid_file = os.path.join(base_dir, 'MIDI', id + '.mid') + assert os.path.exists(mid_file) + notes_file = mid_file.replace('.mid', '_notes.npy') + note_events_file = mid_file.replace('.mid', '_note_events.npy') + + rwc_all[id] = { + 'rwc_pop_id': id, + 'n_frames': get_audio_file_info(mix_audio_file)[1], + 'mix_audio_file': mix_audio_file, + 'notes_file': notes_file, + 'note_events_file': note_events_file, + 'midi_file': mid_file, + 'program': None, + 'is_drum': None, + } + assert len(rwc_all) == 100 + + # Generate corrected MIDI files by reassigning program numbers + os.makedirs(os.path.join(base_dir, 'MIDI_full_corrected'), exist_ok=True) + for id, info in rwc_all.items(): + org_mid_file = info['midi_file'] + new_mid_file = org_mid_file.replace('/MIDI/', '/MIDI_full_corrected/') + generate_corrected_midi(org_mid_file, + new_mid_file, + filters=TRACK_NAME_FILTERS, + prg2ch=PRG2CH) + # Update file path with corrected MIDI file + rwc_all[id]['midi_file'] = new_mid_file + rwc_all[id]['notes_file'] = new_mid_file.replace('.mid', '_notes.npy') + rwc_all[id]['note_events_file'] = new_mid_file.replace( + '.mid', '_note_events.npy') + + # Unused ids + for id in UNUSED_IDS: + rwc_all.pop(str(id)) + print(f'Number of used IDs: {len(rwc_all)}, Unused ids: {UNUSED_IDS}') + + # Create note and note_event files + for id in rwc_all.keys(): + midi_file = rwc_all[id]['midi_file'] + notes_file = rwc_all[id]['notes_file'] + note_events_file = rwc_all[id]['note_events_file'] + + # Create note and note_event files + notes, note_events = create_note_event_and_note_from_midi( + midi_file, + id, + ch_9_as_drum=False, # we will use track_name_to_program instead + track_name_to_program=TRACK_NAME_TO_PROGRAM_MAP, + ignore_pedal=False) + + # Update programs and is_drum + rwc_all[id]['program'] = notes['program'] + rwc_all[id]['is_drum'] = notes['is_drum'] + + # Save note and note_event files + np.save(notes_file, notes, allow_pickle=True, fix_imports=False) + print(f'Created {notes_file}') + np.save(note_events_file, + note_events, + allow_pickle=True, + fix_imports=False) + print(f'Created {note_events_file}') + + # Save index file + split = 'full' + output_index_file = os.path.join(output_index_dir, + f'rwc_pop_{split}_file_list.json') + + file_list = {} + for i, id in enumerate(rwc_all.keys()): + file_list[i] = rwc_all[id] + + with open(output_index_file, 'w') as f: + json.dump(file_list, f, indent=4) + print(f'Created {output_index_file}') diff --git a/amt/src/utils/preprocess/preprocess_slakh.py b/amt/src/utils/preprocess/preprocess_slakh.py new file mode 100644 index 0000000000000000000000000000000000000000..5fa3a037d7a8478e3464d0fdf0fd669138dc8984 --- /dev/null +++ b/amt/src/utils/preprocess/preprocess_slakh.py @@ -0,0 +1,239 @@ +""" preprocess_mtrack_slakh.py + +""" +import os +import time +import json +from typing import Dict, List, Tuple +import numpy as np +from utils.audio import get_audio_file_info, load_audio_file +from utils.midi import midi2note +from utils.note2event import note2note_event, mix_notes +import mirdata +from utils.mirdata_dev.datasets import slakh16k + + +def create_audio_stem_from_mtrack(ds: mirdata.core.Dataset, + mtrack_id: str, + delete_source_files: bool = False) -> Dict: + """Extracts audio stems and metadata from a multitrack.""" + mtrack = ds.multitrack(mtrack_id) + track_ids = mtrack.track_ids + max_length = 0 + program_numbers = [] + is_drum = [] + audio_tracks = [] # multi-channel audio array (C, T) + + # collect all the audio tracks and their metadata + for track_id in track_ids: + track = ds.track(track_id) + audio_file = track.audio_path + program_numbers.append(track.program_number) + is_drum.append(1) if track.is_drum else is_drum.append(0) + + fs, n_frames, n_channels = get_audio_file_info(audio_file) + assert (fs == 16000 and n_channels == 1) + max_length = n_frames if n_frames > max_length else max_length + audio = load_audio_file(audio_file, dtype=np.int16) # returns bytes + audio = audio / 2**15 + audio = audio.astype(np.float16) + audio_tracks.append(audio) + if delete_source_files: + print(f'🗑️ Deleting {audio_file} ...') + os.remove(audio_file) + + # collate all the audio tracks into a single array + n_tracks = len(track_ids) + audio_array = np.zeros((n_tracks, max_length), dtype=np.float16) + for j, audio in enumerate(audio_tracks): + audio_array[j, :len(audio)] = audio + + stem_content = { + 'mtrack_id': mtrack_id, # str + 'program': np.array(program_numbers, dtype=np.int64), + 'is_drum': np.array(is_drum, dtype=np.int64), + 'n_frames': max_length, # int + 'audio_array': audio_array # (n_tracks, n_frames) + } + return stem_content + + +def create_note_event_and_note_from_mtrack_mirdata( + ds: mirdata.core.Dataset, + mtrack_id: str, + fix_bass_octave: bool = True) -> Tuple[Dict, Dict]: + """Extracts note or note_event and metadata from a multitrack: + Args: + ds (mirdata.core.Dataset): Slakh dataset. + mtrack_id (str): multitrack id. + Returns: + notes (dict): note events and metadata. + note_events (dict): note events and metadata. + """ + mtrack = ds.multitrack(mtrack_id) + track_ids = mtrack.track_ids + program_numbers = [] + is_drum = [] + mixed_notes = [] + duration_sec = 0. + + # mix notes from all stem midi files + for track_id in track_ids: + track = ds.track(track_id) + stem_midi_file = track.midi_path + notes, dur_sec = midi2note( + stem_midi_file, + binary_velocity=True, + ch_9_as_drum=False, # checked safe to set to False in Slakh + force_all_drum=True if track.is_drum else False, + force_all_program_to=None, # Slakh always has program number + trim_overlap=True, + fix_offset=True, + quantize=True, + verbose=0, + minimum_offset_sec=0.01, + drum_offset_sec=0.01) + + if fix_bass_octave == True and track.program_number in np.arange(32, 40): + if track.plugin_name == 'scarbee_jay_bass_slap_both.nkm': + pass + else: + for note in notes: + note.pitch -= 12 + print("Fixed bass octave for track", track_id) + + mixed_notes = mix_notes((mixed_notes, notes), True, True, True) + program_numbers.append(track.program_number) + is_drum.append(1) if track.is_drum else is_drum.append(0) + duration_sec = max(duration_sec, dur_sec) + + # convert mixed notes to note events + mixed_note_events = note2note_event(mixed_notes, sort=True, return_activity=True) + return { # notes + 'mtrack_id': mtrack_id, # str + 'program': np.array(program_numbers, dtype=np.int64), # (n,) + 'is_drum': np.array(is_drum, dtype=np.int64), # (n,) with 1 is drum + 'duration_sec': duration_sec, # float + 'notes': mixed_notes # list of Note instances + }, { # note_events + 'mtrack_id': mtrack_id, # str + 'program': np.array(program_numbers, dtype=np.int64), # (n,) + 'is_drum': np.array(is_drum, dtype=np.int64), # (n,) with 1 is drum + 'duration_sec': duration_sec, # float + 'note_events': mixed_note_events # list of NoteEvent instances + } + + +def preprocess_slakh16k(data_home: str, + run_checksum: bool = False, + delete_source_files: bool = False, + fix_bass_octave: bool = True) -> None: + """ + Processes the Slakh dataset and extracts stems for each multitrack. + + Args: + data_home (str): path to the Slakh data. + run_checksum (bool): if True, validates the dataset using its checksum. Default is False. + delete_source_files (bool): if True, deletes original audio files. Default is False. + fix_bass_octave (bool): if True, fixes the bass to be -1 octave. Slakh bass is annotated as +1 octave. Default is True. + + Writes: + - {dataset_name}_{split}_file_list.json: a dictionary with the following keys: + { + 'mtrack_id': mtrack_id, + 'n_frames': n of audio frames + 'stem_file': Dict of stem audio file info + 'mix_audio_file': mtrack.mix_path, + 'notes_file': available only for 'validation' and 'test' + 'note_events_file': available only for 'train' and 'validation' + 'midi_file': mtrack.midi_path + } + """ + start_time = time.time() + + ds = slakh16k.Dataset(data_home=data_home, version='2100-yourmt3-16k') + if run_checksum: + print('Checksum for slakh dataset...') + ds.validate() + print('Preprocessing slakh dataset...') + + mtrack_split_dict = ds.get_mtrack_splits() + for split in ['train', 'validation', 'test']: + file_list = {} # write a file list for each split + mtrack_ids = mtrack_split_dict[split] + + for i, mtrack_id in enumerate(mtrack_ids): + print(f'🏃🏻‍♂️: processing {mtrack_id} ({i+1}/{len(mtrack_ids)} in {split})') + mtrack = ds.multitrack(mtrack_id) + output_dir = os.path.dirname(mtrack.mix_path) # same as mtrack + """Audio: get stems (as array) and metadata from the multitrack""" + stem_content = create_audio_stem_from_mtrack(ds, mtrack_id, delete_source_files) + + # save the audio array and metadata to disk + stem_file = os.path.join(output_dir, mtrack_id + '_stem.npy') + np.save(stem_file, stem_content) + print(f'💿 Created {stem_file}') + + # no preprocessing for mix audio + """MIDI: pre-process and get metadata from the multitrack""" + notes, note_events = create_note_event_and_note_from_mtrack_mirdata( + ds, mtrack_id, fix_bass_octave=fix_bass_octave) + # save the note events and metadata to disk + notes_file = os.path.join(output_dir, mtrack_id + '_notes.npy') + np.save(notes_file, notes, allow_pickle=True, \ + fix_imports=False) + print(f'🎹 Created {notes_file}') + + note_events_file = os.path.join(output_dir, mtrack_id + '_note_events.npy') + np.save(note_events_file, note_events, allow_pickle=True, \ + fix_imports=False) + print(f'🎹 Created {note_events_file}') + + # add to the file list of the split + file_list[i] = { + 'mtrack_id': mtrack_id, + 'n_frames': stem_content['n_frames'], # n of audio frames + 'stem_file': stem_file, + 'mix_audio_file': mtrack.mix_path, + 'notes_file': notes_file, + 'note_events_file': note_events_file,\ + 'midi_file': mtrack.midi_path + } + # By split, save a file list as json + summary_dir = os.path.join(data_home, 'yourmt3_indexes') + os.makedirs(summary_dir, exist_ok=True) + summary_file = os.path.join(summary_dir, f'slakh_{split}_file_list.json') + with open(summary_file, 'w') as f: + json.dump(file_list, f, indent=4) + print(f'💾 Created {summary_file}') + + elapsed_time = time.time() - start_time + print( + f"⏰: {int(elapsed_time // 3600):02d}h {int(elapsed_time % 3600 // 60):02d}m {elapsed_time % 60:.2f}s" + ) + """ end of preprocess_slakh16k """ + + +def add_program_and_is_drum_info_to_file_list(data_home: str): + + for split in ['train', 'validation', 'test']: + file_list_dir = os.path.join(data_home, 'yourmt3_indexes') + file = os.path.join(file_list_dir, f'slakh_{split}_file_list.json') + with open(file, 'r') as f: + file_list = json.load(f) + + for v in file_list.values(): + stem_file = v['stem_file'] + stem_content = np.load(stem_file, allow_pickle=True).item() + v['program'] = stem_content['program'].tolist() + v['is_drum'] = stem_content['is_drum'].tolist() + + with open(file, 'w') as f: + json.dump(file_list, f, indent=4) + print(f'💾 Added program and drum info to {file}') + + +if __name__ == '__main__': + from config.config import shared_cfg + data_home = shared_cfg['PATH']['data_home'] + preprocess_slakh16k(data_home=data_home, delete_source_files=False) \ No newline at end of file diff --git a/amt/src/utils/preprocess/preprocess_urmp.py b/amt/src/utils/preprocess/preprocess_urmp.py new file mode 100644 index 0000000000000000000000000000000000000000..eeca6a3e71bcec730a1015e52686fef6f2af6bf3 --- /dev/null +++ b/amt/src/utils/preprocess/preprocess_urmp.py @@ -0,0 +1,264 @@ +"""preprocess_mir1k.py""" +import os +import shutil +import glob +import re +import json +from typing import Dict, List, Tuple +import numpy as np +from utils.audio import get_audio_file_info, load_audio_file +from utils.midi import midi2note, note_event2midi +from utils.note2event import note2note_event, mix_notes, sort_notes, validate_notes, trim_overlapping_notes +from utils.event2note import event2note_event +from utils.note_event_dataclasses import Note, NoteEvent +from utils.utils import note_event2token2note_event_sanity_check, freq_to_midi + +MT3_TEST_IDS = [1, 2, 12, 13, 24, 25, 31, 38, 39] +PROGRAM_STR2NUM = { + 'vn': 40, + 'va': 41, + 'vc': 42, + 'db': 43, + 'fl': 73, + 'ob': 68, + 'cl': 71, + 'sax': 65, # The type of sax used in the dataset is not clear. We guess it would be alto sax. + 'bn': 70, + 'tpt': 56, + 'hn': 60, # Just annotated as horn. We guess it would be french horn, due to the pitch range. + 'tbn': 57, + 'tba': 58, +} + + +def delete_hidden_files(base_dir): + for hidden_file in glob.glob(os.path.join(base_dir, '**/.*'), recursive=True): + os.remove(hidden_file) + print(f"Deleted: {hidden_file}") + + +def convert_annotation_to_notes(id, program, ann_files): + notes = [] + for ann_file, prog in zip(ann_files, program): + data = np.loadtxt(ann_file) + onset = data[:, 0] + freq = data[:, 1] + duration = data[:, 2] + + notes_by_instr = [] + for o, f, d in zip(onset, freq, duration): + notes_by_instr.append( + Note( + is_drum=False, + program=prog, + onset=o, + offset=o + d, + pitch=freq_to_midi(f), + velocity=1)) + notes = mix_notes([notes, notes_by_instr], sort=True, trim_overlap=True, fix_offset=True) + notes = sort_notes(notes) + note_events = note2note_event(notes, sort=True) + duration_sec = note_events[-1].time + 0.01 + return { # notes + 'urmp_id': id, + 'program': program, + 'is_drum': [0] * len(program), + 'duration_sec': duration_sec, + 'notes': notes, + }, { # note_events + 'guitarset_id': id, + 'program': program, + 'is_drum': [0] * len(program), + 'duration_sec': duration_sec, + 'note_events': note_events, + } + + +def create_audio_stem(audio_tracks, id, program, n_frames): + max_length = max([len(tr) for tr in audio_tracks]) + max_length = max(max_length, n_frames) + n_tracks = len(audio_tracks) + audio_array = np.zeros((n_tracks, max_length), dtype=np.float16) + for j, audio in enumerate(audio_tracks): + audio_array[j, :len(audio)] = audio + + return { + 'urmp_id': id, + 'program': np.array(program), + 'is_drum': np.array([0] * len(program), dtype=np.int64), + 'n_frames': n_frames, # int + 'audio_array': audio_array # (n_tracks, n_frames) + } + + +def data_bug_fix(base_dir): + files = glob.glob(os.path.join(base_dir, '15_Surprise_tpt_tpt_tbn', '*3_tpt*.*')) + for file in files: + new_file = file.replace('3_tpt', '3_tbn') + shutil.move(file, new_file) + print(f"Renamed: {file} -> {new_file}") + + +def preprocess_urmp16k(data_home=os.PathLike, + dataset_name='urmp', + delete_source_files: bool = False, + sanity_check=True) -> None: + """ + URMP dataset does not have official split information. We follow the split used in MT3 paper. + + About: + - 44 pieces of classical music + - Duet, Trio, Quartet, Quintet of strings or winds or mixed + - Multi-stem audio + - MIDI file is unaligned, it is for score + - Annotation (10ms hop) is provided. + - There is timing issue for annotation + - We do not use video + + Splits: + - train: 35 files, following MT3 + - test: 9 files, follwing MT3 + - all: 44 files + + Writes: + - {dataset_name}_{split}_file_list.json: a dictionary with the following keys: + { + index: + { + 'urmp_id': urmp_id, + 'n_frames': (int), + 'stem_file': 'path/to/stem.npy', + 'mix_audio_file': 'path/to/mix.wav', + 'notes_file': 'path/to/notes.npy', + 'note_events_file': 'path/to/note_events.npy', + 'midi_file': 'path/to/midi.mid', # this is 120bpm converted midi file from note_events + 'program': List[int], # + 'is_drum': List[int], # [0] or [1] + } + } + """ + + # Directory and file paths + base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') + output_index_dir = os.path.join(data_home, 'yourmt3_indexes') + os.makedirs(output_index_dir, exist_ok=True) + + # Databug fix + data_bug_fix(base_dir) + + # Delete hidden files + delete_hidden_files(base_dir) + + # Create file list for split==all + file_list = dict() + for dir_name in sorted(os.listdir(base_dir)): + if dir_name.startswith('.'): + continue + if 'Supplementary' in dir_name: + continue + # urmp_id + id = dir_name.split('_')[0] + title = dir_name.split('_')[1] + + # program + program_strings = dir_name.split('_')[2:] + program = [PROGRAM_STR2NUM[p] for p in program_strings] + + # is_drum + is_drum = [0] * len(program) + + # file paths + stem_file = os.path.join(base_dir, dir_name, 'stem.npy') + mix_audio_file = glob.glob(os.path.join(base_dir, dir_name, 'AuMix*.wav'))[0] + notes_file = os.path.join(base_dir, dir_name, 'notes.npy') + note_events_file = os.path.join(base_dir, dir_name, 'note_events.npy') + midi_file = os.path.join(base_dir, dir_name, f'{str(id)}_120bpm_converted.mid') + + # n_frames + fs, n_frames, n_channels = get_audio_file_info(mix_audio_file) + assert fs == 16000 and n_channels == 1 + + # Fill out a file list + file_list[id] = { + 'urmp_id': id, + 'n_frames': n_frames, + 'stem_file': stem_file, + 'mix_audio_file': mix_audio_file, + 'notes_file': notes_file, + 'note_events_file': note_events_file, + 'midi_file': midi_file, + 'program': program, + 'is_drum': is_drum, + } + + # Process Annotations + ann_files = [ + os.path.join(base_dir, dir_name, f'Notes_{i+1}_{p}_{str(id)}_{title}.txt') + for i, p in enumerate(program_strings) + ] + + # Check if all files exist + for ann_file in ann_files: + assert os.path.exists(ann_file), f"{ann_file} does not exist." + assert len(program) == len(ann_files) + + # Create and save notes and note_events from annotation + notes, note_events = convert_annotation_to_notes(id, program, ann_files) + np.save(notes_file, notes, allow_pickle=True, fix_imports=False) + print(f'Created {notes_file}') + np.save(note_events_file, note_events, allow_pickle=True, fix_imports=False) + print(f'Created {note_events_file}') + + # Create 120bpm MIDI file from note_events + note_event2midi(note_events['note_events'], midi_file) + print(f'Created {midi_file}') + + # Process Audio + audio_tracks = [] + for i, p in enumerate(program_strings): + audio_sep_file = os.path.join(base_dir, dir_name, f'AuSep_{i+1}_{p}_{id}_{title}.wav') + audio_track = load_audio_file(audio_sep_file, dtype=np.int16) / 2**15 # returns bytes + audio_tracks.append(audio_track.astype(np.float16)) + if delete_source_files: + os.remove(audio_sep_file) + + stem_content = create_audio_stem(audio_tracks, id, program, n_frames) + np.save(stem_file, stem_content, allow_pickle=True, fix_imports=False) + print(f'Created {stem_file}') + + # Sanity check + if sanity_check: + recon_notes, _ = midi2note(midi_file) + recon_note_events = note2note_event(recon_notes) + note_event2token2note_event_sanity_check(recon_note_events, notes['notes']) + + # File existence check + assert os.path.exists(mix_audio_file) + + # Create index for splits + file_list_all = {} + for i, key in enumerate(file_list.keys()): + file_list_all[i] = file_list[key] + + file_list_train = {} + i = 0 + for key in file_list.keys(): + if int(key) not in MT3_TEST_IDS: + file_list_train[i] = file_list[key] + i += 1 + + file_list_test = {} + i = 0 + for key in file_list.keys(): + if int(key) in MT3_TEST_IDS: + file_list_test[i] = file_list[key] + i += 1 + + all_fl = {'all': file_list_all, 'train': file_list_train, 'test': file_list_test} + + # Save index + for split in ['all', 'train', 'test']: + output_index_file = os.path.join(output_index_dir, f'{dataset_name}_{split}_file_list.json') + with open(output_index_file, 'w') as f: + json.dump(all_fl[split], f, indent=4) + print(f'Created {output_index_file}') \ No newline at end of file diff --git a/amt/src/utils/task_manager.py b/amt/src/utils/task_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..5911b621a244039127028083b917d2444a7a9303 --- /dev/null +++ b/amt/src/utils/task_manager.py @@ -0,0 +1,398 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +import numpy as np +from typing import Optional, Union, Tuple, Dict, Any, List, Counter +from utils.note_event_dataclasses import NoteEvent, Event, NoteEventListsBundle +from config.task import task_cfg +from config.config import model_cfg +from utils.tokenizer import NoteEventTokenizer +from utils.utils import create_program2channel_vocab +from utils.note2event import separate_channel_by_program_group_from_note_event_lists_bundle + +SINGING_PROGRAM = 100 +DRUM_PROGRAM = 128 +UNANNOTATED_PROGRAM = 129 + +# import random +# class RandomProgramSampler: +# def __init__(self, program_vocab: Dict[str, int], max_n: int = 7): +# for key, values in program_vocab.items(): +# for value in values: +# self.inverse_vocab_program[value] = values[0] +# self.max_n = max_n +# self.shuffled_ + +# def sample(self): + +# def shuffle_and_repeat_randomly(lst, max_n=5): +# shuffled = lst.copy() +# random.shuffle(shuffled) +# index = 0 + +# while True: +# if index >= len(shuffled): # 리스트의 모든 요소가 사용되면, 다시 셔플 +# random.shuffle(shuffled) +# index = 0 + +# n = random.randint(1, max_n) # 1과 max_n 사이의 랜덤한 개수 결정 +# end_index = index + n + +# if end_index > len(shuffled): # 리스트의 끝을 넘어가는 경우, 리스트의 끝까지만 반환 +# yield shuffled[index:] +# index = len(shuffled) +# else: +# yield shuffled[index:end_index] +# index = end_index + + +class TaskManager: + """ + The TaskManager class manages tasks for training. It is initialized with a task name and retrieves + the corresponding configuration from the task_cfg dictionary defined in config/task.py. + + Attributes: + # Basic + task_name (str): The name of the task being managed. + base_codec (str): The base codec associated with the task. + train_program_vocab (dict): The program vocabulary used for training. + train_drum_vocab (dict): The drum vocabulary used for training. + subtask_tokens (list): Additional tokens specific to subtasks, if any. + extra_tokens (list): Extra tokens used in the task, including subtask tokens. + ignore_decoding_tokens (list): Tokens to ignore during decoding. + ignore_decoding_tokens_by_delimiter (Optional, list[str, str]): Tokens to ignore during decoding by delimiters. Default is None. + tokenizer (NoteEventTokenizer): An instance of the NoteEventTokenizer class for tokenizing note events. + eval_subtask_prefix (dict): A dictionary defining evaluation subtask prefixes to tokens. + + # Multi-channel decoding task exclusive + num_decoding_channels (int): The number of decoding channels. + max_token_length_per_ch (int): The maximum token length per channel. + mask_loss_strategy (str): The mask loss strategy to use. NOT IMPLEMENTED YET. + program2channel_vocab (dict): A dictionary mapping program to channel. + + Methods: + get_tokenizer(): Returns the tokenizer instance associated with the task. + set_tokenizer(): Initializes the tokenizer using the NoteEventTokenizer class with the appropriate parameters. + """ + + def __init__(self, task_name: str = "mt3_full_plus", max_shift_steps: int = 206, debug_mode: bool = False): + """ + Initializes a TaskManager object with the specified task name. + + Args: + task_name (str): The name of the task to manage. + max_shift_steps (int): The maximum shift steps for the tokenizer. Default is 206. Definable in config/config.py. + debug_mode (bool): Whether to enable debug mode. Default is False. + """ + self.debug_mode = debug_mode + self.task_name = task_name + + if task_name not in task_cfg.keys(): + raise ValueError("Invalid task name") + else: + self.task = task_cfg[task_name] + + # Basic task parameters + self.base_codec = self.task.get("base_codec", "mt3") + self.train_program_vocab = self.task["train_program_vocab"] + self.train_drum_vocab = self.task["train_drum_vocab"] + self.subtask_tokens = self.task.get("subtask_tokens", []) + self.extra_tokens = self.subtask_tokens + self.task.get("extra_tokens", []) + self.ignore_decoding_tokens = self.task.get("ignore_decoding_tokens", []) + self.ignore_decoding_tokens_from_and_to = self.task.get("ignore_decoding_tokens_from_and_to", None) + self.max_note_token_length = self.task.get("max_note_token_length", model_cfg["event_length"]) + self.max_task_token_length = self.task.get("max_task_token_length", 0) + self.padding_task_token = self.task.get("padding_task_token", False) + self._eval_subtask_prefix = self.task.get("eval_subtask_prefix", None) + self.eval_subtask_prefix_dict = {} + + # Multi-channel decoding exclusive parameters + self.num_decoding_channels = self.task.get("num_decoding_channels", 1) + if self.num_decoding_channels > 1: + program2channel_vocab_source = self.task.get("program2channel_vocab_source", None) + if program2channel_vocab_source is None: + program2channel_vocab_source = self.train_program_vocab + + # Create an inverse mapping of program to channel + if self.num_decoding_channels == len(program2channel_vocab_source) + 1: + self.program2channel_vocab, _ = create_program2channel_vocab(program2channel_vocab_source) + else: + raise ValueError("Invalid num_decoding_channels, or program2channel_vocab not provided") + + self.max_note_token_length_per_ch = self.task.get("max_note_token_length_per_ch") + self.mask_loss_strategy = self.task.get("mask_loss_strategy", None) # Not implemented yet + else: + self.max_note_token_length_per_ch = self.max_note_token_length + + # Define max_total_token_length + self.max_total_token_length = self.max_note_token_length_per_ch + self.max_task_token_length + + # Max shift steps for the tokenizer + self.max_shift_steps = max_shift_steps + + # Initialize a tokenizer + self.set_tokenizer() + self.set_eval_task_prefix() + self.num_tokens = self.tokenizer.num_tokens + self.inverse_vocab_program = self.tokenizer.codec.inverse_vocab_program + + def set_eval_task_prefix(self) -> None: + """ + Sets the evaluation task prefix for the task. + + Example: + self.eval_task_prefix_dict = { + "default": [Event("transcribe_all", 0), Event("task", 0)], + "singing-only": [Event("transcribe_singing", 0), Event("task", 0)] + } + """ + if self._eval_subtask_prefix is not None: + assert "default" in self._eval_subtask_prefix.keys() + for key, val in self._eval_subtask_prefix.items(): + if self.padding_task_token: + self.eval_subtask_prefix_dict[key] = self.tokenizer.encode_task( + val, max_length=self.max_task_token_length) + else: + self.eval_subtask_prefix_dict[key] = self.tokenizer.encode_task(val) + else: + self.eval_subtask_prefix_dict["default"] = [] + + def get_eval_subtask_prefix_dict(self) -> dict: + return self.eval_subtask_prefix_dict + + def get_tokenizer(self) -> NoteEventTokenizer: + """ + Returns the tokenizer instance associated with the task. + + Returns: + NoteEventTokenizer: The tokenizer instance. + """ + return self.tokenizer + + def set_tokenizer(self) -> None: + """ + Initializes the tokenizer using the NoteEventTokenizer class with the appropriate parameters. + """ + self.tokenizer = NoteEventTokenizer(base_codec=self.base_codec, + max_length=self.max_total_token_length, + program_vocabulary=self.train_program_vocab, + drum_vocabulary=self.train_drum_vocab, + special_tokens=['PAD', 'EOS', 'UNK'], + extra_tokens=self.extra_tokens, + max_shift_steps=self.max_shift_steps, + ignore_decoding_tokens=self.ignore_decoding_tokens, + ignore_decoding_tokens_from_and_to=self.ignore_decoding_tokens_from_and_to, + debug_mode=self.debug_mode) + + # Newly implemented for exclusive transcription task + def tokenize_task_and_note_events_batch( + self, + programs_segments: List[List[int]], + has_unannotated_segments: List[bool], + note_event_segments: NoteEventListsBundle, + subunit_programs_segments: Optional[List[List[np.ndarray]]] = None, # TODO + subunit_note_event_segments: Optional[List[NoteEventListsBundle]] = None, # TODO + stage: str = 'train' # 'train' or 'eval' + ): + """Tokenizes a batch of note events into a batch of encoded tokens. + Optionally, appends task tokens to the note event tokens. + + Args: + programs_segments (List[int]): A list of program numbers. + has_unannotated_segments (bool): Whether the batch has unannotated segments. + note_event_segments (NoteEventListsBundle): A bundle of note events. + subunit_programs_segments (Optional[List[List[np.ndarray]]]): A list of subunit programs. + subunit_note_event_segments (Optional[List[NoteEventListsBundle]]): A list of subunit note events. + + Returns: + np.ndarray: A batch of encoded tokens, with shape (B, C, L). + """ + if self.task_name == 'exclusive': + # batch_sz = len(programs_segments) + # token_array = np.zeros((batch_sz, self.num_decoding_channels, self.max_note_token_length_per_ch), + # dtype=np.int32) + + # for programs, has_unannotated, note_events, tie_note_events, start_times in zip( + # programs_segments, has_unannotated_segments, note_event_segments['note_events'], + # note_event_segments['tie_note_events'], note_event_segments['start_times']): + # if has_unannotated: + # annotated_programs = [p for p in programs if p != UNANNOTATED_PROGRAM] + # note_token_array = self.tokenizer.encode_plus(note_events, + # tie_note_events, + # start_times, + # pad_to_max_length=False) # will append EOS token + # task_token_array = self.tokenizer.encode_task(task_events) + # else: + # annotated_programs = programs + + # task_events = [Event('transcribe_all', 0), Event('task', 0)] + # note_token_array = self.tokenize_note_events_batch(note_events) + # task_token_array = self.tokenize_task_events(annotated_programs, has_unannotated) + # return [] + raise NotImplementedError("Exclusive transcription task is not implemented yet.") + else: + # Default task: single or multi-channel decoding, without appending task tokens + return self.tokenize_note_events_batch(note_event_segments) # (B, C, L) + # Exclusive transcription task + # if has_unannotated_segments: + # annotated_programs = [p for p in programs_segments if p != UNANNOTATED_PROGRAM] + # else: + # annotated_programs = programs_segments + + # # Main task: transcribe all + # main_task_events = self.task.get("eval_subtask_prefix") + + def tokenize_note_events_batch(self, + note_event_segments: NoteEventListsBundle, + start_time_to_zero: bool = False, + sort: bool = True) -> np.ndarray: + """Tokenizes a batch of note events into a batch of encoded tokens. + + Args: + note_event_segments (NoteEventListsBundle): A bundle of note events. + + Returns: + np.ndarray: A batch of encoded tokens, with shape (B, C, L). + """ + batch_sz = len(note_event_segments["note_events"]) + note_token_array = np.zeros((batch_sz, self.num_decoding_channels, self.max_note_token_length_per_ch), + dtype=np.int32) + + if self.num_decoding_channels == 1: + # Single-channel decoding task + zipped_events = list(zip(*note_event_segments.values())) + for b in range(batch_sz): + note_token_array[b, 0, :] = self.tokenizer.encode_plus(*zipped_events[b], + max_length=self.max_note_token_length, + pad_to_max_length=True) + elif self.num_decoding_channels > 1: + # Multi-channel decoding task + ch_sep_ne_bundle = separate_channel_by_program_group_from_note_event_lists_bundle( + source_note_event_lists_bundle=note_event_segments, + num_program_groups=self.num_decoding_channels, + program2channel_vocab=self.program2channel_vocab, + start_time_to_zero=start_time_to_zero, + sort=sort) # (batch_sz,) + + for b in range(batch_sz): + zipped_channel = list(zip(*ch_sep_ne_bundle[b].values())) + for c in range(self.num_decoding_channels): + note_token_array[b, c, :] = self.tokenizer.encode_plus(*zipped_channel[c], + max_length=self.max_note_token_length_per_ch, + pad_to_max_length=True) + return note_token_array # (B, C, L) + + def tokenize_note_events(self, + note_events: List[NoteEvent], + tie_note_events: Optional[List[NoteEvent]] = None, + start_time: float = 0., + **kwargs: Any) -> List[int]: + """(Deprecated) Tokenizes a sequence of note events into a sequence of encoded tokens.""" + return self.tokenizer.encode_plus(note_events, tie_note_events, start_time, **kwargs) + + +# # This will be deprecated, currently used by datasets_eval.py + +# def tokenize_task_events_batch(self, programs_segments: List[int], +# has_unannotated_segments: List[bool]) -> List[int]: +# """Tokenizes batch of task tokens from annotation info. + +# Args: +# programs_segments (List[int]): A list of program numbers. +# has_unannotated_segments (bool): Whether the batch has unannotated segments. + +# Returns: +# np.ndarray: Shape (B, C, L). + +# """ +# batch_sz = len(programs_segments) +# task_token_array = np.zeros((batch_sz, self.num_decoding_channels, self.max_task_token_length), dtype=np.int32) + +# if self.max_task_token_length == 0: +# return task_token_array + +# if self.num_decoding_channels == 1: +# for b in range(batch_sz): +# task_token_array[b, 0, :] = self.tokenize_task_events(programs_segments[b], has_unannotated_segments[b]) +# elif self.num_decoding_channels > 1: +# for b in range(batch_sz): +# task_token_array[b, :, :] = self.tokenize_task_events(programs_segments[b], has_unannotated_segments[b]) +# return task_token_array # (B, C, L) + + def tokenize_task_events(self, programs: List[int], has_unannotated: bool) -> List[int]: + """Tokenizes a sequence of programs into a sequence of encoded tokens. Used for training.""" + if self.task_name == 'singing_drum_v1': + if has_unannotated: + if SINGING_PROGRAM in programs: + task_events = [Event('transcribe_singing', 0), Event('task', 0)] + elif DRUM_PROGRAM in programs: + task_events = [Event('transcribe_drum', 0), Event('task', 0)] + else: + task_events = [Event('transcribe_all', 0), Event('task', 0)] + else: + return [] + + if self.padding_task_token: + return self.tokenizer.encode_task(task_events, max_length=self.max_task_token_length) + else: + return self.tokenizer.encode_task(task_events) + + def detokenize( + self, + tokens: List[int], + start_time: float = 0., + return_events: bool = False + ) -> Union[Tuple[List[NoteEvent], List[NoteEvent]], Tuple[List[NoteEvent], List[NoteEvent], List[Event], int]]: + """Decodes a sequence of tokens into note events, ignoring specific token IDs. + Returns: + Union[Tuple[List[NoteEvent], List[NoteEvent]], + Tuple[List[NoteEvent], List[NoteEvent], List[Event], int]]: The decoded note events. + If `return_events` is False, the returned tuple contains `note_events`, `tie_note_events`, + `last_activity`, and `err_cnt`. + If `return_events` is True, the returned tuple contains `note_events`, `tie_note_events`, + `last_activity`, `events`, and `err_cnt`. + + Notes: + This decoding process ignores specific token IDs based on `self.ids_to_ignore_decoding` attribute. + """ + return self.tokenizer.decode(tokens=tokens, start_time=start_time, return_events=return_events) + + def detokenize_list_batches( + self, + list_batch_tokens: Union[List[List[List[int]]], List[np.ndarray]], + list_start_times: Union[List[List[float]], List[float]], + return_events: bool = False + ) -> Union[Tuple[List[List[Tuple[List[NoteEvent], List[NoteEvent], int, float]]], Counter[str]], Tuple[ + List[List[Tuple[List[NoteEvent], List[NoteEvent], int, float]]], List[List[Event]], Counter[str]]]: + """ Decodes a list of variable size batches of token array to a list of + zipped note_events and tie_note_events. + + Args: + list_batch_tokens: List[np.ndarray], where array shape is (batch_size, variable_length) + list_start_times: List[float], where the length is sum of all batch_sizes. + return_events: bool + + Returns: + list_list_zipped_note_events_and_tie: + List[ + Tuple[ + List[NoteEvent]: A list of note events. + List[NoteEvent]: A list of tie note events. + List[Tuple[int]]: A list of last activity of segment. [(program, pitch), ...]. This is useful + for validating notes within a batch of segments extracted from a file. + List[float]: A list of segment start times. + ] + ] + (Optional) list_events: + List[List[Event]] + total_err_cnt: + Counter[str]: error counter. + + """ + return self.tokenizer.decode_list_batches(list_batch_tokens, list_start_times, return_events) diff --git a/amt/src/utils/tokenizer.py b/amt/src/utils/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..cdc309601e3f0f82503ccb5b0154b9cfe652a552 --- /dev/null +++ b/amt/src/utils/tokenizer.py @@ -0,0 +1,410 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +""" tokenizer.py: Encodes and decodes events to/from tokens. """ +import numpy as np +import warnings +from abc import ABC, abstractmethod +from utils.note_event_dataclasses import Event, EventRange, Note #, Codec +from utils.event_codec import FastCodec as Codec +from utils.note_event_dataclasses import NoteEvent +from utils.note2event import note_event2event +from utils.event2note import event2note_event, note_event2note +from typing import List, Optional, Union, Tuple, Dict, Counter + + +#TODO: Too complex to be an abstract class. +class EventTokenizerBase(ABC): + """ + A base class for encoding and decoding events to and from tokens. + """ + + def __init__( + self, + base_codec: Union[Codec, str] = 'mt3', + special_tokens: List[str] = ['PAD', 'EOS', 'UNK'], + extra_tokens: List[str] = [], + max_shift_steps: int = 206, # 1001 in Gardner et al. + program_vocabulary: Optional[Dict] = None, + drum_vocabulary: Optional[Dict] = None, + ) -> None: + """ + Initializes the EventTokenizerBase object. + + :param base_codec: The codec to use for encoding and decoding. + :param special_tokens: None or list of special tokens to include in the vocabulary. + :param extra_tokens: None or list of tokens to be treated as additional special tokens. + :param program_vocabulary: None or a dictionary mapping program names to program indices. + :param drum_vocabulary: None or a dictionary mapping drum names to drum indices. + :param max_shift_steps: The maximum number of shift steps to use for the codec. + """ + # Initialize the codec attribute based on the input codec parameter. + if isinstance(base_codec, str): + # If codec is a string, initialize codec with the appropriate Codec object. + if base_codec.lower() == 'mt3': + event_ranges = [ + EventRange('pitch', min_value=0, max_value=127), + EventRange('velocity', min_value=0, max_value=1), + EventRange('tie', min_value=0, max_value=0), + EventRange('program', min_value=0, max_value=127), + EventRange('drum', min_value=0, max_value=127), + ] + else: + raise ValueError(f'Unknown codec name: {base_codec}') + + # Initialize codec + self.codec = Codec(special_tokens=special_tokens + extra_tokens, + max_shift_steps=max_shift_steps, + event_ranges=event_ranges, + program_vocabulary=program_vocabulary, + drum_vocabulary=drum_vocabulary, + name='mt3') + + elif isinstance(base_codec, Codec): + # If codec is a Codec object, store it directly. + self.codec = base_codec + if program_vocabulary is not None or drum_vocabulary is not None: + print('') + warnings.warn("Vocabulary cannot be applied when using a custom codec.") + else: + # If codec is neither a string nor a Codec object, raise a NotImplementedError. + raise TypeError(f'Unknown codec type: {type(base_codec)}') + self.num_tokens = self.codec._num_classes + + def _encode(self, events: List[Event]) -> List[int]: + return [self.codec.encode_event(e) for e in events] + + def _decode(self, tokens: List[int]) -> List[Event]: + return [self.codec.decode_event_index(idx) for idx in tokens] + + @abstractmethod + def encode(self): + """ Encode your custom events to tokens. """ + pass + + @abstractmethod + def decode(self): + """ Decode your custom tokens to events.""" + pass + + +class EventTokenizer(EventTokenizerBase): + """ + Eencoding and decoding events to and from tokens. + """ + + def __init__(self, + base_codec: Union[Codec, str] = 'mt3', + special_tokens: List[str] = ['PAD', 'EOS', 'UNK'], + extra_tokens: List[str] = [], + max_shift_steps: int = 206, + program_vocabulary: Optional[Dict] = None, + drum_vocabulary: Optional[Dict] = None) -> None: + """ + Initializes the EventTokenizerBase object. + + :param codec: The codec to use for encoding and decoding. + :param special_tokens: None or list of special tokens to include in the vocabulary. + :param extra_tokens: None or list of tokens to be treated as additional special tokens. + :param program_vocabulary: None or a dictionary mapping program names to program indices. + :param drum_vocabulary: None or a dictionary mapping drum names to drum indices. + :param max_shift_steps: The maximum number of shift steps to use for the codec. + """ + # Initialize the codec attribute based on the input codec parameter. + super().__init__( + base_codec=base_codec, + special_tokens=special_tokens, + extra_tokens=extra_tokens, + max_shift_steps=max_shift_steps, + program_vocabulary=program_vocabulary, + drum_vocabulary=drum_vocabulary, + ) + + def encode(self, events): + """ Encode your custom events to tokens. """ + return super()._encode(events) + + def decode(self, tokens): + """ Decode your custom tokens to events.""" + return super()._decode(tokens) + + +class NoteEventTokenizer(EventTokenizerBase): + """ Encodes and decodes note events to/from tokens. """ + + def __init__( + self, + base_codec: Union[Codec, str] = 'mt3', + max_length: int = 1024, # max length of tokens + tps: int = 100, + sort_note_event: bool = True, + special_tokens: List[str] = ['PAD', 'EOS', 'UNK'], + extra_tokens: List[str] = [], + max_shift_steps: int = 206, + program_vocabulary: Optional[Dict] = None, + drum_vocabulary: Optional[Dict] = None, + ignore_decoding_tokens: List[str] = [], + ignore_decoding_tokens_from_and_to: Optional[List[str]] = None, + debug_mode: bool = False) -> None: + """ + Initializes the TaskEventNoteTokenizer object. + + List[NoteEvent] -> encdoe_note_events -> np.ndarray[int] + + np.ndarray[int] -> decode_note_events -> Tuple[List[NoteEvent], List[NoteEvent]] + + :param codec: The codec to use for encoding and decoding. + :param special_tokens: None or list of special tokens to include in the vocabulary. + :param extra_tokens: None or list of tokens to be treated as additional special tokens. + :param program_vocabulary: None or a dictionary mapping program names to program indices. + :param drum_vocabulary: None or a dictionary mapping drum names to drum indices. + :param max_shift_steps: The maximum number of shift steps to use for the codec. + + :param ignore_decoding_tokens: List of tokens to ignore during decoding. + :param ignore_decoding_tokens_from_and_to: List of tokens to ignore during decoding. [from, to] + """ + super().__init__(base_codec=base_codec, + special_tokens=special_tokens, + extra_tokens=extra_tokens, + max_shift_steps=max_shift_steps, + program_vocabulary=program_vocabulary, + drum_vocabulary=drum_vocabulary) + self.max_length = max_length + self.tps = tps + self.sort = sort_note_event + + # Prepare prefix, suffix and pad tokens. + self._prefix = [] + self._suffix = [] + for stk in self.codec.special_tokens: + if stk == 'EOS': + self._suffix.append(self.codec.special_tokens.index('EOS')) + elif stk == 'PAD': + self._zero_pad = [0] * 1024 + elif stk == 'UNK': + pass + else: + pass + # raise NotImplementedError(f'Unknown special token: {stk}') + self.eos_id = self.codec.special_tokens.index('EOS') + self.pad_id = self.codec.special_tokens.index('PAD') + self.ids_to_ignore_decoding = [self.codec.special_tokens.index(t) for t in ignore_decoding_tokens] + self.ignore_tokens_from_and_to = ignore_decoding_tokens_from_and_to + self.debug_mode = debug_mode + + def _decode(self, tokens): + # This is event detokenizer, not note_event. It is required for displaying events in validation dashboard + return super()._decode(tokens) + + def encode( + self, + note_events: List[NoteEvent], + tie_note_events: Optional[List[NoteEvent]] = None, + start_time: float = 0., + ) -> List[int]: + """ Encodes note events and tie note events to tokens. """ + events = note_event2event( + note_events=note_events, + tie_note_events=tie_note_events, + start_time=start_time, # required for calcuating relative time + tps=self.tps, + sort=self.sort) + return super()._encode(events) + + def encode_plus( + self, + note_events: List[NoteEvent], + tie_note_events: Optional[List[NoteEvent]] = None, + start_times: float = 0., # Fixing bug: start_time --> start_times + add_special_tokens: Optional[bool] = True, + max_length: Optional[int] = None, # if None, use self.max_length + pad_to_max_length: Optional[bool] = True, + return_attention_mask: bool = False) -> Union[List[int], Tuple[List[int], List[int]]]: + """ Encodes note events and tie note info to padded tokens. """ + encoded = self.encode(note_events, tie_note_events, start_times) + + # if task_events: + # encoded = super()._encode(task_events) + encoded + if add_special_tokens: + if self._prefix: + encoded = self._prefix + encoded + if self._suffix: + encoded = encoded + self._suffix + + if max_length is None: + max_length = self.max_length + + length = len(encoded) + if length >= max_length: + encoded = encoded[:max_length] + length = max_length + + if return_attention_mask: + attention_mask = [1] * length + + # + if pad_to_max_length is True: + if len(self._zero_pad) != max_length: + self._zero_pad = [self.pad_id] * max_length + if return_attention_mask: + attention_mask += self._zero_pad[length:] + encoded = encoded + self._zero_pad[length:] + + if return_attention_mask: + return encoded, attention_mask + + return encoded + + def encode_task(self, task_events: List[Event], max_length: Optional[int] = None) -> List[int]: + # NOTE: This is an event tokenizer that generates task ids, not the list of note_event objects. + encoded = super()._encode(task_events) + + # + if max_length is not None: + if len(self._zero_pad_task) != max_length: + self._zero_pad_task = [self.pad_id] * max_length + length = len(encoded) + encoded = encoded + self._zero_pad[length:] + + return encoded + + def decode( + self, + tokens: List[int], + start_time: float = 0., + return_events: bool = False, + ) -> Union[Tuple[List[NoteEvent], List[NoteEvent]], Tuple[List[NoteEvent], List[NoteEvent], List[Tuple[int]], + List[Event], int]]: + """Decodes a sequence of tokens into note events. + + Args: + tokens (List[int]): The list of tokens to be decoded. + start_time (float, optional): The starting time for the note events. Defaults to 0. + return_events (bool, optional): Indicates whether to include the raw events in the return value. + Defaults to False. + + Returns: + Union[Tuple[List[NoteEvent], List[NoteEvent]], + Tuple[List[NoteEvent], List[NoteEvent], List[Event], int]]: The decoded note events. + If `return_events` is False, the returned tuple contains `note_events`, `tie_note_events`, + `last_activity`, and `err_cnt`. + If `return_events` is True, the returned tuple contains `note_events`, `tie_note_events`, + `last_activity`, `events`, and `err_cnt`. + """ + if self.debug_mode: + ignored_tokens_from_input = [t for t in tokens if t in self.ids_to_ignore_decoding] + print(ignored_tokens_from_input) + + if self.ids_to_ignore_decoding: + tokens = [t for t in tokens if t not in self.ids_to_ignore_decoding] + + events = super()._decode(tokens) + note_events, tie_note_events, last_activity, err_cnt = event2note_event(events, start_time, True, self.tps) + if return_events: + return note_events, tie_note_events, last_activity, events, err_cnt + else: + return note_events, tie_note_events, last_activity, err_cnt + + def decode_batch( + self, + batch_tokens: Union[List[List[int]], np.ndarray], + start_times: List[float], + return_events: bool = False + ) -> Union[Tuple[List[Tuple[List[NoteEvent], List[NoteEvent], List[Tuple[int]], List[float]]], int], + Tuple[List[Tuple[List[NoteEvent], List[NoteEvent], List[Tuple[int]], List[float]]], List[List[Event]], + Counter[str]]]: + """ + Decodes a batch of tokens to note_events and tie_note_events. + + Args: + batch_tokens (List[List[int]] or np.ndarray): Tokens to be decoded. + start_times (List[float]): List of start times for each token set. + return_events (bool, optional): Flag to determine if events should be returned. Defaults to False. + + """ + if isinstance(batch_tokens, np.ndarray): + batch_tokens = batch_tokens.tolist() + + if len(batch_tokens) != len(start_times): + raise ValueError('The length of batch_tokens and start_times must be same.') + + zipped_note_events_and_tie = [] + list_events = [] + total_err_cnt = 0 + + for tokens, start_time in zip(batch_tokens, start_times): + if return_events: + note_events, tie_note_events, last_activity, events, err_cnt = self.decode( + tokens, start_time, return_events) + list_events.append(events) + else: + note_events, tie_note_events, last_activity, err_cnt = self.decode(tokens, start_time, return_events) + + zipped_note_events_and_tie.append((note_events, tie_note_events, last_activity, start_time)) + total_err_cnt += err_cnt + + if return_events: + return zipped_note_events_and_tie, list_events, total_err_cnt + else: + return zipped_note_events_and_tie, total_err_cnt + + def decode_list_batches( + self, + list_batch_tokens: Union[List[List[List[int]]], List[np.ndarray]], + list_start_times: Union[List[List[float]], List[float]], + return_events: bool = False + ) -> Union[Tuple[List[List[Tuple[List[NoteEvent], List[NoteEvent], List[Tuple[int]], List[float]]]], Counter[str]], + Tuple[List[List[Tuple[List[NoteEvent], List[NoteEvent], List[Tuple[int]], List[float]]]], + List[List[Event]], Counter[str]]]: + """ + Decodes a list of variable-size batches of token array to a list of + zipped note_events and tie_note_events. + + Args: + list_batch_tokens: List[np.ndarray], where array shape is (batch_size, variable_length) + list_start_times: List[float], where the length is sum of all batch_sizes. + return_events: bool, Defaults to False. + + Returns: + list_list_zipped_note_events_and_tie: + List[ + Tuple[ + List[NoteEvent]: A list of note events. + List[NoteEvent]: A list of tie note events. + List[Tuple[int]]: A list of last activity of segment. [(program, pitch), ...]. This is useful + for validating notes within a batch of segments extracted from a file. + List[float]: A list of segment start times. + ] + ] + (Optional) list_events: + List[List[Event]] + total_err_cnt: + Counter[str]: error counter. + """ + list_tokens = [] + for arr in list_batch_tokens: + for tokens in arr: + list_tokens.append(tokens) + assert (len(list_tokens) == len(list_start_times)) + + zipped_note_events_and_tie = [] + list_events = [] + total_err_cnt = Counter() + for tokens, start_time in zip(list_tokens, list_start_times): + note_events, tie_note_events, last_activity, events, err_cnt = self.decode( + tokens, start_time, return_events) + zipped_note_events_and_tie.append((note_events, tie_note_events, last_activity, start_time)) + if return_events: + list_events.append(events) + total_err_cnt += err_cnt + + if return_events: + return zipped_note_events_and_tie, list_events, total_err_cnt + else: + return zipped_note_events_and_tie, total_err_cnt diff --git a/amt/src/utils/utils.py b/amt/src/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ae6ce712694e7c5dd3ea461b8d4958a933e5a7aa --- /dev/null +++ b/amt/src/utils/utils.py @@ -0,0 +1,560 @@ +# Copyright 2024 The YourMT3 Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Please see the details in the LICENSE file. +import os +import json +import time +import hashlib +import requests +import tarfile +import warnings +import argparse +from typing import Tuple, Union, Optional, List, Dict, Any +from tqdm import tqdm +import numpy as np +from collections import Counter +from utils.note_event_dataclasses import Note +from utils.note2event import note2note_event +from utils.midi import note_event2midi +from utils.note2event import slice_multiple_note_events_and_ties_to_bundle +from utils.event2note import merge_zipped_note_events_and_ties_to_notes +from utils.metrics import compute_track_metrics +from utils.tokenizer import EventTokenizer, NoteEventTokenizer +from utils.note_event_dataclasses import Note, NoteEvent, Event +from config.vocabulary import GM_INSTR_FULL, GM_INSTR_CLASS_PLUS +from config.config import shared_cfg + + +def get_checksum(file_path: os.PathLike, buffer_size: int = 65536) -> str: + md5 = hashlib.md5() + with open(file_path, "rb") as f: + while True: + data = f.read(buffer_size) + if not data: + break + md5.update(data) + return md5.hexdigest() + + +def download_and_extract(data_home: os.PathLike, + url: str, + remove_tar_file: bool = True, + check_sum: Optional[str] = None, + zenodo_token: Optional[str] = None) -> None: + + file_name = url.split("/")[-1].split("?")[0] + tar_path = os.path.join(data_home, file_name) + + if not os.path.exists(data_home): + os.makedirs(data_home) + + if zenodo_token is not None: + url_with_token = f"{url}&token={zenodo_token}" if "?download=1" in url else f"{url}?token={zenodo_token}" + else: + url_with_token = url + + response = requests.get(url_with_token, stream=True) + + # Check HTTP Status + if response.status_code != 200: + print(f"Failed to download file. Status code: {response.status_code}") + return + + total_size = int(response.headers.get('content-length', 0)) + + with open(tar_path, "wb") as f: + for chunk in tqdm(response.iter_content(chunk_size=8192), total=total_size // 8192, unit='KB', desc=file_name): + f.write(chunk) + + _check_sum = get_checksum(tar_path) + print(f"Checksum (md5): {_check_sum}") + + if check_sum is not None and check_sum != _check_sum: + raise ValueError(f"Checksum doesn't match! Expected: {check_sum}, Actual: {_check_sum}") + + with tarfile.open(tar_path, "r:gz") as tar: + tar.extractall(data_home) + + if remove_tar_file: + os.remove(tar_path) + + +def create_inverse_vocab(vocab: Dict) -> Dict: + inverse_vocab = {} + for k, vnp in vocab.items(): + for v in vnp: + inverse_vocab[v] = (vnp[0], k) # (program, str_instrument_name) + return inverse_vocab + + +def create_program2channel_vocab(program_vocab: Dict, drum_program: int = 128, force_assign_13_ch: bool = False): + """ + Create a direct map for programs to indices, instrument groups, and primary programs. + + Args: + program_vocab (dict): A dictionary of program vocabularies. + drum_program (int): The program number for drums. Default: 128. + + Returns: + program2channel_vocab (dict): A dictionary of program to indices, instrument groups, and primary programs. + e.g. { + 0: {'channel': 0, 'instrument_group': 'Piano', 'primary_program': 0}, + 1: {'channel': 1, 'instrument_group': 'Chromatic Percussion', 'primary_program': 8}, + ... + 100: {'channel': 11, 'instrument_group': 'Singing Voice', 'primary_program': 100}, + 128: {'channel': 12, 'instrument_group': 'Drums', 'primary_program': 128} + } + "primary_program" is not used now. + + num_channels (int): The number of channels. Typically length of program vocab + 1 (for drums) + + """ + num_channels = len(program_vocab) + 1 + program2channel_vocab = {} + for idx, (instrument_group, programs) in enumerate(program_vocab.items()): + if idx > num_channels: + raise ValueError( + f"📕 The number of channels ({num_channels}) is less than the number of instrument groups ({idx})") + for program in programs: + if program in program2channel_vocab: + raise ValueError(f"📕 program {program} is duplicated in program_vocab") + else: + program2channel_vocab[program] = { + "channel": int(idx), + "instrument_group": str(instrument_group), + "primary_program": int(programs[0]), + } + + # Add drums + if drum_program in program2channel_vocab.keys(): + raise ValueError( + f"📕 drum_program {drum_program} is duplicated in program_vocab. program_vocab should not include drum or program 128" + ) + else: + program2channel_vocab[drum_program] = { + "channel": idx + 1, + "instrument_group": "Drums", + "primary_program": drum_program, + } + return program2channel_vocab, num_channels + + +def write_model_output_as_npy(data, output_dir, track_id): + output_dir = os.path.join(output_dir, "model_output") + os.makedirs(output_dir, exist_ok=True) + output_file = os.path.join(output_dir, f"output_{track_id}.npy") + np.save(output_file, data, allow_pickle=True) + + +def write_model_output_as_midi(notes: List[Note], + output_dir: os.PathLike, + track_id: str, + output_inverse_vocab: Optional[Dict] = None, + output_dir_suffix: Optional[str] = None) -> None: + + if output_dir_suffix is not None: + output_dir = os.path.join(output_dir, f"model_output/{output_dir_suffix}") + else: + output_dir = os.path.join(output_dir, "model_output") + os.makedirs(output_dir, exist_ok=True) + output_file = os.path.join(output_dir, f"{track_id}.mid") + + if output_inverse_vocab is not None: + # Convert the note events to the output vocabulary + new_notes = [] + for note in notes: + if note.is_drum: + new_notes.append(note) + else: + new_notes.append( + Note(is_drum=note.is_drum, + program=output_inverse_vocab.get(note.program, [note.program])[0], + onset=note.onset, + offset=note.offset, + pitch=note.pitch, + velocity=note.velocity)) + + note_events = note2note_event(new_notes, return_activity=False) + note_event2midi(note_events, output_file, output_inverse_vocab=output_inverse_vocab) + + +def write_err_cnt_as_json( + track_id: str, + output_dir: os.PathLike, + output_dir_suffix: Optional[str] = None, + note_err_cnt: Optional[Counter] = None, + note_event_err_cnt: Optional[Counter] = None, +): + + if output_dir_suffix is not None: + output_dir = os.path.join(output_dir, f"model_output/{output_dir_suffix}") + else: + output_dir = os.path.join(output_dir, "model_output") + os.makedirs(output_dir, exist_ok=True) + output_file = os.path.join(output_dir, f"error_count_{track_id}.json") + + output_dict = {} + if note_err_cnt is not None: + output_dict['note_err_cnt'] = dict(note_err_cnt) + if note_event_err_cnt is not None: + output_dict['note_event_err_cnt'] = dict(note_event_err_cnt) + output_str = json.dumps(output_dict, indent=4) + + with open(output_file, 'w') as json_file: + json_file.write(output_str) + + +class Timer: + """A simple timer class to measure elapsed time. + Usage: + + with Timer() as t: + # Your code here + time.sleep(2) + t.print_elapsed_time() + + """ + + def __init__(self) -> None: + self.start_time = None + self.end_time = None + + def start(self) -> None: + self.start_time = time.time() + + def stop(self) -> None: + self.end_time = time.time() + + def elapsed_time(self) -> float: + if self.start_time is None: + raise ValueError("Timer has not been started yet.") + if self.end_time is None: + raise ValueError("Timer has not been stopped yet.") + return self.end_time - self.start_time + + def print_elapsed_time(self, message: Optional[str] = None) -> float: + elapsed_seconds = self.elapsed_time() + minutes, seconds = divmod(elapsed_seconds, 60) + milliseconds = (elapsed_seconds % 1) * 1000 + if message is not None: + text = f"⏰ {message}: {int(minutes)}m {int(seconds)}s {milliseconds:.2f}ms" + else: + text = f"⏰ elapse time: {int(minutes)}m {int(seconds)}s {milliseconds:.2f}ms" + print(text) + return elapsed_seconds + + def reset(self) -> None: + self.start_time = None + self.end_time = None + + def __enter__(self) -> 'Timer': + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.stop() + + +def merge_file_lists(file_lists: List[Dict]) -> Dict[int, Any]: + """ Merge file lists from different datasets, and return a reindexed + dictionary of file list.""" + merged_file_list = {} + index = 0 + for file_list in file_lists: + for v in file_list.values(): + merged_file_list[index] = v + index += 1 + return merged_file_list + + +def merge_splits(splits: List[str], dataset_name: Union[str, List[str]]) -> Dict[int, Any]: + """ + merge_splits: + - Merge multiple splits from different datasets, and return a reindexed + dictionary of file list. + - It is also possible to merge splits from different datasets. + + """ + n_splits = len(splits) + if n_splits > 1 and isinstance(dataset_name, str): + dataset_name = [dataset_name] * n_splits + elif n_splits > 1 and isinstance(dataset_name, list) and len(dataset_name) != n_splits: + raise ValueError("The number of dataset names in list must be equal to the number of splits.") + else: + pass + + # load file_list dictionaries + data_home = shared_cfg['PATH']['data_home'] + file_lists = [] # list of dictionaries + for i, s in enumerate(splits): + json_file = (f"{data_home}/yourmt3_indexes/{dataset_name[i]}_{s}_file_list.json") + + # Fix for missing file_list with incomplete dataset package + if not os.path.exists(json_file): + warnings.warn( + f"File list {json_file} does not exist. If you don't have a complete package of dataset, ignore this warning..." + ) + return {} + + with open(json_file, 'r') as j: + file_lists.append(json.load(j)) + merged_file_list = merge_file_lists(file_lists) # reindexed, merged file list + return merged_file_list + + +def reindex_file_list_keys(file_list: Dict[str, Any]) -> Dict[int, Any]: + """ Reindex file list keys from 0 to total count.""" + reindexed_file_list = {} + for i, (k, v) in enumerate(file_list.items()): + reindexed_file_list[i] = v + return reindexed_file_list + + +def remove_ids_from_file_list(file_list: Dict[str, Any], + selected_ids: List[int], + reindex: bool = True) -> Dict[int, Any]: + """ Remove selected ids from file list.""" + key = None + for v in file_list.values(): + # search keys that contain 'id' + for k in v.keys(): + if 'id' in k: + key = k + break + if key: + break + + if key is None: + raise ValueError("No key contains 'id'.") + + # generate new filelist by removing selected ids + selected_ids = [str(id) for id in selected_ids] # ids to str + file_list = {k: v for k, v in file_list.items() if str(v[key]) not in selected_ids} + if reindex: + return reindex_file_list_keys(file_list) + else: + return file_list + + +def deduplicate_splits(split_a: Union[str, Dict], + split_b: Union[str, Dict], + dataset_name: Optional[str] = None, + reindex: bool = True) -> Dict[int, Any]: + """Remove overlapping splits in file_list A with splits from file_list B, + and return a reindexed dictionary of files.""" + data_home = shared_cfg['PATH']['data_home'] + + if isinstance(split_a, str): + json_file_a = (f"{data_home}/yourmt3_indexes/{dataset_name}_{split_a}_file_list.json") + with open(json_file_a, 'r') as j: + file_list_a = json.load(j) + elif isinstance(split_a, dict): + file_list_a = split_a + + if isinstance(split_b, str): + json_file_b = (f"{data_home}/yourmt3_indexes/{dataset_name}_{split_b}_file_list.json") + with open(json_file_b, 'r') as j: + file_list_b = json.load(j) + elif isinstance(split_b, dict): + file_list_b = split_b + + # Get the key that contains 'id' from file_list_a splits + id_key = None + for v in file_list_a.values(): + for k in v.keys(): + if 'id' in k: + id_key = k + break + if id_key: + break + if id_key is None: + raise ValueError("No key contains 'id' in file_list_a.") + + # Get IDs from file_list_b splits + ids_b = set(str(v.get(id_key, '')) for v in file_list_b.values()) + + # Extract IDs from file_list_a splits + ids_a = [str(v.get(id_key, '')) for v in file_list_a.values()] + + # Remove IDs from file_list_a that are also in file_list_b + ids_to_remove = list(set(ids_a).intersection(ids_b)) + filtered_file_list_a = remove_ids_from_file_list(file_list_a, ids_to_remove, reindex) + + return filtered_file_list_a + + +def merge_vocab(vocab_list: List[Dict]) -> Dict[str, Any]: + """ Merge file lists from different datasets, and return a reindexed + dictionary of file list.""" + merged_vocab = {} + for vocab in vocab_list: + for k, v in vocab.items(): + if k not in merged_vocab.keys(): + merged_vocab[k] = v + return merged_vocab + + +def assert_note_events_almost_equal(actual_note_events, + predicted_note_events, + ignore_time=False, + ignore_activity=True, + delta=5.1e-3): + """ + Asserts that the given lists of Note instances are equal up to a small + floating-point tolerance, similar to `assertAlmostEqual` of `unittest`. + Tolerance is 5e-3 by default, which is 5 ms for 100 ticks-per-second. + + If `ignore_time` is True, then the time field is ignored. (useful for + comparing tie note events, default is False) + + If `ignore_activity` is True, then the activity field is ignored (default + is True). + """ + assert len(actual_note_events) == len(predicted_note_events) + for j, (actual_note_event, predicted_note_event) in enumerate(zip(actual_note_events, predicted_note_events)): + if ignore_time is False: + assert abs(actual_note_event.time - predicted_note_event.time) <= delta, (j, actual_note_event, + predicted_note_event) + assert actual_note_event.is_drum == predicted_note_event.is_drum, (j, actual_note_event, predicted_note_event) + assert actual_note_event.program == predicted_note_event.program, (j, actual_note_event, predicted_note_event) + assert actual_note_event.pitch == predicted_note_event.pitch, (j, actual_note_event, predicted_note_event) + assert actual_note_event.velocity == predicted_note_event.velocity, (j, actual_note_event, predicted_note_event) + if ignore_activity is False: + assert actual_note_event.activity == predicted_note_event.activity, (j, actual_note_event, + predicted_note_event) + + +def note_event2token2note_event_sanity_check(note_events: List[NoteEvent], + notes: List[Note], + report_err_cnt=False) -> Counter: + # slice note events + max_time = note_events[-1].time + num_segs = int(max_time * 16000 // 32757 + 1) + seg_len_sec = 32767 / 16000 + start_times = [i * seg_len_sec for i in range(num_segs)] + note_event_segments = slice_multiple_note_events_and_ties_to_bundle( + note_events, + start_times, + seg_len_sec, + ) + + # encode + tokenizer = NoteEventTokenizer() + token_array = np.zeros((num_segs, 1024), dtype=np.int32) + for i, tup in enumerate(list(zip(*note_event_segments.values()))): + padded_tokens = tokenizer.encode_plus(*tup) + token_array[i, :] = padded_tokens + + # decode: warning: Invalid pitch event without program or velocity --> solved + zipped_note_events_and_tie, list_events, err_cnt = tokenizer.decode_list_batches([token_array], + start_times, + return_events=True) + if report_err_cnt: + # report error and do not break.. + err_cnt_all = err_cnt + else: + assert len(err_cnt) == 0 + err_cnt_all = Counter() + + # First check, the number of empty note_events and tie_note_events + cnt_org_empty = 0 + cnt_recon_empty = 0 + for i, (recon_note_events, recon_tie_note_events, _, _) in enumerate(zipped_note_events_and_tie): + org_note_events = note_event_segments['note_events'][i] + org_tie_note_events = note_event_segments['tie_note_events'][i] + if org_note_events == []: + cnt_org_empty += 1 + if recon_note_events == []: + cnt_recon_empty += 1 + + # assert len(org_note_events) == len(recon_note_events) # passed after bug fix + + # Check the reconstruction of note_events + for i, (recon_note_events, recon_tie_note_events, _, _) in enumerate(zipped_note_events_and_tie): + org_note_events = note_event_segments['note_events'][i] + org_tie_note_events = note_event_segments['tie_note_events'][i] + + org_note_events.sort(key=lambda n_ev: (n_ev.time, n_ev.is_drum, n_ev.program, n_ev.velocity, n_ev.pitch)) + org_tie_note_events.sort(key=lambda n_ev: (n_ev.program, n_ev.pitch)) + recon_note_events.sort(key=lambda n_ev: (n_ev.time, n_ev.is_drum, n_ev.program, n_ev.velocity, n_ev.pitch)) + recon_tie_note_events.sort(key=lambda n_ev: (n_ev.program, n_ev.pitch)) + + #assert_note_events_almost_equal(org_note_events, recon_note_events) + # assert_note_events_almost_equal( + # org_tie_note_events, recon_tie_note_events, ignore_time=True) + + # Check notes: of course this fails.. and a lot of warning for cut off 20s + recon_notes, err_cnt = merge_zipped_note_events_and_ties_to_notes(zipped_note_events_and_tie, fix_offset=False) + # assert len(err_cnt) == 0 # this error is due to the cut off 5 seconds... + + # Check metric + drum_metric, non_drum_metric, instr_metric = compute_track_metrics(recon_notes, + notes, + eval_vocab=GM_INSTR_FULL, + onset_tolerance=0.005) # 5ms + if not np.isnan(non_drum_metric['offset_f']) and non_drum_metric['offset_f'] != 1.0: + warnings.warn(f"non_drum_metric['offset_f'] = {non_drum_metric['offset_f']}") + assert non_drum_metric['onset_f'] > 0.99 + if not np.isnan(drum_metric['onset_f_drum']) and non_drum_metric['offset_f'] != 1.0: + warnings.warn(f"drum_metric['offset_f'] = {drum_metric['offset_f']}") + assert drum_metric['offset_f'] > 0.99 + return err_cnt_all + Counter(err_cnt) + + +def str2bool(v): + """ + Converts a string value to a boolean value. + + Args: + v: The string value to convert. + + Returns: + The boolean value equivalent of the input string. + + Raises: + ArgumentTypeError: If the input string is not a valid boolean value. + """ + if v.lower() in ('yes', 'true', 't', 'y', '1'): + return True + elif v.lower() in ('no', 'false', 'f', 'n', '0'): + return False + else: + raise argparse.ArgumentTypeError('Boolean value expected.') + + +def freq_to_midi(freq): + return round(69 + 12 * np.log2(freq / 440)) + + +def dict_iterator(d: Dict): + """ + This function is used to iterate over a dictionary of lists. + As an output, it yields a newly created instance of a dictionary + """ + for values in zip(*d.values()): + yield {k: [v] for k, v in zip(d.keys(), values)} + + +def extend_dict(dict1: dict, dict2: dict) -> None: + """ + Extends the lists in dict1 with the corresponding lists in dict2. + Modifies dict1 in-place and does not return anything. + + Args: + dict1 (dict): The dictionary to be extended. + dict2 (dict): The dictionary with lists to extend dict1. + + Example: + dict1 = {'a': [1,2,3], 'b':[4,5,6]} + dict2 = {'a':[10], 'b':[17]} + extend_dict_in_place(dict1, dict2) + print(dict1) # Outputs: {'a': [1, 2, 3, 10], 'b': [4, 5, 6, 17]} + """ + for key in dict1: + dict1[key].extend(dict2[key])