Spaces:
Sleeping
Sleeping
import argparse | |
import os | |
import subprocess | |
import numpy as np | |
import torch | |
import yaml | |
from scipy.io import wavfile | |
from mtts.models.fs2_model import FastSpeech2 | |
from mtts.models.vocoder import * | |
from mtts.text import TextProcessor | |
from mtts.utils.logging import get_logger | |
logger = get_logger(__file__) | |
def check_ffmpeg(): | |
r, path = subprocess.getstatusoutput("which ffmpeg") | |
return r == 0 | |
with_ffmpeg = check_ffmpeg() | |
def build_vocoder(device, config): | |
vocoder_name = config['vocoder']['type'] | |
VocoderClass = eval(vocoder_name) | |
model = VocoderClass(config=config['vocoder'][vocoder_name]) | |
return model | |
def normalize(wav): | |
assert wav.dtype == np.float32 | |
eps = 1e-6 | |
sil = wav[1500:2000] | |
#wav = wav - np.mean(sil) | |
#wav = (wav - np.min(wav))/(np.max(wav)-np.min(wav)+eps) | |
wav = wav / np.max(np.abs(wav)) | |
#wav = wav*2-1 | |
wav = wav * 32767 | |
return wav.astype('int16') | |
def to_int16(wav): | |
wav = wav = wav * 32767 | |
wav = np.clamp(wav, -32767, 32768) | |
return wav.astype('int16') | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-i', '--input', type=str, default='input.txt') | |
parser.add_argument('--duration', type=float, default=1.0) | |
parser.add_argument('--output_dir', type=str, default='./outputs/') | |
parser.add_argument('--checkpoint', type=str, required=False, default='checkpoints\checkpoint_140000.pth.tar') | |
parser.add_argument('-c', '--config', type=str, default='./config.yaml') | |
parser.add_argument('-d', '--device', choices=['cuda', 'cpu'], type=str, default='cuda') | |
args = parser.parse_args() | |
if not os.path.exists(args.output_dir): | |
os.makedirs(args.output_dir) | |
with open(args.config) as f: | |
config = yaml.safe_load(f) | |
logger.info(f.read()) | |
sr = config['fbank']['sample_rate'] | |
vocoder = build_vocoder(args.device, config) | |
text_processor = TextProcessor(config) | |
model = FastSpeech2(config) | |
if args.checkpoint != '': | |
sd = torch.load(args.checkpoint, map_location=args.device) | |
if 'model' in sd.keys(): | |
sd = sd['model'] | |
model.load_state_dict(sd) | |
del sd # to save mem | |
model = model.to(args.device) | |
torch.set_grad_enabled(False) | |
try: | |
lines = open(args.input).read().split('\n') | |
except: | |
print('Failed to open text file', args.input) | |
print('Treating input as text') | |
lines = [args.input] | |
for line in lines: | |
if len(line) == 0 or line.startswith('#'): | |
continue | |
logger.info(f'processing {line}') | |
name, tokens = text_processor(line) | |
tokens = tokens.to(args.device) | |
seq_len = torch.tensor([tokens.shape[1]]) | |
tokens = tokens.unsqueeze(1) | |
seq_len = seq_len.to(args.device) | |
max_src_len = torch.max(seq_len) | |
output = model(tokens, seq_len, max_src_len=max_src_len, d_control=args.duration) | |
mel_pred, mel_postnet, d_pred, src_mask, mel_mask, mel_len = output | |
# convert to waveform using vocoder | |
mel_postnet = mel_postnet[0].transpose(0, 1).detach() | |
mel_postnet += config['fbank']['mel_mean'] | |
wav = vocoder(mel_postnet) | |
if config['synthesis']['normalize']: | |
wav = normalize(wav) | |
else: | |
wav = to_int16(wav) | |
dst_file = os.path.join(args.output_dir, f'{name}.wav') | |
#np.save(dst_file+'.npy',mel_postnet.cpu().numpy()) | |
logger.info(f'writing file to {dst_file}') | |
wavfile.write(dst_file, sr, wav) | |