import copy import glob import os import shutil import time from collections import OrderedDict from PIL import Image from PIL.ImageOps import exif_transpose from einops import rearrange from safetensors.torch import save_file, load_file from torch.utils.data import DataLoader, ConcatDataset import torch from torch import nn from torchvision.transforms import transforms from jobs.process import BaseTrainProcess from toolkit.image_utils import show_tensors from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm from toolkit.data_loader import ImageDataset from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss, total_variation from toolkit.metadata import get_meta_for_safetensors from toolkit.optimizer import get_optimizer from toolkit.style import get_style_model_and_losses from toolkit.train_tools import get_torch_dtype from diffusers import AutoencoderKL from tqdm import tqdm import math import torchvision.utils import time import numpy as np from .models.critic import Critic from torchvision.transforms import Resize import lpips import random import traceback IMAGE_TRANSFORMS = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) def unnormalize(tensor): return (tensor / 2 + 0.5).clamp(0, 1) def channel_dropout(x, p=0.5): keep_prob = 1 - p mask = torch.rand(x.size(0), x.size(1), 1, 1, device=x.device, dtype=x.dtype) < keep_prob mask = mask / keep_prob # scale return x * mask class TrainVAEProcess(BaseTrainProcess): def __init__(self, process_id: int, job, config: OrderedDict): super().__init__(process_id, job, config) self.data_loader = None self.vae = None self.device = self.get_conf('device', self.job.device) self.vae_path = self.get_conf('vae_path', None) self.eq_vae = self.get_conf('eq_vae', False) self.datasets_objects = self.get_conf('datasets', required=True) self.batch_size = self.get_conf('batch_size', 1, as_type=int) self.resolution = self.get_conf('resolution', 256, as_type=int) self.learning_rate = self.get_conf('learning_rate', 1e-6, as_type=float) self.sample_every = self.get_conf('sample_every', None) self.optimizer_type = self.get_conf('optimizer', 'adam') self.epochs = self.get_conf('epochs', None, as_type=int) self.max_steps = self.get_conf('max_steps', None, as_type=int) self.save_every = self.get_conf('save_every', None) self.dtype = self.get_conf('dtype', 'float32') self.sample_sources = self.get_conf('sample_sources', None) self.log_every = self.get_conf('log_every', 100, as_type=int) self.style_weight = self.get_conf('style_weight', 0, as_type=float) self.content_weight = self.get_conf('content_weight', 0, as_type=float) self.kld_weight = self.get_conf('kld_weight', 0, as_type=float) self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float) self.mv_loss_weight = self.get_conf('mv_loss_weight', 0, as_type=float) self.tv_weight = self.get_conf('tv_weight', 0, as_type=float) self.ltv_weight = self.get_conf('ltv_weight', 0, as_type=float) self.lpm_weight = self.get_conf('lpm_weight', 0, as_type=float) # latent pixel matching self.lpips_weight = self.get_conf('lpips_weight', 1e0, as_type=float) self.critic_weight = self.get_conf('critic_weight', 1, as_type=float) self.pattern_weight = self.get_conf('pattern_weight', 0, as_type=float) self.optimizer_params = self.get_conf('optimizer_params', {}) self.vae_config = self.get_conf('vae_config', None) self.dropout = self.get_conf('dropout', 0.0, as_type=float) self.train_encoder = self.get_conf('train_encoder', False, as_type=bool) self.random_scaling = self.get_conf('random_scaling', False, as_type=bool) if not self.train_encoder: # remove losses that only target encoder self.kld_weight = 0 self.mv_loss_weight = 0 self.ltv_weight = 0 self.lpm_weight = 0 self.blocks_to_train = self.get_conf('blocks_to_train', ['all']) self.torch_dtype = get_torch_dtype(self.dtype) self.vgg_19 = None self.style_weight_scalers = [] self.content_weight_scalers = [] self.lpips_loss:lpips.LPIPS = None self.vae_scale_factor = 8 self.step_num = 0 self.epoch_num = 0 self.use_critic = self.get_conf('use_critic', False, as_type=bool) self.critic = None if self.use_critic: self.critic = Critic( device=self.device, dtype=self.dtype, process=self, **self.get_conf('critic', {}) # pass any other params ) if self.sample_every is not None and self.sample_sources is None: raise ValueError('sample_every is specified but sample_sources is not') if self.epochs is None and self.max_steps is None: raise ValueError('epochs or max_steps must be specified') self.data_loaders = [] # check datasets assert isinstance(self.datasets_objects, list) for dataset in self.datasets_objects: if 'path' not in dataset: raise ValueError('dataset must have a path') # check if is dir if not os.path.isdir(dataset['path']): raise ValueError(f"dataset path does is not a directory: {dataset['path']}") # make training folder if not os.path.exists(self.save_root): os.makedirs(self.save_root, exist_ok=True) self._pattern_loss = None def update_training_metadata(self): self.add_meta(OrderedDict({"training_info": self.get_training_info()})) def get_training_info(self): info = OrderedDict({ 'step': self.step_num, 'epoch': self.epoch_num, }) return info def load_datasets(self): if self.data_loader is None: print(f"Loading datasets") datasets = [] for dataset in self.datasets_objects: print(f" - Dataset: {dataset['path']}") ds = copy.copy(dataset) dataset_res = self.resolution if self.random_scaling: # scale 2x to allow for random scaling dataset_res = int(dataset_res * 2) ds['resolution'] = dataset_res image_dataset = ImageDataset(ds) datasets.append(image_dataset) concatenated_dataset = ConcatDataset(datasets) self.data_loader = DataLoader( concatenated_dataset, batch_size=self.batch_size, shuffle=True, num_workers=16 ) def remove_oldest_checkpoint(self): max_to_keep = 4 folders = glob.glob(os.path.join(self.save_root, f"{self.job.name}*_diffusers")) if len(folders) > max_to_keep: folders.sort(key=os.path.getmtime) for folder in folders[:-max_to_keep]: print(f"Removing {folder}") shutil.rmtree(folder) # also handle CRITIC_vae_42_000000500.safetensors format for critic critic_files = glob.glob(os.path.join(self.save_root, f"CRITIC_{self.job.name}*.safetensors")) if len(critic_files) > max_to_keep: critic_files.sort(key=os.path.getmtime) for file in critic_files[:-max_to_keep]: print(f"Removing {file}") os.remove(file) def setup_vgg19(self): if self.vgg_19 is None: self.vgg_19, self.style_losses, self.content_losses, self.vgg19_pool_4 = get_style_model_and_losses( single_target=True, device=self.device, output_layer_name='pool_4', dtype=self.torch_dtype ) self.vgg_19.to(self.device, dtype=self.torch_dtype) self.vgg_19.requires_grad_(False) # we run random noise through first to get layer scalers to normalize the loss per layer # bs of 2 because we run pred and target through stacked noise = torch.randn((2, 3, self.resolution, self.resolution), device=self.device, dtype=self.torch_dtype) self.vgg_19(noise) for style_loss in self.style_losses: # get a scaler to normalize to 1 scaler = 1 / torch.mean(style_loss.loss).item() self.style_weight_scalers.append(scaler) for content_loss in self.content_losses: # get a scaler to normalize to 1 scaler = 1 / torch.mean(content_loss.loss).item() self.content_weight_scalers.append(scaler) self.print(f"Style weight scalers: {self.style_weight_scalers}") self.print(f"Content weight scalers: {self.content_weight_scalers}") def get_style_loss(self): if self.style_weight > 0: # scale all losses with loss scalers loss = torch.sum( torch.stack([loss.loss * scaler for loss, scaler in zip(self.style_losses, self.style_weight_scalers)])) return loss else: return torch.tensor(0.0, device=self.device) def get_content_loss(self): if self.content_weight > 0: # scale all losses with loss scalers loss = torch.sum(torch.stack( [loss.loss * scaler for loss, scaler in zip(self.content_losses, self.content_weight_scalers)])) return loss else: return torch.tensor(0.0, device=self.device) def get_mse_loss(self, pred, target): if self.mse_weight > 0: loss_fn = nn.MSELoss() loss = loss_fn(pred, target) return loss else: return torch.tensor(0.0, device=self.device) def get_kld_loss(self, mu, log_var): if self.kld_weight > 0: # Kullback-Leibler divergence # added here for full training (not implemented). Not needed for only decoder # as we are not changing the distribution of the latent space # normally it would help keep a normal distribution for latents KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp()) # KL divergence return KLD else: return torch.tensor(0.0, device=self.device) def get_mean_variance_loss(self, latents: torch.Tensor): if self.mv_loss_weight > 0: # collapse rows into channels latents_col = rearrange(latents, 'b c h (gw w) -> b (c gw) h w', gw=latents.shape[-1]) mean_col = latents_col.mean(dim=(2, 3), keepdim=True) std_col = latents_col.std(dim=(2, 3), keepdim=True, unbiased=False) mean_loss_col = (mean_col ** 2).mean() std_loss_col = ((std_col - 1) ** 2).mean() # collapse columns into channels latents_row = rearrange(latents, 'b c (gh h) w -> b (c gh) h w', gh=latents.shape[-2]) mean_row = latents_row.mean(dim=(2, 3), keepdim=True) std_row = latents_row.std(dim=(2, 3), keepdim=True, unbiased=False) mean_loss_row = (mean_row ** 2).mean() std_loss_row = ((std_row - 1) ** 2).mean() # do a global one mean = latents.mean(dim=(2, 3), keepdim=True) std = latents.std(dim=(2, 3), keepdim=True, unbiased=False) mean_loss_global = (mean ** 2).mean() std_loss_global = ((std - 1) ** 2).mean() return (mean_loss_col + std_loss_col + mean_loss_row + std_loss_row + mean_loss_global + std_loss_global) / 3 else: return torch.tensor(0.0, device=self.device) def get_ltv_loss(self, latent): # loss to reduce the latent space variance if self.ltv_weight > 0: return total_variation(latent).mean() else: return torch.tensor(0.0, device=self.device) def get_latent_pixel_matching_loss(self, latent, pixels): if self.lpm_weight > 0: with torch.no_grad(): pixels = pixels.to(latent.device, dtype=latent.dtype) # resize down to latent size pixels = torch.nn.functional.interpolate(pixels, size=(latent.shape[2], latent.shape[3]), mode='bilinear', align_corners=False) # mean the color channel and then expand to latent size pixels = pixels.mean(dim=1, keepdim=True) pixels = pixels.repeat(1, latent.shape[1], 1, 1) # match the mean std of latent latent_mean = latent.mean(dim=(2, 3), keepdim=True) latent_std = latent.std(dim=(2, 3), keepdim=True) pixels_mean = pixels.mean(dim=(2, 3), keepdim=True) pixels_std = pixels.std(dim=(2, 3), keepdim=True) pixels = (pixels - pixels_mean) / (pixels_std + 1e-6) * latent_std + latent_mean return torch.nn.functional.mse_loss(latent.float(), pixels.float()) else: return torch.tensor(0.0, device=self.device) def get_tv_loss(self, pred, target): if self.tv_weight > 0: get_tv_loss = ComparativeTotalVariation() loss = get_tv_loss(pred, target) return loss else: return torch.tensor(0.0, device=self.device) def get_pattern_loss(self, pred, target): if self._pattern_loss is None: self._pattern_loss = PatternLoss(pattern_size=16, dtype=self.torch_dtype).to(self.device, dtype=self.torch_dtype) loss = torch.mean(self._pattern_loss(pred, target)) return loss def save(self, step=None): if not os.path.exists(self.save_root): os.makedirs(self.save_root, exist_ok=True) step_num = '' if step is not None: # zeropad 9 digits step_num = f"_{str(step).zfill(9)}" self.update_training_metadata() filename = f'{self.job.name}{step_num}_diffusers' self.vae = self.vae.to("cpu", dtype=torch.float16) self.vae.save_pretrained( save_directory=os.path.join(self.save_root, filename) ) self.vae = self.vae.to(self.device, dtype=self.torch_dtype) self.print(f"Saved to {os.path.join(self.save_root, filename)}") if self.use_critic: self.critic.save(step) self.remove_oldest_checkpoint() def sample(self, step=None): sample_folder = os.path.join(self.save_root, 'samples') if not os.path.exists(sample_folder): os.makedirs(sample_folder, exist_ok=True) with torch.no_grad(): for i, img_url in enumerate(self.sample_sources): img = exif_transpose(Image.open(img_url)) img = img.convert('RGB') # crop if not square if img.width != img.height: min_dim = min(img.width, img.height) img = img.crop((0, 0, min_dim, min_dim)) # resize img = img.resize((self.resolution, self.resolution)) input_img = img img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.torch_dtype) img = img latent = self.vae.encode(img).latent_dist.sample() latent_img = latent.clone() bs, ch, h, w = latent_img.shape grid_size = math.ceil(math.sqrt(ch)) pad = grid_size * grid_size - ch # take first item in batch latent_img = latent_img[0] # shape: (ch, h, w) if pad > 0: padding = torch.zeros((pad, h, w), dtype=latent_img.dtype, device=latent_img.device) latent_img = torch.cat([latent_img, padding], dim=0) # make grid new_img = torch.zeros((1, grid_size * h, grid_size * w), dtype=latent_img.dtype, device=latent_img.device) for x in range(grid_size): for y in range(grid_size): if x * grid_size + y < ch: new_img[0, x * h:(x + 1) * h, y * w:(y + 1) * w] = latent_img[x * grid_size + y] latent_img = new_img # make rgb latent_img = latent_img.repeat(3, 1, 1).unsqueeze(0) latent_img = (latent_img / 2 + 0.5).clamp(0, 1) # resize to 256x256 latent_img = torch.nn.functional.interpolate(latent_img, size=(self.resolution, self.resolution), mode='nearest') latent_img = latent_img.squeeze(0).cpu().permute(1, 2, 0).float().numpy() latent_img = (latent_img * 255).astype(np.uint8) # convert to pillow image latent_img = Image.fromarray(latent_img) decoded = self.vae.decode(latent).sample decoded = (decoded / 2 + 0.5).clamp(0, 1) # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 decoded = decoded.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy() # convert to pillow image decoded = Image.fromarray((decoded * 255).astype(np.uint8)) # stack input image and decoded image input_img = input_img.resize((self.resolution, self.resolution)) decoded = decoded.resize((self.resolution, self.resolution)) output_img = Image.new('RGB', (self.resolution * 3, self.resolution)) output_img.paste(input_img, (0, 0)) output_img.paste(decoded, (self.resolution, 0)) output_img.paste(latent_img, (self.resolution * 2, 0)) scale_up = 2 if output_img.height <= 300: scale_up = 4 # scale up using nearest neighbor output_img = output_img.resize((output_img.width * scale_up, output_img.height * scale_up), Image.NEAREST) step_num = '' if step is not None: # zero-pad 9 digits step_num = f"_{str(step).zfill(9)}" seconds_since_epoch = int(time.time()) # zero-pad 2 digits i_str = str(i).zfill(2) filename = f"{seconds_since_epoch}{step_num}_{i_str}.png" output_img.save(os.path.join(sample_folder, filename)) def load_vae(self): path_to_load = self.vae_path # see if we have a checkpoint in out output to resume from self.print(f"Looking for latest checkpoint in {self.save_root}") files = glob.glob(os.path.join(self.save_root, f"{self.job.name}*_diffusers")) if files and len(files) > 0: latest_file = max(files, key=os.path.getmtime) print(f" - Latest checkpoint is: {latest_file}") path_to_load = latest_file # todo update step and epoch count else: self.print(f" - No checkpoint found, starting from scratch") # load vae self.print(f"Loading VAE") self.print(f" - Loading VAE: {path_to_load}") if self.vae is None: if path_to_load is not None: self.vae = AutoencoderKL.from_pretrained(path_to_load) elif self.vae_config is not None: self.vae = AutoencoderKL(**self.vae_config) else: raise ValueError('vae_path or ae_config must be specified') # set decoder to train self.vae.to(self.device, dtype=self.torch_dtype) if self.eq_vae: self.vae.encoder.train() else: self.vae.requires_grad_(False) self.vae.eval() self.vae.decoder.train() self.vae_scale_factor = 2 ** (len(self.vae.config['block_out_channels']) - 1) def run(self): super().run() self.load_datasets() max_step_epochs = self.max_steps // len(self.data_loader) num_epochs = self.epochs if num_epochs is None or num_epochs > max_step_epochs: num_epochs = max_step_epochs max_epoch_steps = len(self.data_loader) * num_epochs num_steps = self.max_steps if num_steps is None or num_steps > max_epoch_steps: num_steps = max_epoch_steps self.max_steps = num_steps self.epochs = num_epochs start_step = self.step_num self.first_step = start_step self.print(f"Training VAE") self.print(f" - Training folder: {self.training_folder}") self.print(f" - Batch size: {self.batch_size}") self.print(f" - Learning rate: {self.learning_rate}") self.print(f" - Epochs: {num_epochs}") self.print(f" - Max steps: {self.max_steps}") # load vae self.load_vae() params = [] # only set last 2 layers to trainable for param in self.vae.decoder.parameters(): param.requires_grad = False train_all = 'all' in self.blocks_to_train if train_all: params = list(self.vae.decoder.parameters()) self.vae.decoder.requires_grad_(True) if self.train_encoder: # encoder params += list(self.vae.encoder.parameters()) self.vae.encoder.requires_grad_(True) else: # mid_block if train_all or 'mid_block' in self.blocks_to_train: params += list(self.vae.decoder.mid_block.parameters()) self.vae.decoder.mid_block.requires_grad_(True) # up_blocks if train_all or 'up_blocks' in self.blocks_to_train: params += list(self.vae.decoder.up_blocks.parameters()) self.vae.decoder.up_blocks.requires_grad_(True) # conv_out (single conv layer output) if train_all or 'conv_out' in self.blocks_to_train: params += list(self.vae.decoder.conv_out.parameters()) self.vae.decoder.conv_out.requires_grad_(True) if self.style_weight > 0 or self.content_weight > 0: self.setup_vgg19() # self.vgg_19.requires_grad_(False) self.vgg_19.eval() if self.use_critic: self.critic.setup() if self.lpips_weight > 0 and self.lpips_loss is None: # self.lpips_loss = lpips.LPIPS(net='vgg') self.lpips_loss = lpips.LPIPS(net='vgg').to(self.device, dtype=self.torch_dtype) optimizer = get_optimizer(params, self.optimizer_type, self.learning_rate, optimizer_params=self.optimizer_params) # setup scheduler # todo allow other schedulers scheduler = torch.optim.lr_scheduler.ConstantLR( optimizer, total_iters=num_steps, factor=1, verbose=False ) # setup tqdm progress bar self.progress_bar = tqdm( total=num_steps, desc='Training VAE', leave=True ) # sample first self.sample() blank_losses = OrderedDict({ "total": [], "lpips": [], "style": [], "content": [], "mse": [], "mvl": [], "ltv": [], "lpm": [], "kl": [], "tv": [], "ptn": [], "crD": [], "crG": [], }) epoch_losses = copy.deepcopy(blank_losses) log_losses = copy.deepcopy(blank_losses) # range start at self.epoch_num go to self.epochs latent_size = self.resolution // self.vae_scale_factor for epoch in range(self.epoch_num, self.epochs, 1): if self.step_num >= self.max_steps: break for batch in self.data_loader: if self.step_num >= self.max_steps: break with torch.no_grad(): batch = batch.to(self.device, dtype=self.torch_dtype) if self.random_scaling: # only random scale 0.5 of the time if random.random() < 0.5: # random scale the batch scale_factor = 0.25 else: scale_factor = 0.5 new_size = (int(batch.shape[2] * scale_factor), int(batch.shape[3] * scale_factor)) # make sure it is vae divisible new_size = (new_size[0] // self.vae_scale_factor * self.vae_scale_factor, new_size[1] // self.vae_scale_factor * self.vae_scale_factor) # resize so it matches size of vae evenly if batch.shape[2] % self.vae_scale_factor != 0 or batch.shape[3] % self.vae_scale_factor != 0: batch = Resize((batch.shape[2] // self.vae_scale_factor * self.vae_scale_factor, batch.shape[3] // self.vae_scale_factor * self.vae_scale_factor))(batch) # forward pass # grad only if eq_vae with torch.set_grad_enabled(self.train_encoder): dgd = self.vae.encode(batch).latent_dist mu, logvar = dgd.mean, dgd.logvar latents = dgd.sample() if self.eq_vae: # process flips, rotate, scale latent_chunks = list(latents.chunk(latents.shape[0], dim=0)) batch_chunks = list(batch.chunk(batch.shape[0], dim=0)) out_chunks = [] for i in range(len(latent_chunks)): try: do_rotate = random.randint(0, 3) do_flip_x = random.randint(0, 1) do_flip_y = random.randint(0, 1) do_scale = random.randint(0, 1) if do_rotate > 0: latent_chunks[i] = torch.rot90(latent_chunks[i], do_rotate, (2, 3)) batch_chunks[i] = torch.rot90(batch_chunks[i], do_rotate, (2, 3)) if do_flip_x > 0: latent_chunks[i] = torch.flip(latent_chunks[i], [2]) batch_chunks[i] = torch.flip(batch_chunks[i], [2]) if do_flip_y > 0: latent_chunks[i] = torch.flip(latent_chunks[i], [3]) batch_chunks[i] = torch.flip(batch_chunks[i], [3]) # resize latent to fit if latent_chunks[i].shape[2] != latent_size or latent_chunks[i].shape[3] != latent_size: latent_chunks[i] = torch.nn.functional.interpolate(latent_chunks[i], size=(latent_size, latent_size), mode='bilinear', align_corners=False) # if do_scale > 0: # scale = 2 # start_latent_h = latent_chunks[i].shape[2] # start_latent_w = latent_chunks[i].shape[3] # start_batch_h = batch_chunks[i].shape[2] # start_batch_w = batch_chunks[i].shape[3] # latent_chunks[i] = torch.nn.functional.interpolate(latent_chunks[i], scale_factor=scale, mode='bilinear', align_corners=False) # batch_chunks[i] = torch.nn.functional.interpolate(batch_chunks[i], scale_factor=scale, mode='bilinear', align_corners=False) # # random crop. latent is smaller than match but crops need to match # latent_x = random.randint(0, latent_chunks[i].shape[2] - start_latent_h) # latent_y = random.randint(0, latent_chunks[i].shape[3] - start_latent_w) # batch_x = latent_x * self.vae_scale_factor # batch_y = latent_y * self.vae_scale_factor # # crop # latent_chunks[i] = latent_chunks[i][:, :, latent_x:latent_x + start_latent_h, latent_y:latent_y + start_latent_w] # batch_chunks[i] = batch_chunks[i][:, :, batch_x:batch_x + start_batch_h, batch_y:batch_y + start_batch_w] except Exception as e: print(f"Error processing image {i}: {e}") traceback.print_exc() raise e out_chunks.append(latent_chunks[i]) latents = torch.cat(out_chunks, dim=0) # do dropout if self.dropout > 0: forward_latents = channel_dropout(latents, self.dropout) else: forward_latents = latents # resize batch to resolution if needed if batch_chunks[0].shape[2] != self.resolution or batch_chunks[0].shape[3] != self.resolution: batch_chunks = [torch.nn.functional.interpolate(b, size=(self.resolution, self.resolution), mode='bilinear', align_corners=False) for b in batch_chunks] batch = torch.cat(batch_chunks, dim=0) else: latents.detach().requires_grad_(True) forward_latents = latents forward_latents = forward_latents.to(self.device, dtype=self.torch_dtype) if not self.train_encoder: # detach latents if not training encoder forward_latents = forward_latents.detach() pred = self.vae.decode(forward_latents).sample # Run through VGG19 if self.style_weight > 0 or self.content_weight > 0: stacked = torch.cat([pred, batch], dim=0) stacked = (stacked / 2 + 0.5).clamp(0, 1) self.vgg_19(stacked) if self.use_critic: stacked = torch.cat([pred, batch], dim=0) critic_d_loss = self.critic.step(stacked.detach()) else: critic_d_loss = 0.0 style_loss = self.get_style_loss() * self.style_weight content_loss = self.get_content_loss() * self.content_weight kld_loss = self.get_kld_loss(mu, logvar) * self.kld_weight mse_loss = self.get_mse_loss(pred, batch) * self.mse_weight if self.lpips_weight > 0: lpips_loss = self.lpips_loss( pred.clamp(-1, 1), batch.clamp(-1, 1) ).mean() * self.lpips_weight else: lpips_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) tv_loss = self.get_tv_loss(pred, batch) * self.tv_weight pattern_loss = self.get_pattern_loss(pred, batch) * self.pattern_weight if self.use_critic: stacked = torch.cat([pred, batch], dim=0) critic_gen_loss = self.critic.get_critic_loss(stacked) * self.critic_weight # do not let abs critic gen loss be higher than abs lpips * 0.1 if using it if self.lpips_weight > 0: max_target = lpips_loss.abs() * 0.1 with torch.no_grad(): crit_g_scaler = 1.0 if critic_gen_loss.abs() > max_target: crit_g_scaler = max_target / critic_gen_loss.abs() critic_gen_loss *= crit_g_scaler else: critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) if self.mv_loss_weight > 0: mv_loss = self.get_mean_variance_loss(latents) * self.mv_loss_weight else: mv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) if self.ltv_weight > 0: ltv_loss = self.get_ltv_loss(latents) * self.ltv_weight else: ltv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) if self.lpm_weight > 0: lpm_loss = self.get_latent_pixel_matching_loss(latents, batch) * self.lpm_weight else: lpm_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype) loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss + mv_loss + ltv_loss # check if loss is NaN or Inf if torch.isnan(loss) or torch.isinf(loss): self.print(f"Loss is NaN or Inf, stopping at step {self.step_num}") self.print(f" - Style loss: {style_loss.item()}") self.print(f" - Content loss: {content_loss.item()}") self.print(f" - KLD loss: {kld_loss.item()}") self.print(f" - MSE loss: {mse_loss.item()}") self.print(f" - LPIPS loss: {lpips_loss.item()}") self.print(f" - TV loss: {tv_loss.item()}") self.print(f" - Pattern loss: {pattern_loss.item()}") self.print(f" - Critic gen loss: {critic_gen_loss.item()}") self.print(f" - Critic D loss: {critic_d_loss}") self.print(f" - Mean variance loss: {mv_loss.item()}") self.print(f" - Latent TV loss: {ltv_loss.item()}") self.print(f" - Latent pixel matching loss: {lpm_loss.item()}") self.print(f" - Total loss: {loss.item()}") self.print(f" - Stopping training") exit(1) # Backward pass and optimization optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() # update progress bar loss_value = loss.item() # get exponent like 3.54e-4 loss_string = f"loss: {loss_value:.2e}" if self.lpips_weight > 0: loss_string += f" lpips: {lpips_loss.item():.2e}" if self.content_weight > 0: loss_string += f" cnt: {content_loss.item():.2e}" if self.style_weight > 0: loss_string += f" sty: {style_loss.item():.2e}" if self.kld_weight > 0: loss_string += f" kld: {kld_loss.item():.2e}" if self.mse_weight > 0: loss_string += f" mse: {mse_loss.item():.2e}" if self.tv_weight > 0: loss_string += f" tv: {tv_loss.item():.2e}" if self.pattern_weight > 0: loss_string += f" ptn: {pattern_loss.item():.2e}" if self.use_critic and self.critic_weight > 0: loss_string += f" crG: {critic_gen_loss.item():.2e}" if self.use_critic: loss_string += f" crD: {critic_d_loss:.2e}" if self.mv_loss_weight > 0: loss_string += f" mvl: {mv_loss:.2e}" if self.ltv_weight > 0: loss_string += f" ltv: {ltv_loss:.2e}" if self.lpm_weight > 0: loss_string += f" lpm: {lpm_loss:.2e}" if hasattr(optimizer, 'get_avg_learning_rate'): learning_rate = optimizer.get_avg_learning_rate() elif self.optimizer_type.startswith('dadaptation') or \ self.optimizer_type.lower().startswith('prodigy'): learning_rate = ( optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] ) else: learning_rate = optimizer.param_groups[0]['lr'] lr_critic_string = '' if self.use_critic: lr_critic = self.critic.get_lr() lr_critic_string = f" lrC: {lr_critic:.1e}" self.progress_bar.set_postfix_str(f"lr: {learning_rate:.1e}{lr_critic_string} {loss_string}") self.progress_bar.set_description(f"E: {epoch}") self.progress_bar.update(1) epoch_losses["total"].append(loss_value) epoch_losses["lpips"].append(lpips_loss.item()) epoch_losses["style"].append(style_loss.item()) epoch_losses["content"].append(content_loss.item()) epoch_losses["mse"].append(mse_loss.item()) epoch_losses["kl"].append(kld_loss.item()) epoch_losses["tv"].append(tv_loss.item()) epoch_losses["ptn"].append(pattern_loss.item()) epoch_losses["crG"].append(critic_gen_loss.item()) epoch_losses["crD"].append(critic_d_loss) epoch_losses["mvl"].append(mv_loss.item()) epoch_losses["ltv"].append(ltv_loss.item()) epoch_losses["lpm"].append(lpm_loss.item()) log_losses["total"].append(loss_value) log_losses["lpips"].append(lpips_loss.item()) log_losses["style"].append(style_loss.item()) log_losses["content"].append(content_loss.item()) log_losses["mse"].append(mse_loss.item()) log_losses["kl"].append(kld_loss.item()) log_losses["tv"].append(tv_loss.item()) log_losses["ptn"].append(pattern_loss.item()) log_losses["crG"].append(critic_gen_loss.item()) log_losses["crD"].append(critic_d_loss) log_losses["mvl"].append(mv_loss.item()) log_losses["ltv"].append(ltv_loss.item()) log_losses["lpm"].append(lpm_loss.item()) # don't do on first step if self.step_num != start_step: if self.sample_every and self.step_num % self.sample_every == 0: # print above the progress bar self.print(f"Sampling at step {self.step_num}") self.sample(self.step_num) if self.save_every and self.step_num % self.save_every == 0: # print above the progress bar self.print(f"Saving at step {self.step_num}") self.save(self.step_num) if self.log_every and self.step_num % self.log_every == 0: # log to tensorboard if self.writer is not None: # get avg loss for key in log_losses: log_losses[key] = sum(log_losses[key]) / (len(log_losses[key]) + 1e-6) # if log_losses[key] > 0: self.writer.add_scalar(f"loss/{key}", log_losses[key], self.step_num) # reset log losses log_losses = copy.deepcopy(blank_losses) self.step_num += 1 # end epoch if self.writer is not None: eps = 1e-6 # get avg loss for key in epoch_losses: epoch_losses[key] = sum(log_losses[key]) / (len(log_losses[key]) + eps) if epoch_losses[key] > 0: self.writer.add_scalar(f"epoch loss/{key}", epoch_losses[key], epoch) # reset epoch losses epoch_losses = copy.deepcopy(blank_losses) self.save()