File size: 1,557 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os
import torch
import random
import argparse
import transformers
from s3prl import hub
from packaging import version

SAMPLE_RATE = 16000
BATCH_SIZE = 8

parser = argparse.ArgumentParser()
parser.add_argument("--base", action="store_true")
parser.add_argument("--large", action="store_true")
parser.add_argument("--device", default="cuda")
args = parser.parse_args()

assert version.parse(transformers.__version__) <= version.parse(
    "4.9.0"
), "Newer version of transformers change the places for feature extraction."
assert args.base or args.large
s3prl_str = "wav2vec2_base_960" if args.base else "wav2vec2_large_ll60k"
huggingface_str = "wav2vec2_hug_base_960" if args.base else "wav2vec2_hug_large_ll60k"

s3prl = getattr(hub, s3prl_str)().to(args.device)
huggingface = getattr(hub, huggingface_str)().to(args.device)

if args.base:
    s3prl.wav_normalize = True
    s3prl.apply_padding_mask = False
s3prl.numpy_wav_normalize = True

s3prl.eval()
huggingface.eval()

wavs = [
    torch.randn(random.randint(SAMPLE_RATE * 1, SAMPLE_RATE * 15)).to(args.device)
    for _ in range(BATCH_SIZE)
]

with torch.no_grad():
    hiddens1 = s3prl(wavs)["hidden_states"]
    hiddens2 = huggingface(wavs)["hidden_states"]
    assert len(hiddens1) == len(hiddens2)

    diffs = []
    for idx, (hidden1, hidden2) in enumerate(zip(hiddens1, hiddens2)):
        diff = (hidden1 - hidden2).abs().max().item()
        print(f"hidden {idx} difference: {diff}")
        diffs.append(diff)

    print(f"Max difference: {torch.FloatTensor(diffs).max().item()}")