Delete echopipeline.py
Browse files- echopipeline.py +0 -662
echopipeline.py
DELETED
@@ -1,662 +0,0 @@
|
|
1 |
-
import pyworld as pw
|
2 |
-
import os
|
3 |
-
import math
|
4 |
-
import logging
|
5 |
-
import torch
|
6 |
-
import torchaudio
|
7 |
-
import torch.nn.functional as F
|
8 |
-
import numpy as np
|
9 |
-
from typing import Optional, Dict, Union, List, Tuple, Any
|
10 |
-
from functools import partial
|
11 |
-
from datetime import datetime
|
12 |
-
from datasets import load_dataset, Audio, concatenate_datasets
|
13 |
-
from transformers.trainer_seq2seq import Seq2SeqTrainer
|
14 |
-
from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
|
15 |
-
import evaluate
|
16 |
-
from dataclasses import dataclass
|
17 |
-
|
18 |
-
extractor = None
|
19 |
-
tokenizer = None
|
20 |
-
optimizer = None
|
21 |
-
scheduler = None
|
22 |
-
model = None
|
23 |
-
Residual = None
|
24 |
-
MultiheadA = None
|
25 |
-
Echo = None
|
26 |
-
|
27 |
-
metric = evaluate.load(path="wer")
|
28 |
-
|
29 |
-
@dataclass
|
30 |
-
class Dimensions:
|
31 |
-
vocab: int
|
32 |
-
text_ctx: int
|
33 |
-
text_dims: int
|
34 |
-
text_head: int
|
35 |
-
text_idx: int
|
36 |
-
mels: int
|
37 |
-
aud_ctx: int
|
38 |
-
aud_dims: int
|
39 |
-
aud_head: int
|
40 |
-
aud_idx: int
|
41 |
-
act: str
|
42 |
-
debug: List[str]
|
43 |
-
cross_attn: bool
|
44 |
-
features: List[str]
|
45 |
-
f0_rotary: bool
|
46 |
-
|
47 |
-
def align_f0(f0, ctx):
|
48 |
-
ctx = torch.tensor(ctx)
|
49 |
-
bat, length = f0.shape
|
50 |
-
if length == ctx:
|
51 |
-
return f0
|
52 |
-
frames = length / ctx
|
53 |
-
idx = torch.arange(ctx, device=f0.device)
|
54 |
-
idx = (idx * frames).long()
|
55 |
-
batch_idx = torch.arange(bat, device=f0.device).unsqueeze(1)
|
56 |
-
return f0[batch_idx, idx.unsqueeze(0).expand(bat, -1)]
|
57 |
-
|
58 |
-
@dataclass
|
59 |
-
class DataCollator:
|
60 |
-
tokenizer: Any
|
61 |
-
def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
62 |
-
pad_token_id = tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else 0
|
63 |
-
bos_token_id = tokenizer.bos_token_id if hasattr(tokenizer, 'bos_token_id') else 1
|
64 |
-
|
65 |
-
batch = {}
|
66 |
-
|
67 |
-
if "spectrogram" in features[0] and features[0]["spectrogram"] is not None:
|
68 |
-
spectrogram_list = [f["spectrogram"] for f in features]
|
69 |
-
max_len_feat = max(f.shape[-1] for f in spectrogram_list)
|
70 |
-
pad_spectrogram = []
|
71 |
-
for feat in spectrogram_list:
|
72 |
-
current_len = feat.shape[-1]
|
73 |
-
padding = max_len_feat - current_len
|
74 |
-
if padding > 0:
|
75 |
-
pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
|
76 |
-
else:
|
77 |
-
pad_feat = feat
|
78 |
-
pad_spectrogram.append(pad_feat)
|
79 |
-
batch["spectrogram"] = torch.stack(pad_spectrogram)
|
80 |
-
|
81 |
-
if "waveform" in features[0] and features[0]["waveform"] is not None:
|
82 |
-
waveform_list = [f["waveform"] for f in features]
|
83 |
-
max_len_wav = max(w.shape[-1] for w in waveform_list)
|
84 |
-
pad_waveforms = []
|
85 |
-
for wav in waveform_list:
|
86 |
-
current_len = wav.shape[-1]
|
87 |
-
padding = max_len_wav - current_len
|
88 |
-
if padding > 0:
|
89 |
-
if wav.ndim == 1:
|
90 |
-
wav = wav.unsqueeze(0)
|
91 |
-
pad_wav = F.pad(wav, (0, padding), mode='constant', value=pad_token_id)
|
92 |
-
else:
|
93 |
-
pad_wav = wav
|
94 |
-
pad_waveforms.append(pad_wav)
|
95 |
-
batch["waveform"] = torch.stack(pad_waveforms)
|
96 |
-
|
97 |
-
if "label" in features[0] and features[0]["label"] is not None:
|
98 |
-
labels_list = [f["label"] for f in features]
|
99 |
-
max_len = max(len(l) for l in labels_list)
|
100 |
-
all_ids = []
|
101 |
-
all_labels = []
|
102 |
-
|
103 |
-
for label in labels_list:
|
104 |
-
label_list = label.tolist() if isinstance(label, torch.Tensor) else label
|
105 |
-
decoder_input = [bos_token_id] + label_list
|
106 |
-
label_eos = label_list + [pad_token_id]
|
107 |
-
input_len = max_len + 1 - len(decoder_input)
|
108 |
-
label_len = max_len + 1 - len(label_eos)
|
109 |
-
padded_input = decoder_input + [pad_token_id] * input_len
|
110 |
-
padded_labels = label_eos + [pad_token_id] * label_len
|
111 |
-
all_ids.append(padded_input)
|
112 |
-
all_labels.append(padded_labels)
|
113 |
-
batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long)
|
114 |
-
batch["labels"] = torch.tensor(all_labels, dtype=torch.long)
|
115 |
-
|
116 |
-
if "pitch" in features[0] and features[0]["pitch"] is not None:
|
117 |
-
pitch_list = [f["pitch"] for f in features]
|
118 |
-
max_len_pitch = max(e.shape[-1] for e in pitch_list)
|
119 |
-
pad_pitch = []
|
120 |
-
for pitch in pitch_list:
|
121 |
-
current_len = pitch.shape[-1]
|
122 |
-
padding = max_len_pitch - current_len
|
123 |
-
if padding > 0:
|
124 |
-
pad_pitch_item = F.pad(pitch, (0, padding), mode='constant', value=pad_token_id)
|
125 |
-
else:
|
126 |
-
pad_pitch_item = pitch
|
127 |
-
pad_pitch.append(pad_pitch_item)
|
128 |
-
batch["pitch"] = torch.stack(pad_pitch)
|
129 |
-
|
130 |
-
if "f0" in features[0] and features[0]["f0"] is not None:
|
131 |
-
input_ids_batch = batch.get("input_ids", None)
|
132 |
-
if input_ids_batch is not None:
|
133 |
-
target_length = input_ids_batch.shape[-1]
|
134 |
-
aligned_list = []
|
135 |
-
original_list = []
|
136 |
-
for feature in features:
|
137 |
-
f0 = feature["f0"]
|
138 |
-
original_list.append(f0)
|
139 |
-
if f0.shape[-1] != target_length:
|
140 |
-
aligned_f0 = align_f0(f0.unsqueeze(0), target_length).squeeze(0)
|
141 |
-
else:
|
142 |
-
aligned_f0 = f0
|
143 |
-
aligned_list.append(aligned_f0)
|
144 |
-
batch["f0d"] = torch.stack(aligned_list)
|
145 |
-
batch["f0"] = torch.stack(original_list)
|
146 |
-
|
147 |
-
if "envelope" in features[0] and features[0]["envelope"] is not None:
|
148 |
-
env_list = [f["envelope"] for f in features]
|
149 |
-
max_len = max(f.shape[-1] for f in env_list)
|
150 |
-
pad_env = []
|
151 |
-
for feat in env_list:
|
152 |
-
current_len = feat.shape[-1]
|
153 |
-
padding = max_len_feat - current_len
|
154 |
-
if padding > 0:
|
155 |
-
pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
|
156 |
-
else:
|
157 |
-
pad_feat = feat
|
158 |
-
pad_env.append(pad_feat)
|
159 |
-
batch["envelope"] = torch.stack(pad_env)
|
160 |
-
|
161 |
-
if "phase" in features[0] and features[0]["phase"] is not None:
|
162 |
-
ph_list = [f["phase"] for f in features]
|
163 |
-
max_len = max(f.shape[-1] for f in ph_list)
|
164 |
-
pad_ph = []
|
165 |
-
for feat in ph_list:
|
166 |
-
current_len = feat.shape[-1]
|
167 |
-
padding = max_len_feat - current_len
|
168 |
-
if padding > 0:
|
169 |
-
pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
|
170 |
-
else:
|
171 |
-
pad_feat = feat
|
172 |
-
pad_ph.append(pad_feat)
|
173 |
-
batch["phase"] = torch.stack(pad_ph)
|
174 |
-
return batch
|
175 |
-
|
176 |
-
def hilbert_transform(x):
|
177 |
-
N = x.shape[-1]
|
178 |
-
xf = torch.fft.rfft(x)
|
179 |
-
h = torch.zeros(N // 2 + 1, device=x.device, dtype=x.dtype)
|
180 |
-
if N % 2 == 0:
|
181 |
-
h[0] = h[N//2] = 1
|
182 |
-
h[1:N//2] = 2
|
183 |
-
else:
|
184 |
-
h[0] = 1
|
185 |
-
h[1:(N+1)//2] = 2
|
186 |
-
return torch.fft.irfft(xf * h, n=N)
|
187 |
-
|
188 |
-
def analytic_signal(x):
|
189 |
-
return x + 1j * hilbert_transform(x)
|
190 |
-
|
191 |
-
def hilbert_transform_2d(x, dim=-1):
|
192 |
-
N = x.shape[dim]
|
193 |
-
if dim == -1 or dim == len(x.shape) - 1:
|
194 |
-
xf = torch.fft.rfft(x)
|
195 |
-
else:
|
196 |
-
xf = torch.fft.rfft(x, dim=dim)
|
197 |
-
h_shape = [1] * len(x.shape)
|
198 |
-
h_shape[dim] = N // 2 + 1
|
199 |
-
h = torch.zeros(h_shape, device=x.device, dtype=x.dtype)
|
200 |
-
if dim == -1 or dim == len(x.shape) - 1:
|
201 |
-
if N % 2 == 0:
|
202 |
-
h[..., 0] = h[..., -1] = 1
|
203 |
-
h[..., 1:-1] = 2
|
204 |
-
else:
|
205 |
-
h[..., 0] = 1
|
206 |
-
h[..., 1:] = 2
|
207 |
-
else:
|
208 |
-
pass
|
209 |
-
return torch.fft.irfft(xf * h, n=N, dim=dim)
|
210 |
-
|
211 |
-
def hilbert_transform_true_2d(x):
|
212 |
-
xf = torch.fft.rfft2(x)
|
213 |
-
h1, h2 = torch.meshgrid(
|
214 |
-
torch.fft.rfftfreq(x.shape[-2]) * 2 - 1,
|
215 |
-
torch.fft.rfftfreq(x.shape[-1]) * 2 - 1,
|
216 |
-
indexing='ij')
|
217 |
-
h = -1j / (math.pi * (h1 + 1j*h2))
|
218 |
-
h[0, 0] = 0
|
219 |
-
return torch.fft.irfft2(xf * h.to(x.device))
|
220 |
-
|
221 |
-
def process_spectrogram_with_hilbert(spec):
|
222 |
-
analytic = spec + 1j * hilbert_transform(spec)
|
223 |
-
envelope = torch.abs(analytic)
|
224 |
-
phase = torch.angle(analytic)
|
225 |
-
return envelope, phase
|
226 |
-
|
227 |
-
def load_wave(wave_data, sample_rate):
|
228 |
-
if isinstance(wave_data, str):
|
229 |
-
waveform, sr = torchaudio.load(uri=wave_data, normalize=False)
|
230 |
-
elif isinstance(wave_data, dict):
|
231 |
-
waveform = torch.tensor(data=wave_data["array"]).float()
|
232 |
-
sr = wave_data["sampling_rate"]
|
233 |
-
else:
|
234 |
-
raise TypeError("Invalid wave_data format.")
|
235 |
-
|
236 |
-
if waveform.dim() == 1:
|
237 |
-
waveform = waveform.unsqueeze(0)
|
238 |
-
|
239 |
-
if sr != sample_rate:
|
240 |
-
original_length = waveform.shape[1]
|
241 |
-
target_length = int(original_length * (sample_rate / sr))
|
242 |
-
|
243 |
-
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
|
244 |
-
waveform = resampler(waveform)
|
245 |
-
|
246 |
-
return waveform.flatten()
|
247 |
-
|
248 |
-
def extract_features(batch, tokenizer, spectrogram, waveforms, pitch, frequency=False,
|
249 |
-
hop_length=128, fmin=0, fmax=8000, n_mels=128, n_fft=1024, sampling_rate=16000,
|
250 |
-
pad_mode="constant", center=True, power=2.0, window_fn=torch.hann_window, mel_scale="htk",
|
251 |
-
norm=None, normalized=False, downsamples=False, period=False, hilbert=False):
|
252 |
-
|
253 |
-
dtype = torch.float32
|
254 |
-
device = torch.device("cuda:0")
|
255 |
-
audio = batch["audio"]
|
256 |
-
sampling_rate = audio["sampling_rate"]
|
257 |
-
sr = audio["sampling_rate"]
|
258 |
-
wav = load_wave(wave_data=audio, sample_rate=sr)
|
259 |
-
|
260 |
-
if spectrogram:
|
261 |
-
transform = torchaudio.transforms.MelSpectrogram(
|
262 |
-
f_max=fmax,
|
263 |
-
f_min=fmin,
|
264 |
-
n_mels=n_mels,
|
265 |
-
sample_rate=sr,
|
266 |
-
n_fft=n_fft,
|
267 |
-
hop_length=hop_length,
|
268 |
-
norm=norm,
|
269 |
-
normalized=normalized,
|
270 |
-
power=power,
|
271 |
-
center=center,
|
272 |
-
mel_scale=mel_scale,
|
273 |
-
window_fn=window_fn,
|
274 |
-
pad_mode=pad_mode)
|
275 |
-
|
276 |
-
mel_spectrogram = transform(wav)
|
277 |
-
log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
|
278 |
-
log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
|
279 |
-
spec = (log_mel + 4.0) / 4.0
|
280 |
-
spec = torch.tensor(spec)
|
281 |
-
batch["spectrogram"] = spec
|
282 |
-
|
283 |
-
if hilbert:
|
284 |
-
envelope_list = []
|
285 |
-
phase_list = []
|
286 |
-
|
287 |
-
for ch_idx in range(spec.shape[0]):
|
288 |
-
envelope, phase = process_spectrogram_with_hilbert(spec[ch_idx])
|
289 |
-
envelope_list.append(envelope)
|
290 |
-
phase_list.append(phase)
|
291 |
-
|
292 |
-
batch["envelope"] = torch.stack(envelope_list)
|
293 |
-
batch["phase"] = torch.stack(phase_list)
|
294 |
-
|
295 |
-
wav_1d = wav.unsqueeze(0)
|
296 |
-
|
297 |
-
if waveforms:
|
298 |
-
batch["waveform"] = wav_1d
|
299 |
-
|
300 |
-
if pitch:
|
301 |
-
wav_np = wav.numpy().astype(np.float64)
|
302 |
-
f0, t = pw.dio(wav_np, sampling_rate,
|
303 |
-
frame_period=hop_length/sampling_rate*1000)
|
304 |
-
f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
|
305 |
-
f0 = torch.from_numpy(f0).float()
|
306 |
-
batch["pitch"] = f0.unsqueeze(0)
|
307 |
-
|
308 |
-
if frequency:
|
309 |
-
wav_np = wav.numpy().astype(np.float64)
|
310 |
-
f0, t = pw.dio(wav_np, sampling_rate,
|
311 |
-
frame_period=hop_length/sampling_rate*1000)
|
312 |
-
f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
|
313 |
-
f0 = f0
|
314 |
-
batch["f0"] = torch.from_numpy(f0).float()
|
315 |
-
|
316 |
-
if spectrogram and waveforms and pitch:
|
317 |
-
spec_mean = batch["spectrogram"].mean()
|
318 |
-
spec_std = batch["spectrogram"].std() + 1e-6
|
319 |
-
batch["spectrogram"] = (batch["spectrogram"] - spec_mean) / spec_std
|
320 |
-
|
321 |
-
wav_mean = batch["waveform"].mean()
|
322 |
-
wav_std = batch["waveform"].std() + 1e-6
|
323 |
-
batch["waveform"] = (batch["waveform"] - wav_mean) / wav_std
|
324 |
-
|
325 |
-
if batch["pitch"].max() > 1.0:
|
326 |
-
pitch_min = 50.0
|
327 |
-
pitch_max = 600.0
|
328 |
-
batch["pitch"] = (batch["pitch"] - pitch_min) / (pitch_max - pitch_min)
|
329 |
-
|
330 |
-
batch["label"] = tokenizer.encode(batch["transcription"], add_special_tokens=False)
|
331 |
-
return batch
|
332 |
-
|
333 |
-
def compute_metrics(eval_pred, compute_result: bool = True,
|
334 |
-
print_pred: bool = False, num_samples: int = 0, tokenizer=None, pitch=None, model=None):
|
335 |
-
|
336 |
-
pred_logits = eval_pred.predictions
|
337 |
-
label_ids = eval_pred.label_ids
|
338 |
-
|
339 |
-
if hasattr(pred_logits, "cpu"):
|
340 |
-
pred_logits = pred_logits.cpu()
|
341 |
-
if hasattr(label_ids, "cpu"):
|
342 |
-
label_ids = label_ids.cpu()
|
343 |
-
if isinstance(pred_logits, tuple):
|
344 |
-
pred_ids = pred_logits[0]
|
345 |
-
else:
|
346 |
-
pred_ids = pred_logits
|
347 |
-
if hasattr(pred_ids, "ndim") and pred_ids.ndim == 3:
|
348 |
-
if not isinstance(pred_ids, torch.Tensor):
|
349 |
-
pred_ids = torch.tensor(pred_ids)
|
350 |
-
pred_ids = pred_ids.argmax(dim=-1)
|
351 |
-
pred_ids = pred_ids.tolist()
|
352 |
-
|
353 |
-
if hasattr(label_ids, "tolist"):
|
354 |
-
label_ids = label_ids.tolist()
|
355 |
-
|
356 |
-
label_ids = [[0 if token == -100 else token for token in seq] for seq in label_ids]
|
357 |
-
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=False)
|
358 |
-
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=False)
|
359 |
-
|
360 |
-
if print_pred:
|
361 |
-
for i in range(min(num_samples, len(pred_str))):
|
362 |
-
print(f"Preds: {pred_str[i]}")
|
363 |
-
print(f"Label: {label_str[i]}")
|
364 |
-
print(f"preds: {pred_ids[i]}")
|
365 |
-
print(f"label: {label_ids[i]}")
|
366 |
-
print("--------------------------------")
|
367 |
-
|
368 |
-
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
|
369 |
-
label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
|
370 |
-
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
|
371 |
-
|
372 |
-
if model is None:
|
373 |
-
global global_model
|
374 |
-
if 'global_model' in globals():
|
375 |
-
model = global_model
|
376 |
-
|
377 |
-
if model is not None:
|
378 |
-
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1_000_000
|
379 |
-
if trainable_params > 0:
|
380 |
-
efficiency_score = (100 - wer) / trainable_params
|
381 |
-
else:
|
382 |
-
print("Warning: Zero trainable parameters detected")
|
383 |
-
efficiency_score = 0.0
|
384 |
-
else:
|
385 |
-
print("Warning: Model not available for parameter counting")
|
386 |
-
trainable_params = 0.0
|
387 |
-
efficiency_score = 0.0
|
388 |
-
|
389 |
-
if hasattr(wer, "item"):
|
390 |
-
wer = wer.item()
|
391 |
-
|
392 |
-
metrics = {
|
393 |
-
"wer": float(wer),
|
394 |
-
"trainable_params_M": float(trainable_params),
|
395 |
-
"efficiency_score": float(efficiency_score),
|
396 |
-
}
|
397 |
-
|
398 |
-
return metrics
|
399 |
-
|
400 |
-
logger = logging.getLogger(__name__)
|
401 |
-
|
402 |
-
def create_model(param: Dimensions) -> Echo:
|
403 |
-
model = Echo(param).to('cuda')
|
404 |
-
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
405 |
-
total_params = sum(p.numel() for p in model.parameters())
|
406 |
-
logger.info(f"Trainable parameters: {trainable_params:,}")
|
407 |
-
logger.info(f"Total parameters: {total_params:,}")
|
408 |
-
print(f"Trainable parameters: {trainable_params:,}")
|
409 |
-
print(f"Total parameters: {total_params:,}")
|
410 |
-
|
411 |
-
return model
|
412 |
-
|
413 |
-
def setup_tokenizer(token: str, local_tokenizer_path: str = "D:/newmodel/model/tokenn/"):
|
414 |
-
from tokenizers import Tokenizer
|
415 |
-
tokenizer = Tokenizer.from_file(f"{local_tokenizer_path}/tokenizer.json")
|
416 |
-
orig_encode = tokenizer.encode
|
417 |
-
def enc(text, add_special_tokens=True):
|
418 |
-
ids = orig_encode(text).ids
|
419 |
-
if not add_special_tokens:
|
420 |
-
sp_ids = [tokenizer.token_to_id(t) for t in ["<PAD>", "<BOS>", "<EOS>"]]
|
421 |
-
ids = [id for id in ids if id not in sp_ids]
|
422 |
-
return ids
|
423 |
-
def bdec(ids_list, skip_special_tokens=True):
|
424 |
-
results = []
|
425 |
-
for ids in ids_list:
|
426 |
-
if skip_special_tokens:
|
427 |
-
ids = [id for id in ids if id not in [0, 1, 2]]
|
428 |
-
results.append(tokenizer.decode(ids))
|
429 |
-
return results
|
430 |
-
def save_pretrained(save_dir):
|
431 |
-
os.makedirs(save_dir, exist_ok=True)
|
432 |
-
tokenizer.save(f"{save_dir}/tokenizer.json")
|
433 |
-
tokenizer.encode = enc
|
434 |
-
tokenizer.batch_decode = bdec
|
435 |
-
tokenizer.save_pretrained = save_pretrained
|
436 |
-
tokenizer.pad_token_id = 0
|
437 |
-
tokenizer.bos_token_id = 1
|
438 |
-
tokenizer.eos_token_id = 2
|
439 |
-
return tokenizer
|
440 |
-
|
441 |
-
def prepare_datasets(tokenizer, token: str, sanity_check: bool = False, dataset_config: Optional[Dict] = None) -> Tuple[any, any]:
|
442 |
-
if dataset_config is None:
|
443 |
-
dataset_config = {
|
444 |
-
"spectrogram": True,
|
445 |
-
"waveforms": True,
|
446 |
-
"pitch": True,
|
447 |
-
"frequency": True,
|
448 |
-
"downsamples": True,
|
449 |
-
"hop_length": 128,
|
450 |
-
"fmin": 50,
|
451 |
-
"fmax": 2000,
|
452 |
-
"n_mels": 128,
|
453 |
-
"n_fft": 1024,
|
454 |
-
"sampling_rate": 16000,
|
455 |
-
}
|
456 |
-
|
457 |
-
dataset = load_dataset(
|
458 |
-
"google/fleurs",
|
459 |
-
"en_us",
|
460 |
-
token=token,
|
461 |
-
trust_remote_code=True,
|
462 |
-
streaming=False)
|
463 |
-
|
464 |
-
dataset = dataset.cast_column(column="audio", feature=Audio(sampling_rate=16000)).select_columns(["audio", "transcription"])
|
465 |
-
|
466 |
-
if sanity_check:
|
467 |
-
dataset = dataset["test"].take(10)
|
468 |
-
dataset = dataset.select_columns(["audio", "transcription"])
|
469 |
-
logger.info(f"Sanity dataset size: {dataset.num_rows}")
|
470 |
-
print(f"Sanity dataset size: {dataset.num_rows}")
|
471 |
-
prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
|
472 |
-
|
473 |
-
dataset = dataset.map(
|
474 |
-
function=prepare_fn,
|
475 |
-
remove_columns=["audio", "transcription"]
|
476 |
-
).with_format(type="torch")
|
477 |
-
train_dataset = dataset
|
478 |
-
test_dataset = dataset
|
479 |
-
else:
|
480 |
-
def filter_func(x):
|
481 |
-
return (0 < len(x["transcription"]) < 512 and
|
482 |
-
len(x["audio"]["array"]) > 0 and
|
483 |
-
len(x["audio"]["array"]) < 1500 * 160)
|
484 |
-
|
485 |
-
dataset = dataset.filter(filter_func).shuffle(seed=4)
|
486 |
-
logger.info(f"Dataset size: {dataset['train'].num_rows}, {dataset['test'].num_rows}")
|
487 |
-
print(f"Dataset size: {dataset['train'].num_rows}, {dataset['test'].num_rows}")
|
488 |
-
prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
|
489 |
-
columns_to_remove = list(next(iter(dataset.values())).features)
|
490 |
-
train_dataset = dataset["train"]
|
491 |
-
test_dataset = dataset["test"].take(50)
|
492 |
-
logger.info(f"Train dataset size: {train_dataset.num_rows}, Test dataset size: {test_dataset.num_rows}")
|
493 |
-
|
494 |
-
train_dataset = train_dataset.map(
|
495 |
-
function=prepare_fn,
|
496 |
-
remove_columns=columns_to_remove
|
497 |
-
).with_format(type="torch")
|
498 |
-
|
499 |
-
test_dataset = test_dataset.map(
|
500 |
-
function=prepare_fn,
|
501 |
-
remove_columns=columns_to_remove
|
502 |
-
).with_format(type="torch")
|
503 |
-
|
504 |
-
return train_dataset, test_dataset
|
505 |
-
|
506 |
-
def get_training_args(
|
507 |
-
log_dir: str,
|
508 |
-
batch_eval_metrics: bool = False,
|
509 |
-
max_steps: int = 10,
|
510 |
-
save_steps: int = 1000,
|
511 |
-
eval_steps: int = 1,
|
512 |
-
warmup_steps: int = 0,
|
513 |
-
num_train_epochs: int = 1,
|
514 |
-
logging_steps: int = 1,
|
515 |
-
eval_on_start: bool = False,
|
516 |
-
learning_rate: float = 1e-4,
|
517 |
-
weight_decay: float = 0.01,
|
518 |
-
max_grad_norm: float = 1.0,
|
519 |
-
) -> Seq2SeqTrainingArguments:
|
520 |
-
|
521 |
-
return Seq2SeqTrainingArguments(
|
522 |
-
output_dir=log_dir,
|
523 |
-
per_device_train_batch_size=1,
|
524 |
-
per_device_eval_batch_size=1,
|
525 |
-
gradient_accumulation_steps=1,
|
526 |
-
eval_accumulation_steps=1,
|
527 |
-
tf32=True,
|
528 |
-
bf16=True,
|
529 |
-
eval_strategy="steps",
|
530 |
-
save_strategy="steps",
|
531 |
-
max_steps=max_steps,
|
532 |
-
save_steps=save_steps,
|
533 |
-
eval_steps=eval_steps,
|
534 |
-
warmup_steps=warmup_steps,
|
535 |
-
num_train_epochs=num_train_epochs,
|
536 |
-
logging_steps=logging_steps,
|
537 |
-
logging_dir=log_dir,
|
538 |
-
logging_strategy="steps",
|
539 |
-
report_to=["tensorboard"],
|
540 |
-
push_to_hub=False,
|
541 |
-
disable_tqdm=False,
|
542 |
-
save_total_limit=1,
|
543 |
-
label_names=["labels"],
|
544 |
-
optim="adamw_torch",
|
545 |
-
lr_scheduler_type="cosine",
|
546 |
-
learning_rate=learning_rate,
|
547 |
-
weight_decay=weight_decay,
|
548 |
-
save_safetensors=False,
|
549 |
-
eval_on_start=eval_on_start,
|
550 |
-
batch_eval_metrics=batch_eval_metrics,
|
551 |
-
max_grad_norm=max_grad_norm,
|
552 |
-
)
|
553 |
-
|
554 |
-
def main():
|
555 |
-
|
556 |
-
token = ""
|
557 |
-
log_dir = os.path.join('./output/logs', datetime.now().strftime(format='%m-%d_%H'))
|
558 |
-
os.makedirs(name=log_dir, exist_ok=True)
|
559 |
-
tokenizer = setup_tokenizer(token)
|
560 |
-
|
561 |
-
def sanity(sanity: bool):
|
562 |
-
|
563 |
-
if sanity:
|
564 |
-
training_args = get_training_args(
|
565 |
-
log_dir,
|
566 |
-
batch_eval_metrics = False,
|
567 |
-
max_steps = 10,
|
568 |
-
save_steps = 0,
|
569 |
-
eval_steps = 1,
|
570 |
-
warmup_steps = 0,
|
571 |
-
logging_steps = 1,
|
572 |
-
eval_on_start = False,
|
573 |
-
learning_rate = 5e-6,
|
574 |
-
weight_decay = 0.01,
|
575 |
-
)
|
576 |
-
else:
|
577 |
-
training_args = get_training_args(
|
578 |
-
log_dir,
|
579 |
-
batch_eval_metrics = False,
|
580 |
-
max_steps = 1000,
|
581 |
-
save_steps = 1000,
|
582 |
-
eval_steps = 100,
|
583 |
-
warmup_steps = 100,
|
584 |
-
logging_steps = 10,
|
585 |
-
eval_on_start = False,
|
586 |
-
learning_rate = 2.5e-4,
|
587 |
-
weight_decay = 0.01,
|
588 |
-
)
|
589 |
-
|
590 |
-
return training_args
|
591 |
-
|
592 |
-
param = Dimensions(
|
593 |
-
mels=128,
|
594 |
-
aud_ctx=1500,
|
595 |
-
aud_head=4,
|
596 |
-
aud_dims=512,
|
597 |
-
aud_idx=4,
|
598 |
-
vocab=40000,
|
599 |
-
text_ctx=512,
|
600 |
-
text_head=4,
|
601 |
-
text_dims=512,
|
602 |
-
text_idx=4,
|
603 |
-
act="swish",
|
604 |
-
debug={},#{"encoder", "decoder", "residual", "rotary"},
|
605 |
-
cross_attn=True,
|
606 |
-
f0_rotary=False,
|
607 |
-
features = ["spectrogram"]#, "waveform", "pitch", "f0", "envelope", "phase"],
|
608 |
-
)
|
609 |
-
|
610 |
-
sanity_check = False
|
611 |
-
training_args = sanity(sanity_check)
|
612 |
-
dataset_config = {
|
613 |
-
"spectrogram": True,
|
614 |
-
"waveforms": False,
|
615 |
-
"pitch": False,
|
616 |
-
"downsamples": False,
|
617 |
-
"frequency": True,
|
618 |
-
"hilbert": False,
|
619 |
-
"hop_length": 128,
|
620 |
-
"fmin": 150,
|
621 |
-
"fmax": 2000,
|
622 |
-
"n_mels": 128,
|
623 |
-
"n_fft": 1024,
|
624 |
-
"sampling_rate": 16000,
|
625 |
-
"pad_mode": "constant",
|
626 |
-
"center": True,
|
627 |
-
"power": 2.0,
|
628 |
-
"window_fn": torch.hann_window,
|
629 |
-
"mel_scale": "htk",
|
630 |
-
"norm": None,
|
631 |
-
"normalized": False}
|
632 |
-
|
633 |
-
model = create_model(param)
|
634 |
-
|
635 |
-
global global_model
|
636 |
-
global_model = model
|
637 |
-
|
638 |
-
metrics_fn = partial(compute_metrics, print_pred=False, num_samples=5,
|
639 |
-
tokenizer=tokenizer, model=model)
|
640 |
-
|
641 |
-
print(f"{'Sanity check' if sanity_check else 'Training'} mode")
|
642 |
-
train_dataset, test_dataset = prepare_datasets(
|
643 |
-
tokenizer=tokenizer,
|
644 |
-
token=token,
|
645 |
-
sanity_check=sanity_check,
|
646 |
-
dataset_config=dataset_config)
|
647 |
-
|
648 |
-
trainer = Seq2SeqTrainer(
|
649 |
-
args=training_args,
|
650 |
-
model=model,
|
651 |
-
train_dataset=train_dataset,
|
652 |
-
eval_dataset=test_dataset,
|
653 |
-
data_collator=DataCollator(tokenizer=tokenizer),
|
654 |
-
compute_metrics=metrics_fn,
|
655 |
-
)
|
656 |
-
|
657 |
-
model.init_weights()
|
658 |
-
trainer.train()
|
659 |
-
|
660 |
-
if __name__ == "__main__":
|
661 |
-
main()
|
662 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|