Spaces:
Sleeping
Sleeping
# Copyright 2024 The YourMT3 Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Please see the details in the LICENSE file. | |
import json | |
import os | |
from typing import Dict, Any, Union, Tuple, Optional | |
import torch | |
import numpy as np | |
from einops import rearrange | |
from torch.utils.data import DataLoader, Dataset | |
from utils.audio import load_audio_file, slice_padded_array | |
from utils.tokenizer import EventTokenizerBase, NoteEventTokenizer | |
from utils.note2event import slice_multiple_note_events_and_ties_to_bundle | |
from utils.note_event_dataclasses import Note, NoteEvent, NoteEventListsBundle | |
from utils.task_manager import TaskManager | |
from config.config import shared_cfg | |
from config.config import audio_cfg as default_audio_cfg | |
UNANNOTATED_PROGRAM = 129 | |
class AudioFileDataset(Dataset): | |
""" | |
🎧 AudioFileDataset for validation/test: | |
This dataset class is designed to be used ONLY with `batch_size=None` and | |
returns sliced audio segments and unsliced notes and sliced note events for | |
a single song when `__getitem__` is called. | |
Args: | |
file_list (Union[str, bytes, os.PathLike], optional): | |
Path to the file list. e.g. "../../data/yourmt3_indexes/slakh_validation_file_list.json" | |
task_manager (TaskManager, optional): TaskManager instance. Defaults to TaskManager(). | |
fs (int, optional): Sampling rate. Defaults to 16000. | |
seg_len_frame (int, optional): Segment length in frames. Defaults to 32767. | |
seg_hop_frame (int, optional): Segment hop in frames. Defaults to 32767. | |
sub_batch_size (int, optional): Sub-batch size that will be used in | |
generation of tokens. Defaults to 32. | |
max_num_files (int, optional): Maximum number of files to be loaded. Defaults to None. | |
Variables: | |
file_list: | |
'{dataset_name}_{split}_file_list.json' has the following keys: | |
{ | |
'index': | |
{ | |
'mtrack_id': mtrack_id, | |
'n_frames': n of audio frames | |
'stem_file': Dict of stem audio file info | |
'mix_audio_file': mtrack.mix_path, | |
'notes_file': available only for 'validation' and 'test' | |
'note_events_file': available only for 'train' and 'validation' | |
'midi_file': mtrack.midi_path | |
} | |
} | |
__getitem__(index) returns: | |
audio_segment: | |
torch.FloatTensor: (nearest_N_divisable_by_sub_batch_size, 1, seg_len_frame) | |
notes_dict: | |
{ | |
'mtrack_id': int, | |
'program': List[int], | |
'is_drum': bool, | |
'duration_sec': float, | |
'notes': List[Note], | |
} | |
token_array: | |
torch.LongTensor: (n_segments, seg_len_frame) | |
""" | |
def __init__( | |
self, | |
file_list: Union[str, bytes, os.PathLike], | |
task_manager: TaskManager = TaskManager(), | |
# tokenizer: Optional[EventTokenizerBase] = None, | |
fs: int = 16000, | |
seg_len_frame: int = 32767, | |
seg_hop_frame: int = 32767, | |
max_num_files: Optional[int] = None) -> None: | |
# load the file list | |
with open(file_list, 'r') as f: | |
fl = json.load(f) | |
file_list = {int(key): value for key, value in fl.items()} | |
if max_num_files: # reduce the number of files | |
self.file_list = dict(list(file_list.items())[:max_num_files]) | |
else: | |
self.file_list = file_list | |
self.fs = fs | |
self.seg_len_frame = seg_len_frame | |
self.seg_len_sec = seg_len_frame / fs | |
self.seg_hop_frame = seg_hop_frame | |
self.task_manager = task_manager | |
def __getitem__(self, index: int) -> Tuple[np.ndarray, Dict, NoteEventListsBundle]: | |
# get metadata | |
metadata = self.file_list[index] | |
audio_file = metadata['mix_audio_file'] | |
notes_file = metadata['notes_file'] | |
note_events_file = metadata['note_events_file'] | |
# load the audio | |
audio = load_audio_file(audio_file, dtype=np.int16) # returns bytes | |
audio = audio / 2**15 | |
audio = audio.astype(np.float32) | |
audio = audio.reshape(1, -1) | |
audio_segments = slice_padded_array( | |
audio, | |
self.seg_len_frame, | |
self.seg_hop_frame, | |
pad=True, | |
) # (n_segs, seg_len_frame) | |
audio_segments = rearrange(audio_segments, 'n t -> n 1 t').astype(np.float32) | |
num_segs = audio_segments.shape[0] | |
# load all notes and from a file (of a single song) | |
notes_dict = np.load(notes_file, allow_pickle=True, fix_imports=False).tolist() | |
# TODO: add midi_file path in preprocessing instead of here | |
notes_dict['midi_file'] = metadata['midi_file'] | |
# tokenize note_events | |
note_events_dict = np.load(note_events_file, allow_pickle=True, fix_imports=False).tolist() | |
if self.task_manager.tokenizer is not None: | |
# not using seg_len_sec to avoid accumulated rounding errors | |
start_times = [i * self.seg_hop_frame / self.fs for i in range(num_segs)] | |
note_event_segments = slice_multiple_note_events_and_ties_to_bundle( | |
note_events_dict['note_events'], | |
start_times, | |
self.seg_len_sec, | |
) | |
# Support for multi-channel decoding | |
if UNANNOTATED_PROGRAM in notes_dict['program']: | |
has_unannotated_segments = [True] * num_segs | |
else: | |
has_unannotated_segments = [False] * num_segs | |
token_array = self.task_manager.tokenize_note_events_batch(note_event_segments, | |
start_time_to_zero=False, | |
sort=True) | |
# note_token_array = self.task_manager.tokenize_note_events_batch(note_event_segments, | |
# start_time_to_zero=False, | |
# sort=True) | |
# task_token_array = self.task_manager.tokenize_task_events_batch(note_event_segments, | |
# has_unannotated_segments) | |
# Shape: | |
# processed_audio_array: (num_segs, 1, nframe) | |
# notes_dict: Dict | |
# note_token_array: (num_segs, decoding_ch, max_note_token_len) | |
# task_token_array: (num_segs, decoding_ch, max_task_token_len) | |
# return torch.from_numpy(audio_segments), notes_dict, torch.from_numpy( | |
# note_token_array).long(), torch.from_numpy(task_token_array).long() | |
return torch.from_numpy(audio_segments), notes_dict, torch.from_numpy(token_array).long() | |
# # Tokenize/pad note_event_segments -> array of token and mask | |
# max_len = self.tokenizer.max_length | |
# token_array = np.zeros((num_segs, max_len), dtype=np.int32) | |
# for i, tup in enumerate(list(zip(*note_event_segments.values()))): | |
# padded_tokens = self.tokenizer.encode_plus(*tup) | |
# token_array[i, :] = padded_tokens | |
# return torch.from_numpy(audio_segments), notes_dict, torch.from_numpy(token_array).long() | |
def __len__(self) -> int: | |
return len(self.file_list) | |
def get_eval_dataloader( | |
dataset_name: str, | |
split: str = 'validation', | |
dataloader_config: Dict = {"num_workers": 0}, | |
task_manager: TaskManager = TaskManager(), | |
# tokenizer: Optional[EventTokenizerBase] = NoteEventTokenizer('mt3'), | |
max_num_files: Optional[int] = None, | |
audio_cfg: Optional[Dict] = None, | |
) -> DataLoader: | |
""" | |
🎧 get_audio_file_dataloader: | |
This function returns a dataloader for AudioFileDataset that returns padded slices | |
of audio samples with the divisable number of sub-batch size. | |
""" | |
data_home = shared_cfg["PATH"]["data_home"] | |
file_list = f"{data_home}/yourmt3_indexes/{dataset_name}_{split}_file_list.json" | |
if audio_cfg is None: | |
audio_cfg = default_audio_cfg | |
ds = AudioFileDataset( | |
file_list, | |
task_manager=task_manager, | |
# tokenizer=tokenizer, | |
seg_len_frame=int(audio_cfg["input_frames"]), # Default: 32767 | |
seg_hop_frame=int(audio_cfg["input_frames"]), # Default: 32767 | |
max_num_files=max_num_files) | |
dl = DataLoader(ds, batch_size=None, collate_fn=lambda k: k, **dataloader_config) | |
return dl | |