warisqr7 commited on
Commit
93fa313
1 Parent(s): 3103d3b

Update custom_interface.py

Browse files
Files changed (1) hide show
  1. custom_interface.py +33 -0
custom_interface.py CHANGED
@@ -152,6 +152,39 @@ class CustomEncoderWav2vec2Classifier(Pretrained):
152
  text_lab = self.hparams.label_encoder.decode_torch(index)
153
  return out_prob, score, index, text_lab
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  def forward(self, wavs, wav_lens=None, normalize=False):
156
  return self.encode_batch(
157
  wavs=wavs, wav_lens=wav_lens, normalize=normalize
 
152
  text_lab = self.hparams.label_encoder.decode_torch(index)
153
  return out_prob, score, index, text_lab
154
 
155
+ def classify_sample(self, sample, sr):
156
+ """Classifies the given audio sample into the given set of labels.
157
+
158
+ Arguments
159
+ ---------
160
+ sample : torch tensor
161
+ wav tensor. ([T, 1])
162
+ sr: int
163
+ sampling rate.
164
+
165
+ Returns
166
+ -------
167
+ out_prob
168
+ The log posterior probabilities of each class ([batch, N_class])
169
+ score:
170
+ It is the value of the log-posterior for the best class ([batch,])
171
+ index
172
+ The indexes of the best class ([batch,])
173
+ text_lab:
174
+ List with the text labels corresponding to the indexes.
175
+ (label encoder should be provided).
176
+ """
177
+ # Fake a batch:
178
+ waveform = self.audio_normalizer(sample, sr)
179
+ batch = waveform.unsqueeze(0)
180
+ rel_length = torch.tensor([1.0])
181
+ outputs = self.encode_batch(batch, rel_length)
182
+ outputs = self.mods.output_mlp(outputs).squeeze(1)
183
+ out_prob = self.hparams.softmax(outputs)
184
+ score, index = torch.max(out_prob, dim=-1)
185
+ text_lab = self.hparams.label_encoder.decode_torch(index)
186
+ return out_prob, score, index, text_lab
187
+
188
  def forward(self, wavs, wav_lens=None, normalize=False):
189
  return self.encode_batch(
190
  wavs=wavs, wav_lens=wav_lens, normalize=normalize