Spaces:
Sleeping
Sleeping
File size: 3,563 Bytes
9270314 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import argparse
import os
import subprocess
import numpy as np
import torch
import yaml
from scipy.io import wavfile
from mtts.models.fs2_model import FastSpeech2
from mtts.models.vocoder import *
from mtts.text import TextProcessor
from mtts.utils.logging import get_logger
logger = get_logger(__file__)
def check_ffmpeg():
r, path = subprocess.getstatusoutput("which ffmpeg")
return r == 0
with_ffmpeg = check_ffmpeg()
def build_vocoder(device, config):
vocoder_name = config['vocoder']['type']
VocoderClass = eval(vocoder_name)
model = VocoderClass(config=config['vocoder'][vocoder_name])
return model
def normalize(wav):
assert wav.dtype == np.float32
eps = 1e-6
sil = wav[1500:2000]
#wav = wav - np.mean(sil)
#wav = (wav - np.min(wav))/(np.max(wav)-np.min(wav)+eps)
wav = wav / np.max(np.abs(wav))
#wav = wav*2-1
wav = wav * 32767
return wav.astype('int16')
def to_int16(wav):
wav = wav = wav * 32767
wav = np.clamp(wav, -32767, 32768)
return wav.astype('int16')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', type=str, default='input.txt')
parser.add_argument('--duration', type=float, default=1.0)
parser.add_argument('--output_dir', type=str, default='./outputs/')
parser.add_argument('--checkpoint', type=str, required=False, default='checkpoints\checkpoint_140000.pth.tar')
parser.add_argument('-c', '--config', type=str, default='./config.yaml')
parser.add_argument('-d', '--device', choices=['cuda', 'cpu'], type=str, default='cuda')
args = parser.parse_args()
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
with open(args.config) as f:
config = yaml.safe_load(f)
logger.info(f.read())
sr = config['fbank']['sample_rate']
vocoder = build_vocoder(args.device, config)
text_processor = TextProcessor(config)
model = FastSpeech2(config)
if args.checkpoint != '':
sd = torch.load(args.checkpoint, map_location=args.device)
if 'model' in sd.keys():
sd = sd['model']
model.load_state_dict(sd)
del sd # to save mem
model = model.to(args.device)
torch.set_grad_enabled(False)
try:
lines = open(args.input).read().split('\n')
except:
print('Failed to open text file', args.input)
print('Treating input as text')
lines = [args.input]
for line in lines:
if len(line) == 0 or line.startswith('#'):
continue
logger.info(f'processing {line}')
name, tokens = text_processor(line)
tokens = tokens.to(args.device)
seq_len = torch.tensor([tokens.shape[1]])
tokens = tokens.unsqueeze(1)
seq_len = seq_len.to(args.device)
max_src_len = torch.max(seq_len)
output = model(tokens, seq_len, max_src_len=max_src_len, d_control=args.duration)
mel_pred, mel_postnet, d_pred, src_mask, mel_mask, mel_len = output
# convert to waveform using vocoder
mel_postnet = mel_postnet[0].transpose(0, 1).detach()
mel_postnet += config['fbank']['mel_mean']
wav = vocoder(mel_postnet)
if config['synthesis']['normalize']:
wav = normalize(wav)
else:
wav = to_int16(wav)
dst_file = os.path.join(args.output_dir, f'{name}.wav')
#np.save(dst_file+'.npy',mel_postnet.cpu().numpy())
logger.info(f'writing file to {dst_file}')
wavfile.write(dst_file, sr, wav)
|