File size: 3,458 Bytes
05005db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# Copyright 2024 Yiwei Guo
#  Licensed under Apache 2.0

"""Extract VQ indexes using WavLM model (from microsoft UniLM)"""

import torch
from vec2wav2.ssl_models.WavLM import WavLM, WavLMConfig
import soundfile as sf
from vec2wav2.utils.espnet_utils import pad_list, make_pad_mask
import time
from pathlib import Path
import argparse
from kaldiio import WriteHelper
from tqdm import tqdm
import logging
from vec2wav2.utils.utils import read_wav_16k

class Extractor:
    def __init__(self, checkpoint="pretrained/WavLM-Large.pt", device="cuda", output_layer=6):
        self.device = device
        checkpoint = torch.load(checkpoint)
        self.cfg = WavLMConfig(checkpoint['cfg'])
        self.model = WavLM(self.cfg)
        self.model.load_state_dict(checkpoint['model'])
        self.model.to(device)
        self.model.eval()
        for p in self.model.parameters():
            p.requires_grad_(False)
        self.output_layer = output_layer

    def extract(self, wav):
        with torch.no_grad():
            wav_input_16khz = torch.from_numpy(wav).unsqueeze(0).float().to(self.device)
            if self.cfg.normalize:
                wav_input_16khz = torch.nn.functional.layer_norm(wav_input_16khz, wav_input_16khz.shape)
            rep = self.model.extract_features(wav_input_16khz, output_layer=self.output_layer)[0]
            return rep.squeeze(0).clone().detach()  # torch.tensor [T, D]

    def extract_batch(self, wav_list, frame_lens):
        # suppose wav is already a tensor padded with 0
        # should be careful with LayerNorm since it may cause difference between batch vs single modes.
        pad_mask = make_pad_mask(frame_lens).to(self.device)
        with torch.no_grad():
            wav_input_16khz = [torch.from_numpy(wav).float().to(self.device) for wav in wav_list]
            if self.cfg.normalize:
                wav_input_16khz = [torch.nn.functional.layer_norm(wav, wav.shape) for wav in wav_input_16khz]
            wav_input_16khz = pad_list(wav_input_16khz, 0)
            s = time.time()
            rep = self.model.extract_features(wav_input_16khz, output_layer=self.output_layer, padding_mask=pad_mask)[0]
            t = time.time()
            print(f'in batch mode, pure extracting costs {t-s} s')
            return rep.clone().detach()  # [B, T, D]


def calc_out_len(in_len, k, s):
    return int((in_len-(k-1)-1)/s + 1)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--wav-scp', type=str)
    parser.add_argument("--out-dir", type=str)
    parser.add_argument('--model', default="pretrained/WavLM-Large.pt", type=str)
    parser.add_argument('--output-layer', default=6, type=int)
    args = parser.parse_args()
    
    extractor = Extractor(checkpoint=args.model, 
                          device="cuda" if torch.cuda.is_available() else "cpu", 
                          output_layer=args.output_layer)
    
    out_dir=Path(args.out_dir).absolute()
    out_dir.mkdir(parents=True, exist_ok=True)
    
    with open(args.wav_scp, 'r') as f, torch.no_grad(), WriteHelper(f"ark,scp:{out_dir}/feats.ark,{out_dir}/feats.scp") as writer:
        for line in tqdm(f.readlines()):
            uttid, wav_path = line.strip().split(maxsplit=1)
            logging.info("Extracting " + uttid)
            audio = read_wav_16k(wav_path)
            rep = extractor.extract(audio)
            rep = rep.cpu().numpy()
            writer(uttid, rep)