RVC-UI / infer /lib /rtrvc.py
Blane187's picture
Upload folder using huggingface_hub
c7b379a verified
raw
history blame
13.6 kB
from io import BytesIO
import os
from typing import Union, Literal, Optional
import fairseq
import faiss
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchaudio.transforms import Resample
from rvc.f0 import PM, Harvest, RMVPE, CRePE, Dio, FCPE
from rvc.synthesizer import load_synthesizer
class RVC:
def __init__(
self,
key: Union[int, float],
formant: Union[int, float],
pth_path: torch.serialization.FILE_LIKE,
index_path: str,
index_rate: Union[int, float],
n_cpu: int = os.cpu_count(),
device: str = "cpu",
use_jit: bool = False,
is_half: bool = False,
is_dml: bool = False,
) -> None:
if is_dml:
def forward_dml(ctx, x, scale):
ctx.scale = scale
res = x.clone().detach()
return res
fairseq.modules.grad_multiply.GradMultiply.forward = forward_dml
self.device = device
self.f0_up_key = key
self.formant_shift = formant
self.sr = 16000 # hubert sampling rate
self.window = 160 # hop length
self.f0_min = 50
self.f0_max = 1100
self.f0_mel_min = 1127 * np.log(1 + self.f0_min / 700)
self.f0_mel_max = 1127 * np.log(1 + self.f0_max / 700)
self.n_cpu = n_cpu
self.use_jit = use_jit
self.is_half = is_half
if index_rate > 0:
self.index = faiss.read_index(index_path)
self.big_npy = self.index.reconstruct_n(0, self.index.ntotal)
self.pth_path = pth_path
self.index_path = index_path
self.index_rate = index_rate
self.cache_pitch: torch.Tensor = torch.zeros(
1024, device=self.device, dtype=torch.long
)
self.cache_pitchf = torch.zeros(1024, device=self.device, dtype=torch.float32)
self.resample_kernel = {}
self.f0_methods = {
"crepe": self._get_f0_crepe,
"rmvpe": self._get_f0_rmvpe,
"fcpe": self._get_f0_fcpe,
"pm": self._get_f0_pm,
"harvest": self._get_f0_harvest,
"dio": self._get_f0_dio,
}
models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
["assets/hubert/hubert_base.pt"],
suffix="",
)
hubert_model = models[0]
hubert_model = hubert_model.to(self.device)
if self.is_half:
hubert_model = hubert_model.half()
else:
hubert_model = hubert_model.float()
hubert_model.eval()
self.hubert = hubert_model
self.net_g: Optional[nn.Module] = None
def set_default_model():
self.net_g, cpt = load_synthesizer(self.pth_path, self.device)
self.tgt_sr = cpt["config"][-1]
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0]
self.if_f0 = cpt.get("f0", 1)
self.version = cpt.get("version", "v1")
if self.is_half:
self.net_g = self.net_g.half()
else:
self.net_g = self.net_g.float()
def set_jit_model():
from rvc.jit import get_jit_model
from rvc.synthesizer import synthesizer_jit_export
cpt = get_jit_model(self.pth_path, self.is_half, synthesizer_jit_export)
self.tgt_sr = cpt["config"][-1]
self.if_f0 = cpt.get("f0", 1)
self.version = cpt.get("version", "v1")
self.net_g = torch.jit.load(BytesIO(cpt["model"]), map_location=self.device)
self.net_g.infer = self.net_g.forward
self.net_g.eval().to(self.device)
if (
self.use_jit
and not is_dml
and not (self.is_half and "cpu" in str(self.device))
):
set_jit_model()
else:
set_default_model()
def set_key(self, new_key):
self.f0_up_key = new_key
def set_formant(self, new_formant):
self.formant_shift = new_formant
def set_index_rate(self, new_index_rate):
if new_index_rate > 0 and self.index_rate <= 0:
self.index = faiss.read_index(self.index_path)
self.big_npy = self.index.reconstruct_n(0, self.index.ntotal)
self.index_rate = new_index_rate
def infer(
self,
input_wav: torch.Tensor,
block_frame_16k: int,
skip_head: int,
return_length: int,
f0method: Union[tuple, str],
inp_f0: Optional[np.ndarray] = None,
protect: float = 1.0,
) -> np.ndarray:
with torch.no_grad():
if self.is_half:
feats = input_wav.half()
else:
feats = input_wav.float()
feats = feats.to(self.device)
if feats.dim() == 2: # double channels
feats = feats.mean(-1)
feats = feats.view(1, -1)
padding_mask = torch.BoolTensor(feats.shape).to(self.device).fill_(False)
inputs = {
"source": feats,
"padding_mask": padding_mask,
"output_layer": 9 if self.version == "v1" else 12,
}
logits = self.hubert.extract_features(**inputs)
feats = (
self.hubert.final_proj(logits[0]) if self.version == "v1" else logits[0]
)
feats = torch.cat((feats, feats[:, -1:, :]), 1)
if protect < 0.5 and self.if_f0 == 1:
feats0 = feats.clone()
try:
if hasattr(self, "index") and self.index_rate > 0:
npy = feats[0][skip_head // 2 :].cpu().numpy()
if self.is_half:
npy = npy.astype("float32")
score, ix = self.index.search(npy, k=8)
if (ix >= 0).all():
weight = np.square(1 / score)
weight /= weight.sum(axis=1, keepdims=True)
npy = np.sum(
self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1
)
if self.is_half:
npy = npy.astype("float16")
feats[0][skip_head // 2 :] = (
torch.from_numpy(npy).unsqueeze(0).to(self.device)
* self.index_rate
+ (1 - self.index_rate) * feats[0][skip_head // 2 :]
)
except:
pass
p_len = input_wav.shape[0] // self.window
factor = pow(2, self.formant_shift / 12)
return_length2 = int(np.ceil(return_length * factor))
cache_pitch = cache_pitchf = None
pitch = pitchf = None
if isinstance(f0method, tuple):
pitch, pitchf = f0method
pitch = torch.tensor(pitch, device=self.device).unsqueeze(0).long()
pitchf = torch.tensor(pitchf, device=self.device).unsqueeze(0).float()
elif self.if_f0 == 1:
f0_extractor_frame = block_frame_16k + 800
if f0method == "rmvpe":
f0_extractor_frame = (
5120 * ((f0_extractor_frame - 1) // 5120 + 1) - self.window
)
if inp_f0 is not None:
pitch, pitchf = self._get_f0_post(
inp_f0, self.f0_up_key - self.formant_shift
)
else:
pitch, pitchf = self._get_f0(
input_wav[-f0_extractor_frame:],
self.f0_up_key - self.formant_shift,
method=f0method,
)
shift = block_frame_16k // self.window
self.cache_pitch[:-shift] = self.cache_pitch[shift:].clone()
self.cache_pitchf[:-shift] = self.cache_pitchf[shift:].clone()
self.cache_pitch[4 - pitch.shape[0] :] = pitch[3:-1]
self.cache_pitchf[4 - pitch.shape[0] :] = pitchf[3:-1]
cache_pitch = self.cache_pitch[None, -p_len:]
cache_pitchf = (
self.cache_pitchf[None, -p_len:] * return_length2 / return_length
)
feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
feats = feats[:, :p_len, :]
if protect < 0.5 and pitch is not None and pitchf is not None:
feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(
0, 2, 1
)
feats0 = feats0[:, :p_len, :]
pitchff = pitchf.clone()
pitchff[pitchf > 0] = 1
pitchff[pitchf < 1] = protect
pitchff = pitchff.unsqueeze(-1)
feats = feats * pitchff + feats0 * (1 - pitchff)
feats = feats.to(feats0.dtype)
p_len = torch.LongTensor([p_len]).to(self.device)
sid = torch.LongTensor([0]).to(self.device)
with torch.no_grad():
infered_audio = (
self.net_g.infer(
feats,
p_len,
sid,
pitch=cache_pitch,
pitchf=cache_pitchf,
skip_head=skip_head,
return_length=return_length,
return_length2=return_length2,
)
.squeeze(1)
.float()
)
upp_res = int(np.floor(factor * self.tgt_sr // 100))
if upp_res != self.tgt_sr // 100:
if upp_res not in self.resample_kernel:
self.resample_kernel[upp_res] = Resample(
orig_freq=upp_res,
new_freq=self.tgt_sr // 100,
dtype=torch.float32,
).to(self.device)
infered_audio = self.resample_kernel[upp_res](
infered_audio[:, : return_length * upp_res]
)
return infered_audio.squeeze()
def _get_f0(
self,
x: torch.Tensor,
f0_up_key: Union[int, float],
filter_radius: Optional[Union[int, float]] = None,
method: Literal["crepe", "rmvpe", "fcpe", "pm", "harvest", "dio"] = "fcpe",
):
if method not in self.f0_methods.keys():
raise RuntimeError("Not supported f0 method: " + method)
return self.f0_methods[method](x, f0_up_key, filter_radius)
def _get_f0_post(self, f0, f0_up_key):
f0 *= pow(2, f0_up_key / 12)
if not torch.is_tensor(f0):
f0 = torch.from_numpy(f0)
f0 = f0.float().to(self.device).squeeze()
f0_mel = 1127 * torch.log(1 + f0 / 700)
f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - self.f0_mel_min) * 254 / (
self.f0_mel_max - self.f0_mel_min
) + 1
f0_mel[f0_mel <= 1] = 1
f0_mel[f0_mel > 255] = 255
f0_coarse = torch.round(f0_mel).long()
return f0_coarse, f0
def _get_f0_pm(self, x, f0_up_key, filter_radius):
if not hasattr(self, "pm"):
self.pm = PM(hop_length=160, sampling_rate=16000)
f0 = self.pm.compute_f0(x.cpu().numpy())
return self._get_f0_post(f0, f0_up_key)
def _get_f0_harvest(self, x, f0_up_key, filter_radius=3):
if not hasattr(self, "harvest"):
self.harvest = Harvest(
self.window,
self.f0_min,
self.f0_max,
self.sr,
)
if filter_radius is None:
filter_radius = 3
f0 = self.harvest.compute_f0(x.cpu().numpy(), filter_radius=filter_radius)
return self._get_f0_post(f0, f0_up_key)
def _get_f0_dio(self, x, f0_up_key, filter_radius):
if not hasattr(self, "dio"):
self.dio = Dio(
self.window,
self.f0_min,
self.f0_max,
self.sr,
)
f0 = self.dio.compute_f0(x.cpu().numpy())
return self._get_f0_post(f0, f0_up_key)
def _get_f0_crepe(self, x, f0_up_key, filter_radius):
if hasattr(self, "crepe") == False:
self.crepe = CRePE(
self.window,
self.f0_min,
self.f0_max,
self.sr,
self.device,
)
f0 = self.crepe.compute_f0(x)
return self._get_f0_post(f0, f0_up_key)
def _get_f0_rmvpe(self, x, f0_up_key, filter_radius=0.03):
if hasattr(self, "rmvpe") == False:
self.rmvpe = RMVPE(
"%s/rmvpe.pt" % os.environ["rmvpe_root"],
is_half=self.is_half,
device=self.device,
use_jit=self.use_jit,
)
if filter_radius is None:
filter_radius = 0.03
return self._get_f0_post(
self.rmvpe.compute_f0(x, filter_radius=filter_radius),
f0_up_key,
)
def _get_f0_fcpe(self, x, f0_up_key, filter_radius):
if hasattr(self, "fcpe") == False:
self.fcpe = FCPE(
160,
self.f0_min,
self.f0_max,
16000,
self.device,
)
f0 = self.fcpe.compute_f0(x)
return self._get_f0_post(f0, f0_up_key)