File size: 2,647 Bytes
05005db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# Copyright 2024 Yiwei Guo
#  Licensed under Apache 2.0

"""Extract VQ indexes using vq-wav2vec model (from fairseq)"""

import torch
import logging
from kaldiio import WriteHelper
import os
import fairseq
import argparse
import numpy as np
from pathlib import Path
import soundfile as sf
from tqdm import tqdm
from vec2wav2.utils.utils import read_wav_16k

logging.basicConfig(level=logging.INFO, format='%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s')

class Extractor:
    def __init__(self, checkpoint="pretrained/vq-wav2vec_kmeans.pt", device="cuda"):
        self.device = device
        self.model, self.cfg, self.task = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint])
        self.model = self.model[0].to(device)
        self.model.eval()
        for p in self.model.parameters():
            p.requires_grad_(False)
    
    def extract(self, wav: np.ndarray) -> torch.Tensor:
        with torch.no_grad():
            audio = torch.from_numpy(wav).float().unsqueeze(0).to(self.device)

            z = self.model.feature_extractor(audio)
            _, idxs = self.model.vector_quantizer.forward_idx(z)
        return idxs[0].cpu()  # [L, Groups]

    def get_codebook(self) -> np.ndarray:
        quantizer = self.model.vector_quantizer
        if self.cfg.model.vq_type == "kmeans":
            codebook = quantizer.expand_embedding.data.transpose(0,1).contiguous()
        elif self.cfg.model.vq_type == "gumbel":
            codebook = quantizer.vars.data
            if quantizer.combine_groups:
                codebook = codebook.repeat(1, quantizer.groups, 1)
            codebook = codebook.view(quantizer.groups, quantizer.num_vars, -1) 

        codebook = codebook.cpu().numpy()
        return codebook

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--wav-scp', type=str)
    parser.add_argument("--out-dir", type=str)
    parser.add_argument('--model', default="pretrained/vq-wav2vec_kmeans.pt", type=str)
    args = parser.parse_args()
    
    extractor = Extractor(checkpoint=args.model, device="cuda" if torch.cuda.is_available() else "cpu")

    out_dir=Path(args.out_dir).absolute()
    with open(args.wav_scp, 'r') as f, torch.no_grad(), WriteHelper(f"ark,scp:{out_dir}/feats.ark,{out_dir}/feats.scp") as writer:
        for line in tqdm(f.readlines()):
            uttid, wav_path = line.strip().split(maxsplit=1)
            logging.info("Extracting " + uttid)
            audio = read_wav_16k(wav_path)
            idxs = extractor.extract(audio).cpu().numpy()
            idxs = idxs.astype(float)
            writer(uttid, idxs)