|
import os |
|
|
|
import torch |
|
import numpy as np |
|
from modules.hifigan.hifigan import HifiGanGenerator |
|
from vocoders.hifigan import HifiGAN |
|
from inference.svs.opencpop.map import cpop_pinyin2ph_func |
|
|
|
from utils import load_ckpt |
|
from utils.hparams import set_hparams, hparams |
|
from utils.text_encoder import TokenTextEncoder |
|
from pypinyin import pinyin, lazy_pinyin, Style |
|
import librosa |
|
import glob |
|
import re |
|
|
|
|
|
class BaseSVSInfer: |
|
def __init__(self, hparams, device=None): |
|
if device is None: |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
self.hparams = hparams |
|
self.device = device |
|
|
|
phone_list = ["AP", "SP", "a", "ai", "an", "ang", "ao", "b", "c", "ch", "d", "e", "ei", "en", "eng", "er", "f", "g", |
|
"h", "i", "ia", "ian", "iang", "iao", "ie", "in", "ing", "iong", "iu", "j", "k", "l", "m", "n", "o", |
|
"ong", "ou", "p", "q", "r", "s", "sh", "t", "u", "ua", "uai", "uan", "uang", "ui", "un", "uo", "v", |
|
"van", "ve", "vn", "w", "x", "y", "z", "zh"] |
|
self.ph_encoder = TokenTextEncoder(None, vocab_list=phone_list, replace_oov=',') |
|
self.pinyin2phs = cpop_pinyin2ph_func() |
|
self.spk_map = {'opencpop': 0} |
|
|
|
self.model = self.build_model() |
|
self.model.eval() |
|
self.model.to(self.device) |
|
self.vocoder = self.build_vocoder() |
|
self.vocoder.eval() |
|
self.vocoder.to(self.device) |
|
|
|
def build_model(self): |
|
raise NotImplementedError |
|
|
|
def forward_model(self, inp): |
|
raise NotImplementedError |
|
|
|
def build_vocoder(self): |
|
base_dir = hparams['vocoder_ckpt'] |
|
config_path = f'{base_dir}/config.yaml' |
|
ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key= |
|
lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))[-1] |
|
print('| load HifiGAN: ', ckpt) |
|
ckpt_dict = torch.load(ckpt, map_location="cpu") |
|
config = set_hparams(config_path, global_hparams=False) |
|
state = ckpt_dict["state_dict"]["model_gen"] |
|
vocoder = HifiGanGenerator(config) |
|
vocoder.load_state_dict(state, strict=True) |
|
vocoder.remove_weight_norm() |
|
vocoder = vocoder.eval().to(self.device) |
|
return vocoder |
|
|
|
def run_vocoder(self, c, **kwargs): |
|
c = c.transpose(2, 1) |
|
f0 = kwargs.get('f0') |
|
if f0 is not None and hparams.get('use_nsf'): |
|
|
|
y = self.vocoder(c, f0).view(-1) |
|
else: |
|
y = self.vocoder(c).view(-1) |
|
|
|
return y[None] |
|
|
|
def preprocess_word_level_input(self, inp): |
|
|
|
text_raw = inp['text'].replace('最长', '最常').replace('长睫毛', '常睫毛') \ |
|
.replace('那么长', '那么常').replace('多长', '多常') \ |
|
.replace('很长', '很常') |
|
|
|
|
|
pinyins = lazy_pinyin(text_raw, strict=False) |
|
ph_per_word_lst = [self.pinyin2phs[pinyin.strip()] for pinyin in pinyins if pinyin.strip() in self.pinyin2phs] |
|
|
|
|
|
note_per_word_lst = [x.strip() for x in inp['notes'].split('|') if x.strip() != ''] |
|
mididur_per_word_lst = [x.strip() for x in inp['notes_duration'].split('|') if x.strip() != ''] |
|
|
|
if len(note_per_word_lst) == len(ph_per_word_lst) == len(mididur_per_word_lst): |
|
print('Pass word-notes check.') |
|
else: |
|
print('The number of words does\'t match the number of notes\' windows. ', |
|
'You should split the note(s) for each word by | mark.') |
|
print(ph_per_word_lst, note_per_word_lst, mididur_per_word_lst) |
|
print(len(ph_per_word_lst), len(note_per_word_lst), len(mididur_per_word_lst)) |
|
return None |
|
|
|
note_lst = [] |
|
ph_lst = [] |
|
midi_dur_lst = [] |
|
is_slur = [] |
|
for idx, ph_per_word in enumerate(ph_per_word_lst): |
|
|
|
|
|
ph_in_this_word = ph_per_word.split() |
|
|
|
|
|
|
|
note_in_this_word = note_per_word_lst[idx].split() |
|
midi_dur_in_this_word = mididur_per_word_lst[idx].split() |
|
|
|
|
|
|
|
|
|
|
|
|
|
for ph in ph_in_this_word: |
|
ph_lst.append(ph) |
|
note_lst.append(note_in_this_word[0]) |
|
midi_dur_lst.append(midi_dur_in_this_word[0]) |
|
is_slur.append(0) |
|
|
|
|
|
|
|
|
|
|
|
if len(note_in_this_word) > 1: |
|
for idx in range(1, len(note_in_this_word)): |
|
ph_lst.append(ph_in_this_word[-1]) |
|
note_lst.append(note_in_this_word[idx]) |
|
midi_dur_lst.append(midi_dur_in_this_word[idx]) |
|
is_slur.append(1) |
|
ph_seq = ' '.join(ph_lst) |
|
|
|
if len(ph_lst) == len(note_lst) == len(midi_dur_lst): |
|
print(len(ph_lst), len(note_lst), len(midi_dur_lst)) |
|
print('Pass word-notes check.') |
|
else: |
|
print('The number of words does\'t match the number of notes\' windows. ', |
|
'You should split the note(s) for each word by | mark.') |
|
return None |
|
return ph_seq, note_lst, midi_dur_lst, is_slur |
|
|
|
def preprocess_phoneme_level_input(self, inp): |
|
ph_seq = inp['ph_seq'] |
|
note_lst = inp['note_seq'].split() |
|
midi_dur_lst = inp['note_dur_seq'].split() |
|
is_slur = [float(x) for x in inp['is_slur_seq'].split()] |
|
print(len(note_lst), len(ph_seq.split()), len(midi_dur_lst)) |
|
if len(note_lst) == len(ph_seq.split()) == len(midi_dur_lst): |
|
print('Pass word-notes check.') |
|
else: |
|
print('The number of words does\'t match the number of notes\' windows. ', |
|
'You should split the note(s) for each word by | mark.') |
|
return None |
|
return ph_seq, note_lst, midi_dur_lst, is_slur |
|
|
|
def preprocess_input(self, inp, input_type='word'): |
|
""" |
|
|
|
:param inp: {'text': str, 'item_name': (str, optional), 'spk_name': (str, optional)} |
|
:return: |
|
""" |
|
|
|
item_name = inp.get('item_name', '<ITEM_NAME>') |
|
spk_name = inp.get('spk_name', 'opencpop') |
|
|
|
|
|
spk_id = self.spk_map[spk_name] |
|
|
|
|
|
if input_type == 'word': |
|
ret = self.preprocess_word_level_input(inp) |
|
elif input_type == 'phoneme': |
|
ret = self.preprocess_phoneme_level_input(inp) |
|
else: |
|
print('Invalid input type.') |
|
return None |
|
|
|
if ret: |
|
ph_seq, note_lst, midi_dur_lst, is_slur = ret |
|
else: |
|
print('==========> Preprocess_word_level or phone_level input wrong.') |
|
return None |
|
|
|
|
|
try: |
|
midis = [librosa.note_to_midi(x.split("/")[0]) if x != 'rest' else 0 |
|
for x in note_lst] |
|
midi_dur_lst = [float(x) for x in midi_dur_lst] |
|
except Exception as e: |
|
print(e) |
|
print('Invalid Input Type.') |
|
return None |
|
|
|
ph_token = self.ph_encoder.encode(ph_seq) |
|
item = {'item_name': item_name, 'text': inp['text'], 'ph': ph_seq, 'spk_id': spk_id, |
|
'ph_token': ph_token, 'pitch_midi': np.asarray(midis), 'midi_dur': np.asarray(midi_dur_lst), |
|
'is_slur': np.asarray(is_slur), } |
|
item['ph_len'] = len(item['ph_token']) |
|
return item |
|
|
|
def input_to_batch(self, item): |
|
item_names = [item['item_name']] |
|
text = [item['text']] |
|
ph = [item['ph']] |
|
txt_tokens = torch.LongTensor(item['ph_token'])[None, :].to(self.device) |
|
txt_lengths = torch.LongTensor([txt_tokens.shape[1]]).to(self.device) |
|
spk_ids = torch.LongTensor(item['spk_id'])[None, :].to(self.device) |
|
|
|
pitch_midi = torch.LongTensor(item['pitch_midi'])[None, :hparams['max_frames']].to(self.device) |
|
midi_dur = torch.FloatTensor(item['midi_dur'])[None, :hparams['max_frames']].to(self.device) |
|
is_slur = torch.LongTensor(item['is_slur'])[None, :hparams['max_frames']].to(self.device) |
|
|
|
batch = { |
|
'item_name': item_names, |
|
'text': text, |
|
'ph': ph, |
|
'txt_tokens': txt_tokens, |
|
'txt_lengths': txt_lengths, |
|
'spk_ids': spk_ids, |
|
'pitch_midi': pitch_midi, |
|
'midi_dur': midi_dur, |
|
'is_slur': is_slur |
|
} |
|
return batch |
|
|
|
def postprocess_output(self, output): |
|
return output |
|
|
|
def infer_once(self, inp): |
|
inp = self.preprocess_input(inp, input_type=inp['input_type'] if inp.get('input_type') else 'word') |
|
output = self.forward_model(inp) |
|
output = self.postprocess_output(output) |
|
return output |
|
|
|
@classmethod |
|
def example_run(cls, inp): |
|
from utils.audio import save_wav |
|
set_hparams(print_hparams=False) |
|
infer_ins = cls(hparams) |
|
out = infer_ins.infer_once(inp) |
|
os.makedirs('infer_out', exist_ok=True) |
|
save_wav(out, f'infer_out/example_out.wav', hparams['audio_sample_rate']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|