File size: 5,686 Bytes
7ff2ba3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
from io import BytesIO
import os
from typing import Any, Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
from rvc.jit import load_inputs, get_jit_model, export_jit_model, save_pickle
from .mel import MelSpectrogram
from .f0 import F0Predictor
from .models import get_rmvpe
def rmvpe_jit_export(
model_path: str,
mode: str = "script",
inputs_path: str = None,
save_path: str = None,
device=torch.device("cpu"),
is_half=False,
):
if not save_path:
save_path = model_path.rstrip(".pth")
save_path += ".half.jit" if is_half else ".jit"
if "cuda" in str(device) and ":" not in str(device):
device = torch.device("cuda:0")
model = get_rmvpe(model_path, device, is_half)
inputs = None
if mode == "trace":
inputs = load_inputs(inputs_path, device, is_half)
ckpt = export_jit_model(model, mode, inputs, device, is_half)
ckpt["device"] = str(device)
save_pickle(ckpt, save_path)
return ckpt
class RMVPE(F0Predictor):
def __init__(
self,
model_path: str,
is_half: bool,
device: str,
use_jit=False,
):
hop_length = 160
f0_min = 30
f0_max = 8000
sampling_rate = 16000
super().__init__(
hop_length,
f0_min,
f0_max,
sampling_rate,
device,
)
self.is_half = is_half
cents_mapping = 20 * np.arange(360) + 1997.3794084376191
self.cents_mapping = np.pad(cents_mapping, (4, 4)) # 368
self.mel_extractor = MelSpectrogram(
is_half=is_half,
n_mel_channels=128,
sampling_rate=sampling_rate,
win_length=1024,
hop_length=hop_length,
mel_fmin=f0_min,
mel_fmax=f0_max,
device=self.device,
).to(self.device)
if "privateuseone" in str(self.device):
import onnxruntime as ort
self.model = ort.InferenceSession(
"%s/rmvpe.onnx" % os.environ["rmvpe_root"],
providers=["DmlExecutionProvider"],
)
else:
def rmvpe_jit_model():
ckpt = get_jit_model(model_path, is_half, self.device, rmvpe_jit_export)
model = torch.jit.load(BytesIO(ckpt["model"]), map_location=self.device)
model = model.to(self.device)
return model
if use_jit and not (is_half and "cpu" in str(self.device)):
self.model = rmvpe_jit_model()
else:
self.model = get_rmvpe(model_path, self.device, is_half)
def compute_f0(
self,
wav: np.ndarray,
p_len: Optional[int] = None,
filter_radius: Optional[Union[int, float]] = None,
):
if p_len is None:
p_len = wav.shape[0] // self.hop_length
if not torch.is_tensor(wav):
wav = torch.from_numpy(wav)
mel = self.mel_extractor(wav.float().to(self.device).unsqueeze(0), center=True)
hidden = self._mel2hidden(mel)
if "privateuseone" not in str(self.device):
hidden = hidden.squeeze(0).cpu().numpy()
else:
hidden = hidden[0]
if self.is_half == True:
hidden = hidden.astype("float32")
f0 = self._decode(hidden, thred=filter_radius)
return self._interpolate_f0(self._resize_f0(f0, p_len))[0]
def _to_local_average_cents(self, salience, threshold=0.05):
center = np.argmax(salience, axis=1) # 帧长#index
salience = np.pad(salience, ((0, 0), (4, 4))) # 帧长,368
center += 4
todo_salience = []
todo_cents_mapping = []
starts = center - 4
ends = center + 5
for idx in range(salience.shape[0]):
todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
todo_salience = np.array(todo_salience) # 帧长,9
todo_cents_mapping = np.array(todo_cents_mapping) # 帧长,9
product_sum = np.sum(todo_salience * todo_cents_mapping, 1)
weight_sum = np.sum(todo_salience, 1) # 帧长
devided = product_sum / weight_sum # 帧长
maxx = np.max(salience, axis=1) # 帧长
devided[maxx <= threshold] = 0
return devided
def _mel2hidden(self, mel):
with torch.no_grad():
n_frames = mel.shape[-1]
n_pad = 32 * ((n_frames - 1) // 32 + 1) - n_frames
if n_pad > 0:
mel = F.pad(mel, (0, n_pad), mode="constant")
if "privateuseone" in str(self.device):
onnx_input_name = self.model.get_inputs()[0].name
onnx_outputs_names = self.model.get_outputs()[0].name
hidden = self.model.run(
[onnx_outputs_names],
input_feed={onnx_input_name: mel.cpu().numpy()},
)[0]
else:
mel = mel.half() if self.is_half else mel.float()
hidden = self.model(mel)
return hidden[:, :n_frames]
def _decode(self, hidden, thred=0.03):
cents_pred = self._to_local_average_cents(hidden, threshold=thred)
f0 = 10 * (2 ** (cents_pred / 1200))
f0[f0 == 10] = 0
# f0 = np.array([10 * (2 ** (cent_pred / 1200)) if cent_pred else 0 for cent_pred in cents_pred])
return f0
|