Sin2pi commited on
Commit
e2b33e8
·
verified ·
1 Parent(s): effaa91

Upload model2.ipynb

Browse files
Files changed (1) hide show
  1. model2.ipynb +1166 -0
model2.ipynb ADDED
@@ -0,0 +1,1166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "\n",
10
+ "import base64, gzip, evaluate, math, os, sys, time\n",
11
+ "import gzip, neologdn\n",
12
+ "from transformers.modeling_utils import PreTrainedModel \n",
13
+ "import collections\n",
14
+ "import copy\n",
15
+ "import functools\n",
16
+ "from functools import partial, wraps\n",
17
+ "from threading import Thread\n",
18
+ "import gc\n",
19
+ "import importlib.metadata\n",
20
+ "import inspect\n",
21
+ "import itertools\n",
22
+ "from sklearn.metrics import accuracy_score, precision_score, f1_score, recall_score\n",
23
+ "from torch import amp, Tensor, optim\n",
24
+ "from torch.utils.checkpoint import checkpoint\n",
25
+ "from contextlib import contextmanager\n",
26
+ "from dataclasses import dataclass\n",
27
+ "from transformers.models.whisper.modeling_whisper import WhisperPreTrainedModel\n",
28
+ "from transformers.models.whisper.generation_whisper import WhisperGenerationMixin\n",
29
+ "from transformers.optimization import Adafactor, AdafactorSchedule\n",
30
+ "from huggingface_hub import PyTorchModelHubMixin\n",
31
+ "from datasets import IterableDatasetDict, Audio, load_dataset, load_from_disk\n",
32
+ "import numpy as np\n",
33
+ "import torch, transformers, warnings\n",
34
+ "from typing import Dict, Iterable, Optional, Tuple, Union, List, Any, Type\n",
35
+ "import torch.nn.functional as F\n",
36
+ "from torch import Tensor, nn\n",
37
+ "import torchaudio, torchaudio.transforms as T\n",
38
+ "from transformers import Seq2SeqTrainer, TrainerCallback, Seq2SeqTrainingArguments, WhisperTokenizer, WhisperForConditionalGeneration, WhisperConfig, WhisperProcessor, WhisperFeatureExtractor, WhisperTokenizer, WhisperTokenizerFast\n",
39
+ "from whisper.decoding import decode as decode_function\n",
40
+ "from whisper.decoding import detect_language as detect_language_function\n",
41
+ "from whisper.transcribe import transcribe as transcribe_function\n",
42
+ "from torch.utils.tensorboard import SummaryWriter\n",
43
+ "\n",
44
+ "try:\n",
45
+ " from torch.nn.functional import scaled_dot_product_attention\n",
46
+ "\n",
47
+ " SDPA_AVAILABLE = True\n",
48
+ "except (ImportError, RuntimeError, OSError):\n",
49
+ " scaled_dot_product_attention = None\n",
50
+ " SDPA_AVAILABLE = False\n",
51
+ "\n",
52
+ "transformers.utils.logging.set_verbosity_error()\n",
53
+ "warnings.filterwarnings(action=\"ignore\")\n",
54
+ "warnings.warn = lambda *args,**kwargs: None\n",
55
+ "torch.autograd.set_detect_anomaly(True)\n",
56
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
57
+ "dtype = torch.float32\n",
58
+ "torch.set_default_dtype(dtype)"
59
+ ]
60
+ },
61
+ {
62
+ "cell_type": "code",
63
+ "execution_count": null,
64
+ "metadata": {},
65
+ "outputs": [],
66
+ "source": [
67
+ "### Model ###\n",
68
+ "\n",
69
+ "class LayerNorm(nn.Module):\n",
70
+ " def __init__(self, num_features, eps=1e-6):\n",
71
+ " super(LayerNorm, self).__init__()\n",
72
+ " self.gamma = nn.Parameter(torch.ones(num_features))\n",
73
+ " self.beta = nn.Parameter(torch.zeros(num_features))\n",
74
+ " self.eps = eps\n",
75
+ "\n",
76
+ " def forward(self, x):\n",
77
+ " mean = x.mean(dim=-1, keepdim=True)\n",
78
+ " std = x.std(dim=-1, keepdim=True)\n",
79
+ " x = (x - mean) / (std + self.eps)\n",
80
+ " return self.gamma * x + self.beta\n",
81
+ "\n",
82
+ "class Linear(nn.Module):\n",
83
+ " def __init__(self, in_features: int, out_features: int, dropout_rate = 0.001, use_batchnorm: bool = True, activation: str = 'relu'):\n",
84
+ " super(Linear, self).__init__()\n",
85
+ " self.linear = nn.Linear(in_features, out_features)\n",
86
+ " self.dropout = nn.Dropout(dropout_rate)\n",
87
+ " self.use_batchnorm = use_batchnorm\n",
88
+ " self.activation = activation\n",
89
+ "\n",
90
+ " if self.use_batchnorm:\n",
91
+ " self.batchnorm = nn.BatchNorm1d(out_features)\n",
92
+ " self.reset_parameters()\n",
93
+ "\n",
94
+ " def reset_parameters(self):\n",
95
+ " nn.init.kaiming_uniform_(self.linear.weight, nonlinearity=self.activation)\n",
96
+ " if self.linear.bias is not None:\n",
97
+ " nn.init.zeros_(self.linear.bias)\n",
98
+ "\n",
99
+ " def forward(self, x):\n",
100
+ " batch_size, seq_len, _ = x.size()\n",
101
+ " x = x.view(-1, x.size(-1)) \n",
102
+ " x = self.linear(x)\n",
103
+ "\n",
104
+ " if self.use_batchnorm:\n",
105
+ " x = self.batchnorm(x)\n",
106
+ "\n",
107
+ " x = self.apply_activation(x=x)\n",
108
+ " x = self.dropout(x)\n",
109
+ " x = x.view(batch_size, seq_len, -1) \n",
110
+ " \n",
111
+ " return x\n",
112
+ "\n",
113
+ " def apply_activation(self, x):\n",
114
+ " if self.activation == 'relu':\n",
115
+ " return F.relu(x)\n",
116
+ " elif self.activation == 'tanh':\n",
117
+ " return torch.tanh(x)\n",
118
+ " elif self.activation == 'sigmoid':\n",
119
+ " return torch.sigmoid(x)\n",
120
+ " else:\n",
121
+ " raise ValueError(f'Unsupported activation function: {self.activation}')\n",
122
+ "\n",
123
+ "class Conv1d(nn.Conv1d):\n",
124
+ " def __init__(self, *args, **kwargs):\n",
125
+ " super().__init__(*args, **kwargs)\n",
126
+ " self.reset_parameters()\n",
127
+ "\n",
128
+ " def reset_parameters(self):\n",
129
+ " nn.init.kaiming_uniform_(self.weight, nonlinearity='relu')\n",
130
+ " if self.bias is not None:\n",
131
+ " nn.init.zeros_(self.bias)\n",
132
+ "\n",
133
+ " def _conv_forward(self, x, weight, bias) -> Tensor:\n",
134
+ " weight = self.weight.to(x.dtype)\n",
135
+ " bias = None if self.bias is None else self.bias.to(x.dtype)\n",
136
+ " return super()._conv_forward(x, weight, bias)\n",
137
+ "\n",
138
+ "class BiasedCrossAttention(nn.Module):\n",
139
+ " def __init__(self, n_state, n_head, dropout_rate=0.001):\n",
140
+ " super().__init__()\n",
141
+ " self.n_head = n_head\n",
142
+ " self.n_state = n_state\n",
143
+ " self.head_dim = n_state // n_head\n",
144
+ "\n",
145
+ " self.query = nn.Linear(n_state, n_state)\n",
146
+ " self.key = nn.Linear(n_state, n_state, bias=False)\n",
147
+ " self.value = nn.Linear(n_state, n_state)\n",
148
+ " self.out = nn.Linear(n_state, n_state)\n",
149
+ "\n",
150
+ " self.bias = nn.Parameter(torch.zeros(n_head, 1, self.head_dim))\n",
151
+ " self.dropout = nn.Dropout(dropout_rate)\n",
152
+ " self.norm = LayerNorm(num_features=n_state)\n",
153
+ " \n",
154
+ " def forward(self, q, k, v, mask=None):\n",
155
+ " batch_size, seq_length, _ = q.size()\n",
156
+ "\n",
157
+ " q = self.query(q).view(batch_size, seq_length, self.n_head, self.head_dim)\n",
158
+ " k = self.key(k).view(batch_size, seq_length, self.n_head, self.head_dim)\n",
159
+ " v = self.value(v).view(batch_size, seq_length, self.n_head, self.head_dim)\n",
160
+ "\n",
161
+ " qk = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5) + self.bias\n",
162
+ " if mask is not None:\n",
163
+ " qk = qk.masked_fill(mask == 0, float('-inf'))\n",
164
+ "\n",
165
+ " w = F.softmax(qk, dim=-1)\n",
166
+ " w = self.dropout(w)\n",
167
+ "\n",
168
+ " out = (w @ v).transpose(1, 2).contiguous().view(batch_size, seq_length, -1)\n",
169
+ " out = self.norm(self.out(out) + q.view(batch_size, seq_length, -1))\n",
170
+ " return out\n",
171
+ "\n",
172
+ "class DynamicConvAttention(nn.Module):\n",
173
+ " def __init__(self, n_state, n_head, kernel_size=3, dropout_rate=0.001):\n",
174
+ " super().__init__()\n",
175
+ " self.n_state = n_state\n",
176
+ " self.n_head = n_head\n",
177
+ " self.kernel_size = kernel_size\n",
178
+ "\n",
179
+ " self.conv = nn.Conv1d(n_state, n_state, kernel_size, padding=kernel_size // 2, groups=n_head)\n",
180
+ " self.dropout = nn.Dropout(dropout_rate)\n",
181
+ "\n",
182
+ " self.query = nn.Linear(n_state, n_state)\n",
183
+ " self.key = nn.Linear(n_state, n_state, bias=False)\n",
184
+ " self.value = nn.Linear(n_state, n_state)\n",
185
+ " self.out_proj = nn.Linear(n_state, n_state)\n",
186
+ "\n",
187
+ " self.norm = LayerNorm(num_features=n_state)\n",
188
+ "\n",
189
+ " def forward(self, x):\n",
190
+ " batch_size, seq_len, embed_dim = x.size()\n",
191
+ " if embed_dim != self.n_state:\n",
192
+ " raise ValueError(f\"Expected embed_dim of {self.n_state}, but got {embed_dim}\")\n",
193
+ "\n",
194
+ " q = self.query(x)\n",
195
+ " k = self.key(x)\n",
196
+ " v = self.value(x)\n",
197
+ "\n",
198
+ " x = x.permute(0, 2, 1)\n",
199
+ " conv_out = self.conv(x)\n",
200
+ " conv_out = conv_out.permute(0, 2, 1)\n",
201
+ " conv_out = self.norm(conv_out)\n",
202
+ " conv_out = self.dropout(conv_out)\n",
203
+ "\n",
204
+ " attention_out = F.softmax(torch.matmul(q, k.transpose(-2, -1)) / (self.n_state ** 0.5), dim=-1)\n",
205
+ " attention_out = torch.matmul(attention_out, v)\n",
206
+ " \n",
207
+ " combined_out = conv_out + attention_out\n",
208
+ " combined_out = self.norm(combined_out)\n",
209
+ " \n",
210
+ " return self.out_proj(self.dropout(combined_out)) + x.permute(0, 2, 1)\n",
211
+ "\n",
212
+ "class HybridAttention(nn.Module):\n",
213
+ " def __init__(self, n_state, n_head, window_size=1, dropout_rate=0.001):\n",
214
+ " super().__init__()\n",
215
+ " self.local_attn = nn.MultiheadAttention(n_state, n_head, dropout=dropout_rate)\n",
216
+ " self.global_attn = nn.MultiheadAttention(n_state, n_head, dropout=dropout_rate)\n",
217
+ " self.ln_local = LayerNorm(num_features=n_state)\n",
218
+ " self.ln_global = LayerNorm(num_features=n_state)\n",
219
+ "\n",
220
+ " self.dropout = nn.Dropout(dropout_rate)\n",
221
+ " self.window_size = window_size\n",
222
+ "\n",
223
+ " def forward(self, x):\n",
224
+ " x_local = self.ln_local(x)\n",
225
+ " x_global = self.ln_global(x)\n",
226
+ " x_local = x_local.permute(1, 0, 2)\n",
227
+ " x_global = x_global.permute(1, 0, 2)\n",
228
+ " local_out = self.sliding_window_attention(x_local)\n",
229
+ " global_out, _ = self.global_attn(x_global, x_global, x_global)\n",
230
+ " combined_out = local_out + global_out\n",
231
+ " combined_out = combined_out.permute(1, 0, 2)\n",
232
+ " return self.dropout(combined_out)\n",
233
+ "\n",
234
+ " def sliding_window_attention(self, x):\n",
235
+ " batch_size, seq_len, n_state = x.size()\n",
236
+ " window_size = min(self.window_size, max(1, seq_len // 4))\n",
237
+ " output = torch.zeros_like(x, device=x.device, dtype=x.dtype)\n",
238
+ "\n",
239
+ " for i in range(0, seq_len, step=window_size):\n",
240
+ " end = min(i + window_size, seq_len)\n",
241
+ " query = x[i:end, :, :]\n",
242
+ " start = max(0, i - window_size)\n",
243
+ " key = x[start:end, :, :]\n",
244
+ " value = x[start:end, :, :]\n",
245
+ " attn_output, _ = self.local_attn(query, key, value)\n",
246
+ " output[i:end, :, :] = attn_output[:end - i, :, :]\n",
247
+ "\n",
248
+ " return output\n",
249
+ "\n",
250
+ "class CombinedRotaryEmbedding(nn.Module):\n",
251
+ " def __init__(self, n_state, n_head, num_rotations, base=10000, checkpointing=False):\n",
252
+ " super().__init__()\n",
253
+ " self.n_state = n_state\n",
254
+ " self.n_head = n_head\n",
255
+ " self.h_dim = n_state // n_head\n",
256
+ " self.num_rotations = num_rotations\n",
257
+ " self.base = base\n",
258
+ " self.checkpointing = checkpointing\n",
259
+ " \n",
260
+ " self.thetas = nn.Parameter(torch.zeros(num_rotations))\n",
261
+ " self.rotation_pairs = nn.Parameter(torch.rand(num_rotations, 2) * self.h_dim)\n",
262
+ "\n",
263
+ " self.theta_scale = nn.Parameter(torch.ones(1)) \n",
264
+ "\n",
265
+ " self.rotation_matrix = nn.Parameter(torch.eye(self.h_dim))\n",
266
+ " \n",
267
+ " self.inv_freq = nn.Parameter(1.0 / (self.base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim)))\n",
268
+ " \n",
269
+ " def givens_rotation_matrix(self, n_state, i, j, theta):\n",
270
+ " G = torch.eye(n_state, device=theta.device)\n",
271
+ " G[i, i] = math.cos(theta)\n",
272
+ " G[i, j] = -math.sin(theta)\n",
273
+ " G[j, i] = math.sin(theta)\n",
274
+ " G[j, j] = math.cos(theta)\n",
275
+ " return G\n",
276
+ " \n",
277
+ " def update_base(self, new_base):\n",
278
+ " self.base = new_base\n",
279
+ " self.inv_freq = nn.Parameter(1.0 / (self.base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim)))\n",
280
+ " \n",
281
+ " def reset_parameters(self):\n",
282
+ " nn.init.orthogonal_(self.rotation_matrix)\n",
283
+ " nn.init.zeros_(self.thetas)\n",
284
+ " \n",
285
+ " def forward(self, x):\n",
286
+ " if self.checkpointing:\n",
287
+ " return checkpoint(self._forward, x)\n",
288
+ " else:\n",
289
+ " return self._forward(x)\n",
290
+ " \n",
291
+ " def _forward(self, x):\n",
292
+ " if x.dim() not in [3, 4]:\n",
293
+ " raise ValueError(f\"Expected input tensor to be 3D or 4D, but got {x.dim()}D\")\n",
294
+ " \n",
295
+ " if x.dim() == 3:\n",
296
+ " batch_size, seq_len, n_state = x.size()\n",
297
+ " x = x.view(batch_size, seq_len, self.n_head, self.h_dim)\n",
298
+ " else:\n",
299
+ " batch_size, seq_len, n_head, h_dim = x.size()\n",
300
+ " if n_head != self.n_head or h_dim != self.h_dim:\n",
301
+ " raise ValueError(f\"Expected n_head {self.n_head} and h_dim {self.h_dim}, but got n_head {n_head} and h_dim {h_dim}\")\n",
302
+ " \n",
303
+ " x = x.reshape(-1, self.h_dim)\n",
304
+ " \n",
305
+ " for k in range(self.num_rotations):\n",
306
+ " i, j = self.rotation_pairs[k].long()\n",
307
+ " \n",
308
+ " theta = self.thetas[k] * self.theta_scale \n",
309
+ " \n",
310
+ " G = self.givens_rotation_matrix(n_state=self.h_dim, i=i, j=j, theta=theta)\n",
311
+ " x = torch.matmul(x, G)\n",
312
+ " \n",
313
+ " x = torch.matmul(x, self.rotation_matrix)\n",
314
+ " \n",
315
+ " x = x.view(batch_size, seq_len, self.n_head, self.h_dim)\n",
316
+ " \n",
317
+ " sinusoid_inp = torch.einsum('i, j -> i j', torch.arange(seq_len, device=x.device), self.inv_freq.to(x.device))\n",
318
+ " sin = sinusoid_inp.sin()[None, :, None, :]\n",
319
+ " cos = sinusoid_inp.cos()[None, :, None, :]\n",
320
+ " \n",
321
+ " x1, x2 = x[..., ::2], x[..., 1::2]\n",
322
+ " x = torch.cat([x1 * cos - x2 * sin, x1 * sin + x2 * cos], dim=-1)\n",
323
+ " \n",
324
+ " x = x.view(batch_size, seq_len, self.n_state)\n",
325
+ " \n",
326
+ " return x\n",
327
+ "\n",
328
+ "class LearnedSinusoidalEmbeddings(nn.Module):\n",
329
+ " def __init__(self, n_ctx, n_state, checkpointing=False):\n",
330
+ " super().__init__()\n",
331
+ " self.n_ctx = n_ctx\n",
332
+ " self.n_state = n_state\n",
333
+ " self.checkpointing = checkpointing\n",
334
+ "\n",
335
+ " position = torch.arange(0, n_ctx, dtype=torch.float).unsqueeze(1)\n",
336
+ " div_term = torch.exp(torch.arange(0, n_state, 2).float() * -(math.log(10000.0) / n_state))\n",
337
+ " features = torch.zeros(n_ctx, n_state)\n",
338
+ " features[:, 0::2] = torch.sin(position * div_term)\n",
339
+ " features[:, 1::2] = torch.cos(position * div_term)\n",
340
+ " self.register_buffer('sinusoidal_features', features)\n",
341
+ "\n",
342
+ " self.positional_embeddings = nn.Parameter(self.sinusoidal_features.clone())\n",
343
+ "\n",
344
+ " def forward(self, positions):\n",
345
+ " if self.checkpointing:\n",
346
+ " position_embeddings = checkpoint(lambda x: self.positional_embeddings[x], positions)\n",
347
+ " else:\n",
348
+ " position_embeddings = self.positional_embeddings[positions]\n",
349
+ "\n",
350
+ " position_embeddings = torch.nn.functional.normalize(position_embeddings, p=2, dim=-1)\n",
351
+ " return position_embeddings\n",
352
+ "\n",
353
+ "class MultiHeadAttention(nn.Module):\n",
354
+ " use_sdpa = True\n",
355
+ "\n",
356
+ " def __init__(self, n_state: int, n_head: int, max_rel_dist: int = 1, base: int = 10000):\n",
357
+ " super().__init__()\n",
358
+ " assert n_state % n_head == 0, \"n_state must be divisible by n_head\"\n",
359
+ " self.n_head = n_head\n",
360
+ " self.h_dim = n_state // n_head\n",
361
+ " assert self.h_dim % 2 == 0, \"Head dimension must be even for rotary embeddings\"\n",
362
+ "\n",
363
+ " self.positional_scaling = nn.Parameter(torch.ones(1))\n",
364
+ "\n",
365
+ " self.query = nn.Linear(n_state, n_state)\n",
366
+ " self.key = nn.Linear(n_state, n_state, bias=False)\n",
367
+ " self.value = nn.Linear(n_state, n_state)\n",
368
+ " self.out = nn.Linear(n_state, n_state)\n",
369
+ "\n",
370
+ " self.max_rel_dist = max_rel_dist\n",
371
+ " self.base = base\n",
372
+ " inv_freq = 1.0 / (self.base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim))\n",
373
+ " self.register_buffer('inv_freq', inv_freq)\n",
374
+ " self.rel_pos_bias = nn.Embedding(2 * self.max_rel_dist - 1, self.n_head)\n",
375
+ " self.rel_pos_bias.weight.data.fill_(0)\n",
376
+ "\n",
377
+ " self.combined_rotary = CombinedRotaryEmbedding(\n",
378
+ " n_state=n_state,\n",
379
+ " n_head=n_head,\n",
380
+ " num_rotations=self.h_dim // 2,\n",
381
+ " base=base,\n",
382
+ " checkpointing=False \n",
383
+ " )\n",
384
+ "\n",
385
+ " if device:\n",
386
+ " self.to(device)\n",
387
+ "\n",
388
+ " def update_base(self, new_base): \n",
389
+ " self.base = new_base \n",
390
+ " inv_freq = 1.0 / (self.base ** (torch.arange(0, self.h_dim, 2).float() / self.h_dim)) \n",
391
+ " self.register_buffer('inv_freq', inv_freq) \n",
392
+ " self.combined_rotary.update_base(new_base=new_base)\n",
393
+ "\n",
394
+ " def forward(self, x, xa = None, mask = None, kv_cache = None):\n",
395
+ " q = self.query(x)\n",
396
+ "\n",
397
+ " if kv_cache is None or xa is None or 'k' not in kv_cache:\n",
398
+ " k_input = x if xa is None else xa\n",
399
+ " k = self.key(k_input)\n",
400
+ " v = self.value(k_input)\n",
401
+ " if kv_cache is not None:\n",
402
+ " kv_cache['k'] = k\n",
403
+ " kv_cache['v'] = v\n",
404
+ " else:\n",
405
+ " k = kv_cache['k']\n",
406
+ " v = kv_cache['v']\n",
407
+ "\n",
408
+ " q = q.view(q.shape[0], q.shape[1], self.n_head, -1)\n",
409
+ " k = k.view(k.shape[0], k.shape[1], self.n_head, -1)\n",
410
+ " v = v.view(v.shape[0], v.shape[1], self.n_head, -1)\n",
411
+ "\n",
412
+ " q = self.combined_rotary(q) \n",
413
+ " k = self.combined_rotary(k)\n",
414
+ "\n",
415
+ " q = q.view(q.shape[0], q.shape[1], -1)\n",
416
+ " k = k.view(k.shape[0], k.shape[1], -1)\n",
417
+ "\n",
418
+ " wv, qk = self.qkv_attention(q=q, k=k, v=v, mask=mask)\n",
419
+ " return self.out(wv), qk\n",
420
+ " \n",
421
+ " def qkv_attention(self, q, k, v, mask = None):\n",
422
+ " n_batch, n_ctx, n_state = q.shape\n",
423
+ "\n",
424
+ " scale = (n_state // self.n_head) ** -0.25\n",
425
+ " q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)\n",
426
+ " k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)\n",
427
+ " v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)\n",
428
+ "\n",
429
+ " qk = (q * scale) @ (k * scale).transpose(-1, -2)\n",
430
+ "\n",
431
+ " seq_len_q = q.size(2)\n",
432
+ " seq_len_k = k.size(2)\n",
433
+ "\n",
434
+ " positions = torch.arange(seq_len_q, device=q.device).unsqueeze(1) - torch.arange(seq_len_k, device=q.device).unsqueeze(0)\n",
435
+ " positions = positions.clamp(-self.max_rel_dist + 1, self.max_rel_dist - 1) + self.max_rel_dist - 1\n",
436
+ " rel_bias = self.rel_pos_bias(positions) \n",
437
+ " rel_bias = rel_bias.permute(2, 0, 1).unsqueeze(0) \n",
438
+ "\n",
439
+ " qk = qk + rel_bias\n",
440
+ "\n",
441
+ " if mask is not None:\n",
442
+ " qk = qk + mask[:n_ctx, :n_ctx]\n",
443
+ " qk = qk.float()\n",
444
+ "\n",
445
+ " w = F.softmax(qk, dim=-1).to(q.dtype)\n",
446
+ " out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)\n",
447
+ " qk = qk.detach()\n",
448
+ "\n",
449
+ " return out, qk\n",
450
+ "\n",
451
+ "class ResidualAttentionBlock(nn.Module):\n",
452
+ " def __init__(self, n_state, n_head, cross_attention = False, max_rel_dist = 1, checkpointing=False):\n",
453
+ " super().__init__()\n",
454
+ "\n",
455
+ " self.attn = MultiHeadAttention(n_state=n_state, n_head=n_head)\n",
456
+ " self.attn_ln = LayerNorm(num_features=n_state)\n",
457
+ " self.checkpointing = checkpointing\n",
458
+ " self.max_rel_dist = max_rel_dist\n",
459
+ "\n",
460
+ " self.cross_attn = (\n",
461
+ " MultiHeadAttention(n_state=n_state, n_head=n_head) if cross_attention else None\n",
462
+ " )\n",
463
+ " self.cross_attn_ln = LayerNorm(num_features=n_state) if cross_attention else None\n",
464
+ "\n",
465
+ " n_mlp = n_state * 4\n",
466
+ " self.mlp = nn.Sequential(\n",
467
+ " Linear(in_features=n_state, out_features=n_mlp), nn.GELU(), Linear(in_features=n_mlp, out_features=n_state)\n",
468
+ " )\n",
469
+ " self.mlp_ln = LayerNorm(num_features=n_state)\n",
470
+ "\n",
471
+ " def forward(self, x, xa = None, mask = None, kv_cache = None):\n",
472
+ " if self.checkpointing:\n",
473
+ " x = checkpoint(self._attn_forward, x, mask, kv_cache)\n",
474
+ " else:\n",
475
+ " x = self._attn_forward(x=x, mask=mask, kv_cache=kv_cache)\n",
476
+ "\n",
477
+ " if self.cross_attn:\n",
478
+ " if self.checkpointing:\n",
479
+ " x = checkpoint(self._cross_attn_forward, x, xa, kv_cache)\n",
480
+ " else:\n",
481
+ " x = self._cross_attn_forward(x=x, xa=xa, kv_cache=kv_cache)\n",
482
+ "\n",
483
+ " if self.checkpointing:\n",
484
+ " x = checkpoint(self._mlp_forward, x)\n",
485
+ " else:\n",
486
+ " x = self._mlp_forward(x=x)\n",
487
+ "\n",
488
+ " return x\n",
489
+ "\n",
490
+ " def _attn_forward(self, x, mask, kv_cache):\n",
491
+ " residual = x\n",
492
+ " x = self.attn_ln(x)\n",
493
+ " x = residual + self.attn(x, mask=mask, kv_cache=kv_cache)[0]\n",
494
+ " return x\n",
495
+ "\n",
496
+ " def _cross_attn_forward(self, x, xa, kv_cache):\n",
497
+ " residual = x\n",
498
+ " x = self.cross_attn_ln(x)\n",
499
+ " x = residual + self.cross_attn(x, xa, kv_cache=kv_cache)[0]\n",
500
+ " return x\n",
501
+ "\n",
502
+ " def _mlp_forward(self, x):\n",
503
+ " residual = x\n",
504
+ " x = self.mlp_ln(x)\n",
505
+ " x = residual + self.mlp(x)\n",
506
+ " return x\n",
507
+ "\n",
508
+ "class AudioEncoder(nn.Module):\n",
509
+ " def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, max_rel_dist = 1, cross_attention=True, checkpointing=False, base=10000):\n",
510
+ " super().__init__()\n",
511
+ " self.conv1 = nn.Conv1d(n_mels, n_state, kernel_size=3, padding=1)\n",
512
+ " self.conv2 = nn.Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)\n",
513
+ " self.positional_embedding = LearnedSinusoidalEmbeddings(n_ctx=n_ctx, n_state=n_state, checkpointing=checkpointing)\n",
514
+ " self.checkpointing = checkpointing\n",
515
+ " self.h_dim = n_state // n_head\n",
516
+ "\n",
517
+ " self.combined_rotary = CombinedRotaryEmbedding(\n",
518
+ " n_state=n_state,\n",
519
+ " n_head=n_head,\n",
520
+ " num_rotations=self.h_dim // 2,\n",
521
+ " base=base,\n",
522
+ " checkpointing=False \n",
523
+ " )\n",
524
+ "\n",
525
+ " self.blocks = nn.ModuleList(\n",
526
+ " [ResidualAttentionBlock(n_state=n_state, n_head=n_head, cross_attention=cross_attention, max_rel_dist=max_rel_dist, checkpointing=checkpointing) for _ in range(n_layer)]\n",
527
+ " )\n",
528
+ " self.ln_post = LayerNorm(num_features=n_state)\n",
529
+ "\n",
530
+ " def update_base(self, new_base):\n",
531
+ " self.combined_rotary.update_base(new_base=new_base)\n",
532
+ " for block in self.blocks:\n",
533
+ " if isinstance(block.attn, MultiHeadAttention, CombinedRotaryEmbedding):\n",
534
+ " block.attn.update_base(new_base)\n",
535
+ " if block.cross_attn and isinstance(block.cross_attn, MultiHeadAttention, CombinedRotaryEmbedding):\n",
536
+ " block.cross_attn.update_base(new_base)\n",
537
+ "\n",
538
+ " def forward(self, x):\n",
539
+ " if self.checkpointing:\n",
540
+ " x = checkpoint(self._conv_forward, x)\n",
541
+ " else:\n",
542
+ " x = self._conv_forward(x=x)\n",
543
+ "\n",
544
+ " for block in self.blocks:\n",
545
+ " if self.checkpointing:\n",
546
+ " x = checkpoint(block, x)\n",
547
+ " else:\n",
548
+ " x = block(x)\n",
549
+ "\n",
550
+ " x = self.ln_post(x)\n",
551
+ " return x\n",
552
+ "\n",
553
+ " def _conv_forward(self, x):\n",
554
+ " x = F.gelu(self.conv1(x))\n",
555
+ " x = F.gelu(self.conv2(x))\n",
556
+ " x = x.permute(0, 2, 1)\n",
557
+ "\n",
558
+ " x = self.combined_rotary(x)\n",
559
+ "\n",
560
+ " pos_emb = self.positional_embedding(torch.arange(x.size(1), device=x.device)).unsqueeze(0)\n",
561
+ " x = x + pos_emb\n",
562
+ " return x\n",
563
+ "\n",
564
+ "class TextDecoder(nn.Module):\n",
565
+ " def __init__(self, vocab_size, n_ctx, n_state, n_head, n_layer, max_rel_dist = 1, cross_attention=True, checkpointing=False, base=10000):\n",
566
+ " super().__init__()\n",
567
+ " self.token_embedding = nn.Embedding(vocab_size, n_state)\n",
568
+ " self.positional_embedding = LearnedSinusoidalEmbeddings(n_ctx=n_ctx, n_state=n_state, checkpointing=checkpointing)\n",
569
+ " self.checkpointing = checkpointing\n",
570
+ " self.n_head = n_head\n",
571
+ " self.h_dim = n_state // n_head\n",
572
+ " \n",
573
+ " self.combined_rotary = CombinedRotaryEmbedding(\n",
574
+ " n_state=n_state,\n",
575
+ " n_head=n_head,\n",
576
+ " num_rotations=self.h_dim // 2, \n",
577
+ " base=base,\n",
578
+ " checkpointing=False \n",
579
+ " )\n",
580
+ "\n",
581
+ " self.blocks = nn.ModuleList([\n",
582
+ " ResidualAttentionBlock(n_state=n_state, n_head=n_head, cross_attention=cross_attention, max_rel_dist=max_rel_dist, checkpointing=checkpointing)\n",
583
+ " for _ in range(n_layer)\n",
584
+ " ])\n",
585
+ " self.ln = LayerNorm(num_features=n_state)\n",
586
+ " mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)\n",
587
+ " self.register_buffer(\"mask\", mask, persistent=False)\n",
588
+ "\n",
589
+ " def update_base(self, new_base):\n",
590
+ " self.combined_rotary.update_base(new_base=new_base)\n",
591
+ " for block in self.blocks:\n",
592
+ " if isinstance(block.attn, MultiHeadAttention, CombinedRotaryEmbedding):\n",
593
+ " block.attn.update_base(new_base)\n",
594
+ " if block.cross_attn and isinstance(block.cross_attn, MultiHeadAttention, CombinedRotaryEmbedding):\n",
595
+ " block.cross_attn.update_base(new_base)\n",
596
+ "\n",
597
+ " def forward(self, x, xa, kv_cache = None):\n",
598
+ " if self.checkpointing:\n",
599
+ " x = checkpoint(self._embedding_forward, x, xa, kv_cache)\n",
600
+ " else:\n",
601
+ " x = self._embedding_forward(x=x, xa=xa, kv_cache=kv_cache)\n",
602
+ " for block in self.blocks:\n",
603
+ " if self.checkpointing:\n",
604
+ " x = checkpoint(block, x, xa, self.mask, kv_cache)\n",
605
+ " else:\n",
606
+ " x = block(x, xa, self.mask, kv_cache)\n",
607
+ " x = self.ln(x)\n",
608
+ " logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()\n",
609
+ " return logits\n",
610
+ "\n",
611
+ " def _embedding_forward(self, x, xa, kv_cache):\n",
612
+ " offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0\n",
613
+ " positions = torch.arange(x.shape[1], device=x.device) + offset\n",
614
+ " pos_emb = self.positional_embedding(positions).unsqueeze(0)\n",
615
+ " x = self.token_embedding(x) + pos_emb\n",
616
+ " x = x.to(xa.dtype)\n",
617
+ " batch_size, seq_length, embedding_dim = x.shape\n",
618
+ " num_heads = self.n_head\n",
619
+ " head_dim = embedding_dim // num_heads\n",
620
+ " x = x.view(batch_size, seq_length, num_heads, head_dim)\n",
621
+ " x = self.combined_rotary(x)\n",
622
+ " x = x.view(batch_size, seq_length, embedding_dim)\n",
623
+ " return x\n",
624
+ " \n",
625
+ "class Echo(WhisperPreTrainedModel, PyTorchModelHubMixin):\n",
626
+ " config_class = WhisperConfig\n",
627
+ "\n",
628
+ " def __init__(self, config: WhisperConfig):\n",
629
+ " super().__init__(config)\n",
630
+ " self.config = config\n",
631
+ "\n",
632
+ " self.n_mels = self.config.num_mel_bins\n",
633
+ " self.n_audio_ctx = self.config.max_source_positions\n",
634
+ " self.n_audio_state = self.config.d_model\n",
635
+ " self.n_audio_head = self.config.encoder_attention_heads\n",
636
+ " self.n_audio_layer = self.config.encoder_layers\n",
637
+ " self.vocab_size = self.config.vocab_size\n",
638
+ " self.n_text_ctx = self.config.max_target_positions\n",
639
+ " self.n_text_state = self.config.d_model\n",
640
+ " self.n_text_head = self.config.decoder_attention_heads\n",
641
+ " self.n_text_layer = self.config.decoder_layers\n",
642
+ " self.checkpointing = self.config.checkpointing\n",
643
+ " self.max_rel_dist = self.config.max_rel_dist\n",
644
+ " self.cross_attention = self.config.cross_attention\n",
645
+ " self.base = self.config.base\n",
646
+ "\n",
647
+ " self.encoder = AudioEncoder(\n",
648
+ " n_mels=self.config.n_mels,\n",
649
+ " n_ctx=self.config.n_audio_ctx,\n",
650
+ " n_state=self.config.n_audio_state,\n",
651
+ " n_head=self.config.n_audio_head,\n",
652
+ " n_layer=self.config.n_audio_layer,\n",
653
+ " max_rel_dist=self.config.checkpointing,\n",
654
+ " cross_attention=self.config.max_rel_dist,\n",
655
+ " checkpointing=self.config.cross_attention,\n",
656
+ " base=self.config.base,\n",
657
+ " )\n",
658
+ " self.decoder = TextDecoder(\n",
659
+ " vocab_size=self.config.vocab_size,\n",
660
+ " n_ctx=self.config.n_text_ctx,\n",
661
+ " n_state=self.config.n_text_state,\n",
662
+ " n_head=self.config.n_text_head,\n",
663
+ " n_layer=self.config.n_text_layer,\n",
664
+ " max_rel_dist=self.config.checkpointing,\n",
665
+ " cross_attention=self.config.max_rel_dist,\n",
666
+ " checkpointing=self.config.cross_attention,\n",
667
+ " base=self.config.base,\n",
668
+ " )\n",
669
+ "\n",
670
+ " all_heads = torch.zeros(self.config.n_text_layer, self.config.n_text_head, dtype=torch.bool)\n",
671
+ " all_heads[self.config.n_text_layer // 2:] = True\n",
672
+ " self.register_buffer(\"alignment_heads\", all_heads.to_sparse(), persistent=False)\n",
673
+ "\n",
674
+ " self.best_loss = float('inf')\n",
675
+ " self.base = 10000 \n",
676
+ "\n",
677
+ " def update_base(self, new_base):\n",
678
+ " self.encoder.combined_rotary.update_base(new_base=new_base)\n",
679
+ " self.decoder.combined_rotary.update_base(new_base=new_base)\n",
680
+ "\n",
681
+ " for name, module in self.encoder.named_modules():\n",
682
+ " if isinstance(module, (MultiHeadAttention, CombinedRotaryEmbedding)):\n",
683
+ " module.update_base(new_base=new_base)\n",
684
+ "\n",
685
+ " for name, module in self.decoder.named_modules():\n",
686
+ " if isinstance(module, (MultiHeadAttention, CombinedRotaryEmbedding)):\n",
687
+ " module.update_base(new_base=new_base)\n",
688
+ "\n",
689
+ " def adjust_base(self, loss, factor=1.05):\n",
690
+ " if loss < self.best_loss:\n",
691
+ " new_base = self.base * factor\n",
692
+ " else:\n",
693
+ " new_base = self.base / factor\n",
694
+ "\n",
695
+ " self.update_base(new_base=new_base)\n",
696
+ " self.best_loss = loss\n",
697
+ " # print(f\"Adjusted base: {new_base}\")\n",
698
+ "\n",
699
+ " @staticmethod\n",
700
+ " def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id) -> torch.Tensor:\n",
701
+ " shifted_input_ids = input_ids.new_zeros(input_ids.shape)\n",
702
+ " shifted_input_ids[:, 1:] = input_ids[:, :-1]\n",
703
+ " shifted_input_ids[:, 0] = decoder_start_token_id\n",
704
+ " shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)\n",
705
+ " return shifted_input_ids\n",
706
+ "\n",
707
+ " def forward(self, input_features, labels=None, dec_input_ids=None):\n",
708
+ " if labels is not None:\n",
709
+ " if dec_input_ids is None:\n",
710
+ " dec_input_ids = self.shift_tokens_right(\n",
711
+ " input_ids=labels, pad_token_id=self.config.pad_token_id, decoder_start_token_id=self.config.decoder_start_token_id\n",
712
+ " )\n",
713
+ "\n",
714
+ " encoded_features = self.encoder(input_features).to(device)\n",
715
+ " logits = self.decoder(dec_input_ids, encoded_features)\n",
716
+ "\n",
717
+ " loss = None\n",
718
+ " if labels is not None:\n",
719
+ " loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100) \n",
720
+ " labels = labels.to(logits.device).long()\n",
721
+ " loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))\n",
722
+ "\n",
723
+ " self.adjust_base(loss.item())\n",
724
+ "\n",
725
+ " return {\n",
726
+ " \"loss\": loss,\n",
727
+ " \"logits\": logits,\n",
728
+ " }\n",
729
+ "\n",
730
+ " def _initialize_weights(self):\n",
731
+ " nn.init.normal_(self.decoder.token_embedding.weight, mean=0.0, std=self.config.init_std)\n",
732
+ " if hasattr(self.decoder.positional_embedding, 'weight'):\n",
733
+ " nn.init.normal_(self.decoder.positional_embedding.weight, mean=0.0, std=self.config.init_std)\n",
734
+ " for block in self.decoder.blocks:\n",
735
+ " for layer in block.children():\n",
736
+ " if isinstance(layer, nn.Linear):\n",
737
+ " nn.init.xavier_normal_(layer.weight)\n",
738
+ " if layer.bias is not None:\n",
739
+ " nn.init.zeros_(layer.bias)\n",
740
+ "\n",
741
+ " nn.init.constant_(self.decoder.ln.gamma, 1)\n",
742
+ " if self.decoder.ln.beta is not None:\n",
743
+ " nn.init.constant_(self.decoder.ln.beta, 0)\n",
744
+ "\n",
745
+ " nn.init.xavier_normal_(self.encoder.conv1.weight)\n",
746
+ " if self.encoder.conv1.bias is not None:\n",
747
+ " nn.init.zeros_(self.encoder.conv1.bias)\n",
748
+ "\n",
749
+ " nn.init.kaiming_normal_(self.encoder.conv2.weight, mode='fan_out', nonlinearity='relu')\n",
750
+ " if self.encoder.conv2.bias is not None:\n",
751
+ " nn.init.zeros_(self.encoder.conv2.bias)\n",
752
+ "\n",
753
+ " nn.init.constant_(self.encoder.ln_post.gamma, 1)\n",
754
+ " if self.encoder.ln_post.beta is not None:\n",
755
+ " nn.init.constant_(self.encoder.ln_post.beta, 0)\n",
756
+ " \n",
757
+ " def apply_initialization(self):\n",
758
+ " self._initialize_weights()\n",
759
+ "\n",
760
+ " def set_alignment_heads(self, dump: bytes):\n",
761
+ " array = np.frombuffer(\n",
762
+ " gzip.decompress(base64.b85decode(dump)), dtype=bool\n",
763
+ " ).copy()\n",
764
+ " mask = torch.from_numpy(array).reshape(\n",
765
+ " self.config.n_text_layer, self.config.n_text_head\n",
766
+ " )\n",
767
+ " self.register_buffer(\"alignment_heads\", mask.to_sparse(), persistent=False)\n",
768
+ "\n",
769
+ " def embed_audio(self, mel):\n",
770
+ " return self.encoder(mel)\n",
771
+ "\n",
772
+ " def logits(self, labels, input_features):\n",
773
+ " return self.decoder(labels, input_features)\n",
774
+ "\n",
775
+ " @property\n",
776
+ " def device(self):\n",
777
+ " return next(self.parameters()).device\n",
778
+ "\n",
779
+ " @property\n",
780
+ " def is_multilingual(self):\n",
781
+ " return self.config.vocab_size >= len(tokenizer)\n",
782
+ "\n",
783
+ " @property\n",
784
+ " def num_languages(self):\n",
785
+ " return self.config.vocab_size - (len(tokenizer)-100) - int(self.is_multilingual)\n",
786
+ "\n",
787
+ " def install_kv_cache_hooks(self, cache = None):\n",
788
+ " cache = {**cache} if cache is not None else {}\n",
789
+ " hooks = []\n",
790
+ "\n",
791
+ " def save_to_cache(module, _, output):\n",
792
+ " if module not in cache or output.shape[1] > self.config.n_text_ctx:\n",
793
+ " cache[module] = output\n",
794
+ " else:\n",
795
+ " cache[module] = torch.cat([cache[module], output], dim=1).detach()\n",
796
+ " return cache[module]\n",
797
+ "\n",
798
+ " def install_hooks(layer: nn.Module):\n",
799
+ " if isinstance(layer, MultiHeadAttention):\n",
800
+ " hooks.append(layer.key.register_forward_hook(save_to_cache))\n",
801
+ " hooks.append(layer.value.register_forward_hook(save_to_cache))\n",
802
+ "\n",
803
+ " self.decoder.apply(install_hooks)\n",
804
+ " return cache, hooks\n",
805
+ "\n",
806
+ " detect_language = detect_language_function\n",
807
+ " transcribe = transcribe_function\n",
808
+ " decode = decode_function\n",
809
+ "\n",
810
+ " def get_encoder(self):\n",
811
+ " return self.encoder\n",
812
+ "\n",
813
+ " def prepare_inputs_for_generation(self, input_ids, **kwargs):\n",
814
+ " return {'input_features': input_ids}\n",
815
+ "\n",
816
+ " def _prepare_decoder_input_ids_for_generation(self, batch_size, decoder_start_token_id=None, bos_token_id=None):\n",
817
+ " return torch.ones((batch_size, 1), dtype=torch.long, device=self.device) * self.config.decoder_start_token_id\n",
818
+ "\n",
819
+ " def can_generate(self):\n",
820
+ " return True\n",
821
+ " \n",
822
+ " def generate(self, inputs, **kwargs):\n",
823
+ " encoder_outputs = self.encoder(inputs)\n",
824
+ " decoder_input_ids = torch.zeros((inputs.size(0), 1), dtype=torch.long, device=inputs.device)\n",
825
+ " outputs = self.decoder(decoder_input_ids, encoder_outputs)\n",
826
+ " return outputs.argmax(dim=-1)\n",
827
+ "\n",
828
+ " def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):\n",
829
+ " if not self.supports_gradient_checkpointing:\n",
830
+ " raise ValueError(f\"{self.__class__.__name__} does not support gradient checkpointing.\")\n",
831
+ " if gradient_checkpointing_kwargs is None:\n",
832
+ " gradient_checkpointing_kwargs = {\"use_reentrant\": True}\n",
833
+ " gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)\n",
834
+ " _is_using_old_format = \"value\" in inspect.signature(self._set_gradient_checkpointing).parameters\n",
835
+ " if not _is_using_old_format:\n",
836
+ " self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)\n",
837
+ " else:\n",
838
+ " self.apply(partial(self._set_gradient_checkpointing, value=True))\n",
839
+ " if getattr(self, \"_hf_peft_config_loaded\", False):\n",
840
+ " self.enable_input_require_grads()\n",
841
+ "\n",
842
+ "config = WhisperConfig(\n",
843
+ " n_mels=128,\n",
844
+ " n_audio_ctx=1500,\n",
845
+ " n_audio_state=1024,\n",
846
+ " n_audio_head=16,\n",
847
+ " n_audio_layer=24,\n",
848
+ " vocab_size=51865,\n",
849
+ " n_text_ctx=448,\n",
850
+ " n_text_state=1024,\n",
851
+ " n_text_head=16,\n",
852
+ " n_text_layer=20,\n",
853
+ " max_rel_dist=50,\n",
854
+ " cross_attention=True,\n",
855
+ " checkpointing=True,\n",
856
+ " base=10000,\n",
857
+ " bos_token_id = 50257,\n",
858
+ " eos_token_id = 50257,\n",
859
+ " pad_token_id = 50257,\n",
860
+ " decoder_start_token_id = 50258,\n",
861
+ " is_encoder_decoder = True,\n",
862
+ " init_std=0.02,\n",
863
+ " )"
864
+ ]
865
+ },
866
+ {
867
+ "cell_type": "code",
868
+ "execution_count": null,
869
+ "metadata": {},
870
+ "outputs": [],
871
+ "source": [
872
+ "class GradientClippingCallback(TrainerCallback):\n",
873
+ " def on_step_end(self, args, state, control, **kwargs):\n",
874
+ " torch.nn.utils.clip_grad_norm_(kwargs[\"model\"].parameters(), max_norm=0.95)\n",
875
+ "\n",
876
+ "class MetricsCallback(TrainerCallback):\n",
877
+ " def __init__(self, tb_writer, tokenizer, metric, log_every_n_steps=20):\n",
878
+ " super().__init__()\n",
879
+ " self.tb_writer = tb_writer\n",
880
+ " self.tokenizer = tokenizer\n",
881
+ " self.metric = metric\n",
882
+ " self.log_every_n_steps = log_every_n_steps\n",
883
+ " self.predictions = None\n",
884
+ " self.label_ids = None\n",
885
+ "\n",
886
+ " def compute_cer(self, pred_str, label_str):\n",
887
+ " cer = 100 * self.metric.compute(predictions=pred_str, references=label_str)\n",
888
+ " return cer\n",
889
+ "\n",
890
+ " def on_evaluate(self, args, state, control, metrics=None, **kwargs):\n",
891
+ " if metrics is not None:\n",
892
+ " for key, value in metrics.items():\n",
893
+ " if key.startswith(\"eval_\"):\n",
894
+ " self.tb_writer.add_scalar(key, value, state.global_step)\n",
895
+ " print(f\"Step {state.global_step} - {key}: {value}\")\n",
896
+ "\n",
897
+ " if self.predictions is not None and self.label_ids is not None:\n",
898
+ " pred_str = self.tokenizer.batch_decode(self.predictions, skip_special_tokens=True)\n",
899
+ " label_str = self.tokenizer.batch_decode(self.label_ids, skip_special_tokens=True)\n",
900
+ "\n",
901
+ " sample_index = 1\n",
902
+ " self.tb_writer.add_text(\"Prediction\", pred_str[sample_index], state.global_step)\n",
903
+ " self.tb_writer.add_text(\"Label\", label_str[sample_index], state.global_step)\n",
904
+ "\n",
905
+ " print(f\"Step {state.global_step} - Sample Prediction: {pred_str[sample_index]}\")\n",
906
+ " print(f\"Step {state.global_step} - Sample Label: {label_str[sample_index]}\")\n",
907
+ "\n",
908
+ " self.predictions = None\n",
909
+ " self.label_ids = None\n",
910
+ "\n",
911
+ "def create_compute_metrics(callback_instance):\n",
912
+ " def compute_metrics(eval_pred):\n",
913
+ " pred_logits = eval_pred.predictions\n",
914
+ " label_ids = eval_pred.label_ids\n",
915
+ "\n",
916
+ " if isinstance(pred_logits, tuple):\n",
917
+ " pred_ids = pred_logits[0]\n",
918
+ " else:\n",
919
+ " pred_ids = pred_logits\n",
920
+ " if pred_ids.ndim == 3:\n",
921
+ " pred_ids = np.argmax(pred_ids, axis=-1)\n",
922
+ "\n",
923
+ " label_ids[label_ids == -100] = callback_instance.tokenizer.pad_token_id\n",
924
+ " callback_instance.predictions = pred_ids\n",
925
+ " callback_instance.label_ids = label_ids\n",
926
+ "\n",
927
+ " pred_str = callback_instance.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n",
928
+ " label_str = callback_instance.tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n",
929
+ " cer = 100 * callback_instance.metric.compute(predictions=pred_str, references=label_str)\n",
930
+ "\n",
931
+ " pred_flat = pred_ids.flatten()\n",
932
+ " labels_flat = label_ids.flatten()\n",
933
+ " mask = labels_flat != callback_instance.tokenizer.pad_token_id\n",
934
+ "\n",
935
+ " accuracy = accuracy_score(y_true=labels_flat[mask], y_pred=pred_flat[mask])\n",
936
+ " precision = precision_score(y_true=labels_flat[mask], y_pred=pred_flat[mask], average='weighted', zero_division=0)\n",
937
+ " recall = recall_score(y_true=labels_flat[mask], y_pred=pred_flat[mask], average='weighted', zero_division=0)\n",
938
+ " f1 = f1_score(y_true=labels_flat[mask], y_pred=pred_flat[mask], average='weighted', zero_division=0)\n",
939
+ "\n",
940
+ " return {\n",
941
+ " \"cer\": cer,\n",
942
+ " \"accuracy\": accuracy,\n",
943
+ " \"precision\": precision,\n",
944
+ " \"recall\": recall,\n",
945
+ " \"f1\": f1\n",
946
+ " }\n",
947
+ " return compute_metrics\n"
948
+ ]
949
+ },
950
+ {
951
+ "cell_type": "code",
952
+ "execution_count": null,
953
+ "metadata": {},
954
+ "outputs": [],
955
+ "source": [
956
+ "def prepare_dataset(batch):\n",
957
+ " audio = batch[\"audio\"]\n",
958
+ " batch[\"input_features\"] = feature_extractor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"]).input_features[0]\n",
959
+ " transcription = batch[\"sentence\"]\n",
960
+ " batch[\"labels\"] = tokenizer(transcription).input_ids\n",
961
+ " return batch\n",
962
+ "\n",
963
+ "def get_length_of_dataset(dataset):\n",
964
+ " length = 0\n",
965
+ " for item in dataset: length += (len(item[\"audio\"][\"array\"]) / item[\"audio\"][\"sampling_rate\"])\n",
966
+ " return length//3600\n",
967
+ "\n",
968
+ "@dataclass\n",
969
+ "class DataCollatorSpeechSeq2SeqWithPadding:\n",
970
+ " processor: Any\n",
971
+ " tokenizer: Any\n",
972
+ " feature_extractor: Any\n",
973
+ "\n",
974
+ " def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n",
975
+ " input_features = [{\"input_features\": feature[\"input_features\"]} for feature in features]\n",
976
+ " batch = self.feature_extractor.pad(input_features, return_tensors=\"pt\")\n",
977
+ " label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n",
978
+ " labels_batch = self.tokenizer.pad(label_features, return_tensors=\"pt\")\n",
979
+ " labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n",
980
+ " if (labels[:, 0] == self.tokenizer.bos_token_id).all().cpu().item():\n",
981
+ " labels = labels[:, 1:]\n",
982
+ " batch[\"labels\"] = labels\n",
983
+ " return batch\n",
984
+ "\n",
985
+ "feature_extractor = WhisperFeatureExtractor.from_pretrained(pretrained_model_name_or_path=\"openai/whisper-small\", feature_size=128, n_fft=1024, hop_length=256, sampling_rate=16000)\n",
986
+ "tokenizer = WhisperTokenizerFast.from_pretrained(pretrained_model_name_or_path=\"D:/newproject/new_tokenizer2\")\n",
987
+ "processor = WhisperProcessor.from_pretrained(pretrained_model_name_or_path=\"openai/whisper-small\", language=\"Japanese\", task=\"transcribe\")\n",
988
+ "\n",
989
+ "def train():\n",
990
+ "\n",
991
+ " from datetime import datetime\n",
992
+ " log_dir = os.path.join('./output/', datetime.now().strftime('%Y-%m-%d_%H'))\n",
993
+ " os.makedirs(log_dir, exist_ok=True)\n",
994
+ "\n",
995
+ " name=\"echo\"\n",
996
+ " model = Echo(config=config).to(device)\n",
997
+ " model.apply_initialization()\n",
998
+ " model.save_pretrained(log_dir)\n",
999
+ " torch.save(obj=model.state_dict(), f=log_dir+name)\n",
1000
+ " # model = Echo.from_pretrained(log_dir).to(device)\n",
1001
+ "\n",
1002
+ " dataset = load_dataset(path=\"audiofolder\", data_dir=\"D:/proj/datasets/gv\")[\"train\"].to_iterable_dataset(num_shards=20).filter(lambda x: bool(x['sentence']))\n",
1003
+ " dataset = dataset.map(prepare_dataset).select_columns([\"input_features\", \"labels\"])\n",
1004
+ " test, train = dataset.take(100), dataset.skip(100)\n",
1005
+ " data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor, tokenizer=tokenizer, feature_extractor=feature_extractor)\n",
1006
+ " \n",
1007
+ " optimizer = transformers.Adafactor(params=model.parameters(), \n",
1008
+ " clip_threshold=0.99, \n",
1009
+ " weight_decay=0.005, \n",
1010
+ " scale_parameter=True, \n",
1011
+ " relative_step=True, \n",
1012
+ " warmup_init=True, \n",
1013
+ " lr=None)\n",
1014
+ "\n",
1015
+ " scheduler = transformers.optimization.AdafactorSchedule(optimizer=optimizer, initial_lr=2.25e-5)\n",
1016
+ " loss_fn = torch.nn.CrossEntropyLoss(ignore_index=-100)\n",
1017
+ "\n",
1018
+ " metric = evaluate.load(path=\"cer\")\n",
1019
+ " tb_writer = SummaryWriter(log_dir=log_dir)\n",
1020
+ "\n",
1021
+ " metrics_callback = MetricsCallback(tb_writer=tb_writer, tokenizer=tokenizer, metric=metric, log_every_n_steps=20)\n",
1022
+ " compute_metrics = create_compute_metrics(metrics_callback)\n",
1023
+ "\n",
1024
+ " torch.backends.cuda.matmul.allow_tf32 = True\n",
1025
+ " torch.backends.cudnn.allow_tf32 = True\n",
1026
+ " torch.cuda.empty_cache()\n",
1027
+ " torch.cuda.set_device(device=0)\n",
1028
+ "\n",
1029
+ " training_args = Seq2SeqTrainingArguments(\n",
1030
+ " output_dir=\"./test\", \n",
1031
+ " per_device_train_batch_size=1,\n",
1032
+ " per_device_eval_batch_size=1,\n",
1033
+ " gradient_accumulation_steps=1,\n",
1034
+ " eval_accumulation_steps=1,\n",
1035
+ " tf32=True,\n",
1036
+ " bf16=True,\n",
1037
+ " learning_rate=1e-5,\n",
1038
+ " warmup_steps=500,\n",
1039
+ " evaluation_strategy=\"steps\",\n",
1040
+ " max_steps=10000,\n",
1041
+ " save_steps=1000,\n",
1042
+ " eval_steps=50,\n",
1043
+ " logging_steps=5,\n",
1044
+ " report_to=[\"tensorboard\"],\n",
1045
+ " load_best_model_at_end=True,\n",
1046
+ " metric_for_best_model=\"loss\",\n",
1047
+ " greater_is_better=False,\n",
1048
+ " push_to_hub=False,\n",
1049
+ " optim=\"adafactor\",\n",
1050
+ " weight_decay=0.0025,\n",
1051
+ " disable_tqdm=False,\n",
1052
+ " save_total_limit=2,\n",
1053
+ " torch_empty_cache_steps=10,\n",
1054
+ " )\n",
1055
+ "\n",
1056
+ " trainer = Seq2SeqTrainer(\n",
1057
+ " args=training_args,\n",
1058
+ " model=model,\n",
1059
+ " train_dataset=train,\n",
1060
+ " eval_dataset=test,\n",
1061
+ " data_collator=data_collator,\n",
1062
+ " compute_metrics=compute_metrics,\n",
1063
+ " tokenizer=feature_extractor,\n",
1064
+ " callbacks=[metrics_callback]\n",
1065
+ " ) \n",
1066
+ "\n",
1067
+ " trainer.train(resume_from_checkpoint=False)\n",
1068
+ "\n",
1069
+ "\n",
1070
+ "if __name__==\"__main__\":\n",
1071
+ "\n",
1072
+ " train()\n",
1073
+ " import tensorboard\n",
1074
+ "\n",
1075
+ " # model.save_pretrained(\"./models/echo_train\")"
1076
+ ]
1077
+ },
1078
+ {
1079
+ "cell_type": "code",
1080
+ "execution_count": null,
1081
+ "metadata": {},
1082
+ "outputs": [],
1083
+ "source": [
1084
+ "# torch.backends.cudnn.benchmark = True\n",
1085
+ "# torch.autograd.set_detect_anomaly(False)\n",
1086
+ "# torch.autograd.profiler.profile(False)\n",
1087
+ "# torch.autograd.profiler.emit_nvtx(False)\n",
1088
+ "# torch.autograd.profiler.record_function_enter_exit(False)\n",
1089
+ "# torch.autograd.gradcheck(False)\n",
1090
+ "# torch.autograd.gradgradcheck(False)\n",
1091
+ "# torch.autograd.set_grad_enabled(True)\n",
1092
+ "# torch.autograd.detect_anomaly(False)\n",
1093
+ "# torch.autograd.profiler.profile(False)\n",
1094
+ "# torch.autograd.profiler.emit_nvtx(False)\n",
1095
+ "# torch.autograd.profiler.record_function_enter_exit(False)\n",
1096
+ "# torch.autograd.gradcheck(False)\n",
1097
+ "# torch.autograd.gradgradcheck(False)\n",
1098
+ "# torch.autograd.set_grad_enabled(True)\n",
1099
+ "# torch.autograd.detect_anomaly(False)\n",
1100
+ "# torch.autograd.profiler.profile(False)\n",
1101
+ "# torch.autograd.profiler.emit_nvtx(False)\n",
1102
+ "# torch.autograd.profiler.record_function_enter_exit(False)\n",
1103
+ "# torch.autograd.gradcheck(False)\n",
1104
+ "# torch.autograd.gradgradcheck(False)\n",
1105
+ "# torch.autograd.set_grad_enabled(True)\n",
1106
+ "# torch.autograd.detect_anomaly(False)\n",
1107
+ "# torch.autograd.profiler.profile(False)\n",
1108
+ "# torch.autograd.profiler.emit_nvtx(False)\n",
1109
+ "# torch.autograd.profiler.record_function_enter_exit(False)\n",
1110
+ "# torch.autograd.gradcheck(False)\n",
1111
+ "# torch.autograd.gradgradcheck(False)\n",
1112
+ "# torch.autograd.set_grad_enabled(True)\n",
1113
+ "# torch.autograd.detect_anomaly(False)\n",
1114
+ "# torch.autograd.profiler.profile(False)\n",
1115
+ "# torch.autograd.profiler.emit_nvtx(False)\n",
1116
+ "# torch.autograd.profiler.record_function_enter_exit(False)\n",
1117
+ "# torch.autograd.gradcheck(False)\n",
1118
+ "# torch.autograd.gradgradcheck(False)\n",
1119
+ "# torch.autograd.set_grad_enabled(True)\n",
1120
+ "# torch.autograd.detect_anomaly(False)\n",
1121
+ "# torch.autograd.profiler.profile(False)\n",
1122
+ "# torch.autograd.profiler.emit_nvtx(False)\n",
1123
+ "# torch.autograd.profiler.record_function_enter_exit(False)\n",
1124
+ "# torch.autograd.gradcheck(False)\n",
1125
+ "# torch.autograd.gradgradcheck(False)\n",
1126
+ "# torch.autograd.set_grad_enabled(True)\n",
1127
+ "# torch.autograd.detect_anomaly(False)\n",
1128
+ "# torch.autograd.profiler.profile(False)\n",
1129
+ "# torch.autograd.profiler.emit_nvtx(False)\n",
1130
+ "# torch.autograd.profiler.record_function_enter_exit(False)\n",
1131
+ "# torch.autograd.gradcheck(False)\n",
1132
+ "# torch.autograd.gradgradcheck(False)\n",
1133
+ "# torch.autograd.set_grad_enabled(True)\n",
1134
+ "# torch.autograd.detect_anomaly(False)\n",
1135
+ "# torch.autograd.profiler.profile(False)\n",
1136
+ "# torch.autograd.profiler.emit_nvtx(False)\n",
1137
+ "# torch.autograd.profiler.record_function_enter_exit(False)\n",
1138
+ "# torch.autograd.gradcheck(False)\n",
1139
+ "# torch.autograd.gradgradcheck(False)\n",
1140
+ "# torch.autograd.set_grad_enabled(True)\n",
1141
+ "# torch.autograd.detect_anomaly(False)"
1142
+ ]
1143
+ }
1144
+ ],
1145
+ "metadata": {
1146
+ "kernelspec": {
1147
+ "display_name": "Python 3",
1148
+ "language": "python",
1149
+ "name": "python3"
1150
+ },
1151
+ "language_info": {
1152
+ "codemirror_mode": {
1153
+ "name": "ipython",
1154
+ "version": 3
1155
+ },
1156
+ "file_extension": ".py",
1157
+ "mimetype": "text/x-python",
1158
+ "name": "python",
1159
+ "nbconvert_exporter": "python",
1160
+ "pygments_lexer": "ipython3",
1161
+ "version": "3.10.0"
1162
+ }
1163
+ },
1164
+ "nbformat": 4,
1165
+ "nbformat_minor": 2
1166
+ }