thanhduycao commited on
Commit
8edb6e9
1 Parent(s): 92e64e9

Upload model_handling.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model_handling.py +163 -0
model_handling.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Wav2Vec2PreTrainedModel, Wav2Vec2Model
2
+ from torch import nn
3
+ import warnings
4
+ import torch
5
+ from transformers.modeling_outputs import CausalLMOutput
6
+ from collections import OrderedDict
7
+
8
+ _HIDDEN_STATES_START_POSITION = 2
9
+
10
+
11
+ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
12
+ def __init__(self, config):
13
+ super().__init__(config)
14
+
15
+ self.wav2vec2 = Wav2Vec2Model(config)
16
+ self.dropout = nn.Dropout(config.final_dropout)
17
+
18
+ self.feature_transform = nn.Sequential(OrderedDict([
19
+ ('linear1', nn.Linear(config.hidden_size, config.hidden_size)),
20
+ ('bn1', nn.BatchNorm1d(config.hidden_size)),
21
+ ('activation1', nn.LeakyReLU()),
22
+ ('drop1', nn.Dropout(config.final_dropout)),
23
+ ('linear2', nn.Linear(config.hidden_size, config.hidden_size)),
24
+ ('bn2', nn.BatchNorm1d(config.hidden_size)),
25
+ ('activation2', nn.LeakyReLU()),
26
+ ('drop2', nn.Dropout(config.final_dropout)),
27
+ ('linear3', nn.Linear(config.hidden_size, config.hidden_size)),
28
+ ('bn3', nn.BatchNorm1d(config.hidden_size)),
29
+ ('activation3', nn.LeakyReLU()),
30
+ ('drop3', nn.Dropout(config.final_dropout))
31
+ ]))
32
+
33
+ if config.vocab_size is None:
34
+ raise ValueError(
35
+ f"You are trying to instantiate {self.__class__} with a configuration that "
36
+ "does not define the vocabulary size of the language model head. Please "
37
+ "instantiate the model as follows: `Wav2Vec2ForCTC.from_pretrained(..., vocab_size=vocab_size)`. "
38
+ "or define `vocab_size` of your model's configuration."
39
+ )
40
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
41
+
42
+ self.is_wav2vec_freeze = False
43
+
44
+ # Initialize weights and apply final processing
45
+ self.post_init()
46
+
47
+ def freeze_feature_extractor(self):
48
+ """
49
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
50
+ not be updated during training.
51
+ """
52
+ warnings.warn(
53
+ "The method `freeze_feature_extractor` is deprecated and will be removed in Transformers v5."
54
+ "Please use the equivalent `freeze_feature_encoder` method instead.",
55
+ FutureWarning,
56
+ )
57
+ self.freeze_feature_encoder()
58
+
59
+ def freeze_feature_encoder(self):
60
+ """
61
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
62
+ not be updated during training.
63
+ """
64
+ self.wav2vec2.feature_extractor._freeze_parameters()
65
+
66
+ def freeze_wav2vec(self, is_freeze=True):
67
+ """
68
+ Calling this function will disable the gradient computation for the feature extractor so that its parameter
69
+ will not be updated during training.
70
+ """
71
+ if is_freeze:
72
+ self.is_wav2vec_freeze = True
73
+ for param in self.wav2vec2.parameters():
74
+ param.requires_grad = False
75
+ else:
76
+ self.is_wav2vec_freeze = False
77
+ for param in self.wav2vec2.parameters():
78
+ param.requires_grad = True
79
+ self.freeze_feature_encoder()
80
+
81
+ model_total_params = sum(p.numel() for p in self.parameters())
82
+ model_total_params_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
83
+ print("model_total_params: {}\nmodel_total_params_trainable: {}".format(model_total_params,
84
+ model_total_params_trainable))
85
+
86
+ def forward(
87
+ self,
88
+ input_values,
89
+ attention_mask=None,
90
+ output_attentions=None,
91
+ output_hidden_states=None,
92
+ return_dict=None,
93
+ labels=None,
94
+ ):
95
+ r"""
96
+ labels (`torch.LongTensor` of shape `(batch_size, target_length)`, *optional*):
97
+ Labels for connectionist temporal classification. Note that `target_length` has to be smaller or equal to
98
+ the sequence length of the output logits. Indices are selected in `[-100, 0, ..., config.vocab_size - 1]`.
99
+ All labels set to `-100` are ignored (masked), the loss is only computed for labels in `[0, ...,
100
+ config.vocab_size - 1]`.
101
+ """
102
+
103
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
104
+
105
+ outputs = self.wav2vec2(
106
+ input_values,
107
+ attention_mask=attention_mask,
108
+ output_attentions=output_attentions,
109
+ output_hidden_states=output_hidden_states,
110
+ return_dict=return_dict,
111
+ )
112
+
113
+ hidden_states = outputs[0]
114
+ hidden_states = self.dropout(hidden_states)
115
+
116
+ B, T, F = hidden_states.size()
117
+ hidden_states = hidden_states.view(B * T, F)
118
+
119
+ hidden_states = self.feature_transform(hidden_states)
120
+
121
+ hidden_states = hidden_states.view(B, T, F)
122
+
123
+ logits = self.lm_head(hidden_states)
124
+
125
+ loss = None
126
+ if labels is not None:
127
+
128
+ if labels.max() >= self.config.vocab_size:
129
+ raise ValueError(f"Label values must be <= vocab_size: {self.config.vocab_size}")
130
+
131
+ # retrieve loss input_lengths from attention_mask
132
+ attention_mask = (
133
+ attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
134
+ )
135
+ input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
136
+
137
+ # assuming that padded tokens are filled with -100
138
+ # when not being attended to
139
+ labels_mask = labels >= 0
140
+ target_lengths = labels_mask.sum(-1)
141
+ flattened_targets = labels.masked_select(labels_mask)
142
+
143
+ # ctc_loss doesn't support fp16
144
+ log_probs = nn.functional.log_softmax(logits, dim=-1, dtype=torch.float32).transpose(0, 1)
145
+
146
+ with torch.backends.cudnn.flags(enabled=False):
147
+ loss = nn.functional.ctc_loss(
148
+ log_probs,
149
+ flattened_targets,
150
+ input_lengths,
151
+ target_lengths,
152
+ blank=self.config.pad_token_id,
153
+ reduction=self.config.ctc_loss_reduction,
154
+ zero_infinity=self.config.ctc_zero_infinity,
155
+ )
156
+
157
+ if not return_dict:
158
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
159
+ return ((loss,) + output) if loss is not None else output
160
+
161
+ return CausalLMOutput(
162
+ loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
163
+ )