warisqr7 commited on
Commit
d567409
1 Parent(s): 9a56940

Update custom_interface.py

Browse files
Files changed (1) hide show
  1. custom_interface.py +1 -1
custom_interface.py CHANGED
@@ -139,7 +139,6 @@ class CustomEncoderWav2vec2Classifier(Pretrained):
139
  batch = waveform.unsqueeze(0)
140
  rel_length = torch.tensor([1.0])
141
  outputs = self.encode_batch(batch, rel_length)
142
- outputs = self.mods.output_mlp(outputs).squeeze(1)
143
  return outputs
144
 
145
 
@@ -164,6 +163,7 @@ class CustomEncoderWav2vec2Classifier(Pretrained):
164
  (label encoder should be provided).
165
  """
166
  outputs = self.embed_file(path)
 
167
  out_prob = self.hparams.softmax(outputs)
168
  score, index = torch.max(out_prob, dim=-1)
169
  text_lab = self.hparams.label_encoder.decode_torch(index)
 
139
  batch = waveform.unsqueeze(0)
140
  rel_length = torch.tensor([1.0])
141
  outputs = self.encode_batch(batch, rel_length)
 
142
  return outputs
143
 
144
 
 
163
  (label encoder should be provided).
164
  """
165
  outputs = self.embed_file(path)
166
+ outputs = self.mods.output_mlp(outputs).squeeze(1)
167
  out_prob = self.hparams.softmax(outputs)
168
  score, index = torch.max(out_prob, dim=-1)
169
  text_lab = self.hparams.label_encoder.decode_torch(index)