File size: 1,557 Bytes
0b32ad6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import os
import torch
import random
import argparse
import transformers
from s3prl import hub
from packaging import version
SAMPLE_RATE = 16000
BATCH_SIZE = 8
parser = argparse.ArgumentParser()
parser.add_argument("--base", action="store_true")
parser.add_argument("--large", action="store_true")
parser.add_argument("--device", default="cuda")
args = parser.parse_args()
assert version.parse(transformers.__version__) <= version.parse(
"4.9.0"
), "Newer version of transformers change the places for feature extraction."
assert args.base or args.large
s3prl_str = "wav2vec2_base_960" if args.base else "wav2vec2_large_ll60k"
huggingface_str = "wav2vec2_hug_base_960" if args.base else "wav2vec2_hug_large_ll60k"
s3prl = getattr(hub, s3prl_str)().to(args.device)
huggingface = getattr(hub, huggingface_str)().to(args.device)
if args.base:
s3prl.wav_normalize = True
s3prl.apply_padding_mask = False
s3prl.numpy_wav_normalize = True
s3prl.eval()
huggingface.eval()
wavs = [
torch.randn(random.randint(SAMPLE_RATE * 1, SAMPLE_RATE * 15)).to(args.device)
for _ in range(BATCH_SIZE)
]
with torch.no_grad():
hiddens1 = s3prl(wavs)["hidden_states"]
hiddens2 = huggingface(wavs)["hidden_states"]
assert len(hiddens1) == len(hiddens2)
diffs = []
for idx, (hidden1, hidden2) in enumerate(zip(hiddens1, hiddens2)):
diff = (hidden1 - hidden2).abs().max().item()
print(f"hidden {idx} difference: {diff}")
diffs.append(diff)
print(f"Max difference: {torch.FloatTensor(diffs).max().item()}")
|