flexthink commited on
Commit
6c51980
·
1 Parent(s): 380887d

Add support for Gumbel encoding

Browse files
Files changed (1) hide show
  1. custom_interface.py +130 -2
custom_interface.py CHANGED
@@ -1,6 +1,8 @@
1
  import torch
 
2
  from speechbrain.inference.interfaces import Pretrained
3
 
 
4
  class AttentionMLP(torch.nn.Module):
5
  def __init__(self, input_dim, hidden_dim):
6
  super(AttentionMLP, self).__init__()
@@ -32,8 +34,11 @@ class Discrete_EmbeddingLayer(torch.nn.Module):
32
  init: boolean (default: False):
33
  If set to True, init the embedding with the tokenizer embedding otherwise init randomly.
34
  freeze: boolean (default: False)
35
- If True, the embedding is frozen. If False, the model will be trained
36
  alongside with the rest of the pipeline.
 
 
 
37
 
38
  Example
39
  -------
@@ -62,6 +67,7 @@ class Discrete_EmbeddingLayer(torch.nn.Module):
62
  freeze=False,
63
  available_layers=None,
64
  layers=None,
 
65
  ):
66
  super(Discrete_EmbeddingLayer, self).__init__()
67
  self.vocab_size = vocab_size
@@ -74,6 +80,8 @@ class Discrete_EmbeddingLayer(torch.nn.Module):
74
  self.layers = layers
75
  self.available_layers = available_layers
76
  self.offsets = self.build_offsets()
 
 
77
 
78
  def init_embedding(self, weights):
79
  with torch.no_grad():
@@ -111,6 +119,77 @@ class Discrete_EmbeddingLayer(torch.nn.Module):
111
  in_embs = self.embedding(in_tokens_offset.int())
112
  return in_embs
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  class DiscreteSpkEmb(Pretrained):
116
  """A ready-to-use class for utterance-level classification (e.g, speaker-id,
@@ -168,5 +247,54 @@ class DiscreteSpkEmb(Pretrained):
168
  embeddings = self.mods.embedding_model(feats, length)
169
  return embeddings.squeeze(1)
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  def forward(self, audio, length=None):
172
- return self.encode_batch(audio, length)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ import math
3
  from speechbrain.inference.interfaces import Pretrained
4
 
5
+
6
  class AttentionMLP(torch.nn.Module):
7
  def __init__(self, input_dim, hidden_dim):
8
  super(AttentionMLP, self).__init__()
 
34
  init: boolean (default: False):
35
  If set to True, init the embedding with the tokenizer embedding otherwise init randomly.
36
  freeze: boolean (default: False)
37
+ If True, the embedding is frozen. If False, the model will be trained
38
  alongside with the rest of the pipeline.
39
+ chunk_size: int
40
+ The size of lengthwize chunks use when evaluating via
41
+ Gumbel softmax
42
 
43
  Example
44
  -------
 
67
  freeze=False,
68
  available_layers=None,
69
  layers=None,
70
+ chunk_size=100,
71
  ):
72
  super(Discrete_EmbeddingLayer, self).__init__()
73
  self.vocab_size = vocab_size
 
80
  self.layers = layers
81
  self.available_layers = available_layers
82
  self.offsets = self.build_offsets()
83
+ self.layer_embs = self.compute_layer_embs()
84
+ self.chunk_size = chunk_size
85
 
86
  def init_embedding(self, weights):
87
  with torch.no_grad():
 
119
  in_embs = self.embedding(in_tokens_offset.int())
120
  return in_embs
121
 
122
+ def compute_layer_embs(self):
123
+ weight = self.embedding.weight
124
+
125
+ # Compute offsets
126
+ layer_idx_map = {
127
+ layer: idx
128
+ for idx, layer in enumerate(self.available_layers)
129
+ }
130
+ layer_idx = [
131
+ layer_idx_map[layer]
132
+ for layer in self.layers
133
+ ]
134
+
135
+ offsets = [
136
+ idx * self.vocab_size
137
+ for idx in layer_idx
138
+ ]
139
+
140
+ layer_embs = torch.stack([
141
+ weight[offset:offset + self.vocab_size]
142
+ for offset in offsets
143
+ ])
144
+
145
+ # To (Batch x Length x Emb)
146
+ layer_embs = layer_embs.unsqueeze(0).unsqueeze(0)
147
+ return layer_embs
148
+
149
+ def encode_logits(self, logits, length=None):
150
+ """Computes waveforms from a batch of discrete units
151
+ Arguments
152
+ ---------
153
+ units: torch.tensor
154
+ Batch of discrete unit logits [batch, length, head, token]
155
+ or tokens [batch, length, head]
156
+ spk: torch.tensor
157
+ Batch of speaker embeddings [batch, spk_dim]
158
+ Returns
159
+ -------
160
+ waveforms: torch.tensor
161
+ Batch of mel-waveforms [batch, 1, time]
162
+ """
163
+
164
+ # Convert logits to one-hot representations
165
+ # without losing the gradient
166
+ units_gumbel = torch.nn.functional.gumbel_softmax(
167
+ logits,
168
+ hard=False,
169
+ dim=-1
170
+ )
171
+
172
+ # Straight-through trick
173
+ _, argmax_idx = logits.max(dim=-1, keepdim=True)
174
+ units_ref = torch.zeros_like(logits).scatter_(
175
+ dim=-1, index=argmax_idx, src=torch.ones_like(logits)
176
+ )
177
+ units_hard = units_ref - units_gumbel.detach() + units_gumbel
178
+
179
+ # Sum over embeddings for each layer
180
+ units_hard_chunked = units_hard.chunk(
181
+ math.ceil(units_hard.size(1) / self.chunk_size),
182
+ dim=1
183
+ )
184
+ emb = torch.cat(
185
+ [
186
+ (self.layer_embs * units_hard_chunk.unsqueeze(-1)).sum(-2)
187
+ for units_hard_chunk in units_hard_chunked
188
+ ],
189
+ dim=1
190
+ )
191
+ return emb
192
+
193
 
194
  class DiscreteSpkEmb(Pretrained):
195
  """A ready-to-use class for utterance-level classification (e.g, speaker-id,
 
247
  embeddings = self.mods.embedding_model(feats, length)
248
  return embeddings.squeeze(1)
249
 
250
+ def encode_logits(self, logits, length=None):
251
+ """Encodes the input audio logits into a single vector embedding.
252
+
253
+ Arguments
254
+ ---------
255
+ audio : torch.tensor
256
+ Batch of tokenized audio [batch, time, heads]
257
+ length : torch.tensor
258
+ Lengths of the waveforms relative to the longest one in the
259
+ batch, tensor of shape [batch]. The longest one should have
260
+ relative length 1.0 and others len(waveform) / max_length.
261
+ Used for ignoring padding.
262
+
263
+ Returns
264
+ -------
265
+ torch.tensor
266
+ The encoded batch
267
+ """
268
+ embeddings = self.mods.discrete_embedding_layer.encode_logits(logits)
269
+ att_w = self.mods.attention_mlp(embeddings)
270
+ feats = torch.matmul(att_w.transpose(2, -1), embeddings).squeeze(-2)
271
+ embeddings = self.mods.embedding_model(feats, length)
272
+ return embeddings.squeeze(1)
273
+
274
  def forward(self, audio, length=None):
275
+ """Encodes the input audio into a single vector embedding.
276
+ The waveforms should already be in the model's desired format.
277
+ Arguments
278
+ ---------
279
+ audio : torch.tensor
280
+ Batch of tokenized audio [batch, time, heads]
281
+ or logits [batch, time, heads, tokens]
282
+ length : torch.tensor
283
+ Lengths of the waveforms relative to the longest one in the
284
+ batch, tensor of shape [batch]. The longest one should have
285
+ relative length 1.0 and others len(waveform) / max_length.
286
+ Used for ignoring padding.
287
+
288
+ Returns
289
+ -------
290
+ torch.tensor
291
+ The encoded batch
292
+ """
293
+ audio_dim = audio.dim()
294
+ if audio_dim == 3:
295
+ embeddings = self.encode_batch(audio, length)
296
+ elif audio_dim == 4:
297
+ embeddings = self.encode_logits(audio, length)
298
+ else:
299
+ raise ValueError("Unsupported audio shape {audio.shape}")
300
+ return embeddings