File size: 8,100 Bytes
6c404c6 f19e2e7 6c404c6 9a56940 801f630 ef78dc0 801f630 ef78dc0 801f630 ef78dc0 801f630 9a56940 6c404c6 9a56940 d567409 6c404c6 93fa313 801f630 93fa313 6c404c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
import torch
from speechbrain.inference.interfaces import Pretrained
class CustomEncoderWav2vec2Classifier(Pretrained):
"""A ready-to-use class for utterance-level classification (e.g, speaker-id,
language-id, emotion recognition, keyword spotting, etc).
The class assumes that an self-supervised encoder like wav2vec2/hubert and a classifier model
are defined in the yaml file. If you want to
convert the predicted index into a corresponding text label, please
provide the path of the label_encoder in a variable called 'lab_encoder_file'
within the yaml.
The class can be used either to run only the encoder (encode_batch()) to
extract embeddings or to run a classification step (classify_batch()).
```
Example
-------
>>> import torchaudio
>>> from speechbrain.pretrained import EncoderClassifier
>>> # Model is downloaded from the speechbrain HuggingFace repo
>>> tmpdir = getfixture("tmpdir")
>>> classifier = EncoderClassifier.from_hparams(
... source="speechbrain/spkrec-ecapa-voxceleb",
... savedir=tmpdir,
... )
>>> # Compute embeddings
>>> signal, fs = torchaudio.load("samples/audio_samples/example1.wav")
>>> embeddings = classifier.encode_batch(signal)
>>> # Classification
>>> prediction = classifier .classify_batch(signal)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def encode_batch(self, wavs, wav_lens=None, normalize=False):
"""Encodes the input audio into a single vector embedding.
The waveforms should already be in the model's desired format.
You can call:
``normalized = <this>.normalizer(signal, sample_rate)``
to get a correctly converted signal in most cases.
Arguments
---------
wavs : torch.tensor
Batch of waveforms [batch, time, channels] or [batch, time]
depending on the model. Make sure the sample rate is fs=16000 Hz.
wav_lens : torch.tensor
Lengths of the waveforms relative to the longest one in the
batch, tensor of shape [batch]. The longest one should have
relative length 1.0 and others len(waveform) / max_length.
Used for ignoring padding.
normalize : bool
If True, it normalizes the embeddings with the statistics
contained in mean_var_norm_emb.
Returns
-------
torch.tensor
The encoded batch
"""
# Manage single waveforms in input
if len(wavs.shape) == 1:
wavs = wavs.unsqueeze(0)
# Assign full length if wav_lens is not assigned
if wav_lens is None:
wav_lens = torch.ones(wavs.shape[0], device=self.device)
# Storing waveform in the specified device
wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
wavs = wavs.float()
# Computing features and embeddings
outputs = self.mods.wav2vec2(wavs)
# last dim will be used for AdaptativeAVG pool
outputs = self.mods.avg_pool(outputs, wav_lens)
outputs = outputs.view(outputs.shape[0], -1)
return outputs
def classify_batch(self, wavs, wav_lens=None):
"""Performs classification on the top of the encoded features.
It returns the posterior probabilities, the index and, if the label
encoder is specified it also the text label.
Arguments
---------
wavs : torch.tensor
Batch of waveforms [batch, time, channels] or [batch, time]
depending on the model. Make sure the sample rate is fs=16000 Hz.
wav_lens : torch.tensor
Lengths of the waveforms relative to the longest one in the
batch, tensor of shape [batch]. The longest one should have
relative length 1.0 and others len(waveform) / max_length.
Used for ignoring padding.
Returns
-------
out_prob
The log posterior probabilities of each class ([batch, N_class])
score:
It is the value of the log-posterior for the best class ([batch,])
index
The indexes of the best class ([batch,])
text_lab:
List with the text labels corresponding to the indexes.
(label encoder should be provided).
"""
outputs = self.encode_batch(wavs, wav_lens)
outputs = self.mods.output_mlp(outputs)
out_prob = self.hparams.softmax(outputs)
score, index = torch.max(out_prob, dim=-1)
text_lab = self.hparams.label_encoder.decode_torch(index)
return out_prob, score, index, text_lab
def embed_file(self, path):
"""Returns embedding (last layer output) for the given audiofile.
Arguments
---------
path : str
Path to audio file to classify.
Returns
-------
embed
The log posterior probabilities of each class ([batch, embed_dim])
"""
waveform = self.load_audio(path)
# Fake a batch:
batch = waveform.unsqueeze(0)
rel_length = torch.tensor([1.0])
outputs = self.encode_batch(batch, rel_length)
return outputs
def embed_sample(self, sample, sr):
"""Returns embedding (last layer output) for the given audio sample.
Arguments
---------
sample : torch tensor
wav tensor. ([1, T])
sr: int
sampling rate.
Returns
-------
embed
The log posterior probabilities of each class ([batch, embed_dim])
"""
waveform = self.audio_normalizer(sample.transpose(0,1), sr)
batch = waveform.unsqueeze(0)
rel_length = torch.tensor([1.0])
outputs = self.encode_batch(batch, rel_length)
return outputs
def classify_file(self, path):
"""Classifies the given audiofile into the given set of labels.
Arguments
---------
path : str
Path to audio file to classify.
Returns
-------
out_prob
The log posterior probabilities of each class ([batch, N_class])
score:
It is the value of the log-posterior for the best class ([batch,])
index
The indexes of the best class ([batch,])
text_lab:
List with the text labels corresponding to the indexes.
(label encoder should be provided).
"""
outputs = self.embed_file(path)
outputs = self.mods.output_mlp(outputs).squeeze(1)
out_prob = self.hparams.softmax(outputs)
score, index = torch.max(out_prob, dim=-1)
text_lab = self.hparams.label_encoder.decode_torch(index)
return out_prob, score, index, text_lab
def classify_sample(self, sample, sr):
"""Classifies the given audio sample into the given set of labels.
Arguments
---------
sample : torch tensor
wav tensor. ([T, 1])
sr: int
sampling rate.
Returns
-------
out_prob
The log posterior probabilities of each class ([batch, N_class])
score:
It is the value of the log-posterior for the best class ([batch,])
index
The indexes of the best class ([batch,])
text_lab:
List with the text labels corresponding to the indexes.
(label encoder should be provided).
"""
# Fake a batch:
outputs = self.embed_sample(sample, sr)
outputs = self.mods.output_mlp(outputs).squeeze(1)
out_prob = self.hparams.softmax(outputs)
score, index = torch.max(out_prob, dim=-1)
text_lab = self.hparams.label_encoder.decode_torch(index)
return out_prob, score, index, text_lab
def forward(self, wavs, wav_lens=None, normalize=False):
return self.encode_batch(
wavs=wavs, wav_lens=wav_lens, normalize=normalize
) |