flexthink
commited on
Commit
·
6c51980
1
Parent(s):
380887d
Add support for Gumbel encoding
Browse files- custom_interface.py +130 -2
custom_interface.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
import torch
|
|
|
2 |
from speechbrain.inference.interfaces import Pretrained
|
3 |
|
|
|
4 |
class AttentionMLP(torch.nn.Module):
|
5 |
def __init__(self, input_dim, hidden_dim):
|
6 |
super(AttentionMLP, self).__init__()
|
@@ -32,8 +34,11 @@ class Discrete_EmbeddingLayer(torch.nn.Module):
|
|
32 |
init: boolean (default: False):
|
33 |
If set to True, init the embedding with the tokenizer embedding otherwise init randomly.
|
34 |
freeze: boolean (default: False)
|
35 |
-
|
36 |
alongside with the rest of the pipeline.
|
|
|
|
|
|
|
37 |
|
38 |
Example
|
39 |
-------
|
@@ -62,6 +67,7 @@ class Discrete_EmbeddingLayer(torch.nn.Module):
|
|
62 |
freeze=False,
|
63 |
available_layers=None,
|
64 |
layers=None,
|
|
|
65 |
):
|
66 |
super(Discrete_EmbeddingLayer, self).__init__()
|
67 |
self.vocab_size = vocab_size
|
@@ -74,6 +80,8 @@ class Discrete_EmbeddingLayer(torch.nn.Module):
|
|
74 |
self.layers = layers
|
75 |
self.available_layers = available_layers
|
76 |
self.offsets = self.build_offsets()
|
|
|
|
|
77 |
|
78 |
def init_embedding(self, weights):
|
79 |
with torch.no_grad():
|
@@ -111,6 +119,77 @@ class Discrete_EmbeddingLayer(torch.nn.Module):
|
|
111 |
in_embs = self.embedding(in_tokens_offset.int())
|
112 |
return in_embs
|
113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
class DiscreteSpkEmb(Pretrained):
|
116 |
"""A ready-to-use class for utterance-level classification (e.g, speaker-id,
|
@@ -168,5 +247,54 @@ class DiscreteSpkEmb(Pretrained):
|
|
168 |
embeddings = self.mods.embedding_model(feats, length)
|
169 |
return embeddings.squeeze(1)
|
170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
def forward(self, audio, length=None):
|
172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
+
import math
|
3 |
from speechbrain.inference.interfaces import Pretrained
|
4 |
|
5 |
+
|
6 |
class AttentionMLP(torch.nn.Module):
|
7 |
def __init__(self, input_dim, hidden_dim):
|
8 |
super(AttentionMLP, self).__init__()
|
|
|
34 |
init: boolean (default: False):
|
35 |
If set to True, init the embedding with the tokenizer embedding otherwise init randomly.
|
36 |
freeze: boolean (default: False)
|
37 |
+
If True, the embedding is frozen. If False, the model will be trained
|
38 |
alongside with the rest of the pipeline.
|
39 |
+
chunk_size: int
|
40 |
+
The size of lengthwize chunks use when evaluating via
|
41 |
+
Gumbel softmax
|
42 |
|
43 |
Example
|
44 |
-------
|
|
|
67 |
freeze=False,
|
68 |
available_layers=None,
|
69 |
layers=None,
|
70 |
+
chunk_size=100,
|
71 |
):
|
72 |
super(Discrete_EmbeddingLayer, self).__init__()
|
73 |
self.vocab_size = vocab_size
|
|
|
80 |
self.layers = layers
|
81 |
self.available_layers = available_layers
|
82 |
self.offsets = self.build_offsets()
|
83 |
+
self.layer_embs = self.compute_layer_embs()
|
84 |
+
self.chunk_size = chunk_size
|
85 |
|
86 |
def init_embedding(self, weights):
|
87 |
with torch.no_grad():
|
|
|
119 |
in_embs = self.embedding(in_tokens_offset.int())
|
120 |
return in_embs
|
121 |
|
122 |
+
def compute_layer_embs(self):
|
123 |
+
weight = self.embedding.weight
|
124 |
+
|
125 |
+
# Compute offsets
|
126 |
+
layer_idx_map = {
|
127 |
+
layer: idx
|
128 |
+
for idx, layer in enumerate(self.available_layers)
|
129 |
+
}
|
130 |
+
layer_idx = [
|
131 |
+
layer_idx_map[layer]
|
132 |
+
for layer in self.layers
|
133 |
+
]
|
134 |
+
|
135 |
+
offsets = [
|
136 |
+
idx * self.vocab_size
|
137 |
+
for idx in layer_idx
|
138 |
+
]
|
139 |
+
|
140 |
+
layer_embs = torch.stack([
|
141 |
+
weight[offset:offset + self.vocab_size]
|
142 |
+
for offset in offsets
|
143 |
+
])
|
144 |
+
|
145 |
+
# To (Batch x Length x Emb)
|
146 |
+
layer_embs = layer_embs.unsqueeze(0).unsqueeze(0)
|
147 |
+
return layer_embs
|
148 |
+
|
149 |
+
def encode_logits(self, logits, length=None):
|
150 |
+
"""Computes waveforms from a batch of discrete units
|
151 |
+
Arguments
|
152 |
+
---------
|
153 |
+
units: torch.tensor
|
154 |
+
Batch of discrete unit logits [batch, length, head, token]
|
155 |
+
or tokens [batch, length, head]
|
156 |
+
spk: torch.tensor
|
157 |
+
Batch of speaker embeddings [batch, spk_dim]
|
158 |
+
Returns
|
159 |
+
-------
|
160 |
+
waveforms: torch.tensor
|
161 |
+
Batch of mel-waveforms [batch, 1, time]
|
162 |
+
"""
|
163 |
+
|
164 |
+
# Convert logits to one-hot representations
|
165 |
+
# without losing the gradient
|
166 |
+
units_gumbel = torch.nn.functional.gumbel_softmax(
|
167 |
+
logits,
|
168 |
+
hard=False,
|
169 |
+
dim=-1
|
170 |
+
)
|
171 |
+
|
172 |
+
# Straight-through trick
|
173 |
+
_, argmax_idx = logits.max(dim=-1, keepdim=True)
|
174 |
+
units_ref = torch.zeros_like(logits).scatter_(
|
175 |
+
dim=-1, index=argmax_idx, src=torch.ones_like(logits)
|
176 |
+
)
|
177 |
+
units_hard = units_ref - units_gumbel.detach() + units_gumbel
|
178 |
+
|
179 |
+
# Sum over embeddings for each layer
|
180 |
+
units_hard_chunked = units_hard.chunk(
|
181 |
+
math.ceil(units_hard.size(1) / self.chunk_size),
|
182 |
+
dim=1
|
183 |
+
)
|
184 |
+
emb = torch.cat(
|
185 |
+
[
|
186 |
+
(self.layer_embs * units_hard_chunk.unsqueeze(-1)).sum(-2)
|
187 |
+
for units_hard_chunk in units_hard_chunked
|
188 |
+
],
|
189 |
+
dim=1
|
190 |
+
)
|
191 |
+
return emb
|
192 |
+
|
193 |
|
194 |
class DiscreteSpkEmb(Pretrained):
|
195 |
"""A ready-to-use class for utterance-level classification (e.g, speaker-id,
|
|
|
247 |
embeddings = self.mods.embedding_model(feats, length)
|
248 |
return embeddings.squeeze(1)
|
249 |
|
250 |
+
def encode_logits(self, logits, length=None):
|
251 |
+
"""Encodes the input audio logits into a single vector embedding.
|
252 |
+
|
253 |
+
Arguments
|
254 |
+
---------
|
255 |
+
audio : torch.tensor
|
256 |
+
Batch of tokenized audio [batch, time, heads]
|
257 |
+
length : torch.tensor
|
258 |
+
Lengths of the waveforms relative to the longest one in the
|
259 |
+
batch, tensor of shape [batch]. The longest one should have
|
260 |
+
relative length 1.0 and others len(waveform) / max_length.
|
261 |
+
Used for ignoring padding.
|
262 |
+
|
263 |
+
Returns
|
264 |
+
-------
|
265 |
+
torch.tensor
|
266 |
+
The encoded batch
|
267 |
+
"""
|
268 |
+
embeddings = self.mods.discrete_embedding_layer.encode_logits(logits)
|
269 |
+
att_w = self.mods.attention_mlp(embeddings)
|
270 |
+
feats = torch.matmul(att_w.transpose(2, -1), embeddings).squeeze(-2)
|
271 |
+
embeddings = self.mods.embedding_model(feats, length)
|
272 |
+
return embeddings.squeeze(1)
|
273 |
+
|
274 |
def forward(self, audio, length=None):
|
275 |
+
"""Encodes the input audio into a single vector embedding.
|
276 |
+
The waveforms should already be in the model's desired format.
|
277 |
+
Arguments
|
278 |
+
---------
|
279 |
+
audio : torch.tensor
|
280 |
+
Batch of tokenized audio [batch, time, heads]
|
281 |
+
or logits [batch, time, heads, tokens]
|
282 |
+
length : torch.tensor
|
283 |
+
Lengths of the waveforms relative to the longest one in the
|
284 |
+
batch, tensor of shape [batch]. The longest one should have
|
285 |
+
relative length 1.0 and others len(waveform) / max_length.
|
286 |
+
Used for ignoring padding.
|
287 |
+
|
288 |
+
Returns
|
289 |
+
-------
|
290 |
+
torch.tensor
|
291 |
+
The encoded batch
|
292 |
+
"""
|
293 |
+
audio_dim = audio.dim()
|
294 |
+
if audio_dim == 3:
|
295 |
+
embeddings = self.encode_batch(audio, length)
|
296 |
+
elif audio_dim == 4:
|
297 |
+
embeddings = self.encode_logits(audio, length)
|
298 |
+
else:
|
299 |
+
raise ValueError("Unsupported audio shape {audio.shape}")
|
300 |
+
return embeddings
|