File size: 955 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import os
import torch
import random
import argparse
import numpy as np
from s3prl import hub

SAMPLE_RATE = 16000
BATCH_SIZE = 8

parser = argparse.ArgumentParser()
parser.add_argument("--upstream", "-u", required=True)
parser.add_argument("--output", "-o", required=True)
parser.add_argument("--device", default="cuda")
parser.add_argument("--seed", type=int, default=0)
args = parser.parse_args()

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

wavs = [
    torch.randn(random.randint(SAMPLE_RATE * 1, SAMPLE_RATE * 15)).to(args.device)
    for _ in range(BATCH_SIZE)
]

upstream = getattr(hub, args.upstream)().to(args.device)
upstream.eval()

with torch.no_grad():
    hidden = upstream(wavs)["last_hidden_state"].detach().cpu()
    torch.save(hidden, args.output)