Spaces:
Sleeping
Sleeping
Muhammad Naufal Rizqullah
commited on
Commit
·
21a662b
1
Parent(s):
36cfe0b
first commit
Browse files- .gitignore +30 -0
- LICENSE +21 -0
- app.py +44 -0
- config/__init__.py +0 -0
- config/core.py +34 -0
- models/__init__.py +0 -0
- models/base.py +181 -0
- models/lightning.py +138 -0
- requirements.txt +3 -0
- utility/__init__.py +0 -0
- utility/helper.py +60 -0
- weights/epoch=999-step=96000.ckpt +3 -0
- weights/source.txt +2 -0
.gitignore
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.idea
|
2 |
+
.ipynb_checkpoints
|
3 |
+
.mypy_cache
|
4 |
+
.vscode
|
5 |
+
__pycache__
|
6 |
+
.pytest_cache
|
7 |
+
htmlcov
|
8 |
+
dist
|
9 |
+
site
|
10 |
+
.coverage
|
11 |
+
coverage.xml
|
12 |
+
.netlify
|
13 |
+
test.db
|
14 |
+
log.txt
|
15 |
+
Pipfile.lock
|
16 |
+
env3.*
|
17 |
+
env
|
18 |
+
docs_build
|
19 |
+
site_build
|
20 |
+
venv
|
21 |
+
docs.zip
|
22 |
+
archive.zip
|
23 |
+
|
24 |
+
# vim temporary files
|
25 |
+
*~
|
26 |
+
.*.sw?
|
27 |
+
.cache
|
28 |
+
|
29 |
+
# macOS
|
30 |
+
.DS_Store
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Muhammad Naufal Rizqullah
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
app.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from PIL import Image
|
3 |
+
import numpy as np
|
4 |
+
import gradio as gr
|
5 |
+
|
6 |
+
from config.core import config
|
7 |
+
from utility.helper import load_model_weights, init_generator_model, get_selected_value
|
8 |
+
|
9 |
+
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
10 |
+
|
11 |
+
model = init_generator_model()
|
12 |
+
model = load_model_weights(config.CKPT_PATH, model, DEVICE, "generator")
|
13 |
+
model.eval()
|
14 |
+
|
15 |
+
def inference(choice):
|
16 |
+
z = torch.randn(1, config.INPUT_Z_DIM, 1, 1).to(DEVICE)
|
17 |
+
label = torch.tensor([get_selected_value(choice)], device=DEVICE)
|
18 |
+
|
19 |
+
image_tensor = model(z, label)
|
20 |
+
|
21 |
+
image_tensor = (image_tensor + 1) / 2 # Shift and scale to 0 to 1
|
22 |
+
image_unflat = image_tensor.detach().cpu().squeeze(0) # Remove batch dimension
|
23 |
+
image = image_unflat.permute(1, 2, 0) # Permute to (H, W, C)
|
24 |
+
|
25 |
+
# Convert image to numpy array
|
26 |
+
image_array = image.numpy()
|
27 |
+
|
28 |
+
# Scale values to 0-255 range
|
29 |
+
image_array = (image_array * 255).astype(np.uint8)
|
30 |
+
|
31 |
+
# Convert numpy array to PIL Image
|
32 |
+
image = Image.fromarray(image_array)
|
33 |
+
|
34 |
+
return image
|
35 |
+
|
36 |
+
demo = gr.Interface(
|
37 |
+
fn=inference,
|
38 |
+
inputs=gr.Dropdown(choices=list(config.OPTIONS_MAPPING.keys()), label="Select an option to Generates Images"),
|
39 |
+
outputs=gr.Image(),
|
40 |
+
title="Shoe, Sandal, Boot - Conditional GAN",
|
41 |
+
description="Conditional WGAN-GP",
|
42 |
+
)
|
43 |
+
|
44 |
+
demo.launch()
|
config/__init__.py
ADDED
File without changes
|
config/core.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseSettings
|
2 |
+
|
3 |
+
class Config(BaseSettings):
|
4 |
+
IMAGE_CHANNEL: int = 3
|
5 |
+
NUM_CLASSES: int = 3
|
6 |
+
IMAGE_SIZE: int = 128
|
7 |
+
FEATURES_DISCRIMINATOR: int = 64
|
8 |
+
FEATURES_GENERATOR: int = 64
|
9 |
+
EMBED_SIZE: int = 64
|
10 |
+
INPUT_Z_DIM: int = 64
|
11 |
+
BATCH_SIZE: int = 128
|
12 |
+
DISPLAY_STEP: int = 500
|
13 |
+
MAX_SAMPLES: int = 3000
|
14 |
+
|
15 |
+
LEARNING_RATE: float = 0.0002
|
16 |
+
BETA_1: float = 0.5
|
17 |
+
BETA_2: float = 0.999
|
18 |
+
C_LAMBDA: int = 10
|
19 |
+
|
20 |
+
NUM_EPOCH: int = 200 * 5
|
21 |
+
|
22 |
+
CRITIC_REPEAT: int = 3
|
23 |
+
|
24 |
+
LOAD_CHECKPOINT: bool = True
|
25 |
+
PATH_DATASET: str = ""
|
26 |
+
CKPT_PATH: str = "./weights/epoch=999-step=96000.ckpt"
|
27 |
+
|
28 |
+
OPTIONS_MAPPING: dict = {
|
29 |
+
"Boot": 0,
|
30 |
+
"Sandal": 1,
|
31 |
+
"Shoe": 2
|
32 |
+
}
|
33 |
+
|
34 |
+
config = Config()
|
models/__init__.py
ADDED
File without changes
|
models/base.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class Discriminator(nn.Module):
|
5 |
+
"""Discriminator model for Conditional GAN.
|
6 |
+
|
7 |
+
Args:
|
8 |
+
num_classes (int): Number of classes in the dataset.
|
9 |
+
image_size (int): Size of the input images (assumes square images).
|
10 |
+
features_discriminator (int): Number of feature maps in the first layer of the discriminator.
|
11 |
+
image_channel (int): Number of channels in the input image.
|
12 |
+
|
13 |
+
Attributes:
|
14 |
+
disc (nn.Sequential): The sequential layers that define the discriminator.
|
15 |
+
embed (nn.Embedding): Embedding layer to encode labels into image-like format.
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, num_classes=3, image_size=128, features_discriminator=128, image_channel=3):
|
19 |
+
super().__init__()
|
20 |
+
|
21 |
+
self.num_classes = num_classes
|
22 |
+
self.image_size = image_size
|
23 |
+
|
24 |
+
label_channel = 1
|
25 |
+
|
26 |
+
self.disc = nn.Sequential(
|
27 |
+
self._block_discriminator(image_channel + label_channel, features_discriminator, kernel_size=4, stride=2, padding=1),
|
28 |
+
self._block_discriminator(features_discriminator, features_discriminator, kernel_size=4, stride=2, padding=1),
|
29 |
+
self._block_discriminator(features_discriminator, features_discriminator * 2, kernel_size=4, stride=2, padding=1),
|
30 |
+
self._block_discriminator(features_discriminator * 2, features_discriminator * 4, kernel_size=4, stride=2, padding=1),
|
31 |
+
self._block_discriminator(features_discriminator * 4, features_discriminator * 4, kernel_size=4, stride=2, padding=1),
|
32 |
+
self._block_discriminator(features_discriminator * 4, 1, kernel_size=4, stride=1, padding=0, final_layer=True)
|
33 |
+
)
|
34 |
+
|
35 |
+
self.embed = nn.Embedding(num_classes, image_size * image_size)
|
36 |
+
|
37 |
+
def forward(self, image, label):
|
38 |
+
"""Forward pass for the discriminator.
|
39 |
+
|
40 |
+
Args:
|
41 |
+
image (torch.Tensor): Batch of input images.
|
42 |
+
label (torch.Tensor): Corresponding labels for the images.
|
43 |
+
|
44 |
+
Returns:
|
45 |
+
torch.Tensor: Discriminator output.
|
46 |
+
"""
|
47 |
+
# Embed label into an image-like format
|
48 |
+
embedding = self.embed(label)
|
49 |
+
embedding = embedding.view(
|
50 |
+
label.shape[0],
|
51 |
+
1,
|
52 |
+
self.image_size,
|
53 |
+
self.image_size
|
54 |
+
) # Reshape into 1-channel image
|
55 |
+
|
56 |
+
data = torch.cat([image, embedding], dim=1) # Concatenate image with the label channel
|
57 |
+
|
58 |
+
x = self.disc(data)
|
59 |
+
|
60 |
+
return x.view(len(x), -1)
|
61 |
+
|
62 |
+
def _block_discriminator(self, input_channels, output_channels, kernel_size=3, stride=2, padding=0, final_layer=False):
|
63 |
+
"""Creates a convolutional block for the discriminator.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
input_channels (int): Number of input channels for the convolutional layer.
|
67 |
+
output_channels (int): Number of output channels for the convolutional layer.
|
68 |
+
kernel_size (int): Size of the kernel for the convolutional layer.
|
69 |
+
stride (int): Stride of the convolutional layer.
|
70 |
+
padding (int): Padding for the convolutional layer.
|
71 |
+
final_layer (bool): If True, this is the final layer, which doesn't include normalization or activation.
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
nn.Sequential: Sequential block for the discriminator.
|
75 |
+
"""
|
76 |
+
if not final_layer:
|
77 |
+
return nn.Sequential(
|
78 |
+
nn.Conv2d(input_channels, output_channels, kernel_size, stride, padding),
|
79 |
+
nn.InstanceNorm2d(output_channels, affine=True),
|
80 |
+
nn.LeakyReLU(0.2)
|
81 |
+
)
|
82 |
+
else:
|
83 |
+
return nn.Sequential(
|
84 |
+
nn.Conv2d(input_channels, output_channels, kernel_size, stride, padding),
|
85 |
+
)
|
86 |
+
|
87 |
+
class Generator(nn.Module):
|
88 |
+
"""Generator model for Conditional GAN.
|
89 |
+
|
90 |
+
Args:
|
91 |
+
embed_size (int): Size of the embedding vector for the labels.
|
92 |
+
num_classes (int): Number of classes in the dataset.
|
93 |
+
image_size (int): Size of the output images (assumes square images).
|
94 |
+
features_generator (int): Number of feature maps in the first layer of the generator.
|
95 |
+
input_dim (int): Dimensionality of the noise vector.
|
96 |
+
image_channel (int): Number of channels in the output image.
|
97 |
+
|
98 |
+
Attributes:
|
99 |
+
gen (nn.Sequential): The sequential layers that define the generator.
|
100 |
+
embed (nn.Embedding): Embedding layer to encode labels.
|
101 |
+
"""
|
102 |
+
|
103 |
+
def __init__(self, embed_size=128, num_classes=3, image_size=128, features_generator=128, input_dim=128, image_channel=3):
|
104 |
+
super(Generator, self).__init__()
|
105 |
+
|
106 |
+
self.gen = nn.Sequential(
|
107 |
+
self._block(input_dim + embed_size, features_generator * 2, first_double_up=True),
|
108 |
+
self._block(features_generator * 2, features_generator * 4, first_double_up=False, final_layer=False),
|
109 |
+
self._block(features_generator * 4, features_generator * 4, first_double_up=False, final_layer=False),
|
110 |
+
self._block(features_generator * 4, features_generator * 4, first_double_up=False, final_layer=False),
|
111 |
+
self._block(features_generator * 4, features_generator * 2, first_double_up=False, final_layer=False),
|
112 |
+
self._block(features_generator * 2, features_generator, first_double_up=False, final_layer=False),
|
113 |
+
self._block(features_generator, image_channel, first_double_up=False, use_double=False, final_layer=True),
|
114 |
+
)
|
115 |
+
|
116 |
+
self.image_size = image_size
|
117 |
+
self.embed_size = embed_size
|
118 |
+
|
119 |
+
self.embed = nn.Embedding(num_classes, embed_size)
|
120 |
+
|
121 |
+
def forward(self, noise, labels):
|
122 |
+
"""Forward pass for the generator.
|
123 |
+
|
124 |
+
Args:
|
125 |
+
noise (torch.Tensor): Batch of input noise vectors.
|
126 |
+
labels (torch.Tensor): Corresponding labels for the noise vectors.
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
torch.Tensor: Generated images.
|
130 |
+
"""
|
131 |
+
embedding_label = self.embed(labels).unsqueeze(2).unsqueeze(3) # Reshape to (batch_size, embed_size, 1, 1)
|
132 |
+
|
133 |
+
noise = noise.view(noise.size(0), noise.size(1), 1, 1) # Reshape to (batch_size, z_dim, 1, 1)
|
134 |
+
|
135 |
+
x = torch.cat([noise, embedding_label], dim=1)
|
136 |
+
|
137 |
+
return self.gen(x)
|
138 |
+
|
139 |
+
def _block(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,
|
140 |
+
first_double_up=False, use_double=True, final_layer=False):
|
141 |
+
"""Creates a convolutional block for the generator.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
in_channels (int): Number of input channels for the convolutional layer.
|
145 |
+
out_channels (int): Number of output channels for the convolutional layer.
|
146 |
+
kernel_size (int): Size of the kernel for the convolutional layer.
|
147 |
+
stride (int): Stride of the convolutional layer.
|
148 |
+
padding (int): Padding for the convolutional layer.
|
149 |
+
first_double_up (bool): If True, the first layer uses a different upsampling strategy.
|
150 |
+
use_double (bool): If True, the block includes an upsampling layer.
|
151 |
+
final_layer (bool): If True, this is the final layer, which uses Tanh activation.
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
nn.Sequential: Sequential block for the generator.
|
155 |
+
"""
|
156 |
+
layers = []
|
157 |
+
|
158 |
+
if not final_layer:
|
159 |
+
# Add first convolutional layer
|
160 |
+
layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding))
|
161 |
+
layers.append(nn.BatchNorm2d(out_channels))
|
162 |
+
layers.append(nn.LeakyReLU(0.2))
|
163 |
+
|
164 |
+
# Add second convolutional layer
|
165 |
+
layers.append(nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding))
|
166 |
+
layers.append(nn.BatchNorm2d(out_channels))
|
167 |
+
layers.append(nn.LeakyReLU(0.2))
|
168 |
+
else:
|
169 |
+
layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding))
|
170 |
+
layers.append(nn.Tanh())
|
171 |
+
|
172 |
+
if use_double:
|
173 |
+
if first_double_up:
|
174 |
+
layers.append(nn.ConvTranspose2d(out_channels, out_channels, 4, 1, 0))
|
175 |
+
else:
|
176 |
+
layers.append(nn.ConvTranspose2d(out_channels, out_channels, 4, 2, 1))
|
177 |
+
|
178 |
+
layers.append(nn.BatchNorm2d(out_channels))
|
179 |
+
layers.append(nn.LeakyReLU(0.2))
|
180 |
+
|
181 |
+
return nn.Sequential(*layers)
|
models/lightning.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.optim as optim
|
3 |
+
import lightning as L
|
4 |
+
from .base import Discriminator, Generator
|
5 |
+
|
6 |
+
class ConditionalWGAN_GP(L.LightningModule):
|
7 |
+
"""Conditional WGAN-GP implementation using PyTorch Lightning.
|
8 |
+
|
9 |
+
Attributes:
|
10 |
+
image_size: Size of the generated images.
|
11 |
+
critic_repeats: Number of critic iterations per generator iteration.
|
12 |
+
c_lambda: Gradient penalty lambda hyperparameter.
|
13 |
+
generator: The generator model.
|
14 |
+
critic: The discriminator (critic) model.
|
15 |
+
critic_losses: List to store critic loss values.
|
16 |
+
generator_losses: List to store generator loss values.
|
17 |
+
curr_step: The current training step.
|
18 |
+
fixed_latent_space: Fixed latent vectors for generating consistent images.
|
19 |
+
fixed_label: Fixed labels corresponding to the latent vectors.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, image_size, learning_rate, z_dim, embed_size, num_classes,
|
23 |
+
critic_repeats, feature_gen, feature_critic, c_lambda, beta_1,
|
24 |
+
beta_2, display_step):
|
25 |
+
"""Initializes the Conditional WGAN-GP model.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
image_size: Size of the generated images.
|
29 |
+
learning_rate: Learning rate for the optimizers.
|
30 |
+
z_dim: Dimension of the latent space.
|
31 |
+
embed_size: Size of the embedding for the labels.
|
32 |
+
num_classes: Number of classes for the conditional generation.
|
33 |
+
critic_repeats: Number of critic iterations per generator iteration.
|
34 |
+
feature_gen: Number of features for the generator.
|
35 |
+
feature_critic: Number of features for the critic.
|
36 |
+
c_lambda: Gradient penalty lambda hyperparameter.
|
37 |
+
beta_1: Beta1 parameter for the Adam optimizer.
|
38 |
+
beta_2: Beta2 parameter for the Adam optimizer.
|
39 |
+
display_step: Step interval for displaying generated images.
|
40 |
+
"""
|
41 |
+
super().__init__()
|
42 |
+
|
43 |
+
self.automatic_optimization = False
|
44 |
+
|
45 |
+
self.image_size = image_size
|
46 |
+
self.critic_repeats = critic_repeats
|
47 |
+
self.c_lambda = c_lambda
|
48 |
+
|
49 |
+
self.generator = Generator(
|
50 |
+
embed_size=embed_size,
|
51 |
+
num_classes=num_classes,
|
52 |
+
image_size=image_size,
|
53 |
+
features_generator=feature_gen,
|
54 |
+
input_dim=z_dim,
|
55 |
+
)
|
56 |
+
|
57 |
+
self.critic = Discriminator(
|
58 |
+
num_classes=num_classes,
|
59 |
+
image_size=image_size,
|
60 |
+
features_discriminator=feature_critic,
|
61 |
+
)
|
62 |
+
|
63 |
+
self.critic_losses = []
|
64 |
+
self.generator_losses = []
|
65 |
+
self.curr_step = 0
|
66 |
+
|
67 |
+
self.fixed_latent_space = torch.randn(25, z_dim, 1, 1)
|
68 |
+
self.fixed_label = torch.tensor([i % num_classes for i in range(25)])
|
69 |
+
|
70 |
+
self.save_hyperparameters()
|
71 |
+
|
72 |
+
def configure_optimizers(self):
|
73 |
+
"""Configures the optimizers for the generator and critic.
|
74 |
+
|
75 |
+
Returns:
|
76 |
+
A tuple of two Adam optimizers, one for the generator and one for the critic.
|
77 |
+
"""
|
78 |
+
optimizer_g = optim.Adam(
|
79 |
+
self.generator.parameters(),
|
80 |
+
lr=self.hparams.learning_rate,
|
81 |
+
betas=(self.hparams.beta_1, self.hparams.beta_2),
|
82 |
+
)
|
83 |
+
optimizer_c = optim.Adam(
|
84 |
+
self.critic.parameters(),
|
85 |
+
lr=self.hparams.learning_rate,
|
86 |
+
betas=(self.hparams.beta_1, self.hparams.beta_2),
|
87 |
+
)
|
88 |
+
|
89 |
+
return optimizer_g, optimizer_c
|
90 |
+
|
91 |
+
def on_load_checkpoint(self, checkpoint):
|
92 |
+
"""Loads necessary variables from a checkpoint.
|
93 |
+
|
94 |
+
Args:
|
95 |
+
checkpoint: The checkpoint dictionary.
|
96 |
+
"""
|
97 |
+
if self.current_epoch != 0:
|
98 |
+
self.critic_losses = checkpoint['critic_losses']
|
99 |
+
self.generator_losses = checkpoint['generator_losses']
|
100 |
+
self.curr_step = checkpoint['curr_step']
|
101 |
+
self.fixed_latent_space = checkpoint['fixed_latent_space']
|
102 |
+
self.fixed_label = checkpoint['fixed_label']
|
103 |
+
|
104 |
+
def on_save_checkpoint(self, checkpoint):
|
105 |
+
"""Saves necessary variables to a checkpoint.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
checkpoint: The checkpoint dictionary.
|
109 |
+
"""
|
110 |
+
checkpoint['critic_losses'] = self.critic_losses
|
111 |
+
checkpoint['generator_losses'] = self.generator_losses
|
112 |
+
checkpoint['curr_step'] = self.curr_step
|
113 |
+
checkpoint['fixed_latent_space'] = self.fixed_latent_space
|
114 |
+
checkpoint['fixed_label'] = self.fixed_label
|
115 |
+
|
116 |
+
def forward(self, noise, labels):
|
117 |
+
"""Generates an image given noise and labels.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
noise: Latent noise vector.
|
121 |
+
labels: Class labels for conditional generation.
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
Generated image tensor.
|
125 |
+
"""
|
126 |
+
return self.generator(noise, labels)
|
127 |
+
|
128 |
+
def predict_step(self, noise, labels):
|
129 |
+
"""Predicts an image given noise and labels.
|
130 |
+
|
131 |
+
Args:
|
132 |
+
noise: Latent noise vector.
|
133 |
+
labels: Class labels for conditional generation.
|
134 |
+
|
135 |
+
Returns:
|
136 |
+
Generated image tensor.
|
137 |
+
"""
|
138 |
+
return self.generator(noise, labels)
|
requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.1.2
|
2 |
+
pytorch-lightning==2.3.3
|
3 |
+
python-multipart
|
utility/__init__.py
ADDED
File without changes
|
utility/helper.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from config.core import config
|
4 |
+
from models.base import Generator
|
5 |
+
|
6 |
+
def load_model_weights(checkpoint_path, model, device, prefix):
|
7 |
+
"""
|
8 |
+
Load specific weights from a PyTorch Lightning checkpoint into a model.
|
9 |
+
|
10 |
+
Parameters:
|
11 |
+
checkpoint_path (str): Path to the checkpoint file.
|
12 |
+
model (torch.nn.Module): The model instance to load weights into.
|
13 |
+
prefix (str): The prefix in the checkpoint's state_dict keys to filter by and remove.
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
model (torch.nn.Module): The model with loaded weights.
|
17 |
+
"""
|
18 |
+
# Load the checkpoint
|
19 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
20 |
+
|
21 |
+
# Extract and modify the state_dict keys to match the model's keys
|
22 |
+
model_weights = {k.replace(f"{prefix}.", ""): v for k, v in checkpoint["state_dict"].items() if k.startswith(f"{prefix}.")}
|
23 |
+
|
24 |
+
# Load the weights into the model
|
25 |
+
model.load_state_dict(model_weights)
|
26 |
+
|
27 |
+
return model
|
28 |
+
|
29 |
+
def init_generator_model():
|
30 |
+
"""
|
31 |
+
Initializes and returns the Generator model.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
None.
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
Generator: The initialized Generator model.
|
38 |
+
"""
|
39 |
+
model = Generator(
|
40 |
+
embed_size=config.EMBED_SIZE,
|
41 |
+
num_classes=config.NUM_CLASSES,
|
42 |
+
image_size=config.IMAGE_SIZE,
|
43 |
+
features_generator=config.FEATURES_GENERATOR,
|
44 |
+
input_dim=config.INPUT_Z_DIM,
|
45 |
+
image_channel=config.IMAGE_CHANNEL
|
46 |
+
)
|
47 |
+
return model
|
48 |
+
|
49 |
+
def get_selected_value(label):
|
50 |
+
"""
|
51 |
+
Get the selected value based on the display label.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
label (str): The display label.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
int: The selected value corresponding to the display label.
|
58 |
+
"""
|
59 |
+
# Get the selected value from the options mapping based on the display label.
|
60 |
+
return config.OPTIONS_MAPPING[label]
|
weights/epoch=999-step=96000.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:93409ee79fbe7ecfbcd95fe775a7625408e088624c0e80153e14c234c93d8132
|
3 |
+
size 116330608
|
weights/source.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
using a weight from kaggle after training 600 epoch:
|
2 |
+
- https://www.kaggle.com/datasets/dimensioncore/conditional-gan-part-2/versions/1020
|