Update custom_interface.py
Browse files- custom_interface.py +33 -0
custom_interface.py
CHANGED
@@ -152,6 +152,39 @@ class CustomEncoderWav2vec2Classifier(Pretrained):
|
|
152 |
text_lab = self.hparams.label_encoder.decode_torch(index)
|
153 |
return out_prob, score, index, text_lab
|
154 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
def forward(self, wavs, wav_lens=None, normalize=False):
|
156 |
return self.encode_batch(
|
157 |
wavs=wavs, wav_lens=wav_lens, normalize=normalize
|
|
|
152 |
text_lab = self.hparams.label_encoder.decode_torch(index)
|
153 |
return out_prob, score, index, text_lab
|
154 |
|
155 |
+
def classify_sample(self, sample, sr):
|
156 |
+
"""Classifies the given audio sample into the given set of labels.
|
157 |
+
|
158 |
+
Arguments
|
159 |
+
---------
|
160 |
+
sample : torch tensor
|
161 |
+
wav tensor. ([T, 1])
|
162 |
+
sr: int
|
163 |
+
sampling rate.
|
164 |
+
|
165 |
+
Returns
|
166 |
+
-------
|
167 |
+
out_prob
|
168 |
+
The log posterior probabilities of each class ([batch, N_class])
|
169 |
+
score:
|
170 |
+
It is the value of the log-posterior for the best class ([batch,])
|
171 |
+
index
|
172 |
+
The indexes of the best class ([batch,])
|
173 |
+
text_lab:
|
174 |
+
List with the text labels corresponding to the indexes.
|
175 |
+
(label encoder should be provided).
|
176 |
+
"""
|
177 |
+
# Fake a batch:
|
178 |
+
waveform = self.audio_normalizer(sample, sr)
|
179 |
+
batch = waveform.unsqueeze(0)
|
180 |
+
rel_length = torch.tensor([1.0])
|
181 |
+
outputs = self.encode_batch(batch, rel_length)
|
182 |
+
outputs = self.mods.output_mlp(outputs).squeeze(1)
|
183 |
+
out_prob = self.hparams.softmax(outputs)
|
184 |
+
score, index = torch.max(out_prob, dim=-1)
|
185 |
+
text_lab = self.hparams.label_encoder.decode_torch(index)
|
186 |
+
return out_prob, score, index, text_lab
|
187 |
+
|
188 |
def forward(self, wavs, wav_lens=None, normalize=False):
|
189 |
return self.encode_batch(
|
190 |
wavs=wavs, wav_lens=wav_lens, normalize=normalize
|