OSUM / wenet /whisper /whisper_with_clap.py
tomxxie
适配zeroGPU
568e264
# Copyright (c) 2023 Wenet Community. (authors: Xingchen Song)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Modified from [Whisper](https://github.com/openai/whisper)
import torch
from typing import Tuple, Dict, List
from torch import nn
from wenet.transformer.asr_model import ASRModel
from wenet.transformer.ctc import CTC
from wenet.transformer.encoder import TransformerEncoder
from wenet.transformer.decoder import TransformerDecoder
from wenet.utils.common import IGNORE_ID, add_whisper_tokens, th_accuracy
class Whisper(ASRModel):
def __init__(
self,
vocab_size: int,
encoder: TransformerEncoder,
decoder: TransformerDecoder,
ctc: CTC = None,
ctc_weight: float = 0.5,
ignore_id: int = IGNORE_ID,
reverse_weight: float = 0.0,
lsm_weight: float = 0.0,
length_normalized_loss: bool = False,
special_tokens: dict = None,
):
super().__init__(vocab_size, encoder, decoder, ctc, ctc_weight,
ignore_id, reverse_weight, lsm_weight,
length_normalized_loss, special_tokens)
assert reverse_weight == 0.0
self.sos = special_tokens["sot"]
self.eos = special_tokens["eot"]
self.decode_maxlen = self.decoder.embed[1].max_len
# 添加clap
self.clip_length = 40
self.prefix_length = 40
num_layers = 12
dim_embedding = 1024
dim_clip = 512
# 修改一下使用nn.transformer
nhead = 8
self.ttt = nn.TransformerEncoder(
encoder_layer=nn.TransformerEncoderLayer(d_model=dim_embedding, nhead=nhead),
num_layers=num_layers
)
self.linear = nn.Linear(dim_clip, self.clip_length * dim_embedding)
self.prefix_const = nn.Parameter(torch.randn(self.prefix_length, dim_embedding), requires_grad=True)
from transformers import ClapModel, AutoFeatureExtractor
# 加载模型和处理器
self.model = ClapModel.from_pretrained(
"/home/work_nfs11/wjtian/work_space/wenet_whisper_finetune/examples/wenetspeech/whisper/pretrain_ckpt/clap-htsat-unfused")
self.processor = AutoFeatureExtractor.from_pretrained(
"/home/work_nfs11/wjtian/work_space/wenet_whisper_finetune/examples/wenetspeech/whisper/pretrain_ckpt/clap-htsat-unfused")
for param in self.model.parameters():
param.requires_grad = False
# TODO(xcsong): time align
def set_alignment_heads(self, dump: bytes):
raise NotImplementedError
@property
def is_multilingual(self):
return self.vocab_size >= 51865
@property
def num_languages(self):
return self.vocab_size - 51765 - int(self.is_multilingual)
def _calc_att_loss(
self,
encoder_out: torch.Tensor,
encoder_mask: torch.Tensor,
ys_pad: torch.Tensor,
ys_pad_lens: torch.Tensor,
infos: Dict[str, List[str]],
) -> Tuple[torch.Tensor, float]:
prev_len = ys_pad.size(1)
ys_in_pad, ys_out_pad = add_whisper_tokens(self.special_tokens,
ys_pad,
self.ignore_id,
tasks=infos['tasks'],
no_timestamp=True,
langs=infos['langs'],
use_prev=False)
cur_len = ys_in_pad.size(1)
ys_in_lens = ys_pad_lens + cur_len - prev_len
# 1. Forward decoder
decoder_out, r_decoder_out, _ = self.decoder(encoder_out, encoder_mask,
ys_in_pad, ys_in_lens)
# 2. Compute attention loss
loss_att = self.criterion_att(decoder_out, ys_out_pad)
acc_att = th_accuracy(
decoder_out.view(-1, self.vocab_size),
ys_out_pad,
ignore_label=self.ignore_id,
)
return loss_att, acc_att