Spaces:
Running
Running
File size: 3,982 Bytes
e13d732 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import lightning_module
import torch
import torchaudio
import unittest
class Score:
"""Predicting score for each audio clip."""
def __init__(
self,
ckpt_path: str = "epoch=3-step=7459.ckpt",
input_sample_rate: int = 16000,
device: str = "cpu"):
"""
Args:
ckpt_path: path to pretrained checkpoint of UTMOS strong learner.
input_sample_rate: sampling rate of input audio tensor. The input audio tensor
is automatically downsampled to 16kHz.
"""
print(f"Using device: {device}")
self.device = device
self.model = lightning_module.BaselineLightningModule.load_from_checkpoint(
ckpt_path).eval().to(device)
self.in_sr = input_sample_rate
self.resampler = torchaudio.transforms.Resample(
orig_freq=input_sample_rate,
new_freq=16000,
resampling_method="sinc_interpolation",
lowpass_filter_width=6,
dtype=torch.float32,
).to(device)
def score(self, wavs: torch.tensor) -> torch.tensor:
"""
Args:
wavs: audio waveform to be evaluated. When len(wavs) == 1 or 2,
the model processes the input as a single audio clip. The model
performs batch processing when len(wavs) == 3.
"""
if len(wavs.shape) == 1:
out_wavs = wavs.unsqueeze(0).unsqueeze(0)
elif len(wavs.shape) == 2:
out_wavs = wavs.unsqueeze(0)
elif len(wavs.shape) == 3:
out_wavs = wavs
else:
raise ValueError('Dimension of input tensor needs to be <= 3.')
if self.in_sr != 16000:
out_wavs = self.resampler(out_wavs)
bs = out_wavs.shape[0]
batch = {
'wav': out_wavs,
'domains': torch.zeros(bs, dtype=torch.int).to(self.device),
'judge_id': torch.ones(bs, dtype=torch.int).to(self.device)*288
}
with torch.no_grad():
output = self.model(batch)
return output.mean(dim=1).squeeze(1).cpu().detach().numpy()*2 + 3
class TestFunc(unittest.TestCase):
"""Test class."""
def test_1dim_0(self):
scorer = Score(input_sample_rate=16000)
seq_len = 10000
inp_audio = torch.ones(seq_len)
pred = scorer.score(inp_audio)
self.assertGreaterEqual(pred, 0.)
self.assertLessEqual(pred, 5.)
def test_1dim_1(self):
scorer = Score(input_sample_rate=24000)
seq_len = 10000
inp_audio = torch.ones(seq_len)
pred = scorer.score(inp_audio)
self.assertGreaterEqual(pred, 0.)
self.assertLessEqual(pred, 5.)
def test_2dim_0(self):
scorer = Score(input_sample_rate=16000)
seq_len = 10000
inp_audio = torch.ones(1, seq_len)
pred = scorer.score(inp_audio)
self.assertGreaterEqual(pred, 0.)
self.assertLessEqual(pred, 5.)
def test_2dim_1(self):
scorer = Score(input_sample_rate=24000)
seq_len = 10000
inp_audio = torch.ones(1, seq_len)
pred = scorer.score(inp_audio)
print(pred)
print(pred.shape)
self.assertGreaterEqual(pred, 0.)
self.assertLessEqual(pred, 5.)
def test_3dim_0(self):
scorer = Score(input_sample_rate=16000)
seq_len = 10000
batch = 8
inp_audio = torch.ones(batch, 1, seq_len)
pred = scorer.score(inp_audio)
for p in pred:
self.assertGreaterEqual(p, 0.)
self.assertLessEqual(p, 5.)
def test_3dim_1(self):
scorer = Score(input_sample_rate=24000)
seq_len = 10000
batch = 8
inp_audio = torch.ones(batch, 1, seq_len)
pred = scorer.score(inp_audio)
for p in pred:
self.assertGreaterEqual(p, 0.)
self.assertLessEqual(p, 5.)
if __name__ == '__main__':
unittest.main() |