rahul7star commited on
Commit
909058a
·
verified ·
1 Parent(s): 703133d

Update jobs/process/TrainVAEProcess.py

Browse files
Files changed (1) hide show
  1. jobs/process/TrainVAEProcess.py +307 -34
jobs/process/TrainVAEProcess.py CHANGED
@@ -7,6 +7,7 @@ from collections import OrderedDict
7
 
8
  from PIL import Image
9
  from PIL.ImageOps import exif_transpose
 
10
  from safetensors.torch import save_file, load_file
11
  from torch.utils.data import DataLoader, ConcatDataset
12
  import torch
@@ -17,18 +18,22 @@ from jobs.process import BaseTrainProcess
17
  from toolkit.image_utils import show_tensors
18
  from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm
19
  from toolkit.data_loader import ImageDataset
20
- from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss
21
  from toolkit.metadata import get_meta_for_safetensors
22
  from toolkit.optimizer import get_optimizer
23
  from toolkit.style import get_style_model_and_losses
24
  from toolkit.train_tools import get_torch_dtype
25
  from diffusers import AutoencoderKL
26
  from tqdm import tqdm
 
 
27
  import time
28
  import numpy as np
29
- from .models.vgg19_critic import Critic
30
  from torchvision.transforms import Resize
31
  import lpips
 
 
32
 
33
  IMAGE_TRANSFORMS = transforms.Compose(
34
  [
@@ -42,13 +47,21 @@ def unnormalize(tensor):
42
  return (tensor / 2 + 0.5).clamp(0, 1)
43
 
44
 
 
 
 
 
 
 
 
45
  class TrainVAEProcess(BaseTrainProcess):
46
  def __init__(self, process_id: int, job, config: OrderedDict):
47
  super().__init__(process_id, job, config)
48
  self.data_loader = None
49
  self.vae = None
50
  self.device = self.get_conf('device', self.job.device)
51
- self.vae_path = self.get_conf('vae_path', required=True)
 
52
  self.datasets_objects = self.get_conf('datasets', required=True)
53
  self.batch_size = self.get_conf('batch_size', 1, as_type=int)
54
  self.resolution = self.get_conf('resolution', 256, as_type=int)
@@ -65,11 +78,25 @@ class TrainVAEProcess(BaseTrainProcess):
65
  self.content_weight = self.get_conf('content_weight', 0, as_type=float)
66
  self.kld_weight = self.get_conf('kld_weight', 0, as_type=float)
67
  self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float)
68
- self.tv_weight = self.get_conf('tv_weight', 1e0, as_type=float)
 
 
 
69
  self.lpips_weight = self.get_conf('lpips_weight', 1e0, as_type=float)
70
  self.critic_weight = self.get_conf('critic_weight', 1, as_type=float)
71
- self.pattern_weight = self.get_conf('pattern_weight', 1, as_type=float)
72
  self.optimizer_params = self.get_conf('optimizer_params', {})
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  self.blocks_to_train = self.get_conf('blocks_to_train', ['all'])
75
  self.torch_dtype = get_torch_dtype(self.dtype)
@@ -133,7 +160,11 @@ class TrainVAEProcess(BaseTrainProcess):
133
  for dataset in self.datasets_objects:
134
  print(f" - Dataset: {dataset['path']}")
135
  ds = copy.copy(dataset)
136
- ds['resolution'] = self.resolution
 
 
 
 
137
  image_dataset = ImageDataset(ds)
138
  datasets.append(image_dataset)
139
 
@@ -142,7 +173,7 @@ class TrainVAEProcess(BaseTrainProcess):
142
  concatenated_dataset,
143
  batch_size=self.batch_size,
144
  shuffle=True,
145
- num_workers=6
146
  )
147
 
148
  def remove_oldest_checkpoint(self):
@@ -153,6 +184,13 @@ class TrainVAEProcess(BaseTrainProcess):
153
  for folder in folders[:-max_to_keep]:
154
  print(f"Removing {folder}")
155
  shutil.rmtree(folder)
 
 
 
 
 
 
 
156
 
157
  def setup_vgg19(self):
158
  if self.vgg_19 is None:
@@ -218,6 +256,62 @@ class TrainVAEProcess(BaseTrainProcess):
218
  else:
219
  return torch.tensor(0.0, device=self.device)
220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  def get_tv_loss(self, pred, target):
222
  if self.tv_weight > 0:
223
  get_tv_loss = ComparativeTotalVariation()
@@ -277,7 +371,39 @@ class TrainVAEProcess(BaseTrainProcess):
277
  input_img = img
278
  img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.torch_dtype)
279
  img = img
280
- decoded = self.vae(img).sample
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  decoded = (decoded / 2 + 0.5).clamp(0, 1)
282
  # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
283
  decoded = decoded.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
@@ -289,9 +415,10 @@ class TrainVAEProcess(BaseTrainProcess):
289
  input_img = input_img.resize((self.resolution, self.resolution))
290
  decoded = decoded.resize((self.resolution, self.resolution))
291
 
292
- output_img = Image.new('RGB', (self.resolution * 2, self.resolution))
293
  output_img.paste(input_img, (0, 0))
294
  output_img.paste(decoded, (self.resolution, 0))
 
295
 
296
  scale_up = 2
297
  if output_img.height <= 300:
@@ -326,12 +453,20 @@ class TrainVAEProcess(BaseTrainProcess):
326
  self.print(f"Loading VAE")
327
  self.print(f" - Loading VAE: {path_to_load}")
328
  if self.vae is None:
329
- self.vae = AutoencoderKL.from_pretrained(path_to_load)
 
 
 
 
 
330
 
331
  # set decoder to train
332
  self.vae.to(self.device, dtype=self.torch_dtype)
333
- self.vae.requires_grad_(False)
334
- self.vae.eval()
 
 
 
335
  self.vae.decoder.train()
336
  self.vae_scale_factor = 2 ** (len(self.vae.config['block_out_channels']) - 1)
337
 
@@ -374,6 +509,10 @@ class TrainVAEProcess(BaseTrainProcess):
374
  if train_all:
375
  params = list(self.vae.decoder.parameters())
376
  self.vae.decoder.requires_grad_(True)
 
 
 
 
377
  else:
378
  # mid_block
379
  if train_all or 'mid_block' in self.blocks_to_train:
@@ -388,12 +527,13 @@ class TrainVAEProcess(BaseTrainProcess):
388
  params += list(self.vae.decoder.conv_out.parameters())
389
  self.vae.decoder.conv_out.requires_grad_(True)
390
 
391
- if self.style_weight > 0 or self.content_weight > 0 or self.use_critic:
392
  self.setup_vgg19()
393
- self.vgg_19.requires_grad_(False)
394
  self.vgg_19.eval()
395
- if self.use_critic:
396
- self.critic.setup()
 
397
 
398
  if self.lpips_weight > 0 and self.lpips_loss is None:
399
  # self.lpips_loss = lpips.LPIPS(net='vgg')
@@ -426,6 +566,9 @@ class TrainVAEProcess(BaseTrainProcess):
426
  "style": [],
427
  "content": [],
428
  "mse": [],
 
 
 
429
  "kl": [],
430
  "tv": [],
431
  "ptn": [],
@@ -435,6 +578,9 @@ class TrainVAEProcess(BaseTrainProcess):
435
  epoch_losses = copy.deepcopy(blank_losses)
436
  log_losses = copy.deepcopy(blank_losses)
437
  # range start at self.epoch_num go to self.epochs
 
 
 
438
  for epoch in range(self.epoch_num, self.epochs, 1):
439
  if self.step_num >= self.max_steps:
440
  break
@@ -442,8 +588,20 @@ class TrainVAEProcess(BaseTrainProcess):
442
  if self.step_num >= self.max_steps:
443
  break
444
  with torch.no_grad():
445
-
446
  batch = batch.to(self.device, dtype=self.torch_dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
447
 
448
  # resize so it matches size of vae evenly
449
  if batch.shape[2] % self.vae_scale_factor != 0 or batch.shape[3] % self.vae_scale_factor != 0:
@@ -451,27 +609,92 @@ class TrainVAEProcess(BaseTrainProcess):
451
  batch.shape[3] // self.vae_scale_factor * self.vae_scale_factor))(batch)
452
 
453
  # forward pass
 
 
454
  dgd = self.vae.encode(batch).latent_dist
455
  mu, logvar = dgd.mean, dgd.logvar
456
  latents = dgd.sample()
457
- latents.detach().requires_grad_(True)
458
-
459
- pred = self.vae.decode(latents).sample
460
-
461
- with torch.no_grad():
462
- show_tensors(
463
- pred.clamp(-1, 1).clone(),
464
- "combined tensor"
465
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
 
467
  # Run through VGG19
468
- if self.style_weight > 0 or self.content_weight > 0 or self.use_critic:
469
  stacked = torch.cat([pred, batch], dim=0)
470
  stacked = (stacked / 2 + 0.5).clamp(0, 1)
471
  self.vgg_19(stacked)
472
 
473
  if self.use_critic:
474
- critic_d_loss = self.critic.step(self.vgg19_pool_4.tensor.detach())
 
475
  else:
476
  critic_d_loss = 0.0
477
 
@@ -489,7 +712,8 @@ class TrainVAEProcess(BaseTrainProcess):
489
  tv_loss = self.get_tv_loss(pred, batch) * self.tv_weight
490
  pattern_loss = self.get_pattern_loss(pred, batch) * self.pattern_weight
491
  if self.use_critic:
492
- critic_gen_loss = self.critic.get_critic_loss(self.vgg19_pool_4.tensor) * self.critic_weight
 
493
 
494
  # do not let abs critic gen loss be higher than abs lpips * 0.1 if using it
495
  if self.lpips_weight > 0:
@@ -502,8 +726,42 @@ class TrainVAEProcess(BaseTrainProcess):
502
  critic_gen_loss *= crit_g_scaler
503
  else:
504
  critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
505
-
506
- loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
507
 
508
  # Backward pass and optimization
509
  optimizer.zero_grad()
@@ -533,8 +791,17 @@ class TrainVAEProcess(BaseTrainProcess):
533
  loss_string += f" crG: {critic_gen_loss.item():.2e}"
534
  if self.use_critic:
535
  loss_string += f" crD: {critic_d_loss:.2e}"
536
-
537
- if self.optimizer_type.startswith('dadaptation') or \
 
 
 
 
 
 
 
 
 
538
  self.optimizer_type.lower().startswith('prodigy'):
539
  learning_rate = (
540
  optimizer.param_groups[0]["d"] *
@@ -562,6 +829,9 @@ class TrainVAEProcess(BaseTrainProcess):
562
  epoch_losses["ptn"].append(pattern_loss.item())
563
  epoch_losses["crG"].append(critic_gen_loss.item())
564
  epoch_losses["crD"].append(critic_d_loss)
 
 
 
565
 
566
  log_losses["total"].append(loss_value)
567
  log_losses["lpips"].append(lpips_loss.item())
@@ -573,6 +843,9 @@ class TrainVAEProcess(BaseTrainProcess):
573
  log_losses["ptn"].append(pattern_loss.item())
574
  log_losses["crG"].append(critic_gen_loss.item())
575
  log_losses["crD"].append(critic_d_loss)
 
 
 
576
 
577
  # don't do on first step
578
  if self.step_num != start_step:
@@ -609,4 +882,4 @@ class TrainVAEProcess(BaseTrainProcess):
609
  # reset epoch losses
610
  epoch_losses = copy.deepcopy(blank_losses)
611
 
612
- self.save()
 
7
 
8
  from PIL import Image
9
  from PIL.ImageOps import exif_transpose
10
+ from einops import rearrange
11
  from safetensors.torch import save_file, load_file
12
  from torch.utils.data import DataLoader, ConcatDataset
13
  import torch
 
18
  from toolkit.image_utils import show_tensors
19
  from toolkit.kohya_model_util import load_vae, convert_diffusers_back_to_ldm
20
  from toolkit.data_loader import ImageDataset
21
+ from toolkit.losses import ComparativeTotalVariation, get_gradient_penalty, PatternLoss, total_variation
22
  from toolkit.metadata import get_meta_for_safetensors
23
  from toolkit.optimizer import get_optimizer
24
  from toolkit.style import get_style_model_and_losses
25
  from toolkit.train_tools import get_torch_dtype
26
  from diffusers import AutoencoderKL
27
  from tqdm import tqdm
28
+ import math
29
+ import torchvision.utils
30
  import time
31
  import numpy as np
32
+ from .models.critic import Critic
33
  from torchvision.transforms import Resize
34
  import lpips
35
+ import random
36
+ import traceback
37
 
38
  IMAGE_TRANSFORMS = transforms.Compose(
39
  [
 
47
  return (tensor / 2 + 0.5).clamp(0, 1)
48
 
49
 
50
+ def channel_dropout(x, p=0.5):
51
+ keep_prob = 1 - p
52
+ mask = torch.rand(x.size(0), x.size(1), 1, 1, device=x.device, dtype=x.dtype) < keep_prob
53
+ mask = mask / keep_prob # scale
54
+ return x * mask
55
+
56
+
57
  class TrainVAEProcess(BaseTrainProcess):
58
  def __init__(self, process_id: int, job, config: OrderedDict):
59
  super().__init__(process_id, job, config)
60
  self.data_loader = None
61
  self.vae = None
62
  self.device = self.get_conf('device', self.job.device)
63
+ self.vae_path = self.get_conf('vae_path', None)
64
+ self.eq_vae = self.get_conf('eq_vae', False)
65
  self.datasets_objects = self.get_conf('datasets', required=True)
66
  self.batch_size = self.get_conf('batch_size', 1, as_type=int)
67
  self.resolution = self.get_conf('resolution', 256, as_type=int)
 
78
  self.content_weight = self.get_conf('content_weight', 0, as_type=float)
79
  self.kld_weight = self.get_conf('kld_weight', 0, as_type=float)
80
  self.mse_weight = self.get_conf('mse_weight', 1e0, as_type=float)
81
+ self.mv_loss_weight = self.get_conf('mv_loss_weight', 0, as_type=float)
82
+ self.tv_weight = self.get_conf('tv_weight', 0, as_type=float)
83
+ self.ltv_weight = self.get_conf('ltv_weight', 0, as_type=float)
84
+ self.lpm_weight = self.get_conf('lpm_weight', 0, as_type=float) # latent pixel matching
85
  self.lpips_weight = self.get_conf('lpips_weight', 1e0, as_type=float)
86
  self.critic_weight = self.get_conf('critic_weight', 1, as_type=float)
87
+ self.pattern_weight = self.get_conf('pattern_weight', 0, as_type=float)
88
  self.optimizer_params = self.get_conf('optimizer_params', {})
89
+ self.vae_config = self.get_conf('vae_config', None)
90
+ self.dropout = self.get_conf('dropout', 0.0, as_type=float)
91
+ self.train_encoder = self.get_conf('train_encoder', False, as_type=bool)
92
+ self.random_scaling = self.get_conf('random_scaling', False, as_type=bool)
93
+
94
+ if not self.train_encoder:
95
+ # remove losses that only target encoder
96
+ self.kld_weight = 0
97
+ self.mv_loss_weight = 0
98
+ self.ltv_weight = 0
99
+ self.lpm_weight = 0
100
 
101
  self.blocks_to_train = self.get_conf('blocks_to_train', ['all'])
102
  self.torch_dtype = get_torch_dtype(self.dtype)
 
160
  for dataset in self.datasets_objects:
161
  print(f" - Dataset: {dataset['path']}")
162
  ds = copy.copy(dataset)
163
+ dataset_res = self.resolution
164
+ if self.random_scaling:
165
+ # scale 2x to allow for random scaling
166
+ dataset_res = int(dataset_res * 2)
167
+ ds['resolution'] = dataset_res
168
  image_dataset = ImageDataset(ds)
169
  datasets.append(image_dataset)
170
 
 
173
  concatenated_dataset,
174
  batch_size=self.batch_size,
175
  shuffle=True,
176
+ num_workers=16
177
  )
178
 
179
  def remove_oldest_checkpoint(self):
 
184
  for folder in folders[:-max_to_keep]:
185
  print(f"Removing {folder}")
186
  shutil.rmtree(folder)
187
+ # also handle CRITIC_vae_42_000000500.safetensors format for critic
188
+ critic_files = glob.glob(os.path.join(self.save_root, f"CRITIC_{self.job.name}*.safetensors"))
189
+ if len(critic_files) > max_to_keep:
190
+ critic_files.sort(key=os.path.getmtime)
191
+ for file in critic_files[:-max_to_keep]:
192
+ print(f"Removing {file}")
193
+ os.remove(file)
194
 
195
  def setup_vgg19(self):
196
  if self.vgg_19 is None:
 
256
  else:
257
  return torch.tensor(0.0, device=self.device)
258
 
259
+ def get_mean_variance_loss(self, latents: torch.Tensor):
260
+ if self.mv_loss_weight > 0:
261
+ # collapse rows into channels
262
+ latents_col = rearrange(latents, 'b c h (gw w) -> b (c gw) h w', gw=latents.shape[-1])
263
+ mean_col = latents_col.mean(dim=(2, 3), keepdim=True)
264
+ std_col = latents_col.std(dim=(2, 3), keepdim=True, unbiased=False)
265
+ mean_loss_col = (mean_col ** 2).mean()
266
+ std_loss_col = ((std_col - 1) ** 2).mean()
267
+
268
+ # collapse columns into channels
269
+ latents_row = rearrange(latents, 'b c (gh h) w -> b (c gh) h w', gh=latents.shape[-2])
270
+ mean_row = latents_row.mean(dim=(2, 3), keepdim=True)
271
+ std_row = latents_row.std(dim=(2, 3), keepdim=True, unbiased=False)
272
+ mean_loss_row = (mean_row ** 2).mean()
273
+ std_loss_row = ((std_row - 1) ** 2).mean()
274
+
275
+ # do a global one
276
+
277
+ mean = latents.mean(dim=(2, 3), keepdim=True)
278
+ std = latents.std(dim=(2, 3), keepdim=True, unbiased=False)
279
+ mean_loss_global = (mean ** 2).mean()
280
+ std_loss_global = ((std - 1) ** 2).mean()
281
+
282
+ return (mean_loss_col + std_loss_col + mean_loss_row + std_loss_row + mean_loss_global + std_loss_global) / 3
283
+ else:
284
+ return torch.tensor(0.0, device=self.device)
285
+
286
+ def get_ltv_loss(self, latent):
287
+ # loss to reduce the latent space variance
288
+ if self.ltv_weight > 0:
289
+ return total_variation(latent).mean()
290
+ else:
291
+ return torch.tensor(0.0, device=self.device)
292
+
293
+ def get_latent_pixel_matching_loss(self, latent, pixels):
294
+ if self.lpm_weight > 0:
295
+ with torch.no_grad():
296
+ pixels = pixels.to(latent.device, dtype=latent.dtype)
297
+ # resize down to latent size
298
+ pixels = torch.nn.functional.interpolate(pixels, size=(latent.shape[2], latent.shape[3]), mode='bilinear', align_corners=False)
299
+
300
+ # mean the color channel and then expand to latent size
301
+ pixels = pixels.mean(dim=1, keepdim=True)
302
+ pixels = pixels.repeat(1, latent.shape[1], 1, 1)
303
+ # match the mean std of latent
304
+ latent_mean = latent.mean(dim=(2, 3), keepdim=True)
305
+ latent_std = latent.std(dim=(2, 3), keepdim=True)
306
+ pixels_mean = pixels.mean(dim=(2, 3), keepdim=True)
307
+ pixels_std = pixels.std(dim=(2, 3), keepdim=True)
308
+ pixels = (pixels - pixels_mean) / (pixels_std + 1e-6) * latent_std + latent_mean
309
+
310
+ return torch.nn.functional.mse_loss(latent.float(), pixels.float())
311
+
312
+ else:
313
+ return torch.tensor(0.0, device=self.device)
314
+
315
  def get_tv_loss(self, pred, target):
316
  if self.tv_weight > 0:
317
  get_tv_loss = ComparativeTotalVariation()
 
371
  input_img = img
372
  img = IMAGE_TRANSFORMS(img).unsqueeze(0).to(self.device, dtype=self.torch_dtype)
373
  img = img
374
+ latent = self.vae.encode(img).latent_dist.sample()
375
+
376
+ latent_img = latent.clone()
377
+ bs, ch, h, w = latent_img.shape
378
+ grid_size = math.ceil(math.sqrt(ch))
379
+ pad = grid_size * grid_size - ch
380
+
381
+ # take first item in batch
382
+ latent_img = latent_img[0] # shape: (ch, h, w)
383
+
384
+ if pad > 0:
385
+ padding = torch.zeros((pad, h, w), dtype=latent_img.dtype, device=latent_img.device)
386
+ latent_img = torch.cat([latent_img, padding], dim=0)
387
+
388
+ # make grid
389
+ new_img = torch.zeros((1, grid_size * h, grid_size * w), dtype=latent_img.dtype, device=latent_img.device)
390
+ for x in range(grid_size):
391
+ for y in range(grid_size):
392
+ if x * grid_size + y < ch:
393
+ new_img[0, x * h:(x + 1) * h, y * w:(y + 1) * w] = latent_img[x * grid_size + y]
394
+ latent_img = new_img
395
+ # make rgb
396
+ latent_img = latent_img.repeat(3, 1, 1).unsqueeze(0)
397
+ latent_img = (latent_img / 2 + 0.5).clamp(0, 1)
398
+
399
+ # resize to 256x256
400
+ latent_img = torch.nn.functional.interpolate(latent_img, size=(self.resolution, self.resolution), mode='nearest')
401
+ latent_img = latent_img.squeeze(0).cpu().permute(1, 2, 0).float().numpy()
402
+ latent_img = (latent_img * 255).astype(np.uint8)
403
+ # convert to pillow image
404
+ latent_img = Image.fromarray(latent_img)
405
+
406
+ decoded = self.vae.decode(latent).sample
407
  decoded = (decoded / 2 + 0.5).clamp(0, 1)
408
  # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
409
  decoded = decoded.cpu().permute(0, 2, 3, 1).squeeze(0).float().numpy()
 
415
  input_img = input_img.resize((self.resolution, self.resolution))
416
  decoded = decoded.resize((self.resolution, self.resolution))
417
 
418
+ output_img = Image.new('RGB', (self.resolution * 3, self.resolution))
419
  output_img.paste(input_img, (0, 0))
420
  output_img.paste(decoded, (self.resolution, 0))
421
+ output_img.paste(latent_img, (self.resolution * 2, 0))
422
 
423
  scale_up = 2
424
  if output_img.height <= 300:
 
453
  self.print(f"Loading VAE")
454
  self.print(f" - Loading VAE: {path_to_load}")
455
  if self.vae is None:
456
+ if path_to_load is not None:
457
+ self.vae = AutoencoderKL.from_pretrained(path_to_load)
458
+ elif self.vae_config is not None:
459
+ self.vae = AutoencoderKL(**self.vae_config)
460
+ else:
461
+ raise ValueError('vae_path or ae_config must be specified')
462
 
463
  # set decoder to train
464
  self.vae.to(self.device, dtype=self.torch_dtype)
465
+ if self.eq_vae:
466
+ self.vae.encoder.train()
467
+ else:
468
+ self.vae.requires_grad_(False)
469
+ self.vae.eval()
470
  self.vae.decoder.train()
471
  self.vae_scale_factor = 2 ** (len(self.vae.config['block_out_channels']) - 1)
472
 
 
509
  if train_all:
510
  params = list(self.vae.decoder.parameters())
511
  self.vae.decoder.requires_grad_(True)
512
+ if self.train_encoder:
513
+ # encoder
514
+ params += list(self.vae.encoder.parameters())
515
+ self.vae.encoder.requires_grad_(True)
516
  else:
517
  # mid_block
518
  if train_all or 'mid_block' in self.blocks_to_train:
 
527
  params += list(self.vae.decoder.conv_out.parameters())
528
  self.vae.decoder.conv_out.requires_grad_(True)
529
 
530
+ if self.style_weight > 0 or self.content_weight > 0:
531
  self.setup_vgg19()
532
+ # self.vgg_19.requires_grad_(False)
533
  self.vgg_19.eval()
534
+
535
+ if self.use_critic:
536
+ self.critic.setup()
537
 
538
  if self.lpips_weight > 0 and self.lpips_loss is None:
539
  # self.lpips_loss = lpips.LPIPS(net='vgg')
 
566
  "style": [],
567
  "content": [],
568
  "mse": [],
569
+ "mvl": [],
570
+ "ltv": [],
571
+ "lpm": [],
572
  "kl": [],
573
  "tv": [],
574
  "ptn": [],
 
578
  epoch_losses = copy.deepcopy(blank_losses)
579
  log_losses = copy.deepcopy(blank_losses)
580
  # range start at self.epoch_num go to self.epochs
581
+
582
+ latent_size = self.resolution // self.vae_scale_factor
583
+
584
  for epoch in range(self.epoch_num, self.epochs, 1):
585
  if self.step_num >= self.max_steps:
586
  break
 
588
  if self.step_num >= self.max_steps:
589
  break
590
  with torch.no_grad():
 
591
  batch = batch.to(self.device, dtype=self.torch_dtype)
592
+
593
+ if self.random_scaling:
594
+ # only random scale 0.5 of the time
595
+ if random.random() < 0.5:
596
+ # random scale the batch
597
+ scale_factor = 0.25
598
+ else:
599
+ scale_factor = 0.5
600
+ new_size = (int(batch.shape[2] * scale_factor), int(batch.shape[3] * scale_factor))
601
+ # make sure it is vae divisible
602
+ new_size = (new_size[0] // self.vae_scale_factor * self.vae_scale_factor,
603
+ new_size[1] // self.vae_scale_factor * self.vae_scale_factor)
604
+
605
 
606
  # resize so it matches size of vae evenly
607
  if batch.shape[2] % self.vae_scale_factor != 0 or batch.shape[3] % self.vae_scale_factor != 0:
 
609
  batch.shape[3] // self.vae_scale_factor * self.vae_scale_factor))(batch)
610
 
611
  # forward pass
612
+ # grad only if eq_vae
613
+ with torch.set_grad_enabled(self.train_encoder):
614
  dgd = self.vae.encode(batch).latent_dist
615
  mu, logvar = dgd.mean, dgd.logvar
616
  latents = dgd.sample()
617
+
618
+ if self.eq_vae:
619
+ # process flips, rotate, scale
620
+ latent_chunks = list(latents.chunk(latents.shape[0], dim=0))
621
+ batch_chunks = list(batch.chunk(batch.shape[0], dim=0))
622
+ out_chunks = []
623
+ for i in range(len(latent_chunks)):
624
+ try:
625
+ do_rotate = random.randint(0, 3)
626
+ do_flip_x = random.randint(0, 1)
627
+ do_flip_y = random.randint(0, 1)
628
+ do_scale = random.randint(0, 1)
629
+ if do_rotate > 0:
630
+ latent_chunks[i] = torch.rot90(latent_chunks[i], do_rotate, (2, 3))
631
+ batch_chunks[i] = torch.rot90(batch_chunks[i], do_rotate, (2, 3))
632
+ if do_flip_x > 0:
633
+ latent_chunks[i] = torch.flip(latent_chunks[i], [2])
634
+ batch_chunks[i] = torch.flip(batch_chunks[i], [2])
635
+ if do_flip_y > 0:
636
+ latent_chunks[i] = torch.flip(latent_chunks[i], [3])
637
+ batch_chunks[i] = torch.flip(batch_chunks[i], [3])
638
+
639
+ # resize latent to fit
640
+ if latent_chunks[i].shape[2] != latent_size or latent_chunks[i].shape[3] != latent_size:
641
+ latent_chunks[i] = torch.nn.functional.interpolate(latent_chunks[i], size=(latent_size, latent_size), mode='bilinear', align_corners=False)
642
+
643
+ # if do_scale > 0:
644
+ # scale = 2
645
+ # start_latent_h = latent_chunks[i].shape[2]
646
+ # start_latent_w = latent_chunks[i].shape[3]
647
+ # start_batch_h = batch_chunks[i].shape[2]
648
+ # start_batch_w = batch_chunks[i].shape[3]
649
+ # latent_chunks[i] = torch.nn.functional.interpolate(latent_chunks[i], scale_factor=scale, mode='bilinear', align_corners=False)
650
+ # batch_chunks[i] = torch.nn.functional.interpolate(batch_chunks[i], scale_factor=scale, mode='bilinear', align_corners=False)
651
+ # # random crop. latent is smaller than match but crops need to match
652
+ # latent_x = random.randint(0, latent_chunks[i].shape[2] - start_latent_h)
653
+ # latent_y = random.randint(0, latent_chunks[i].shape[3] - start_latent_w)
654
+ # batch_x = latent_x * self.vae_scale_factor
655
+ # batch_y = latent_y * self.vae_scale_factor
656
+
657
+ # # crop
658
+ # latent_chunks[i] = latent_chunks[i][:, :, latent_x:latent_x + start_latent_h, latent_y:latent_y + start_latent_w]
659
+ # batch_chunks[i] = batch_chunks[i][:, :, batch_x:batch_x + start_batch_h, batch_y:batch_y + start_batch_w]
660
+ except Exception as e:
661
+ print(f"Error processing image {i}: {e}")
662
+ traceback.print_exc()
663
+ raise e
664
+ out_chunks.append(latent_chunks[i])
665
+ latents = torch.cat(out_chunks, dim=0)
666
+ # do dropout
667
+ if self.dropout > 0:
668
+ forward_latents = channel_dropout(latents, self.dropout)
669
+ else:
670
+ forward_latents = latents
671
+
672
+ # resize batch to resolution if needed
673
+ if batch_chunks[0].shape[2] != self.resolution or batch_chunks[0].shape[3] != self.resolution:
674
+ batch_chunks = [torch.nn.functional.interpolate(b, size=(self.resolution, self.resolution), mode='bilinear', align_corners=False) for b in batch_chunks]
675
+ batch = torch.cat(batch_chunks, dim=0)
676
+
677
+ else:
678
+ latents.detach().requires_grad_(True)
679
+ forward_latents = latents
680
+
681
+ forward_latents = forward_latents.to(self.device, dtype=self.torch_dtype)
682
+
683
+ if not self.train_encoder:
684
+ # detach latents if not training encoder
685
+ forward_latents = forward_latents.detach()
686
+
687
+ pred = self.vae.decode(forward_latents).sample
688
 
689
  # Run through VGG19
690
+ if self.style_weight > 0 or self.content_weight > 0:
691
  stacked = torch.cat([pred, batch], dim=0)
692
  stacked = (stacked / 2 + 0.5).clamp(0, 1)
693
  self.vgg_19(stacked)
694
 
695
  if self.use_critic:
696
+ stacked = torch.cat([pred, batch], dim=0)
697
+ critic_d_loss = self.critic.step(stacked.detach())
698
  else:
699
  critic_d_loss = 0.0
700
 
 
712
  tv_loss = self.get_tv_loss(pred, batch) * self.tv_weight
713
  pattern_loss = self.get_pattern_loss(pred, batch) * self.pattern_weight
714
  if self.use_critic:
715
+ stacked = torch.cat([pred, batch], dim=0)
716
+ critic_gen_loss = self.critic.get_critic_loss(stacked) * self.critic_weight
717
 
718
  # do not let abs critic gen loss be higher than abs lpips * 0.1 if using it
719
  if self.lpips_weight > 0:
 
726
  critic_gen_loss *= crit_g_scaler
727
  else:
728
  critic_gen_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
729
+
730
+ if self.mv_loss_weight > 0:
731
+ mv_loss = self.get_mean_variance_loss(latents) * self.mv_loss_weight
732
+ else:
733
+ mv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
734
+
735
+ if self.ltv_weight > 0:
736
+ ltv_loss = self.get_ltv_loss(latents) * self.ltv_weight
737
+ else:
738
+ ltv_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
739
+
740
+ if self.lpm_weight > 0:
741
+ lpm_loss = self.get_latent_pixel_matching_loss(latents, batch) * self.lpm_weight
742
+ else:
743
+ lpm_loss = torch.tensor(0.0, device=self.device, dtype=self.torch_dtype)
744
+
745
+ loss = style_loss + content_loss + kld_loss + mse_loss + tv_loss + critic_gen_loss + pattern_loss + lpips_loss + mv_loss + ltv_loss
746
+
747
+ # check if loss is NaN or Inf
748
+ if torch.isnan(loss) or torch.isinf(loss):
749
+ self.print(f"Loss is NaN or Inf, stopping at step {self.step_num}")
750
+ self.print(f" - Style loss: {style_loss.item()}")
751
+ self.print(f" - Content loss: {content_loss.item()}")
752
+ self.print(f" - KLD loss: {kld_loss.item()}")
753
+ self.print(f" - MSE loss: {mse_loss.item()}")
754
+ self.print(f" - LPIPS loss: {lpips_loss.item()}")
755
+ self.print(f" - TV loss: {tv_loss.item()}")
756
+ self.print(f" - Pattern loss: {pattern_loss.item()}")
757
+ self.print(f" - Critic gen loss: {critic_gen_loss.item()}")
758
+ self.print(f" - Critic D loss: {critic_d_loss}")
759
+ self.print(f" - Mean variance loss: {mv_loss.item()}")
760
+ self.print(f" - Latent TV loss: {ltv_loss.item()}")
761
+ self.print(f" - Latent pixel matching loss: {lpm_loss.item()}")
762
+ self.print(f" - Total loss: {loss.item()}")
763
+ self.print(f" - Stopping training")
764
+ exit(1)
765
 
766
  # Backward pass and optimization
767
  optimizer.zero_grad()
 
791
  loss_string += f" crG: {critic_gen_loss.item():.2e}"
792
  if self.use_critic:
793
  loss_string += f" crD: {critic_d_loss:.2e}"
794
+ if self.mv_loss_weight > 0:
795
+ loss_string += f" mvl: {mv_loss:.2e}"
796
+ if self.ltv_weight > 0:
797
+ loss_string += f" ltv: {ltv_loss:.2e}"
798
+ if self.lpm_weight > 0:
799
+ loss_string += f" lpm: {lpm_loss:.2e}"
800
+
801
+
802
+ if hasattr(optimizer, 'get_avg_learning_rate'):
803
+ learning_rate = optimizer.get_avg_learning_rate()
804
+ elif self.optimizer_type.startswith('dadaptation') or \
805
  self.optimizer_type.lower().startswith('prodigy'):
806
  learning_rate = (
807
  optimizer.param_groups[0]["d"] *
 
829
  epoch_losses["ptn"].append(pattern_loss.item())
830
  epoch_losses["crG"].append(critic_gen_loss.item())
831
  epoch_losses["crD"].append(critic_d_loss)
832
+ epoch_losses["mvl"].append(mv_loss.item())
833
+ epoch_losses["ltv"].append(ltv_loss.item())
834
+ epoch_losses["lpm"].append(lpm_loss.item())
835
 
836
  log_losses["total"].append(loss_value)
837
  log_losses["lpips"].append(lpips_loss.item())
 
843
  log_losses["ptn"].append(pattern_loss.item())
844
  log_losses["crG"].append(critic_gen_loss.item())
845
  log_losses["crD"].append(critic_d_loss)
846
+ log_losses["mvl"].append(mv_loss.item())
847
+ log_losses["ltv"].append(ltv_loss.item())
848
+ log_losses["lpm"].append(lpm_loss.item())
849
 
850
  # don't do on first step
851
  if self.step_num != start_step:
 
882
  # reset epoch losses
883
  epoch_losses = copy.deepcopy(blank_losses)
884
 
885
+ self.save()