flexthink
commited on
Commit
·
4101f55
1
Parent(s):
6c51980
Gumbel fixes
Browse files- custom_interface.py +10 -2
custom_interface.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
import torch
|
2 |
import math
|
3 |
from speechbrain.inference.interfaces import Pretrained
|
@@ -80,7 +81,6 @@ class Discrete_EmbeddingLayer(torch.nn.Module):
|
|
80 |
self.layers = layers
|
81 |
self.available_layers = available_layers
|
82 |
self.offsets = self.build_offsets()
|
83 |
-
self.layer_embs = self.compute_layer_embs()
|
84 |
self.chunk_size = chunk_size
|
85 |
|
86 |
def init_embedding(self, weights):
|
@@ -95,7 +95,10 @@ class Discrete_EmbeddingLayer(torch.nn.Module):
|
|
95 |
)
|
96 |
if self.layers:
|
97 |
selected_layers = set(self.layers)
|
98 |
-
indexes = [
|
|
|
|
|
|
|
99 |
offsets = offsets[indexes]
|
100 |
return offsets
|
101 |
|
@@ -190,6 +193,11 @@ class Discrete_EmbeddingLayer(torch.nn.Module):
|
|
190 |
)
|
191 |
return emb
|
192 |
|
|
|
|
|
|
|
|
|
|
|
193 |
|
194 |
class DiscreteSpkEmb(Pretrained):
|
195 |
"""A ready-to-use class for utterance-level classification (e.g, speaker-id,
|
|
|
1 |
+
from typing import Mapping
|
2 |
import torch
|
3 |
import math
|
4 |
from speechbrain.inference.interfaces import Pretrained
|
|
|
81 |
self.layers = layers
|
82 |
self.available_layers = available_layers
|
83 |
self.offsets = self.build_offsets()
|
|
|
84 |
self.chunk_size = chunk_size
|
85 |
|
86 |
def init_embedding(self, weights):
|
|
|
95 |
)
|
96 |
if self.layers:
|
97 |
selected_layers = set(self.layers)
|
98 |
+
indexes = [
|
99 |
+
idx for idx, layer in enumerate(self.available_layers)
|
100 |
+
if layer in selected_layers
|
101 |
+
]
|
102 |
offsets = offsets[indexes]
|
103 |
return offsets
|
104 |
|
|
|
193 |
)
|
194 |
return emb
|
195 |
|
196 |
+
def load_state_dict(self, state_dict, strict=True):
|
197 |
+
result = super().load_state_dict(state_dict, strict)
|
198 |
+
self.layer_embs = self.compute_layer_embs()
|
199 |
+
return result
|
200 |
+
|
201 |
|
202 |
class DiscreteSpkEmb(Pretrained):
|
203 |
"""A ready-to-use class for utterance-level classification (e.g, speaker-id,
|