File size: 2,051 Bytes
2cddd11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from datasets import load_dataset
from tqdm import tqdm
import pandas as pd

cache_dir = "./../../../cache"

dataset = load_dataset("openslr/librispeech_asr", cache_dir=cache_dir, trust_remote_code=True)

from repcodec.RepCodec import RepCodec
import torch
import yaml

config = "./../repcodec/configs/repcodec_dim1024.yaml"
with open(config) as fp:
    conf = yaml.load(fp, Loader=yaml.FullLoader)

model = RepCodec(**conf)
model.load_state_dict(torch.load("./../../models/data2vec_large_l18.pkl", map_location="cuda:0")["model"]["repcodec"])
model.quantizer.initial()
model.eval()
model.to("cuda:0")

from data2vec_feature_reader import Data2vecFeatureReader

reader = Data2vecFeatureReader("./../../models/vox_pretrained.pt", 18, device="cuda:0", max_chunk=1600000)

import torch.nn.functional as F
import numpy as np

for split in dataset.keys():

    tokens = []

    for idx in tqdm(range(len(dataset[split]))):

        sample = dataset[split][idx]
        
        x = sample["audio"]["array"]
        
        with torch.no_grad():
            x = torch.from_numpy(x).float().to(reader.device)
            if reader.task.cfg.normalize:
                x = F.layer_norm(x, x.shape)
            x = x.view(1, -1)
        
            feat = []
            for start in range(0, x.size(1), reader.max_chunk):
                x_chunk = x[:, start: start + reader.max_chunk]
                res = reader.model.extract_features(
                    source=x_chunk,
                    padding_mask=None,
                    mask=False,
                    layer=reader.layer,
                )
                feat_chunk = res["x"]
                feat.append(feat_chunk)
                
            features = torch.cat(feat, 1).permute(0, 2, 1)
        
            x = model.encoder(features)
            z = model.projector(x)
            _, idx = model.quantizer.codebook.forward_index(z.transpose(2, 1))
            tkn = idx.detach().cpu().data.numpy()[0]
            
        tokens.append(tkn)
    np.savez(f"./tkns/{split}.npz", *tokens)