ZeroRVC / infer /modules /train /extract_feature_print.py
JacobLinCool's picture
feat: hubert features
a0ad823
raw
history blame
3.61 kB
import os
import traceback
import fairseq
import numpy as np
import soundfile as sf
import torch
import torch.nn.functional as F
device = "cpu"
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
model_path = "assets/hubert/hubert_base.pt"
models, saved_cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[model_path],
suffix="",
)
model = models[0]
model = model.to(device)
is_half = False
if is_half:
if device not in ["mps", "cpu"]:
model = model.half()
model.eval()
# wave must be 16k, hop_size=320
def readwave(wav_path, normalize=False):
wav, sr = sf.read(wav_path)
assert sr == 16000
feats = torch.from_numpy(wav).float()
if feats.dim() == 2: # double channels
feats = feats.mean(-1)
assert feats.dim() == 1, feats.dim()
if normalize:
with torch.no_grad():
feats = F.layer_norm(feats, feats.shape)
feats = feats.view(1, -1)
return feats
class HubertFeatureExtractor:
def __init__(self, exp_dir: str):
self.exp_dir = exp_dir
self.logfile = open("%s/extract_f0_feature.log" % exp_dir, "a+")
self.wavPath = "%s/1_16k_wavs" % exp_dir
self.outPath = "%s/3_feature768" % exp_dir
os.makedirs(self.outPath, exist_ok=True)
def println(self, strr):
print(strr)
self.logfile.write("%s\n" % strr)
self.logfile.flush()
def run(self):
todo = sorted(list(os.listdir(self.wavPath)))
n = max(1, len(todo) // 10) # ζœ€ε€šζ‰“ε°εζ‘
if len(todo) == 0:
self.println("no-feature-todo")
else:
self.println("all-feature-%s" % len(todo))
for idx, file in enumerate(todo):
try:
if file.endswith(".wav"):
wav_path = "%s/%s" % (self.wavPath, file)
out_path = "%s/%s" % (self.outPath, file.replace("wav", "npy"))
if os.path.exists(out_path):
continue
feats = readwave(wav_path, normalize=saved_cfg.task.normalize)
padding_mask = torch.BoolTensor(feats.shape).fill_(False)
inputs = {
"source": (
feats.half().to(device)
if is_half and device not in ["mps", "cpu"]
else feats.to(device)
),
"padding_mask": padding_mask.to(device),
"output_layer": 12,
}
with torch.no_grad():
logits = model.extract_features(**inputs)
feats = logits[0]
feats = feats.squeeze(0).float().cpu().numpy()
if np.isnan(feats).sum() == 0:
np.save(out_path, feats, allow_pickle=False)
else:
self.println("%s-contains nan" % file)
if idx % n == 0:
self.println(
"now-%s,all-%s,%s,%s"
% (len(todo), idx, file, feats.shape)
)
except:
self.println(traceback.format_exc())
self.println("all-feature-done")