Sin2pi commited on
Commit
5e4bd82
·
verified ·
1 Parent(s): badbec0

Delete echopipeline.py

Browse files
Files changed (1) hide show
  1. echopipeline.py +0 -662
echopipeline.py DELETED
@@ -1,662 +0,0 @@
1
- import pyworld as pw
2
- import os
3
- import math
4
- import logging
5
- import torch
6
- import torchaudio
7
- import torch.nn.functional as F
8
- import numpy as np
9
- from typing import Optional, Dict, Union, List, Tuple, Any
10
- from functools import partial
11
- from datetime import datetime
12
- from datasets import load_dataset, Audio, concatenate_datasets
13
- from transformers.trainer_seq2seq import Seq2SeqTrainer
14
- from transformers.training_args_seq2seq import Seq2SeqTrainingArguments
15
- import evaluate
16
- from dataclasses import dataclass
17
-
18
- extractor = None
19
- tokenizer = None
20
- optimizer = None
21
- scheduler = None
22
- model = None
23
- Residual = None
24
- MultiheadA = None
25
- Echo = None
26
-
27
- metric = evaluate.load(path="wer")
28
-
29
- @dataclass
30
- class Dimensions:
31
- vocab: int
32
- text_ctx: int
33
- text_dims: int
34
- text_head: int
35
- text_idx: int
36
- mels: int
37
- aud_ctx: int
38
- aud_dims: int
39
- aud_head: int
40
- aud_idx: int
41
- act: str
42
- debug: List[str]
43
- cross_attn: bool
44
- features: List[str]
45
- f0_rotary: bool
46
-
47
- def align_f0(f0, ctx):
48
- ctx = torch.tensor(ctx)
49
- bat, length = f0.shape
50
- if length == ctx:
51
- return f0
52
- frames = length / ctx
53
- idx = torch.arange(ctx, device=f0.device)
54
- idx = (idx * frames).long()
55
- batch_idx = torch.arange(bat, device=f0.device).unsqueeze(1)
56
- return f0[batch_idx, idx.unsqueeze(0).expand(bat, -1)]
57
-
58
- @dataclass
59
- class DataCollator:
60
- tokenizer: Any
61
- def __call__(self, features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
62
- pad_token_id = tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else 0
63
- bos_token_id = tokenizer.bos_token_id if hasattr(tokenizer, 'bos_token_id') else 1
64
-
65
- batch = {}
66
-
67
- if "spectrogram" in features[0] and features[0]["spectrogram"] is not None:
68
- spectrogram_list = [f["spectrogram"] for f in features]
69
- max_len_feat = max(f.shape[-1] for f in spectrogram_list)
70
- pad_spectrogram = []
71
- for feat in spectrogram_list:
72
- current_len = feat.shape[-1]
73
- padding = max_len_feat - current_len
74
- if padding > 0:
75
- pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
76
- else:
77
- pad_feat = feat
78
- pad_spectrogram.append(pad_feat)
79
- batch["spectrogram"] = torch.stack(pad_spectrogram)
80
-
81
- if "waveform" in features[0] and features[0]["waveform"] is not None:
82
- waveform_list = [f["waveform"] for f in features]
83
- max_len_wav = max(w.shape[-1] for w in waveform_list)
84
- pad_waveforms = []
85
- for wav in waveform_list:
86
- current_len = wav.shape[-1]
87
- padding = max_len_wav - current_len
88
- if padding > 0:
89
- if wav.ndim == 1:
90
- wav = wav.unsqueeze(0)
91
- pad_wav = F.pad(wav, (0, padding), mode='constant', value=pad_token_id)
92
- else:
93
- pad_wav = wav
94
- pad_waveforms.append(pad_wav)
95
- batch["waveform"] = torch.stack(pad_waveforms)
96
-
97
- if "label" in features[0] and features[0]["label"] is not None:
98
- labels_list = [f["label"] for f in features]
99
- max_len = max(len(l) for l in labels_list)
100
- all_ids = []
101
- all_labels = []
102
-
103
- for label in labels_list:
104
- label_list = label.tolist() if isinstance(label, torch.Tensor) else label
105
- decoder_input = [bos_token_id] + label_list
106
- label_eos = label_list + [pad_token_id]
107
- input_len = max_len + 1 - len(decoder_input)
108
- label_len = max_len + 1 - len(label_eos)
109
- padded_input = decoder_input + [pad_token_id] * input_len
110
- padded_labels = label_eos + [pad_token_id] * label_len
111
- all_ids.append(padded_input)
112
- all_labels.append(padded_labels)
113
- batch["input_ids"] = torch.tensor(all_ids, dtype=torch.long)
114
- batch["labels"] = torch.tensor(all_labels, dtype=torch.long)
115
-
116
- if "pitch" in features[0] and features[0]["pitch"] is not None:
117
- pitch_list = [f["pitch"] for f in features]
118
- max_len_pitch = max(e.shape[-1] for e in pitch_list)
119
- pad_pitch = []
120
- for pitch in pitch_list:
121
- current_len = pitch.shape[-1]
122
- padding = max_len_pitch - current_len
123
- if padding > 0:
124
- pad_pitch_item = F.pad(pitch, (0, padding), mode='constant', value=pad_token_id)
125
- else:
126
- pad_pitch_item = pitch
127
- pad_pitch.append(pad_pitch_item)
128
- batch["pitch"] = torch.stack(pad_pitch)
129
-
130
- if "f0" in features[0] and features[0]["f0"] is not None:
131
- input_ids_batch = batch.get("input_ids", None)
132
- if input_ids_batch is not None:
133
- target_length = input_ids_batch.shape[-1]
134
- aligned_list = []
135
- original_list = []
136
- for feature in features:
137
- f0 = feature["f0"]
138
- original_list.append(f0)
139
- if f0.shape[-1] != target_length:
140
- aligned_f0 = align_f0(f0.unsqueeze(0), target_length).squeeze(0)
141
- else:
142
- aligned_f0 = f0
143
- aligned_list.append(aligned_f0)
144
- batch["f0d"] = torch.stack(aligned_list)
145
- batch["f0"] = torch.stack(original_list)
146
-
147
- if "envelope" in features[0] and features[0]["envelope"] is not None:
148
- env_list = [f["envelope"] for f in features]
149
- max_len = max(f.shape[-1] for f in env_list)
150
- pad_env = []
151
- for feat in env_list:
152
- current_len = feat.shape[-1]
153
- padding = max_len_feat - current_len
154
- if padding > 0:
155
- pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
156
- else:
157
- pad_feat = feat
158
- pad_env.append(pad_feat)
159
- batch["envelope"] = torch.stack(pad_env)
160
-
161
- if "phase" in features[0] and features[0]["phase"] is not None:
162
- ph_list = [f["phase"] for f in features]
163
- max_len = max(f.shape[-1] for f in ph_list)
164
- pad_ph = []
165
- for feat in ph_list:
166
- current_len = feat.shape[-1]
167
- padding = max_len_feat - current_len
168
- if padding > 0:
169
- pad_feat = F.pad(feat, (0, padding), mode='constant', value=pad_token_id)
170
- else:
171
- pad_feat = feat
172
- pad_ph.append(pad_feat)
173
- batch["phase"] = torch.stack(pad_ph)
174
- return batch
175
-
176
- def hilbert_transform(x):
177
- N = x.shape[-1]
178
- xf = torch.fft.rfft(x)
179
- h = torch.zeros(N // 2 + 1, device=x.device, dtype=x.dtype)
180
- if N % 2 == 0:
181
- h[0] = h[N//2] = 1
182
- h[1:N//2] = 2
183
- else:
184
- h[0] = 1
185
- h[1:(N+1)//2] = 2
186
- return torch.fft.irfft(xf * h, n=N)
187
-
188
- def analytic_signal(x):
189
- return x + 1j * hilbert_transform(x)
190
-
191
- def hilbert_transform_2d(x, dim=-1):
192
- N = x.shape[dim]
193
- if dim == -1 or dim == len(x.shape) - 1:
194
- xf = torch.fft.rfft(x)
195
- else:
196
- xf = torch.fft.rfft(x, dim=dim)
197
- h_shape = [1] * len(x.shape)
198
- h_shape[dim] = N // 2 + 1
199
- h = torch.zeros(h_shape, device=x.device, dtype=x.dtype)
200
- if dim == -1 or dim == len(x.shape) - 1:
201
- if N % 2 == 0:
202
- h[..., 0] = h[..., -1] = 1
203
- h[..., 1:-1] = 2
204
- else:
205
- h[..., 0] = 1
206
- h[..., 1:] = 2
207
- else:
208
- pass
209
- return torch.fft.irfft(xf * h, n=N, dim=dim)
210
-
211
- def hilbert_transform_true_2d(x):
212
- xf = torch.fft.rfft2(x)
213
- h1, h2 = torch.meshgrid(
214
- torch.fft.rfftfreq(x.shape[-2]) * 2 - 1,
215
- torch.fft.rfftfreq(x.shape[-1]) * 2 - 1,
216
- indexing='ij')
217
- h = -1j / (math.pi * (h1 + 1j*h2))
218
- h[0, 0] = 0
219
- return torch.fft.irfft2(xf * h.to(x.device))
220
-
221
- def process_spectrogram_with_hilbert(spec):
222
- analytic = spec + 1j * hilbert_transform(spec)
223
- envelope = torch.abs(analytic)
224
- phase = torch.angle(analytic)
225
- return envelope, phase
226
-
227
- def load_wave(wave_data, sample_rate):
228
- if isinstance(wave_data, str):
229
- waveform, sr = torchaudio.load(uri=wave_data, normalize=False)
230
- elif isinstance(wave_data, dict):
231
- waveform = torch.tensor(data=wave_data["array"]).float()
232
- sr = wave_data["sampling_rate"]
233
- else:
234
- raise TypeError("Invalid wave_data format.")
235
-
236
- if waveform.dim() == 1:
237
- waveform = waveform.unsqueeze(0)
238
-
239
- if sr != sample_rate:
240
- original_length = waveform.shape[1]
241
- target_length = int(original_length * (sample_rate / sr))
242
-
243
- resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=sample_rate)
244
- waveform = resampler(waveform)
245
-
246
- return waveform.flatten()
247
-
248
- def extract_features(batch, tokenizer, spectrogram, waveforms, pitch, frequency=False,
249
- hop_length=128, fmin=0, fmax=8000, n_mels=128, n_fft=1024, sampling_rate=16000,
250
- pad_mode="constant", center=True, power=2.0, window_fn=torch.hann_window, mel_scale="htk",
251
- norm=None, normalized=False, downsamples=False, period=False, hilbert=False):
252
-
253
- dtype = torch.float32
254
- device = torch.device("cuda:0")
255
- audio = batch["audio"]
256
- sampling_rate = audio["sampling_rate"]
257
- sr = audio["sampling_rate"]
258
- wav = load_wave(wave_data=audio, sample_rate=sr)
259
-
260
- if spectrogram:
261
- transform = torchaudio.transforms.MelSpectrogram(
262
- f_max=fmax,
263
- f_min=fmin,
264
- n_mels=n_mels,
265
- sample_rate=sr,
266
- n_fft=n_fft,
267
- hop_length=hop_length,
268
- norm=norm,
269
- normalized=normalized,
270
- power=power,
271
- center=center,
272
- mel_scale=mel_scale,
273
- window_fn=window_fn,
274
- pad_mode=pad_mode)
275
-
276
- mel_spectrogram = transform(wav)
277
- log_mel = torch.clamp(mel_spectrogram, min=1e-10).log10()
278
- log_mel = torch.maximum(log_mel, log_mel.max() - 8.0)
279
- spec = (log_mel + 4.0) / 4.0
280
- spec = torch.tensor(spec)
281
- batch["spectrogram"] = spec
282
-
283
- if hilbert:
284
- envelope_list = []
285
- phase_list = []
286
-
287
- for ch_idx in range(spec.shape[0]):
288
- envelope, phase = process_spectrogram_with_hilbert(spec[ch_idx])
289
- envelope_list.append(envelope)
290
- phase_list.append(phase)
291
-
292
- batch["envelope"] = torch.stack(envelope_list)
293
- batch["phase"] = torch.stack(phase_list)
294
-
295
- wav_1d = wav.unsqueeze(0)
296
-
297
- if waveforms:
298
- batch["waveform"] = wav_1d
299
-
300
- if pitch:
301
- wav_np = wav.numpy().astype(np.float64)
302
- f0, t = pw.dio(wav_np, sampling_rate,
303
- frame_period=hop_length/sampling_rate*1000)
304
- f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
305
- f0 = torch.from_numpy(f0).float()
306
- batch["pitch"] = f0.unsqueeze(0)
307
-
308
- if frequency:
309
- wav_np = wav.numpy().astype(np.float64)
310
- f0, t = pw.dio(wav_np, sampling_rate,
311
- frame_period=hop_length/sampling_rate*1000)
312
- f0 = pw.stonemask(wav_np, f0, t, sampling_rate)
313
- f0 = f0
314
- batch["f0"] = torch.from_numpy(f0).float()
315
-
316
- if spectrogram and waveforms and pitch:
317
- spec_mean = batch["spectrogram"].mean()
318
- spec_std = batch["spectrogram"].std() + 1e-6
319
- batch["spectrogram"] = (batch["spectrogram"] - spec_mean) / spec_std
320
-
321
- wav_mean = batch["waveform"].mean()
322
- wav_std = batch["waveform"].std() + 1e-6
323
- batch["waveform"] = (batch["waveform"] - wav_mean) / wav_std
324
-
325
- if batch["pitch"].max() > 1.0:
326
- pitch_min = 50.0
327
- pitch_max = 600.0
328
- batch["pitch"] = (batch["pitch"] - pitch_min) / (pitch_max - pitch_min)
329
-
330
- batch["label"] = tokenizer.encode(batch["transcription"], add_special_tokens=False)
331
- return batch
332
-
333
- def compute_metrics(eval_pred, compute_result: bool = True,
334
- print_pred: bool = False, num_samples: int = 0, tokenizer=None, pitch=None, model=None):
335
-
336
- pred_logits = eval_pred.predictions
337
- label_ids = eval_pred.label_ids
338
-
339
- if hasattr(pred_logits, "cpu"):
340
- pred_logits = pred_logits.cpu()
341
- if hasattr(label_ids, "cpu"):
342
- label_ids = label_ids.cpu()
343
- if isinstance(pred_logits, tuple):
344
- pred_ids = pred_logits[0]
345
- else:
346
- pred_ids = pred_logits
347
- if hasattr(pred_ids, "ndim") and pred_ids.ndim == 3:
348
- if not isinstance(pred_ids, torch.Tensor):
349
- pred_ids = torch.tensor(pred_ids)
350
- pred_ids = pred_ids.argmax(dim=-1)
351
- pred_ids = pred_ids.tolist()
352
-
353
- if hasattr(label_ids, "tolist"):
354
- label_ids = label_ids.tolist()
355
-
356
- label_ids = [[0 if token == -100 else token for token in seq] for seq in label_ids]
357
- pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=False)
358
- label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=False)
359
-
360
- if print_pred:
361
- for i in range(min(num_samples, len(pred_str))):
362
- print(f"Preds: {pred_str[i]}")
363
- print(f"Label: {label_str[i]}")
364
- print(f"preds: {pred_ids[i]}")
365
- print(f"label: {label_ids[i]}")
366
- print("--------------------------------")
367
-
368
- pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
369
- label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)
370
- wer = 100 * metric.compute(predictions=pred_str, references=label_str)
371
-
372
- if model is None:
373
- global global_model
374
- if 'global_model' in globals():
375
- model = global_model
376
-
377
- if model is not None:
378
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1_000_000
379
- if trainable_params > 0:
380
- efficiency_score = (100 - wer) / trainable_params
381
- else:
382
- print("Warning: Zero trainable parameters detected")
383
- efficiency_score = 0.0
384
- else:
385
- print("Warning: Model not available for parameter counting")
386
- trainable_params = 0.0
387
- efficiency_score = 0.0
388
-
389
- if hasattr(wer, "item"):
390
- wer = wer.item()
391
-
392
- metrics = {
393
- "wer": float(wer),
394
- "trainable_params_M": float(trainable_params),
395
- "efficiency_score": float(efficiency_score),
396
- }
397
-
398
- return metrics
399
-
400
- logger = logging.getLogger(__name__)
401
-
402
- def create_model(param: Dimensions) -> Echo:
403
- model = Echo(param).to('cuda')
404
- trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
405
- total_params = sum(p.numel() for p in model.parameters())
406
- logger.info(f"Trainable parameters: {trainable_params:,}")
407
- logger.info(f"Total parameters: {total_params:,}")
408
- print(f"Trainable parameters: {trainable_params:,}")
409
- print(f"Total parameters: {total_params:,}")
410
-
411
- return model
412
-
413
- def setup_tokenizer(token: str, local_tokenizer_path: str = "D:/newmodel/model/tokenn/"):
414
- from tokenizers import Tokenizer
415
- tokenizer = Tokenizer.from_file(f"{local_tokenizer_path}/tokenizer.json")
416
- orig_encode = tokenizer.encode
417
- def enc(text, add_special_tokens=True):
418
- ids = orig_encode(text).ids
419
- if not add_special_tokens:
420
- sp_ids = [tokenizer.token_to_id(t) for t in ["<PAD>", "<BOS>", "<EOS>"]]
421
- ids = [id for id in ids if id not in sp_ids]
422
- return ids
423
- def bdec(ids_list, skip_special_tokens=True):
424
- results = []
425
- for ids in ids_list:
426
- if skip_special_tokens:
427
- ids = [id for id in ids if id not in [0, 1, 2]]
428
- results.append(tokenizer.decode(ids))
429
- return results
430
- def save_pretrained(save_dir):
431
- os.makedirs(save_dir, exist_ok=True)
432
- tokenizer.save(f"{save_dir}/tokenizer.json")
433
- tokenizer.encode = enc
434
- tokenizer.batch_decode = bdec
435
- tokenizer.save_pretrained = save_pretrained
436
- tokenizer.pad_token_id = 0
437
- tokenizer.bos_token_id = 1
438
- tokenizer.eos_token_id = 2
439
- return tokenizer
440
-
441
- def prepare_datasets(tokenizer, token: str, sanity_check: bool = False, dataset_config: Optional[Dict] = None) -> Tuple[any, any]:
442
- if dataset_config is None:
443
- dataset_config = {
444
- "spectrogram": True,
445
- "waveforms": True,
446
- "pitch": True,
447
- "frequency": True,
448
- "downsamples": True,
449
- "hop_length": 128,
450
- "fmin": 50,
451
- "fmax": 2000,
452
- "n_mels": 128,
453
- "n_fft": 1024,
454
- "sampling_rate": 16000,
455
- }
456
-
457
- dataset = load_dataset(
458
- "google/fleurs",
459
- "en_us",
460
- token=token,
461
- trust_remote_code=True,
462
- streaming=False)
463
-
464
- dataset = dataset.cast_column(column="audio", feature=Audio(sampling_rate=16000)).select_columns(["audio", "transcription"])
465
-
466
- if sanity_check:
467
- dataset = dataset["test"].take(10)
468
- dataset = dataset.select_columns(["audio", "transcription"])
469
- logger.info(f"Sanity dataset size: {dataset.num_rows}")
470
- print(f"Sanity dataset size: {dataset.num_rows}")
471
- prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
472
-
473
- dataset = dataset.map(
474
- function=prepare_fn,
475
- remove_columns=["audio", "transcription"]
476
- ).with_format(type="torch")
477
- train_dataset = dataset
478
- test_dataset = dataset
479
- else:
480
- def filter_func(x):
481
- return (0 < len(x["transcription"]) < 512 and
482
- len(x["audio"]["array"]) > 0 and
483
- len(x["audio"]["array"]) < 1500 * 160)
484
-
485
- dataset = dataset.filter(filter_func).shuffle(seed=4)
486
- logger.info(f"Dataset size: {dataset['train'].num_rows}, {dataset['test'].num_rows}")
487
- print(f"Dataset size: {dataset['train'].num_rows}, {dataset['test'].num_rows}")
488
- prepare_fn = partial(extract_features, tokenizer=tokenizer, **dataset_config)
489
- columns_to_remove = list(next(iter(dataset.values())).features)
490
- train_dataset = dataset["train"]
491
- test_dataset = dataset["test"].take(50)
492
- logger.info(f"Train dataset size: {train_dataset.num_rows}, Test dataset size: {test_dataset.num_rows}")
493
-
494
- train_dataset = train_dataset.map(
495
- function=prepare_fn,
496
- remove_columns=columns_to_remove
497
- ).with_format(type="torch")
498
-
499
- test_dataset = test_dataset.map(
500
- function=prepare_fn,
501
- remove_columns=columns_to_remove
502
- ).with_format(type="torch")
503
-
504
- return train_dataset, test_dataset
505
-
506
- def get_training_args(
507
- log_dir: str,
508
- batch_eval_metrics: bool = False,
509
- max_steps: int = 10,
510
- save_steps: int = 1000,
511
- eval_steps: int = 1,
512
- warmup_steps: int = 0,
513
- num_train_epochs: int = 1,
514
- logging_steps: int = 1,
515
- eval_on_start: bool = False,
516
- learning_rate: float = 1e-4,
517
- weight_decay: float = 0.01,
518
- max_grad_norm: float = 1.0,
519
- ) -> Seq2SeqTrainingArguments:
520
-
521
- return Seq2SeqTrainingArguments(
522
- output_dir=log_dir,
523
- per_device_train_batch_size=1,
524
- per_device_eval_batch_size=1,
525
- gradient_accumulation_steps=1,
526
- eval_accumulation_steps=1,
527
- tf32=True,
528
- bf16=True,
529
- eval_strategy="steps",
530
- save_strategy="steps",
531
- max_steps=max_steps,
532
- save_steps=save_steps,
533
- eval_steps=eval_steps,
534
- warmup_steps=warmup_steps,
535
- num_train_epochs=num_train_epochs,
536
- logging_steps=logging_steps,
537
- logging_dir=log_dir,
538
- logging_strategy="steps",
539
- report_to=["tensorboard"],
540
- push_to_hub=False,
541
- disable_tqdm=False,
542
- save_total_limit=1,
543
- label_names=["labels"],
544
- optim="adamw_torch",
545
- lr_scheduler_type="cosine",
546
- learning_rate=learning_rate,
547
- weight_decay=weight_decay,
548
- save_safetensors=False,
549
- eval_on_start=eval_on_start,
550
- batch_eval_metrics=batch_eval_metrics,
551
- max_grad_norm=max_grad_norm,
552
- )
553
-
554
- def main():
555
-
556
- token = ""
557
- log_dir = os.path.join('./output/logs', datetime.now().strftime(format='%m-%d_%H'))
558
- os.makedirs(name=log_dir, exist_ok=True)
559
- tokenizer = setup_tokenizer(token)
560
-
561
- def sanity(sanity: bool):
562
-
563
- if sanity:
564
- training_args = get_training_args(
565
- log_dir,
566
- batch_eval_metrics = False,
567
- max_steps = 10,
568
- save_steps = 0,
569
- eval_steps = 1,
570
- warmup_steps = 0,
571
- logging_steps = 1,
572
- eval_on_start = False,
573
- learning_rate = 5e-6,
574
- weight_decay = 0.01,
575
- )
576
- else:
577
- training_args = get_training_args(
578
- log_dir,
579
- batch_eval_metrics = False,
580
- max_steps = 1000,
581
- save_steps = 1000,
582
- eval_steps = 100,
583
- warmup_steps = 100,
584
- logging_steps = 10,
585
- eval_on_start = False,
586
- learning_rate = 2.5e-4,
587
- weight_decay = 0.01,
588
- )
589
-
590
- return training_args
591
-
592
- param = Dimensions(
593
- mels=128,
594
- aud_ctx=1500,
595
- aud_head=4,
596
- aud_dims=512,
597
- aud_idx=4,
598
- vocab=40000,
599
- text_ctx=512,
600
- text_head=4,
601
- text_dims=512,
602
- text_idx=4,
603
- act="swish",
604
- debug={},#{"encoder", "decoder", "residual", "rotary"},
605
- cross_attn=True,
606
- f0_rotary=False,
607
- features = ["spectrogram"]#, "waveform", "pitch", "f0", "envelope", "phase"],
608
- )
609
-
610
- sanity_check = False
611
- training_args = sanity(sanity_check)
612
- dataset_config = {
613
- "spectrogram": True,
614
- "waveforms": False,
615
- "pitch": False,
616
- "downsamples": False,
617
- "frequency": True,
618
- "hilbert": False,
619
- "hop_length": 128,
620
- "fmin": 150,
621
- "fmax": 2000,
622
- "n_mels": 128,
623
- "n_fft": 1024,
624
- "sampling_rate": 16000,
625
- "pad_mode": "constant",
626
- "center": True,
627
- "power": 2.0,
628
- "window_fn": torch.hann_window,
629
- "mel_scale": "htk",
630
- "norm": None,
631
- "normalized": False}
632
-
633
- model = create_model(param)
634
-
635
- global global_model
636
- global_model = model
637
-
638
- metrics_fn = partial(compute_metrics, print_pred=False, num_samples=5,
639
- tokenizer=tokenizer, model=model)
640
-
641
- print(f"{'Sanity check' if sanity_check else 'Training'} mode")
642
- train_dataset, test_dataset = prepare_datasets(
643
- tokenizer=tokenizer,
644
- token=token,
645
- sanity_check=sanity_check,
646
- dataset_config=dataset_config)
647
-
648
- trainer = Seq2SeqTrainer(
649
- args=training_args,
650
- model=model,
651
- train_dataset=train_dataset,
652
- eval_dataset=test_dataset,
653
- data_collator=DataCollator(tokenizer=tokenizer),
654
- compute_metrics=metrics_fn,
655
- )
656
-
657
- model.init_weights()
658
- trainer.train()
659
-
660
- if __name__ == "__main__":
661
- main()
662
-