File size: 4,605 Bytes
568e264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Copyright (c) 2023 Wenet Community. (authors: Xingchen Song)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Modified from [Whisper](https://github.com/openai/whisper)

import torch

from typing import Tuple, Dict, List

from torch import nn

from wenet.transformer.asr_model import ASRModel
from wenet.transformer.ctc import CTC
from wenet.transformer.encoder import TransformerEncoder
from wenet.transformer.decoder import TransformerDecoder
from wenet.utils.common import IGNORE_ID, add_whisper_tokens, th_accuracy


class Whisper(ASRModel):

    def __init__(
        self,
        vocab_size: int,
        encoder: TransformerEncoder,
        decoder: TransformerDecoder,
        ctc: CTC = None,
        ctc_weight: float = 0.5,
        ignore_id: int = IGNORE_ID,
        reverse_weight: float = 0.0,
        lsm_weight: float = 0.0,
        length_normalized_loss: bool = False,
        special_tokens: dict = None,
    ):
        super().__init__(vocab_size, encoder, decoder, ctc, ctc_weight,
                         ignore_id, reverse_weight, lsm_weight,
                         length_normalized_loss, special_tokens)
        assert reverse_weight == 0.0
        self.sos = special_tokens["sot"]
        self.eos = special_tokens["eot"]
        self.decode_maxlen = self.decoder.embed[1].max_len

        # 添加clap
        self.clip_length = 40
        self.prefix_length = 40
        num_layers = 12
        dim_embedding = 1024
        dim_clip = 512
        # 修改一下使用nn.transformer
        nhead = 8
        self.ttt = nn.TransformerEncoder(
            encoder_layer=nn.TransformerEncoderLayer(d_model=dim_embedding, nhead=nhead),
            num_layers=num_layers
        )
        self.linear = nn.Linear(dim_clip, self.clip_length * dim_embedding)
        self.prefix_const = nn.Parameter(torch.randn(self.prefix_length, dim_embedding), requires_grad=True)

        from transformers import ClapModel, AutoFeatureExtractor
        # 加载模型和处理器
        self.model = ClapModel.from_pretrained(
            "/home/work_nfs11/wjtian/work_space/wenet_whisper_finetune/examples/wenetspeech/whisper/pretrain_ckpt/clap-htsat-unfused")
        self.processor = AutoFeatureExtractor.from_pretrained(
            "/home/work_nfs11/wjtian/work_space/wenet_whisper_finetune/examples/wenetspeech/whisper/pretrain_ckpt/clap-htsat-unfused")
        for param in self.model.parameters():
            param.requires_grad = False

    # TODO(xcsong): time align
    def set_alignment_heads(self, dump: bytes):
        raise NotImplementedError

    @property
    def is_multilingual(self):
        return self.vocab_size >= 51865

    @property
    def num_languages(self):
        return self.vocab_size - 51765 - int(self.is_multilingual)

    def _calc_att_loss(
        self,
        encoder_out: torch.Tensor,
        encoder_mask: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
        infos: Dict[str, List[str]],
    ) -> Tuple[torch.Tensor, float]:
        prev_len = ys_pad.size(1)
        ys_in_pad, ys_out_pad = add_whisper_tokens(self.special_tokens,
                                                   ys_pad,
                                                   self.ignore_id,
                                                   tasks=infos['tasks'],
                                                   no_timestamp=True,
                                                   langs=infos['langs'],
                                                   use_prev=False)
        cur_len = ys_in_pad.size(1)
        ys_in_lens = ys_pad_lens + cur_len - prev_len

        # 1. Forward decoder
        decoder_out, r_decoder_out, _ = self.decoder(encoder_out, encoder_mask,
                                                     ys_in_pad, ys_in_lens)

        # 2. Compute attention loss
        loss_att = self.criterion_att(decoder_out, ys_out_pad)
        acc_att = th_accuracy(
            decoder_out.view(-1, self.vocab_size),
            ys_out_pad,
            ignore_label=self.ignore_id,
        )
        return loss_att, acc_att