Pendrokar's picture
ndarray serialization fix
bd8f4e0
import os
import re
import json
import codecs
import ffmpeg
import argparse
import platform
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn.utils.rnn import pad_sequence
import scipy
import scipy.io.wavfile
import librosa
from scipy.io.wavfile import write
import numpy as np
try:
import sys
sys.path.append(".")
from resources.app.python.xvapitch.text import ALL_SYMBOLS, get_text_preprocessor, lang_names
from resources.app.python.xvapitch.xvapitch_model import xVAPitch as xVAPitchModel
except ModuleNotFoundError:
try:
from python.xvapitch.text import ALL_SYMBOLS, get_text_preprocessor, lang_names
from python.xvapitch.xvapitch_model import xVAPitch as xVAPitchModel
except ModuleNotFoundError:
try:
from xvapitch.text import ALL_SYMBOLS, get_text_preprocessor, lang_names
from xvapitch.xvapitch_model import xVAPitch as xVAPitchModel
except ModuleNotFoundError:
from text import ALL_SYMBOLS, get_text_preprocessor, lang_names
from xvapitch_model import xVAPitch as xVAPitchModel
class xVAPitch(object):
def __init__(self, logger, PROD, device, models_manager):
super(xVAPitch, self).__init__()
self.logger = logger
self.PROD = PROD
self.models_manager = models_manager
self.device = device
self.ckpt_path = None
self.arpabet_dict = {}
# torch.backends.cudnn.benchmark = True
torch.backends.cudnn.benchmark = False
self.base_dir = f'{"./resources/app" if self.PROD else "."}/python/xvapitch/text'
self.lang_tp = {}
self.lang_tp["en"] = get_text_preprocessor("en", self.base_dir, logger=self.logger)
self.language_id_mapping = {name: i for i, name in enumerate(sorted(list(lang_names.keys())))}
self.pitch_emb_values = torch.tensor(np.load(f'{"./resources/app" if self.PROD else "."}/python/xvapitch/embs/pitch_emb.npy')).unsqueeze(0).unsqueeze(-1)
self.angry_emb_values = torch.tensor(np.load(f'{"./resources/app" if self.PROD else "."}/python/xvapitch/embs/angry.npy')).unsqueeze(0).unsqueeze(-1)
self.happy_emb_values = torch.tensor(np.load(f'{"./resources/app" if self.PROD else "."}/python/xvapitch/embs/happy.npy')).unsqueeze(0).unsqueeze(-1)
self.sad_emb_values = torch.tensor(np.load(f'{"./resources/app" if self.PROD else "."}/python/xvapitch/embs/sad.npy')).unsqueeze(0).unsqueeze(-1)
self.surprise_emb_values = torch.tensor(np.load(f'{"./resources/app" if self.PROD else "."}/python/xvapitch/embs/surprise.npy')).unsqueeze(0).unsqueeze(-1)
self.base_lang = "en"
self.init_model()
self.model.pitch_emb_values = self.pitch_emb_values.to(self.models_manager.device)
self.model.angry_emb_values = self.angry_emb_values.to(self.models_manager.device)
self.model.happy_emb_values = self.happy_emb_values.to(self.models_manager.device)
self.model.sad_emb_values = self.sad_emb_values.to(self.models_manager.device)
self.model.surprise_emb_values = self.surprise_emb_values.to(self.models_manager.device)
self.isReady = True
def init_model (self):
parser = argparse.ArgumentParser()
args = parser.parse_args()
# Params from training
args.pitch = 1
args.pe_scaling = 0.1
args.expanded_flow = 0
args.ow_flow = 0
args.energy = 0
self.model = xVAPitchModel(args).to(self.device)
self.model.eval()
self.model.device = self.device
def load_state_dict (self, ckpt_path, ckpt, n_speakers=1, base_lang="en"):
self.logger.info(f'load_state_dict base_lang: {base_lang}')
if base_lang not in self.lang_tp.keys():
self.lang_tp[base_lang] = get_text_preprocessor(base_lang, self.base_dir, logger=self.logger)
self.base_lang = base_lang
self.ckpt_path = ckpt_path
if os.path.exists(ckpt_path.replace(".pt", ".json")):
with open(ckpt_path.replace(".pt", ".json"), "r") as f:
data = json.load(f)
self.base_emb = data["games"][0]["base_speaker_emb"]
if 'model' in ckpt:
ckpt = ckpt['model']
if ckpt["emb_l.weight"].shape[0]==31:
self.model.emb_l = nn.Embedding(31, self.model.embedded_language_dim).to(self.models_manager.device)
elif ckpt["emb_l.weight"].shape[0]==50:
num_languages = 50
self.model.emb_l = nn.Embedding(num_languages, self.model.embedded_language_dim).to(self.models_manager.device)
self.model.load_state_dict(ckpt, strict=False)
self.model = self.model.float()
self.model.eval()
def init_arpabet_dicts (self):
if len(list(self.arpabet_dict.keys()))==0:
self.refresh_arpabet_dicts()
def refresh_arpabet_dicts (self):
self.arpabet_dict = {}
json_files = sorted(os.listdir(f'{"./resources/app" if self.PROD else "."}/arpabet'))
json_files = [fname for fname in json_files if fname.endswith(".json")]
for fname in json_files:
with codecs.open(f'{"./resources/app" if self.PROD else "."}/arpabet/{fname}', encoding="utf-8") as f:
json_data = json.load(f)
for word in list(json_data["data"].keys()):
if json_data["data"][word]["enabled"]==True:
self.arpabet_dict[word] = json_data["data"][word]["arpabet"]
def run_speech_to_speech (self, audiopath, audio_out_path, style_emb, models_manager, plugin_manager, vc_strength=1, useSR=False, useCleanup=False):
if ".wav" in style_emb:
self.logger.info(f'Getting style emb from: {style_emb}')
style_emb = models_manager.models("speaker_rep").compute_embedding(style_emb).squeeze()
else:
self.logger.info(f'Given style emb')
style_emb = torch.tensor(style_emb).squeeze()
try:
content_emb = models_manager.models("speaker_rep").compute_embedding(audiopath).squeeze()
except:
return "TOO_SHORT"
style_emb = F.normalize(style_emb.unsqueeze(0), dim=1).unsqueeze(-1).to(self.models_manager.device)
content_emb = F.normalize(content_emb.unsqueeze(0), dim=1).unsqueeze(-1).to(self.models_manager.device)
content_emb = content_emb + (-(vc_strength-1) * (style_emb - content_emb))
y, sr = librosa.load(audiopath, sr=22050)
D = librosa.stft(
y=y,
n_fft=1024,
hop_length=256,
win_length=1024,
pad_mode="reflect",
window="hann",
center=True,
)
spec = np.abs(D).astype(np.float32)
ref_spectrogram = torch.FloatTensor(spec).unsqueeze(0)
y_lengths = torch.tensor([ref_spectrogram.size(-1)]).to(self.models_manager.device)
y = ref_spectrogram.to(self.models_manager.device)
wav = self.model.voice_conversion(y=y, y_lengths=y_lengths, spk1_emb=content_emb, spk2_emb=style_emb)
wav = wav.squeeze().cpu().detach().numpy()
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
if useCleanup:
ffmpeg_path = 'ffmpeg' if platform.system() == 'Linux' else f'{"./resources/app" if self.PROD else "."}/python/ffmpeg.exe'
if useSR:
scipy.io.wavfile.write(audio_out_path.replace(".wav", "_preSR.wav"), 22050, wav_norm.astype(np.int16))
else:
scipy.io.wavfile.write(audio_out_path.replace(".wav", "_preCleanupPreFFmpeg.wav"), 22050, wav_norm.astype(np.int16))
stream = ffmpeg.input(audio_out_path.replace(".wav", "_preCleanupPreFFmpeg.wav"))
ffmpeg_options = {"ar": 48000}
output_path = audio_out_path.replace(".wav", "_preCleanup.wav")
stream = ffmpeg.output(stream, output_path, **ffmpeg_options)
out, err = (ffmpeg.run(stream, cmd=ffmpeg_path, capture_stdout=True, capture_stderr=True, overwrite_output=True))
os.remove(audio_out_path.replace(".wav", "_preCleanupPreFFmpeg.wav"))
else:
scipy.io.wavfile.write(audio_out_path.replace(".wav", "_preSR.wav") if useSR else audio_out_path, 22050, wav_norm.astype(np.int16))
if useSR:
self.models_manager.init_model("nuwave2")
self.models_manager.models("nuwave2").sr_audio(audio_out_path.replace(".wav", "_preSR.wav"), audio_out_path.replace(".wav", "_preCleanup.wav") if useCleanup else audio_out_path)
if useCleanup:
self.models_manager.init_model("deepfilternet2")
self.models_manager.models("deepfilternet2").cleanup_audio(audio_out_path.replace(".wav", "_preCleanup.wav"), audio_out_path)
return
def infer_batch(self, plugin_manager, linesBatch, outputJSON, vocoder, speaker_i, old_sequence=None, useSR=False, useCleanup=False):
print(f'Inferring batch of {len(linesBatch)} lines')
text_sequences = []
cleaned_text_sequences = []
lang_embs = []
speaker_embs = []
# [sequence, pitch, duration, pace, tempFileLocation, outPath, outFolder, pitch_amp, base_lang, base_emb, vc_content, vc_style]
vc_input = []
tts_input = []
for ri,record in enumerate(linesBatch):
if record[-2]: # If a VC content file has been given, handle this as VC
vc_input.append(record)
else:
tts_input.append(record)
# =================
# ======= Handle VC
# =================
if len(vc_input):
for ri,record in enumerate(vc_input):
content_emb = self.models_manager.models("speaker_rep").compute_embedding(record[-2]).squeeze()
style_emb = self.models_manager.models("speaker_rep").compute_embedding(record[-1]).squeeze()
# content_emb = F.normalize(content_emb.unsqueeze(0), dim=1).squeeze(0)
# style_emb = F.normalize(style_emb.unsqueeze(0), dim=1).squeeze(0)
content_emb = content_emb.unsqueeze(0).unsqueeze(-1).to(self.models_manager.device)
style_emb = style_emb.unsqueeze(0).unsqueeze(-1).to(self.models_manager.device)
y, sr = librosa.load(record[-2], sr=22050)
D = librosa.stft(
y=y,
n_fft=1024,
hop_length=256,
win_length=1024,
pad_mode="reflect",
window="hann",
center=True,
)
spec = np.abs(D).astype(np.float32)
ref_spectrogram = torch.FloatTensor(spec).unsqueeze(0)
y_lengths = torch.tensor([ref_spectrogram.size(-1)]).to(self.models_manager.device)
y = ref_spectrogram.to(self.models_manager.device)
# Run Voice Conversion
self.model.logger = self.logger
wav = self.model.voice_conversion(y=y, y_lengths=y_lengths, spk1_emb=content_emb, spk2_emb=style_emb)
wav = wav.squeeze().cpu().detach().numpy()
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
if useCleanup:
ffmpeg_path = 'ffmpeg' if platform.system() == 'Linux' else f'{"./resources/app" if self.PROD else "."}/python/ffmpeg.exe'
if useSR:
scipy.io.wavfile.write(tts_input[ri][4].replace(".wav", "_preSR.wav"), 22050, wav_norm.astype(np.int16))
else:
scipy.io.wavfile.write(tts_input[ri][4].replace(".wav", "_preCleanupPreFFmpeg.wav"), 22050, wav_norm.astype(np.int16))
stream = ffmpeg.input(tts_input[ri][4].replace(".wav", "_preCleanupPreFFmpeg.wav"))
ffmpeg_options = {"ar": 48000}
output_path = tts_input[ri][4].replace(".wav", "_preCleanup.wav")
stream = ffmpeg.output(stream, output_path, **ffmpeg_options)
out, err = (ffmpeg.run(stream, cmd=ffmpeg_path, capture_stdout=True, capture_stderr=True, overwrite_output=True))
os.remove(tts_input[ri][4].replace(".wav", "_preCleanupPreFFmpeg.wav"))
else:
scipy.io.wavfile.write(vc_input[ri][4].replace(".wav", "_preSR.wav") if useSR else vc_input[ri][4], 22050, wav_norm.astype(np.int16))
if useSR:
self.models_manager.init_model("nuwave2")
self.models_manager.models("nuwave2").sr_audio(vc_input[ri][4].replace(".wav", "_preSR.wav"), vc_input[ri][4].replace(".wav", "_preCleanup.wav") if useCleanup else vc_input[ri][4])
os.remove(vc_input[ri][4].replace(".wav", "_preSR.wav"))
if useCleanup:
self.models_manager.init_model("deepfilternet2")
self.models_manager.models("deepfilternet2").cleanup_audio(vc_input[ri][4].replace(".wav", "_preCleanup.wav"), vc_input[ri][4])
os.remove(vc_input[ri][4].replace(".wav", "_preCleanup.wav"))
# ==================
# ======= Handle TTS
# ==================
if len(tts_input):
lang_embs_sizes = []
for ri,record in enumerate(tts_input):
# Pre-process text
text = record[0].replace("/lang", "\\lang")
base_lang = record[-4]
self.logger.info(f'[infer_batch] text: {text}')
sequenceSplitByLanguage = self.preprocess_prompt_language(text, base_lang)
# Make sure all languages' text processors are initialized
for subSequence in sequenceSplitByLanguage:
langCode = list(subSequence.keys())[0]
if langCode not in self.lang_tp.keys():
self.lang_tp[langCode] = get_text_preprocessor(langCode, self.base_dir, logger=self.logger)
try:
pad_symb = len(ALL_SYMBOLS)-2
all_sequence = []
all_cleaned_text = []
all_text = []
all_lang_ids = []
# Collapse same-language words into phrases, so that heteronyms can still be detected
sequenceSplitByLanguage_grouped = []
last_lang_group = None
group = ""
for ssi, subSequence in enumerate(sequenceSplitByLanguage):
if list(subSequence.keys())[0]!=last_lang_group:
if last_lang_group is not None:
sequenceSplitByLanguage_grouped.append({last_lang_group: group})
group = ""
last_lang_group = list(subSequence.keys())[0]
group += subSequence[last_lang_group]
if len(group):
sequenceSplitByLanguage_grouped.append({last_lang_group: group})
for ssi, subSequence in enumerate(sequenceSplitByLanguage_grouped):
langCode = list(subSequence.keys())[0]
subSeq = subSequence[langCode]
sequence, cleaned_text = self.lang_tp[langCode].text_to_sequence(subSeq)
if ssi<len(sequenceSplitByLanguage_grouped)-1:
sequence = sequence + [pad_symb]
all_sequence.append(sequence)
all_cleaned_text += ("|"+cleaned_text) if len(all_cleaned_text) else cleaned_text
if ssi<len(sequenceSplitByLanguage_grouped)-1:
all_cleaned_text = all_cleaned_text + ["|<PAD>"]
all_text.append(torch.LongTensor(sequence))
language_id = self.language_id_mapping[langCode]
all_lang_ids += [language_id for _ in range(len(sequence))]
except ValueError as e:
self.logger.info("====")
self.logger.info(str(e))
self.logger.info("====--")
if "not in list" in str(e):
symbol_not_in_list = str(e).split("is not in list")[0].split("ValueError:")[-1].replace("'", "").strip()
return f'ERR: ARPABET_NOT_IN_LIST: {symbol_not_in_list}'
all_cleaned_text = "".join(all_cleaned_text)
text = torch.cat(all_text, dim=0)
cleaned_text_sequences.append(all_cleaned_text)
text = torch.LongTensor(text)
text_sequences.append(text)
lang_ids = torch.tensor(all_lang_ids).to(self.models_manager.device)
lang_embs.append(lang_ids)
lang_embs_sizes.append(lang_ids.shape[0])
speaker_embs.append(torch.tensor(tts_input[ri][-3]).unsqueeze(-1))
lang_embs = pad_sequence(lang_embs, batch_first=True).to(self.models_manager.device)
text_sequences = pad_sequence(text_sequences, batch_first=True).to(self.models_manager.device)
speaker_embs = pad_sequence(speaker_embs, batch_first=True).to(self.models_manager.device)
pace = torch.tensor([record[3] for record in tts_input]).unsqueeze(1).to(self.device)
pitch_amp = torch.tensor([record[7] for record in tts_input]).unsqueeze(1).to(self.device)
# Could pass indexes (and get them returned) to the tts inference fn
# Do the same to the vc infer fn
# Then marge them into their place in an output array?
out = self.model.infer_advanced(self.logger, plugin_manager, [cleaned_text_sequences], text_sequences, lang_embs=lang_embs, speaker_embs=speaker_embs, pace=pace, old_sequence=None, pitch_amp=pitch_amp)
if isinstance(out, str):
return out
else:
output_wav, dur_pred, pitch_pred, energy_pred, _, _, _, _ = out
for i,wav in enumerate(output_wav):
wav = wav.squeeze().cpu().detach().numpy()
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
if useCleanup:
ffmpeg_path = 'ffmpeg' if platform.system() == 'Linux' else f'{"./resources/app" if self.PROD else "."}/python/ffmpeg.exe'
if useSR:
scipy.io.wavfile.write(tts_input[i][4].replace(".wav", "_preSR.wav"), 22050, wav_norm.astype(np.int16))
else:
scipy.io.wavfile.write(tts_input[i][4].replace(".wav", "_preCleanupPreFFmpeg.wav"), 22050, wav_norm.astype(np.int16))
stream = ffmpeg.input(tts_input[i][4].replace(".wav", "_preCleanupPreFFmpeg.wav"))
ffmpeg_options = {"ar": 48000}
output_path = tts_input[i][4].replace(".wav", "_preCleanup.wav")
stream = ffmpeg.output(stream, output_path, **ffmpeg_options)
out, err = (ffmpeg.run(stream, cmd=ffmpeg_path, capture_stdout=True, capture_stderr=True, overwrite_output=True))
os.remove(tts_input[i][4].replace(".wav", "_preCleanupPreFFmpeg.wav"))
else:
scipy.io.wavfile.write(tts_input[i][4].replace(".wav", "_preSR.wav") if useSR else tts_input[i][4], 22050, wav_norm.astype(np.int16))
if useSR:
self.models_manager.init_model("nuwave2")
self.models_manager.models("nuwave2").sr_audio(tts_input[i][4].replace(".wav", "_preSR.wav"), tts_input[i][4].replace(".wav", "_preCleanup.wav") if useCleanup else tts_input[i][4])
os.remove(tts_input[i][4].replace(".wav", "_preSR.wav"))
if useCleanup:
self.models_manager.init_model("deepfilternet2")
self.models_manager.models("deepfilternet2").cleanup_audio(tts_input[i][4].replace(".wav", "_preCleanup.wav"), tts_input[i][4])
os.remove(tts_input[i][4].replace(".wav", "_preCleanup.wav"))
if outputJSON:
for ri, record in enumerate(tts_input):
# tts_input: sequence, pitch, duration, pace, tempFileLocation, outPath, outFolder
output_fname = tts_input[ri][5].replace(".wav", ".json")
containing_folder = "/".join(output_fname.split("/")[:-1])
os.makedirs(containing_folder, exist_ok=True)
with open(output_fname, "w+") as f:
data = {}
data["modelType"] = "xVAPitch"
data["inputSequence"] = str(tts_input[ri][0])
data["pacing"] = float(tts_input[ri][3])
data["letters"] = [char.replace("{", "").replace("}", "") for char in list(cleaned_text_sequences[ri].split("|"))]
data["currentVoice"] = self.ckpt_path.split("/")[-1].replace(".pt", "")
# data["resetEnergy"] = [float(val) for val in list(energy_pred[ri].cpu().detach().numpy())]
data["resetEnergy"] = [float(1) for val in list(pitch_pred[ri][0].cpu().detach().numpy())]
data["resetPitch"] = [float(val) for val in list(pitch_pred[ri][0].cpu().detach().numpy())]
data["resetDurs"] = [float(val) for val in list(dur_pred[ri][0].cpu().detach().numpy())]
data["ampFlatCounter"] = 0
data["pitchNew"] = data["resetPitch"]
data["energyNew"] = data["resetEnergy"]
data["dursNew"] = data["resetDurs"]
f.write(json.dumps(data, indent=4))
return ""
# Split words by space, while also breaking out the \land[code][text] formatting
def splitWords (self, sequence, addSpace=False):
words = []
for word in sequence:
if word.startswith("\\lang["):
words.append(word.split("][")[0]+"][")
word = word.split("][")[1]
for char in ["}","]","[","{"]:
if word.startswith(char):
words.append(char)
word = word[1:]
end_extras = []
for char in ["}","]","[","{"]:
if word.endswith(char):
end_extras.append(char)
word = word[:-1]
words.append(word)
end_extras.reverse()
for extra in end_extras:
words.append(extra)
if addSpace:
words.append(" ")
return words
def preprocess_prompt_language (self, sequence, base_lang):
# Separate the ARPAbet brackets from punctuation
sequence = sequence.replace("}.", "} .")
sequence = sequence.replace("}!", "} !")
sequence = sequence.replace("}?", "} ?")
sequence = sequence.replace("},", "} ,")
sequence = sequence.replace("}\"", "} \"")
sequence = sequence.replace("}'", "} '")
sequence = sequence.replace("}-", "} -")
sequence = sequence.replace("})", "} )")
sequence = sequence.replace(".{", ". {")
sequence = sequence.replace("!{", "! {")
sequence = sequence.replace("?{", "? {")
sequence = sequence.replace(",{", ", {")
sequence = sequence.replace("\"{", "\" {")
sequence = sequence.replace("'{", "' {")
sequence = sequence.replace("-{", "- {")
sequence = sequence.replace("({", "( {")
# Prepare the input sequence for processing. Do a few times to catch edge cases
sequence = self.splitWords(sequence.split(" "), True)
sequence = self.splitWords(sequence)
sequence = self.splitWords(sequence)
sequence = self.splitWords(sequence)
subSequences = []
openedLangs = 0
langs_stack = [base_lang]
for word in sequence:
skip_word = False
if word.startswith("\\lang["):
openedLangs += 1
langs_stack.append(word.split("lang[")[1].split("]")[0])
skip_word = True
if word.endswith("]"):
openedLangs -= 1
langs_stack.pop()
skip_word = True
# Add the word to the list if not skipping it, if it's not empty, or it's not a second space in a row
if not skip_word and len(word) and (word!=" " or len(subSequences)==0 or subSequences[-1][list(subSequences[-1].keys())[0]]!=" "):
subSequences.append({langs_stack[-1]: word})
subSequences_collapsed = []
current_open_arpabet = []
last_lang = None
is_in_arpabet = False
# Collapse groups of inlined ARPABet symbols, to have them treated as such
for subSequence in subSequences:
ss_lang = list(subSequence.keys())[0]
ss_val = subSequence[ss_lang]
if ss_lang is not last_lang:
if len(current_open_arpabet):
subSequences_collapsed.append({ss_lang: "{"+" ".join(current_open_arpabet).replace(" "," ")+"}"})
current_open_arpabet = []
last_lang = ss_lang
if ss_val.strip()=="{":
is_in_arpabet = True
elif ss_val.strip()=="}":
subSequences_collapsed.append({ss_lang: "{"+" ".join(current_open_arpabet).replace(" "," ")+"}"})
current_open_arpabet = []
is_in_arpabet = False
else:
if is_in_arpabet:
current_open_arpabet.append(ss_val)
else:
subSequences_collapsed.append({ss_lang: ss_val})
return subSequences_collapsed
def getG2P (self, text, base_lang):
sequenceSplitByLanguage = self.preprocess_prompt_language(text, base_lang)
# Make sure all languages' text processors are initialized
for subSequence in sequenceSplitByLanguage:
langCode = list(subSequence.keys())[0]
if langCode not in self.lang_tp.keys():
self.lang_tp[langCode] = get_text_preprocessor(langCode, self.base_dir, logger=self.logger)
returnString = "{"
langs_stack = [base_lang]
last_lang = base_lang
for subSequence in sequenceSplitByLanguage:
langCode = list(subSequence.keys())[0]
subSeq = subSequence[langCode]
sequence, cleaned_text = self.lang_tp[langCode].text_to_sequence(subSeq)
if langCode != last_lang:
last_lang = langCode
if len(langs_stack)>1 and langs_stack[-2]==langCode:
langs_stack.pop()
if returnString[-1]=="}":
returnString = returnString[:-1]
returnString += "]}"
else:
langs_stack.append(langCode)
if returnString[-1]=="{":
returnString = returnString[:-1]
returnString += f'\\lang[{langCode}][' + "{"
returnString += " ".join([symb for symb in cleaned_text.split("|") if symb != "<PAD>"]).replace("_", "} {")
if returnString[-1]=="{":
returnString = returnString[:-1]
else:
returnString = returnString+"}"
returnString = returnString.replace(".}", "}.")
returnString = returnString.replace(",}", "},")
returnString = returnString.replace("!}", "}!")
returnString = returnString.replace("?}", "}?")
returnString = returnString.replace("]}", "}]")
returnString = returnString.replace("}]}", "}]")
returnString = returnString.replace("{"+"}", "")
returnString = returnString.replace("}"+"}", "}")
returnString = returnString.replace("{"+"{", "{")
return returnString
def infer(self, plugin_manager, text, out_path, vocoder, speaker_i, pace=1.0, editor_data=None, old_sequence=None, globalAmplitudeModifier=None, base_lang="en", base_emb=None, useSR=False, useCleanup=False):
sequenceSplitByLanguage = self.preprocess_prompt_language(text, base_lang)
# Make sure all languages' text processors are initialized
for subSequence in sequenceSplitByLanguage:
langCode = list(subSequence.keys())[0]
if langCode not in self.lang_tp.keys():
self.lang_tp[langCode] = get_text_preprocessor(langCode, self.base_dir, logger=self.logger)
try:
pad_symb = len(ALL_SYMBOLS)-2
all_sequence = []
all_cleaned_text = []
all_text = []
all_lang_ids = []
# Collapse same-language words into phrases, so that heteronyms can still be detected
sequenceSplitByLanguage_grouped = []
last_lang_group = None
group = ""
for ssi, subSequence in enumerate(sequenceSplitByLanguage):
if list(subSequence.keys())[0]!=last_lang_group:
if last_lang_group is not None:
sequenceSplitByLanguage_grouped.append({last_lang_group: group})
group = ""
last_lang_group = list(subSequence.keys())[0]
group += subSequence[last_lang_group]
if len(group):
sequenceSplitByLanguage_grouped.append({last_lang_group: group})
for ssi, subSequence in enumerate(sequenceSplitByLanguage_grouped):
langCode = list(subSequence.keys())[0]
subSeq = subSequence[langCode]
sequence, cleaned_text = self.lang_tp[langCode].text_to_sequence(subSeq)
if ssi<len(sequenceSplitByLanguage_grouped)-1:
sequence = sequence + [pad_symb]
all_sequence.append(sequence)
all_cleaned_text += ("|"+cleaned_text) if len(all_cleaned_text) else cleaned_text
if ssi<len(sequenceSplitByLanguage_grouped)-1:
all_cleaned_text = all_cleaned_text + ["|<PAD>"]
all_text.append(torch.LongTensor(sequence))
language_id = self.language_id_mapping[langCode]
all_lang_ids += [language_id for _ in range(len(sequence))]
except ValueError as e:
self.logger.info("====")
self.logger.info(str(e))
self.logger.info("====--")
if "not in list" in str(e):
symbol_not_in_list = str(e).split("is not in list")[0].split("ValueError:")[-1].replace("'", "").strip()
return f'ERR: ARPABET_NOT_IN_LIST: {symbol_not_in_list}'
all_cleaned_text = "".join(all_cleaned_text)
text = torch.cat(all_text, dim=0)
text = pad_sequence([text], batch_first=True).to(self.models_manager.device)
with torch.no_grad():
if old_sequence is not None:
old_sequence = re.sub(r'[^a-zA-Z\s\(\)\[\]0-9\?\.\,\!\'\{\}\_\@]+', '', old_sequence)
old_sequence, clean_old_sequence = self.lang_tp[base_lang].text_to_sequence(old_sequence)#, "english_basic", ['english_cleaners'])
old_sequence = torch.LongTensor(old_sequence)
old_sequence = pad_sequence([old_sequence], batch_first=True).to(self.models_manager.device)
lang_ids = torch.tensor(all_lang_ids).to(self.models_manager.device)
num_embs = text.shape[1]
base_emb = [float(val) for val in base_emb.split(",")] if "," in base_emb else self.base_emb
speaker_embs = [torch.tensor(base_emb).unsqueeze(dim=0)[0].unsqueeze(-1)]
speaker_embs = torch.stack(speaker_embs, dim=0).to(self.models_manager.device)#.unsqueeze(-1)
speaker_embs = speaker_embs.repeat(1,1,num_embs)
# Do interpolations of speaker style embeddings
if editor_data is not None:
editorStyles = editor_data[-1]
if editorStyles is not None:
style_keys = list(editorStyles.keys())
for style_key in style_keys:
emb = editorStyles[style_key]["embedding"]
sliders_vals = editorStyles[style_key]["sliders"]
style_embs = [torch.tensor(emb).unsqueeze(dim=0)[0].unsqueeze(-1)]
style_embs = torch.stack(style_embs, dim=0).to(self.models_manager.device)#.unsqueeze(-1)
style_embs = style_embs.repeat(1,1,num_embs)
sliders_vals = torch.tensor(sliders_vals).to(self.models_manager.device)
speaker_embs = speaker_embs*(1-sliders_vals) + sliders_vals*style_embs
speaker_embs = speaker_embs.float()
lang_embs = lang_ids # TODO, use pre-extracted trained language embeddings, for interpolation
out = self.model.infer_advanced(self.logger, plugin_manager, [all_cleaned_text], text, lang_embs=lang_embs, speaker_embs=speaker_embs, pace=pace, editor_data=editor_data, old_sequence=old_sequence)
if isinstance(out, str):
return f'ERR:{out}'
else:
output_wav, dur_pred, pitch_pred, energy_pred, em_pred, start_index, end_index, wav_mult = out
[em_angry_pred, em_happy_pred, em_sad_pred, em_surprise_pred] = em_pred
wav = output_wav.squeeze().cpu().detach().numpy()
wav_norm = wav * (32767 / max(0.01, np.max(np.abs(wav))))
if wav_mult is not None:
wav_norm = wav_norm * wav_mult
if useCleanup:
ffmpeg_path = 'ffmpeg' if platform.system() == 'Linux' else f'{"./resources/app" if self.PROD else "."}/python/ffmpeg.exe'
if useSR:
scipy.io.wavfile.write(out_path.replace(".wav", "_preSR.wav"), 22050, wav_norm.astype(np.int16))
else:
scipy.io.wavfile.write(out_path.replace(".wav", "_preCleanupPreFFmpeg.wav"), 22050, wav_norm.astype(np.int16))
stream = ffmpeg.input(out_path.replace(".wav", "_preCleanupPreFFmpeg.wav"))
ffmpeg_options = {"ar": 48000}
output_path = out_path.replace(".wav", "_preCleanup.wav")
stream = ffmpeg.output(stream, output_path, **ffmpeg_options)
out, err = (ffmpeg.run(stream, cmd=ffmpeg_path, capture_stdout=True, capture_stderr=True, overwrite_output=True))
os.remove(out_path.replace(".wav", "_preCleanupPreFFmpeg.wav"))
else:
scipy.io.wavfile.write(out_path.replace(".wav", "_preSR.wav") if useSR else out_path, 22050, wav_norm.astype(np.int16))
if useSR:
self.models_manager.init_model("nuwave2")
self.models_manager.models("nuwave2").sr_audio(out_path.replace(".wav", "_preSR.wav"), out_path.replace(".wav", "_preCleanup.wav") if useCleanup else out_path)
if useCleanup:
self.models_manager.init_model("deepfilternet2")
self.models_manager.models("deepfilternet2").cleanup_audio(out_path.replace(".wav", "_preCleanup.wav"), out_path)
[pitch, durations, energy, em_angry, em_happy, em_sad, em_surprise] = [
pitch_pred.squeeze().cpu().detach().numpy(),
dur_pred.squeeze().cpu().detach().numpy(),
energy_pred.cpu().detach().numpy() if energy_pred is not None else [],
em_angry_pred.squeeze().cpu().detach().numpy() if em_angry_pred is not None else [],
em_happy_pred.squeeze().cpu().detach().numpy() if em_happy_pred is not None else [],
em_sad_pred.squeeze().cpu().detach().numpy() if em_sad_pred is not None else [],
em_surprise_pred.squeeze().cpu().detach().numpy() if em_surprise_pred is not None else [],
]
pitch = [float(v) for v in pitch]
durations = [float(v) for v in durations]
energy = [float(v) for v in energy]
em_angry = [float(v) for v in em_angry]
em_happy = [float(v) for v in em_happy]
em_sad = [float(v) for v in em_sad]
em_surprise = [float(v) for v in em_surprise]
del pitch_pred, dur_pred, energy_pred, text, sequence
return {
"pitch": pitch,
"durations": durations,
"energy": energy,
"em_angry": em_angry,
"em_happy": em_happy,
"em_sad": em_sad,
"em_surprise": em_surprise,
"editorStyles": json.dumps(editorStyles),
"arpabet": all_cleaned_text
}
def set_device (self, device):
self.device = device
self.model = self.model.to(device)
self.model.pitch_emb_values = self.model.pitch_emb_values.to(device)
self.model.device = device