warisqr7 commited on
Commit
801f630
1 Parent(s): d567409

Update custom_interface.py

Browse files
Files changed (1) hide show
  1. custom_interface.py +22 -4
custom_interface.py CHANGED
@@ -140,6 +140,27 @@ class CustomEncoderWav2vec2Classifier(Pretrained):
140
  rel_length = torch.tensor([1.0])
141
  outputs = self.encode_batch(batch, rel_length)
142
  return outputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
 
145
  def classify_file(self, path):
@@ -192,10 +213,7 @@ class CustomEncoderWav2vec2Classifier(Pretrained):
192
  (label encoder should be provided).
193
  """
194
  # Fake a batch:
195
- waveform = self.audio_normalizer(sample, sr)
196
- batch = waveform.unsqueeze(0)
197
- rel_length = torch.tensor([1.0])
198
- outputs = self.encode_batch(batch, rel_length)
199
  outputs = self.mods.output_mlp(outputs).squeeze(1)
200
  out_prob = self.hparams.softmax(outputs)
201
  score, index = torch.max(out_prob, dim=-1)
 
140
  rel_length = torch.tensor([1.0])
141
  outputs = self.encode_batch(batch, rel_length)
142
  return outputs
143
+
144
+ def embed_sample(self, sample, sr):
145
+ """Returns embedding (last layer output) for the given audiofile.
146
+
147
+ Arguments
148
+ ---------
149
+ ample : torch tensor
150
+ wav tensor. ([T, 1])
151
+ sr: int
152
+ sampling rate.
153
+
154
+ Returns
155
+ -------
156
+ embed
157
+ The log posterior probabilities of each class ([batch, embed_dim])
158
+ """
159
+ waveform = self.audio_normalizer(sample, sr)
160
+ batch = waveform.unsqueeze(0)
161
+ rel_length = torch.tensor([1.0])
162
+ outputs = self.encode_batch(batch, rel_length)
163
+ return outputs
164
 
165
 
166
  def classify_file(self, path):
 
213
  (label encoder should be provided).
214
  """
215
  # Fake a batch:
216
+ outputs = self.embed_sample(sample, sr)
 
 
 
217
  outputs = self.mods.output_mlp(outputs).squeeze(1)
218
  out_prob = self.hparams.softmax(outputs)
219
  score, index = torch.max(out_prob, dim=-1)