Spaces:
Running
Running
# Copyright 2024 Yiwei Guo | |
# Licensed under Apache 2.0 | |
"""Extract VQ indexes using WavLM model (from microsoft UniLM)""" | |
import torch | |
from vec2wav2.ssl_models.WavLM import WavLM, WavLMConfig | |
import soundfile as sf | |
from vec2wav2.utils.espnet_utils import pad_list, make_pad_mask | |
import time | |
from pathlib import Path | |
import argparse | |
from kaldiio import WriteHelper | |
from tqdm import tqdm | |
import logging | |
from vec2wav2.utils.utils import read_wav_16k | |
class Extractor: | |
def __init__(self, checkpoint="pretrained/WavLM-Large.pt", device="cuda", output_layer=6): | |
self.device = device | |
checkpoint = torch.load(checkpoint) | |
self.cfg = WavLMConfig(checkpoint['cfg']) | |
self.model = WavLM(self.cfg) | |
self.model.load_state_dict(checkpoint['model']) | |
self.model.to(device) | |
self.model.eval() | |
for p in self.model.parameters(): | |
p.requires_grad_(False) | |
self.output_layer = output_layer | |
def extract(self, wav): | |
with torch.no_grad(): | |
wav_input_16khz = torch.from_numpy(wav).unsqueeze(0).float().to(self.device) | |
if self.cfg.normalize: | |
wav_input_16khz = torch.nn.functional.layer_norm(wav_input_16khz, wav_input_16khz.shape) | |
rep = self.model.extract_features(wav_input_16khz, output_layer=self.output_layer)[0] | |
return rep.squeeze(0).clone().detach() # torch.tensor [T, D] | |
def extract_batch(self, wav_list, frame_lens): | |
# suppose wav is already a tensor padded with 0 | |
# should be careful with LayerNorm since it may cause difference between batch vs single modes. | |
pad_mask = make_pad_mask(frame_lens).to(self.device) | |
with torch.no_grad(): | |
wav_input_16khz = [torch.from_numpy(wav).float().to(self.device) for wav in wav_list] | |
if self.cfg.normalize: | |
wav_input_16khz = [torch.nn.functional.layer_norm(wav, wav.shape) for wav in wav_input_16khz] | |
wav_input_16khz = pad_list(wav_input_16khz, 0) | |
s = time.time() | |
rep = self.model.extract_features(wav_input_16khz, output_layer=self.output_layer, padding_mask=pad_mask)[0] | |
t = time.time() | |
print(f'in batch mode, pure extracting costs {t-s} s') | |
return rep.clone().detach() # [B, T, D] | |
def calc_out_len(in_len, k, s): | |
return int((in_len-(k-1)-1)/s + 1) | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--wav-scp', type=str) | |
parser.add_argument("--out-dir", type=str) | |
parser.add_argument('--model', default="pretrained/WavLM-Large.pt", type=str) | |
parser.add_argument('--output-layer', default=6, type=int) | |
args = parser.parse_args() | |
extractor = Extractor(checkpoint=args.model, | |
device="cuda" if torch.cuda.is_available() else "cpu", | |
output_layer=args.output_layer) | |
out_dir=Path(args.out_dir).absolute() | |
out_dir.mkdir(parents=True, exist_ok=True) | |
with open(args.wav_scp, 'r') as f, torch.no_grad(), WriteHelper(f"ark,scp:{out_dir}/feats.ark,{out_dir}/feats.scp") as writer: | |
for line in tqdm(f.readlines()): | |
uttid, wav_path = line.strip().split(maxsplit=1) | |
logging.info("Extracting " + uttid) | |
audio = read_wav_16k(wav_path) | |
rep = extractor.extract(audio) | |
rep = rep.cpu().numpy() | |
writer(uttid, rep) | |