# Copyright (c) 2023 Wenet Community. (authors: Xingchen Song) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Modified from [Whisper](https://github.com/openai/whisper) import torch from typing import Tuple, Dict, List from torch import nn from wenet.transformer.asr_model import ASRModel from wenet.transformer.ctc import CTC from wenet.transformer.encoder import TransformerEncoder from wenet.transformer.decoder import TransformerDecoder from wenet.utils.common import IGNORE_ID, add_whisper_tokens, th_accuracy class Whisper(ASRModel): def __init__( self, vocab_size: int, encoder: TransformerEncoder, decoder: TransformerDecoder, ctc: CTC = None, ctc_weight: float = 0.5, ignore_id: int = IGNORE_ID, reverse_weight: float = 0.0, lsm_weight: float = 0.0, length_normalized_loss: bool = False, special_tokens: dict = None, ): super().__init__(vocab_size, encoder, decoder, ctc, ctc_weight, ignore_id, reverse_weight, lsm_weight, length_normalized_loss, special_tokens) assert reverse_weight == 0.0 self.sos = special_tokens["sot"] self.eos = special_tokens["eot"] self.decode_maxlen = self.decoder.embed[1].max_len # 添加clap self.clip_length = 40 self.prefix_length = 40 num_layers = 12 dim_embedding = 1024 dim_clip = 512 # 修改一下使用nn.transformer nhead = 8 self.ttt = nn.TransformerEncoder( encoder_layer=nn.TransformerEncoderLayer(d_model=dim_embedding, nhead=nhead), num_layers=num_layers ) self.linear = nn.Linear(dim_clip, self.clip_length * dim_embedding) self.prefix_const = nn.Parameter(torch.randn(self.prefix_length, dim_embedding), requires_grad=True) from transformers import ClapModel, AutoFeatureExtractor # 加载模型和处理器 self.model = ClapModel.from_pretrained( "/home/work_nfs11/wjtian/work_space/wenet_whisper_finetune/examples/wenetspeech/whisper/pretrain_ckpt/clap-htsat-unfused") self.processor = AutoFeatureExtractor.from_pretrained( "/home/work_nfs11/wjtian/work_space/wenet_whisper_finetune/examples/wenetspeech/whisper/pretrain_ckpt/clap-htsat-unfused") for param in self.model.parameters(): param.requires_grad = False # TODO(xcsong): time align def set_alignment_heads(self, dump: bytes): raise NotImplementedError @property def is_multilingual(self): return self.vocab_size >= 51865 @property def num_languages(self): return self.vocab_size - 51765 - int(self.is_multilingual) def _calc_att_loss( self, encoder_out: torch.Tensor, encoder_mask: torch.Tensor, ys_pad: torch.Tensor, ys_pad_lens: torch.Tensor, infos: Dict[str, List[str]], ) -> Tuple[torch.Tensor, float]: prev_len = ys_pad.size(1) ys_in_pad, ys_out_pad = add_whisper_tokens(self.special_tokens, ys_pad, self.ignore_id, tasks=infos['tasks'], no_timestamp=True, langs=infos['langs'], use_prev=False) cur_len = ys_in_pad.size(1) ys_in_lens = ys_pad_lens + cur_len - prev_len # 1. Forward decoder decoder_out, r_decoder_out, _ = self.decoder(encoder_out, encoder_mask, ys_in_pad, ys_in_lens) # 2. Compute attention loss loss_att = self.criterion_att(decoder_out, ys_out_pad) acc_att = th_accuracy( decoder_out.view(-1, self.vocab_size), ys_out_pad, ignore_label=self.ignore_id, ) return loss_att, acc_att