discrete_hubert_spk_rec_ecapatdn / custom_interface.py
flexthink
Initial importZ
2678216
from typing import Mapping
import torch
import math
from speechbrain.inference.interfaces import Pretrained
class AttentionMLP(torch.nn.Module):
def __init__(self, input_dim, hidden_dim):
super(AttentionMLP, self).__init__()
self.layers = torch.nn.Sequential(
torch.nn.Linear(input_dim, hidden_dim),
torch.nn.ReLU(),
torch.nn.Linear(hidden_dim, 1, bias=False),
)
def forward(self, x):
x = self.layers(x)
att_w = torch.nn.functional.softmax(x, dim=2)
return att_w
class Discrete_EmbeddingLayer(torch.nn.Module):
"""This class handles embedding layers for discrete tokens.
Arguments
---------
num_codebooks: int ,
number of codebooks of the tokenizer.
vocab_size : int,
size of the dictionary of embeddings
emb_dim: int ,
the size of each embedding vector
pad_index: int (default: 0),
If specified, the entries at padding_idx do not contribute to the gradient.
init: boolean (default: False):
If set to True, init the embedding with the tokenizer embedding otherwise init randomly.
freeze: boolean (default: False)
If True, the embedding is frozen. If False, the model will be trained
alongside with the rest of the pipeline.
chunk_size: int
The size of lengthwize chunks use when evaluating via
Gumbel softmax
Example
-------
>>> from speechbrain.lobes.models.huggingface_transformers.encodec import Encodec
>>> model_hub = "facebook/encodec_24khz"
>>> save_path = "savedir"
>>> model = Encodec(model_hub, save_path)
>>> audio = torch.randn(4, 1000)
>>> length = torch.tensor([1.0, .5, .75, 1.0])
>>> tokens, emb = model.encode(audio, length)
>>> print(tokens.shape)
torch.Size([4, 4, 2])
>>> emb= Discrete_EmbeddingLayer(2, 1024, 1024)
>>> in_emb = emb(tokens)
>>> print(in_emb.shape)
torch.Size([4, 4, 2, 1024])
"""
def __init__(
self,
num_codebooks,
vocab_size,
emb_dim,
pad_index=0,
init=False,
freeze=False,
available_layers=None,
layers=None,
chunk_size=100,
):
super(Discrete_EmbeddingLayer, self).__init__()
self.vocab_size = vocab_size
self.num_codebooks = num_codebooks
self.freeze = freeze
self.embedding = torch.nn.Embedding(
num_codebooks * vocab_size, emb_dim
).requires_grad_(not self.freeze)
self.init = init
self.layers = layers
self.available_layers = available_layers
self.register_buffer("offsets", self.build_offsets())
self.register_buffer("layer_embs", self.compute_layer_embs())
self.chunk_size = chunk_size
def init_embedding(self, weights):
with torch.no_grad():
self.embedding.weight = torch.nn.Parameter(weights)
def build_offsets(self):
offsets = torch.arange(
0,
self.num_codebooks * self.vocab_size,
self.vocab_size,
)
if self.layers:
selected_layers = set(self.layers)
indexes = [
idx for idx, layer in enumerate(self.available_layers)
if layer in selected_layers
]
offsets = offsets[indexes]
return offsets
def forward(self, in_tokens):
"""Computes the embedding for discrete tokens.
a sample.
Arguments
---------
in_tokens : torch.Tensor
A (Batch x Time x num_codebooks)
audio sample
Returns
-------
in_embs : torch.Tensor
"""
with torch.set_grad_enabled(not self.freeze):
# Add unique token IDs across diffrent codebooks by adding num_codebooks * vocab_size
in_tokens_offset = in_tokens + self.offsets.to(in_tokens.device)
# Forward Pass to embedding and
in_embs = self.embedding(in_tokens_offset.int())
return in_embs
def compute_layer_embs(self):
weight = self.embedding.weight
# Compute offsets
layer_idx_map = {
layer: idx
for idx, layer in enumerate(self.available_layers)
}
layer_idx = [
layer_idx_map[layer]
for layer in self.layers
]
offsets = [
idx * self.vocab_size
for idx in layer_idx
]
layer_embs = torch.stack([
weight[offset:offset + self.vocab_size]
for offset in offsets
])
# To (Batch x Length x Emb)
layer_embs = layer_embs.unsqueeze(0).unsqueeze(0)
return layer_embs
def encode_logits(self, logits, length=None):
"""Computes waveforms from a batch of discrete units
Arguments
---------
units: torch.tensor
Batch of discrete unit logits [batch, length, head, token]
or tokens [batch, length, head]
spk: torch.tensor
Batch of speaker embeddings [batch, spk_dim]
Returns
-------
waveforms: torch.tensor
Batch of mel-waveforms [batch, 1, time]
"""
# Convert logits to one-hot representations
# without losing the gradient
units_gumbel = torch.nn.functional.gumbel_softmax(
logits,
hard=False,
dim=-1
)
# Straight-through trick
_, argmax_idx = logits.max(dim=-1, keepdim=True)
units_ref = torch.zeros_like(logits).scatter_(
dim=-1, index=argmax_idx, src=torch.ones_like(logits)
)
units_hard = units_ref - units_gumbel.detach() + units_gumbel
# Sum over embeddings for each layer
units_hard_chunked = units_hard.chunk(
math.ceil(units_hard.size(1) / self.chunk_size),
dim=1
)
emb = torch.cat(
[
(self.layer_embs * units_hard_chunk.unsqueeze(-1)).sum(-2)
for units_hard_chunk in units_hard_chunked
],
dim=1
)
return emb
def load_state_dict(self, state_dict, strict=True):
result = super().load_state_dict(state_dict, strict)
self.layer_embs = self.compute_layer_embs()
return result
class DiscreteSpkEmb(Pretrained):
"""A ready-to-use class for utterance-level classification (e.g, speaker-id,
language-id, emotion recognition, keyword spotting, etc).
The class assumes that an self-supervised encoder like wav2vec2/hubert and a classifier model
are defined in the yaml file. If you want to
convert the predicted index into a corresponding text label, please
provide the path of the label_encoder in a variable called 'lab_encoder_file'
within the yaml.
The class can be used either to run only the encoder (encode_batch()) to
extract embeddings or to run a classification step (classify_batch()).
```
Example
-------
>>> import torchaudio
>>> from speechbrain.pretrained import EncoderClassifier
>>> # Model is downloaded from the speechbrain HuggingFace repo
>>> tmpdir = getfixture("tmpdir")
>>> classifier = EncoderClassifier.from_hparams(
... source="speechbrain/spkrec-ecapa-voxceleb",
... savedir=tmpdir,
... )
>>> # Compute embeddings
>>> signal, fs = torchaudio.load("samples/audio_samples/example1.wav")
>>> embeddings = classifier.encode_batch(signal)
>>> # Classification
>>> prediction = classifier .classify_batch(signal)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def encode_batch(self, audio, length=None):
"""Encodes the input audio into a single vector embedding.
The waveforms should already be in the model's desired format.
Arguments
---------
audio : torch.tensor
Batch of tokenized audio [batch, time, heads]
length : torch.tensor
Lengths of the waveforms relative to the longest one in the
batch, tensor of shape [batch]. The longest one should have
relative length 1.0 and others len(waveform) / max_length.
Used for ignoring padding.
Returns
-------
torch.tensor
The encoded batch
"""
# Manage single waveforms in input
embeddings = self.mods.discrete_embedding_layer(audio)
att_w = self.mods.attention_mlp(embeddings)
feats = torch.matmul(att_w.transpose(2, -1), embeddings).squeeze(-2)
embeddings = self.mods.embedding_model(feats, length)
return embeddings.squeeze(1)
def encode_logits(self, logits, length=None):
"""Encodes the input audio logits into a single vector embedding.
Arguments
---------
audio : torch.tensor
Batch of tokenized audio [batch, time, heads]
length : torch.tensor
Lengths of the waveforms relative to the longest one in the
batch, tensor of shape [batch]. The longest one should have
relative length 1.0 and others len(waveform) / max_length.
Used for ignoring padding.
Returns
-------
torch.tensor
The encoded batch
"""
embeddings = self.mods.discrete_embedding_layer.encode_logits(logits)
att_w = self.mods.attention_mlp(embeddings)
feats = torch.matmul(att_w.transpose(2, -1), embeddings).squeeze(-2)
embeddings = self.mods.embedding_model(feats, length)
return embeddings.squeeze(1)
def forward(self, audio, length=None):
"""Encodes the input audio into a single vector embedding.
The waveforms should already be in the model's desired format.
Arguments
---------
audio : torch.tensor
Batch of tokenized audio [batch, time, heads]
or logits [batch, time, heads, tokens]
length : torch.tensor
Lengths of the waveforms relative to the longest one in the
batch, tensor of shape [batch]. The longest one should have
relative length 1.0 and others len(waveform) / max_length.
Used for ignoring padding.
Returns
-------
torch.tensor
The encoded batch
"""
audio_dim = audio.dim()
if audio_dim == 3:
embeddings = self.encode_batch(audio, length)
elif audio_dim == 4:
embeddings = self.encode_logits(audio, length)
else:
raise ValueError("Unsupported audio shape {audio.shape}")
return embeddings