vec2wav2.0-demo / vec2wav2 /ssl_models /vqw2v_extractor.py
cantabile-kwok
prepare demo page
05005db
# Copyright 2024 Yiwei Guo
# Licensed under Apache 2.0
"""Extract VQ indexes using vq-wav2vec model (from fairseq)"""
import torch
import logging
from kaldiio import WriteHelper
import os
import fairseq
import argparse
import numpy as np
from pathlib import Path
import soundfile as sf
from tqdm import tqdm
from vec2wav2.utils.utils import read_wav_16k
logging.basicConfig(level=logging.INFO, format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s')
class Extractor:
def __init__(self, checkpoint="pretrained/vq-wav2vec_kmeans.pt", device="cuda"):
self.device = device
self.model, self.cfg, self.task = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint])
self.model = self.model[0].to(device)
self.model.eval()
for p in self.model.parameters():
p.requires_grad_(False)
def extract(self, wav: np.ndarray) -> torch.Tensor:
with torch.no_grad():
audio = torch.from_numpy(wav).float().unsqueeze(0).to(self.device)
z = self.model.feature_extractor(audio)
_, idxs = self.model.vector_quantizer.forward_idx(z)
return idxs[0].cpu() # [L, Groups]
def get_codebook(self) -> np.ndarray:
quantizer = self.model.vector_quantizer
if self.cfg.model.vq_type == "kmeans":
codebook = quantizer.expand_embedding.data.transpose(0,1).contiguous()
elif self.cfg.model.vq_type == "gumbel":
codebook = quantizer.vars.data
if quantizer.combine_groups:
codebook = codebook.repeat(1, quantizer.groups, 1)
codebook = codebook.view(quantizer.groups, quantizer.num_vars, -1)
codebook = codebook.cpu().numpy()
return codebook
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--wav-scp', type=str)
parser.add_argument("--out-dir", type=str)
parser.add_argument('--model', default="pretrained/vq-wav2vec_kmeans.pt", type=str)
args = parser.parse_args()
extractor = Extractor(checkpoint=args.model, device="cuda" if torch.cuda.is_available() else "cpu")
out_dir=Path(args.out_dir).absolute()
with open(args.wav_scp, 'r') as f, torch.no_grad(), WriteHelper(f"ark,scp:{out_dir}/feats.ark,{out_dir}/feats.scp") as writer:
for line in tqdm(f.readlines()):
uttid, wav_path = line.strip().split(maxsplit=1)
logging.info("Extracting " + uttid)
audio = read_wav_16k(wav_path)
idxs = extractor.extract(audio).cpu().numpy()
idxs = idxs.astype(float)
writer(uttid, idxs)