Spaces:
Sleeping
Sleeping
Muhammad Naufal Rizqullah
commited on
Commit
·
e61c431
1
Parent(s):
439c972
Experiment 2
Browse files- config/core.py +9 -8
- data/dataloader.py +52 -0
- models/discriminator.py +48 -33
- models/generator.py +35 -19
- requirements.txt +3 -1
- training/train.py +196 -0
- utility/helper.py +161 -0
- utility/wgan_gp.py +77 -0
- weights/{epoch=957-step=1164300.ckpt → epoch=299-step=450000.ckpt} +2 -2
- weights/source.txt +2 -2
config/core.py
CHANGED
@@ -2,15 +2,16 @@ from pydantic_settings import BaseSettings
|
|
2 |
|
3 |
class Config(BaseSettings):
|
4 |
IMAGE_CHANNEL: int = 3
|
|
|
5 |
NUM_CLASSES: int = 3
|
6 |
IMAGE_SIZE: int = 128
|
7 |
-
FEATURES_DISCRIMINATOR: int = 64
|
8 |
-
FEATURES_GENERATOR: int = 64
|
9 |
-
EMBED_SIZE: int =
|
10 |
-
INPUT_Z_DIM: int = 64
|
11 |
-
BATCH_SIZE: int =
|
12 |
DISPLAY_STEP: int = 500
|
13 |
-
MAX_SAMPLES: int =
|
14 |
|
15 |
LEARNING_RATE: float = 0.0002
|
16 |
BETA_1: float = 0.5
|
@@ -22,8 +23,8 @@ class Config(BaseSettings):
|
|
22 |
CRITIC_REPEAT: int = 3
|
23 |
|
24 |
LOAD_CHECKPOINT: bool = True
|
25 |
-
PATH_DATASET: str = ""
|
26 |
-
CKPT_PATH: str = "./weights/epoch=
|
27 |
|
28 |
OPTIONS_MAPPING: dict = {
|
29 |
"Boot": 0,
|
|
|
2 |
|
3 |
class Config(BaseSettings):
|
4 |
IMAGE_CHANNEL: int = 3
|
5 |
+
LABEL_CHANNEL: int = 3
|
6 |
NUM_CLASSES: int = 3
|
7 |
IMAGE_SIZE: int = 128
|
8 |
+
FEATURES_DISCRIMINATOR: int = 64 * 2
|
9 |
+
FEATURES_GENERATOR: int = 64 * 2
|
10 |
+
EMBED_SIZE: int = 30 + 20
|
11 |
+
INPUT_Z_DIM: int = 64 * 2
|
12 |
+
BATCH_SIZE: int = 20
|
13 |
DISPLAY_STEP: int = 500
|
14 |
+
MAX_SAMPLES: int = 2500
|
15 |
|
16 |
LEARNING_RATE: float = 0.0002
|
17 |
BETA_1: float = 0.5
|
|
|
23 |
CRITIC_REPEAT: int = 3
|
24 |
|
25 |
LOAD_CHECKPOINT: bool = True
|
26 |
+
PATH_DATASET: str = "/kaggle/input/shoe-vs-sandal-vs-boot-dataset-15k-images/Shoe vs Sandal vs Boot Dataset"
|
27 |
+
CKPT_PATH: str = "./weights/epoch=299-step=450000.ckpt"
|
28 |
|
29 |
OPTIONS_MAPPING: dict = {
|
30 |
"Boot": 0,
|
data/dataloader.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import lightning as L
|
3 |
+
import torchvision.transforms as T
|
4 |
+
import os
|
5 |
+
|
6 |
+
from config.core import config
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from torchvision.datasets import ImageFolder
|
9 |
+
from utility.helper import PadToSquare
|
10 |
+
|
11 |
+
class ShoeSandalBoot(L.LightningDataModule):
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
dataset_directory,
|
15 |
+
image_size=config.image_size,
|
16 |
+
batch_size=config.image_size,
|
17 |
+
max_samples=None
|
18 |
+
):
|
19 |
+
super().__init__()
|
20 |
+
|
21 |
+
self.data_dir = dataset_directory
|
22 |
+
self.bs = batch_size
|
23 |
+
self.max_samples = max_samples # to limit dataset.
|
24 |
+
|
25 |
+
self.transforms = T.Compose([
|
26 |
+
# T.Resize(size=(image_size, image_size)),
|
27 |
+
PadToSquare(image_size),
|
28 |
+
T.ToTensor(),
|
29 |
+
T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
30 |
+
])
|
31 |
+
|
32 |
+
self.ssb = None
|
33 |
+
|
34 |
+
def prepare_data(self):
|
35 |
+
pass
|
36 |
+
|
37 |
+
def setup(self, stage):
|
38 |
+
if stage == "fit":
|
39 |
+
dataset = ImageFolder(
|
40 |
+
root=self.data_dir,
|
41 |
+
transform=self.transforms
|
42 |
+
)
|
43 |
+
|
44 |
+
# To Limit Dataset
|
45 |
+
if self.max_samples:
|
46 |
+
print(f"[INFO] Dataset is Limited to {self.max_samples} Samples")
|
47 |
+
self.ssb = torch.utils.data.Subset(dataset, range(min(len(dataset), self.max_samples)))
|
48 |
+
else:
|
49 |
+
self.ssb = dataset
|
50 |
+
|
51 |
+
def train_dataloader(self):
|
52 |
+
return DataLoader(self.ssb, batch_size=self.bs, num_workers=os.cpu_count(), shuffle=True)
|
models/discriminator.py
CHANGED
@@ -2,63 +2,78 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
|
4 |
from .base import WSConv2d, ConvBlock
|
|
|
5 |
|
6 |
|
7 |
class Discriminator(nn.Module):
|
8 |
-
def __init__(self, num_classes=3, image_size=128, features_discriminator=128, image_channel=3):
|
9 |
super().__init__()
|
10 |
-
|
11 |
self.num_classes = num_classes
|
12 |
self.image_size = image_size
|
13 |
-
|
14 |
-
|
|
|
|
|
15 |
self.disc = nn.Sequential(
|
16 |
-
self._block_discriminator(image_channel + label_channel, features_discriminator, kernel_size=4, stride=2,
|
17 |
-
|
18 |
-
self._block_discriminator(features_discriminator, features_discriminator, kernel_size=4, stride=2,
|
19 |
-
|
20 |
-
self._block_discriminator(features_discriminator, features_discriminator
|
21 |
-
|
22 |
-
self._block_discriminator(features_discriminator * 2, features_discriminator * 4, kernel_size=4, stride=2,
|
23 |
-
padding=1),
|
24 |
-
self._block_discriminator(features_discriminator * 4, features_discriminator * 4, kernel_size=4, stride=2,
|
25 |
-
padding=1),
|
26 |
-
self._block_discriminator(features_discriminator * 4, 1, kernel_size=4, stride=1, padding=0,
|
27 |
-
final_layer=True)
|
28 |
)
|
29 |
-
|
30 |
-
self.embed = nn.Embedding(num_classes,
|
|
|
31 |
|
32 |
def forward(self, image, label):
|
33 |
embedding = self.embed(label)
|
34 |
-
|
|
|
|
|
|
|
35 |
label.shape[0],
|
36 |
-
|
37 |
self.image_size,
|
38 |
self.image_size
|
39 |
)
|
40 |
-
|
41 |
-
data = torch.cat([image,
|
42 |
-
|
43 |
x = self.disc(data)
|
44 |
-
|
45 |
return x.view(len(x), -1)
|
46 |
|
47 |
def _block_discriminator(
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
):
|
56 |
-
if not final_layer:
|
57 |
return nn.Sequential(
|
58 |
ConvBlock(input_channels, output_channels),
|
59 |
-
WSConv2d(output_channels, output_channels, kernel_size, stride, padding)
|
60 |
)
|
61 |
else:
|
62 |
return WSConv2d(input_channels, output_channels, kernel_size, stride, padding)
|
|
|
|
|
|
|
|
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
|
|
|
|
|
2 |
import torch.nn as nn
|
3 |
|
4 |
from .base import WSConv2d, ConvBlock
|
5 |
+
from config.core import config
|
6 |
|
7 |
|
8 |
class Discriminator(nn.Module):
|
9 |
+
def __init__(self, num_classes=3, embed_size=128, image_size=128, features_discriminator=128, image_channel=3, label_channel=3):
|
10 |
super().__init__()
|
11 |
+
|
12 |
self.num_classes = num_classes
|
13 |
self.image_size = image_size
|
14 |
+
|
15 |
+
self.embed_size = embed_size
|
16 |
+
self.label_channel = label_channel
|
17 |
+
|
18 |
self.disc = nn.Sequential(
|
19 |
+
self._block_discriminator(image_channel + label_channel, features_discriminator, kernel_size=4, stride=2, padding=1),
|
20 |
+
self._block_discriminator(features_discriminator, features_discriminator, kernel_size=4, stride=2, padding=1),
|
21 |
+
self._block_discriminator(features_discriminator, features_discriminator * 2, kernel_size=4, stride=2, padding=1),
|
22 |
+
self._block_discriminator(features_discriminator * 2, features_discriminator * 4, kernel_size=4, stride=2, padding=1),
|
23 |
+
self._block_discriminator(features_discriminator * 4, features_discriminator *4 , kernel_size=4, stride=2, padding=1),
|
24 |
+
self._block_discriminator(features_discriminator * 4, 1, kernel_size=4, stride=1, padding=0, final_layer=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
)
|
26 |
+
|
27 |
+
self.embed = nn.Embedding(num_classes, embed_size)
|
28 |
+
self.embed_linear = nn.Linear(embed_size, label_channel*image_size*image_size)
|
29 |
|
30 |
def forward(self, image, label):
|
31 |
embedding = self.embed(label)
|
32 |
+
|
33 |
+
linear_embedding = self.embed_linear(embedding)
|
34 |
+
|
35 |
+
embedding_layer = linear_embedding.view(
|
36 |
label.shape[0],
|
37 |
+
self.label_channel,
|
38 |
self.image_size,
|
39 |
self.image_size
|
40 |
)
|
41 |
+
|
42 |
+
data = torch.cat([image, embedding_layer], dim=1)
|
43 |
+
|
44 |
x = self.disc(data)
|
45 |
+
|
46 |
return x.view(len(x), -1)
|
47 |
|
48 |
def _block_discriminator(
|
49 |
+
self,
|
50 |
+
input_channels,
|
51 |
+
output_channels,
|
52 |
+
kernel_size=3,
|
53 |
+
stride=2,
|
54 |
+
padding=0,
|
55 |
+
final_layer=False
|
56 |
):
|
57 |
+
if not final_layer:
|
58 |
return nn.Sequential(
|
59 |
ConvBlock(input_channels, output_channels),
|
60 |
+
WSConv2d(output_channels, output_channels, kernel_size, stride, padding),
|
61 |
)
|
62 |
else:
|
63 |
return WSConv2d(input_channels, output_channels, kernel_size, stride, padding)
|
64 |
+
|
65 |
+
def test():
|
66 |
+
sample = torch.randn(1, 3, 128, 128)
|
67 |
+
label = torch.tensor([1])
|
68 |
|
69 |
+
model = Discriminator(
|
70 |
+
num_classes=config.NUM_CLASSES,
|
71 |
+
embed_size=config.EMBED_SIZE,
|
72 |
+
image_size=config.IMAGE_SIZE,
|
73 |
+
features_discriminator=config.FEATURES_DISCRIMINATOR,
|
74 |
+
image_channel=config.IMAGE_CHANNEL,
|
75 |
+
label_channel=config.LABEL_CHANNEL
|
76 |
+
)
|
77 |
|
78 |
+
preds = model(sample, label)
|
79 |
+
print(preds.shape)
|
models/generator.py
CHANGED
@@ -2,40 +2,42 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
|
4 |
from .base import WSConv2d, ConvBlock, PixelNorm
|
|
|
5 |
|
6 |
|
7 |
class Generator(nn.Module):
|
8 |
-
def __init__(self,
|
9 |
super().__init__()
|
10 |
|
11 |
self.gen = nn.Sequential(
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
)
|
20 |
-
|
21 |
self.image_size = image_size
|
22 |
self.embed_size = embed_size
|
23 |
-
|
24 |
self.embed = nn.Embedding(num_classes, embed_size)
|
|
|
25 |
|
26 |
def forward(self, noise, labels):
|
27 |
-
embedding_label = self.embed(labels)
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
return self.gen(x)
|
36 |
|
|
|
37 |
def _block(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,
|
38 |
-
|
39 |
layers = []
|
40 |
|
41 |
if not final_layer:
|
@@ -55,3 +57,17 @@ class Generator(nn.Module):
|
|
55 |
|
56 |
return nn.Sequential(*layers)
|
57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import torch.nn as nn
|
3 |
|
4 |
from .base import WSConv2d, ConvBlock, PixelNorm
|
5 |
+
from config.core import config
|
6 |
|
7 |
|
8 |
class Generator(nn.Module):
|
9 |
+
def __init__(self, embed_size=128, num_classes=3, image_size=128, features_generator=128, input_dim=128, image_channel=3):
|
10 |
super().__init__()
|
11 |
|
12 |
self.gen = nn.Sequential(
|
13 |
+
self._block(input_dim + embed_size, features_generator*2, first_double_up=True),
|
14 |
+
self._block(features_generator*2, features_generator*4, first_double_up=False, final_layer=False,),
|
15 |
+
self._block(features_generator*4, features_generator*4, first_double_up=False, final_layer=False,),
|
16 |
+
self._block(features_generator*4, features_generator*4, first_double_up=False, final_layer=False,),
|
17 |
+
self._block(features_generator*4, features_generator*2, first_double_up=False, final_layer=False,),
|
18 |
+
self._block(features_generator*2, features_generator, first_double_up=False, final_layer=False,),
|
19 |
+
self._block(features_generator, image_channel, first_double_up=False, use_double=False, final_layer=True,),
|
20 |
)
|
21 |
+
|
22 |
self.image_size = image_size
|
23 |
self.embed_size = embed_size
|
24 |
+
|
25 |
self.embed = nn.Embedding(num_classes, embed_size)
|
26 |
+
self.embed_linear = nn.Linear(embed_size, embed_size)
|
27 |
|
28 |
def forward(self, noise, labels):
|
29 |
+
embedding_label = self.embed(labels)
|
30 |
+
linear_embedding_label = self.embed_linear(embedding_label).unsqueeze(2).unsqueeze(3)
|
31 |
+
|
32 |
+
noise = noise.view(noise.size(0), noise.size(1), 1, 1)
|
33 |
+
|
34 |
+
x = torch.cat([noise, linear_embedding_label], dim=1)
|
35 |
+
|
|
|
36 |
return self.gen(x)
|
37 |
|
38 |
+
|
39 |
def _block(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,
|
40 |
+
first_double_up=False, use_double=True, final_layer=False):
|
41 |
layers = []
|
42 |
|
43 |
if not final_layer:
|
|
|
57 |
|
58 |
return nn.Sequential(*layers)
|
59 |
|
60 |
+
def test():
|
61 |
+
sample = torch.randn(1, config.INPUT_Z_DIM, 1, 1)
|
62 |
+
label = torch.tensor([1])
|
63 |
+
|
64 |
+
model = Generator(
|
65 |
+
embed_size=config.EMBED_SIZE,
|
66 |
+
num_classes=config.NUM_CLASSES,
|
67 |
+
image_size=config.IMAGE_SIZE,
|
68 |
+
features_generator=config.FEATURES_GENERATOR,
|
69 |
+
input_dim=config.INPUT_Z_DIM,
|
70 |
+
)
|
71 |
+
|
72 |
+
preds = model(sample, label)
|
73 |
+
print(preds.shape)
|
requirements.txt
CHANGED
@@ -3,4 +3,6 @@ pytorch-lightning
|
|
3 |
python-multipart
|
4 |
fastapi
|
5 |
pydantic
|
6 |
-
pydantic-settings
|
|
|
|
|
|
3 |
python-multipart
|
4 |
fastapi
|
5 |
pydantic
|
6 |
+
pydantic-settings
|
7 |
+
opencv-python==4.10.0
|
8 |
+
imageio==2.33.1
|
training/train.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import lightning as L
|
3 |
+
import torch.optim as optim
|
4 |
+
|
5 |
+
from models.generator import Generator
|
6 |
+
from models.discriminator import Discriminator
|
7 |
+
from utility.helper import initialize_weights, plot_images_from_tensor
|
8 |
+
from utility.wgan_gp import gradient_penalty, calculate_generator_loss, calculate_critic_loss
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
class ConditionalWGAN_GP(L.LightningModule):
|
13 |
+
def __init__(self, image_channel, label_channel, image_size, learning_rate, z_dim, embed_size, num_classes, critic_repeats, feature_gen, feature_critic, c_lambda, beta_1, beta_2, display_step):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
self.automatic_optimization = False
|
17 |
+
|
18 |
+
self.image_size = image_size
|
19 |
+
self.critic_repeats = critic_repeats
|
20 |
+
self.c_lambda = c_lambda
|
21 |
+
|
22 |
+
self.generator = Generator(
|
23 |
+
embed_size=embed_size,
|
24 |
+
num_classes=num_classes,
|
25 |
+
image_size=image_size,
|
26 |
+
features_generator=feature_gen,
|
27 |
+
input_dim=z_dim,
|
28 |
+
)
|
29 |
+
|
30 |
+
self.critic = Discriminator(
|
31 |
+
num_classes=num_classes,
|
32 |
+
embed_size=embed_size,
|
33 |
+
image_size=image_size,
|
34 |
+
features_discriminator=feature_critic,
|
35 |
+
image_channel=image_channel,
|
36 |
+
label_channel=label_channel,
|
37 |
+
)
|
38 |
+
|
39 |
+
|
40 |
+
self.critic_losses = []
|
41 |
+
self.generator_losses = []
|
42 |
+
self.curr_step = 0
|
43 |
+
|
44 |
+
self.fixed_latent_space = torch.randn(25, z_dim, 1, 1)
|
45 |
+
self.fixed_label = torch.tensor([i % num_classes for i in range(25)])
|
46 |
+
|
47 |
+
self.save_hyperparameters()
|
48 |
+
|
49 |
+
def configure_optimizers(self):
|
50 |
+
# READ: https://lightning.ai/docs/pytorch/stable/common/optimization.html#use-multiple-optimizers-like-gans
|
51 |
+
# READ: https://lightning.ai/docs/pytorch/stable/model/manual_optimization.html
|
52 |
+
# READ: https://lightning.ai/docs/pytorch/stable/model/build_model_advanced.html
|
53 |
+
# READ: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.core.LightningModule.html#lightning.pytorch.core.LightningModule.backward
|
54 |
+
# READ: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#manual-backward
|
55 |
+
optimizer_G = optim.Adam(self.generator.parameters(), lr=self.hparams.learning_rate, betas=(self.hparams.beta_1, self.hparams.beta_2))
|
56 |
+
optimizer_C = optim.Adam(self.critic.parameters(), lr=self.hparams.learning_rate, betas=(self.hparams.beta_1, self.hparams.beta_2))
|
57 |
+
|
58 |
+
return optimizer_G, optimizer_C
|
59 |
+
|
60 |
+
def on_load_checkpoint(self, checkpoint):
|
61 |
+
# List of keys that you expect to load from the checkpoint
|
62 |
+
keys_to_load = ['critic_losses', 'generator_losses', 'curr_step', 'fixed_latent_space', 'fixed_label']
|
63 |
+
|
64 |
+
# Iterate over the keys and load them if they exist in the checkpoint
|
65 |
+
for key in keys_to_load:
|
66 |
+
if key in checkpoint:
|
67 |
+
setattr(self, key, checkpoint[key])
|
68 |
+
|
69 |
+
def on_save_checkpoint(self, checkpoint):
|
70 |
+
# Save necessary variable to checkpoint
|
71 |
+
checkpoint['critic_losses'] = self.critic_losses
|
72 |
+
checkpoint['generator_losses'] = self.generator_losses
|
73 |
+
checkpoint['curr_step'] = self.curr_step
|
74 |
+
checkpoint['fixed_latent_space'] = self.fixed_latent_space
|
75 |
+
checkpoint['fixed_label'] = self.fixed_label
|
76 |
+
|
77 |
+
def on_train_start(self):
|
78 |
+
if self.current_epoch == 0:
|
79 |
+
self.generator.apply(initialize_weights)
|
80 |
+
self.critic.apply(initialize_weights)
|
81 |
+
|
82 |
+
def training_step(self, batch, batch_idx):
|
83 |
+
# Get the Optimizers
|
84 |
+
opt_generator, opt_critic = self.optimizers()
|
85 |
+
|
86 |
+
# Get Data and Label
|
87 |
+
X, labels = batch
|
88 |
+
|
89 |
+
# Get the current batch size
|
90 |
+
batch_size = X.shape[0]
|
91 |
+
|
92 |
+
##############################
|
93 |
+
# Train Critic ###############
|
94 |
+
##############################
|
95 |
+
mean_critic_loss_for_this_iteration = 0
|
96 |
+
|
97 |
+
for _ in range(self.critic_repeats):
|
98 |
+
# Clean the Gradient
|
99 |
+
opt_critic.zero_grad()
|
100 |
+
|
101 |
+
# Generate the noise.
|
102 |
+
noise = torch.randn(batch_size, self.hparams.z_dim, device=self.device)
|
103 |
+
|
104 |
+
# Generate fake image.
|
105 |
+
fake = self.generator(noise, labels)
|
106 |
+
|
107 |
+
# Get the Critic's prediction on the reals and fakes
|
108 |
+
critic_fake_pred = self.critic(fake.detach(), labels)
|
109 |
+
critic_real_pred = self.critic(X, labels)
|
110 |
+
|
111 |
+
# Calculate the Critic loss using WGAN
|
112 |
+
|
113 |
+
# Generate epsilon for interpolate image.
|
114 |
+
epsilon = torch.rand(batch_size, 1, 1, 1, device=self.device, requires_grad=True)
|
115 |
+
|
116 |
+
# Calculate Gradient Penalty Critic model
|
117 |
+
gp = gradient_penalty(self.critic, labels, X, fake.detach(), epsilon)
|
118 |
+
|
119 |
+
# calculate full of WGAN-GP loss for Critic
|
120 |
+
critic_loss = calculate_critic_loss(
|
121 |
+
critic_fake_pred, critic_real_pred, gp, self.c_lambda
|
122 |
+
)
|
123 |
+
|
124 |
+
# Keep track of the average critic loss in this batch
|
125 |
+
mean_critic_loss_for_this_iteration += critic_loss.item() / self.critic_repeats
|
126 |
+
|
127 |
+
# Update the gradients Criticz
|
128 |
+
# self.manual_backward(critic_loss, retain_graph=True)
|
129 |
+
self.manual_backward(critic_loss) # no need retain graph cause, already detach() on the image, so it will cut from backpropagate. use that retain_graph=True if not using detach()
|
130 |
+
|
131 |
+
# Update the optimizer
|
132 |
+
opt_critic.step()
|
133 |
+
|
134 |
+
##############################
|
135 |
+
# Train Generator ############
|
136 |
+
##############################
|
137 |
+
|
138 |
+
# Clean the gradient
|
139 |
+
opt_generator.zero_grad()
|
140 |
+
|
141 |
+
# Generate the noise.
|
142 |
+
noise = torch.randn(batch_size, self.hparams.z_dim, device=self.device)
|
143 |
+
|
144 |
+
# Generate fake image.
|
145 |
+
fake = self.generator(noise, labels)
|
146 |
+
|
147 |
+
# Get the Critic's prediction on the fakes by generator
|
148 |
+
generator_fake_predictions = self.critic(fake, labels)
|
149 |
+
|
150 |
+
# Calculate loss for Generator
|
151 |
+
generator_loss = calculate_generator_loss(generator_fake_predictions)
|
152 |
+
|
153 |
+
# update the gradient generator
|
154 |
+
self.manual_backward(generator_loss)
|
155 |
+
|
156 |
+
# Update the optimizer
|
157 |
+
opt_generator.step()
|
158 |
+
|
159 |
+
##############################
|
160 |
+
# Visualization ##############
|
161 |
+
##############################
|
162 |
+
|
163 |
+
if self.curr_step % self.hparams.display_step == 0 and self.curr_step > 0:
|
164 |
+
VISUALIZE = True
|
165 |
+
if VISUALIZE:
|
166 |
+
with torch.no_grad():
|
167 |
+
fake_images_fixed = self.generator(
|
168 |
+
self.fixed_latent_space.to(self.device),
|
169 |
+
self.fixed_label.to(self.device)
|
170 |
+
)
|
171 |
+
|
172 |
+
path_save = f"/kaggle/working/generates/generated-{self.curr_step}-step.png"
|
173 |
+
plot_images_from_tensor(fake_images_fixed, size=(3, self.image_size, self.image_size), show=False, save_path=path_save)
|
174 |
+
plot_images_from_tensor(X, size=(3, self.image_size, self.image_size), show=False)
|
175 |
+
|
176 |
+
print(f" ==== Critic Loss: {mean_critic_loss_for_this_iteration} ==== ")
|
177 |
+
print(f" ==== Generator Loss: {generator_loss.item()} ==== ")
|
178 |
+
|
179 |
+
self.curr_step += 1
|
180 |
+
|
181 |
+
##############################
|
182 |
+
# Logging ####################
|
183 |
+
##############################
|
184 |
+
# Store the loss Critic into Log
|
185 |
+
self.log("critic_loss", mean_critic_loss_for_this_iteration, on_step=False, on_epoch=True, prog_bar=True)
|
186 |
+
self.log("generator_loss", generator_loss.item(), on_step=False, on_epoch=True, prog_bar=True)
|
187 |
+
|
188 |
+
# store into list, so can used later for visualization
|
189 |
+
self.critic_losses.append(mean_critic_loss_for_this_iteration)
|
190 |
+
self.generator_losses.append(generator_loss.item())
|
191 |
+
|
192 |
+
def forward(self, noise, labels):
|
193 |
+
return self.generator(noise, labels)
|
194 |
+
|
195 |
+
def predict_step(self, noise, labels):
|
196 |
+
return self.generator(noise, labels)
|
utility/helper.py
CHANGED
@@ -1,7 +1,15 @@
|
|
1 |
import torch
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
from config.core import config
|
4 |
from models.generator import Generator
|
|
|
|
|
|
|
5 |
|
6 |
def load_model_weights(checkpoint_path, model, device, prefix):
|
7 |
"""
|
@@ -61,3 +69,156 @@ def get_selected_value(label):
|
|
61 |
"""
|
62 |
# Get the selected value from the options mapping based on the display label.
|
63 |
return config.OPTIONS_MAPPING[label]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import cv2
|
4 |
+
import imageio
|
5 |
+
import os
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
|
8 |
from config.core import config
|
9 |
from models.generator import Generator
|
10 |
+
from PIL import Image
|
11 |
+
from torchvision.utils import make_grid
|
12 |
+
|
13 |
|
14 |
def load_model_weights(checkpoint_path, model, device, prefix):
|
15 |
"""
|
|
|
69 |
"""
|
70 |
# Get the selected value from the options mapping based on the display label.
|
71 |
return config.OPTIONS_MAPPING[label]
|
72 |
+
|
73 |
+
def initialize_weights(model):
|
74 |
+
"""
|
75 |
+
Initializes the weights of a model using a normal distribution.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
model: The model to be initialized.
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
None
|
82 |
+
"""
|
83 |
+
|
84 |
+
for m in model.modules():
|
85 |
+
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d, nn.InstanceNorm2d)):
|
86 |
+
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
87 |
+
|
88 |
+
def plot_images_from_tensor(image_tensor, num_images=25, size=(1, 28, 28), nrow=5, show=True, save_path=None):
|
89 |
+
image_tensor = (image_tensor + 1) / 2
|
90 |
+
image_unflat = image_tensor.detach().cpu()
|
91 |
+
image_grid = make_grid(image_unflat[:num_images], nrow=nrow)
|
92 |
+
plt.imshow(image_grid.permute(1, 2, 0).squeeze())
|
93 |
+
plt.axis('off')
|
94 |
+
if save_path:
|
95 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
96 |
+
plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
|
97 |
+
if show:
|
98 |
+
plt.show()
|
99 |
+
else:
|
100 |
+
plt.close()
|
101 |
+
|
102 |
+
def create_video(image_folder, video_name, fps, appearance_duration=None):
|
103 |
+
"""
|
104 |
+
Creates a video from a sequence of images with customizable appearance duration.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
image_folder (str): The path to the folder containing the images.
|
108 |
+
video_name (str): The name of the output video file.
|
109 |
+
fps (int): The frames per second of the video.
|
110 |
+
appearance_duration (int, optional): The desired appearance duration for each image in milliseconds.
|
111 |
+
If None, the default duration based on frame rate is used.
|
112 |
+
|
113 |
+
Example:
|
114 |
+
image_folder = '/path/to/image/folder' \n
|
115 |
+
video_name = 'output_video.mp4' \n
|
116 |
+
fps = 12 \n
|
117 |
+
appearance_duration = 200 # Appearance duration of 200ms for each image \n
|
118 |
+
|
119 |
+
create_video(image_folder, video_name, fps, appearance_duration)
|
120 |
+
"""
|
121 |
+
|
122 |
+
# Get a list of all image files in the folder
|
123 |
+
image_files = [f for f in os.listdir(image_folder) if f.endswith('.png')]
|
124 |
+
|
125 |
+
# Sort the image files based on the step number
|
126 |
+
image_files = sorted(image_files, key=lambda x: int(x.split('-')[1].split('.')[0]))
|
127 |
+
|
128 |
+
# Load the first image to get the video size
|
129 |
+
image = cv2.imread(os.path.join(image_folder, image_files[0]))
|
130 |
+
height, width, layers = image.shape
|
131 |
+
|
132 |
+
# Create a VideoWriter object
|
133 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Specify the video codec
|
134 |
+
video = cv2.VideoWriter(video_name, fourcc, fps, (width, height))
|
135 |
+
|
136 |
+
# Write each image to the video with customizable appearance duration
|
137 |
+
for image_file in image_files:
|
138 |
+
image = cv2.imread(os.path.join(image_folder, image_file))
|
139 |
+
video.write(image)
|
140 |
+
|
141 |
+
if appearance_duration is not None:
|
142 |
+
# Calculate the number of frames for the desired appearance duration
|
143 |
+
num_frames = appearance_duration * fps // 1000
|
144 |
+
for _ in range(num_frames):
|
145 |
+
video.write(image)
|
146 |
+
|
147 |
+
# Release the video writer
|
148 |
+
video.release()
|
149 |
+
|
150 |
+
def create_gif(image_folder, gif_name, fps, appearance_duration=None):
|
151 |
+
"""
|
152 |
+
Creates a GIF from a sequence of images sorted by step number, with customizable appearance duration.
|
153 |
+
|
154 |
+
Args:
|
155 |
+
image_folder (str): The path to the folder containing the images.
|
156 |
+
gif_name (str): The name of the output GIF file.
|
157 |
+
fps (int): The frames per second of the GIF.
|
158 |
+
appearance_duration (int, optional): The desired appearance duration for each image in milliseconds.
|
159 |
+
If None, the default duration based on frame rate is used.
|
160 |
+
|
161 |
+
Example:
|
162 |
+
image_folder = '/path/to/image/folder'
|
163 |
+
gif_name = 'output_animation.gif'
|
164 |
+
fps = 12
|
165 |
+
appearance_duration = 300 # Appearance duration of 300ms for each image
|
166 |
+
|
167 |
+
create_gif(image_folder, gif_name, fps, appearance_duration)
|
168 |
+
"""
|
169 |
+
|
170 |
+
# Get a list of all image files in the folder
|
171 |
+
image_files = [f for f in os.listdir(image_folder) if f.endswith('.png')]
|
172 |
+
|
173 |
+
# Sort the image files based on the step number
|
174 |
+
image_files = sorted(image_files, key=lambda x: int(x.split('-')[1].split('.')[0]))
|
175 |
+
|
176 |
+
# Load the images into a list
|
177 |
+
images = []
|
178 |
+
for file in image_files:
|
179 |
+
images.append(imageio.imread(os.path.join(image_folder, file)))
|
180 |
+
|
181 |
+
# Create a list to store the repeated images
|
182 |
+
repeated_images = []
|
183 |
+
|
184 |
+
# Repeat each image for the desired duration
|
185 |
+
if appearance_duration is not None:
|
186 |
+
for image in images:
|
187 |
+
repeated_images.extend([image] * (appearance_duration * fps // 1000))
|
188 |
+
else:
|
189 |
+
repeated_images = images # Default appearance duration (based on fps)
|
190 |
+
|
191 |
+
# Save the repeated images as a GIF
|
192 |
+
imageio.mimsave(gif_name, repeated_images, fps=fps)
|
193 |
+
|
194 |
+
class PadToSquare:
|
195 |
+
"""Pad an image to a square of the given size with a white background.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
size (int): The target size for the output image.
|
199 |
+
"""
|
200 |
+
|
201 |
+
def __init__(self, size):
|
202 |
+
self.size = size
|
203 |
+
|
204 |
+
def __call__(self, img):
|
205 |
+
"""Pad the input image to the target size with a white background.
|
206 |
+
|
207 |
+
Args:
|
208 |
+
img (PIL.Image.Image): The input image.
|
209 |
+
|
210 |
+
Returns:
|
211 |
+
PIL.Image.Image: The padded image.
|
212 |
+
"""
|
213 |
+
# Create a white canvas
|
214 |
+
white_canvas = Image.new('RGB', (self.size, self.size), (255, 255, 255))
|
215 |
+
|
216 |
+
# Calculate the position to paste the image onto the white canvas
|
217 |
+
left = (self.size - img.width) // 2
|
218 |
+
top = (self.size - img.height) // 2
|
219 |
+
|
220 |
+
# Paste the image onto the canvas
|
221 |
+
white_canvas.paste(img, (left, top))
|
222 |
+
|
223 |
+
return white_canvas
|
224 |
+
|
utility/wgan_gp.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
# Gradient Penalty Calculation - Calculate Gradient of Critic Score
|
4 |
+
def gradient_penalty(critic, labels, real_images, fake_images, epsilon):
|
5 |
+
"""
|
6 |
+
This function calculates the gradient penalty for the WGAN-GP loss function
|
7 |
+
|
8 |
+
Parameters:
|
9 |
+
critic (nn.Module): The critic model
|
10 |
+
labels (torch.tensor): The labels for the images
|
11 |
+
real_images (torch.tensor): The real images
|
12 |
+
fake_images (torch.tensor): The fake images
|
13 |
+
epsilon (torch.tensor): The interpolation parameter
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
gradient_penalty (torch.tensor): The gradient penalty for the critic model
|
17 |
+
"""
|
18 |
+
|
19 |
+
# Create the interpolated images as a weighted combination of real and fake images
|
20 |
+
interpolated_images = real_images * epsilon + fake_images * (1 - epsilon)
|
21 |
+
|
22 |
+
mixed_scores = critic(interpolated_images, labels)
|
23 |
+
|
24 |
+
create_real_label = torch.ones_like(mixed_scores)
|
25 |
+
|
26 |
+
gradient = torch.autograd.grad(
|
27 |
+
inputs=interpolated_images,
|
28 |
+
outputs=mixed_scores,
|
29 |
+
grad_outputs=create_real_label,
|
30 |
+
create_graph=True,
|
31 |
+
retain_graph=True,
|
32 |
+
)[0]
|
33 |
+
|
34 |
+
# Reshape each image in the batch into a 1D tensor (flatten the images)
|
35 |
+
gradient = gradient.view(len(gradient), -1)
|
36 |
+
|
37 |
+
# Calculate the L2 norm of the gradients
|
38 |
+
gradient_norm = gradient.norm(2, dim=1)
|
39 |
+
|
40 |
+
# Calculate the penalty as the mean squared distance of the norms from 1
|
41 |
+
gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
|
42 |
+
|
43 |
+
return gradient_penalty
|
44 |
+
|
45 |
+
# Critic Loss Calculation
|
46 |
+
def calculate_critic_loss(critic_fake_prediction, critic_real_prediction, gradient_penalty, critic_lambda):
|
47 |
+
"""
|
48 |
+
Calculates the critic loss, which is the difference between the mean of the real scores and mean of the fake scores plus the gradient penalty.
|
49 |
+
|
50 |
+
Parameters:
|
51 |
+
critic_fake_prediction (torch.tensor): The critic predictions for the fake images
|
52 |
+
critic_real_prediction (torch.tensor): The critic predictions for the real images
|
53 |
+
gradient_penalty (torch.tensor): The gradient penalty for the critic model
|
54 |
+
critic_lambda (float): The coefficient for the gradient penalty
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
critic_loss (torch.tensor): The critic loss
|
58 |
+
"""
|
59 |
+
critic_loss = (
|
60 |
+
-(torch.mean(critic_real_prediction) - torch.mean(critic_fake_prediction)) + critic_lambda * gradient_penalty
|
61 |
+
)
|
62 |
+
|
63 |
+
return critic_loss
|
64 |
+
|
65 |
+
# Generator Loss Calculation
|
66 |
+
def calculate_generator_loss(critic_fake_prediction):
|
67 |
+
"""
|
68 |
+
Calculates the generator loss, which is the mean of the critic predictions for the fake images with a negative sign.
|
69 |
+
|
70 |
+
Parameters:
|
71 |
+
critic_fake_prediction (torch.tensor): The critic predictions for the fake images
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
generator_loss (torch.tensor): The generator loss
|
75 |
+
"""
|
76 |
+
generator_loss = -1.0 * torch.mean(critic_fake_prediction)
|
77 |
+
return generator_loss
|
weights/{epoch=957-step=1164300.ckpt → epoch=299-step=450000.ckpt}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8db76fefe85e1e6077bf8aa7532699d379ea0dc1f0d1b6130ebac9ca3b814303
|
3 |
+
size 640723340
|
weights/source.txt
CHANGED
@@ -1,2 +1,2 @@
|
|
1 |
-
|
2 |
-
- https://www.kaggle.com/datasets/dimensioncore/conditional-gan-part-2/versions/
|
|
|
1 |
+
- Notebook: [VERSION 19] https://www.kaggle.com/code/dimensioncore/57894-conditional-gan-try-part-2?scriptVersionId=197222851
|
2 |
+
- Checkpoint: [VERSION 3016] https://www.kaggle.com/datasets/dimensioncore/conditional-gan-part-2/versions/3016
|