Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu) | |
# 2023 NetEase Inc. (authors: Yuting Yang) | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# Modified from ESPnet(https://github.com/espnet/espnet) and | |
# fairseq(https://github.com/facebookresearch/fairseq) | |
from typing import Dict, Optional | |
import torch | |
import torch.nn.functional as F | |
from wenet.transformer.ctc import CTC | |
from wenet.transformer.decoder import TransformerDecoder | |
from wenet.ctl_model.encoder import TransformerEncoder | |
from wenet.transformer.asr_model import ASRModel | |
from wenet.utils.common import IGNORE_ID | |
class CTLModel(ASRModel): | |
""" | |
Implementation of Interspeecch 2023 paper: | |
'Enhancing the Unified Streaming and Non-streaming Model | |
with Contrastive Learning' | |
https://arxiv.org/abs/2306.00755 | |
""" | |
def __init__( | |
self, | |
vocab_size: int, | |
encoder: TransformerEncoder, | |
decoder: TransformerDecoder, | |
ctc: CTC, | |
ctc_weight: float = 0.5, | |
ignore_id: int = IGNORE_ID, | |
reverse_weight: float = 0.0, | |
lsm_weight: float = 0.0, | |
length_normalized_loss: bool = False, | |
logit_temp: float = 0.1, | |
n_negatives: int = 0, | |
ctl_weight: float = 1, | |
special_tokens: dict = None, | |
): | |
assert 0.0 <= ctc_weight <= 1.0, ctc_weight | |
super().__init__(vocab_size, | |
encoder, | |
decoder, | |
ctc, | |
ctc_weight, | |
ignore_id, | |
reverse_weight, | |
lsm_weight, | |
length_normalized_loss, | |
special_tokens=special_tokens) | |
# For CTL Loss | |
self.n_negatives = n_negatives | |
self.ctl_weight = ctl_weight | |
self.logit_temp = logit_temp | |
def forward( | |
self, | |
batch: dict, | |
device: torch.device, | |
) -> Dict[str, Optional[torch.Tensor]]: | |
speech = batch['feats'].to(device) | |
speech_lengths = batch['feats_lengths'].to(device) | |
text = batch['target'].to(device) | |
text_lengths = batch['target_lengths'].to(device) | |
loss_full, encoder_out_full, _, _ = self.forward_full( | |
speech, speech_lengths, text, text_lengths) | |
loss_chunk, encoder_out, lens_chunk, encoder_mask = self.forward_chunk( | |
speech, speech_lengths, text, text_lengths) | |
ctl_loss = 0.0 | |
if self.ctl_weight > 0 and self.n_negatives > 0: | |
num = encoder_out_full.size(1) | |
targets = encoder_out_full | |
src = encoder_out | |
negs, negs_idxs = self.sample_negatives(targets, | |
targets.size(1), | |
speech_lengths=lens_chunk) | |
ctl_loss = self.CTL(src, targets, negs, encoder_mask) | |
loss = loss_full + loss_chunk + self.ctl_weight * ctl_loss | |
return { | |
"loss": loss, | |
"loss_full": loss_full, | |
"loss_chunk": loss_chunk, | |
"loss_ctl": ctl_loss | |
} | |
def forward_full( | |
self, | |
speech: torch.Tensor, | |
speech_lengths: torch.Tensor, | |
text: torch.Tensor, | |
text_lengths: torch.Tensor, | |
): | |
"""Full context mode | |
Frontend + Encoder + Decoder + Calc loss | |
Args: | |
speech: (Batch, Length, ...) | |
speech_lengths: (Batch, ) | |
text: (Batch, Length) | |
text_lengths: (Batch,) | |
""" | |
assert text_lengths.dim() == 1, text_lengths.shape | |
# Check that batch_size is unified | |
assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == | |
text_lengths.shape[0]), (speech.shape, speech_lengths.shape, | |
text.shape, text_lengths.shape) | |
# 1. Encoder | |
encoder_out, encoder_mask = self.encoder.forward_full( | |
speech, speech_lengths) | |
encoder_out_lens = encoder_mask.squeeze(1).sum(1) | |
# 2a. Attention-decoder branch | |
if self.ctc_weight != 1.0: | |
loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask, | |
text, text_lengths) | |
else: | |
loss_att = None | |
# 2b. CTC branch | |
if self.ctc_weight != 0.0: | |
loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, | |
text_lengths) | |
else: | |
loss_ctc = None | |
if loss_ctc is None: | |
loss = loss_att | |
elif loss_att is None: | |
loss = loss_ctc | |
else: | |
loss = self.ctc_weight * loss_ctc[0] + (1 - | |
self.ctc_weight) * loss_att | |
return loss, encoder_out, encoder_out_lens, encoder_mask | |
def forward_chunk( | |
self, | |
speech: torch.Tensor, | |
speech_lengths: torch.Tensor, | |
text: torch.Tensor, | |
text_lengths: torch.Tensor, | |
): | |
"""Chunk-based context mode | |
Frontend + Encoder + Decoder + Calc loss | |
Args: | |
speech: (Batch, Length, ...) | |
speech_lengths: (Batch, ) | |
text: (Batch, Length) | |
text_lengths: (Batch,) | |
""" | |
assert text_lengths.dim() == 1, text_lengths.shape | |
# Check that batch_size is unified | |
assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] == | |
text_lengths.shape[0]), (speech.shape, speech_lengths.shape, | |
text.shape, text_lengths.shape) | |
# 1. Encoder | |
encoder_out, encoder_mask = self.encoder(speech, speech_lengths) | |
encoder_out_lens = encoder_mask.squeeze(1).sum(1) | |
# 2a. Attention-decoder branch | |
if self.ctc_weight != 1.0: | |
loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask, | |
text, text_lengths) | |
else: | |
loss_att = None | |
# 2b. CTC branch | |
if self.ctc_weight != 0.0: | |
loss_ctc = self.ctc(encoder_out, encoder_out_lens, text, | |
text_lengths) | |
else: | |
loss_ctc = None | |
if loss_ctc is None: | |
loss = loss_att | |
elif loss_att is None: | |
loss = loss_ctc | |
else: | |
loss = self.ctc_weight * loss_ctc[0] + (1 - | |
self.ctc_weight) * loss_att | |
return loss, encoder_out, encoder_out_lens, encoder_mask | |
def sample_negatives(self, y, num, padding_count=0, speech_lengths=None): | |
if self.n_negatives == 0: | |
return y.new(0) | |
bsz, tsz, fsz = y.shape | |
y = y.reshape(-1, fsz) # BTC => (BxT)C | |
# FIXME: what happens if padding_count is specified? | |
high = tsz - (padding_count or 0) | |
with torch.no_grad(): | |
assert high > 1, f"{bsz,tsz,fsz}" | |
if self.n_negatives > 0: | |
tszs = (torch.arange(num).unsqueeze(-1).expand( | |
-1, self.n_negatives).flatten()) | |
if speech_lengths is not None: | |
neg_idxs = [ | |
torch.randint(low=0, | |
high=speech_lengths[i].item() - 1, | |
size=(1, self.n_negatives * tsz)) | |
for i in range(len(speech_lengths)) | |
] | |
neg_idxs = torch.cat(neg_idxs).reshape( | |
bsz, self.n_negatives * tsz) | |
else: | |
neg_idxs = torch.randint(low=0, | |
high=num - 1, | |
size=(bsz, | |
self.n_negatives * tsz)) | |
neg_idxs[neg_idxs >= tszs] += 1 | |
if self.n_negatives > 0: | |
neg_idxs = neg_idxs + (torch.arange(bsz).unsqueeze(1) * high) | |
negs = y[neg_idxs.view(-1)] | |
negs = negs.contiguous().view(bsz, num, self.n_negatives, | |
fsz).permute(2, 0, 1, 3) # to NxBxTxC | |
return negs, neg_idxs | |
def compute_preds(self, x, y, negatives): | |
neg_is_pos = (y == negatives).all(-1) | |
y = y.unsqueeze(0) | |
targets = torch.cat([y, negatives], dim=0) | |
logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1) | |
logits = logits / self.logit_temp | |
logits = logits.type_as(x) | |
if neg_is_pos.any(): | |
if not hasattr(self, "_inftensor"): | |
self._inftensor = float("-inf") | |
# logits[1:] = index_put(logits[1:], neg_is_pos, self._inftensor) | |
logits[1:][neg_is_pos] = self._inftensor | |
logits = logits.transpose(0, 2) | |
logits = logits.transpose(0, 1) | |
logits = logits.reshape(-1, logits.size(-1)) | |
return logits | |
def CTL(self, x, y, negs, mask=None): | |
# Step1: compute cosine similarity, shape [B*T, n_negatives+1] | |
logits = self.compute_preds(x, y, negs) | |
# Step2: target shape [B*T] | |
target = x.new_zeros(x.size(0) * x.size(1), dtype=torch.long) | |
# Step3: compute CTL loss | |
if mask is not None: | |
normalize_length = mask.sum() | |
bz, sz = mask.size(0), mask.size(-1) | |
mask = mask.squeeze(1).reshape(bz * sz).eq(0) | |
ce = F.cross_entropy(logits, target, reduction='none') | |
loss = ce.masked_fill(mask, 0).sum() / normalize_length | |
else: | |
loss = F.cross_entropy(logits, target) | |
return loss | |