Spaces:
Running
Running
import os | |
import random | |
from copy import deepcopy | |
import pandas as pd | |
import logging | |
from tqdm import tqdm | |
import json | |
import glob | |
import re | |
from resemblyzer import VoiceEncoder | |
import traceback | |
import numpy as np | |
import pretty_midi | |
import librosa | |
from scipy.interpolate import interp1d | |
import torch | |
from textgrid import TextGrid | |
from utils.hparams import hparams | |
from data_gen.tts.data_gen_utils import build_phone_encoder, get_pitch | |
from utils.pitch_utils import f0_to_coarse | |
from data_gen.tts.base_binarizer import BaseBinarizer, BinarizationError | |
from data_gen.tts.binarizer_zh import ZhBinarizer | |
from data_gen.tts.txt_processors.zh_g2pM import ALL_YUNMU | |
from vocoders.base_vocoder import VOCODERS | |
class SingingBinarizer(BaseBinarizer): | |
def __init__(self, processed_data_dir=None): | |
if processed_data_dir is None: | |
processed_data_dir = hparams['processed_data_dir'] | |
self.processed_data_dirs = processed_data_dir.split(",") | |
self.binarization_args = hparams['binarization_args'] | |
self.pre_align_args = hparams['pre_align_args'] | |
self.item2txt = {} | |
self.item2ph = {} | |
self.item2wavfn = {} | |
self.item2f0fn = {} | |
self.item2tgfn = {} | |
self.item2spk = {} | |
def split_train_test_set(self, item_names): | |
item_names = deepcopy(item_names) | |
test_item_names = [x for x in item_names if any([ts in x for ts in hparams['test_prefixes']])] | |
train_item_names = [x for x in item_names if x not in set(test_item_names)] | |
logging.info("train {}".format(len(train_item_names))) | |
logging.info("test {}".format(len(test_item_names))) | |
return train_item_names, test_item_names | |
def load_meta_data(self): | |
for ds_id, processed_data_dir in enumerate(self.processed_data_dirs): | |
wav_suffix = '_wf0.wav' | |
txt_suffix = '.txt' | |
ph_suffix = '_ph.txt' | |
tg_suffix = '.TextGrid' | |
all_wav_pieces = glob.glob(f'{processed_data_dir}/*/*{wav_suffix}') | |
for piece_path in all_wav_pieces: | |
item_name = raw_item_name = piece_path[len(processed_data_dir)+1:].replace('/', '-')[:-len(wav_suffix)] | |
if len(self.processed_data_dirs) > 1: | |
item_name = f'ds{ds_id}_{item_name}' | |
self.item2txt[item_name] = open(f'{piece_path.replace(wav_suffix, txt_suffix)}').readline() | |
self.item2ph[item_name] = open(f'{piece_path.replace(wav_suffix, ph_suffix)}').readline() | |
self.item2wavfn[item_name] = piece_path | |
self.item2spk[item_name] = re.split('-|#', piece_path.split('/')[-2])[0] | |
if len(self.processed_data_dirs) > 1: | |
self.item2spk[item_name] = f"ds{ds_id}_{self.item2spk[item_name]}" | |
self.item2tgfn[item_name] = piece_path.replace(wav_suffix, tg_suffix) | |
print('spkers: ', set(self.item2spk.values())) | |
self.item_names = sorted(list(self.item2txt.keys())) | |
if self.binarization_args['shuffle']: | |
random.seed(1234) | |
random.shuffle(self.item_names) | |
self._train_item_names, self._test_item_names = self.split_train_test_set(self.item_names) | |
def train_item_names(self): | |
return self._train_item_names | |
def valid_item_names(self): | |
return self._test_item_names | |
def test_item_names(self): | |
return self._test_item_names | |
def process(self): | |
self.load_meta_data() | |
os.makedirs(hparams['binary_data_dir'], exist_ok=True) | |
self.spk_map = self.build_spk_map() | |
print("| spk_map: ", self.spk_map) | |
spk_map_fn = f"{hparams['binary_data_dir']}/spk_map.json" | |
json.dump(self.spk_map, open(spk_map_fn, 'w')) | |
self.phone_encoder = self._phone_encoder() | |
self.process_data('valid') | |
self.process_data('test') | |
self.process_data('train') | |
def _phone_encoder(self): | |
ph_set_fn = f"{hparams['binary_data_dir']}/phone_set.json" | |
ph_set = [] | |
if hparams['reset_phone_dict'] or not os.path.exists(ph_set_fn): | |
for ph_sent in self.item2ph.values(): | |
ph_set += ph_sent.split(' ') | |
ph_set = sorted(set(ph_set)) | |
json.dump(ph_set, open(ph_set_fn, 'w')) | |
print("| Build phone set: ", ph_set) | |
else: | |
ph_set = json.load(open(ph_set_fn, 'r')) | |
print("| Load phone set: ", ph_set) | |
return build_phone_encoder(hparams['binary_data_dir']) | |
# @staticmethod | |
# def get_pitch(wav_fn, spec, res): | |
# wav_suffix = '_wf0.wav' | |
# f0_suffix = '_f0.npy' | |
# f0fn = wav_fn.replace(wav_suffix, f0_suffix) | |
# pitch_info = np.load(f0fn) | |
# f0 = [x[1] for x in pitch_info] | |
# spec_x_coor = np.arange(0, 1, 1 / len(spec))[:len(spec)] | |
# f0_x_coor = np.arange(0, 1, 1 / len(f0))[:len(f0)] | |
# f0 = interp1d(f0_x_coor, f0, 'nearest', fill_value='extrapolate')(spec_x_coor)[:len(spec)] | |
# # f0_x_coor = np.arange(0, 1, 1 / len(f0)) | |
# # f0_x_coor[-1] = 1 | |
# # f0 = interp1d(f0_x_coor, f0, 'nearest')(spec_x_coor)[:len(spec)] | |
# if sum(f0) == 0: | |
# raise BinarizationError("Empty f0") | |
# assert len(f0) == len(spec), (len(f0), len(spec)) | |
# pitch_coarse = f0_to_coarse(f0) | |
# | |
# # vis f0 | |
# # import matplotlib.pyplot as plt | |
# # from textgrid import TextGrid | |
# # tg_fn = wav_fn.replace(wav_suffix, '.TextGrid') | |
# # fig = plt.figure(figsize=(12, 6)) | |
# # plt.pcolor(spec.T, vmin=-5, vmax=0) | |
# # ax = plt.gca() | |
# # ax2 = ax.twinx() | |
# # ax2.plot(f0, color='red') | |
# # ax2.set_ylim(0, 800) | |
# # itvs = TextGrid.fromFile(tg_fn)[0] | |
# # for itv in itvs: | |
# # x = itv.maxTime * hparams['audio_sample_rate'] / hparams['hop_size'] | |
# # plt.vlines(x=x, ymin=0, ymax=80, color='black') | |
# # plt.text(x=x, y=20, s=itv.mark, color='black') | |
# # plt.savefig('tmp/20211229_singing_plots_test.png') | |
# | |
# res['f0'] = f0 | |
# res['pitch'] = pitch_coarse | |
def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args): | |
if hparams['vocoder'] in VOCODERS: | |
wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn) | |
else: | |
wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn) | |
res = { | |
'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn, | |
'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id | |
} | |
try: | |
if binarization_args['with_f0']: | |
# cls.get_pitch(wav_fn, mel, res) | |
cls.get_pitch(wav, mel, res) | |
if binarization_args['with_txt']: | |
try: | |
# print(ph) | |
phone_encoded = res['phone'] = encoder.encode(ph) | |
except: | |
traceback.print_exc() | |
raise BinarizationError(f"Empty phoneme") | |
if binarization_args['with_align']: | |
cls.get_align(tg_fn, ph, mel, phone_encoded, res) | |
except BinarizationError as e: | |
print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}") | |
return None | |
return res | |
class MidiSingingBinarizer(SingingBinarizer): | |
item2midi = {} | |
item2midi_dur = {} | |
item2is_slur = {} | |
item2ph_durs = {} | |
item2wdb = {} | |
def load_meta_data(self): | |
for ds_id, processed_data_dir in enumerate(self.processed_data_dirs): | |
meta_midi = json.load(open(os.path.join(processed_data_dir, 'meta.json'))) # [list of dict] | |
for song_item in meta_midi: | |
item_name = raw_item_name = song_item['item_name'] | |
if len(self.processed_data_dirs) > 1: | |
item_name = f'ds{ds_id}_{item_name}' | |
self.item2wavfn[item_name] = song_item['wav_fn'] | |
self.item2txt[item_name] = song_item['txt'] | |
self.item2ph[item_name] = ' '.join(song_item['phs']) | |
self.item2wdb[item_name] = [1 if x in ALL_YUNMU + ['AP', 'SP', '<SIL>'] else 0 for x in song_item['phs']] | |
self.item2ph_durs[item_name] = song_item['ph_dur'] | |
self.item2midi[item_name] = song_item['notes'] | |
self.item2midi_dur[item_name] = song_item['notes_dur'] | |
self.item2is_slur[item_name] = song_item['is_slur'] | |
self.item2spk[item_name] = 'pop-cs' | |
if len(self.processed_data_dirs) > 1: | |
self.item2spk[item_name] = f"ds{ds_id}_{self.item2spk[item_name]}" | |
print('spkers: ', set(self.item2spk.values())) | |
self.item_names = sorted(list(self.item2txt.keys())) | |
if self.binarization_args['shuffle']: | |
random.seed(1234) | |
random.shuffle(self.item_names) | |
self._train_item_names, self._test_item_names = self.split_train_test_set(self.item_names) | |
def get_pitch(wav_fn, wav, spec, ph, res): | |
wav_suffix = '.wav' | |
# midi_suffix = '.mid' | |
wav_dir = 'wavs' | |
f0_dir = 'f0' | |
item_name = '/'.join(os.path.splitext(wav_fn)[0].split('/')[-2:]).replace('_wf0', '') | |
res['pitch_midi'] = np.asarray(MidiSingingBinarizer.item2midi[item_name]) | |
res['midi_dur'] = np.asarray(MidiSingingBinarizer.item2midi_dur[item_name]) | |
res['is_slur'] = np.asarray(MidiSingingBinarizer.item2is_slur[item_name]) | |
res['word_boundary'] = np.asarray(MidiSingingBinarizer.item2wdb[item_name]) | |
assert res['pitch_midi'].shape == res['midi_dur'].shape == res['is_slur'].shape, ( | |
res['pitch_midi'].shape, res['midi_dur'].shape, res['is_slur'].shape) | |
# gt f0. | |
gt_f0, gt_pitch_coarse = get_pitch(wav, spec, hparams) | |
if sum(gt_f0) == 0: | |
raise BinarizationError("Empty **gt** f0") | |
res['f0'] = gt_f0 | |
res['pitch'] = gt_pitch_coarse | |
def get_align(ph_durs, mel, phone_encoded, res, hop_size=hparams['hop_size'], audio_sample_rate=hparams['audio_sample_rate']): | |
mel2ph = np.zeros([mel.shape[0]], int) | |
startTime = 0 | |
for i_ph in range(len(ph_durs)): | |
start_frame = int(startTime * audio_sample_rate / hop_size + 0.5) | |
end_frame = int((startTime + ph_durs[i_ph]) * audio_sample_rate / hop_size + 0.5) | |
mel2ph[start_frame:end_frame] = i_ph + 1 | |
startTime = startTime + ph_durs[i_ph] | |
# print('ph durs: ', ph_durs) | |
# print('mel2ph: ', mel2ph, len(mel2ph)) | |
res['mel2ph'] = mel2ph | |
# res['dur'] = None | |
def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args): | |
if hparams['vocoder'] in VOCODERS: | |
wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn) | |
else: | |
wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn) | |
res = { | |
'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn, | |
'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id | |
} | |
try: | |
if binarization_args['with_f0']: | |
cls.get_pitch(wav_fn, wav, mel, ph, res) | |
if binarization_args['with_txt']: | |
try: | |
phone_encoded = res['phone'] = encoder.encode(ph) | |
except: | |
traceback.print_exc() | |
raise BinarizationError(f"Empty phoneme") | |
if binarization_args['with_align']: | |
cls.get_align(MidiSingingBinarizer.item2ph_durs[item_name], mel, phone_encoded, res) | |
except BinarizationError as e: | |
print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}") | |
return None | |
return res | |
class ZhSingingBinarizer(ZhBinarizer, SingingBinarizer): | |
pass | |
class M4SingerBinarizer(MidiSingingBinarizer): | |
item2midi = {} | |
item2midi_dur = {} | |
item2is_slur = {} | |
item2ph_durs = {} | |
item2wdb = {} | |
def split_train_test_set(self, item_names): | |
item_names = deepcopy(item_names) | |
test_item_names = [x for x in item_names if any([x.startswith(ts) for ts in hparams['test_prefixes']])] | |
train_item_names = [x for x in item_names if x not in set(test_item_names)] | |
logging.info("train {}".format(len(train_item_names))) | |
logging.info("test {}".format(len(test_item_names))) | |
return train_item_names, test_item_names | |
def load_meta_data(self): | |
raw_data_dir = hparams['raw_data_dir'] | |
song_items = json.load(open(os.path.join(raw_data_dir, 'meta.json'))) # [list of dict] | |
for song_item in song_items: | |
item_name = raw_item_name = song_item['item_name'] | |
singer, song_name, sent_id = item_name.split("#") | |
self.item2wavfn[item_name] = f'{raw_data_dir}/{singer}#{song_name}/{sent_id}.wav' | |
self.item2txt[item_name] = song_item['txt'] | |
self.item2ph[item_name] = ' '.join(song_item['phs']) | |
self.item2ph_durs[item_name] = song_item['ph_dur'] | |
self.item2midi[item_name] = song_item['notes'] | |
self.item2midi_dur[item_name] = song_item['notes_dur'] | |
self.item2is_slur[item_name] = song_item['is_slur'] | |
self.item2wdb[item_name] = [1 if (0 < i < len(song_item['phs']) - 1 and p in ALL_YUNMU + ['<SP>', '<AP>'])\ | |
or i == len(song_item['phs']) - 1 else 0 for i, p in enumerate(song_item['phs'])] | |
self.item2spk[item_name] = singer | |
print('spkers: ', set(self.item2spk.values())) | |
self.item_names = sorted(list(self.item2txt.keys())) | |
if self.binarization_args['shuffle']: | |
random.seed(1234) | |
random.shuffle(self.item_names) | |
self._train_item_names, self._test_item_names = self.split_train_test_set(self.item_names) | |
def get_pitch(item_name, wav, spec, ph, res): | |
wav_suffix = '.wav' | |
# midi_suffix = '.mid' | |
wav_dir = 'wavs' | |
f0_dir = 'text_f0_align' | |
#item_name = os.path.splitext(os.path.basename(wav_fn))[0] | |
res['pitch_midi'] = np.asarray(M4SingerBinarizer.item2midi[item_name]) | |
res['midi_dur'] = np.asarray(M4SingerBinarizer.item2midi_dur[item_name]) | |
res['is_slur'] = np.asarray(M4SingerBinarizer.item2is_slur[item_name]) | |
res['word_boundary'] = np.asarray(M4SingerBinarizer.item2wdb[item_name]) | |
assert res['pitch_midi'].shape == res['midi_dur'].shape == res['is_slur'].shape, (res['pitch_midi'].shape, res['midi_dur'].shape, res['is_slur'].shape) | |
# gt f0. | |
# f0 = None | |
# f0_suffix = '_f0.npy' | |
# f0fn = wav_fn.replace(wav_suffix, f0_suffix).replace(wav_dir, f0_dir) | |
# pitch_info = np.load(f0fn) | |
# f0 = [x[1] for x in pitch_info] | |
# spec_x_coor = np.arange(0, 1, 1 / len(spec))[:len(spec)] | |
# | |
# f0_x_coor = np.arange(0, 1, 1 / len(f0))[:len(f0)] | |
# f0 = interp1d(f0_x_coor, f0, 'nearest', fill_value='extrapolate')(spec_x_coor)[:len(spec)] | |
# if sum(f0) == 0: | |
# raise BinarizationError("Empty **gt** f0") | |
# | |
# pitch_coarse = f0_to_coarse(f0) | |
# res['f0'] = f0 | |
# res['pitch'] = pitch_coarse | |
# gt f0. | |
gt_f0, gt_pitch_coarse = get_pitch(wav, spec, hparams) | |
if sum(gt_f0) == 0: | |
raise BinarizationError("Empty **gt** f0") | |
res['f0'] = gt_f0 | |
res['pitch'] = gt_pitch_coarse | |
def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args): | |
if hparams['vocoder'] in VOCODERS: | |
wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn) | |
else: | |
wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn) | |
res = { | |
'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn, | |
'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id | |
} | |
try: | |
if binarization_args['with_f0']: | |
cls.get_pitch(item_name, wav, mel, ph, res) | |
if binarization_args['with_txt']: | |
try: | |
phone_encoded = res['phone'] = encoder.encode(ph) | |
except: | |
traceback.print_exc() | |
raise BinarizationError(f"Empty phoneme") | |
if binarization_args['with_align']: | |
cls.get_align(M4SingerBinarizer.item2ph_durs[item_name], mel, phone_encoded, res) | |
except BinarizationError as e: | |
print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}") | |
return None | |
return res | |
if __name__ == "__main__": | |
SingingBinarizer().process() | |