Update custom_interface.py
Browse files- custom_interface.py +22 -4
custom_interface.py
CHANGED
@@ -140,6 +140,27 @@ class CustomEncoderWav2vec2Classifier(Pretrained):
|
|
140 |
rel_length = torch.tensor([1.0])
|
141 |
outputs = self.encode_batch(batch, rel_length)
|
142 |
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
|
145 |
def classify_file(self, path):
|
@@ -192,10 +213,7 @@ class CustomEncoderWav2vec2Classifier(Pretrained):
|
|
192 |
(label encoder should be provided).
|
193 |
"""
|
194 |
# Fake a batch:
|
195 |
-
|
196 |
-
batch = waveform.unsqueeze(0)
|
197 |
-
rel_length = torch.tensor([1.0])
|
198 |
-
outputs = self.encode_batch(batch, rel_length)
|
199 |
outputs = self.mods.output_mlp(outputs).squeeze(1)
|
200 |
out_prob = self.hparams.softmax(outputs)
|
201 |
score, index = torch.max(out_prob, dim=-1)
|
|
|
140 |
rel_length = torch.tensor([1.0])
|
141 |
outputs = self.encode_batch(batch, rel_length)
|
142 |
return outputs
|
143 |
+
|
144 |
+
def embed_sample(self, sample, sr):
|
145 |
+
"""Returns embedding (last layer output) for the given audiofile.
|
146 |
+
|
147 |
+
Arguments
|
148 |
+
---------
|
149 |
+
ample : torch tensor
|
150 |
+
wav tensor. ([T, 1])
|
151 |
+
sr: int
|
152 |
+
sampling rate.
|
153 |
+
|
154 |
+
Returns
|
155 |
+
-------
|
156 |
+
embed
|
157 |
+
The log posterior probabilities of each class ([batch, embed_dim])
|
158 |
+
"""
|
159 |
+
waveform = self.audio_normalizer(sample, sr)
|
160 |
+
batch = waveform.unsqueeze(0)
|
161 |
+
rel_length = torch.tensor([1.0])
|
162 |
+
outputs = self.encode_batch(batch, rel_length)
|
163 |
+
return outputs
|
164 |
|
165 |
|
166 |
def classify_file(self, path):
|
|
|
213 |
(label encoder should be provided).
|
214 |
"""
|
215 |
# Fake a batch:
|
216 |
+
outputs = self.embed_sample(sample, sr)
|
|
|
|
|
|
|
217 |
outputs = self.mods.output_mlp(outputs).squeeze(1)
|
218 |
out_prob = self.hparams.softmax(outputs)
|
219 |
score, index = torch.max(out_prob, dim=-1)
|