wasmdashai commited on
Commit
60026f7
·
verified ·
1 Parent(s): d9452a6

Delete VitsModelSplit/Trainer.py

Browse files
Files changed (1) hide show
  1. VitsModelSplit/Trainer.py +0 -848
VitsModelSplit/Trainer.py DELETED
@@ -1,848 +0,0 @@
1
- import os
2
- import shutil
3
- import tempfile
4
- import numpy as np
5
- import wandb
6
- from transformers import VitsModel
7
- import math
8
- import torch
9
- from accelerate.utils import ProjectConfiguration, is_wandb_available, set_seed
10
- from accelerate import Accelerator, DistributedDataParallelKwargs
11
- from transformers.utils import send_example_telemetry
12
- import logging
13
- import sys
14
- from transformers.trainer_utils import get_last_checkpoint, is_main_process
15
- from transformers.trainer_pt_utils import LengthGroupedSampler
16
- from transformers.optimization import get_scheduler
17
-
18
-
19
- from .data_collator import DataCollatorTTSWithPadding
20
- from .discriminator import VitsDiscriminator
21
- from .feature_extraction import VitsFeatureExtractor
22
- from .plot import plot_alignment_to_numpy, plot_spectrogram_to_numpy
23
-
24
- #.............................................
25
-
26
- if is_wandb_available():
27
- import wandb
28
-
29
- ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
30
- logger = logging.getLogger(__name__)
31
- #.............................................
32
-
33
- def discriminator_loss(disc_real_outputs, disc_generated_outputs):
34
- loss = 0
35
- real_losses = 0
36
- generated_losses = 0
37
- for disc_real, disc_generated in zip(disc_real_outputs, disc_generated_outputs):
38
- real_loss = torch.mean((1 - disc_real) ** 2)
39
- generated_loss = torch.mean(disc_generated**2)
40
- loss += real_loss + generated_loss
41
- real_losses += real_loss
42
- generated_losses += generated_loss
43
-
44
- return loss, real_losses, generated_losses
45
-
46
-
47
- def feature_loss(feature_maps_real, feature_maps_generated):
48
- loss = 0
49
- for feature_map_real, feature_map_generated in zip(feature_maps_real, feature_maps_generated):
50
- for real, generated in zip(feature_map_real, feature_map_generated):
51
- real = real.detach()
52
- loss += torch.mean(torch.abs(real - generated))
53
-
54
- return loss * 2
55
-
56
-
57
- def generator_loss(disc_outputs):
58
- total_loss = 0
59
- gen_losses = []
60
- for disc_output in disc_outputs:
61
- disc_output = disc_output
62
- loss = torch.mean((1 - disc_output) ** 2)
63
- gen_losses.append(loss)
64
- total_loss += loss
65
-
66
- return total_loss, gen_losses
67
-
68
-
69
- def kl_loss(prior_latents, posterior_log_variance, prior_means, prior_log_variance, labels_mask):
70
- """
71
- z_p, logs_q: [b, h, t_t]
72
- prior_means, prior_log_variance: [b, h, t_t]
73
- """
74
-
75
- kl = prior_log_variance - posterior_log_variance - 0.5
76
- kl += 0.5 * ((prior_latents - prior_means) ** 2) * torch.exp(-2.0 * prior_log_variance)
77
- kl = torch.sum(kl * labels_mask)
78
- loss = kl / torch.sum(labels_mask)
79
- return loss
80
-
81
-
82
- def log_on_trackers(
83
- trackers,
84
- generated_audio,
85
- generated_attn,
86
- generated_spec,
87
- target_spec,
88
- full_generation_waveform,
89
- epoch,
90
- sampling_rate,
91
- ):
92
- max_num_samples = min(len(generated_audio), 50)
93
- generated_audio = generated_audio[:max_num_samples]
94
- generated_attn = generated_attn[:max_num_samples]
95
- generated_spec = generated_spec[:max_num_samples]
96
- target_spec = target_spec[:max_num_samples]
97
-
98
- for tracker in trackers:
99
- if tracker.name == "tensorboard":
100
- for cpt, audio in enumerate(generated_audio):
101
- tracker.writer.add_audio(f"train_step_audio_{cpt}", audio[None, :], epoch, sample_rate=sampling_rate)
102
-
103
- for cpt, audio in enumerate(full_generation_waveform):
104
- tracker.writer.add_audio(
105
- f"full_generation_sample{cpt}", audio[None, :], epoch, sample_rate=sampling_rate
106
- )
107
-
108
- tracker.writer.add_images("alignements", np.stack(generated_attn), dataformats="NHWC")
109
- tracker.writer.add_images("spectrogram", np.stack(generated_spec), dataformats="NHWC")
110
- tracker.writer.add_images("target spectrogram", np.stack(target_spec), dataformats="NHWC")
111
- elif tracker.name == "wandb":
112
- # wandb can only loads 100 audios per step
113
- tracker.log(
114
- {
115
- "alignments": [wandb.Image(attn, caption=f"Audio epoch {epoch}") for attn in generated_attn],
116
- "spectrogram": [wandb.Image(spec, caption=f"Audio epoch {epoch}") for spec in generated_spec],
117
- "target spectrogram": [wandb.Image(spec, caption=f"Audio epoch {epoch}") for spec in target_spec],
118
- "train generated audio": [
119
- wandb.Audio(
120
- audio[0],
121
- caption=f"Audio during train step epoch {epoch}",
122
- sample_rate=sampling_rate,
123
- )
124
- for audio in generated_audio
125
- ],
126
- "full generations samples": [
127
- wandb.Audio(w, caption=f"Full generation sample {epoch}", sample_rate=sampling_rate)
128
- for w in full_generation_waveform
129
- ],
130
- }
131
- )
132
- else:
133
- logger.warn(f"audio logging not implemented for {tracker.name}")
134
-
135
-
136
- def compute_val_metrics_and_losses(
137
- val_losses,
138
- accelerator,
139
- model_outputs,
140
- mel_scaled_generation,
141
- mel_scaled_target,
142
- batch_size,
143
- compute_clap_similarity=False,
144
- ):
145
- loss_mel = torch.nn.functional.l1_loss(mel_scaled_target, mel_scaled_generation)
146
- loss_kl = kl_loss(
147
- model_outputs.prior_latents,
148
- model_outputs.posterior_log_variances,
149
- model_outputs.prior_means,
150
- model_outputs.prior_log_variances,
151
- model_outputs.labels_padding_mask,
152
- )
153
-
154
- losses_mel_kl = loss_mel + loss_kl
155
-
156
- losses = torch.stack([loss_mel, loss_kl, losses_mel_kl])
157
- losses = accelerator.gather(losses.repeat(batch_size, 1)).mean(0)
158
-
159
- for key, loss in zip(["val_loss_mel", "val_loss_kl", "val_loss_mel_kl"], losses):
160
- val_losses[key] = val_losses.get(key, 0) + loss.item()
161
-
162
- return val_losses
163
-
164
-
165
- #.............................................
166
-
167
- def vits_trainin(
168
- model,
169
- tokenizer,
170
- model_args,
171
- data_args,
172
- training_args,
173
- train_dataset,
174
- eval_dataset,
175
-
176
- ):
177
-
178
-
179
-
180
-
181
- send_example_telemetry("run_vits_finetuning", model_args, data_args)
182
-
183
- logging.basicConfig(
184
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
185
- datefmt="%m/%d/%Y %H:%M:%S",
186
- handlers=[logging.StreamHandler(sys.stdout)],
187
- )
188
- log_level = training_args.get_process_log_level()
189
- logger.setLevel(log_level)
190
- # datasets.utils.logging.set_verbosity(log_level)
191
- # transformers.utils.logging.set_verbosity(log_level)
192
- # transformers.utils.logging.enable_default_handler()
193
- # transformers.utils.logging.enable_explicit_format()
194
- # # logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN)
195
- # if is_main_process(training_args.local_rank):
196
- # transformers.utils.logging.set_verbosity_info()
197
-
198
-
199
-
200
-
201
- set_seed(training_args.seed)
202
-
203
-
204
-
205
- config = model.config
206
- feature_extractor = VitsFeatureExtractor()
207
-
208
- forward_attention_mask = True
209
-
210
-
211
- with training_args.main_process_first(desc="apply_weight_norm"):
212
- # apply weight norms
213
- model.decoder.apply_weight_norm()
214
- for flow in model.flow.flows:
215
- torch.nn.utils.weight_norm(flow.conv_pre)
216
- torch.nn.utils.weight_norm(flow.conv_post)
217
-
218
-
219
-
220
- with training_args.main_process_first():
221
- # only the main process saves them
222
- if is_main_process(training_args.local_rank):
223
- # save feature extractor, tokenizer and config
224
- feature_extractor.save_pretrained(training_args.output_dir)
225
- tokenizer.save_pretrained(training_args.output_dir)
226
- config.save_pretrained(training_args.output_dir)
227
-
228
-
229
- data_collator = DataCollatorTTSWithPadding(
230
- tokenizer=tokenizer,
231
- feature_extractor=feature_extractor,
232
- forward_attention_mask=forward_attention_mask,
233
- )
234
-
235
- with training_args.main_process_first():
236
- input_str = data_args.full_generation_sample_text
237
- full_generation_sample = tokenizer(input_str, return_tensors="pt")
238
-
239
-
240
- project_name = data_args.project_name
241
- logging_dir = os.path.join(training_args.output_dir, training_args.logging_dir)
242
- accelerator_project_config = ProjectConfiguration(project_dir=training_args.output_dir, logging_dir=logging_dir)
243
-
244
- accelerator = Accelerator(
245
- gradient_accumulation_steps=training_args.gradient_accumulation_steps,
246
- log_with=training_args.report_to,
247
- project_config=accelerator_project_config,
248
- kwargs_handlers=[ddp_kwargs],
249
- )
250
-
251
- per_device_train_batch_size = (
252
- training_args.per_device_train_batch_size if training_args.per_device_train_batch_size else 1
253
- )
254
- total_batch_size = (
255
- per_device_train_batch_size * accelerator.num_processes * training_args.gradient_accumulation_steps
256
- )
257
-
258
- num_speakers = model.config.num_speakers
259
- if training_args.gradient_checkpointing:
260
- model.gradient_checkpointing_enable()
261
-
262
-
263
-
264
- train_dataloader = None
265
- if training_args.do_train:
266
- sampler = (
267
- LengthGroupedSampler(
268
- batch_size=per_device_train_batch_size,
269
- dataset=train_dataset,
270
- lengths=train_dataset["tokens_input_length"],
271
- )
272
- if training_args.group_by_length
273
- else None
274
- )
275
- train_dataloader = torch.utils.data.DataLoader(
276
- train_dataset,
277
- shuffle=False,#not training_args.group_by_length,
278
- collate_fn=data_collator,
279
- batch_size=training_args.per_device_train_batch_size,
280
- num_workers=training_args.dataloader_num_workers,
281
- sampler=sampler,
282
- )
283
-
284
- eval_dataloader = None
285
- if training_args.do_eval:
286
- eval_sampler = (
287
- LengthGroupedSampler(
288
- batch_size=training_args.per_device_eval_batch_size,
289
- dataset=eval_dataset,
290
- lengths=eval_dataset["tokens_input_length"],
291
- )
292
- if training_args.group_by_length
293
- else None
294
- )
295
-
296
- eval_dataloader = torch.utils.data.DataLoader(
297
- eval_dataset,
298
- shuffle=False,
299
- collate_fn=data_collator,
300
- batch_size=training_args.per_device_eval_batch_size,
301
- num_workers=training_args.dataloader_num_workers,
302
- sampler=eval_sampler,
303
- )
304
-
305
- model_segment_size = model.segment_size
306
- config_segment_size = model.config.segment_size
307
- sampling_rate = model.config.sampling_rate
308
-
309
- # Scheduler and math around the number of training steps.
310
- overrode_max_train_steps = False
311
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps)
312
- if training_args.max_steps == -1:
313
- training_args.max_steps = training_args.num_train_epochs * num_update_steps_per_epoch
314
- overrode_max_train_steps = True
315
-
316
- # We need to recalculate our total training steps as the size of the training dataloader may have changed.
317
- num_update_steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps)
318
- if overrode_max_train_steps:
319
- training_args.max_steps = int(training_args.num_train_epochs * num_update_steps_per_epoch)
320
- # Afterwards we recalculate our number of training epochs
321
- training_args.num_train_epochs = math.ceil(training_args.max_steps / num_update_steps_per_epoch)
322
-
323
- # hack to be able to train on multiple device
324
- with tempfile.TemporaryDirectory() as tmpdirname:
325
- model.discriminator.save_pretrained(tmpdirname)
326
- discriminator = VitsDiscriminator.from_pretrained(tmpdirname)
327
- for disc in discriminator.discriminators:
328
- disc.apply_weight_norm()
329
- del model.discriminator
330
-
331
- # init gen_optimizer, gen_lr_scheduler, disc_optimizer, dics_lr_scheduler
332
- gen_optimizer = torch.optim.AdamW(
333
- model.parameters(),
334
- training_args.learning_rate,
335
- betas=[training_args.adam_beta1, training_args.adam_beta2],
336
- eps=training_args.adam_epsilon,
337
- )
338
-
339
- disc_optimizer = torch.optim.AdamW(
340
- discriminator.parameters(),
341
- training_args.learning_rate,
342
- betas=[training_args.adam_beta1, training_args.adam_beta2],
343
- eps=training_args.adam_epsilon,
344
- )
345
-
346
- num_warmups_steps = training_args.get_warmup_steps(training_args.num_train_epochs * accelerator.num_processes)
347
- num_training_steps = training_args.num_train_epochs * accelerator.num_processes
348
-
349
- gen_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
350
- gen_optimizer, gamma=training_args.lr_decay, last_epoch=-1
351
- )
352
- disc_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(
353
- disc_optimizer, gamma=training_args.lr_decay, last_epoch=-1
354
- )
355
-
356
-
357
- # Prepare everything with our `accelerator`.
358
- (
359
- model,
360
- discriminator,
361
- gen_optimizer,
362
- gen_lr_scheduler,
363
- disc_optimizer,
364
- disc_lr_scheduler,
365
- train_dataloader,
366
- eval_dataloader,
367
- ) = accelerator.prepare(
368
- model,
369
- discriminator,
370
- gen_optimizer,
371
- gen_lr_scheduler,
372
- disc_optimizer,
373
- disc_lr_scheduler,
374
- train_dataloader,
375
- eval_dataloader,
376
- )
377
-
378
-
379
- # We need to initialize the trackers we use, and also store our configuration.
380
- # The trackers initializes automatically on the main process.
381
- if accelerator.is_main_process:
382
- tracker_config = training_args.to_sanitized_dict()
383
- accelerator.init_trackers(project_name, tracker_config)
384
-
385
-
386
-
387
- # Train!
388
- logger.info("***** Running training *****")
389
- logger.info(f" Num examples = {len(train_dataset)}")
390
- logger.info(f" Num Epochs = {training_args.num_train_epochs}")
391
- logger.info(f" Instantaneous batch size per device = {per_device_train_batch_size}")
392
- logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
393
- logger.info(f" Gradient Accumulation steps = {training_args.gradient_accumulation_steps}")
394
- logger.info(f" Total optimization steps = {training_args.max_steps}")
395
- global_step = 0
396
- first_epoch = 0
397
-
398
-
399
-
400
- # Potentially load in the weights and states from a previous save
401
- if training_args.resume_from_checkpoint:
402
- if training_args.resume_from_checkpoint != "latest":
403
- path = os.path.basename(training_args.resume_from_checkpoint)
404
- else:
405
- # Get the most recent checkpoint
406
- dirs = os.listdir(training_args.output_dir)
407
- dirs = [d for d in dirs if d.startswith("checkpoint")]
408
- dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
409
- path = dirs[-1] if len(dirs) > 0 else None
410
-
411
- if path is None:
412
- accelerator.print(
413
- f"Checkpoint '{training_args.resume_from_checkpoint}' does not exist. Starting a new training run."
414
- )
415
- training_args.resume_from_checkpoint = None
416
- initial_global_step = 0
417
- else:
418
- accelerator.print(f"Resuming from checkpoint {path}")
419
- accelerator.load_state(os.path.join(training_args.output_dir, path))
420
- global_step = int(path.split("-")[1])
421
-
422
- initial_global_step = global_step
423
- first_epoch = global_step // num_update_steps_per_epoch
424
-
425
- else:
426
- initial_global_step = 0
427
-
428
-
429
-
430
- #.......................loop training............................
431
-
432
- for epoch in range(first_epoch, training_args.num_train_epochs):
433
- # keep track of train losses
434
- train_losses = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
435
-
436
- disc_lr_scheduler.step()
437
- gen_lr_scheduler.step()
438
-
439
- for step, batch in enumerate(train_dataloader):
440
- print(f"TRAINIG - batch {step}, process{accelerator.process_index}, waveform {(batch['waveform'].shape)}, tokens {(batch['input_ids'].shape)}... ")
441
- with accelerator.accumulate(model, discriminator):
442
- # forward through model
443
- model_outputs = model(
444
- input_ids=batch["input_ids"],
445
- attention_mask=batch["attention_mask"],
446
- labels=batch["labels"],
447
- labels_attention_mask=batch["labels_attention_mask"],
448
- speaker_id=batch["speaker_id"],
449
- encoder_output = batch['text_encoder_output'],
450
-
451
- return_dict=True,
452
- monotonic_alignment_function=None,
453
- )
454
-
455
- mel_scaled_labels = batch["mel_scaled_input_features"]
456
- mel_scaled_target = model.slice_segments(mel_scaled_labels, model_outputs.ids_slice, model_segment_size)
457
- mel_scaled_generation = feature_extractor._torch_extract_fbank_features(
458
- model_outputs.waveform.squeeze(1)
459
- )[1]
460
-
461
- target_waveform = batch["waveform"].transpose(1, 2)
462
- target_waveform = model.slice_segments(
463
- target_waveform, model_outputs.ids_slice * feature_extractor.hop_length, config_segment_size
464
- )
465
-
466
- # -----------------------
467
- # Train Discriminator
468
- # -----------------------
469
-
470
- discriminator_target, _ = discriminator(target_waveform)
471
- discriminator_candidate, _ = discriminator(model_outputs.waveform.detach())
472
-
473
- loss_disc, loss_real_disc, loss_fake_disc = discriminator_loss(
474
- discriminator_target, discriminator_candidate
475
- )
476
-
477
- # backpropagate
478
- accelerator.backward(loss_disc * training_args.weight_disc)
479
- if accelerator.sync_gradients:
480
- accelerator.clip_grad_norm_(discriminator.parameters(), training_args.max_grad_norm)
481
- disc_optimizer.step()
482
- if not training_args.do_step_schedule_per_epoch:
483
- disc_lr_scheduler.step()
484
- disc_optimizer.zero_grad()
485
-
486
- # -----------------------
487
- # Train Generator
488
- # -----------------------
489
-
490
- _, fmaps_target = discriminator(target_waveform)
491
- discriminator_candidate, fmaps_candidate = discriminator(model_outputs.waveform)
492
-
493
- loss_duration = torch.sum(model_outputs.log_duration)
494
- loss_mel = torch.nn.functional.l1_loss(mel_scaled_target, mel_scaled_generation)
495
- loss_kl = kl_loss(
496
- model_outputs.prior_latents,
497
- model_outputs.posterior_log_variances,
498
- model_outputs.prior_means,
499
- model_outputs.prior_log_variances,
500
- model_outputs.labels_padding_mask,
501
- )
502
- loss_fmaps = feature_loss(fmaps_target, fmaps_candidate)
503
- loss_gen, losses_gen = generator_loss(discriminator_candidate)
504
-
505
- total_generator_loss = (
506
- loss_duration * training_args.weight_duration
507
- + loss_mel * training_args.weight_mel
508
- + loss_kl * training_args.weight_kl
509
- + loss_fmaps * training_args.weight_fmaps
510
- + loss_gen * training_args.weight_gen
511
- )
512
-
513
- # backpropagate
514
- accelerator.backward(total_generator_loss)
515
- if accelerator.sync_gradients:
516
- accelerator.clip_grad_norm_(model.parameters(), training_args.max_grad_norm)
517
- gen_optimizer.step()
518
- if not training_args.do_step_schedule_per_epoch:
519
- gen_lr_scheduler.step()
520
- gen_optimizer.zero_grad()
521
-
522
- # update and gather losses
523
- losses = torch.stack(
524
- [
525
- # for fair comparison, don't use weighted loss
526
- loss_duration + loss_mel + loss_kl + loss_fmaps + loss_gen,
527
- loss_duration,
528
- loss_mel,
529
- loss_kl,
530
- loss_fmaps,
531
- loss_gen,
532
- loss_disc,
533
- loss_real_disc,
534
- loss_fake_disc,
535
- ]
536
- )
537
- losses = accelerator.gather(losses.repeat(per_device_train_batch_size, 1)).mean(0)
538
-
539
- train_losses = [
540
- l + losses[i].item() / training_args.gradient_accumulation_steps
541
- for (i, l) in enumerate(train_losses)
542
- ]
543
-
544
- # Checks if the accelerator has performed an optimization step behind the scenes
545
- if accelerator.sync_gradients:
546
- (
547
- train_summed_losses,
548
- train_loss_duration,
549
- train_loss_mel,
550
- train_loss_kl,
551
- train_loss_fmaps,
552
- train_loss_gen,
553
- train_loss_disc,
554
- train_loss_real_disc,
555
- train_loss_fake_disc,
556
- ) = train_losses
557
-
558
- global_step += 1
559
- accelerator.log(
560
- {
561
- "train_summed_losses": train_summed_losses,
562
- "train_loss_disc": train_loss_disc,
563
- "train_loss_real_disc": train_loss_real_disc,
564
- "train_loss_fake_disc": train_loss_fake_disc,
565
- "train_loss_duration": train_loss_duration,
566
- "train_loss_mel": train_loss_mel,
567
- "train_loss_kl": train_loss_kl,
568
- "train_loss_fmaps": train_loss_fmaps,
569
- "train_loss_gen": train_loss_gen,
570
- "lr": disc_lr_scheduler.get_last_lr()[0],
571
- },
572
- step=global_step,
573
- )
574
- train_losses = [0.0 for _ in train_losses]
575
-
576
- if global_step % training_args.save_steps == 0:
577
- if accelerator.is_main_process:
578
- # _before_ saving state, check if this save would set us over the `save_total_limit`
579
- if training_args.save_total_limit is not None:
580
- checkpoints = os.listdir(training_args.output_dir)
581
- checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
582
- checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
583
-
584
- # before we save the new checkpoint, we need to have at _most_ `save_total_limit - 1` checkpoints
585
- if len(checkpoints) >= training_args.save_total_limit:
586
- num_to_remove = len(checkpoints) - training_args.save_total_limit + 1
587
- removing_checkpoints = checkpoints[0:num_to_remove]
588
-
589
- logger.info(
590
- f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
591
- )
592
- logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
593
-
594
- for removing_checkpoint in removing_checkpoints:
595
- removing_checkpoint = os.path.join(training_args.output_dir, removing_checkpoint)
596
- shutil.rmtree(removing_checkpoint)
597
-
598
- save_path = os.path.join(training_args.output_dir, f"checkpoint-{global_step}")
599
- accelerator.save_state(save_path)
600
- logger.info(f"Saved state to {save_path}")
601
-
602
- logs = {
603
- "step_loss": total_generator_loss.detach().item(),
604
- "lr": disc_lr_scheduler.get_last_lr()[0],
605
- "step_loss_duration": loss_duration.detach().item(),
606
- "step_loss_mel": loss_mel.detach().item(),
607
- "step_loss_kl": loss_kl.detach().item(),
608
- "step_loss_fmaps": loss_fmaps.detach().item(),
609
- "step_loss_gen": loss_gen.detach().item(),
610
- "step_loss_disc": loss_disc.detach().item(),
611
- "step_loss_real_disc": loss_real_disc.detach().item(),
612
- "step_loss_fake_disc": loss_fake_disc.detach().item(),
613
- }
614
-
615
-
616
- if global_step >= training_args.max_steps:
617
- break
618
-
619
- eval_steps = training_args.eval_steps if training_args.eval_steps else 1
620
- do_eval = training_args.do_eval and (global_step % eval_steps == 0) and accelerator.sync_gradients
621
-
622
- if do_eval:
623
- logger.info("Running validation... ")
624
- generated_audio = []
625
- generated_attn = []
626
- generated_spec = []
627
- target_spec = []
628
- val_losses = {}
629
- for step, batch in enumerate(eval_dataloader):
630
- print(
631
- f"VALIDATION - batch {step}, process{accelerator.process_index}, waveform {(batch['waveform'].shape)}, tokens {(batch['input_ids'].shape)}... "
632
- )
633
- with torch.no_grad():
634
- model_outputs_train = model(
635
- input_ids=batch["input_ids"],
636
- attention_mask=batch["attention_mask"],
637
- labels=batch["labels"],
638
- labels_attention_mask=batch["labels_attention_mask"],
639
- speaker_id=batch["speaker_id"],
640
- encoder_output = batch['text_encoder_output'],
641
-
642
- return_dict=True,
643
- monotonic_alignment_function=None,
644
- )
645
-
646
- mel_scaled_labels = batch["mel_scaled_input_features"]
647
- mel_scaled_target = model.slice_segments(
648
- mel_scaled_labels, model_outputs_train.ids_slice, model_segment_size
649
- )
650
- mel_scaled_generation = feature_extractor._torch_extract_fbank_features(
651
- model_outputs_train.waveform.squeeze(1)
652
- )[1]
653
-
654
- val_losses = compute_val_metrics_and_losses(
655
- val_losses,
656
- accelerator,
657
- model_outputs_train,
658
- mel_scaled_generation,
659
- mel_scaled_target,
660
- per_device_train_batch_size,
661
- compute_clap_similarity=False,
662
- )
663
-
664
- print(f"VALIDATION - batch {step}, process{accelerator.process_index}, PADDING AND GATHER... ")
665
- specs = feature_extractor._torch_extract_fbank_features(model_outputs_train.waveform.squeeze(1))[0]
666
- padded_attn, specs, target_specs = accelerator.pad_across_processes(
667
- [model_outputs_train.attn.squeeze(1), specs, batch["labels"]], dim=1
668
- )
669
- padded_attn, specs, target_specs = accelerator.pad_across_processes(
670
- [padded_attn, specs, target_specs], dim=2
671
- )
672
-
673
- generated_train_waveform, padded_attn, specs, target_specs = accelerator.gather_for_metrics(
674
- [model_outputs_train.waveform, padded_attn, specs, target_specs]
675
- )
676
-
677
-
678
- if accelerator.is_main_process:
679
- with torch.no_grad():
680
- speaker_id = None if num_speakers < 2 else list(range(min(5, num_speakers)))
681
- full_generation = model(**full_generation_sample.to(model.device), speaker_id=speaker_id)
682
-
683
- generated_audio.append(generated_train_waveform.cpu())
684
- generated_attn.append(padded_attn.cpu())
685
- generated_spec.append(specs.cpu())
686
- target_spec.append(target_specs.cpu())
687
-
688
- logger.info("Validation inference done, now evaluating... ")
689
- if accelerator.is_main_process:
690
- generated_audio = [audio.numpy() for audio_batch in generated_audio for audio in audio_batch]
691
- generated_attn = [
692
- plot_alignment_to_numpy(attn.numpy()) for attn_batch in generated_attn for attn in attn_batch
693
- ]
694
- generated_spec = [
695
- plot_spectrogram_to_numpy(attn.numpy()) for attn_batch in generated_spec for attn in attn_batch
696
- ]
697
- target_spec = [
698
- plot_spectrogram_to_numpy(attn.numpy()) for attn_batch in target_spec for attn in attn_batch
699
- ]
700
- full_generation_waveform = full_generation.waveform.cpu().numpy()
701
-
702
- accelerator.log(val_losses, step=global_step)
703
-
704
- log_on_trackers(
705
- accelerator.trackers,
706
- generated_audio,
707
- generated_attn,
708
- generated_spec,
709
- target_spec,
710
- full_generation_waveform,
711
- epoch,
712
- sampling_rate,
713
- )
714
-
715
- logger.info("Validation finished... ")
716
-
717
- accelerator.wait_for_everyone()
718
-
719
- accelerator.wait_for_everyone()
720
- if accelerator.is_main_process:
721
- epoch = training_args.num_train_epochs if training_args.num_train_epochs else 1
722
- eval_steps = training_args.eval_steps if training_args.eval_steps else 1
723
-
724
- # Run a final round of inference.
725
- do_eval = training_args.do_eval
726
-
727
- if do_eval:
728
- logger.info("Running final validation... ")
729
- generated_audio = []
730
- generated_attn = []
731
- generated_spec = []
732
- target_spec = []
733
- val_losses = {}
734
- for step, batch in enumerate(eval_dataloader):
735
- print(
736
- f"VALIDATION - batch {step}, process{accelerator.process_index}, waveform {(batch['waveform'].shape)}, tokens {(batch['input_ids'].shape)}... "
737
- )
738
- with torch.no_grad():
739
- model_outputs_train = model(
740
- input_ids=batch["input_ids"],
741
- attention_mask=batch["attention_mask"],
742
- labels=batch["labels"],
743
- labels_attention_mask=batch["labels_attention_mask"],
744
- speaker_id=batch["speaker_id"],
745
- encoder_output = batch['text_encoder_output'],
746
-
747
- return_dict=True,
748
- monotonic_alignment_function=None,
749
- )
750
-
751
- mel_scaled_labels = batch["mel_scaled_input_features"]
752
- mel_scaled_target = model.slice_segments(
753
- mel_scaled_labels, model_outputs_train.ids_slice, model_segment_size
754
- )
755
- mel_scaled_generation = feature_extractor._torch_extract_fbank_features(
756
- model_outputs_train.waveform.squeeze(1)
757
- )[1]
758
-
759
- val_losses = compute_val_metrics_and_losses(
760
- val_losses,
761
- accelerator,
762
- model_outputs_train,
763
- mel_scaled_generation,
764
- mel_scaled_target,
765
- per_device_train_batch_size,
766
- compute_clap_similarity=False,
767
- )
768
- specs = feature_extractor._torch_extract_fbank_features(model_outputs_train.waveform.squeeze(1))[0]
769
- padded_attn, specs, target_specs = accelerator.pad_across_processes(
770
- [model_outputs_train.attn.squeeze(1), specs, batch["labels"]], dim=1
771
- )
772
- padded_attn, specs, target_specs = accelerator.pad_across_processes(
773
- [padded_attn, specs, target_specs], dim=2
774
- )
775
-
776
- generated_train_waveform, padded_attn, specs, target_specs = accelerator.gather_for_metrics(
777
- [model_outputs_train.waveform, padded_attn, specs, target_specs]
778
- )
779
-
780
- if accelerator.is_main_process:
781
- with torch.no_grad():
782
- speaker_id = None if num_speakers < 2 else list(range(min(5, num_speakers)))
783
- full_generation = model(**full_generation_sample.to(model.device), speaker_id=speaker_id)
784
-
785
- generated_audio.append(generated_train_waveform.cpu())
786
- generated_attn.append(padded_attn.cpu())
787
- generated_spec.append(specs.cpu())
788
- target_spec.append(target_specs.cpu())
789
-
790
- logger.info("Validation inference done, now evaluating... ")
791
- if accelerator.is_main_process:
792
- generated_audio = [audio.numpy() for audio_batch in generated_audio for audio in audio_batch]
793
- generated_attn = [
794
- plot_alignment_to_numpy(attn.numpy()) for attn_batch in generated_attn for attn in attn_batch
795
- ]
796
- generated_spec = [
797
- plot_spectrogram_to_numpy(attn.numpy()) for attn_batch in generated_spec for attn in attn_batch
798
- ]
799
- target_spec = [
800
- plot_spectrogram_to_numpy(attn.numpy()) for attn_batch in target_spec for attn in attn_batch
801
- ]
802
- full_generation_waveform = full_generation.waveform.cpu().numpy()
803
-
804
- log_on_trackers(
805
- accelerator.trackers,
806
- generated_audio,
807
- generated_attn,
808
- generated_spec,
809
- target_spec,
810
- full_generation_waveform,
811
- epoch,
812
- sampling_rate,
813
- )
814
-
815
- accelerator.log(val_losses, step=global_step)
816
- logger.info("Validation finished... ")
817
-
818
- accelerator.wait_for_everyone()
819
-
820
- # unwrap, save and push final model
821
- model = accelerator.unwrap_model(model)
822
- discriminator = accelerator.unwrap_model(discriminator)
823
-
824
- model.discriminator = discriminator
825
-
826
- # add weight norms
827
- for disc in model.discriminator.discriminators:
828
- disc.remove_weight_norm()
829
- model.decoder.remove_weight_norm()
830
- for flow in model.flow.flows:
831
- torch.nn.utils.remove_weight_norm(flow.conv_pre)
832
- torch.nn.utils.remove_weight_norm(flow.conv_post)
833
-
834
- model.save_pretrained(training_args.output_dir)
835
-
836
- if training_args.push_to_hub:
837
- VitsModel.from_pretrained(training_args.output_dir).push_to_hub(training_args.hub_model_id)
838
-
839
- accelerator.end_training()
840
-
841
-
842
-
843
- logger.info("***** Training / Inference Done *****")
844
-
845
-
846
-
847
-
848
- #...............................................................................