saefro991 commited on
Commit
e13d732
·
1 Parent(s): 7f297f1

add scorer for quick start

Browse files
Files changed (2) hide show
  1. predict.py +85 -0
  2. score.py +122 -0
predict.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import pathlib
3
+ import tqdm
4
+ from torch.utils.data import Dataset, DataLoader
5
+ import torchaudio
6
+ from score import Score
7
+ import torch
8
+
9
+ def get_arg():
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument("--bs", required=False, default=None, type=int)
12
+ parser.add_argument("--mode", required=True, choices=["predict_file", "predict_dir"], type=str)
13
+ parser.add_argument("--ckpt_path", required=False, default="epoch=3-step=7459.ckpt", type=pathlib.Path)
14
+ parser.add_argument("--inp_dir", required=False, default=None, type=pathlib.Path)
15
+ parser.add_argument("--inp_path", required=False, default=None, type=pathlib.Path)
16
+ parser.add_argument("--out_path", required=True, type=pathlib.Path)
17
+ parser.add_argument("--num_workers", required=False, default=0, type=int)
18
+ return parser.parse_args()
19
+
20
+
21
+ class Dataset(Dataset):
22
+ def __init__(self, dir_path: pathlib.Path):
23
+ self.wavlist = list(dir_path.glob("*.wav"))
24
+ _, self.sr = torchaudio.load(self.wavlist[0])
25
+
26
+ def __len__(self):
27
+ return len(self.wavlist)
28
+
29
+ def __getitem__(self, idx):
30
+ fname = self.wavlist[idx]
31
+ wav, _ = torchaudio.load(fname)
32
+ sample = {
33
+ "wav": wav}
34
+ return sample
35
+
36
+ def collate_fn(self, batch):
37
+ max_len = max([x["wav"].shape[1] for x in batch])
38
+ out = []
39
+ # Performing repeat padding
40
+ for t in batch:
41
+ wav = t["wav"]
42
+ amount_to_pad = max_len - wav.shape[1]
43
+ padding_tensor = wav.repeat(1,1+amount_to_pad//wav.size(1))
44
+ out.append(torch.cat((wav,padding_tensor[:,:amount_to_pad]),dim=1))
45
+ return torch.stack(out, dim=0)
46
+
47
+
48
+ def main():
49
+ args = get_arg()
50
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
51
+ if args.mode == "predict_file":
52
+ assert args.inp_path is not None, "inp_path is required when mode is predict_file."
53
+ assert args.inp_dir is None, "inp_dir should be None."
54
+ assert args.inp_path.exists()
55
+ assert args.inp_path.is_file()
56
+ wav, sr = torchaudio.load(args.inp_path)
57
+ scorer = Score(ckpt_path=args.ckpt_path, input_sample_rate=sr, device=device)
58
+ score = scorer.score(wav.to(device))
59
+ with open(args.out_path, "w") as fw:
60
+ fw.write(str(score[0]))
61
+ else:
62
+ assert args.inp_dir is not None, "inp_dir is required when mode is predict_dir."
63
+ assert args.bs is not None, "bs is required when mode is predict_dir."
64
+ assert args.inp_path is None, "inp_path should be None."
65
+ assert args.inp_dir.exists()
66
+ assert args.inp_dir.is_dir()
67
+ dataset = Dataset(dir_path=args.inp_dir)
68
+ loader = DataLoader(
69
+ dataset,
70
+ batch_size=args.bs,
71
+ collate_fn=dataset.collate_fn,
72
+ shuffle=True,
73
+ num_workers=args.num_workers)
74
+ sr = dataset.sr
75
+ scorer = Score(ckpt_path=args.ckpt_path, input_sample_rate=sr, device=device)
76
+ with open(args.out_path, 'w'):
77
+ pass
78
+ for batch in tqdm.tqdm(loader):
79
+ scores = scorer.score(batch.to(device))
80
+ with open(args.out_path, 'a') as fw:
81
+ fw.write("\n".join([str(s) for s in scores]))
82
+
83
+
84
+ if __name__ == '__main__':
85
+ main()
score.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ import lightning_module
4
+ import torch
5
+ import torchaudio
6
+ import unittest
7
+
8
+ class Score:
9
+ """Predicting score for each audio clip."""
10
+
11
+ def __init__(
12
+ self,
13
+ ckpt_path: str = "epoch=3-step=7459.ckpt",
14
+ input_sample_rate: int = 16000,
15
+ device: str = "cpu"):
16
+ """
17
+ Args:
18
+ ckpt_path: path to pretrained checkpoint of UTMOS strong learner.
19
+ input_sample_rate: sampling rate of input audio tensor. The input audio tensor
20
+ is automatically downsampled to 16kHz.
21
+ """
22
+ print(f"Using device: {device}")
23
+ self.device = device
24
+ self.model = lightning_module.BaselineLightningModule.load_from_checkpoint(
25
+ ckpt_path).eval().to(device)
26
+ self.in_sr = input_sample_rate
27
+ self.resampler = torchaudio.transforms.Resample(
28
+ orig_freq=input_sample_rate,
29
+ new_freq=16000,
30
+ resampling_method="sinc_interpolation",
31
+ lowpass_filter_width=6,
32
+ dtype=torch.float32,
33
+ ).to(device)
34
+
35
+ def score(self, wavs: torch.tensor) -> torch.tensor:
36
+ """
37
+ Args:
38
+ wavs: audio waveform to be evaluated. When len(wavs) == 1 or 2,
39
+ the model processes the input as a single audio clip. The model
40
+ performs batch processing when len(wavs) == 3.
41
+ """
42
+ if len(wavs.shape) == 1:
43
+ out_wavs = wavs.unsqueeze(0).unsqueeze(0)
44
+ elif len(wavs.shape) == 2:
45
+ out_wavs = wavs.unsqueeze(0)
46
+ elif len(wavs.shape) == 3:
47
+ out_wavs = wavs
48
+ else:
49
+ raise ValueError('Dimension of input tensor needs to be <= 3.')
50
+ if self.in_sr != 16000:
51
+ out_wavs = self.resampler(out_wavs)
52
+ bs = out_wavs.shape[0]
53
+ batch = {
54
+ 'wav': out_wavs,
55
+ 'domains': torch.zeros(bs, dtype=torch.int).to(self.device),
56
+ 'judge_id': torch.ones(bs, dtype=torch.int).to(self.device)*288
57
+ }
58
+ with torch.no_grad():
59
+ output = self.model(batch)
60
+
61
+ return output.mean(dim=1).squeeze(1).cpu().detach().numpy()*2 + 3
62
+
63
+
64
+ class TestFunc(unittest.TestCase):
65
+ """Test class."""
66
+
67
+ def test_1dim_0(self):
68
+ scorer = Score(input_sample_rate=16000)
69
+ seq_len = 10000
70
+ inp_audio = torch.ones(seq_len)
71
+ pred = scorer.score(inp_audio)
72
+ self.assertGreaterEqual(pred, 0.)
73
+ self.assertLessEqual(pred, 5.)
74
+
75
+ def test_1dim_1(self):
76
+ scorer = Score(input_sample_rate=24000)
77
+ seq_len = 10000
78
+ inp_audio = torch.ones(seq_len)
79
+ pred = scorer.score(inp_audio)
80
+ self.assertGreaterEqual(pred, 0.)
81
+ self.assertLessEqual(pred, 5.)
82
+
83
+ def test_2dim_0(self):
84
+ scorer = Score(input_sample_rate=16000)
85
+ seq_len = 10000
86
+ inp_audio = torch.ones(1, seq_len)
87
+ pred = scorer.score(inp_audio)
88
+ self.assertGreaterEqual(pred, 0.)
89
+ self.assertLessEqual(pred, 5.)
90
+
91
+ def test_2dim_1(self):
92
+ scorer = Score(input_sample_rate=24000)
93
+ seq_len = 10000
94
+ inp_audio = torch.ones(1, seq_len)
95
+ pred = scorer.score(inp_audio)
96
+ print(pred)
97
+ print(pred.shape)
98
+ self.assertGreaterEqual(pred, 0.)
99
+ self.assertLessEqual(pred, 5.)
100
+
101
+ def test_3dim_0(self):
102
+ scorer = Score(input_sample_rate=16000)
103
+ seq_len = 10000
104
+ batch = 8
105
+ inp_audio = torch.ones(batch, 1, seq_len)
106
+ pred = scorer.score(inp_audio)
107
+ for p in pred:
108
+ self.assertGreaterEqual(p, 0.)
109
+ self.assertLessEqual(p, 5.)
110
+
111
+ def test_3dim_1(self):
112
+ scorer = Score(input_sample_rate=24000)
113
+ seq_len = 10000
114
+ batch = 8
115
+ inp_audio = torch.ones(batch, 1, seq_len)
116
+ pred = scorer.score(inp_audio)
117
+ for p in pred:
118
+ self.assertGreaterEqual(p, 0.)
119
+ self.assertLessEqual(p, 5.)
120
+
121
+ if __name__ == '__main__':
122
+ unittest.main()