|
import random |
|
import argparse |
|
import numpy as np |
|
|
|
import torch |
|
from s3prl.nn import S3PRLUpstream |
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
SAMPLE_RATE = 16000 |
|
BATCH_SIZE = 3 |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("upstream") |
|
parser.add_argument("--ckpt") |
|
parser.add_argument("--device", default="cuda") |
|
parser.add_argument("--seed", type=int, default=0) |
|
args = parser.parse_args() |
|
|
|
random.seed(args.seed) |
|
np.random.seed(args.seed) |
|
torch.manual_seed(args.seed) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed_all(args.seed) |
|
torch.backends.cudnn.deterministic = True |
|
torch.backends.cudnn.benchmark = False |
|
|
|
upstream = S3PRLUpstream(args.upstream, args.ckpt).to(args.device) |
|
wavs = [ |
|
torch.randn(random.randint(SAMPLE_RATE * 1, SAMPLE_RATE * 15)).to(args.device) |
|
for _ in range(BATCH_SIZE) |
|
] |
|
wavs_len = torch.LongTensor([len(w) for w in wavs]).to(args.device) |
|
wavs = pad_sequence(wavs, batch_first=True) |
|
|
|
with torch.no_grad(): |
|
upstream.eval() |
|
hidden, hidden_len = upstream(wavs, wavs_len) |
|
|