text_to_speech / synthesize.py
wuxulong19950206
add files
9270314
raw
history blame
3.56 kB
import argparse
import os
import subprocess
import numpy as np
import torch
import yaml
from scipy.io import wavfile
from mtts.models.fs2_model import FastSpeech2
from mtts.models.vocoder import *
from mtts.text import TextProcessor
from mtts.utils.logging import get_logger
logger = get_logger(__file__)
def check_ffmpeg():
r, path = subprocess.getstatusoutput("which ffmpeg")
return r == 0
with_ffmpeg = check_ffmpeg()
def build_vocoder(device, config):
vocoder_name = config['vocoder']['type']
VocoderClass = eval(vocoder_name)
model = VocoderClass(config=config['vocoder'][vocoder_name])
return model
def normalize(wav):
assert wav.dtype == np.float32
eps = 1e-6
sil = wav[1500:2000]
#wav = wav - np.mean(sil)
#wav = (wav - np.min(wav))/(np.max(wav)-np.min(wav)+eps)
wav = wav / np.max(np.abs(wav))
#wav = wav*2-1
wav = wav * 32767
return wav.astype('int16')
def to_int16(wav):
wav = wav = wav * 32767
wav = np.clamp(wav, -32767, 32768)
return wav.astype('int16')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', type=str, default='input.txt')
parser.add_argument('--duration', type=float, default=1.0)
parser.add_argument('--output_dir', type=str, default='./outputs/')
parser.add_argument('--checkpoint', type=str, required=False, default='checkpoints\checkpoint_140000.pth.tar')
parser.add_argument('-c', '--config', type=str, default='./config.yaml')
parser.add_argument('-d', '--device', choices=['cuda', 'cpu'], type=str, default='cuda')
args = parser.parse_args()
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
with open(args.config) as f:
config = yaml.safe_load(f)
logger.info(f.read())
sr = config['fbank']['sample_rate']
vocoder = build_vocoder(args.device, config)
text_processor = TextProcessor(config)
model = FastSpeech2(config)
if args.checkpoint != '':
sd = torch.load(args.checkpoint, map_location=args.device)
if 'model' in sd.keys():
sd = sd['model']
model.load_state_dict(sd)
del sd # to save mem
model = model.to(args.device)
torch.set_grad_enabled(False)
try:
lines = open(args.input).read().split('\n')
except:
print('Failed to open text file', args.input)
print('Treating input as text')
lines = [args.input]
for line in lines:
if len(line) == 0 or line.startswith('#'):
continue
logger.info(f'processing {line}')
name, tokens = text_processor(line)
tokens = tokens.to(args.device)
seq_len = torch.tensor([tokens.shape[1]])
tokens = tokens.unsqueeze(1)
seq_len = seq_len.to(args.device)
max_src_len = torch.max(seq_len)
output = model(tokens, seq_len, max_src_len=max_src_len, d_control=args.duration)
mel_pred, mel_postnet, d_pred, src_mask, mel_mask, mel_len = output
# convert to waveform using vocoder
mel_postnet = mel_postnet[0].transpose(0, 1).detach()
mel_postnet += config['fbank']['mel_mean']
wav = vocoder(mel_postnet)
if config['synthesis']['normalize']:
wav = normalize(wav)
else:
wav = to_int16(wav)
dst_file = os.path.join(args.output_dir, f'{name}.wav')
#np.save(dst_file+'.npy',mel_postnet.cpu().numpy())
logger.info(f'writing file to {dst_file}')
wavfile.write(dst_file, sr, wav)