# Copyright 2024 The YourMT3 Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Please see the details in the LICENSE file. """ data_modules.py """ from typing import Optional, Dict, List, Any import os import numpy as np from pytorch_lightning import LightningDataModule from pytorch_lightning.utilities import CombinedLoader from utils.datasets_train import get_cache_data_loader from utils.datasets_eval import get_eval_dataloader from utils.datasets_helper import create_merged_train_dataset_info, get_list_of_weighted_random_samplers from utils.task_manager import TaskManager from config.config import shared_cfg from config.config import audio_cfg as default_audio_cfg from config.data_presets import data_preset_single_cfg, data_preset_multi_cfg class AMTDataModule(LightningDataModule): def __init__( self, data_home: Optional[os.PathLike] = None, data_preset_multi: Dict[str, Any] = { "presets": ["musicnet_mt3_synth_only"], }, # only allowing multi_preset_cfg. single_preset_cfg should be converted to multi_preset_cfg task_manager: TaskManager = TaskManager(task_name="mt3_full_plus"), train_num_samples_per_epoch: Optional[int] = None, train_random_amp_range: List[float] = [0.6, 1.2], train_stem_iaug_prob: Optional[float] = 0.7, train_stem_xaug_policy: Optional[Dict] = { "max_k": 3, "tau": 0.3, "alpha": 1.0, "max_subunit_stems": 12, # the number of subunit stems to be reduced to this number of stems "p_include_singing": 0.8, # probability of including singing for cross augmented examples. if None, use base probaility. "no_instr_overlap": True, "no_drum_overlap": True, "uhat_intra_stem_augment": True, }, train_pitch_shift_range: Optional[List[int]] = None, audio_cfg: Optional[Dict] = None) -> None: super().__init__() # check path existence if data_home is None: data_home = shared_cfg["PATH"]["data_home"] if os.path.exists(data_home): self.data_home = data_home else: raise ValueError(f"Invalid data_home: {data_home}") self.preset_multi = data_preset_multi self.preset_singles = [] # e.g. [{"dataset_name": ..., "train_split": ..., "validation_split":...,}, {...}] for dp in self.preset_multi["presets"]: if dp not in data_preset_single_cfg.keys(): raise ValueError("Invalid data_preset") self.preset_singles.append(data_preset_single_cfg[dp]) # task manager self.task_manager = task_manager # train num samples per epoch, passed to the sampler self.train_num_samples_per_epoch = train_num_samples_per_epoch assert shared_cfg["BSZ"]["train_local"] % shared_cfg["BSZ"]["train_sub"] == 0 self.num_train_samplers = shared_cfg["BSZ"]["train_local"] // shared_cfg["BSZ"]["train_sub"] # train augmentation parameters self.train_random_amp_range = train_random_amp_range self.train_stem_iaug_prob = train_stem_iaug_prob self.train_stem_xaug_policy = train_stem_xaug_policy self.train_pitch_shift_range = train_pitch_shift_range # train data info self.train_data_info = None # to be set in setup() # validation/test max num of files self.val_max_num_files = data_preset_multi.get("val_max_num_files", None) self.test_max_num_files = data_preset_multi.get("test_max_num_files", None) # audio config self.audio_cfg = audio_cfg if audio_cfg is not None else default_audio_cfg def set_merged_train_data_info(self) -> None: """Collect train datasets and create info... self.train_dataset_info = { "n_datasets": 0, "n_notes_per_dataset": [], "n_files_per_dataset": [], "dataset_names": [], # dataset names by order of merging file lists "train_split_names": [], # train split names by order of merging file lists "index_ranges": [], # index ranges of each dataset in the merged file list "dataset_weights": [], # pre-defined list of dataset weights for sampling, if available "merged_file_list": {}, } """ self.train_data_info = create_merged_train_dataset_info(self.preset_multi) print( f"AMTDataModule: Added {len(self.train_data_info['merged_file_list'])} files from {self.train_data_info['n_datasets']} datasets to the training set." ) def setup(self, stage: str): """ Prepare data args for the dataloaders to be used on each stage. `stage` is automatically passed by pytorch lightning Trainer. """ if stage == "fit": # Set up train data info self.set_merged_train_data_info() # Distributed Weighted random sampler for training actual_train_num_samples_per_epoch = self.train_num_samples_per_epoch // shared_cfg["BSZ"][ "train_local"] if self.train_num_samples_per_epoch else None samplers = get_list_of_weighted_random_samplers(num_samplers=self.num_train_samplers, dataset_weights=self.train_data_info["dataset_weights"], dataset_index_ranges=self.train_data_info["index_ranges"], num_samples_per_epoch=actual_train_num_samples_per_epoch) # Train dataloader arguments self.train_data_args = [] for sampler in samplers: self.train_data_args.append({ "dataset_name": None, "split": None, "file_list": self.train_data_info["merged_file_list"], "sub_batch_size": shared_cfg["BSZ"]["train_sub"], "task_manager": self.task_manager, "random_amp_range": self.train_random_amp_range, # "0.1,0.5 "stem_iaug_prob": self.train_stem_iaug_prob, "stem_xaug_policy": self.train_stem_xaug_policy, "pitch_shift_range": self.train_pitch_shift_range, "shuffle": True, "sampler": sampler, "audio_cfg": self.audio_cfg, }) # Validation dataloader arguments self.val_data_args = [] for preset_single in self.preset_singles: if preset_single["validation_split"] != None: self.val_data_args.append({ "dataset_name": preset_single["dataset_name"], "split": preset_single["validation_split"], "task_manager": self.task_manager, # "tokenizer": self.task_manager.get_tokenizer(), "max_num_files": self.val_max_num_files, "audio_cfg": self.audio_cfg, }) if stage == "test": self.test_data_args = [] for preset_single in self.preset_singles: if preset_single["test_split"] != None: self.test_data_args.append({ "dataset_name": preset_single["dataset_name"], "split": preset_single["test_split"], "task_manager": self.task_manager, "max_num_files": self.test_max_num_files, "audio_cfg": self.audio_cfg, }) def train_dataloader(self) -> Any: loaders = {} for i, args_dict in enumerate(self.train_data_args): loaders[f"data_loader_{i}"] = get_cache_data_loader(**args_dict, dataloader_config=shared_cfg["DATAIO"]) return CombinedLoader(loaders, mode="min_size") # size is always identical def val_dataloader(self) -> Any: loaders = {} for args_dict in self.val_data_args: dataset_name = args_dict["dataset_name"] loaders[dataset_name] = get_eval_dataloader(**args_dict, dataloader_config=shared_cfg["DATAIO"]) return loaders def test_dataloader(self) -> Any: loaders = {} for args_dict in self.test_data_args: dataset_name = args_dict["dataset_name"] loaders[dataset_name] = get_eval_dataloader(**args_dict, dataloader_config=shared_cfg["DATAIO"]) return loaders """CombinedLoader in "sequential" mode returns dataloader_idx to the trainer, which is used to get the dataset name in the logger. """ @property def num_val_dataloaders(self) -> int: return len(self.val_data_args) @property def num_test_dataloaders(self) -> int: return len(self.test_data_args) def get_val_dataset_name(self, dataloader_idx: int) -> str: return self.val_data_args[dataloader_idx]["dataset_name"] def get_test_dataset_name(self, dataloader_idx: int) -> str: return self.test_data_args[dataloader_idx]["dataset_name"]