Muhammad Naufal Rizqullah commited on
Commit
e61c431
·
1 Parent(s): 439c972

Experiment 2

Browse files
config/core.py CHANGED
@@ -2,15 +2,16 @@ from pydantic_settings import BaseSettings
2
 
3
  class Config(BaseSettings):
4
  IMAGE_CHANNEL: int = 3
 
5
  NUM_CLASSES: int = 3
6
  IMAGE_SIZE: int = 128
7
- FEATURES_DISCRIMINATOR: int = 64
8
- FEATURES_GENERATOR: int = 64
9
- EMBED_SIZE: int = 64
10
- INPUT_Z_DIM: int = 64
11
- BATCH_SIZE: int = 128
12
  DISPLAY_STEP: int = 500
13
- MAX_SAMPLES: int = 3000
14
 
15
  LEARNING_RATE: float = 0.0002
16
  BETA_1: float = 0.5
@@ -22,8 +23,8 @@ class Config(BaseSettings):
22
  CRITIC_REPEAT: int = 3
23
 
24
  LOAD_CHECKPOINT: bool = True
25
- PATH_DATASET: str = ""
26
- CKPT_PATH: str = "./weights/epoch=957-step=1164300.ckpt"
27
 
28
  OPTIONS_MAPPING: dict = {
29
  "Boot": 0,
 
2
 
3
  class Config(BaseSettings):
4
  IMAGE_CHANNEL: int = 3
5
+ LABEL_CHANNEL: int = 3
6
  NUM_CLASSES: int = 3
7
  IMAGE_SIZE: int = 128
8
+ FEATURES_DISCRIMINATOR: int = 64 * 2
9
+ FEATURES_GENERATOR: int = 64 * 2
10
+ EMBED_SIZE: int = 30 + 20
11
+ INPUT_Z_DIM: int = 64 * 2
12
+ BATCH_SIZE: int = 20
13
  DISPLAY_STEP: int = 500
14
+ MAX_SAMPLES: int = 2500
15
 
16
  LEARNING_RATE: float = 0.0002
17
  BETA_1: float = 0.5
 
23
  CRITIC_REPEAT: int = 3
24
 
25
  LOAD_CHECKPOINT: bool = True
26
+ PATH_DATASET: str = "/kaggle/input/shoe-vs-sandal-vs-boot-dataset-15k-images/Shoe vs Sandal vs Boot Dataset"
27
+ CKPT_PATH: str = "./weights/epoch=299-step=450000.ckpt"
28
 
29
  OPTIONS_MAPPING: dict = {
30
  "Boot": 0,
data/dataloader.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import lightning as L
3
+ import torchvision.transforms as T
4
+ import os
5
+
6
+ from config.core import config
7
+ from torch.utils.data import DataLoader
8
+ from torchvision.datasets import ImageFolder
9
+ from utility.helper import PadToSquare
10
+
11
+ class ShoeSandalBoot(L.LightningDataModule):
12
+ def __init__(
13
+ self,
14
+ dataset_directory,
15
+ image_size=config.image_size,
16
+ batch_size=config.image_size,
17
+ max_samples=None
18
+ ):
19
+ super().__init__()
20
+
21
+ self.data_dir = dataset_directory
22
+ self.bs = batch_size
23
+ self.max_samples = max_samples # to limit dataset.
24
+
25
+ self.transforms = T.Compose([
26
+ # T.Resize(size=(image_size, image_size)),
27
+ PadToSquare(image_size),
28
+ T.ToTensor(),
29
+ T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
30
+ ])
31
+
32
+ self.ssb = None
33
+
34
+ def prepare_data(self):
35
+ pass
36
+
37
+ def setup(self, stage):
38
+ if stage == "fit":
39
+ dataset = ImageFolder(
40
+ root=self.data_dir,
41
+ transform=self.transforms
42
+ )
43
+
44
+ # To Limit Dataset
45
+ if self.max_samples:
46
+ print(f"[INFO] Dataset is Limited to {self.max_samples} Samples")
47
+ self.ssb = torch.utils.data.Subset(dataset, range(min(len(dataset), self.max_samples)))
48
+ else:
49
+ self.ssb = dataset
50
+
51
+ def train_dataloader(self):
52
+ return DataLoader(self.ssb, batch_size=self.bs, num_workers=os.cpu_count(), shuffle=True)
models/discriminator.py CHANGED
@@ -2,63 +2,78 @@ import torch
2
  import torch.nn as nn
3
 
4
  from .base import WSConv2d, ConvBlock
 
5
 
6
 
7
  class Discriminator(nn.Module):
8
- def __init__(self, num_classes=3, image_size=128, features_discriminator=128, image_channel=3):
9
  super().__init__()
10
-
11
  self.num_classes = num_classes
12
  self.image_size = image_size
13
- label_channel = 1
14
-
 
 
15
  self.disc = nn.Sequential(
16
- self._block_discriminator(image_channel + label_channel, features_discriminator, kernel_size=4, stride=2,
17
- padding=1),
18
- self._block_discriminator(features_discriminator, features_discriminator, kernel_size=4, stride=2,
19
- padding=1),
20
- self._block_discriminator(features_discriminator, features_discriminator * 2, kernel_size=4, stride=2,
21
- padding=1),
22
- self._block_discriminator(features_discriminator * 2, features_discriminator * 4, kernel_size=4, stride=2,
23
- padding=1),
24
- self._block_discriminator(features_discriminator * 4, features_discriminator * 4, kernel_size=4, stride=2,
25
- padding=1),
26
- self._block_discriminator(features_discriminator * 4, 1, kernel_size=4, stride=1, padding=0,
27
- final_layer=True)
28
  )
29
-
30
- self.embed = nn.Embedding(num_classes, image_size * image_size)
 
31
 
32
  def forward(self, image, label):
33
  embedding = self.embed(label)
34
- embedding = embedding.view(
 
 
 
35
  label.shape[0],
36
- 1,
37
  self.image_size,
38
  self.image_size
39
  )
40
-
41
- data = torch.cat([image, embedding], dim=1)
42
-
43
  x = self.disc(data)
44
-
45
  return x.view(len(x), -1)
46
 
47
  def _block_discriminator(
48
- self,
49
- input_channels,
50
- output_channels,
51
- kernel_size=3,
52
- stride=2,
53
- padding=0,
54
- final_layer=False
55
  ):
56
- if not final_layer:
57
  return nn.Sequential(
58
  ConvBlock(input_channels, output_channels),
59
- WSConv2d(output_channels, output_channels, kernel_size, stride, padding)
60
  )
61
  else:
62
  return WSConv2d(input_channels, output_channels, kernel_size, stride, padding)
 
 
 
 
63
 
 
 
 
 
 
 
 
 
64
 
 
 
 
2
  import torch.nn as nn
3
 
4
  from .base import WSConv2d, ConvBlock
5
+ from config.core import config
6
 
7
 
8
  class Discriminator(nn.Module):
9
+ def __init__(self, num_classes=3, embed_size=128, image_size=128, features_discriminator=128, image_channel=3, label_channel=3):
10
  super().__init__()
11
+
12
  self.num_classes = num_classes
13
  self.image_size = image_size
14
+
15
+ self.embed_size = embed_size
16
+ self.label_channel = label_channel
17
+
18
  self.disc = nn.Sequential(
19
+ self._block_discriminator(image_channel + label_channel, features_discriminator, kernel_size=4, stride=2, padding=1),
20
+ self._block_discriminator(features_discriminator, features_discriminator, kernel_size=4, stride=2, padding=1),
21
+ self._block_discriminator(features_discriminator, features_discriminator * 2, kernel_size=4, stride=2, padding=1),
22
+ self._block_discriminator(features_discriminator * 2, features_discriminator * 4, kernel_size=4, stride=2, padding=1),
23
+ self._block_discriminator(features_discriminator * 4, features_discriminator *4 , kernel_size=4, stride=2, padding=1),
24
+ self._block_discriminator(features_discriminator * 4, 1, kernel_size=4, stride=1, padding=0, final_layer=True)
 
 
 
 
 
 
25
  )
26
+
27
+ self.embed = nn.Embedding(num_classes, embed_size)
28
+ self.embed_linear = nn.Linear(embed_size, label_channel*image_size*image_size)
29
 
30
  def forward(self, image, label):
31
  embedding = self.embed(label)
32
+
33
+ linear_embedding = self.embed_linear(embedding)
34
+
35
+ embedding_layer = linear_embedding.view(
36
  label.shape[0],
37
+ self.label_channel,
38
  self.image_size,
39
  self.image_size
40
  )
41
+
42
+ data = torch.cat([image, embedding_layer], dim=1)
43
+
44
  x = self.disc(data)
45
+
46
  return x.view(len(x), -1)
47
 
48
  def _block_discriminator(
49
+ self,
50
+ input_channels,
51
+ output_channels,
52
+ kernel_size=3,
53
+ stride=2,
54
+ padding=0,
55
+ final_layer=False
56
  ):
57
+ if not final_layer:
58
  return nn.Sequential(
59
  ConvBlock(input_channels, output_channels),
60
+ WSConv2d(output_channels, output_channels, kernel_size, stride, padding),
61
  )
62
  else:
63
  return WSConv2d(input_channels, output_channels, kernel_size, stride, padding)
64
+
65
+ def test():
66
+ sample = torch.randn(1, 3, 128, 128)
67
+ label = torch.tensor([1])
68
 
69
+ model = Discriminator(
70
+ num_classes=config.NUM_CLASSES,
71
+ embed_size=config.EMBED_SIZE,
72
+ image_size=config.IMAGE_SIZE,
73
+ features_discriminator=config.FEATURES_DISCRIMINATOR,
74
+ image_channel=config.IMAGE_CHANNEL,
75
+ label_channel=config.LABEL_CHANNEL
76
+ )
77
 
78
+ preds = model(sample, label)
79
+ print(preds.shape)
models/generator.py CHANGED
@@ -2,40 +2,42 @@ import torch
2
  import torch.nn as nn
3
 
4
  from .base import WSConv2d, ConvBlock, PixelNorm
 
5
 
6
 
7
  class Generator(nn.Module):
8
- def __init__(self, embed_size=128, num_classes=3, image_size=128, features_generator=128, input_dim=128, image_channel=3):
9
  super().__init__()
10
 
11
  self.gen = nn.Sequential(
12
- self._block(input_dim + embed_size, features_generator * 2, first_double_up=True),
13
- self._block(features_generator * 2, features_generator * 4, first_double_up=False, final_layer=False, ),
14
- self._block(features_generator * 4, features_generator * 4, first_double_up=False, final_layer=False, ),
15
- self._block(features_generator * 4, features_generator * 4, first_double_up=False, final_layer=False, ),
16
- self._block(features_generator * 4, features_generator * 2, first_double_up=False, final_layer=False, ),
17
- self._block(features_generator * 2, features_generator, first_double_up=False, final_layer=False, ),
18
- self._block(features_generator, image_channel, first_double_up=False, use_double=False, final_layer=True, ),
19
  )
20
-
21
  self.image_size = image_size
22
  self.embed_size = embed_size
23
-
24
  self.embed = nn.Embedding(num_classes, embed_size)
 
25
 
26
  def forward(self, noise, labels):
27
- embedding_label = self.embed(labels).unsqueeze(2).unsqueeze(
28
- 3) # Add height and width channel; N x Noise_dim x 1 x 1
29
-
30
- # Noise is 4 channel, or 2 channel. later will decide
31
- noise = noise.view(noise.size(0), noise.size(1), 1, 1) # Reshape to (batch_size, z_dim, 1, 1)
32
-
33
- x = torch.cat([noise, embedding_label], dim=1)
34
-
35
  return self.gen(x)
36
 
 
37
  def _block(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,
38
- first_double_up=False, use_double=True, final_layer=False):
39
  layers = []
40
 
41
  if not final_layer:
@@ -55,3 +57,17 @@ class Generator(nn.Module):
55
 
56
  return nn.Sequential(*layers)
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch.nn as nn
3
 
4
  from .base import WSConv2d, ConvBlock, PixelNorm
5
+ from config.core import config
6
 
7
 
8
  class Generator(nn.Module):
9
+ def __init__(self, embed_size=128, num_classes=3, image_size=128, features_generator=128, input_dim=128, image_channel=3):
10
  super().__init__()
11
 
12
  self.gen = nn.Sequential(
13
+ self._block(input_dim + embed_size, features_generator*2, first_double_up=True),
14
+ self._block(features_generator*2, features_generator*4, first_double_up=False, final_layer=False,),
15
+ self._block(features_generator*4, features_generator*4, first_double_up=False, final_layer=False,),
16
+ self._block(features_generator*4, features_generator*4, first_double_up=False, final_layer=False,),
17
+ self._block(features_generator*4, features_generator*2, first_double_up=False, final_layer=False,),
18
+ self._block(features_generator*2, features_generator, first_double_up=False, final_layer=False,),
19
+ self._block(features_generator, image_channel, first_double_up=False, use_double=False, final_layer=True,),
20
  )
21
+
22
  self.image_size = image_size
23
  self.embed_size = embed_size
24
+
25
  self.embed = nn.Embedding(num_classes, embed_size)
26
+ self.embed_linear = nn.Linear(embed_size, embed_size)
27
 
28
  def forward(self, noise, labels):
29
+ embedding_label = self.embed(labels)
30
+ linear_embedding_label = self.embed_linear(embedding_label).unsqueeze(2).unsqueeze(3)
31
+
32
+ noise = noise.view(noise.size(0), noise.size(1), 1, 1)
33
+
34
+ x = torch.cat([noise, linear_embedding_label], dim=1)
35
+
 
36
  return self.gen(x)
37
 
38
+
39
  def _block(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,
40
+ first_double_up=False, use_double=True, final_layer=False):
41
  layers = []
42
 
43
  if not final_layer:
 
57
 
58
  return nn.Sequential(*layers)
59
 
60
+ def test():
61
+ sample = torch.randn(1, config.INPUT_Z_DIM, 1, 1)
62
+ label = torch.tensor([1])
63
+
64
+ model = Generator(
65
+ embed_size=config.EMBED_SIZE,
66
+ num_classes=config.NUM_CLASSES,
67
+ image_size=config.IMAGE_SIZE,
68
+ features_generator=config.FEATURES_GENERATOR,
69
+ input_dim=config.INPUT_Z_DIM,
70
+ )
71
+
72
+ preds = model(sample, label)
73
+ print(preds.shape)
requirements.txt CHANGED
@@ -3,4 +3,6 @@ pytorch-lightning
3
  python-multipart
4
  fastapi
5
  pydantic
6
- pydantic-settings
 
 
 
3
  python-multipart
4
  fastapi
5
  pydantic
6
+ pydantic-settings
7
+ opencv-python==4.10.0
8
+ imageio==2.33.1
training/train.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import lightning as L
3
+ import torch.optim as optim
4
+
5
+ from models.generator import Generator
6
+ from models.discriminator import Discriminator
7
+ from utility.helper import initialize_weights, plot_images_from_tensor
8
+ from utility.wgan_gp import gradient_penalty, calculate_generator_loss, calculate_critic_loss
9
+
10
+
11
+
12
+ class ConditionalWGAN_GP(L.LightningModule):
13
+ def __init__(self, image_channel, label_channel, image_size, learning_rate, z_dim, embed_size, num_classes, critic_repeats, feature_gen, feature_critic, c_lambda, beta_1, beta_2, display_step):
14
+ super().__init__()
15
+
16
+ self.automatic_optimization = False
17
+
18
+ self.image_size = image_size
19
+ self.critic_repeats = critic_repeats
20
+ self.c_lambda = c_lambda
21
+
22
+ self.generator = Generator(
23
+ embed_size=embed_size,
24
+ num_classes=num_classes,
25
+ image_size=image_size,
26
+ features_generator=feature_gen,
27
+ input_dim=z_dim,
28
+ )
29
+
30
+ self.critic = Discriminator(
31
+ num_classes=num_classes,
32
+ embed_size=embed_size,
33
+ image_size=image_size,
34
+ features_discriminator=feature_critic,
35
+ image_channel=image_channel,
36
+ label_channel=label_channel,
37
+ )
38
+
39
+
40
+ self.critic_losses = []
41
+ self.generator_losses = []
42
+ self.curr_step = 0
43
+
44
+ self.fixed_latent_space = torch.randn(25, z_dim, 1, 1)
45
+ self.fixed_label = torch.tensor([i % num_classes for i in range(25)])
46
+
47
+ self.save_hyperparameters()
48
+
49
+ def configure_optimizers(self):
50
+ # READ: https://lightning.ai/docs/pytorch/stable/common/optimization.html#use-multiple-optimizers-like-gans
51
+ # READ: https://lightning.ai/docs/pytorch/stable/model/manual_optimization.html
52
+ # READ: https://lightning.ai/docs/pytorch/stable/model/build_model_advanced.html
53
+ # READ: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningModule.html#lightning.pytorch.core.LightningModule.backward
54
+ # READ: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#manual-backward
55
+ optimizer_G = optim.Adam(self.generator.parameters(), lr=self.hparams.learning_rate, betas=(self.hparams.beta_1, self.hparams.beta_2))
56
+ optimizer_C = optim.Adam(self.critic.parameters(), lr=self.hparams.learning_rate, betas=(self.hparams.beta_1, self.hparams.beta_2))
57
+
58
+ return optimizer_G, optimizer_C
59
+
60
+ def on_load_checkpoint(self, checkpoint):
61
+ # List of keys that you expect to load from the checkpoint
62
+ keys_to_load = ['critic_losses', 'generator_losses', 'curr_step', 'fixed_latent_space', 'fixed_label']
63
+
64
+ # Iterate over the keys and load them if they exist in the checkpoint
65
+ for key in keys_to_load:
66
+ if key in checkpoint:
67
+ setattr(self, key, checkpoint[key])
68
+
69
+ def on_save_checkpoint(self, checkpoint):
70
+ # Save necessary variable to checkpoint
71
+ checkpoint['critic_losses'] = self.critic_losses
72
+ checkpoint['generator_losses'] = self.generator_losses
73
+ checkpoint['curr_step'] = self.curr_step
74
+ checkpoint['fixed_latent_space'] = self.fixed_latent_space
75
+ checkpoint['fixed_label'] = self.fixed_label
76
+
77
+ def on_train_start(self):
78
+ if self.current_epoch == 0:
79
+ self.generator.apply(initialize_weights)
80
+ self.critic.apply(initialize_weights)
81
+
82
+ def training_step(self, batch, batch_idx):
83
+ # Get the Optimizers
84
+ opt_generator, opt_critic = self.optimizers()
85
+
86
+ # Get Data and Label
87
+ X, labels = batch
88
+
89
+ # Get the current batch size
90
+ batch_size = X.shape[0]
91
+
92
+ ##############################
93
+ # Train Critic ###############
94
+ ##############################
95
+ mean_critic_loss_for_this_iteration = 0
96
+
97
+ for _ in range(self.critic_repeats):
98
+ # Clean the Gradient
99
+ opt_critic.zero_grad()
100
+
101
+ # Generate the noise.
102
+ noise = torch.randn(batch_size, self.hparams.z_dim, device=self.device)
103
+
104
+ # Generate fake image.
105
+ fake = self.generator(noise, labels)
106
+
107
+ # Get the Critic's prediction on the reals and fakes
108
+ critic_fake_pred = self.critic(fake.detach(), labels)
109
+ critic_real_pred = self.critic(X, labels)
110
+
111
+ # Calculate the Critic loss using WGAN
112
+
113
+ # Generate epsilon for interpolate image.
114
+ epsilon = torch.rand(batch_size, 1, 1, 1, device=self.device, requires_grad=True)
115
+
116
+ # Calculate Gradient Penalty Critic model
117
+ gp = gradient_penalty(self.critic, labels, X, fake.detach(), epsilon)
118
+
119
+ # calculate full of WGAN-GP loss for Critic
120
+ critic_loss = calculate_critic_loss(
121
+ critic_fake_pred, critic_real_pred, gp, self.c_lambda
122
+ )
123
+
124
+ # Keep track of the average critic loss in this batch
125
+ mean_critic_loss_for_this_iteration += critic_loss.item() / self.critic_repeats
126
+
127
+ # Update the gradients Criticz
128
+ # self.manual_backward(critic_loss, retain_graph=True)
129
+ self.manual_backward(critic_loss) # no need retain graph cause, already detach() on the image, so it will cut from backpropagate. use that retain_graph=True if not using detach()
130
+
131
+ # Update the optimizer
132
+ opt_critic.step()
133
+
134
+ ##############################
135
+ # Train Generator ############
136
+ ##############################
137
+
138
+ # Clean the gradient
139
+ opt_generator.zero_grad()
140
+
141
+ # Generate the noise.
142
+ noise = torch.randn(batch_size, self.hparams.z_dim, device=self.device)
143
+
144
+ # Generate fake image.
145
+ fake = self.generator(noise, labels)
146
+
147
+ # Get the Critic's prediction on the fakes by generator
148
+ generator_fake_predictions = self.critic(fake, labels)
149
+
150
+ # Calculate loss for Generator
151
+ generator_loss = calculate_generator_loss(generator_fake_predictions)
152
+
153
+ # update the gradient generator
154
+ self.manual_backward(generator_loss)
155
+
156
+ # Update the optimizer
157
+ opt_generator.step()
158
+
159
+ ##############################
160
+ # Visualization ##############
161
+ ##############################
162
+
163
+ if self.curr_step % self.hparams.display_step == 0 and self.curr_step > 0:
164
+ VISUALIZE = True
165
+ if VISUALIZE:
166
+ with torch.no_grad():
167
+ fake_images_fixed = self.generator(
168
+ self.fixed_latent_space.to(self.device),
169
+ self.fixed_label.to(self.device)
170
+ )
171
+
172
+ path_save = f"/kaggle/working/generates/generated-{self.curr_step}-step.png"
173
+ plot_images_from_tensor(fake_images_fixed, size=(3, self.image_size, self.image_size), show=False, save_path=path_save)
174
+ plot_images_from_tensor(X, size=(3, self.image_size, self.image_size), show=False)
175
+
176
+ print(f" ==== Critic Loss: {mean_critic_loss_for_this_iteration} ==== ")
177
+ print(f" ==== Generator Loss: {generator_loss.item()} ==== ")
178
+
179
+ self.curr_step += 1
180
+
181
+ ##############################
182
+ # Logging ####################
183
+ ##############################
184
+ # Store the loss Critic into Log
185
+ self.log("critic_loss", mean_critic_loss_for_this_iteration, on_step=False, on_epoch=True, prog_bar=True)
186
+ self.log("generator_loss", generator_loss.item(), on_step=False, on_epoch=True, prog_bar=True)
187
+
188
+ # store into list, so can used later for visualization
189
+ self.critic_losses.append(mean_critic_loss_for_this_iteration)
190
+ self.generator_losses.append(generator_loss.item())
191
+
192
+ def forward(self, noise, labels):
193
+ return self.generator(noise, labels)
194
+
195
+ def predict_step(self, noise, labels):
196
+ return self.generator(noise, labels)
utility/helper.py CHANGED
@@ -1,7 +1,15 @@
1
  import torch
 
 
 
 
 
2
 
3
  from config.core import config
4
  from models.generator import Generator
 
 
 
5
 
6
  def load_model_weights(checkpoint_path, model, device, prefix):
7
  """
@@ -61,3 +69,156 @@ def get_selected_value(label):
61
  """
62
  # Get the selected value from the options mapping based on the display label.
63
  return config.OPTIONS_MAPPING[label]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ import torch.nn as nn
3
+ import cv2
4
+ import imageio
5
+ import os
6
+ import matplotlib.pyplot as plt
7
 
8
  from config.core import config
9
  from models.generator import Generator
10
+ from PIL import Image
11
+ from torchvision.utils import make_grid
12
+
13
 
14
  def load_model_weights(checkpoint_path, model, device, prefix):
15
  """
 
69
  """
70
  # Get the selected value from the options mapping based on the display label.
71
  return config.OPTIONS_MAPPING[label]
72
+
73
+ def initialize_weights(model):
74
+ """
75
+ Initializes the weights of a model using a normal distribution.
76
+
77
+ Args:
78
+ model: The model to be initialized.
79
+
80
+ Returns:
81
+ None
82
+ """
83
+
84
+ for m in model.modules():
85
+ if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d, nn.InstanceNorm2d)):
86
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
87
+
88
+ def plot_images_from_tensor(image_tensor, num_images=25, size=(1, 28, 28), nrow=5, show=True, save_path=None):
89
+ image_tensor = (image_tensor + 1) / 2
90
+ image_unflat = image_tensor.detach().cpu()
91
+ image_grid = make_grid(image_unflat[:num_images], nrow=nrow)
92
+ plt.imshow(image_grid.permute(1, 2, 0).squeeze())
93
+ plt.axis('off')
94
+ if save_path:
95
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
96
+ plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
97
+ if show:
98
+ plt.show()
99
+ else:
100
+ plt.close()
101
+
102
+ def create_video(image_folder, video_name, fps, appearance_duration=None):
103
+ """
104
+ Creates a video from a sequence of images with customizable appearance duration.
105
+
106
+ Args:
107
+ image_folder (str): The path to the folder containing the images.
108
+ video_name (str): The name of the output video file.
109
+ fps (int): The frames per second of the video.
110
+ appearance_duration (int, optional): The desired appearance duration for each image in milliseconds.
111
+ If None, the default duration based on frame rate is used.
112
+
113
+ Example:
114
+ image_folder = '/path/to/image/folder' \n
115
+ video_name = 'output_video.mp4' \n
116
+ fps = 12 \n
117
+ appearance_duration = 200 # Appearance duration of 200ms for each image \n
118
+
119
+ create_video(image_folder, video_name, fps, appearance_duration)
120
+ """
121
+
122
+ # Get a list of all image files in the folder
123
+ image_files = [f for f in os.listdir(image_folder) if f.endswith('.png')]
124
+
125
+ # Sort the image files based on the step number
126
+ image_files = sorted(image_files, key=lambda x: int(x.split('-')[1].split('.')[0]))
127
+
128
+ # Load the first image to get the video size
129
+ image = cv2.imread(os.path.join(image_folder, image_files[0]))
130
+ height, width, layers = image.shape
131
+
132
+ # Create a VideoWriter object
133
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Specify the video codec
134
+ video = cv2.VideoWriter(video_name, fourcc, fps, (width, height))
135
+
136
+ # Write each image to the video with customizable appearance duration
137
+ for image_file in image_files:
138
+ image = cv2.imread(os.path.join(image_folder, image_file))
139
+ video.write(image)
140
+
141
+ if appearance_duration is not None:
142
+ # Calculate the number of frames for the desired appearance duration
143
+ num_frames = appearance_duration * fps // 1000
144
+ for _ in range(num_frames):
145
+ video.write(image)
146
+
147
+ # Release the video writer
148
+ video.release()
149
+
150
+ def create_gif(image_folder, gif_name, fps, appearance_duration=None):
151
+ """
152
+ Creates a GIF from a sequence of images sorted by step number, with customizable appearance duration.
153
+
154
+ Args:
155
+ image_folder (str): The path to the folder containing the images.
156
+ gif_name (str): The name of the output GIF file.
157
+ fps (int): The frames per second of the GIF.
158
+ appearance_duration (int, optional): The desired appearance duration for each image in milliseconds.
159
+ If None, the default duration based on frame rate is used.
160
+
161
+ Example:
162
+ image_folder = '/path/to/image/folder'
163
+ gif_name = 'output_animation.gif'
164
+ fps = 12
165
+ appearance_duration = 300 # Appearance duration of 300ms for each image
166
+
167
+ create_gif(image_folder, gif_name, fps, appearance_duration)
168
+ """
169
+
170
+ # Get a list of all image files in the folder
171
+ image_files = [f for f in os.listdir(image_folder) if f.endswith('.png')]
172
+
173
+ # Sort the image files based on the step number
174
+ image_files = sorted(image_files, key=lambda x: int(x.split('-')[1].split('.')[0]))
175
+
176
+ # Load the images into a list
177
+ images = []
178
+ for file in image_files:
179
+ images.append(imageio.imread(os.path.join(image_folder, file)))
180
+
181
+ # Create a list to store the repeated images
182
+ repeated_images = []
183
+
184
+ # Repeat each image for the desired duration
185
+ if appearance_duration is not None:
186
+ for image in images:
187
+ repeated_images.extend([image] * (appearance_duration * fps // 1000))
188
+ else:
189
+ repeated_images = images # Default appearance duration (based on fps)
190
+
191
+ # Save the repeated images as a GIF
192
+ imageio.mimsave(gif_name, repeated_images, fps=fps)
193
+
194
+ class PadToSquare:
195
+ """Pad an image to a square of the given size with a white background.
196
+
197
+ Args:
198
+ size (int): The target size for the output image.
199
+ """
200
+
201
+ def __init__(self, size):
202
+ self.size = size
203
+
204
+ def __call__(self, img):
205
+ """Pad the input image to the target size with a white background.
206
+
207
+ Args:
208
+ img (PIL.Image.Image): The input image.
209
+
210
+ Returns:
211
+ PIL.Image.Image: The padded image.
212
+ """
213
+ # Create a white canvas
214
+ white_canvas = Image.new('RGB', (self.size, self.size), (255, 255, 255))
215
+
216
+ # Calculate the position to paste the image onto the white canvas
217
+ left = (self.size - img.width) // 2
218
+ top = (self.size - img.height) // 2
219
+
220
+ # Paste the image onto the canvas
221
+ white_canvas.paste(img, (left, top))
222
+
223
+ return white_canvas
224
+
utility/wgan_gp.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # Gradient Penalty Calculation - Calculate Gradient of Critic Score
4
+ def gradient_penalty(critic, labels, real_images, fake_images, epsilon):
5
+ """
6
+ This function calculates the gradient penalty for the WGAN-GP loss function
7
+
8
+ Parameters:
9
+ critic (nn.Module): The critic model
10
+ labels (torch.tensor): The labels for the images
11
+ real_images (torch.tensor): The real images
12
+ fake_images (torch.tensor): The fake images
13
+ epsilon (torch.tensor): The interpolation parameter
14
+
15
+ Returns:
16
+ gradient_penalty (torch.tensor): The gradient penalty for the critic model
17
+ """
18
+
19
+ # Create the interpolated images as a weighted combination of real and fake images
20
+ interpolated_images = real_images * epsilon + fake_images * (1 - epsilon)
21
+
22
+ mixed_scores = critic(interpolated_images, labels)
23
+
24
+ create_real_label = torch.ones_like(mixed_scores)
25
+
26
+ gradient = torch.autograd.grad(
27
+ inputs=interpolated_images,
28
+ outputs=mixed_scores,
29
+ grad_outputs=create_real_label,
30
+ create_graph=True,
31
+ retain_graph=True,
32
+ )[0]
33
+
34
+ # Reshape each image in the batch into a 1D tensor (flatten the images)
35
+ gradient = gradient.view(len(gradient), -1)
36
+
37
+ # Calculate the L2 norm of the gradients
38
+ gradient_norm = gradient.norm(2, dim=1)
39
+
40
+ # Calculate the penalty as the mean squared distance of the norms from 1
41
+ gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
42
+
43
+ return gradient_penalty
44
+
45
+ # Critic Loss Calculation
46
+ def calculate_critic_loss(critic_fake_prediction, critic_real_prediction, gradient_penalty, critic_lambda):
47
+ """
48
+ Calculates the critic loss, which is the difference between the mean of the real scores and mean of the fake scores plus the gradient penalty.
49
+
50
+ Parameters:
51
+ critic_fake_prediction (torch.tensor): The critic predictions for the fake images
52
+ critic_real_prediction (torch.tensor): The critic predictions for the real images
53
+ gradient_penalty (torch.tensor): The gradient penalty for the critic model
54
+ critic_lambda (float): The coefficient for the gradient penalty
55
+
56
+ Returns:
57
+ critic_loss (torch.tensor): The critic loss
58
+ """
59
+ critic_loss = (
60
+ -(torch.mean(critic_real_prediction) - torch.mean(critic_fake_prediction)) + critic_lambda * gradient_penalty
61
+ )
62
+
63
+ return critic_loss
64
+
65
+ # Generator Loss Calculation
66
+ def calculate_generator_loss(critic_fake_prediction):
67
+ """
68
+ Calculates the generator loss, which is the mean of the critic predictions for the fake images with a negative sign.
69
+
70
+ Parameters:
71
+ critic_fake_prediction (torch.tensor): The critic predictions for the fake images
72
+
73
+ Returns:
74
+ generator_loss (torch.tensor): The generator loss
75
+ """
76
+ generator_loss = -1.0 * torch.mean(critic_fake_prediction)
77
+ return generator_loss
weights/{epoch=957-step=1164300.ckpt → epoch=299-step=450000.ckpt} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1602e9cb80fcfbc39ed2c5c74873744a546cbab346730aaa0f25b02140bc1d2f
3
- size 157799852
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8db76fefe85e1e6077bf8aa7532699d379ea0dc1f0d1b6130ebac9ca3b814303
3
+ size 640723340
weights/source.txt CHANGED
@@ -1,2 +1,2 @@
1
- using a weight from kaggle after training 957 epoch:
2
- - https://www.kaggle.com/datasets/dimensioncore/conditional-gan-part-2/versions/2397
 
1
+ - Notebook: [VERSION 19] https://www.kaggle.com/code/dimensioncore/57894-conditional-gan-try-part-2?scriptVersionId=197222851
2
+ - Checkpoint: [VERSION 3016] https://www.kaggle.com/datasets/dimensioncore/conditional-gan-part-2/versions/3016