# Copyright 2024 Yiwei Guo # Licensed under Apache 2.0 """Extract VQ indexes using WavLM model (from microsoft UniLM)""" import torch from vec2wav2.ssl_models.WavLM import WavLM, WavLMConfig import soundfile as sf from vec2wav2.utils.espnet_utils import pad_list, make_pad_mask import time from pathlib import Path import argparse from kaldiio import WriteHelper from tqdm import tqdm import logging from vec2wav2.utils.utils import read_wav_16k class Extractor: def __init__(self, checkpoint="pretrained/WavLM-Large.pt", device="cuda", output_layer=6): self.device = device checkpoint = torch.load(checkpoint) self.cfg = WavLMConfig(checkpoint['cfg']) self.model = WavLM(self.cfg) self.model.load_state_dict(checkpoint['model']) self.model.to(device) self.model.eval() for p in self.model.parameters(): p.requires_grad_(False) self.output_layer = output_layer def extract(self, wav): with torch.no_grad(): wav_input_16khz = torch.from_numpy(wav).unsqueeze(0).float().to(self.device) if self.cfg.normalize: wav_input_16khz = torch.nn.functional.layer_norm(wav_input_16khz, wav_input_16khz.shape) rep = self.model.extract_features(wav_input_16khz, output_layer=self.output_layer)[0] return rep.squeeze(0).clone().detach() # torch.tensor [T, D] def extract_batch(self, wav_list, frame_lens): # suppose wav is already a tensor padded with 0 # should be careful with LayerNorm since it may cause difference between batch vs single modes. pad_mask = make_pad_mask(frame_lens).to(self.device) with torch.no_grad(): wav_input_16khz = [torch.from_numpy(wav).float().to(self.device) for wav in wav_list] if self.cfg.normalize: wav_input_16khz = [torch.nn.functional.layer_norm(wav, wav.shape) for wav in wav_input_16khz] wav_input_16khz = pad_list(wav_input_16khz, 0) s = time.time() rep = self.model.extract_features(wav_input_16khz, output_layer=self.output_layer, padding_mask=pad_mask)[0] t = time.time() print(f'in batch mode, pure extracting costs {t-s} s') return rep.clone().detach() # [B, T, D] def calc_out_len(in_len, k, s): return int((in_len-(k-1)-1)/s + 1) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--wav-scp', type=str) parser.add_argument("--out-dir", type=str) parser.add_argument('--model', default="pretrained/WavLM-Large.pt", type=str) parser.add_argument('--output-layer', default=6, type=int) args = parser.parse_args() extractor = Extractor(checkpoint=args.model, device="cuda" if torch.cuda.is_available() else "cpu", output_layer=args.output_layer) out_dir=Path(args.out_dir).absolute() out_dir.mkdir(parents=True, exist_ok=True) with open(args.wav_scp, 'r') as f, torch.no_grad(), WriteHelper(f"ark,scp:{out_dir}/feats.ark,{out_dir}/feats.scp") as writer: for line in tqdm(f.readlines()): uttid, wav_path = line.strip().split(maxsplit=1) logging.info("Extracting " + uttid) audio = read_wav_16k(wav_path) rep = extractor.extract(audio) rep = rep.cpu().numpy() writer(uttid, rep)