import logging import argparse from pathlib import Path import torch from s3prl.nn import S3PRLUpstream from s3prl.util.pseudo_data import get_pseudo_wavs from s3prl.util.override import parse_overrides logger = logging.getLogger(__name__) SAMPLE_RATE = 16000 if __name__ == "__main__": logging.basicConfig(level=logging.INFO) parser = argparse.ArgumentParser() parser.add_argument("name") parser.add_argument("--output_dir", default="./sample_hidden_states") parser.add_argument("--refresh", action="store_true") parser.add_argument("--device", default="cuda") args, others = parser.parse_known_args() overrides = parse_overrides(others) output_dir = Path(args.output_dir) output_dir.mkdir(exist_ok=True, parents=True) model = S3PRLUpstream(args.name, refresh=args.refresh, extra_conf=overrides).to( args.device ) model.eval() with torch.no_grad(): x, x_len = get_pseudo_wavs(padded=True) hs, hs_len = model(x.to(args.device), x_len.to(args.device)) hs = [h.detach().cpu() for h, h_len in zip(hs, hs_len)] torch.save(hs, output_dir / f"{args.name}.pt")