import os import shutil import tempfile import numpy as np import wandb from transformers import VitsModel import math import torch from accelerate.utils import ProjectConfiguration, is_wandb_available, set_seed from accelerate import Accelerator, DistributedDataParallelKwargs from transformers.utils import send_example_telemetry import logging import sys from transformers.trainer_utils import get_last_checkpoint, is_main_process from transformers.trainer_pt_utils import LengthGroupedSampler from transformers.optimization import get_scheduler from .data_collator import DataCollatorTTSWithPadding from .discriminator import VitsDiscriminator from .feature_extraction import VitsFeatureExtractor from .plot import plot_alignment_to_numpy, plot_spectrogram_to_numpy #............................................. if is_wandb_available(): import wandb ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=False) logger = logging.getLogger(__name__) #............................................. def discriminator_loss(disc_real_outputs, disc_generated_outputs): loss = 0 real_losses = 0 generated_losses = 0 for disc_real, disc_generated in zip(disc_real_outputs, disc_generated_outputs): real_loss = torch.mean((1 - disc_real) ** 2) generated_loss = torch.mean(disc_generated**2) loss += real_loss + generated_loss real_losses += real_loss generated_losses += generated_loss return loss, real_losses, generated_losses def feature_loss(feature_maps_real, feature_maps_generated): loss = 0 for feature_map_real, feature_map_generated in zip(feature_maps_real, feature_maps_generated): for real, generated in zip(feature_map_real, feature_map_generated): real = real.detach() loss += torch.mean(torch.abs(real - generated)) return loss * 2 def generator_loss(disc_outputs): total_loss = 0 gen_losses = [] for disc_output in disc_outputs: disc_output = disc_output loss = torch.mean((1 - disc_output) ** 2) gen_losses.append(loss) total_loss += loss return total_loss, gen_losses def kl_loss(prior_latents, posterior_log_variance, prior_means, prior_log_variance, labels_mask): """ z_p, logs_q: [b, h, t_t] prior_means, prior_log_variance: [b, h, t_t] """ kl = prior_log_variance - posterior_log_variance - 0.5 kl += 0.5 * ((prior_latents - prior_means) ** 2) * torch.exp(-2.0 * prior_log_variance) kl = torch.sum(kl * labels_mask) loss = kl / torch.sum(labels_mask) return loss def log_on_trackers( trackers, generated_audio, generated_attn, generated_spec, target_spec, full_generation_waveform, epoch, sampling_rate, ): max_num_samples = min(len(generated_audio), 50) generated_audio = generated_audio[:max_num_samples] generated_attn = generated_attn[:max_num_samples] generated_spec = generated_spec[:max_num_samples] target_spec = target_spec[:max_num_samples] for tracker in trackers: if tracker.name == "tensorboard": for cpt, audio in enumerate(generated_audio): tracker.writer.add_audio(f"train_step_audio_{cpt}", audio[None, :], epoch, sample_rate=sampling_rate) for cpt, audio in enumerate(full_generation_waveform): tracker.writer.add_audio( f"full_generation_sample{cpt}", audio[None, :], epoch, sample_rate=sampling_rate ) tracker.writer.add_images("alignements", np.stack(generated_attn), dataformats="NHWC") tracker.writer.add_images("spectrogram", np.stack(generated_spec), dataformats="NHWC") tracker.writer.add_images("target spectrogram", np.stack(target_spec), dataformats="NHWC") elif tracker.name == "wandb": # wandb can only loads 100 audios per step tracker.log( { "alignments": [wandb.Image(attn, caption=f"Audio epoch {epoch}") for attn in generated_attn], "spectrogram": [wandb.Image(spec, caption=f"Audio epoch {epoch}") for spec in generated_spec], "target spectrogram": [wandb.Image(spec, caption=f"Audio epoch {epoch}") for spec in target_spec], "train generated audio": [ wandb.Audio( audio[0], caption=f"Audio during train step epoch {epoch}", sample_rate=sampling_rate, ) for audio in generated_audio ], "full generations samples": [ wandb.Audio(w, caption=f"Full generation sample {epoch}", sample_rate=sampling_rate) for w in full_generation_waveform ], } ) else: logger.warn(f"audio logging not implemented for {tracker.name}") def compute_val_metrics_and_losses( val_losses, accelerator, model_outputs, mel_scaled_generation, mel_scaled_target, batch_size, compute_clap_similarity=False, ): loss_mel = torch.nn.functional.l1_loss(mel_scaled_target, mel_scaled_generation) loss_kl = kl_loss( model_outputs.prior_latents, model_outputs.posterior_log_variances, model_outputs.prior_means, model_outputs.prior_log_variances, model_outputs.labels_padding_mask, ) losses_mel_kl = loss_mel + loss_kl losses = torch.stack([loss_mel, loss_kl, losses_mel_kl]) losses = accelerator.gather(losses.repeat(batch_size, 1)).mean(0) for key, loss in zip(["val_loss_mel", "val_loss_kl", "val_loss_mel_kl"], losses): val_losses[key] = val_losses.get(key, 0) + loss.item() return val_losses #............................................. def vits_trainin( model, tokenizer, model_args, data_args, training_args, train_dataset, eval_dataset, ): send_example_telemetry("run_vits_finetuning", model_args, data_args) logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", handlers=[logging.StreamHandler(sys.stdout)], ) log_level = training_args.get_process_log_level() logger.setLevel(log_level) # datasets.utils.logging.set_verbosity(log_level) # transformers.utils.logging.set_verbosity(log_level) # transformers.utils.logging.enable_default_handler() # transformers.utils.logging.enable_explicit_format() # # logger.setLevel(logging.INFO if is_main_process(training_args.local_rank) else logging.WARN) # if is_main_process(training_args.local_rank): # transformers.utils.logging.set_verbosity_info() set_seed(training_args.seed) config = model.config feature_extractor = VitsFeatureExtractor() forward_attention_mask = True with training_args.main_process_first(desc="apply_weight_norm"): # apply weight norms model.decoder.apply_weight_norm() for flow in model.flow.flows: torch.nn.utils.weight_norm(flow.conv_pre) torch.nn.utils.weight_norm(flow.conv_post) with training_args.main_process_first(): # only the main process saves them if is_main_process(training_args.local_rank): # save feature extractor, tokenizer and config feature_extractor.save_pretrained(training_args.output_dir) tokenizer.save_pretrained(training_args.output_dir) config.save_pretrained(training_args.output_dir) data_collator = DataCollatorTTSWithPadding( tokenizer=tokenizer, feature_extractor=feature_extractor, forward_attention_mask=forward_attention_mask, ) with training_args.main_process_first(): input_str = data_args.full_generation_sample_text full_generation_sample = tokenizer(input_str, return_tensors="pt") project_name = data_args.project_name logging_dir = os.path.join(training_args.output_dir, training_args.logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=training_args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( gradient_accumulation_steps=training_args.gradient_accumulation_steps, log_with=training_args.report_to, project_config=accelerator_project_config, kwargs_handlers=[ddp_kwargs], ) per_device_train_batch_size = ( training_args.per_device_train_batch_size if training_args.per_device_train_batch_size else 1 ) total_batch_size = ( per_device_train_batch_size * accelerator.num_processes * training_args.gradient_accumulation_steps ) num_speakers = model.config.num_speakers if training_args.gradient_checkpointing: model.gradient_checkpointing_enable() train_dataloader = None if training_args.do_train: sampler = ( LengthGroupedSampler( batch_size=per_device_train_batch_size, dataset=train_dataset, lengths=train_dataset["tokens_input_length"], ) if training_args.group_by_length else None ) train_dataloader = torch.utils.data.DataLoader( train_dataset, shuffle=False,#not training_args.group_by_length, collate_fn=data_collator, batch_size=training_args.per_device_train_batch_size, num_workers=training_args.dataloader_num_workers, sampler=sampler, ) eval_dataloader = None if training_args.do_eval: eval_sampler = ( LengthGroupedSampler( batch_size=training_args.per_device_eval_batch_size, dataset=eval_dataset, lengths=eval_dataset["tokens_input_length"], ) if training_args.group_by_length else None ) eval_dataloader = torch.utils.data.DataLoader( eval_dataset, shuffle=False, collate_fn=data_collator, batch_size=training_args.per_device_eval_batch_size, num_workers=training_args.dataloader_num_workers, sampler=eval_sampler, ) model_segment_size = model.segment_size config_segment_size = model.config.segment_size sampling_rate = model.config.sampling_rate # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps) if training_args.max_steps == -1: training_args.max_steps = training_args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / training_args.gradient_accumulation_steps) if overrode_max_train_steps: training_args.max_steps = int(training_args.num_train_epochs * num_update_steps_per_epoch) # Afterwards we recalculate our number of training epochs training_args.num_train_epochs = math.ceil(training_args.max_steps / num_update_steps_per_epoch) # hack to be able to train on multiple device with tempfile.TemporaryDirectory() as tmpdirname: model.discriminator.save_pretrained(tmpdirname) discriminator = VitsDiscriminator.from_pretrained(tmpdirname) for disc in discriminator.discriminators: disc.apply_weight_norm() del model.discriminator # init gen_optimizer, gen_lr_scheduler, disc_optimizer, dics_lr_scheduler gen_optimizer = torch.optim.AdamW( model.parameters(), training_args.learning_rate, betas=[training_args.adam_beta1, training_args.adam_beta2], eps=training_args.adam_epsilon, ) disc_optimizer = torch.optim.AdamW( discriminator.parameters(), training_args.learning_rate, betas=[training_args.adam_beta1, training_args.adam_beta2], eps=training_args.adam_epsilon, ) num_warmups_steps = training_args.get_warmup_steps(training_args.num_train_epochs * accelerator.num_processes) num_training_steps = training_args.num_train_epochs * accelerator.num_processes gen_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( gen_optimizer, gamma=training_args.lr_decay, last_epoch=-1 ) disc_lr_scheduler = torch.optim.lr_scheduler.ExponentialLR( disc_optimizer, gamma=training_args.lr_decay, last_epoch=-1 ) # Prepare everything with our `accelerator`. ( model, discriminator, gen_optimizer, gen_lr_scheduler, disc_optimizer, disc_lr_scheduler, train_dataloader, eval_dataloader, ) = accelerator.prepare( model, discriminator, gen_optimizer, gen_lr_scheduler, disc_optimizer, disc_lr_scheduler, train_dataloader, eval_dataloader, ) # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. if accelerator.is_main_process: tracker_config = training_args.to_sanitized_dict() accelerator.init_trackers(project_name, tracker_config) # Train! logger.info("***** Running training *****") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num Epochs = {training_args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {per_device_train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") logger.info(f" Gradient Accumulation steps = {training_args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {training_args.max_steps}") global_step = 0 first_epoch = 0 # Potentially load in the weights and states from a previous save if training_args.resume_from_checkpoint: if training_args.resume_from_checkpoint != "latest": path = os.path.basename(training_args.resume_from_checkpoint) else: # Get the most recent checkpoint dirs = os.listdir(training_args.output_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) path = dirs[-1] if len(dirs) > 0 else None if path is None: accelerator.print( f"Checkpoint '{training_args.resume_from_checkpoint}' does not exist. Starting a new training run." ) training_args.resume_from_checkpoint = None initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") accelerator.load_state(os.path.join(training_args.output_dir, path)) global_step = int(path.split("-")[1]) initial_global_step = global_step first_epoch = global_step // num_update_steps_per_epoch else: initial_global_step = 0 #.......................loop training............................ for epoch in range(first_epoch, training_args.num_train_epochs): # keep track of train losses train_losses = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] disc_lr_scheduler.step() gen_lr_scheduler.step() for step, batch in enumerate(train_dataloader): print(f"TRAINIG - batch {step}, process{accelerator.process_index}, waveform {(batch['waveform'].shape)}, tokens {(batch['input_ids'].shape)}... ") with accelerator.accumulate(model, discriminator): # forward through model model_outputs = model( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"], labels_attention_mask=batch["labels_attention_mask"], speaker_id=batch["speaker_id"], encoder_output = batch['text_encoder_output'], return_dict=True, monotonic_alignment_function=None, ) mel_scaled_labels = batch["mel_scaled_input_features"] mel_scaled_target = model.slice_segments(mel_scaled_labels, model_outputs.ids_slice, model_segment_size) mel_scaled_generation = feature_extractor._torch_extract_fbank_features( model_outputs.waveform.squeeze(1) )[1] target_waveform = batch["waveform"].transpose(1, 2) target_waveform = model.slice_segments( target_waveform, model_outputs.ids_slice * feature_extractor.hop_length, config_segment_size ) # ----------------------- # Train Discriminator # ----------------------- discriminator_target, _ = discriminator(target_waveform) discriminator_candidate, _ = discriminator(model_outputs.waveform.detach()) loss_disc, loss_real_disc, loss_fake_disc = discriminator_loss( discriminator_target, discriminator_candidate ) # backpropagate accelerator.backward(loss_disc * training_args.weight_disc) if accelerator.sync_gradients: accelerator.clip_grad_norm_(discriminator.parameters(), training_args.max_grad_norm) disc_optimizer.step() if not training_args.do_step_schedule_per_epoch: disc_lr_scheduler.step() disc_optimizer.zero_grad() # ----------------------- # Train Generator # ----------------------- _, fmaps_target = discriminator(target_waveform) discriminator_candidate, fmaps_candidate = discriminator(model_outputs.waveform) loss_duration = torch.sum(model_outputs.log_duration) loss_mel = torch.nn.functional.l1_loss(mel_scaled_target, mel_scaled_generation) loss_kl = kl_loss( model_outputs.prior_latents, model_outputs.posterior_log_variances, model_outputs.prior_means, model_outputs.prior_log_variances, model_outputs.labels_padding_mask, ) loss_fmaps = feature_loss(fmaps_target, fmaps_candidate) loss_gen, losses_gen = generator_loss(discriminator_candidate) total_generator_loss = ( loss_duration * training_args.weight_duration + loss_mel * training_args.weight_mel + loss_kl * training_args.weight_kl + loss_fmaps * training_args.weight_fmaps + loss_gen * training_args.weight_gen ) # backpropagate accelerator.backward(total_generator_loss) if accelerator.sync_gradients: accelerator.clip_grad_norm_(model.parameters(), training_args.max_grad_norm) gen_optimizer.step() if not training_args.do_step_schedule_per_epoch: gen_lr_scheduler.step() gen_optimizer.zero_grad() # update and gather losses losses = torch.stack( [ # for fair comparison, don't use weighted loss loss_duration + loss_mel + loss_kl + loss_fmaps + loss_gen, loss_duration, loss_mel, loss_kl, loss_fmaps, loss_gen, loss_disc, loss_real_disc, loss_fake_disc, ] ) losses = accelerator.gather(losses.repeat(per_device_train_batch_size, 1)).mean(0) train_losses = [ l + losses[i].item() / training_args.gradient_accumulation_steps for (i, l) in enumerate(train_losses) ] # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: ( train_summed_losses, train_loss_duration, train_loss_mel, train_loss_kl, train_loss_fmaps, train_loss_gen, train_loss_disc, train_loss_real_disc, train_loss_fake_disc, ) = train_losses global_step += 1 accelerator.log( { "train_summed_losses": train_summed_losses, "train_loss_disc": train_loss_disc, "train_loss_real_disc": train_loss_real_disc, "train_loss_fake_disc": train_loss_fake_disc, "train_loss_duration": train_loss_duration, "train_loss_mel": train_loss_mel, "train_loss_kl": train_loss_kl, "train_loss_fmaps": train_loss_fmaps, "train_loss_gen": train_loss_gen, "lr": disc_lr_scheduler.get_last_lr()[0], }, step=global_step, ) train_losses = [0.0 for _ in train_losses] if global_step % training_args.save_steps == 0: if accelerator.is_main_process: # _before_ saving state, check if this save would set us over the `save_total_limit` if training_args.save_total_limit is not None: checkpoints = os.listdir(training_args.output_dir) checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) # before we save the new checkpoint, we need to have at _most_ `save_total_limit - 1` checkpoints if len(checkpoints) >= training_args.save_total_limit: num_to_remove = len(checkpoints) - training_args.save_total_limit + 1 removing_checkpoints = checkpoints[0:num_to_remove] logger.info( f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" ) logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") for removing_checkpoint in removing_checkpoints: removing_checkpoint = os.path.join(training_args.output_dir, removing_checkpoint) shutil.rmtree(removing_checkpoint) save_path = os.path.join(training_args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") logs = { "step_loss": total_generator_loss.detach().item(), "lr": disc_lr_scheduler.get_last_lr()[0], "step_loss_duration": loss_duration.detach().item(), "step_loss_mel": loss_mel.detach().item(), "step_loss_kl": loss_kl.detach().item(), "step_loss_fmaps": loss_fmaps.detach().item(), "step_loss_gen": loss_gen.detach().item(), "step_loss_disc": loss_disc.detach().item(), "step_loss_real_disc": loss_real_disc.detach().item(), "step_loss_fake_disc": loss_fake_disc.detach().item(), } if global_step >= training_args.max_steps: break eval_steps = training_args.eval_steps if training_args.eval_steps else 1 do_eval = training_args.do_eval and (global_step % eval_steps == 0) and accelerator.sync_gradients if do_eval: logger.info("Running validation... ") generated_audio = [] generated_attn = [] generated_spec = [] target_spec = [] val_losses = {} for step, batch in enumerate(eval_dataloader): print( f"VALIDATION - batch {step}, process{accelerator.process_index}, waveform {(batch['waveform'].shape)}, tokens {(batch['input_ids'].shape)}... " ) with torch.no_grad(): model_outputs_train = model( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"], labels_attention_mask=batch["labels_attention_mask"], speaker_id=batch["speaker_id"], encoder_output = batch['text_encoder_output'], return_dict=True, monotonic_alignment_function=None, ) mel_scaled_labels = batch["mel_scaled_input_features"] mel_scaled_target = model.slice_segments( mel_scaled_labels, model_outputs_train.ids_slice, model_segment_size ) mel_scaled_generation = feature_extractor._torch_extract_fbank_features( model_outputs_train.waveform.squeeze(1) )[1] val_losses = compute_val_metrics_and_losses( val_losses, accelerator, model_outputs_train, mel_scaled_generation, mel_scaled_target, per_device_train_batch_size, compute_clap_similarity=False, ) print(f"VALIDATION - batch {step}, process{accelerator.process_index}, PADDING AND GATHER... ") specs = feature_extractor._torch_extract_fbank_features(model_outputs_train.waveform.squeeze(1))[0] padded_attn, specs, target_specs = accelerator.pad_across_processes( [model_outputs_train.attn.squeeze(1), specs, batch["labels"]], dim=1 ) padded_attn, specs, target_specs = accelerator.pad_across_processes( [padded_attn, specs, target_specs], dim=2 ) generated_train_waveform, padded_attn, specs, target_specs = accelerator.gather_for_metrics( [model_outputs_train.waveform, padded_attn, specs, target_specs] ) if accelerator.is_main_process: with torch.no_grad(): speaker_id = None if num_speakers < 2 else list(range(min(5, num_speakers))) full_generation = model(**full_generation_sample.to(model.device), speaker_id=speaker_id) generated_audio.append(generated_train_waveform.cpu()) generated_attn.append(padded_attn.cpu()) generated_spec.append(specs.cpu()) target_spec.append(target_specs.cpu()) logger.info("Validation inference done, now evaluating... ") if accelerator.is_main_process: generated_audio = [audio.numpy() for audio_batch in generated_audio for audio in audio_batch] generated_attn = [ plot_alignment_to_numpy(attn.numpy()) for attn_batch in generated_attn for attn in attn_batch ] generated_spec = [ plot_spectrogram_to_numpy(attn.numpy()) for attn_batch in generated_spec for attn in attn_batch ] target_spec = [ plot_spectrogram_to_numpy(attn.numpy()) for attn_batch in target_spec for attn in attn_batch ] full_generation_waveform = full_generation.waveform.cpu().numpy() accelerator.log(val_losses, step=global_step) log_on_trackers( accelerator.trackers, generated_audio, generated_attn, generated_spec, target_spec, full_generation_waveform, epoch, sampling_rate, ) logger.info("Validation finished... ") accelerator.wait_for_everyone() accelerator.wait_for_everyone() if accelerator.is_main_process: epoch = training_args.num_train_epochs if training_args.num_train_epochs else 1 eval_steps = training_args.eval_steps if training_args.eval_steps else 1 # Run a final round of inference. do_eval = training_args.do_eval if do_eval: logger.info("Running final validation... ") generated_audio = [] generated_attn = [] generated_spec = [] target_spec = [] val_losses = {} for step, batch in enumerate(eval_dataloader): print( f"VALIDATION - batch {step}, process{accelerator.process_index}, waveform {(batch['waveform'].shape)}, tokens {(batch['input_ids'].shape)}... " ) with torch.no_grad(): model_outputs_train = model( input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"], labels_attention_mask=batch["labels_attention_mask"], speaker_id=batch["speaker_id"], encoder_output = batch['text_encoder_output'], return_dict=True, monotonic_alignment_function=None, ) mel_scaled_labels = batch["mel_scaled_input_features"] mel_scaled_target = model.slice_segments( mel_scaled_labels, model_outputs_train.ids_slice, model_segment_size ) mel_scaled_generation = feature_extractor._torch_extract_fbank_features( model_outputs_train.waveform.squeeze(1) )[1] val_losses = compute_val_metrics_and_losses( val_losses, accelerator, model_outputs_train, mel_scaled_generation, mel_scaled_target, per_device_train_batch_size, compute_clap_similarity=False, ) specs = feature_extractor._torch_extract_fbank_features(model_outputs_train.waveform.squeeze(1))[0] padded_attn, specs, target_specs = accelerator.pad_across_processes( [model_outputs_train.attn.squeeze(1), specs, batch["labels"]], dim=1 ) padded_attn, specs, target_specs = accelerator.pad_across_processes( [padded_attn, specs, target_specs], dim=2 ) generated_train_waveform, padded_attn, specs, target_specs = accelerator.gather_for_metrics( [model_outputs_train.waveform, padded_attn, specs, target_specs] ) if accelerator.is_main_process: with torch.no_grad(): speaker_id = None if num_speakers < 2 else list(range(min(5, num_speakers))) full_generation = model(**full_generation_sample.to(model.device), speaker_id=speaker_id) generated_audio.append(generated_train_waveform.cpu()) generated_attn.append(padded_attn.cpu()) generated_spec.append(specs.cpu()) target_spec.append(target_specs.cpu()) logger.info("Validation inference done, now evaluating... ") if accelerator.is_main_process: generated_audio = [audio.numpy() for audio_batch in generated_audio for audio in audio_batch] generated_attn = [ plot_alignment_to_numpy(attn.numpy()) for attn_batch in generated_attn for attn in attn_batch ] generated_spec = [ plot_spectrogram_to_numpy(attn.numpy()) for attn_batch in generated_spec for attn in attn_batch ] target_spec = [ plot_spectrogram_to_numpy(attn.numpy()) for attn_batch in target_spec for attn in attn_batch ] full_generation_waveform = full_generation.waveform.cpu().numpy() log_on_trackers( accelerator.trackers, generated_audio, generated_attn, generated_spec, target_spec, full_generation_waveform, epoch, sampling_rate, ) accelerator.log(val_losses, step=global_step) logger.info("Validation finished... ") accelerator.wait_for_everyone() # unwrap, save and push final model model = accelerator.unwrap_model(model) discriminator = accelerator.unwrap_model(discriminator) model.discriminator = discriminator # add weight norms for disc in model.discriminator.discriminators: disc.remove_weight_norm() model.decoder.remove_weight_norm() for flow in model.flow.flows: torch.nn.utils.remove_weight_norm(flow.conv_pre) torch.nn.utils.remove_weight_norm(flow.conv_post) model.save_pretrained(training_args.output_dir) if training_args.push_to_hub: VitsModel.from_pretrained(training_args.output_dir).push_to_hub(training_args.hub_model_id) accelerator.end_training() logger.info("***** Training / Inference Done *****") #...............................................................................