from io import BytesIO import os from typing import Union, Literal, Optional import fairseq import faiss import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torchaudio.transforms import Resample from rvc.f0 import PM, Harvest, RMVPE, CRePE, Dio, FCPE from rvc.synthesizer import load_synthesizer class RVC: def __init__( self, key: Union[int, float], formant: Union[int, float], pth_path: torch.serialization.FILE_LIKE, index_path: str, index_rate: Union[int, float], n_cpu: int = os.cpu_count(), device: str = "cpu", use_jit: bool = False, is_half: bool = False, is_dml: bool = False, ) -> None: if is_dml: def forward_dml(ctx, x, scale): ctx.scale = scale res = x.clone().detach() return res fairseq.modules.grad_multiply.GradMultiply.forward = forward_dml self.device = device self.f0_up_key = key self.formant_shift = formant self.sr = 16000 # hubert sampling rate self.window = 160 # hop length self.f0_min = 50 self.f0_max = 1100 self.f0_mel_min = 1127 * np.log(1 + self.f0_min / 700) self.f0_mel_max = 1127 * np.log(1 + self.f0_max / 700) self.n_cpu = n_cpu self.use_jit = use_jit self.is_half = is_half if index_rate > 0: self.index = faiss.read_index(index_path) self.big_npy = self.index.reconstruct_n(0, self.index.ntotal) self.pth_path = pth_path self.index_path = index_path self.index_rate = index_rate self.cache_pitch: torch.Tensor = torch.zeros( 1024, device=self.device, dtype=torch.long ) self.cache_pitchf = torch.zeros(1024, device=self.device, dtype=torch.float32) self.resample_kernel = {} self.f0_methods = { "crepe": self._get_f0_crepe, "rmvpe": self._get_f0_rmvpe, "fcpe": self._get_f0_fcpe, "pm": self._get_f0_pm, "harvest": self._get_f0_harvest, "dio": self._get_f0_dio, } models, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( ["assets/hubert/hubert_base.pt"], suffix="", ) hubert_model = models[0] hubert_model = hubert_model.to(self.device) if self.is_half: hubert_model = hubert_model.half() else: hubert_model = hubert_model.float() hubert_model.eval() self.hubert = hubert_model self.net_g: Optional[nn.Module] = None def set_default_model(): self.net_g, cpt = load_synthesizer(self.pth_path, self.device) self.tgt_sr = cpt["config"][-1] cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] self.if_f0 = cpt.get("f0", 1) self.version = cpt.get("version", "v1") if self.is_half: self.net_g = self.net_g.half() else: self.net_g = self.net_g.float() def set_jit_model(): from rvc.jit import get_jit_model from rvc.synthesizer import synthesizer_jit_export cpt = get_jit_model(self.pth_path, self.is_half, synthesizer_jit_export) self.tgt_sr = cpt["config"][-1] self.if_f0 = cpt.get("f0", 1) self.version = cpt.get("version", "v1") self.net_g = torch.jit.load(BytesIO(cpt["model"]), map_location=self.device) self.net_g.infer = self.net_g.forward self.net_g.eval().to(self.device) if ( self.use_jit and not is_dml and not (self.is_half and "cpu" in str(self.device)) ): set_jit_model() else: set_default_model() def set_key(self, new_key): self.f0_up_key = new_key def set_formant(self, new_formant): self.formant_shift = new_formant def set_index_rate(self, new_index_rate): if new_index_rate > 0 and self.index_rate <= 0: self.index = faiss.read_index(self.index_path) self.big_npy = self.index.reconstruct_n(0, self.index.ntotal) self.index_rate = new_index_rate def infer( self, input_wav: torch.Tensor, block_frame_16k: int, skip_head: int, return_length: int, f0method: Union[tuple, str], inp_f0: Optional[np.ndarray] = None, protect: float = 1.0, ) -> np.ndarray: with torch.no_grad(): if self.is_half: feats = input_wav.half() else: feats = input_wav.float() feats = feats.to(self.device) if feats.dim() == 2: # double channels feats = feats.mean(-1) feats = feats.view(1, -1) padding_mask = torch.BoolTensor(feats.shape).to(self.device).fill_(False) inputs = { "source": feats, "padding_mask": padding_mask, "output_layer": 9 if self.version == "v1" else 12, } logits = self.hubert.extract_features(**inputs) feats = ( self.hubert.final_proj(logits[0]) if self.version == "v1" else logits[0] ) feats = torch.cat((feats, feats[:, -1:, :]), 1) if protect < 0.5 and self.if_f0 == 1: feats0 = feats.clone() try: if hasattr(self, "index") and self.index_rate > 0: npy = feats[0][skip_head // 2 :].cpu().numpy() if self.is_half: npy = npy.astype("float32") score, ix = self.index.search(npy, k=8) if (ix >= 0).all(): weight = np.square(1 / score) weight /= weight.sum(axis=1, keepdims=True) npy = np.sum( self.big_npy[ix] * np.expand_dims(weight, axis=2), axis=1 ) if self.is_half: npy = npy.astype("float16") feats[0][skip_head // 2 :] = ( torch.from_numpy(npy).unsqueeze(0).to(self.device) * self.index_rate + (1 - self.index_rate) * feats[0][skip_head // 2 :] ) except: pass p_len = input_wav.shape[0] // self.window factor = pow(2, self.formant_shift / 12) return_length2 = int(np.ceil(return_length * factor)) cache_pitch = cache_pitchf = None pitch = pitchf = None if isinstance(f0method, tuple): pitch, pitchf = f0method pitch = torch.tensor(pitch, device=self.device).unsqueeze(0).long() pitchf = torch.tensor(pitchf, device=self.device).unsqueeze(0).float() elif self.if_f0 == 1: f0_extractor_frame = block_frame_16k + 800 if f0method == "rmvpe": f0_extractor_frame = ( 5120 * ((f0_extractor_frame - 1) // 5120 + 1) - self.window ) if inp_f0 is not None: pitch, pitchf = self._get_f0_post( inp_f0, self.f0_up_key - self.formant_shift ) else: pitch, pitchf = self._get_f0( input_wav[-f0_extractor_frame:], self.f0_up_key - self.formant_shift, method=f0method, ) shift = block_frame_16k // self.window self.cache_pitch[:-shift] = self.cache_pitch[shift:].clone() self.cache_pitchf[:-shift] = self.cache_pitchf[shift:].clone() self.cache_pitch[4 - pitch.shape[0] :] = pitch[3:-1] self.cache_pitchf[4 - pitch.shape[0] :] = pitchf[3:-1] cache_pitch = self.cache_pitch[None, -p_len:] cache_pitchf = ( self.cache_pitchf[None, -p_len:] * return_length2 / return_length ) feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1) feats = feats[:, :p_len, :] if protect < 0.5 and pitch is not None and pitchf is not None: feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute( 0, 2, 1 ) feats0 = feats0[:, :p_len, :] pitchff = pitchf.clone() pitchff[pitchf > 0] = 1 pitchff[pitchf < 1] = protect pitchff = pitchff.unsqueeze(-1) feats = feats * pitchff + feats0 * (1 - pitchff) feats = feats.to(feats0.dtype) p_len = torch.LongTensor([p_len]).to(self.device) sid = torch.LongTensor([0]).to(self.device) with torch.no_grad(): infered_audio = ( self.net_g.infer( feats, p_len, sid, pitch=cache_pitch, pitchf=cache_pitchf, skip_head=skip_head, return_length=return_length, return_length2=return_length2, ) .squeeze(1) .float() ) upp_res = int(np.floor(factor * self.tgt_sr // 100)) if upp_res != self.tgt_sr // 100: if upp_res not in self.resample_kernel: self.resample_kernel[upp_res] = Resample( orig_freq=upp_res, new_freq=self.tgt_sr // 100, dtype=torch.float32, ).to(self.device) infered_audio = self.resample_kernel[upp_res]( infered_audio[:, : return_length * upp_res] ) return infered_audio.squeeze() def _get_f0( self, x: torch.Tensor, f0_up_key: Union[int, float], filter_radius: Optional[Union[int, float]] = None, method: Literal["crepe", "rmvpe", "fcpe", "pm", "harvest", "dio"] = "fcpe", ): if method not in self.f0_methods.keys(): raise RuntimeError("Not supported f0 method: " + method) return self.f0_methods[method](x, f0_up_key, filter_radius) def _get_f0_post(self, f0, f0_up_key): f0 *= pow(2, f0_up_key / 12) if not torch.is_tensor(f0): f0 = torch.from_numpy(f0) f0 = f0.float().to(self.device).squeeze() f0_mel = 1127 * torch.log(1 + f0 / 700) f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - self.f0_mel_min) * 254 / ( self.f0_mel_max - self.f0_mel_min ) + 1 f0_mel[f0_mel <= 1] = 1 f0_mel[f0_mel > 255] = 255 f0_coarse = torch.round(f0_mel).long() return f0_coarse, f0 def _get_f0_pm(self, x, f0_up_key, filter_radius): if not hasattr(self, "pm"): self.pm = PM(hop_length=160, sampling_rate=16000) f0 = self.pm.compute_f0(x.cpu().numpy()) return self._get_f0_post(f0, f0_up_key) def _get_f0_harvest(self, x, f0_up_key, filter_radius=3): if not hasattr(self, "harvest"): self.harvest = Harvest( self.window, self.f0_min, self.f0_max, self.sr, ) if filter_radius is None: filter_radius = 3 f0 = self.harvest.compute_f0(x.cpu().numpy(), filter_radius=filter_radius) return self._get_f0_post(f0, f0_up_key) def _get_f0_dio(self, x, f0_up_key, filter_radius): if not hasattr(self, "dio"): self.dio = Dio( self.window, self.f0_min, self.f0_max, self.sr, ) f0 = self.dio.compute_f0(x.cpu().numpy()) return self._get_f0_post(f0, f0_up_key) def _get_f0_crepe(self, x, f0_up_key, filter_radius): if hasattr(self, "crepe") == False: self.crepe = CRePE( self.window, self.f0_min, self.f0_max, self.sr, self.device, ) f0 = self.crepe.compute_f0(x) return self._get_f0_post(f0, f0_up_key) def _get_f0_rmvpe(self, x, f0_up_key, filter_radius=0.03): if hasattr(self, "rmvpe") == False: self.rmvpe = RMVPE( "%s/rmvpe.pt" % os.environ["rmvpe_root"], is_half=self.is_half, device=self.device, use_jit=self.use_jit, ) if filter_radius is None: filter_radius = 0.03 return self._get_f0_post( self.rmvpe.compute_f0(x, filter_radius=filter_radius), f0_up_key, ) def _get_f0_fcpe(self, x, f0_up_key, filter_radius): if hasattr(self, "fcpe") == False: self.fcpe = FCPE( 160, self.f0_min, self.f0_max, 16000, self.device, ) f0 = self.fcpe.compute_f0(x) return self._get_f0_post(f0, f0_up_key)