Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2023 Wenet Community. (authors: Xingchen Song) | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# Modified from [Whisper](https://github.com/openai/whisper) | |
import torch | |
from typing import Tuple, Dict, List | |
from wenet.transformer.asr_model import ASRModel | |
from wenet.transformer.ctc import CTC | |
from wenet.transformer.encoder import TransformerEncoder | |
from wenet.transformer.decoder import TransformerDecoder | |
from wenet.utils.common import IGNORE_ID, add_whisper_tokens, th_accuracy | |
class Whisper(ASRModel): | |
def __init__( | |
self, | |
vocab_size: int, | |
encoder: TransformerEncoder, | |
decoder: TransformerDecoder, | |
ctc: CTC = None, | |
ctc_weight: float = 0.5, | |
ignore_id: int = IGNORE_ID, | |
reverse_weight: float = 0.0, | |
lsm_weight: float = 0.0, | |
length_normalized_loss: bool = False, | |
special_tokens: dict = None, | |
): | |
super().__init__(vocab_size, encoder, decoder, ctc, ctc_weight, | |
ignore_id, reverse_weight, lsm_weight, | |
length_normalized_loss, special_tokens) | |
assert reverse_weight == 0.0 | |
self.sos = special_tokens["sot"] | |
self.eos = special_tokens["eot"] | |
self.decode_maxlen = self.decoder.embed[1].max_len | |
# TODO(xcsong): time align | |
def set_alignment_heads(self, dump: bytes): | |
raise NotImplementedError | |
def is_multilingual(self): | |
return self.vocab_size >= 51865 | |
def num_languages(self): | |
return self.vocab_size - 51765 - int(self.is_multilingual) | |
def _calc_att_loss( | |
self, | |
encoder_out: torch.Tensor, | |
encoder_mask: torch.Tensor, | |
ys_pad: torch.Tensor, | |
ys_pad_lens: torch.Tensor, | |
infos: Dict[str, List[str]], | |
) -> Tuple[torch.Tensor, float]: | |
prev_len = ys_pad.size(1) | |
ys_in_pad, ys_out_pad = add_whisper_tokens(self.special_tokens, | |
ys_pad, | |
self.ignore_id, | |
tasks=infos['tasks'], | |
no_timestamp=True, | |
langs=infos['langs'], | |
use_prev=False) | |
cur_len = ys_in_pad.size(1) | |
ys_in_lens = ys_pad_lens + cur_len - prev_len | |
# 1. Forward decoder | |
decoder_out, r_decoder_out, _ = self.decoder(encoder_out, encoder_mask, | |
ys_in_pad, ys_in_lens) | |
# 2. Compute attention loss | |
loss_att = self.criterion_att(decoder_out, ys_out_pad) | |
acc_att = th_accuracy( | |
decoder_out.view(-1, self.vocab_size), | |
ys_out_pad, | |
ignore_label=self.ignore_id, | |
) | |
return loss_att, acc_att | |