soni_cloned / voice_main.py
test-rtechs's picture
Update voice_main.py
31c69b4 verified
from soni_translate.logging_setup import logger
import torch
import gc
import numpy as np
import os
import shutil
import warnings
import threading
from tqdm import tqdm
from lib.infer_pack.models import (
SynthesizerTrnMs256NSFsid,
SynthesizerTrnMs256NSFsid_nono,
SynthesizerTrnMs768NSFsid,
SynthesizerTrnMs768NSFsid_nono,
)
from lib.audio import load_audio
import soundfile as sf
import edge_tts
import asyncio
from soni_translate.utils import (
remove_directory_contents,
create_directories,
write_chunked,
)
from scipy import signal
from time import time as ttime
import faiss
from vci_pipeline import VC, change_rms, bh, ah
import librosa
warnings.filterwarnings("ignore")
class Config:
def __init__(self, only_cpu=False):
self.device = "cuda:0"
self.is_half = True
self.n_cpu = 0
self.gpu_name = None
self.gpu_mem = None
(
self.x_pad,
self.x_query,
self.x_center,
self.x_max
) = self.device_config(only_cpu)
def device_config(self, only_cpu) -> tuple:
if torch.cuda.is_available() and not only_cpu:
i_device = int(self.device.split(":")[-1])
self.gpu_name = torch.cuda.get_device_name(i_device)
if (
("16" in self.gpu_name and "V100" not in self.gpu_name.upper())
or "P40" in self.gpu_name.upper()
or "1060" in self.gpu_name
or "1070" in self.gpu_name
or "1080" in self.gpu_name
):
logger.info(
"16/10 Series GPUs and P40 excel "
"in single-precision tasks."
)
self.is_half = False
else:
self.gpu_name = None
self.gpu_mem = int(
torch.cuda.get_device_properties(i_device).total_memory
/ 1024
/ 1024
/ 1024
+ 0.4
)
elif torch.backends.mps.is_available() and not only_cpu:
logger.info("Supported N-card not found, using MPS for inference")
self.device = "mps"
else:
logger.info("No supported N-card found, using CPU for inference")
self.device = "cpu"
self.is_half = False
if self.n_cpu == 0:
self.n_cpu = os.cpu_count()
if self.is_half:
# 6GB VRAM configuration
x_pad = 3
x_query = 10
x_center = 60
x_max = 65
else:
# 5GB VRAM configuration
x_pad = 1
x_query = 6
x_center = 38
x_max = 41
if self.gpu_mem is not None and self.gpu_mem <= 4:
x_pad = 1
x_query = 5
x_center = 30
x_max = 32
logger.info(
f"Config: Device is {self.device}, "
f"half precision is {self.is_half}"
)
return x_pad, x_query, x_center, x_max
BASE_DOWNLOAD_LINK = "https://huggingface.co/r3gm/sonitranslate_voice_models/resolve/main/"
BASE_MODELS = [
"hubert_base.pt",
"rmvpe.pt"
]
BASE_DIR = "."
def load_hu_bert(config):
from fairseq import checkpoint_utils
from soni_translate.utils import download_manager
for id_model in BASE_MODELS:
download_manager(
os.path.join(BASE_DOWNLOAD_LINK, id_model), BASE_DIR
)
models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
["hubert_base.pt"],
suffix="",
)
hubert_model = models[0]
hubert_model = hubert_model.to(config.device)
if config.is_half:
hubert_model = hubert_model.half()
else:
hubert_model = hubert_model.float()
hubert_model.eval()
return hubert_model
def load_trained_model(model_path, config):
if not model_path:
raise ValueError("No model found")
logger.info("Loading %s" % model_path)
cpt = torch.load(model_path, map_location="cpu")
tgt_sr = cpt["config"][-1]
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
if_f0 = cpt.get("f0", 1)
if if_f0 == 0:
# protect to 0.5 need?
pass
version = cpt.get("version", "v1")
if version == "v1":
if if_f0 == 1:
net_g = SynthesizerTrnMs256NSFsid(
*cpt["config"], is_half=config.is_half
)
else:
net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
elif version == "v2":
if if_f0 == 1:
net_g = SynthesizerTrnMs768NSFsid(
*cpt["config"], is_half=config.is_half
)
else:
net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
del net_g.enc_q
net_g.load_state_dict(cpt["weight"], strict=False)
net_g.eval().to(config.device)
if config.is_half:
net_g = net_g.half()
else:
net_g = net_g.float()
vc = VC(tgt_sr, config)
n_spk = cpt["config"][-3]
return n_spk, tgt_sr, net_g, vc, cpt, version
class ClassVoices:
def __init__(self, only_cpu=False):
self.model_config = {}
self.config = None
self.only_cpu = only_cpu
def apply_conf(
self,
tag="base_model",
file_model="",
pitch_algo="pm",
pitch_lvl=0,
file_index="",
index_influence=0.66,
respiration_median_filtering=3,
envelope_ratio=0.25,
consonant_breath_protection=0.33,
resample_sr=0,
file_pitch_algo="",
):
if not file_model:
raise ValueError("Model not found")
if file_index is None:
file_index = ""
if file_pitch_algo is None:
file_pitch_algo = ""
if not self.config:
self.config = Config(self.only_cpu)
self.hu_bert_model = None
self.model_pitch_estimator = None
self.model_config[tag] = {
"file_model": file_model,
"pitch_algo": pitch_algo,
"pitch_lvl": pitch_lvl, # no decimal
"file_index": file_index,
"index_influence": index_influence,
"respiration_median_filtering": respiration_median_filtering,
"envelope_ratio": envelope_ratio,
"consonant_breath_protection": consonant_breath_protection,
"resample_sr": resample_sr,
"file_pitch_algo": file_pitch_algo,
}
return f"CONFIGURATION APPLIED FOR {tag}: {file_model}"
def infer(
self,
task_id,
params,
# load model
n_spk,
tgt_sr,
net_g,
pipe,
cpt,
version,
if_f0,
# load index
index_rate,
index,
big_npy,
# load f0 file
inp_f0,
# audio file
input_audio_path,
overwrite,
):
f0_method = params["pitch_algo"]
f0_up_key = params["pitch_lvl"]
filter_radius = params["respiration_median_filtering"]
resample_sr = params["resample_sr"]
rms_mix_rate = params["envelope_ratio"]
protect = params["consonant_breath_protection"]
if not os.path.exists(input_audio_path):
raise ValueError(
"The audio file was not found or is not "
f"a valid file: {input_audio_path}"
)
f0_up_key = int(f0_up_key)
audio = load_audio(input_audio_path, 16000)
# Normalize audio
audio_max = np.abs(audio).max() / 0.95
if audio_max > 1:
audio /= audio_max
times = [0, 0, 0]
# filters audio signal, pads it, computes sliding window sums,
# and extracts optimized time indices
audio = signal.filtfilt(bh, ah, audio)
audio_pad = np.pad(
audio, (pipe.window // 2, pipe.window // 2), mode="reflect"
)
opt_ts = []
if audio_pad.shape[0] > pipe.t_max:
audio_sum = np.zeros_like(audio)
for i in range(pipe.window):
audio_sum += audio_pad[i:i - pipe.window]
for t in range(pipe.t_center, audio.shape[0], pipe.t_center):
opt_ts.append(
t
- pipe.t_query
+ np.where(
np.abs(audio_sum[t - pipe.t_query: t + pipe.t_query])
== np.abs(audio_sum[t - pipe.t_query: t + pipe.t_query]).min()
)[0][0]
)
s = 0
audio_opt = []
t = None
t1 = ttime()
sid_value = 0
sid = torch.tensor(sid_value, device=pipe.device).unsqueeze(0).long()
# Pads audio symmetrically, calculates length divided by window size.
audio_pad = np.pad(audio, (pipe.t_pad, pipe.t_pad), mode="reflect")
p_len = audio_pad.shape[0] // pipe.window
# Estimates pitch from audio signal
pitch, pitchf = None, None
if if_f0 == 1:
pitch, pitchf = pipe.get_f0(
input_audio_path,
audio_pad,
p_len,
f0_up_key,
f0_method,
filter_radius,
inp_f0,
)
pitch = pitch[:p_len]
pitchf = pitchf[:p_len]
if pipe.device == "mps":
pitchf = pitchf.astype(np.float32)
pitch = torch.tensor(
pitch, device=pipe.device
).unsqueeze(0).long()
pitchf = torch.tensor(
pitchf, device=pipe.device
).unsqueeze(0).float()
t2 = ttime()
times[1] += t2 - t1
for t in opt_ts:
t = t // pipe.window * pipe.window
if if_f0 == 1:
pitch_slice = pitch[
:, s // pipe.window: (t + pipe.t_pad2) // pipe.window
]
pitchf_slice = pitchf[
:, s // pipe.window: (t + pipe.t_pad2) // pipe.window
]
else:
pitch_slice = None
pitchf_slice = None
audio_slice = audio_pad[s:t + pipe.t_pad2 + pipe.window]
audio_opt.append(
pipe.vc(
self.hu_bert_model,
net_g,
sid,
audio_slice,
pitch_slice,
pitchf_slice,
times,
index,
big_npy,
index_rate,
version,
protect,
)[pipe.t_pad_tgt:-pipe.t_pad_tgt]
)
s = t
pitch_end_slice = pitch[
:, t // pipe.window:
] if t is not None else pitch
pitchf_end_slice = pitchf[
:, t // pipe.window:
] if t is not None else pitchf
audio_opt.append(
pipe.vc(
self.hu_bert_model,
net_g,
sid,
audio_pad[t:],
pitch_end_slice,
pitchf_end_slice,
times,
index,
big_npy,
index_rate,
version,
protect,
)[pipe.t_pad_tgt:-pipe.t_pad_tgt]
)
audio_opt = np.concatenate(audio_opt)
if rms_mix_rate != 1:
audio_opt = change_rms(
audio, 16000, audio_opt, tgt_sr, rms_mix_rate
)
if resample_sr >= 16000 and tgt_sr != resample_sr:
audio_opt = librosa.resample(
audio_opt, orig_sr=tgt_sr, target_sr=resample_sr
)
audio_max = np.abs(audio_opt).max() / 0.99
max_int16 = 32768
if audio_max > 1:
max_int16 /= audio_max
audio_opt = (audio_opt * max_int16).astype(np.int16)
del pitch, pitchf, sid
if torch.cuda.is_available():
torch.cuda.empty_cache()
if tgt_sr != resample_sr >= 16000:
final_sr = resample_sr
else:
final_sr = tgt_sr
"""
"Success.\n %s\nTime:\n npy:%ss, f0:%ss, infer:%ss" % (
times[0],
times[1],
times[2],
), (final_sr, audio_opt)
"""
if overwrite:
output_audio_path = input_audio_path # Overwrite
else:
basename = os.path.basename(input_audio_path)
dirname = os.path.dirname(input_audio_path)
new_basename = basename.split(
'.')[0] + "_edited." + basename.split('.')[-1]
new_path = os.path.join(dirname, new_basename)
logger.info(str(new_path))
output_audio_path = new_path
# Save file
write_chunked(
file=output_audio_path,
samplerate=final_sr,
data=audio_opt,
format="ogg",
subtype="vorbis",
)
self.model_config[task_id]["result"].append(output_audio_path)
self.output_list.append(output_audio_path)
def make_test(
self,
tts_text,
tts_voice,
model_path,
index_path,
transpose,
f0_method,
):
folder_test = "test"
tag = "test_edge"
tts_file = "test/test.wav"
tts_edited = "test/test_edited.wav"
create_directories(folder_test)
remove_directory_contents(folder_test)
if "SET_LIMIT" == os.getenv("DEMO"):
if len(tts_text) > 60:
tts_text = tts_text[:60]
logger.warning("DEMO; limit to 60 characters")
try:
asyncio.run(edge_tts.Communicate(
tts_text, "-".join(tts_voice.split('-')[:-1])
).save(tts_file))
except Exception as e:
raise ValueError(
"No audio was received. Please change the "
f"tts voice for {tts_voice}. Error: {str(e)}"
)
shutil.copy(tts_file, tts_edited)
self.apply_conf(
tag=tag,
file_model=model_path,
pitch_algo=f0_method,
pitch_lvl=transpose,
file_index=index_path,
index_influence=0.66,
respiration_median_filtering=3,
envelope_ratio=0.25,
consonant_breath_protection=0.33,
)
self(
audio_files=tts_edited,
tag_list=tag,
overwrite=True
)
return tts_edited, tts_file
def run_threads(self, threads):
# Start threads
for thread in threads:
thread.start()
# Wait for all threads to finish
for thread in threads:
thread.join()
gc.collect()
torch.cuda.empty_cache()
def unload_models(self):
self.hu_bert_model = None
self.model_pitch_estimator = None
gc.collect()
torch.cuda.empty_cache()
def __call__(
self,
audio_files=[],
tag_list=[],
overwrite=False,
parallel_workers=1,
):
logger.info(f"Parallel workers: {str(parallel_workers)}")
self.output_list = []
if not self.model_config:
raise ValueError("No model has been configured for inference")
if isinstance(audio_files, str):
audio_files = [audio_files]
if isinstance(tag_list, str):
tag_list = [tag_list]
if not audio_files:
raise ValueError("No audio found to convert")
if not tag_list:
tag_list = [list(self.model_config.keys())[-1]] * len(audio_files)
if len(audio_files) > len(tag_list):
logger.info("Extend tag list to match audio files")
extend_number = len(audio_files) - len(tag_list)
tag_list.extend([tag_list[0]] * extend_number)
if len(audio_files) < len(tag_list):
logger.info("Cut list tags")
tag_list = tag_list[:len(audio_files)]
tag_file_pairs = list(zip(tag_list, audio_files))
sorted_tag_file = sorted(tag_file_pairs, key=lambda x: x[0])
# Base params
if not self.hu_bert_model:
self.hu_bert_model = load_hu_bert(self.config)
cache_params = None
threads = []
progress_bar = tqdm(total=len(tag_list), desc="Progress")
for i, (id_tag, input_audio_path) in enumerate(sorted_tag_file):
if id_tag not in self.model_config.keys():
logger.info(
f"No configured model for {id_tag} with {input_audio_path}"
)
continue
if (
len(threads) >= parallel_workers
or cache_params != id_tag
and cache_params is not None
):
self.run_threads(threads)
progress_bar.update(len(threads))
threads = []
if cache_params != id_tag:
self.model_config[id_tag]["result"] = []
# Unload previous
(
n_spk,
tgt_sr,
net_g,
pipe,
cpt,
version,
if_f0,
index_rate,
index,
big_npy,
inp_f0,
) = [None] * 11
gc.collect()
torch.cuda.empty_cache()
# Model params
params = self.model_config[id_tag]
model_path = params["file_model"]
f0_method = params["pitch_algo"]
file_index = params["file_index"]
index_rate = params["index_influence"]
f0_file = params["file_pitch_algo"]
# Load model
(
n_spk,
tgt_sr,
net_g,
pipe,
cpt,
version
) = load_trained_model(model_path, self.config)
if_f0 = cpt.get("f0", 1) # pitch data
# Load index
if os.path.exists(file_index) and index_rate != 0:
try:
index = faiss.read_index(file_index)
big_npy = index.reconstruct_n(0, index.ntotal)
except Exception as error:
logger.error(f"Index: {str(error)}")
index_rate = 0
index = big_npy = None
else:
logger.warning("File index not found")
index_rate = 0
index = big_npy = None
# Load f0 file
inp_f0 = None
if os.path.exists(f0_file):
try:
with open(f0_file, "r") as f:
lines = f.read().strip("\n").split("\n")
inp_f0 = []
for line in lines:
inp_f0.append([float(i) for i in line.split(",")])
inp_f0 = np.array(inp_f0, dtype="float32")
except Exception as error:
logger.error(f"f0 file: {str(error)}")
if "rmvpe" in f0_method:
if not self.model_pitch_estimator:
from lib.rmvpe import RMVPE
logger.info("Loading vocal pitch estimator model")
self.model_pitch_estimator = RMVPE(
"rmvpe.pt",
is_half=self.config.is_half,
device=self.config.device
)
pipe.model_rmvpe = self.model_pitch_estimator
cache_params = id_tag
# self.infer(
# id_tag,
# params,
# # load model
# n_spk,
# tgt_sr,
# net_g,
# pipe,
# cpt,
# version,
# if_f0,
# # load index
# index_rate,
# index,
# big_npy,
# # load f0 file
# inp_f0,
# # output file
# input_audio_path,
# overwrite,
# )
thread = threading.Thread(
target=self.infer,
args=(
id_tag,
params,
# loaded model
n_spk,
tgt_sr,
net_g,
pipe,
cpt,
version,
if_f0,
# loaded index
index_rate,
index,
big_npy,
# loaded f0 file
inp_f0,
# audio file
input_audio_path,
overwrite,
)
)
threads.append(thread)
# Run last
if threads:
self.run_threads(threads)
progress_bar.update(len(threads))
progress_bar.close()
final_result = []
valid_tags = set(tag_list)
for tag in valid_tags:
if (
tag in self.model_config.keys()
and "result" in self.model_config[tag].keys()
):
final_result.extend(self.model_config[tag]["result"])
return final_result