|
from typing import Mapping |
|
import torch |
|
import math |
|
from speechbrain.inference.interfaces import Pretrained |
|
|
|
|
|
class AttentionMLP(torch.nn.Module): |
|
def __init__(self, input_dim, hidden_dim): |
|
super(AttentionMLP, self).__init__() |
|
self.layers = torch.nn.Sequential( |
|
torch.nn.Linear(input_dim, hidden_dim), |
|
torch.nn.ReLU(), |
|
torch.nn.Linear(hidden_dim, 1, bias=False), |
|
) |
|
|
|
def forward(self, x): |
|
x = self.layers(x) |
|
att_w = torch.nn.functional.softmax(x, dim=2) |
|
return att_w |
|
|
|
|
|
class Discrete_EmbeddingLayer(torch.nn.Module): |
|
"""This class handles embedding layers for discrete tokens. |
|
|
|
Arguments |
|
--------- |
|
num_codebooks: int , |
|
number of codebooks of the tokenizer. |
|
vocab_size : int, |
|
size of the dictionary of embeddings |
|
emb_dim: int , |
|
the size of each embedding vector |
|
pad_index: int (default: 0), |
|
If specified, the entries at padding_idx do not contribute to the gradient. |
|
init: boolean (default: False): |
|
If set to True, init the embedding with the tokenizer embedding otherwise init randomly. |
|
freeze: boolean (default: False) |
|
If True, the embedding is frozen. If False, the model will be trained |
|
alongside with the rest of the pipeline. |
|
chunk_size: int |
|
The size of lengthwize chunks use when evaluating via |
|
Gumbel softmax |
|
|
|
Example |
|
------- |
|
>>> from speechbrain.lobes.models.huggingface_transformers.encodec import Encodec |
|
>>> model_hub = "facebook/encodec_24khz" |
|
>>> save_path = "savedir" |
|
>>> model = Encodec(model_hub, save_path) |
|
>>> audio = torch.randn(4, 1000) |
|
>>> length = torch.tensor([1.0, .5, .75, 1.0]) |
|
>>> tokens, emb = model.encode(audio, length) |
|
>>> print(tokens.shape) |
|
torch.Size([4, 4, 2]) |
|
>>> emb= Discrete_EmbeddingLayer(2, 1024, 1024) |
|
>>> in_emb = emb(tokens) |
|
>>> print(in_emb.shape) |
|
torch.Size([4, 4, 2, 1024]) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
num_codebooks, |
|
vocab_size, |
|
emb_dim, |
|
pad_index=0, |
|
init=False, |
|
freeze=False, |
|
available_layers=None, |
|
layers=None, |
|
chunk_size=100, |
|
): |
|
super(Discrete_EmbeddingLayer, self).__init__() |
|
self.vocab_size = vocab_size |
|
self.num_codebooks = num_codebooks |
|
self.freeze = freeze |
|
self.embedding = torch.nn.Embedding( |
|
num_codebooks * vocab_size, emb_dim |
|
).requires_grad_(not self.freeze) |
|
self.init = init |
|
self.layers = layers |
|
self.available_layers = available_layers |
|
self.register_buffer("offsets", self.build_offsets()) |
|
self.register_buffer("layer_embs", self.compute_layer_embs()) |
|
self.chunk_size = chunk_size |
|
|
|
def init_embedding(self, weights): |
|
with torch.no_grad(): |
|
self.embedding.weight = torch.nn.Parameter(weights) |
|
|
|
def build_offsets(self): |
|
offsets = torch.arange( |
|
0, |
|
self.num_codebooks * self.vocab_size, |
|
self.vocab_size, |
|
) |
|
if self.layers: |
|
selected_layers = set(self.layers) |
|
indexes = [ |
|
idx for idx, layer in enumerate(self.available_layers) |
|
if layer in selected_layers |
|
] |
|
offsets = offsets[indexes] |
|
return offsets |
|
|
|
def forward(self, in_tokens): |
|
"""Computes the embedding for discrete tokens. |
|
a sample. |
|
|
|
Arguments |
|
--------- |
|
in_tokens : torch.Tensor |
|
A (Batch x Time x num_codebooks) |
|
audio sample |
|
Returns |
|
------- |
|
in_embs : torch.Tensor |
|
""" |
|
with torch.set_grad_enabled(not self.freeze): |
|
|
|
in_tokens_offset = in_tokens + self.offsets.to(in_tokens.device) |
|
|
|
in_embs = self.embedding(in_tokens_offset.int()) |
|
return in_embs |
|
|
|
def compute_layer_embs(self): |
|
weight = self.embedding.weight |
|
|
|
|
|
layer_idx_map = { |
|
layer: idx |
|
for idx, layer in enumerate(self.available_layers) |
|
} |
|
layer_idx = [ |
|
layer_idx_map[layer] |
|
for layer in self.layers |
|
] |
|
|
|
offsets = [ |
|
idx * self.vocab_size |
|
for idx in layer_idx |
|
] |
|
|
|
layer_embs = torch.stack([ |
|
weight[offset:offset + self.vocab_size] |
|
for offset in offsets |
|
]) |
|
|
|
|
|
layer_embs = layer_embs.unsqueeze(0).unsqueeze(0) |
|
return layer_embs |
|
|
|
def encode_logits(self, logits, length=None): |
|
"""Computes waveforms from a batch of discrete units |
|
Arguments |
|
--------- |
|
units: torch.tensor |
|
Batch of discrete unit logits [batch, length, head, token] |
|
or tokens [batch, length, head] |
|
spk: torch.tensor |
|
Batch of speaker embeddings [batch, spk_dim] |
|
Returns |
|
------- |
|
waveforms: torch.tensor |
|
Batch of mel-waveforms [batch, 1, time] |
|
""" |
|
|
|
|
|
|
|
units_gumbel = torch.nn.functional.gumbel_softmax( |
|
logits, |
|
hard=False, |
|
dim=-1 |
|
) |
|
|
|
|
|
_, argmax_idx = logits.max(dim=-1, keepdim=True) |
|
units_ref = torch.zeros_like(logits).scatter_( |
|
dim=-1, index=argmax_idx, src=torch.ones_like(logits) |
|
) |
|
units_hard = units_ref - units_gumbel.detach() + units_gumbel |
|
|
|
|
|
units_hard_chunked = units_hard.chunk( |
|
math.ceil(units_hard.size(1) / self.chunk_size), |
|
dim=1 |
|
) |
|
emb = torch.cat( |
|
[ |
|
(self.layer_embs * units_hard_chunk.unsqueeze(-1)).sum(-2) |
|
for units_hard_chunk in units_hard_chunked |
|
], |
|
dim=1 |
|
) |
|
return emb |
|
|
|
def load_state_dict(self, state_dict, strict=True): |
|
result = super().load_state_dict(state_dict, strict) |
|
self.layer_embs = self.compute_layer_embs() |
|
return result |
|
|
|
|
|
class DiscreteSpkEmb(Pretrained): |
|
"""A ready-to-use class for utterance-level classification (e.g, speaker-id, |
|
language-id, emotion recognition, keyword spotting, etc). |
|
The class assumes that an self-supervised encoder like wav2vec2/hubert and a classifier model |
|
are defined in the yaml file. If you want to |
|
convert the predicted index into a corresponding text label, please |
|
provide the path of the label_encoder in a variable called 'lab_encoder_file' |
|
within the yaml. |
|
The class can be used either to run only the encoder (encode_batch()) to |
|
extract embeddings or to run a classification step (classify_batch()). |
|
``` |
|
Example |
|
------- |
|
>>> import torchaudio |
|
>>> from speechbrain.pretrained import EncoderClassifier |
|
>>> # Model is downloaded from the speechbrain HuggingFace repo |
|
>>> tmpdir = getfixture("tmpdir") |
|
>>> classifier = EncoderClassifier.from_hparams( |
|
... source="speechbrain/spkrec-ecapa-voxceleb", |
|
... savedir=tmpdir, |
|
... ) |
|
>>> # Compute embeddings |
|
>>> signal, fs = torchaudio.load("samples/audio_samples/example1.wav") |
|
>>> embeddings = classifier.encode_batch(signal) |
|
>>> # Classification |
|
>>> prediction = classifier .classify_batch(signal) |
|
""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
def encode_batch(self, audio, length=None): |
|
"""Encodes the input audio into a single vector embedding. |
|
The waveforms should already be in the model's desired format. |
|
Arguments |
|
--------- |
|
audio : torch.tensor |
|
Batch of tokenized audio [batch, time, heads] |
|
length : torch.tensor |
|
Lengths of the waveforms relative to the longest one in the |
|
batch, tensor of shape [batch]. The longest one should have |
|
relative length 1.0 and others len(waveform) / max_length. |
|
Used for ignoring padding. |
|
|
|
Returns |
|
------- |
|
torch.tensor |
|
The encoded batch |
|
""" |
|
|
|
embeddings = self.mods.discrete_embedding_layer(audio) |
|
att_w = self.mods.attention_mlp(embeddings) |
|
feats = torch.matmul(att_w.transpose(2, -1), embeddings).squeeze(-2) |
|
embeddings = self.mods.embedding_model(feats, length) |
|
return embeddings.squeeze(1) |
|
|
|
def encode_logits(self, logits, length=None): |
|
"""Encodes the input audio logits into a single vector embedding. |
|
|
|
Arguments |
|
--------- |
|
audio : torch.tensor |
|
Batch of tokenized audio [batch, time, heads] |
|
length : torch.tensor |
|
Lengths of the waveforms relative to the longest one in the |
|
batch, tensor of shape [batch]. The longest one should have |
|
relative length 1.0 and others len(waveform) / max_length. |
|
Used for ignoring padding. |
|
|
|
Returns |
|
------- |
|
torch.tensor |
|
The encoded batch |
|
""" |
|
embeddings = self.mods.discrete_embedding_layer.encode_logits(logits) |
|
att_w = self.mods.attention_mlp(embeddings) |
|
feats = torch.matmul(att_w.transpose(2, -1), embeddings).squeeze(-2) |
|
embeddings = self.mods.embedding_model(feats, length) |
|
return embeddings.squeeze(1) |
|
|
|
def forward(self, audio, length=None): |
|
"""Encodes the input audio into a single vector embedding. |
|
The waveforms should already be in the model's desired format. |
|
Arguments |
|
--------- |
|
audio : torch.tensor |
|
Batch of tokenized audio [batch, time, heads] |
|
or logits [batch, time, heads, tokens] |
|
length : torch.tensor |
|
Lengths of the waveforms relative to the longest one in the |
|
batch, tensor of shape [batch]. The longest one should have |
|
relative length 1.0 and others len(waveform) / max_length. |
|
Used for ignoring padding. |
|
|
|
Returns |
|
------- |
|
torch.tensor |
|
The encoded batch |
|
""" |
|
audio_dim = audio.dim() |
|
if audio_dim == 3: |
|
embeddings = self.encode_batch(audio, length) |
|
elif audio_dim == 4: |
|
embeddings = self.encode_logits(audio, length) |
|
else: |
|
raise ValueError("Unsupported audio shape {audio.shape}") |
|
return embeddings |
|
|