Update custom_interface.py
Browse files- custom_interface.py +1 -1
custom_interface.py
CHANGED
@@ -139,7 +139,6 @@ class CustomEncoderWav2vec2Classifier(Pretrained):
|
|
139 |
batch = waveform.unsqueeze(0)
|
140 |
rel_length = torch.tensor([1.0])
|
141 |
outputs = self.encode_batch(batch, rel_length)
|
142 |
-
outputs = self.mods.output_mlp(outputs).squeeze(1)
|
143 |
return outputs
|
144 |
|
145 |
|
@@ -164,6 +163,7 @@ class CustomEncoderWav2vec2Classifier(Pretrained):
|
|
164 |
(label encoder should be provided).
|
165 |
"""
|
166 |
outputs = self.embed_file(path)
|
|
|
167 |
out_prob = self.hparams.softmax(outputs)
|
168 |
score, index = torch.max(out_prob, dim=-1)
|
169 |
text_lab = self.hparams.label_encoder.decode_torch(index)
|
|
|
139 |
batch = waveform.unsqueeze(0)
|
140 |
rel_length = torch.tensor([1.0])
|
141 |
outputs = self.encode_batch(batch, rel_length)
|
|
|
142 |
return outputs
|
143 |
|
144 |
|
|
|
163 |
(label encoder should be provided).
|
164 |
"""
|
165 |
outputs = self.embed_file(path)
|
166 |
+
outputs = self.mods.output_mlp(outputs).squeeze(1)
|
167 |
out_prob = self.hparams.softmax(outputs)
|
168 |
score, index = torch.max(out_prob, dim=-1)
|
169 |
text_lab = self.hparams.label_encoder.decode_torch(index)
|