wavlm-large / s3prl_s3prl_main /test /test_upstream.py
lmzjms's picture
Upload 1162 files
0b32ad6 verified
import logging
import os
import shutil
import tempfile
import traceback
from pathlib import Path
from subprocess import check_call
import pytest
import torch
from filelock import FileLock
from s3prl.nn import Featurizer, S3PRLUpstream
from s3prl.util.download import _urls_to_filepaths
from s3prl.util.pseudo_data import get_pseudo_wavs
logger = logging.getLogger(__name__)
TEST_MORE_ITER = 2
TRAIN_MORE_ITER = 5
SAMPLE_RATE = 16000
ATOL = 0.01
MAX_LENGTH_DIFF = 3
EXTRA_SHORT_SEC = 0.05
EXTRACTED_GT_DIR = Path(__file__).parent.parent / "sample_hidden_states"
S3PRL_HF_SAMPLE_HS = "https://huggingface.co/datasets/s3prl/sample_hidden_states"
# Expect the following directory structure:
#
# -- s3prl/ (repository root)
# ---- s3prl/ (package root)
# ---- test/
# ------- test_upstream.py
# ---- sample_hidden_states/
def _prepare_sample_hidden_states():
lock_file = Path(__file__).parent.parent / "sample_hidden_states.lock"
with FileLock(str(lock_file)):
# NOTE: home variable is necessary for git lfs to work
env = dict(os.environ)
if not "HOME" in env:
env["HOME"] = Path.home()
if not EXTRACTED_GT_DIR.is_dir():
with tempfile.TemporaryDirectory() as tempdir:
tempdir = Path(tempdir)
tempdir.mkdir(exist_ok=True, parents=True)
logger.info("Downloading extracted sample hidden states...")
check_call("git lfs install".split(), cwd=tempdir, env=env)
check_call(
f"git clone {S3PRL_HF_SAMPLE_HS}".split(),
cwd=tempdir,
env=env,
)
shutil.move(
str(tempdir / "sample_hidden_states"), str(EXTRACTED_GT_DIR.parent)
)
else:
logger.info(f"{EXTRACTED_GT_DIR} exists. Perform git pull...")
check_call(
f"git pull {S3PRL_HF_SAMPLE_HS} main".split(),
cwd=EXTRACTED_GT_DIR,
env=env,
)
try:
lock_file.unlink()
except FileNotFoundError:
pass
def _extract_feat(
model: S3PRLUpstream,
seed: int = 0,
**pseudo_wavs_args,
):
wavs, wavs_len = get_pseudo_wavs(seed=seed, padded=True, **pseudo_wavs_args)
all_hs, all_lens = model(wavs, wavs_len)
return all_hs
def _all_hidden_states_same(hs1, hs2):
for h1, h2 in zip(hs1, hs2):
if h1.size(1) != h2.size(1):
length_diff = abs(h1.size(1) - h2.size(1))
assert length_diff <= MAX_LENGTH_DIFF, f"{length_diff} > {MAX_LENGTH_DIFF}"
min_seqlen = min(h1.size(1), h2.size(1))
h1 = h1[:, :min_seqlen, :]
h2 = h2[:, :min_seqlen, :]
assert torch.allclose(h1, h2, atol=ATOL)
def _load_ground_truth(name: str):
source = f"{EXTRACTED_GT_DIR}/{name}.pt"
if source.startswith("http"):
path = _urls_to_filepaths(source)
else:
path = source
return torch.load(path)
def _compare_with_extracted(name: str):
model = S3PRLUpstream(name)
model.eval()
with torch.no_grad():
hs = _extract_feat(model)
hs_gt = _load_ground_truth(name)
_all_hidden_states_same(hs, hs_gt)
for i in range(TEST_MORE_ITER):
more_hs = _extract_feat(model)
for h1, h2 in zip(hs, more_hs):
assert torch.allclose(
h1, h2
), "should have deterministic representation in eval mode"
for i in range(TEST_MORE_ITER):
more_hs = _extract_feat(model, seed=i + 1)
assert len(hs) == len(
more_hs
), "should have deterministic num_layer in eval mode"
model.train()
for i in range(TRAIN_MORE_ITER):
more_hs = _extract_feat(model, seed=i + 1)
assert len(hs) == len(
more_hs
), "should have deterministic num_layer in train mode"
def _test_forward_backward(name: str, **pseudo_wavs_args):
"""
Test the upstream with the name: 'name' can successfully forward and backward
"""
with torch.autograd.set_detect_anomaly(True):
model = S3PRLUpstream(name)
hs = _extract_feat(model, **pseudo_wavs_args)
h_sum = 0
for h in hs:
h_sum = h_sum + h.sum()
h_sum.backward()
def _filter_options(options: list):
options = [
name
for name in options
if (not name == "customized_upstream")
and (
not "mos" in name
) # mos models do not have hidden_states key. They only return a single mos score
and (
not "stft_mag" in name
) # stft_mag upstream must past the config file currently and is not so important. So, skip the test now
and (
not "pase" in name
) # pase_plus needs lots of dependencies and is difficult to be tested and is not very worthy today
and (
not name == "xls_r_1b"
) # skip due to too large model, too long download time
and (
not name == "xls_r_2b"
) # skip due to too large model, too long download time
and (
not name in ["ast", "ssast_patch_base", "ssast_frame_base"]
) # FIXME: remove timm dependency
and (not name == "vggish") # FIXME: remove resampy dependency
and (not name == "byol_s_cvt") # FIXME: remove einops dependency
and (not "lighthubert" in name) # FIXME: solve the random subnet issue
and (not name == "passt_hop160base2lvl") # too huge memory usage
and (not name == "passt_hop160base2lvlmel") # too huge memory usage
and (not name == "passt_hop100base2lvl") # too huge memory usage
and (not name == "passt_hop100base2lvlmel") # too huge memory usage
]
options = [option for option in options if "passt" in option]
return options
"""
Test cases ensure that all upstreams are working and are same with pre-extracted features
"""
def _test_specific_upstream(name: str):
_compare_with_extracted(name)
_test_forward_backward(
name, min_secs=EXTRA_SHORT_SEC, max_secs=EXTRA_SHORT_SEC, n=1
)
_test_forward_backward(
name, min_secs=EXTRA_SHORT_SEC, max_secs=EXTRA_SHORT_SEC, n=2
)
_test_forward_backward(name, min_secs=EXTRA_SHORT_SEC, max_secs=1, n=3)
@pytest.mark.upstream
@pytest.mark.parametrize(
"name",
[
"apc",
"audio_albert",
"fbank",
"mel",
"modified_cpc",
"data2vec",
"decoar_layers",
"decoar2",
"distilhubert",
# "espnet_hubert_base_iter1", # espnet will be tested separately due to complex dependency
"hubert",
"lighthubert_base",
"mockingjay",
"npc",
"discretebert",
"tera",
"unispeech_sat_base",
"vq_apc",
"vq_wav2vec",
"wav2vec",
"wav2vec2",
"wavlm",
],
)
def test_common_upstream(name):
if "espnet" in name:
try:
import espnet
except:
logger.info("Skip ESPNet upstream test cases if espnet is not installed")
return
_prepare_sample_hidden_states()
_test_specific_upstream(name)
@pytest.mark.upstream
def test_specific_upstream(upstream_names: str):
_prepare_sample_hidden_states()
if upstream_names is not None:
options = upstream_names.split(",")
tracebacks = []
for name in options:
logger.info(f"Testing upstream: '{name}'")
try:
_test_specific_upstream(name)
except Exception as e:
logger.error(f"{name}\n{traceback.format_exc()}")
tb = traceback.format_exc()
tracebacks.append((name, tb))
if len(tracebacks) > 0:
for name, tb in tracebacks:
logger.error(f"Error in {name}:\n{tb}")
logger.error(f"All failed models:\n{[name for name, _ in tracebacks]}")
assert False
@pytest.mark.upstream
@pytest.mark.slow
def test_upstream_with_extracted(upstream_names: str):
_prepare_sample_hidden_states()
if upstream_names is not None:
options = upstream_names.split(",")
else:
options = S3PRLUpstream.available_names(only_registered_ckpt=True)
options = _filter_options(options)
options = sorted(options)
tracebacks = []
for name in options:
logger.info(f"Testing upstream: '{name}'")
try:
_compare_with_extracted(name)
except Exception as e:
logger.error(f"{name}\n{traceback.format_exc()}")
tb = traceback.format_exc()
tracebacks.append((name, tb))
if len(tracebacks) > 0:
for name, tb in tracebacks:
logger.error(f"Error in {name}:\n{tb}")
logger.error(f"All failed models:\n{[name for name, _ in tracebacks]}")
assert False
@pytest.mark.upstream
@pytest.mark.slow
def test_upstream_forward_backward(upstream_names: str):
if upstream_names is not None:
options = upstream_names.split(",")
else:
options = S3PRLUpstream.available_names(only_registered_ckpt=True)
options = _filter_options(options)
options = sorted(options)
options = reversed(options)
tracebacks = []
for name in options:
logger.info(f"Testing upstream: '{name}'")
try:
_test_forward_backward(name)
except Exception as e:
logger.error(f"{name}\n{traceback.format_exc()}")
tb = traceback.format_exc()
tracebacks.append((name, tb))
if len(tracebacks) > 0:
for name, tb in tracebacks:
logger.error(f"Error in {name}:\n{tb}")
logger.error(f"All failed models:\n{[name for name, _ in tracebacks]}")
assert False
@pytest.mark.upstream
@pytest.mark.parametrize("layer_selections", [None, [0, 4, 9]])
@pytest.mark.parametrize("normalize", [False, True])
def test_featurizer(layer_selections, normalize):
model = S3PRLUpstream("hubert")
featurizer = Featurizer(
model, layer_selections=layer_selections, normalize=normalize
)
wavs, wavs_len = get_pseudo_wavs(padded=True)
all_hs, all_lens = model(wavs, wavs_len)
hs, hs_len = featurizer(all_hs, all_lens)
assert isinstance(hs, torch.FloatTensor)
assert isinstance(hs_len, torch.LongTensor)
@pytest.mark.upstream
def test_upstream_properties():
model = S3PRLUpstream("hubert")
featurizer = Featurizer(model)
assert isinstance(model.hidden_sizes, (tuple, list)) and isinstance(
model.hidden_sizes[0], int
)
assert isinstance(model.downsample_rates, (tuple, list)) and isinstance(
model.downsample_rates[0], int
)
assert isinstance(featurizer.output_size, int)
assert isinstance(featurizer.downsample_rate, int)