OSUM / wenet /ctl_model /asr_model_ctl.py
tomxxie
适配zeroGPU
568e264
raw
history blame
10.3 kB
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
# 2023 NetEase Inc. (authors: Yuting Yang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet) and
# fairseq(https://github.com/facebookresearch/fairseq)
from typing import Dict, Optional
import torch
import torch.nn.functional as F
from wenet.transformer.ctc import CTC
from wenet.transformer.decoder import TransformerDecoder
from wenet.ctl_model.encoder import TransformerEncoder
from wenet.transformer.asr_model import ASRModel
from wenet.utils.common import IGNORE_ID
class CTLModel(ASRModel):
"""
Implementation of Interspeecch 2023 paper:
'Enhancing the Unified Streaming and Non-streaming Model
with Contrastive Learning'
https://arxiv.org/abs/2306.00755
"""
def __init__(
self,
vocab_size: int,
encoder: TransformerEncoder,
decoder: TransformerDecoder,
ctc: CTC,
ctc_weight: float = 0.5,
ignore_id: int = IGNORE_ID,
reverse_weight: float = 0.0,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
logit_temp: float = 0.1,
n_negatives: int = 0,
ctl_weight: float = 1,
special_tokens: dict = None,
):
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
super().__init__(vocab_size,
encoder,
decoder,
ctc,
ctc_weight,
ignore_id,
reverse_weight,
lsm_weight,
length_normalized_loss,
special_tokens=special_tokens)
# For CTL Loss
self.n_negatives = n_negatives
self.ctl_weight = ctl_weight
self.logit_temp = logit_temp
@torch.jit.unused
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
speech = batch['feats'].to(device)
speech_lengths = batch['feats_lengths'].to(device)
text = batch['target'].to(device)
text_lengths = batch['target_lengths'].to(device)
loss_full, encoder_out_full, _, _ = self.forward_full(
speech, speech_lengths, text, text_lengths)
loss_chunk, encoder_out, lens_chunk, encoder_mask = self.forward_chunk(
speech, speech_lengths, text, text_lengths)
ctl_loss = 0.0
if self.ctl_weight > 0 and self.n_negatives > 0:
num = encoder_out_full.size(1)
targets = encoder_out_full
src = encoder_out
negs, negs_idxs = self.sample_negatives(targets,
targets.size(1),
speech_lengths=lens_chunk)
ctl_loss = self.CTL(src, targets, negs, encoder_mask)
loss = loss_full + loss_chunk + self.ctl_weight * ctl_loss
return {
"loss": loss,
"loss_full": loss_full,
"loss_chunk": loss_chunk,
"loss_ctl": ctl_loss
}
def forward_full(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
):
"""Full context mode
Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] ==
text_lengths.shape[0]), (speech.shape, speech_lengths.shape,
text.shape, text_lengths.shape)
# 1. Encoder
encoder_out, encoder_mask = self.encoder.forward_full(
speech, speech_lengths)
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
# 2a. Attention-decoder branch
if self.ctc_weight != 1.0:
loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask,
text, text_lengths)
else:
loss_att = None
# 2b. CTC branch
if self.ctc_weight != 0.0:
loss_ctc = self.ctc(encoder_out, encoder_out_lens, text,
text_lengths)
else:
loss_ctc = None
if loss_ctc is None:
loss = loss_att
elif loss_att is None:
loss = loss_ctc
else:
loss = self.ctc_weight * loss_ctc[0] + (1 -
self.ctc_weight) * loss_att
return loss, encoder_out, encoder_out_lens, encoder_mask
def forward_chunk(
self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
text: torch.Tensor,
text_lengths: torch.Tensor,
):
"""Chunk-based context mode
Frontend + Encoder + Decoder + Calc loss
Args:
speech: (Batch, Length, ...)
speech_lengths: (Batch, )
text: (Batch, Length)
text_lengths: (Batch,)
"""
assert text_lengths.dim() == 1, text_lengths.shape
# Check that batch_size is unified
assert (speech.shape[0] == speech_lengths.shape[0] == text.shape[0] ==
text_lengths.shape[0]), (speech.shape, speech_lengths.shape,
text.shape, text_lengths.shape)
# 1. Encoder
encoder_out, encoder_mask = self.encoder(speech, speech_lengths)
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
# 2a. Attention-decoder branch
if self.ctc_weight != 1.0:
loss_att, acc_att = self._calc_att_loss(encoder_out, encoder_mask,
text, text_lengths)
else:
loss_att = None
# 2b. CTC branch
if self.ctc_weight != 0.0:
loss_ctc = self.ctc(encoder_out, encoder_out_lens, text,
text_lengths)
else:
loss_ctc = None
if loss_ctc is None:
loss = loss_att
elif loss_att is None:
loss = loss_ctc
else:
loss = self.ctc_weight * loss_ctc[0] + (1 -
self.ctc_weight) * loss_att
return loss, encoder_out, encoder_out_lens, encoder_mask
def sample_negatives(self, y, num, padding_count=0, speech_lengths=None):
if self.n_negatives == 0:
return y.new(0)
bsz, tsz, fsz = y.shape
y = y.reshape(-1, fsz) # BTC => (BxT)C
# FIXME: what happens if padding_count is specified?
high = tsz - (padding_count or 0)
with torch.no_grad():
assert high > 1, f"{bsz,tsz,fsz}"
if self.n_negatives > 0:
tszs = (torch.arange(num).unsqueeze(-1).expand(
-1, self.n_negatives).flatten())
if speech_lengths is not None:
neg_idxs = [
torch.randint(low=0,
high=speech_lengths[i].item() - 1,
size=(1, self.n_negatives * tsz))
for i in range(len(speech_lengths))
]
neg_idxs = torch.cat(neg_idxs).reshape(
bsz, self.n_negatives * tsz)
else:
neg_idxs = torch.randint(low=0,
high=num - 1,
size=(bsz,
self.n_negatives * tsz))
neg_idxs[neg_idxs >= tszs] += 1
if self.n_negatives > 0:
neg_idxs = neg_idxs + (torch.arange(bsz).unsqueeze(1) * high)
negs = y[neg_idxs.view(-1)]
negs = negs.contiguous().view(bsz, num, self.n_negatives,
fsz).permute(2, 0, 1, 3) # to NxBxTxC
return negs, neg_idxs
def compute_preds(self, x, y, negatives):
neg_is_pos = (y == negatives).all(-1)
y = y.unsqueeze(0)
targets = torch.cat([y, negatives], dim=0)
logits = torch.cosine_similarity(x.float(), targets.float(), dim=-1)
logits = logits / self.logit_temp
logits = logits.type_as(x)
if neg_is_pos.any():
if not hasattr(self, "_inftensor"):
self._inftensor = float("-inf")
# logits[1:] = index_put(logits[1:], neg_is_pos, self._inftensor)
logits[1:][neg_is_pos] = self._inftensor
logits = logits.transpose(0, 2)
logits = logits.transpose(0, 1)
logits = logits.reshape(-1, logits.size(-1))
return logits
def CTL(self, x, y, negs, mask=None):
# Step1: compute cosine similarity, shape [B*T, n_negatives+1]
logits = self.compute_preds(x, y, negs)
# Step2: target shape [B*T]
target = x.new_zeros(x.size(0) * x.size(1), dtype=torch.long)
# Step3: compute CTL loss
if mask is not None:
normalize_length = mask.sum()
bz, sz = mask.size(0), mask.size(-1)
mask = mask.squeeze(1).reshape(bz * sz).eq(0)
ce = F.cross_entropy(logits, target, reduction='none')
loss = ce.masked_fill(mask, 0).sum() / normalize_length
else:
loss = F.cross_entropy(logits, target)
return loss