flexthink commited on
Commit
4101f55
·
1 Parent(s): 6c51980

Gumbel fixes

Browse files
Files changed (1) hide show
  1. custom_interface.py +10 -2
custom_interface.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  import math
3
  from speechbrain.inference.interfaces import Pretrained
@@ -80,7 +81,6 @@ class Discrete_EmbeddingLayer(torch.nn.Module):
80
  self.layers = layers
81
  self.available_layers = available_layers
82
  self.offsets = self.build_offsets()
83
- self.layer_embs = self.compute_layer_embs()
84
  self.chunk_size = chunk_size
85
 
86
  def init_embedding(self, weights):
@@ -95,7 +95,10 @@ class Discrete_EmbeddingLayer(torch.nn.Module):
95
  )
96
  if self.layers:
97
  selected_layers = set(self.layers)
98
- indexes = [idx for idx, layer in enumerate(self.layers) if layer in selected_layers]
 
 
 
99
  offsets = offsets[indexes]
100
  return offsets
101
 
@@ -190,6 +193,11 @@ class Discrete_EmbeddingLayer(torch.nn.Module):
190
  )
191
  return emb
192
 
 
 
 
 
 
193
 
194
  class DiscreteSpkEmb(Pretrained):
195
  """A ready-to-use class for utterance-level classification (e.g, speaker-id,
 
1
+ from typing import Mapping
2
  import torch
3
  import math
4
  from speechbrain.inference.interfaces import Pretrained
 
81
  self.layers = layers
82
  self.available_layers = available_layers
83
  self.offsets = self.build_offsets()
 
84
  self.chunk_size = chunk_size
85
 
86
  def init_embedding(self, weights):
 
95
  )
96
  if self.layers:
97
  selected_layers = set(self.layers)
98
+ indexes = [
99
+ idx for idx, layer in enumerate(self.available_layers)
100
+ if layer in selected_layers
101
+ ]
102
  offsets = offsets[indexes]
103
  return offsets
104
 
 
193
  )
194
  return emb
195
 
196
+ def load_state_dict(self, state_dict, strict=True):
197
+ result = super().load_state_dict(state_dict, strict)
198
+ self.layer_embs = self.compute_layer_embs()
199
+ return result
200
+
201
 
202
  class DiscreteSpkEmb(Pretrained):
203
  """A ready-to-use class for utterance-level classification (e.g, speaker-id,