Spaces:
Running
Running
# Copyright 2024 The YourMT3 Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Please see the details in the LICENSE file. | |
"""preprocess_rnsynth.py | |
RNSynth: Randomly generated note sequences using the NSynth dataset. | |
""" | |
import os | |
import random | |
import glob | |
import json | |
import logging | |
import numpy as np | |
from typing import Dict, Literal, Optional | |
from utils.note_event_dataclasses import Note | |
from utils.audio import get_audio_file_info, load_audio_file, write_wav_file, guess_onset_offset_by_amp_envelope | |
from utils.midi import note_event2midi | |
from utils.note2event import note2note_event, sort_notes, validate_notes, trim_overlapping_notes, mix_notes | |
# yapf: disable | |
QUALITY_VOCAB = [ | |
'bright', 'dark', 'distortion', 'fast_decay', 'long_release', 'multiphonic', 'nonlinear_env', | |
'percussive', 'reverb', 'tempo-synced' | |
] | |
INSTRUMENT_FAMILY_VOCAB = [ | |
'bass', 'brass', 'flute', 'guitar', 'keyboard', 'mallet', 'organ', 'reed', 'string', 'vocal', | |
'synth_lead' | |
] | |
INSTRUMENT_SOURCE_VOCAB = ['acoustic', 'electronic', 'synthetic'] | |
INSTRUMENT_MAPPING = { | |
# key: (instrument_family, instrument_source) | |
('bass', 'acoustic'): {'program': 32, 'channel': 0, 'allow_poly': False,}, | |
('bass', 'electronic'): {'program': 33, 'channel': 0, 'allow_poly': False,}, | |
('bass', 'synthetic'): {'program': 38, 'channel': 0, 'allow_poly': False,}, | |
('brass', 'acoustic'): {'program': 61, 'channel': 1, 'allow_poly': True,}, | |
('brass', 'electronic'): {'program': 62, 'channel': 1, 'allow_poly': True,}, | |
('brass', 'synthetic'): {'program': 62, 'channel': 1, 'allow_poly': True, }, | |
('flute', 'acoustic'): {'program': 73, 'channel': 2, 'allow_poly': False,}, | |
('flute', 'electronic'): {'program': 76, 'channel': 2, 'allow_poly': False,}, | |
('flute', 'synthetic'): {'program': 76, 'channel': 2, 'allow_poly': False,}, | |
('guitar', 'acoustic'): {'program': 24, 'channel': 3, 'allow_poly': True,}, | |
('guitar', 'electronic'): {'program': 27, 'channel': 3, 'allow_poly': True,}, | |
('guitar', 'synthetic'): {'program': 27, 'channel': 3, 'allow_poly': True,}, | |
('keyboard', 'acoustic'): {'program': 0, 'channel': 4, 'allow_poly': True,}, | |
('keyboard', 'electronic'): {'program': 4, 'channel': 4, 'allow_poly': True,}, | |
('keyboard', 'synthetic'): {'program': 80, 'channel': 4, 'allow_poly': True,}, | |
('mallet', 'acoustic'): {'program': 12, 'channel': 5, 'allow_poly': True,}, | |
('mallet', 'electronic'): {'program': 12, 'channel': 5, 'allow_poly': True,}, | |
('mallet', 'synthetic'): {'program': 12, 'channel': 5, 'allow_poly': True,}, | |
('organ', 'acoustic'): {'program': 16, 'channel': 6, 'allow_poly': True,}, | |
('organ', 'electronic'): {'program': 18, 'channel': 6, 'allow_poly': True,}, | |
('organ', 'synthetic'): {'program': 18, 'channel': 6, 'allow_poly': True,}, | |
('reed', 'acoustic'): {'program': 65, 'channel': 7, 'allow_poly': True,}, | |
('reed', 'electronic'): {'program': 83, 'channel': 7, 'allow_poly': True,}, | |
('reed', 'synthetic'): {'program': 83, 'channel': 7, 'allow_poly': True,}, | |
('string', 'acoustic'): {'program': 48, 'channel': 8, 'allow_poly': True,}, | |
('string', 'electronic'): {'program': 50, 'channel': 8, 'allow_poly': True,}, | |
('string', 'synthetic'): {'program': 50, 'channel': 8, 'allow_poly': True,}, | |
# ('vocal', 'acoustic'): [56], | |
# ('vocal', 'electronic'): [56], | |
# ('vocal', 'synthetic'): [56], | |
('synth_lead', 'acoustic'): {'program': 80, 'channel': 9, 'allow_poly': True,}, | |
('synth_lead', 'electronic'): {'program': 80, 'channel': 9, 'allow_poly': True,}, | |
('synth_lead', 'synthetic'): {'program': 80, 'channel': 9, 'allow_poly': True,}, | |
} | |
CHANNEL_INFO = { | |
0: {'name': 'bass', 'max_poly': 1}, | |
1: {'name': 'brass', 'max_poly': 4}, | |
2: {'name': 'flute', 'max_poly': 1}, | |
3: {'name': 'guitar', 'max_poly': 6}, | |
4: {'name': 'keyboard', 'max_poly': 8}, | |
5: {'name': 'mallet', 'max_poly': 4}, | |
6: {'name': 'organ', 'max_poly': 8}, | |
7: {'name': 'reed', 'max_poly': 2}, | |
8: {'name': 'string', 'max_poly': 4}, | |
9: {'name': 'synth_lead', 'max_poly': 2}, | |
} | |
# yapf: enable | |
class RandomNSynthGenerator(object): | |
def __init__(self, channel_info: Dict=CHANNEL_INFO): | |
self.num_channels = len(channel_info) | |
self.channel_info = channel_info | |
self.channel_max_poly = [channel_info[ch]['max_poly'] for ch in range(self.num_channels)] | |
# channel_space_left[ch]: current state of empty space for notes left in channel | |
self.channel_space_left = [0] * self.num_channels | |
for ch in range(self.num_channels): | |
self.reset_space_left(ch) | |
def reset_space_left(self, ch: int): | |
max_poly = self.channel_max_poly[ch] | |
if max_poly == 1: | |
self.channel_space_left[ch] = 1 | |
else: | |
self.channel_space_left[ch] = np.random.randint(1, max_poly + 1 ) | |
def setup_logger(log_file: str) -> logging.Logger: | |
logger = logging.getLogger('my_logger') | |
logger.setLevel(logging.DEBUG) | |
file_handler = logging.FileHandler(log_file) | |
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s') | |
file_handler.setFormatter(formatter) | |
if not logger.handlers: | |
logger.addHandler(file_handler) | |
console_handler = logging.StreamHandler() | |
console_handler.setLevel(logging.DEBUG) | |
console_formatter = logging.Formatter('%(levelname)s - %(message)s') | |
console_handler.setFormatter(console_formatter) | |
logger.addHandler(console_handler) | |
return logger | |
def get_duration_by_detecting_offset(audio_file: os.PathLike, | |
side_info: Optional[str] = None, | |
offset_threshold: float = 0.02) -> float: | |
fs, n_frames, _ = get_audio_file_info(audio_file) | |
x = load_audio_file(audio_file, fs=fs) | |
if side_info is not None and 'fast_decay' in side_info or 'percussive' in side_info: | |
x = x[:int(fs * 2.0)] # limit to 1.5 sec | |
_, offset, _ = guess_onset_offset_by_amp_envelope( | |
x, fs=fs, onset_threshold=0., offset_threshold=offset_threshold, frame_size=128) | |
offset = min(offset, n_frames) | |
dur_sec = np.floor((offset / fs) * 1000) / 1000 | |
return dur_sec | |
def random_key_cycle(d: Dict): | |
keys = list(d.keys()) | |
while True: | |
random.shuffle(keys) | |
for i, key in enumerate(keys): | |
is_last_element = (i == len(keys) - 1) # Check if it's the last element in the cycle | |
yield (d[key], is_last_element) | |
def create_sound_info(base_dir: os.PathLike, logger: logging.Logger, | |
split: Literal['train', 'validation', 'test'], metadata_file: os.PathLike): | |
"""Create a dictionary of sound info from the metadata file.""" | |
with open(metadata_file, 'r') as f: | |
metadata = json.load(f) | |
logger.info(f'Loaded {metadata_file}. Number of examples: {len(metadata)}') | |
# Create a sound_info dictionary | |
sound_info = {} # key: nsynth_id, value: dictionary of sound info | |
count_skipped = 0 | |
skipped_instrument_family = set() | |
for i, (k, v) in enumerate(metadata.items()): | |
if i % 5000 == 0: | |
print(f'Creating sound info {i} / {len(metadata)}') | |
nsynth_id = v['note'] | |
instrument_family = v['instrument_family_str'] | |
instrument_source = v['instrument_source_str'] | |
audio_file = os.path.join(base_dir, split, 'audio', k + '.wav') | |
if not os.path.exists(audio_file): | |
raise FileNotFoundError(audio_file) | |
dur_sec = get_duration_by_detecting_offset( | |
audio_file, side_info=v['qualities_str'], offset_threshold=0.001) | |
if INSTRUMENT_MAPPING.get((instrument_family, instrument_source), None) is not None: | |
sound_info[nsynth_id] = { | |
'audio_file': | |
audio_file, | |
'program': | |
INSTRUMENT_MAPPING[instrument_family, instrument_source]['program'], | |
'pitch': | |
int(v['pitch']), | |
'velocity': | |
int(v['velocity']), | |
'channel_group': | |
INSTRUMENT_MAPPING[instrument_family, instrument_source]['channel'], | |
'dur_sec': | |
dur_sec, | |
} | |
else: | |
count_skipped += 1 | |
skipped_instrument_family.add(instrument_family) | |
logger.info(f'Created sound info. Number of examples: {len(sound_info)}') | |
logger.info(f'Number of skipped examples: {count_skipped}, {skipped_instrument_family}') | |
del metadata | |
# Regroup sound_info by channel_group | |
sound_info_by_channel_group = {} # key: channel_group, value: list of sound_info | |
num_channel_groups = 10 | |
for i in range(num_channel_groups): | |
sound_info_by_channel_group[i] = {} | |
for nsynth_id, info in sound_info.items(): | |
channel_group = info['channel_group'] | |
sound_info_by_channel_group[channel_group][nsynth_id] = info | |
del sound_info | |
channel_group_counts = [ | |
(CHANNEL_INFO[k]['name'], len(v)) for k, v in sound_info_by_channel_group.items() | |
] | |
logger.info('Count of sound_info in each channel_group: {}'.format(channel_group_counts)) | |
return sound_info_by_channel_group, num_channel_groups | |
def random_nsynth_generator(data_home: os.PathLike, | |
dataset_name: str = 'random_nsynth', | |
generation_minutes_per_file: float = 4.0) -> None: | |
""" | |
Splits: | |
'train' | |
'validation' | |
'test' | |
Writes: | |
- {dataset_name}_{split}_file_list.json: a dictionary with the following keys: | |
{ | |
index: | |
{ | |
'random_nsynth_id': random_nsynth_id, # = nsynth_id | |
'n_frames': (int), | |
'stem_file': 'path/to/stem.npy', | |
'mix_audio_file': 'path/to/mix.wav', | |
'notes_file': 'path/to/notes.npy', | |
'note_events_file': 'path/to/note_events.npy', | |
'midi_file': 'path/to/midi.mid', # this is 120bpm converted midi file from note_events | |
'program': List[int], | |
'is_drum': List[int], # [0] or [1] | |
} | |
} | |
""" | |
# Directory and file paths | |
base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') | |
output_index_dir = os.path.join(data_home, 'yourmt3_indexes') | |
os.makedirs(output_index_dir, exist_ok=True) | |
# Setup logger | |
log_file = os.path.join(base_dir, 'sound_genetation_log.txt') | |
logger = setup_logger(log_file) | |
# Load annotation json file as dictionary | |
split = 'validation' | |
metadata_file = os.path.join(base_dir, split, 'examples.json') | |
# Create a sound_info dictionary | |
sound_info_by_channel_group, num_channel_groups = create_sound_info( | |
base_dir, logger, split, metadata_file) | |
# Gnenerate random note sequences | |
max_frames_per_file = int(generation_minutes_per_file * 60 * 16000) | |
sound_gens = [ | |
random_key_cycle(sound_info_by_channel_group[key]) | |
for key in sorted(sound_info_by_channel_group.keys()) | |
] | |
# 5-minute audio generation | |
notes = [] | |
y = np.zeros((num_channel_groups, max_frames_per_file), dtype=np.float32) # (C, L) | |
bass_channel = 0 # loop for a cycle of bass channel generation | |
cur_frame = 0 | |
# is_last_element_bass = False | |
#while cur_frame < max_frames_per_file and is_last_element_bass == False: | |
# x: source audio, y: target audio for each channel | |
x_info, is_last_element = next(sound_gens[ch]) | |
if ch == bass_channel: | |
is_last_element = is_last_element_bass | |
# info about this channel | |
onset_in_frame = cur_frame | |
offset_in_frame = cur_frame + int(x_info['dur_sec'] * 16000) | |
x = load_audio_file(x_info['audio_file'], fs=16000) | |
x = x[:int(x_info['dur_sec'] * 16000)] | |
y[ch, :] = 0 | |
def preprocess_random_nsynth_16k(data_home=os.PathLike, dataset_name='random_nsynth') -> None: | |
""" | |
Splits: | |
'train' | |
'validation' | |
'test' | |
Writes: | |
- {dataset_name}_{split}_file_list.json: a dictionary with the following keys: | |
{ | |
index: | |
{ | |
'random_nsynth_id': random_nsynth_id, # = nsynth_id | |
'n_frames': (int), | |
'stem_file': 'path/to/stem.npy', | |
'mix_audio_file': 'path/to/mix.wav', | |
'notes_file': 'path/to/notes.npy', | |
'note_events_file': 'path/to/note_events.npy', | |
'midi_file': 'path/to/midi.mid', # this is 120bpm converted midi file from note_events | |
'program': List[int], | |
'is_drum': List[int], # [0] or [1] | |
} | |
} | |
""" | |
# Directory and file paths | |
base_dir = os.path.join(data_home, dataset_name + '_yourmt3_16k') | |
output_index_dir = os.path.join(data_home, 'yourmt3_indexes') | |
os.makedirs(output_index_dir, exist_ok=True) | |
# Setup logger | |
log_file = os.path.join(base_dir, 'log.txt') | |
logger = setup_logger(log_file) | |
# Load annotation json file as dictionary | |
split = 'validation' | |
metadata_file = os.path.join(base_dir, split, 'examples.json') | |
with open(metadata_file, 'r') as f: | |
metadata = json.load(f) | |
logger.info(f'Loaded {metadata_file}. Number of examples: {len(metadata)}') |