import os import torch import random import argparse import transformers from s3prl import hub from packaging import version SAMPLE_RATE = 16000 BATCH_SIZE = 8 parser = argparse.ArgumentParser() parser.add_argument("--base", action="store_true") parser.add_argument("--large", action="store_true") parser.add_argument("--device", default="cuda") args = parser.parse_args() assert version.parse(transformers.__version__) <= version.parse( "4.9.0" ), "Newer version of transformers change the places for feature extraction." assert args.base or args.large s3prl_str = "wav2vec2_base_960" if args.base else "wav2vec2_large_ll60k" huggingface_str = "wav2vec2_hug_base_960" if args.base else "wav2vec2_hug_large_ll60k" s3prl = getattr(hub, s3prl_str)().to(args.device) huggingface = getattr(hub, huggingface_str)().to(args.device) if args.base: s3prl.wav_normalize = True s3prl.apply_padding_mask = False s3prl.numpy_wav_normalize = True s3prl.eval() huggingface.eval() wavs = [ torch.randn(random.randint(SAMPLE_RATE * 1, SAMPLE_RATE * 15)).to(args.device) for _ in range(BATCH_SIZE) ] with torch.no_grad(): hiddens1 = s3prl(wavs)["hidden_states"] hiddens2 = huggingface(wavs)["hidden_states"] assert len(hiddens1) == len(hiddens2) diffs = [] for idx, (hidden1, hidden2) in enumerate(zip(hiddens1, hiddens2)): diff = (hidden1 - hidden2).abs().max().item() print(f"hidden {idx} difference: {diff}") diffs.append(diff) print(f"Max difference: {torch.FloatTensor(diffs).max().item()}")