Spaces:
Sleeping
Sleeping
# Copyright 2024 The YourMT3 Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Please see the details in the LICENSE file. | |
""" data_modules.py """ | |
from typing import Optional, Dict, List, Any | |
import os | |
import numpy as np | |
from pytorch_lightning import LightningDataModule | |
from pytorch_lightning.utilities import CombinedLoader | |
from utils.datasets_train import get_cache_data_loader | |
from utils.datasets_eval import get_eval_dataloader | |
from utils.datasets_helper import create_merged_train_dataset_info, get_list_of_weighted_random_samplers | |
from utils.task_manager import TaskManager | |
from config.config import shared_cfg | |
from config.config import audio_cfg as default_audio_cfg | |
from config.data_presets import data_preset_single_cfg, data_preset_multi_cfg | |
class AMTDataModule(LightningDataModule): | |
def __init__( | |
self, | |
data_home: Optional[os.PathLike] = None, | |
data_preset_multi: Dict[str, Any] = { | |
"presets": ["musicnet_mt3_synth_only"], | |
}, # only allowing multi_preset_cfg. single_preset_cfg should be converted to multi_preset_cfg | |
task_manager: TaskManager = TaskManager(task_name="mt3_full_plus"), | |
train_num_samples_per_epoch: Optional[int] = None, | |
train_random_amp_range: List[float] = [0.6, 1.2], | |
train_stem_iaug_prob: Optional[float] = 0.7, | |
train_stem_xaug_policy: Optional[Dict] = { | |
"max_k": 3, | |
"tau": 0.3, | |
"alpha": 1.0, | |
"max_subunit_stems": 12, # the number of subunit stems to be reduced to this number of stems | |
"p_include_singing": | |
0.8, # probability of including singing for cross augmented examples. if None, use base probaility. | |
"no_instr_overlap": True, | |
"no_drum_overlap": True, | |
"uhat_intra_stem_augment": True, | |
}, | |
train_pitch_shift_range: Optional[List[int]] = None, | |
audio_cfg: Optional[Dict] = None) -> None: | |
super().__init__() | |
# check path existence | |
if data_home is None: | |
data_home = shared_cfg["PATH"]["data_home"] | |
if os.path.exists(data_home): | |
self.data_home = data_home | |
else: | |
raise ValueError(f"Invalid data_home: {data_home}") | |
self.preset_multi = data_preset_multi | |
self.preset_singles = [] | |
# e.g. [{"dataset_name": ..., "train_split": ..., "validation_split":...,}, {...}] | |
for dp in self.preset_multi["presets"]: | |
if dp not in data_preset_single_cfg.keys(): | |
raise ValueError("Invalid data_preset") | |
self.preset_singles.append(data_preset_single_cfg[dp]) | |
# task manager | |
self.task_manager = task_manager | |
# train num samples per epoch, passed to the sampler | |
self.train_num_samples_per_epoch = train_num_samples_per_epoch | |
assert shared_cfg["BSZ"]["train_local"] % shared_cfg["BSZ"]["train_sub"] == 0 | |
self.num_train_samplers = shared_cfg["BSZ"]["train_local"] // shared_cfg["BSZ"]["train_sub"] | |
# train augmentation parameters | |
self.train_random_amp_range = train_random_amp_range | |
self.train_stem_iaug_prob = train_stem_iaug_prob | |
self.train_stem_xaug_policy = train_stem_xaug_policy | |
self.train_pitch_shift_range = train_pitch_shift_range | |
# train data info | |
self.train_data_info = None # to be set in setup() | |
# validation/test max num of files | |
self.val_max_num_files = data_preset_multi.get("val_max_num_files", None) | |
self.test_max_num_files = data_preset_multi.get("test_max_num_files", None) | |
# audio config | |
self.audio_cfg = audio_cfg if audio_cfg is not None else default_audio_cfg | |
def set_merged_train_data_info(self) -> None: | |
"""Collect train datasets and create info... | |
self.train_dataset_info = { | |
"n_datasets": 0, | |
"n_notes_per_dataset": [], | |
"n_files_per_dataset": [], | |
"dataset_names": [], # dataset names by order of merging file lists | |
"train_split_names": [], # train split names by order of merging file lists | |
"index_ranges": [], # index ranges of each dataset in the merged file list | |
"dataset_weights": [], # pre-defined list of dataset weights for sampling, if available | |
"merged_file_list": {}, | |
} | |
""" | |
self.train_data_info = create_merged_train_dataset_info(self.preset_multi) | |
print( | |
f"AMTDataModule: Added {len(self.train_data_info['merged_file_list'])} files from {self.train_data_info['n_datasets']} datasets to the training set." | |
) | |
def setup(self, stage: str): | |
""" | |
Prepare data args for the dataloaders to be used on each stage. | |
`stage` is automatically passed by pytorch lightning Trainer. | |
""" | |
if stage == "fit": | |
# Set up train data info | |
self.set_merged_train_data_info() | |
# Distributed Weighted random sampler for training | |
actual_train_num_samples_per_epoch = self.train_num_samples_per_epoch // shared_cfg["BSZ"][ | |
"train_local"] if self.train_num_samples_per_epoch else None | |
samplers = get_list_of_weighted_random_samplers(num_samplers=self.num_train_samplers, | |
dataset_weights=self.train_data_info["dataset_weights"], | |
dataset_index_ranges=self.train_data_info["index_ranges"], | |
num_samples_per_epoch=actual_train_num_samples_per_epoch) | |
# Train dataloader arguments | |
self.train_data_args = [] | |
for sampler in samplers: | |
self.train_data_args.append({ | |
"dataset_name": None, | |
"split": None, | |
"file_list": self.train_data_info["merged_file_list"], | |
"sub_batch_size": shared_cfg["BSZ"]["train_sub"], | |
"task_manager": self.task_manager, | |
"random_amp_range": self.train_random_amp_range, # "0.1,0.5 | |
"stem_iaug_prob": self.train_stem_iaug_prob, | |
"stem_xaug_policy": self.train_stem_xaug_policy, | |
"pitch_shift_range": self.train_pitch_shift_range, | |
"shuffle": True, | |
"sampler": sampler, | |
"audio_cfg": self.audio_cfg, | |
}) | |
# Validation dataloader arguments | |
self.val_data_args = [] | |
for preset_single in self.preset_singles: | |
if preset_single["validation_split"] != None: | |
self.val_data_args.append({ | |
"dataset_name": preset_single["dataset_name"], | |
"split": preset_single["validation_split"], | |
"task_manager": self.task_manager, | |
# "tokenizer": self.task_manager.get_tokenizer(), | |
"max_num_files": self.val_max_num_files, | |
"audio_cfg": self.audio_cfg, | |
}) | |
if stage == "test": | |
self.test_data_args = [] | |
for preset_single in self.preset_singles: | |
if preset_single["test_split"] != None: | |
self.test_data_args.append({ | |
"dataset_name": preset_single["dataset_name"], | |
"split": preset_single["test_split"], | |
"task_manager": self.task_manager, | |
"max_num_files": self.test_max_num_files, | |
"audio_cfg": self.audio_cfg, | |
}) | |
def train_dataloader(self) -> Any: | |
loaders = {} | |
for i, args_dict in enumerate(self.train_data_args): | |
loaders[f"data_loader_{i}"] = get_cache_data_loader(**args_dict, dataloader_config=shared_cfg["DATAIO"]) | |
return CombinedLoader(loaders, mode="min_size") # size is always identical | |
def val_dataloader(self) -> Any: | |
loaders = {} | |
for args_dict in self.val_data_args: | |
dataset_name = args_dict["dataset_name"] | |
loaders[dataset_name] = get_eval_dataloader(**args_dict, dataloader_config=shared_cfg["DATAIO"]) | |
return loaders | |
def test_dataloader(self) -> Any: | |
loaders = {} | |
for args_dict in self.test_data_args: | |
dataset_name = args_dict["dataset_name"] | |
loaders[dataset_name] = get_eval_dataloader(**args_dict, dataloader_config=shared_cfg["DATAIO"]) | |
return loaders | |
"""CombinedLoader in "sequential" mode returns dataloader_idx to the | |
trainer, which is used to get the dataset name in the logger. """ | |
def num_val_dataloaders(self) -> int: | |
return len(self.val_data_args) | |
def num_test_dataloaders(self) -> int: | |
return len(self.test_data_args) | |
def get_val_dataset_name(self, dataloader_idx: int) -> str: | |
return self.val_data_args[dataloader_idx]["dataset_name"] | |
def get_test_dataset_name(self, dataloader_idx: int) -> str: | |
return self.test_data_args[dataloader_idx]["dataset_name"] | |