OSUM / wenet /ssl /bestrq /bestrq_model.py
tomxxie
适配zeroGPU
568e264
import math
from typing import Dict, Optional, Tuple
import torch
from wenet.ssl.bestrq.mask import compute_mask_indices_v2
from wenet.utils.mask import make_non_pad_mask, make_pad_mask
from wenet.transformer.attention import RelPositionMultiHeadedAttention
from wenet.transformer.encoder_layer import ConformerEncoderLayer
def quantize_vector(latent: torch.Tensor, codebook: torch.Tensor):
"""
Symbols in comments:
B: batch_size.
D: latent_dim.
C: num_latent_classes per group
G: num of codebook groups.
Args:
latent: [B, D]
codebook: [C, G, D // G]
Returns:
(quantized, codes, onehot).
- quantized: [B, D]
- codes: [B, G]
- onehot: [B, G, C]
"""
assert len(codebook.size()) == 3
b, d = latent.size()
c, g, _ = codebook.size()
assert d % g == 0
latent = latent.reshape(b, g, d // g)
# [B, G, C]
# torch.transpose(codebook, [2,1,0])
distance = (
# [b, g, 1]
torch.sum(latent**2, -1, keepdim=True) -
# [b, g, c]
2 * torch.einsum('bgd,cgd->bgc', latent, codebook) +
# [1, g, c]
torch.sum(codebook.permute([2, 1, 0])**2, 0, keepdim=True))
# [B, G]
codes = torch.argmin(distance, dim=-1)
# [B, G, C]
one_hot = torch.nn.functional.one_hot(codes, c).type(codebook.dtype)
quantized = torch.einsum('bgc,cgd->bgd', one_hot, codebook)
quantized = torch.reshape(quantized, [b, d])
return quantized, codes, one_hot
class BestRQModel(torch.nn.Module):
def __init__(
self,
encoder: torch.nn.Module,
num_mel_bins: int = 80,
embedding_dim: int = 16,
num_embeddings: int = 8192,
num_codebooks: int = 1,
mask_prob: float = 0.01,
mask_length: int = 10,
min_masks: int = 2,
norm_epsilon: float = 1e-5,
out_bias: bool = False,
features_regularization_weight: float = 0.01,
) -> None:
super().__init__()
assert mask_prob > 0.0
self.mask_prob = mask_prob
self.mask_length = mask_length
self.min_masks = min_masks
self.num_codebooks = num_codebooks
self.num_embeddings = num_embeddings
self.features_regularization_weight = features_regularization_weight
# encoder
self.encoder = encoder
# n softmax
self.encoder_top_n_out = torch.nn.parameter.Parameter(
torch.empty(self.num_codebooks, self.encoder.output_size(),
num_embeddings))
torch.nn.init.trunc_normal_(self.encoder_top_n_out, std=0.02)
self.out_bias = out_bias
if self.out_bias:
self.encoder_top_n_out_bias = torch.nn.parameter.Parameter(
torch.empty(self.num_codebooks, num_embeddings))
torch.nn.init.zeros_(self.encoder_top_n_out_bias)
# stack input: eg: fbank
self.stack_frames = self.encoder.embed.right_context + 1
self.stride = self.encoder.embed.subsampling_rate
input_dim = num_mel_bins * self.stride
# random projectoin
self.projection = torch.nn.parameter.Parameter(
torch.empty(input_dim, embedding_dim * self.num_codebooks),
requires_grad=False,
)
torch.nn.init.xavier_uniform_(self.projection)
# codebooks
# [num_embeddings, num_codebooks, num_embeddings] means
# [C, G, D] see quantize_vector
self.embeddings = torch.nn.parameter.Parameter(
torch.empty(num_embeddings, self.num_codebooks, embedding_dim),
requires_grad=False,
)
torch.nn.init.normal_(self.embeddings)
self.embeddings /= (self.embeddings.norm(dim=-1, p=2, keepdim=True) +
1e-8)
# force reset encoder papameter
self.reset_encoder_parameter()
def reset_encoder_parameter(self):
def _reset_parameter(module: torch.nn.Module):
if isinstance(module, torch.nn.Linear):
torch.nn.init.trunc_normal_(module.weight.data,
mean=0.0,
std=0.02)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, torch.nn.Conv1d):
torch.nn.init.kaiming_normal_(module.weight)
if module.bias is not None:
k = math.sqrt(module.groups /
(module.in_channels * module.kernel_size[0]))
torch.nn.init.uniform_(module.bias, a=-k, b=k)
elif isinstance(module, torch.Tensor):
torch.nn.init.trunc_normal_(module)
else:
raise NotImplementedError("other module not support now")
encoders = self.encoder.encoders
for _, layer in enumerate(encoders):
self_attn = layer.self_attn
_reset_parameter(self_attn.linear_q)
_reset_parameter(self_attn.linear_k)
_reset_parameter(self_attn.linear_v)
_reset_parameter(self_attn.linear_out)
if isinstance(self_attn, RelPositionMultiHeadedAttention):
_reset_parameter(self_attn.pos_bias_u)
_reset_parameter(self_attn.pos_bias_v)
if isinstance(layer, ConformerEncoderLayer):
conv1, conv2 = (layer.conv_module.pointwise_conv1,
layer.conv_module.depthwise_conv)
_reset_parameter(conv1)
_reset_parameter(conv2)
def forward(
self,
batch: Dict,
device: torch.device,
):
xs = batch['feats'].to(device)
xs_lens = batch['feats_lengths'].to(device)
input = xs
features_pen: Optional[torch.Tensor] = None
if self.features_regularization_weight != 0.0:
features_pen = input.pow(2).mean()
# 1 mask input
xs, code_ids_mask = self._apply_mask_signal(xs, xs_lens)
# 2.0 stack fbank
unmasked_xs = self._stack_features(input, xs_lens)
masked_xs = xs
# 2.1 get nearest embedding
target_ids = self._nearest_embedding_idx(unmasked_xs)
target_ids = target_ids[:, :code_ids_mask.size(1), :]
# 3 forward xxx-formaer block and its subsampling layer
out, out_mask = self.encoder(masked_xs, xs_lens)
# 4 get logits
out = out.unsqueeze(1) # [B, 1, T', dim]
top_n_out = self.encoder_top_n_out.unsqueeze(
0) # [1, num_codebooks, dim, num_embeddings]
out = torch.matmul(out,
top_n_out) # [B, num_codebooks, T', num_embeddings]
if self.out_bias:
out = out + self.encoder_top_n_out_bias.unsqueeze(0).unsqueeze(2)
# 5 compute loss
masks = out_mask.squeeze(1) * code_ids_mask
loss = self._compute_loss(out, target_ids, mask=masks)
if self.features_regularization_weight != 0.0:
loss = loss + self.features_regularization_weight * features_pen
# 6 other info: num codes used in batch, unique num codes used in batch
num_codes = masks.sum() * self.num_codebooks
uniq_num_codes = torch.tensor(
torch.unique(target_ids * masks.unsqueeze(2)).numel()).detach()
ids_corr = out.argmax(dim=-1, keepdim=False).transpose(1,
2) == target_ids
codes_acc = (ids_corr * masks.unsqueeze(2)).sum() / num_codes
return {
"codes_acc": codes_acc,
"features_l2": features_pen,
"loss": loss,
"num_codes": num_codes,
"uniq_num_codes": uniq_num_codes,
"th_accuracy": codes_acc,
}
def _apply_mask_signal(
self, input: torch.Tensor,
input_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
device = input.device
B, T, _ = input.size()
padding_mask = make_pad_mask(input_lens)
# calc subsampling masks
padding_mask_stride = padding_mask.unfold(
1,
size=self.stack_frames,
step=self.stride,
)
padding_mask, _ = torch.max(padding_mask_stride, dim=-1)
masks = compute_mask_indices_v2(padding_mask.size(),
padding_mask,
self.mask_prob,
self.mask_length,
min_masks=self.min_masks,
device=device)
# calc signal mask
subsampling_mask = masks
bool_stride_mask = torch.ones_like(padding_mask_stride, device=device)
mask_stride = torch.where(masks.unsqueeze(-1), bool_stride_mask, False)
# recover orign seq masks
masks = mask_stride[:, :, :self.stride].flatten(start_dim=1)
masks_padding = torch.zeros(
B,
T,
device=device,
dtype=padding_mask.dtype,
)
masks_padding[:, :masks.size(-1)] = masks
masks = masks_padding
masks_expand = masks.unsqueeze(-1) # [B, T, 1]
# NOTE(Mddct): you can use size (b,t,d) for torch.normal
mask_emb = torch.normal(mean=0, std=0.1,
size=(1, 1, input.size(2))).to(input.device)
xs = torch.where(masks_expand, mask_emb, input)
return xs, subsampling_mask
def _stack_features(self, input: torch.Tensor,
input_lens: torch.Tensor) -> torch.Tensor:
stack_input = input.unfold(1, size=self.stride, step=self.stride)
stack_input = stack_input.transpose(-1, -2)
b, n, f, d = stack_input.size()
stack_input = stack_input.reshape(b, n, f * d)
# NOTE(Mddct): important!!!
# norm stack features
mask = make_non_pad_mask(input_lens)
stack_mask = mask.unfold(1, size=self.stride, step=self.stride)
stack_mask, _ = torch.min(stack_mask, dim=-1)
stack_input = stack_input * stack_mask.unsqueeze(2)
mean = stack_input.sum(1, keepdim=True) / stack_mask.sum(
dim=1, keepdim=True).unsqueeze(1)
std = torch.sqrt(((stack_input - mean)**2).sum(dim=1, keepdim=True) /
stack_mask.sum(dim=1, keepdim=True).unsqueeze(1))
norm_stack_input = (stack_input - mean) / (std + 1e-5)
return norm_stack_input
def _compute_loss(self, input: torch.Tensor, target: torch.Tensor,
mask: torch.Tensor) -> torch.Tensor:
logits = input.transpose(1, 2).contiguous().view(-1, input.size(-1))
loss = torch.nn.functional.cross_entropy(
logits,
target.contiguous().view(-1),
reduction='none',
)
loss = (loss * mask.view(-1)).sum() / mask.sum()
return loss
def _nearest_embedding_idx(self, xs: torch.Tensor) -> torch.Tensor:
xs = torch.matmul(xs, self.projection.to(xs.device))
xs = xs / (xs.norm(dim=-1, p=2, keepdim=True) + 1e-8)
codebooks = self.embeddings
B, T, C = xs.size()
xs_flatten = xs.view(B * T, C)
_, codes, _ = quantize_vector(xs_flatten, codebooks)
return codes.reshape(B, T, -1) # [B, T, num_codebooks]