import os import sys import glob import json import torch import hashlib import logging import argparse import datetime import warnings import logging.handlers import numpy as np import soundfile as sf import matplotlib.pyplot as plt import torch.distributed as dist import torch.utils.data as tdata import torch.multiprocessing as mp from tqdm import tqdm from collections import OrderedDict from random import randint, shuffle from torch.utils.checkpoint import checkpoint from torch.cuda.amp import GradScaler, autocast from torch.utils.tensorboard import SummaryWriter from time import time as ttime from torch.nn import functional as F from distutils.util import strtobool from librosa.filters import mel as librosa_mel_fn from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.utils.parametrizations import spectral_norm, weight_norm sys.path.append(os.getcwd()) from main.configs.config import Config from main.library.algorithm.residuals import LRELU_SLOPE from main.library.algorithm.synthesizers import Synthesizer from main.library.algorithm.commons import get_padding, slice_segments, clip_grad_value MATPLOTLIB_FLAG = False translations = Config().translations warnings.filterwarnings("ignore") logging.getLogger("torch").setLevel(logging.ERROR) class HParams: def __init__(self, **kwargs): for k, v in kwargs.items(): self[k] = HParams(**v) if isinstance(v, dict) else v def keys(self): return self.__dict__.keys() def items(self): return self.__dict__.items() def values(self): return self.__dict__.values() def __len__(self): return len(self.__dict__) def __getitem__(self, key): return self.__dict__[key] def __setitem__(self, key, value): self.__dict__[key] = value def __contains__(self, key): return key in self.__dict__ def __repr__(self): return repr(self.__dict__) def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument("--model_name", type=str, required=True) parser.add_argument("--rvc_version", type=str, default="v2") parser.add_argument("--save_every_epoch", type=int, required=True) parser.add_argument("--save_only_latest", type=lambda x: bool(strtobool(x)), default=True) parser.add_argument("--save_every_weights", type=lambda x: bool(strtobool(x)), default=True) parser.add_argument("--total_epoch", type=int, default=300) parser.add_argument("--sample_rate", type=int, required=True) parser.add_argument("--batch_size", type=int, default=8) parser.add_argument("--gpu", type=str, default="0") parser.add_argument("--pitch_guidance", type=lambda x: bool(strtobool(x)), default=True) parser.add_argument("--g_pretrained_path", type=str, default="") parser.add_argument("--d_pretrained_path", type=str, default="") parser.add_argument("--overtraining_detector", type=lambda x: bool(strtobool(x)), default=False) parser.add_argument("--overtraining_threshold", type=int, default=50) parser.add_argument("--cleanup", type=lambda x: bool(strtobool(x)), default=False) parser.add_argument("--cache_data_in_gpu", type=lambda x: bool(strtobool(x)), default=False) parser.add_argument("--model_author", type=str) parser.add_argument("--vocoder", type=str, default="Default") parser.add_argument("--checkpointing", type=lambda x: bool(strtobool(x)), default=False) return parser.parse_args() args = parse_arguments() model_name, save_every_epoch, total_epoch, pretrainG, pretrainD, version, gpus, batch_size, sample_rate, pitch_guidance, save_only_latest, save_every_weights, cache_data_in_gpu, overtraining_detector, overtraining_threshold, cleanup, model_author, vocoder, checkpointing = args.model_name, args.save_every_epoch, args.total_epoch, args.g_pretrained_path, args.d_pretrained_path, args.rvc_version, args.gpu, args.batch_size, args.sample_rate, args.pitch_guidance, args.save_only_latest, args.save_every_weights, args.cache_data_in_gpu, args.overtraining_detector, args.overtraining_threshold, args.cleanup, args.model_author, args.vocoder, args.checkpointing experiment_dir = os.path.join("assets", "logs", model_name) training_file_path = os.path.join(experiment_dir, "training_data.json") config_save_path = os.path.join(experiment_dir, "config.json") os.environ["CUDA_VISIBLE_DEVICES"] = gpus.replace("-", ",") n_gpus = len(gpus.split("-")) torch.backends.cudnn.deterministic = False torch.backends.cudnn.benchmark = False lowest_value = {"step": 0, "value": float("inf"), "epoch": 0} global_step, last_loss_gen_all, overtrain_save_epoch = 0, 0, 0 loss_gen_history, smoothed_loss_gen_history, loss_disc_history, smoothed_loss_disc_history = [], [], [], [] with open(config_save_path, "r") as f: config = json.load(f) config = HParams(**config) config.data.training_files = os.path.join(experiment_dir, "filelist.txt") logger = logging.getLogger(__name__) if logger.hasHandlers(): logger.handlers.clear() else: console_handler = logging.StreamHandler() console_handler.setFormatter(logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")) console_handler.setLevel(logging.INFO) file_handler = logging.handlers.RotatingFileHandler(os.path.join(experiment_dir, "train.log"), maxBytes=5*1024*1024, backupCount=3, encoding='utf-8') file_handler.setFormatter(logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")) file_handler.setLevel(logging.DEBUG) logger.addHandler(console_handler) logger.addHandler(file_handler) logger.setLevel(logging.DEBUG) log_data = {translations['modelname']: model_name, translations["save_every_epoch"]: save_every_epoch, translations["total_e"]: total_epoch, translations["dorg"].format(pretrainG=pretrainG, pretrainD=pretrainD): "", translations['training_version']: version, "Gpu": gpus, translations['batch_size']: batch_size, translations['pretrain_sr']: sample_rate, translations['training_f0']: pitch_guidance, translations['save_only_latest']: save_only_latest, translations['save_every_weights']: save_every_weights, translations['cache_in_gpu']: cache_data_in_gpu, translations['overtraining_detector']: overtraining_detector, translations['threshold']: overtraining_threshold, translations['cleanup_training']: cleanup, translations['memory_efficient_training']: checkpointing} if model_author: log_data[translations["model_author"].format(model_author=model_author)] = "" if vocoder != "Default": log_data[translations['vocoder']] = vocoder for key, value in log_data.items(): logger.debug(f"{key}: {value}" if value != "" else f"{key} {value}") def main(): global training_file_path, last_loss_gen_all, smoothed_loss_gen_history, loss_gen_history, loss_disc_history, smoothed_loss_disc_history, overtrain_save_epoch, model_author, vocoder, checkpointing os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(randint(20000, 55555)) if torch.cuda.is_available(): device, n_gpus = torch.device("cuda"), torch.cuda.device_count() elif torch.backends.mps.is_available(): device, n_gpus = torch.device("mps"), 1 else: device, n_gpus = torch.device("cpu"), 1 def start(): children = [] pid_data = {"process_pids": []} with open(config_save_path, "r") as pid_file: try: pid_data.update(json.load(pid_file)) except json.JSONDecodeError: pass with open(config_save_path, "w") as pid_file: for i in range(n_gpus): subproc = mp.Process(target=run, args=(i, n_gpus, experiment_dir, pretrainG, pretrainD, pitch_guidance, total_epoch, save_every_weights, config, device, model_author, vocoder, checkpointing)) children.append(subproc) subproc.start() pid_data["process_pids"].append(subproc.pid) json.dump(pid_data, pid_file, indent=4) for i in range(n_gpus): children[i].join() def load_from_json(file_path): if os.path.exists(file_path): with open(file_path, "r") as f: data = json.load(f) return (data.get("loss_disc_history", []), data.get("smoothed_loss_disc_history", []), data.get("loss_gen_history", []), data.get("smoothed_loss_gen_history", [])) return [], [], [], [] def continue_overtrain_detector(training_file_path): if overtraining_detector and os.path.exists(training_file_path): (loss_disc_history, smoothed_loss_disc_history, loss_gen_history, smoothed_loss_gen_history) = load_from_json(training_file_path) n_gpus = torch.cuda.device_count() if not torch.cuda.is_available() and torch.backends.mps.is_available(): n_gpus = 1 if n_gpus < 1: logger.warning(translations["not_gpu"]) n_gpus = 1 if cleanup: for root, dirs, files in os.walk(experiment_dir, topdown=False): for name in files: file_path = os.path.join(root, name) _, file_extension = os.path.splitext(name) if (file_extension == ".0" or (name.startswith("D_") and file_extension == ".pth") or (name.startswith("G_") and file_extension == ".pth") or (file_extension == ".index")): os.remove(file_path) for name in dirs: if name == "eval": folder_path = os.path.join(root, name) for item in os.listdir(folder_path): item_path = os.path.join(folder_path, item) if os.path.isfile(item_path): os.remove(item_path) os.rmdir(folder_path) continue_overtrain_detector(training_file_path) start() def plot_spectrogram_to_numpy(spectrogram): global MATPLOTLIB_FLAG if not MATPLOTLIB_FLAG: plt.switch_backend("Agg") MATPLOTLIB_FLAG = True fig, ax = plt.subplots(figsize=(10, 2)) plt.colorbar(ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none"), ax=ax) plt.xlabel("Frames") plt.ylabel("Channels") plt.tight_layout() fig.canvas.draw() plt.close(fig) return np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8).reshape(fig.canvas.get_width_height()[::-1] + (3,)) def verify_checkpoint_shapes(checkpoint_path, model): checkpoint = torch.load(checkpoint_path, map_location="cpu") checkpoint_state_dict = checkpoint["model"] try: model_state_dict = model.module.load_state_dict(checkpoint_state_dict) if hasattr(model, "module") else model.load_state_dict(checkpoint_state_dict) except RuntimeError: logger.warning(translations["checkpointing_err"]) sys.exit(1) else: del checkpoint, checkpoint_state_dict, model_state_dict def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sample_rate=22050): for k, v in scalars.items(): writer.add_scalar(k, v, global_step) for k, v in histograms.items(): writer.add_histogram(k, v, global_step) for k, v in images.items(): writer.add_image(k, v, global_step, dataformats="HWC") for k, v in audios.items(): writer.add_audio(k, v, global_step, audio_sample_rate) def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1): assert os.path.isfile(checkpoint_path), translations["not_found_checkpoint"].format(checkpoint_path=checkpoint_path) checkpoint_dict = replace_keys_in_dict(replace_keys_in_dict(torch.load(checkpoint_path, map_location="cpu"), ".weight_v", ".parametrizations.weight.original1"), ".weight_g", ".parametrizations.weight.original0") new_state_dict = {k: checkpoint_dict["model"].get(k, v) for k, v in (model.module.state_dict() if hasattr(model, "module") else model.state_dict()).items()} if hasattr(model, "module"): model.module.load_state_dict(new_state_dict, strict=False) else: model.load_state_dict(new_state_dict, strict=False) if optimizer and load_opt == 1: optimizer.load_state_dict(checkpoint_dict.get("optimizer", {})) logger.debug(translations["save_checkpoint"].format(checkpoint_path=checkpoint_path, checkpoint_dict=checkpoint_dict['iteration'])) return (model, optimizer, checkpoint_dict.get("learning_rate", 0), checkpoint_dict["iteration"]) def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path): state_dict = (model.module.state_dict() if hasattr(model, "module") else model.state_dict()) torch.save(replace_keys_in_dict(replace_keys_in_dict({"model": state_dict, "iteration": iteration, "optimizer": optimizer.state_dict(), "learning_rate": learning_rate}, ".parametrizations.weight.original1", ".weight_v"), ".parametrizations.weight.original0", ".weight_g"), checkpoint_path) logger.info(translations["save_model"].format(checkpoint_path=checkpoint_path, iteration=iteration)) def latest_checkpoint_path(dir_path, regex="G_*.pth"): checkpoints = sorted(glob.glob(os.path.join(dir_path, regex)), key=lambda f: int("".join(filter(str.isdigit, f)))) return checkpoints[-1] if checkpoints else None def load_wav_to_torch(full_path): data, sample_rate = sf.read(full_path, dtype='float32') return torch.FloatTensor(data.astype(np.float32)), sample_rate def load_filepaths_and_text(filename, split="|"): with open(filename, encoding="utf-8") as f: return [line.strip().split(split) for line in f] def feature_loss(fmap_r, fmap_g): loss = 0 for dr, dg in zip(fmap_r, fmap_g): for rl, gl in zip(dr, dg): loss += torch.mean(torch.abs(rl.float().detach() - gl.float())) return loss * 2 def discriminator_loss(disc_real_outputs, disc_generated_outputs): loss = 0 r_losses, g_losses = [], [] for dr, dg in zip(disc_real_outputs, disc_generated_outputs): dr = dr.float() dg = dg.float() r_loss = torch.mean((1 - dr) ** 2) g_loss = torch.mean(dg**2) loss += r_loss + g_loss r_losses.append(r_loss.item()) g_losses.append(g_loss.item()) return loss, r_losses, g_losses def generator_loss(disc_outputs): loss = 0 gen_losses = [] for dg in disc_outputs: l = torch.mean((1 - dg.float()) ** 2) gen_losses.append(l) loss += l return loss, gen_losses def kl_loss(z_p, logs_q, m_p, logs_p, z_mask): z_p = z_p.float() logs_q = logs_q.float() m_p = m_p.float() logs_p = logs_p.float() z_mask = z_mask.float() kl = logs_p - logs_q - 0.5 kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p) return torch.sum(kl * z_mask) / torch.sum(z_mask) class TextAudioLoaderMultiNSFsid(tdata.Dataset): def __init__(self, hparams): self.audiopaths_and_text = load_filepaths_and_text(hparams.training_files) self.max_wav_value = hparams.max_wav_value self.sample_rate = hparams.sample_rate self.filter_length = hparams.filter_length self.hop_length = hparams.hop_length self.win_length = hparams.win_length self.sample_rate = hparams.sample_rate self.min_text_len = getattr(hparams, "min_text_len", 1) self.max_text_len = getattr(hparams, "max_text_len", 5000) self._filter() def _filter(self): audiopaths_and_text_new, lengths = [], [] for audiopath, text, pitch, pitchf, dv in self.audiopaths_and_text: if self.min_text_len <= len(text) and len(text) <= self.max_text_len: audiopaths_and_text_new.append([audiopath, text, pitch, pitchf, dv]) lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length)) self.audiopaths_and_text = audiopaths_and_text_new self.lengths = lengths def get_sid(self, sid): try: sid = torch.LongTensor([int(sid)]) except ValueError as e: logger.error(translations["sid_error"].format(sid=sid, e=e)) sid = torch.LongTensor([0]) return sid def get_audio_text_pair(self, audiopath_and_text): phone, pitch, pitchf = self.get_labels(audiopath_and_text[1], audiopath_and_text[2], audiopath_and_text[3]) spec, wav = self.get_audio(audiopath_and_text[0]) dv = self.get_sid(audiopath_and_text[4]) len_phone = phone.size()[0] len_spec = spec.size()[-1] if len_phone != len_spec: len_min = min(len_phone, len_spec) len_wav = len_min * self.hop_length spec, wav, phone = spec[:, :len_min], wav[:, :len_wav], phone[:len_min, :] pitch, pitchf = pitch[:len_min], pitchf[:len_min] return (spec, wav, phone, pitch, pitchf, dv) def get_labels(self, phone, pitch, pitchf): phone = np.repeat(np.load(phone), 2, axis=0) n_num = min(phone.shape[0], 900) return torch.FloatTensor(phone[:n_num, :]), torch.LongTensor(np.load(pitch)[:n_num]), torch.FloatTensor(np.load(pitchf)[:n_num]) def get_audio(self, filename): audio, sample_rate = load_wav_to_torch(filename) if sample_rate != self.sample_rate: raise ValueError(translations["sr_does_not_match"].format(sample_rate=sample_rate, sample_rate2=self.sample_rate)) audio_norm = audio.unsqueeze(0) spec_filename = filename.replace(".wav", ".spec.pt") if os.path.exists(spec_filename): try: spec = torch.load(spec_filename) except Exception as e: logger.error(translations["spec_error"].format(spec_filename=spec_filename, e=e)) spec = torch.squeeze(spectrogram_torch(audio_norm, self.filter_length, self.hop_length, self.win_length, center=False), 0) torch.save(spec, spec_filename, _use_new_zipfile_serialization=False) else: spec = torch.squeeze(spectrogram_torch(audio_norm, self.filter_length, self.hop_length, self.win_length, center=False), 0) torch.save(spec, spec_filename, _use_new_zipfile_serialization=False) return spec, audio_norm def __getitem__(self, index): return self.get_audio_text_pair(self.audiopaths_and_text[index]) def __len__(self): return len(self.audiopaths_and_text) class TextAudioCollateMultiNSFsid: def __init__(self, return_ids=False): self.return_ids = return_ids def __call__(self, batch): _, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True) spec_lengths, wave_lengths = torch.LongTensor(len(batch)), torch.LongTensor(len(batch)) spec_padded, wave_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max([x[0].size(1) for x in batch])), torch.FloatTensor(len(batch), 1, max([x[1].size(1) for x in batch])) spec_padded.zero_() wave_padded.zero_() max_phone_len = max([x[2].size(0) for x in batch]) phone_lengths, phone_padded = torch.LongTensor(len(batch)), torch.FloatTensor(len(batch), max_phone_len, batch[0][2].shape[1]) pitch_padded, pitchf_padded = torch.LongTensor(len(batch), max_phone_len), torch.FloatTensor(len(batch), max_phone_len) phone_padded.zero_() pitch_padded.zero_() pitchf_padded.zero_() sid = torch.LongTensor(len(batch)) for i in range(len(ids_sorted_decreasing)): row = batch[ids_sorted_decreasing[i]] spec = row[0] spec_padded[i, :, : spec.size(1)] = spec spec_lengths[i] = spec.size(1) wave = row[1] wave_padded[i, :, : wave.size(1)] = wave wave_lengths[i] = wave.size(1) phone = row[2] phone_padded[i, : phone.size(0), :] = phone phone_lengths[i] = phone.size(0) pitch = row[3] pitch_padded[i, : pitch.size(0)] = pitch pitchf = row[4] pitchf_padded[i, : pitchf.size(0)] = pitchf sid[i] = row[5] return (phone_padded, phone_lengths, pitch_padded, pitchf_padded, spec_padded, spec_lengths, wave_padded, wave_lengths, sid) class TextAudioLoader(tdata.Dataset): def __init__(self, hparams): self.audiopaths_and_text = load_filepaths_and_text(hparams.training_files) self.max_wav_value = hparams.max_wav_value self.sample_rate = hparams.sample_rate self.filter_length = hparams.filter_length self.hop_length = hparams.hop_length self.win_length = hparams.win_length self.sample_rate = hparams.sample_rate self.min_text_len = getattr(hparams, "min_text_len", 1) self.max_text_len = getattr(hparams, "max_text_len", 5000) self._filter() def _filter(self): audiopaths_and_text_new, lengths = [], [] for entry in self.audiopaths_and_text: if len(entry) >= 3: audiopath, text, dv = entry[:3] if self.min_text_len <= len(text) and len(text) <= self.max_text_len: audiopaths_and_text_new.append([audiopath, text, dv]) lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length)) self.audiopaths_and_text = audiopaths_and_text_new self.lengths = lengths def get_sid(self, sid): try: sid = torch.LongTensor([int(sid)]) except ValueError as e: logger.error(translations["sid_error"].format(sid=sid, e=e)) sid = torch.LongTensor([0]) return sid def get_audio_text_pair(self, audiopath_and_text): phone = self.get_labels(audiopath_and_text[1]) spec, wav = self.get_audio(audiopath_and_text[0]) dv = self.get_sid(audiopath_and_text[2]) len_phone = phone.size()[0] len_spec = spec.size()[-1] if len_phone != len_spec: len_min = min(len_phone, len_spec) len_wav = len_min * self.hop_length spec = spec[:, :len_min] wav = wav[:, :len_wav] phone = phone[:len_min, :] return (spec, wav, phone, dv) def get_labels(self, phone): phone = np.repeat(np.load(phone), 2, axis=0) return torch.FloatTensor(phone[:min(phone.shape[0], 900), :]) def get_audio(self, filename): audio, sample_rate = load_wav_to_torch(filename) if sample_rate != self.sample_rate: raise ValueError(translations["sr_does_not_match"].format(sample_rate=sample_rate, sample_rate2=self.sample_rate)) audio_norm = audio.unsqueeze(0) spec_filename = filename.replace(".wav", ".spec.pt") if os.path.exists(spec_filename): try: spec = torch.load(spec_filename) except Exception as e: logger.error(translations["spec_error"].format(spec_filename=spec_filename, e=e)) spec = torch.squeeze(spectrogram_torch(audio_norm, self.filter_length, self.hop_length, self.win_length, center=False), 0) torch.save(spec, spec_filename, _use_new_zipfile_serialization=False) else: spec = torch.squeeze(spectrogram_torch(audio_norm, self.filter_length, self.hop_length, self.win_length, center=False), 0) torch.save(spec, spec_filename, _use_new_zipfile_serialization=False) return spec, audio_norm def __getitem__(self, index): return self.get_audio_text_pair(self.audiopaths_and_text[index]) def __len__(self): return len(self.audiopaths_and_text) class TextAudioCollate: def __init__(self, return_ids=False): self.return_ids = return_ids def __call__(self, batch): _, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True) spec_lengths, wave_lengths = torch.LongTensor(len(batch)), torch.LongTensor(len(batch)) spec_padded, wave_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max([x[0].size(1) for x in batch])), torch.FloatTensor(len(batch), 1, max([x[1].size(1) for x in batch])) spec_padded.zero_() wave_padded.zero_() max_phone_len = max([x[2].size(0) for x in batch]) phone_lengths, phone_padded = torch.LongTensor(len(batch)), torch.FloatTensor(len(batch), max_phone_len, batch[0][2].shape[1]) phone_padded.zero_() sid = torch.LongTensor(len(batch)) for i in range(len(ids_sorted_decreasing)): row = batch[ids_sorted_decreasing[i]] spec = row[0] spec_padded[i, :, : spec.size(1)] = spec spec_lengths[i] = spec.size(1) wave = row[1] wave_padded[i, :, : wave.size(1)] = wave wave_lengths[i] = wave.size(1) phone = row[2] phone_padded[i, : phone.size(0), :] = phone phone_lengths[i] = phone.size(0) sid[i] = row[3] return (phone_padded, phone_lengths, spec_padded, spec_lengths, wave_padded, wave_lengths, sid) class DistributedBucketSampler(tdata.distributed.DistributedSampler): def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True): super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle) self.lengths = dataset.lengths self.batch_size = batch_size self.boundaries = boundaries self.buckets, self.num_samples_per_bucket = self._create_buckets() self.total_size = sum(self.num_samples_per_bucket) self.num_samples = self.total_size // self.num_replicas def _create_buckets(self): buckets = [[] for _ in range(len(self.boundaries) - 1)] for i in range(len(self.lengths)): idx_bucket = self._bisect(self.lengths[i]) if idx_bucket != -1: buckets[idx_bucket].append(i) for i in range(len(buckets) - 1, -1, -1): if len(buckets[i]) == 0: buckets.pop(i) self.boundaries.pop(i + 1) num_samples_per_bucket = [] for i in range(len(buckets)): len_bucket = len(buckets[i]) total_batch_size = self.num_replicas * self.batch_size num_samples_per_bucket.append(len_bucket + ((total_batch_size - (len_bucket % total_batch_size)) % total_batch_size)) return buckets, num_samples_per_bucket def __iter__(self): g = torch.Generator() g.manual_seed(self.epoch) indices, batches = [], [] if self.shuffle: for bucket in self.buckets: indices.append(torch.randperm(len(bucket), generator=g).tolist()) else: for bucket in self.buckets: indices.append(list(range(len(bucket)))) for i in range(len(self.buckets)): bucket = self.buckets[i] len_bucket = len(bucket) ids_bucket = indices[i] rem = self.num_samples_per_bucket[i] - len_bucket ids_bucket = (ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[: (rem % len_bucket)])[self.rank :: self.num_replicas] for j in range(len(ids_bucket) // self.batch_size): batches.append([bucket[idx] for idx in ids_bucket[j * self.batch_size : (j + 1) * self.batch_size]]) if self.shuffle: batches = [batches[i] for i in torch.randperm(len(batches), generator=g).tolist()] self.batches = batches assert len(self.batches) * self.batch_size == self.num_samples return iter(self.batches) def _bisect(self, x, lo=0, hi=None): if hi is None: hi = len(self.boundaries) - 1 if hi > lo: mid = (hi + lo) // 2 if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]: return mid elif x <= self.boundaries[mid]: return self._bisect(x, lo, mid) else: return self._bisect(x, mid + 1, hi) else: return -1 def __len__(self): return self.num_samples // self.batch_size class MultiPeriodDiscriminator(torch.nn.Module): def __init__(self, version, use_spectral_norm=False, checkpointing=False): super(MultiPeriodDiscriminator, self).__init__() self.checkpointing = checkpointing periods = ([2, 3, 5, 7, 11, 17] if version == "v1" else [2, 3, 5, 7, 11, 17, 23, 37]) self.discriminators = torch.nn.ModuleList([DiscriminatorS(use_spectral_norm=use_spectral_norm, checkpointing=checkpointing)] + [DiscriminatorP(p, use_spectral_norm=use_spectral_norm, checkpointing=checkpointing) for p in periods]) def forward(self, y, y_hat): y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], [] for d in self.discriminators: if self.training and self.checkpointing: def forward_discriminator(d, y, y_hat): y_d_r, fmap_r = d(y) y_d_g, fmap_g = d(y_hat) return y_d_r, fmap_r, y_d_g, fmap_g y_d_r, fmap_r, y_d_g, fmap_g = checkpoint(forward_discriminator, d, y, y_hat, use_reentrant=False) else: y_d_r, fmap_r = d(y) y_d_g, fmap_g = d(y_hat) y_d_rs.append(y_d_r); fmap_rs.append(fmap_r) y_d_gs.append(y_d_g); fmap_gs.append(fmap_g) return y_d_rs, y_d_gs, fmap_rs, fmap_gs class DiscriminatorS(torch.nn.Module): def __init__(self, use_spectral_norm=False, checkpointing=False): super(DiscriminatorS, self).__init__() self.checkpointing = checkpointing norm_f = spectral_norm if use_spectral_norm else weight_norm self.convs = torch.nn.ModuleList([norm_f(torch.nn.Conv1d(1, 16, 15, 1, padding=7)), norm_f(torch.nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)), norm_f(torch.nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)), norm_f(torch.nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)), norm_f(torch.nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), norm_f(torch.nn.Conv1d(1024, 1024, 5, 1, padding=2))]) self.conv_post = norm_f(torch.nn.Conv1d(1024, 1, 3, 1, padding=1)) self.lrelu = torch.nn.LeakyReLU(LRELU_SLOPE) def forward(self, x): fmap = [] for conv in self.convs: x = checkpoint(self.lrelu, checkpoint(conv, x, use_reentrant = False), use_reentrant = False) if self.training and self.checkpointing else self.lrelu(conv(x)) fmap.append(x) x = self.conv_post(x) fmap.append(x) return torch.flatten(x, 1, -1), fmap class DiscriminatorP(torch.nn.Module): def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False, checkpointing=False): super(DiscriminatorP, self).__init__() self.period = period self.checkpointing = checkpointing norm_f = spectral_norm if use_spectral_norm else weight_norm self.convs = torch.nn.ModuleList([norm_f(torch.nn.Conv2d(in_ch, out_ch, (kernel_size, 1), (stride, 1), padding=(get_padding(kernel_size, 1), 0))) for in_ch, out_ch in zip([1, 32, 128, 512, 1024], [32, 128, 512, 1024, 1024])]) self.conv_post = norm_f(torch.nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) self.lrelu = torch.nn.LeakyReLU(LRELU_SLOPE) def forward(self, x): fmap = [] b, c, t = x.shape if t % self.period != 0: x = torch.nn.functional.pad(x, (0, (self.period - (t % self.period))), "reflect") x = x.view(b, c, -1, self.period) for conv in self.convs: x = checkpoint(self.lrelu, checkpoint(conv, x, use_reentrant = False), use_reentrant = False) if self.training and self.checkpointing else self.lrelu(conv(x)) fmap.append(x) x = self.conv_post(x) fmap.append(x) return torch.flatten(x, 1, -1), fmap class EpochRecorder: def __init__(self): self.last_time = ttime() def record(self): now_time = ttime() elapsed_time = now_time - self.last_time self.last_time = now_time return translations["time_or_speed_training"].format(current_time=datetime.datetime.now().strftime("%H:%M:%S"), elapsed_time_str=str(datetime.timedelta(seconds=int(round(elapsed_time, 1))))) def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): return torch.log(torch.clamp(x, min=clip_val) * C) def dynamic_range_decompression_torch(x, C=1): return torch.exp(x) / C def spectral_normalize_torch(magnitudes): return dynamic_range_compression_torch(magnitudes) def spectral_de_normalize_torch(magnitudes): return dynamic_range_decompression_torch(magnitudes) mel_basis, hann_window = {}, {} def spectrogram_torch(y, n_fft, hop_size, win_size, center=False): global hann_window wnsize_dtype_device = str(win_size) + "_" + str(y.dtype) + "_" + str(y.device) if wnsize_dtype_device not in hann_window: hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device) spec = torch.stft(torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect").squeeze(1), n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], center=center, pad_mode="reflect", normalized=False, onesided=True, return_complex=True) return torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-6) def spec_to_mel_torch(spec, n_fft, num_mels, sample_rate, fmin, fmax): global mel_basis fmax_dtype_device = str(fmax) + "_" + str(spec.dtype) + "_" + str(spec.device) if fmax_dtype_device not in mel_basis: mel_basis[fmax_dtype_device] = torch.from_numpy(librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)).to(dtype=spec.dtype, device=spec.device) return spectral_normalize_torch(torch.matmul(mel_basis[fmax_dtype_device], spec)) def mel_spectrogram_torch(y, n_fft, num_mels, sample_rate, hop_size, win_size, fmin, fmax, center=False): return spec_to_mel_torch(spectrogram_torch(y, n_fft, hop_size, win_size, center), n_fft, num_mels, sample_rate, fmin, fmax) def replace_keys_in_dict(d, old_key_part, new_key_part): updated_dict = OrderedDict() if isinstance(d, OrderedDict) else {} for key, value in d.items(): updated_dict[(key.replace(old_key_part, new_key_part) if isinstance(key, str) else key)] = (replace_keys_in_dict(value, old_key_part, new_key_part) if isinstance(value, dict) else value) return updated_dict def extract_model(ckpt, sr, pitch_guidance, name, model_path, epoch, step, version, hps, model_author, vocoder): try: logger.info(translations["savemodel"].format(model_dir=model_path, epoch=epoch, step=step)) os.makedirs(os.path.dirname(model_path), exist_ok=True) opt = OrderedDict(weight={key: value.half() for key, value in ckpt.items() if "enc_q" not in key}) opt["config"] = [hps.data.filter_length // 2 + 1, 32, hps.model.inter_channels, hps.model.hidden_channels, hps.model.filter_channels, hps.model.n_heads, hps.model.n_layers, hps.model.kernel_size, hps.model.p_dropout, hps.model.resblock, hps.model.resblock_kernel_sizes, hps.model.resblock_dilation_sizes, hps.model.upsample_rates, hps.model.upsample_initial_channel, hps.model.upsample_kernel_sizes, hps.model.spk_embed_dim, hps.model.gin_channels, hps.data.sample_rate] opt["epoch"] = f"{epoch}epoch" opt["step"] = step opt["sr"] = sr opt["f0"] = int(pitch_guidance) opt["version"] = version opt["creation_date"] = datetime.datetime.now().isoformat() opt["model_hash"] = hashlib.sha256(f"{str(ckpt)} {epoch} {step} {datetime.datetime.now().isoformat()}".encode()).hexdigest() opt["model_name"] = name opt["author"] = model_author opt["vocoder"] = vocoder torch.save(replace_keys_in_dict(replace_keys_in_dict(opt, ".parametrizations.weight.original1", ".weight_v"), ".parametrizations.weight.original0", ".weight_g"), model_path) except Exception as e: logger.error(f"{translations['extract_model_error']}: {e}") def run(rank, n_gpus, experiment_dir, pretrainG, pretrainD, pitch_guidance, custom_total_epoch, custom_save_every_weights, config, device, model_author, vocoder, checkpointing): global global_step if rank == 0: writer_eval = SummaryWriter(log_dir=os.path.join(experiment_dir, "eval")) else: writer_eval = None dist.init_process_group(backend="gloo", init_method="env://", world_size=n_gpus, rank=rank) torch.manual_seed(config.train.seed) if torch.cuda.is_available(): torch.cuda.set_device(rank) train_dataset = TextAudioLoaderMultiNSFsid(config.data) train_loader = tdata.DataLoader(train_dataset, num_workers=4, shuffle=False, pin_memory=True, collate_fn=TextAudioCollateMultiNSFsid(), batch_sampler=DistributedBucketSampler(train_dataset, batch_size * n_gpus, [100, 200, 300, 400, 500, 600, 700, 800, 900], num_replicas=n_gpus, rank=rank, shuffle=True), persistent_workers=True, prefetch_factor=8) net_g, net_d = Synthesizer(config.data.filter_length // 2 + 1, config.train.segment_size // config.data.hop_length, **config.model, use_f0=pitch_guidance, sr=sample_rate, vocoder=vocoder, checkpointing=checkpointing), MultiPeriodDiscriminator(version, config.model.use_spectral_norm, checkpointing=checkpointing) net_g, net_d = (net_g.cuda(rank), net_d.cuda(rank)) if torch.cuda.is_available() else (net_g.to(device), net_d.to(device)) optim_g, optim_d = torch.optim.AdamW(net_g.parameters(), config.train.learning_rate, betas=config.train.betas, eps=config.train.eps), torch.optim.AdamW(net_d.parameters(), config.train.learning_rate, betas=config.train.betas, eps=config.train.eps) net_g, net_d = (DDP(net_g, device_ids=[rank]), DDP(net_d, device_ids=[rank])) if torch.cuda.is_available() else (DDP(net_g), DDP(net_d)) try: logger.info(translations["start_training"]) _, _, _, epoch_str = load_checkpoint((os.path.join(experiment_dir, "D_latest.pth") if save_only_latest else latest_checkpoint_path(experiment_dir, "D_*.pth")), net_d, optim_d) _, _, _, epoch_str = load_checkpoint((os.path.join(experiment_dir, "G_latest.pth") if save_only_latest else latest_checkpoint_path(experiment_dir, "G_*.pth")), net_g, optim_g) epoch_str += 1 global_step = (epoch_str - 1) * len(train_loader) except: epoch_str, global_step = 1, 0 if pretrainG != "" and pretrainG != "None": if rank == 0: verify_checkpoint_shapes(pretrainG, net_g) logger.info(translations["import_pretrain"].format(dg="G", pretrain=pretrainG)) if hasattr(net_g, "module"): net_g.module.load_state_dict(torch.load(pretrainG, map_location="cpu")["model"]) else: net_g.load_state_dict(torch.load(pretrainG, map_location="cpu")["model"]) else: logger.warning(translations["not_using_pretrain"].format(dg="G")) if pretrainD != "" and pretrainD != "None": if rank == 0: verify_checkpoint_shapes(pretrainD, net_d) logger.info(translations["import_pretrain"].format(dg="D", pretrain=pretrainD)) if hasattr(net_d, "module"): net_d.module.load_state_dict(torch.load(pretrainD, map_location="cpu")["model"]) else: net_d.load_state_dict(torch.load(pretrainD, map_location="cpu")["model"]) else: logger.warning(translations["not_using_pretrain"].format(dg="D")) scheduler_g, scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.train.lr_decay, last_epoch=epoch_str - 2), torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.train.lr_decay, last_epoch=epoch_str - 2) optim_d.step(); optim_g.step() scaler = GradScaler(enabled=False) cache = [] for info in train_loader: phone, phone_lengths, pitch, pitchf, _, _, _, _, sid = info reference = (phone.cuda(rank, non_blocking=True), phone_lengths.cuda(rank, non_blocking=True), (pitch.cuda(rank, non_blocking=True) if pitch_guidance else None), (pitchf.cuda(rank, non_blocking=True) if pitch_guidance else None), sid.cuda(rank, non_blocking=True)) if device.type == "cuda" else (phone.to(device), phone_lengths.to(device), (pitch.to(device) if pitch_guidance else None), (pitchf.to(device) if pitch_guidance else None), sid.to(device)) break for epoch in range(epoch_str, total_epoch + 1): train_and_evaluate(rank, epoch, config, [net_g, net_d], [optim_g, optim_d], scaler, train_loader, writer_eval, cache, custom_save_every_weights, custom_total_epoch, device, reference, model_author, vocoder) scheduler_g.step(); scheduler_d.step() def train_and_evaluate(rank, epoch, hps, nets, optims, scaler, train_loader, writer, cache, custom_save_every_weights, custom_total_epoch, device, reference, model_author, vocoder): global global_step, lowest_value, loss_disc, consecutive_increases_gen, consecutive_increases_disc if epoch == 1: lowest_value = {"step": 0, "value": float("inf"), "epoch": 0} last_loss_gen_all, consecutive_increases_gen, consecutive_increases_disc = 0.0, 0, 0 net_g, net_d = nets optim_g, optim_d = optims train_loader.batch_sampler.set_epoch(epoch) net_g.train(); net_d.train() if device.type == "cuda" and cache_data_in_gpu: data_iterator = cache if cache == []: for batch_idx, info in enumerate(train_loader): cache.append((batch_idx, [tensor.cuda(rank, non_blocking=True) for tensor in info])) else: shuffle(cache) else: data_iterator = enumerate(train_loader) epoch_recorder = EpochRecorder() with tqdm(total=len(train_loader), leave=False) as pbar: for batch_idx, info in data_iterator: if device.type == "cuda" and not cache_data_in_gpu: info = [tensor.cuda(rank, non_blocking=True) for tensor in info] elif device.type != "cuda": info = [tensor.to(device) for tensor in info] phone, phone_lengths, pitch, pitchf, spec, spec_lengths, wave, _, sid = info pitch = pitch if pitch_guidance else None pitchf = pitchf if pitch_guidance else None with autocast(enabled=False): y_hat, ids_slice, _, z_mask, (_, z_p, m_p, logs_p, _, logs_q) = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid) mel = spec_to_mel_torch(spec, config.data.filter_length, config.data.n_mel_channels, config.data.sample_rate, config.data.mel_fmin, config.data.mel_fmax) y_mel = slice_segments(mel, ids_slice, config.train.segment_size // config.data.hop_length, dim=3) with autocast(enabled=False): y_hat_mel = mel_spectrogram_torch(y_hat.float().squeeze(1), config.data.filter_length, config.data.n_mel_channels, config.data.sample_rate, config.data.hop_length, config.data.win_length, config.data.mel_fmin, config.data.mel_fmax) wave = slice_segments(wave, ids_slice * config.data.hop_length, config.train.segment_size, dim=3) y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach()) with autocast(enabled=False): loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g) optim_d.zero_grad() scaler.scale(loss_disc).backward() scaler.unscale_(optim_d) grad_norm_d = clip_grad_value(net_d.parameters(), None) scaler.step(optim_d) with autocast(enabled=False): y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat) with autocast(enabled=False): loss_mel = F.l1_loss(y_mel, y_hat_mel) * config.train.c_mel loss_kl = (kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * config.train.c_kl) loss_fm = feature_loss(fmap_r, fmap_g) loss_gen, losses_gen = generator_loss(y_d_hat_g) loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl if loss_gen_all < lowest_value["value"]: lowest_value["value"] = loss_gen_all lowest_value["step"] = global_step lowest_value["epoch"] = epoch if epoch > lowest_value["epoch"]: logger.warning(translations["training_warning"]) optim_g.zero_grad() scaler.scale(loss_gen_all).backward() scaler.unscale_(optim_g) grad_norm_g = clip_grad_value(net_g.parameters(), None) scaler.step(optim_g) scaler.update() if rank == 0 and global_step % config.train.log_interval == 0: if loss_mel > 75: loss_mel = 75 if loss_kl > 9: loss_kl = 9 scalar_dict = {"loss/g/total": loss_gen_all, "loss/d/total": loss_disc, "learning_rate": optim_g.param_groups[0]["lr"], "grad/norm_d": grad_norm_d, "grad/norm_g": grad_norm_g, "loss/g/fm": loss_fm, "loss/g/mel": loss_mel, "loss/g/kl": loss_kl} scalar_dict.update({f"loss/g/{i}": v for i, v in enumerate(losses_gen)}) scalar_dict.update({f"loss/d_r/{i}": v for i, v in enumerate(losses_disc_r)}) scalar_dict.update({f"loss/d_g/{i}": v for i, v in enumerate(losses_disc_g)}) with torch.no_grad(): o, *_ = net_g.module.infer(*reference) if hasattr(net_g, "module") else net_g.infer(*reference) summarize(writer=writer, global_step=global_step, images={"slice/mel_org": plot_spectrogram_to_numpy(y_mel[0].data.cpu().numpy()), "slice/mel_gen": plot_spectrogram_to_numpy(y_hat_mel[0].data.cpu().numpy()), "all/mel": plot_spectrogram_to_numpy(mel[0].data.cpu().numpy())}, scalars=scalar_dict, audios={f"gen/audio_{global_step:07d}": o[0, :, :]}, audio_sample_rate=config.data.sample_rate) global_step += 1 pbar.update(1) def check_overtraining(smoothed_loss_history, threshold, epsilon=0.004): if len(smoothed_loss_history) < threshold + 1: return False for i in range(-threshold, -1): if smoothed_loss_history[i + 1] > smoothed_loss_history[i]: return True if abs(smoothed_loss_history[i + 1] - smoothed_loss_history[i]) >= epsilon: return False return True def update_exponential_moving_average(smoothed_loss_history, new_value, smoothing=0.987): smoothed_value = new_value if not smoothed_loss_history else (smoothing * smoothed_loss_history[-1] + (1 - smoothing) * new_value) smoothed_loss_history.append(smoothed_value) return smoothed_value def save_to_json(file_path, loss_disc_history, smoothed_loss_disc_history, loss_gen_history, smoothed_loss_gen_history): with open(file_path, "w") as f: json.dump({"loss_disc_history": loss_disc_history, "smoothed_loss_disc_history": smoothed_loss_disc_history, "loss_gen_history": loss_gen_history, "smoothed_loss_gen_history": smoothed_loss_gen_history}, f) model_add, model_del = [], [] done = False if rank == 0: if epoch % save_every_epoch == False: checkpoint_suffix = f"{'latest' if save_only_latest else global_step}.pth" save_checkpoint(net_g, optim_g, config.train.learning_rate, epoch, os.path.join(experiment_dir, "G_" + checkpoint_suffix)) save_checkpoint(net_d, optim_d, config.train.learning_rate, epoch, os.path.join(experiment_dir, "D_" + checkpoint_suffix)) if custom_save_every_weights: model_add.append(os.path.join("assets", "weights", f"{model_name}_{epoch}e_{global_step}s.pth")) if overtraining_detector and epoch > 1: current_loss_disc = float(loss_disc) loss_disc_history.append(current_loss_disc) smoothed_value_disc = update_exponential_moving_average(smoothed_loss_disc_history, current_loss_disc) is_overtraining_disc = check_overtraining(smoothed_loss_disc_history, overtraining_threshold * 2) if is_overtraining_disc: consecutive_increases_disc += 1 else: consecutive_increases_disc = 0 current_loss_gen = float(lowest_value["value"]) loss_gen_history.append(current_loss_gen) smoothed_value_gen = update_exponential_moving_average(smoothed_loss_gen_history, current_loss_gen) is_overtraining_gen = check_overtraining(smoothed_loss_gen_history, overtraining_threshold, 0.01) if is_overtraining_gen: consecutive_increases_gen += 1 else: consecutive_increases_gen = 0 if epoch % save_every_epoch == 0: save_to_json(training_file_path, loss_disc_history, smoothed_loss_disc_history, loss_gen_history, smoothed_loss_gen_history) if (is_overtraining_gen and consecutive_increases_gen == overtraining_threshold or is_overtraining_disc and consecutive_increases_disc == (overtraining_threshold * 2)): logger.info(translations["overtraining_find"].format(epoch=epoch, smoothed_value_gen=f"{smoothed_value_gen:.3f}", smoothed_value_disc=f"{smoothed_value_disc:.3f}")) done = True else: logger.info(translations["best_epoch"].format(epoch=epoch, smoothed_value_gen=f"{smoothed_value_gen:.3f}", smoothed_value_disc=f"{smoothed_value_disc:.3f}")) for file in glob.glob(os.path.join("assets", "weights", f"{model_name}_*e_*s_best_epoch.pth")): model_del.append(file) model_add.append(os.path.join("assets", "weights", f"{model_name}_{epoch}e_{global_step}s_best_epoch.pth")) if epoch >= custom_total_epoch: logger.info(translations["success_training"].format(epoch=epoch, global_step=global_step, loss_gen_all=round(loss_gen_all.item(), 3))) logger.info(translations["training_info"].format(lowest_value_rounded=round(float(lowest_value["value"]), 3), lowest_value_epoch=lowest_value['epoch'], lowest_value_step=lowest_value['step'])) pid_file_path = os.path.join(experiment_dir, "config.json") with open(pid_file_path, "r") as pid_file: pid_data = json.load(pid_file) with open(pid_file_path, "w") as pid_file: pid_data.pop("process_pids", None) json.dump(pid_data, pid_file, indent=4) model_add.append(os.path.join("assets", "weights", f"{model_name}_{epoch}e_{global_step}s.pth")) done = True for m in model_del: os.remove(m) if model_add: ckpt = (net_g.module.state_dict() if hasattr(net_g, "module") else net_g.state_dict()) for m in model_add: extract_model(ckpt=ckpt, sr=sample_rate, pitch_guidance=pitch_guidance == True, name=model_name, model_path=m, epoch=epoch, step=global_step, version=version, hps=hps, model_author=model_author, vocoder=vocoder) lowest_value_rounded = round(float(lowest_value["value"]), 3) if epoch > 1 and overtraining_detector: logger.info(translations["model_training_info"].format(model_name=model_name, epoch=epoch, global_step=global_step, epoch_recorder=epoch_recorder.record(), lowest_value_rounded=lowest_value_rounded, lowest_value_epoch=lowest_value['epoch'], lowest_value_step=lowest_value['step'], remaining_epochs_gen=(overtraining_threshold - consecutive_increases_gen), remaining_epochs_disc=((overtraining_threshold * 2) - consecutive_increases_disc), smoothed_value_gen=f"{smoothed_value_gen:.3f}", smoothed_value_disc=f"{smoothed_value_disc:.3f}")) elif epoch > 1 and overtraining_detector == False: logger.info(translations["model_training_info_2"].format(model_name=model_name, epoch=epoch, global_step=global_step, epoch_recorder=epoch_recorder.record(), lowest_value_rounded=lowest_value_rounded, lowest_value_epoch=lowest_value['epoch'], lowest_value_step=lowest_value['step'])) else: logger.info(translations["model_training_info_3"].format(model_name=model_name, epoch=epoch, global_step=global_step, epoch_recorder=epoch_recorder.record())) last_loss_gen_all = loss_gen_all if done: os._exit(0) if __name__ == "__main__": torch.multiprocessing.set_start_method("spawn") try: main() except Exception as e: logger.error(f"{translations['training_error']} {e}") import traceback logger.debug(traceback.format_exc())