Spaces:
Sleeping
Sleeping
Muhammad Naufal Rizqullah
commited on
Commit
·
eb42124
1
Parent(s):
2b42608
change architecture, model, and etc
Browse files- README.md +49 -13
- config/core.py +1 -1
- models/base.py +68 -149
- models/discriminator.py +64 -0
- models/generator.py +57 -0
- utility/helper.py +4 -1
- weights/source.txt +2 -2
README.md
CHANGED
@@ -1,13 +1,49 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Conditional GAN Shoe, Sandal, Boot
|
2 |
+
=====================================
|
3 |
+
|
4 |
+
## Project Overview
|
5 |
+
---------------
|
6 |
+
|
7 |
+
This project implements a Conditional Generative Adversarial Network (CGAN) using the Conditional WGAN-GP architecture to generate images of shoes, sandals, and boots.
|
8 |
+
|
9 |
+
### Key Features
|
10 |
+
------------
|
11 |
+
|
12 |
+
* **Conditional WGAN-GP architecture**: for generating images of shoes, sandals, and boots
|
13 |
+
* **Trained on a limited dataset**: due to VRAM constraints
|
14 |
+
* **Model architecture and hyperparameters optimized**: for Kaggle environment with 15GB VRAM
|
15 |
+
* **Embedding used instead of one-hot encoding**: for training labels to avoid using 0 on labels
|
16 |
+
* **Implemented using PyTorch Lightning framework**
|
17 |
+
|
18 |
+
## Training Details
|
19 |
+
---------------
|
20 |
+
|
21 |
+
### Training Epochs
|
22 |
+
----------------
|
23 |
+
|
24 |
+
* Approximately 1000 epochs
|
25 |
+
|
26 |
+
### Model Architecture Compromises
|
27 |
+
-----------------------------
|
28 |
+
|
29 |
+
* **Reduced Size of latent space (z dim) and Embedding**: due to VRAM limitations
|
30 |
+
* **Limited features for generator and critic networks**: due to VRAM limitations
|
31 |
+
* **Image size limitations**: due to VRAM limitations
|
32 |
+
* **Dataset used**: [Shoe vs Sandal vs Boot Image Dataset (15K Images)](https://www.kaggle.com/datasets/hasibalmuzdadid/shoe-vs-sandal-vs-boot-dataset-15k-images)
|
33 |
+
|
34 |
+
## Results
|
35 |
+
---------
|
36 |
+
|
37 |
+
Despite the architectural compromises, the model produces reasonable results. However, the quality of the generated images may not be optimal due to the limited dataset and VRAM constraints.
|
38 |
+
|
39 |
+
## Demo
|
40 |
+
-------------
|
41 |
+
The demo of this project is deployed on Hugging Face's model hub and uses the Gradio framework to provide a user-friendly interface for interacting with the model. You can try out the demo by visiting [this link](https://huggingface.co/spaces/SkylarWhite/57894).
|
42 |
+
|
43 |
+
|
44 |
+
## Future Work
|
45 |
+
-------------
|
46 |
+
|
47 |
+
* **Experiment with larger datasets and more complex model architectures**
|
48 |
+
* **Investigate alternative optimization techniques to improve model performance**
|
49 |
+
* **Explore other applications of Conditional GANs in computer vision**
|
config/core.py
CHANGED
@@ -23,7 +23,7 @@ class Config(BaseSettings):
|
|
23 |
|
24 |
LOAD_CHECKPOINT: bool = True
|
25 |
PATH_DATASET: str = ""
|
26 |
-
CKPT_PATH: str = "./weights/epoch=
|
27 |
|
28 |
OPTIONS_MAPPING: dict = {
|
29 |
"Boot": 0,
|
|
|
23 |
|
24 |
LOAD_CHECKPOINT: bool = True
|
25 |
PATH_DATASET: str = ""
|
26 |
+
CKPT_PATH: str = "./weights/epoch=957-step=1164300.ckpt"
|
27 |
|
28 |
OPTIONS_MAPPING: dict = {
|
29 |
"Boot": 0,
|
models/base.py
CHANGED
@@ -1,181 +1,100 @@
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
|
4 |
-
|
5 |
-
|
|
|
|
|
6 |
|
7 |
Args:
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
disc (nn.Sequential): The sequential layers that define the discriminator.
|
15 |
-
embed (nn.Embedding): Embedding layer to encode labels into image-like format.
|
16 |
"""
|
17 |
|
18 |
-
def __init__(self,
|
19 |
super().__init__()
|
20 |
-
|
21 |
-
self.
|
22 |
-
self.
|
23 |
-
|
24 |
-
label_channel = 1
|
25 |
-
|
26 |
-
self.disc = nn.Sequential(
|
27 |
-
self._block_discriminator(image_channel + label_channel, features_discriminator, kernel_size=4, stride=2, padding=1),
|
28 |
-
self._block_discriminator(features_discriminator, features_discriminator, kernel_size=4, stride=2, padding=1),
|
29 |
-
self._block_discriminator(features_discriminator, features_discriminator * 2, kernel_size=4, stride=2, padding=1),
|
30 |
-
self._block_discriminator(features_discriminator * 2, features_discriminator * 4, kernel_size=4, stride=2, padding=1),
|
31 |
-
self._block_discriminator(features_discriminator * 4, features_discriminator * 4, kernel_size=4, stride=2, padding=1),
|
32 |
-
self._block_discriminator(features_discriminator * 4, 1, kernel_size=4, stride=1, padding=0, final_layer=True)
|
33 |
-
)
|
34 |
-
|
35 |
-
self.embed = nn.Embedding(num_classes, image_size * image_size)
|
36 |
-
|
37 |
-
def forward(self, image, label):
|
38 |
-
"""Forward pass for the discriminator.
|
39 |
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
|
44 |
-
|
45 |
-
torch.Tensor: Discriminator output.
|
46 |
"""
|
47 |
-
|
48 |
-
embedding = self.embed(label)
|
49 |
-
embedding = embedding.view(
|
50 |
-
label.shape[0],
|
51 |
-
1,
|
52 |
-
self.image_size,
|
53 |
-
self.image_size
|
54 |
-
) # Reshape into 1-channel image
|
55 |
-
|
56 |
-
data = torch.cat([image, embedding], dim=1) # Concatenate image with the label channel
|
57 |
-
|
58 |
-
x = self.disc(data)
|
59 |
-
|
60 |
-
return x.view(len(x), -1)
|
61 |
-
|
62 |
-
def _block_discriminator(self, input_channels, output_channels, kernel_size=3, stride=2, padding=0, final_layer=False):
|
63 |
-
"""Creates a convolutional block for the discriminator.
|
64 |
|
65 |
Args:
|
66 |
-
|
67 |
-
output_channels (int): Number of output channels for the convolutional layer.
|
68 |
-
kernel_size (int): Size of the kernel for the convolutional layer.
|
69 |
-
stride (int): Stride of the convolutional layer.
|
70 |
-
padding (int): Padding for the convolutional layer.
|
71 |
-
final_layer (bool): If True, this is the final layer, which doesn't include normalization or activation.
|
72 |
|
73 |
Returns:
|
74 |
-
|
75 |
"""
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
else:
|
83 |
-
return nn.Sequential(
|
84 |
-
nn.Conv2d(input_channels, output_channels, kernel_size, stride, padding),
|
85 |
-
)
|
86 |
-
|
87 |
-
class Generator(nn.Module):
|
88 |
-
"""Generator model for Conditional GAN.
|
89 |
|
90 |
Args:
|
91 |
-
|
92 |
-
num_classes (int): Number of classes in the dataset.
|
93 |
-
image_size (int): Size of the output images (assumes square images).
|
94 |
-
features_generator (int): Number of feature maps in the first layer of the generator.
|
95 |
-
input_dim (int): Dimensionality of the noise vector.
|
96 |
-
image_channel (int): Number of channels in the output image.
|
97 |
-
|
98 |
-
Attributes:
|
99 |
-
gen (nn.Sequential): The sequential layers that define the generator.
|
100 |
-
embed (nn.Embedding): Embedding layer to encode labels.
|
101 |
"""
|
102 |
|
103 |
-
def __init__(self,
|
104 |
-
super(
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
self._block(features_generator * 4, features_generator * 4, first_double_up=False, final_layer=False),
|
111 |
-
self._block(features_generator * 4, features_generator * 2, first_double_up=False, final_layer=False),
|
112 |
-
self._block(features_generator * 2, features_generator, first_double_up=False, final_layer=False),
|
113 |
-
self._block(features_generator, image_channel, first_double_up=False, use_double=False, final_layer=True),
|
114 |
-
)
|
115 |
-
|
116 |
-
self.image_size = image_size
|
117 |
-
self.embed_size = embed_size
|
118 |
-
|
119 |
-
self.embed = nn.Embedding(num_classes, embed_size)
|
120 |
-
|
121 |
-
def forward(self, noise, labels):
|
122 |
-
"""Forward pass for the generator.
|
123 |
|
124 |
Args:
|
125 |
-
|
126 |
-
labels (torch.Tensor): Corresponding labels for the noise vectors.
|
127 |
|
128 |
Returns:
|
129 |
-
torch.Tensor:
|
130 |
"""
|
131 |
-
|
132 |
-
|
133 |
-
noise = noise.view(noise.size(0), noise.size(1), 1, 1) # Reshape to (batch_size, z_dim, 1, 1)
|
134 |
|
135 |
-
x = torch.cat([noise, embedding_label], dim=1)
|
136 |
-
|
137 |
-
return self.gen(x)
|
138 |
|
139 |
-
|
140 |
-
|
141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
142 |
|
143 |
Args:
|
144 |
-
|
145 |
-
out_channels (int): Number of output channels for the convolutional layer.
|
146 |
-
kernel_size (int): Size of the kernel for the convolutional layer.
|
147 |
-
stride (int): Stride of the convolutional layer.
|
148 |
-
padding (int): Padding for the convolutional layer.
|
149 |
-
first_double_up (bool): If True, the first layer uses a different upsampling strategy.
|
150 |
-
use_double (bool): If True, the block includes an upsampling layer.
|
151 |
-
final_layer (bool): If True, this is the final layer, which uses Tanh activation.
|
152 |
|
153 |
Returns:
|
154 |
-
|
155 |
"""
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
layers.append(nn.LeakyReLU(0.2))
|
163 |
-
|
164 |
-
# Add second convolutional layer
|
165 |
-
layers.append(nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding))
|
166 |
-
layers.append(nn.BatchNorm2d(out_channels))
|
167 |
-
layers.append(nn.LeakyReLU(0.2))
|
168 |
-
else:
|
169 |
-
layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding))
|
170 |
-
layers.append(nn.Tanh())
|
171 |
-
|
172 |
-
if use_double:
|
173 |
-
if first_double_up:
|
174 |
-
layers.append(nn.ConvTranspose2d(out_channels, out_channels, 4, 1, 0))
|
175 |
-
else:
|
176 |
-
layers.append(nn.ConvTranspose2d(out_channels, out_channels, 4, 2, 1))
|
177 |
-
|
178 |
-
layers.append(nn.BatchNorm2d(out_channels))
|
179 |
-
layers.append(nn.LeakyReLU(0.2))
|
180 |
-
|
181 |
-
return nn.Sequential(*layers)
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
|
4 |
+
|
5 |
+
class WSConv2d(nn.Module):
|
6 |
+
"""
|
7 |
+
A 2D convolutional layer with weight scaling.
|
8 |
|
9 |
Args:
|
10 |
+
in_channels (int): Number of input channels.
|
11 |
+
out_channels (int): Number of output channels.
|
12 |
+
kernel_size (int, optional): Size of the convolving kernel. Default is 3.
|
13 |
+
stride (int, optional): Stride of the convolution. Default is 1.
|
14 |
+
padding (int, optional): Zero-padding added to both sides of the input. Default is 1.
|
15 |
+
gain (float, optional): Gain factor for weight scaling. Default is 2.
|
|
|
|
|
16 |
"""
|
17 |
|
18 |
+
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, gain=2):
|
19 |
super().__init__()
|
20 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
|
21 |
+
self.scale = (gain / (in_channels * (kernel_size ** 2))) ** 0.5
|
22 |
+
self.bias = self.conv.bias
|
23 |
+
self.conv.bias = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
+
# Initialize Conv Layer
|
26 |
+
nn.init.normal_(self.conv.weight)
|
27 |
+
nn.init.zeros_(self.bias)
|
28 |
|
29 |
+
def forward(self, x):
|
|
|
30 |
"""
|
31 |
+
Forward pass of the WSConv2d layer.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
Args:
|
34 |
+
x (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width).
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
Returns:
|
37 |
+
torch.Tensor: Output tensor after applying convolution, weight scaling, and bias addition.
|
38 |
"""
|
39 |
+
return self.conv(x * self.scale) + self.bias.view(1, self.bias.shape[0], 1, 1)
|
40 |
+
|
41 |
+
|
42 |
+
class PixelNorm(nn.Module):
|
43 |
+
"""
|
44 |
+
Pixel normalization layer.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
Args:
|
47 |
+
eps (float, optional): Small value to avoid division by zero. Default is 1e-8.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
"""
|
49 |
|
50 |
+
def __init__(self, eps=1e-8):
|
51 |
+
super(PixelNorm, self).__init__()
|
52 |
+
self.epsilon = eps
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
"""
|
56 |
+
Forward pass of the PixelNorm layer.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
|
58 |
Args:
|
59 |
+
x (torch.Tensor): Input tensor of shape (batch_size, channels, height, width).
|
|
|
60 |
|
61 |
Returns:
|
62 |
+
torch.Tensor: Normalized tensor.
|
63 |
"""
|
64 |
+
return x / torch.sqrt(torch.mean(x ** 2, dim=1, keepdim=True) + self.epsilon)
|
|
|
|
|
65 |
|
|
|
|
|
|
|
66 |
|
67 |
+
class ConvBlock(nn.Module):
|
68 |
+
"""
|
69 |
+
A block of two convolutional layers, with optional pixel normalization and LeakyReLU activation.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
in_channels (int): Number of input channels.
|
73 |
+
out_channels (int): Number of output channels.
|
74 |
+
use_pixelnorm (bool, optional): Whether to apply pixel normalization. Default is True.
|
75 |
+
"""
|
76 |
+
|
77 |
+
def __init__(self, in_channels, out_channels, use_pixelnorm=True):
|
78 |
+
super(ConvBlock, self).__init__()
|
79 |
+
self.use_pn = use_pixelnorm
|
80 |
+
self.conv1 = WSConv2d(in_channels, out_channels)
|
81 |
+
self.conv2 = WSConv2d(out_channels, out_channels)
|
82 |
+
self.leaky = nn.LeakyReLU(0.2)
|
83 |
+
self.pn = PixelNorm()
|
84 |
+
|
85 |
+
def forward(self, x):
|
86 |
+
"""
|
87 |
+
Forward pass of the ConvBlock.
|
88 |
|
89 |
Args:
|
90 |
+
x (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
Returns:
|
93 |
+
torch.Tensor: Output tensor after two convolutional layers, optional pixel normalization, and LeakyReLU activation.
|
94 |
"""
|
95 |
+
x = self.leaky(self.conv1(x))
|
96 |
+
x = self.pn(x) if self.use_pn else x
|
97 |
+
|
98 |
+
x = self.leaky(self.conv2(x))
|
99 |
+
x = self.pn(x) if self.use_pn else x
|
100 |
+
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/discriminator.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .base import WSConv2d, ConvBlock
|
5 |
+
|
6 |
+
|
7 |
+
class Discriminator(nn.Module):
|
8 |
+
def __init__(self, num_classes=3, image_size=128, features_discriminator=128, image_channel=3):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
self.num_classes = num_classes
|
12 |
+
self.image_size = image_size
|
13 |
+
label_channel = 1
|
14 |
+
|
15 |
+
self.disc = nn.Sequential(
|
16 |
+
self._block_discriminator(image_channel + label_channel, features_discriminator, kernel_size=4, stride=2,
|
17 |
+
padding=1),
|
18 |
+
self._block_discriminator(features_discriminator, features_discriminator, kernel_size=4, stride=2,
|
19 |
+
padding=1),
|
20 |
+
self._block_discriminator(features_discriminator, features_discriminator * 2, kernel_size=4, stride=2,
|
21 |
+
padding=1),
|
22 |
+
self._block_discriminator(features_discriminator * 2, features_discriminator * 4, kernel_size=4, stride=2,
|
23 |
+
padding=1),
|
24 |
+
self._block_discriminator(features_discriminator * 4, features_discriminator * 4, kernel_size=4, stride=2,
|
25 |
+
padding=1),
|
26 |
+
self._block_discriminator(features_discriminator * 4, 1, kernel_size=4, stride=1, padding=0,
|
27 |
+
final_layer=True)
|
28 |
+
)
|
29 |
+
|
30 |
+
self.embed = nn.Embedding(num_classes, image_size * image_size)
|
31 |
+
|
32 |
+
def forward(self, image, label):
|
33 |
+
embedding = self.embed(label)
|
34 |
+
embedding = embedding.view(
|
35 |
+
label.shape[0],
|
36 |
+
1,
|
37 |
+
self.image_size,
|
38 |
+
self.image_size
|
39 |
+
)
|
40 |
+
|
41 |
+
data = torch.cat([image, embedding], dim=1)
|
42 |
+
|
43 |
+
x = self.disc(data)
|
44 |
+
|
45 |
+
return x.view(len(x), -1)
|
46 |
+
|
47 |
+
def _block_discriminator(
|
48 |
+
self,
|
49 |
+
input_channels,
|
50 |
+
output_channels,
|
51 |
+
kernel_size=3,
|
52 |
+
stride=2,
|
53 |
+
padding=0,
|
54 |
+
final_layer=False
|
55 |
+
):
|
56 |
+
if not final_layer:
|
57 |
+
return nn.Sequential(
|
58 |
+
ConvBlock(input_channels, output_channels),
|
59 |
+
WSConv2d(output_channels, output_channels, kernel_size, stride, padding)
|
60 |
+
)
|
61 |
+
else:
|
62 |
+
return WSConv2d(input_channels, output_channels, kernel_size, stride, padding)
|
63 |
+
|
64 |
+
|
models/generator.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .base import WSConv2d, ConvBlock, PixelNorm
|
5 |
+
|
6 |
+
|
7 |
+
class Generator(nn.Module):
|
8 |
+
def __init__(self, embed_size=128, num_classes=3, image_size=128, features_generator=128, input_dim=128, image_channel=3):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
self.gen = nn.Sequential(
|
12 |
+
self._block(input_dim + embed_size, features_generator * 2, first_double_up=True),
|
13 |
+
self._block(features_generator * 2, features_generator * 4, first_double_up=False, final_layer=False, ),
|
14 |
+
self._block(features_generator * 4, features_generator * 4, first_double_up=False, final_layer=False, ),
|
15 |
+
self._block(features_generator * 4, features_generator * 4, first_double_up=False, final_layer=False, ),
|
16 |
+
self._block(features_generator * 4, features_generator * 2, first_double_up=False, final_layer=False, ),
|
17 |
+
self._block(features_generator * 2, features_generator, first_double_up=False, final_layer=False, ),
|
18 |
+
self._block(features_generator, image_channel, first_double_up=False, use_double=False, final_layer=True, ),
|
19 |
+
)
|
20 |
+
|
21 |
+
self.image_size = image_size
|
22 |
+
self.embed_size = embed_size
|
23 |
+
|
24 |
+
self.embed = nn.Embedding(num_classes, embed_size)
|
25 |
+
|
26 |
+
def forward(self, noise, labels):
|
27 |
+
embedding_label = self.embed(labels).unsqueeze(2).unsqueeze(
|
28 |
+
3) # Add height and width channel; N x Noise_dim x 1 x 1
|
29 |
+
|
30 |
+
# Noise is 4 channel, or 2 channel. later will decide
|
31 |
+
noise = noise.view(noise.size(0), noise.size(1), 1, 1) # Reshape to (batch_size, z_dim, 1, 1)
|
32 |
+
|
33 |
+
x = torch.cat([noise, embedding_label], dim=1)
|
34 |
+
|
35 |
+
return self.gen(x)
|
36 |
+
|
37 |
+
def _block(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1,
|
38 |
+
first_double_up=False, use_double=True, final_layer=False):
|
39 |
+
layers = []
|
40 |
+
|
41 |
+
if not final_layer:
|
42 |
+
layers.append(ConvBlock(in_channels, out_channels))
|
43 |
+
else:
|
44 |
+
layers.append(WSConv2d(in_channels, out_channels, kernel_size, stride, padding))
|
45 |
+
layers.append(nn.Tanh())
|
46 |
+
|
47 |
+
if use_double:
|
48 |
+
if first_double_up:
|
49 |
+
layers.append(nn.ConvTranspose2d(out_channels, out_channels, 4, 1, 0))
|
50 |
+
else:
|
51 |
+
layers.append(nn.ConvTranspose2d(out_channels, out_channels, 4, 2, 1))
|
52 |
+
|
53 |
+
layers.append(PixelNorm())
|
54 |
+
layers.append(nn.LeakyReLU(0.2))
|
55 |
+
|
56 |
+
return nn.Sequential(*layers)
|
57 |
+
|
utility/helper.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
import torch
|
2 |
|
3 |
from config.core import config
|
4 |
-
from models.
|
5 |
|
6 |
def load_model_weights(checkpoint_path, model, device, prefix):
|
7 |
"""
|
@@ -26,6 +26,9 @@ def load_model_weights(checkpoint_path, model, device, prefix):
|
|
26 |
|
27 |
return model
|
28 |
|
|
|
|
|
|
|
29 |
def init_generator_model():
|
30 |
"""
|
31 |
Initializes and returns the Generator model.
|
|
|
1 |
import torch
|
2 |
|
3 |
from config.core import config
|
4 |
+
from models.generator import Generator
|
5 |
|
6 |
def load_model_weights(checkpoint_path, model, device, prefix):
|
7 |
"""
|
|
|
26 |
|
27 |
return model
|
28 |
|
29 |
+
def load_latent_space(checkpoint_path):
|
30 |
+
pass
|
31 |
+
|
32 |
def init_generator_model():
|
33 |
"""
|
34 |
Initializes and returns the Generator model.
|
weights/source.txt
CHANGED
@@ -1,2 +1,2 @@
|
|
1 |
-
using a weight from kaggle after training
|
2 |
-
- https://www.kaggle.com/datasets/dimensioncore/conditional-gan-part-2/versions/
|
|
|
1 |
+
using a weight from kaggle after training 957 epoch:
|
2 |
+
- https://www.kaggle.com/datasets/dimensioncore/conditional-gan-part-2/versions/2397
|