# Copyright 2024 Yiwei Guo # Licensed under Apache 2.0 """Extract VQ indexes using wav2vec2.0 model (from fairseq)""" import torch import logging from kaldiio import WriteHelper import os from transformers import Wav2Vec2FeatureExtractor, Wav2Vec2ForPreTraining import argparse import numpy as np from pathlib import Path import soundfile as sf from tqdm import tqdm logging.basicConfig(level=logging.INFO, format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s') class Extractor: def __init__(self, checkpoint="pretrained/wav2vec2-large-lv60/", device="cuda"): self.device = device feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(checkpoint) model = Wav2Vec2ForPreTraining.from_pretrained(checkpoint) model.to(self.device) model.half() model.eval() self.model = model self.feature_extractor = feature_extractor logging.info(self.model) for p in self.model.parameters(): p.requires_grad_(False) def extract(self, wav: np.ndarray, sample_rate: int) -> torch.Tensor: with torch.no_grad(): wav = torch.from_numpy(wav).float() input_values = self.feature_extractor(wav, return_tensors="pt", sampling_rate=sample_rate).input_values input_values = input_values.half().to(self.device) outputs = self.model.wav2vec2(input_values) extract_features = self.model.dropout_features(outputs[1]) hidden_states = extract_features batch_size, sequence_length, hidden_size = hidden_states.shape hidden_states = self.model.quantizer.weight_proj(hidden_states) hidden_states = hidden_states.view(batch_size * sequence_length * self.model.quantizer.num_groups, -1) codevector_idx = hidden_states.argmax(dim=-1) idxs = codevector_idx.view(batch_size, sequence_length, self.model.quantizer.num_groups) return idxs[0].cpu() # [L, Groups] def get_codebook(self) -> np.ndarray: quantizer = self.model.quantizer codebook = quantizer.codevectors # (1, 640, 384) codebook = codebook.view(quantizer.num_groups, quantizer.num_vars, -1) # (2, 320, 384) return codebook.cpu().numpy() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--wav-scp', type=str) parser.add_argument("--out-dir", type=str) parser.add_argument('--model', default="pretrained/wav2vec2-large-lv60/", type=str) args = parser.parse_args() extractor = Extractor(checkpoint=args.model, device="cuda" if torch.cuda.is_available() else "cpu") out_dir=Path(args.out_dir).absolute() with open(args.wav_scp, 'r') as f, torch.no_grad(), WriteHelper(f"ark,scp:{out_dir}/feats.ark,{out_dir}/feats.scp") as writer: for line in tqdm(f.readlines()): uttid, wav_path = line.strip().split(maxsplit=1) logging.info("Extracting " + uttid) audio, sample_rate = sf.read(wav_path) idxs = extractor.extract(audio, sample_rate=sample_rate) idxs = idxs.astype(float) writer(uttid, idxs)