Muhammad Naufal Rizqullah commited on
Commit
ae0af75
·
1 Parent(s): b2c027e

first commit

Browse files
.gitignore ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .idea
2
+ .ipynb_checkpoints
3
+ .mypy_cache
4
+ .vscode
5
+ __pycache__
6
+ .pytest_cache
7
+ htmlcov
8
+ dist
9
+ site
10
+ .coverage
11
+ coverage.xml
12
+ .netlify
13
+ test.db
14
+ log.txt
15
+ Pipfile.lock
16
+ env3.*
17
+ env
18
+ docs_build
19
+ site_build
20
+ venv
21
+ docs.zip
22
+ archive.zip
23
+
24
+ # vim temporary files
25
+ *~
26
+ .*.sw?
27
+ .cache
28
+
29
+ # macOS
30
+ .DS_Store
CHANGELOG.md ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Changelog
2
+
3
+ All notable changes to this project will be documented in this file.
4
+
5
+ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
6
+ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
7
+
8
+ ## [Unreleased] - 2024-10-03
9
+ ### Plan
10
+ The Plan is making modular core. So when training in kaggle only pull and make some config for dataset (kaggle) and training / inference using only script with args.
11
+
12
+ ## [Unreleased] - 2024-10-18
13
+ Build up the app so it can run demo. after that, find out the model is ugly when we tested. so maybe i back to training again.
14
+
15
+ idk what happen, but when i see the result of training by inspecting image (visualize generate), seems fine. but in interface, run badly.
16
+
17
+ ## [0.0.1] - 2024-10-19
18
+ After looking on dataset, the problem before is because we training the image in shape 1024x1024 Close up Face Image, so when retrive image with face and a bit body can make model mess up. so require image like example provided to make some nice result.
19
+ ### Feature:
20
+ - Turn Image of Face Close up into a Comic Style.
21
+
22
+ ### Changed
23
+ - The Example is change, so user will get some insiration for the input
24
+ - 2 Output, first is the original image after transformation, and second is image after sending to model
25
+
26
+ ### Removed
27
+ - Old Example
28
+
29
+ ### Fixed
30
+ - When the resize doesnt match 256x256, because not provide in tuple, so resize only height when passed 1 paramters.
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Muhammad Naufal Rizqullah
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import gradio as gr
4
+ import os
5
+
6
+ from PIL import Image
7
+ import torchvision.transforms as T
8
+ from config.core import config
9
+ from utility.helper import load_model_weights, init_generator_model
10
+
11
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
12
+
13
+ model = init_generator_model()
14
+ model = load_model_weights(
15
+ model=model,
16
+ checkpoint_path=config.CKPT_PATH,
17
+ device=device,
18
+ prefix="gen",
19
+ )
20
+
21
+ # Transformation
22
+ transform_face = T.Compose([
23
+ T.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)),
24
+ T.ToTensor(),
25
+ T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
26
+ ])
27
+
28
+ def inference(image: Image):
29
+ # transforms the target image and add a batch dimension
30
+ img = transform_face(image)
31
+ img_un = img.unsqueeze(0)
32
+
33
+ image_transform = img_un * 0.5 + 0.5 # Normalize from Tanh
34
+ im_detach = image_transform.detach().cpu().squeeze(0)
35
+ im_permute = im_detach.permute(1, 2, 0)
36
+ im_array = im_permute.numpy()
37
+
38
+ # Scale values to 0-255 range
39
+ im_array = (im_array * 255).astype(np.uint8)
40
+
41
+ # Convert numpy array to PIL Image
42
+ im_pil = Image.fromarray(im_array)
43
+
44
+ # Inference the image
45
+ model.eval()
46
+ with torch.inference_mode():
47
+ c2f = model(img_un)
48
+
49
+ c2f = c2f * 0.5 + 0.5 # Normalize from Tanh
50
+ image_unflat = c2f.detach().cpu().squeeze(0) # Remove batch dimension
51
+ image = image_unflat.permute(1, 2, 0) # Permute to (H, W, C)
52
+
53
+ # Convert image to numpy array
54
+ image_array = image.numpy()
55
+
56
+ # Scale values to 0-255 range
57
+ image_array = (image_array * 255).astype(np.uint8)
58
+
59
+ # Convert numpy array to PIL Image
60
+ image = Image.fromarray(image_array)
61
+
62
+ return im_pil, image
63
+
64
+ demo = gr.Interface(
65
+ fn=inference,
66
+ inputs=gr.Image(type="pil"),
67
+ outputs=[
68
+ gr.Image(label="Original after Transform"),
69
+ gr.Image(label="Converted by Model")
70
+ ],
71
+ title="Pix2Pix Face to Comic",
72
+ description="A implementation Pix2Pix from Scratch Pytorch",
73
+ examples=[f"data/examples/{i}" for i in os.listdir("data/examples") if i.endswith(('.png', '.jpg', '.jpeg', '.gif'))]
74
+ )
75
+
76
+ demo.launch()
77
+
config/__init__.py ADDED
File without changes
config/core.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pydantic_settings import BaseSettings
3
+
4
+ class Config(BaseSettings):
5
+ PATH_FACE: str = "/kaggle/input/comic-faces-paired-synthetic-v2/face2comics_v2.0.0_by_Sxela/face2comics_v2.0.0_by_Sxela/faces"
6
+ PATH_COMIC: str = "/kaggle/input/comic-faces-paired-synthetic-v2/face2comics_v2.0.0_by_Sxela/face2comics_v2.0.0_by_Sxela/comics"
7
+ PATH_OUTPUT: str ="/kaggle/working/generates"
8
+
9
+ IMAGE_CHANNELS: int = 3
10
+
11
+ FEATURE_DISCRIMINATOR: list = [64, 128, 256, 512]
12
+ FEATURE_GENERATOR: int = 64
13
+
14
+ IMAGE_SIZE: int = 256
15
+ BATCH_SIZE: int = 128
16
+ DISPLAY_STEP: int = 500
17
+ MAX_SAMPLES: int = 5000
18
+
19
+ LEARNING_RATE: float = 2e-4
20
+ L1_LAMBDA: int = 100
21
+ NUM_EPOCH: int = 500
22
+
23
+ LOAD_CHECKPOINT: bool = False
24
+ CKPT_PATH: str = "weights\epoch=266-step=42186.ckpt"
25
+
26
+
27
+ config = Config()
data/dataloader.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lightning as L
2
+ import torchvision.transforms as T
3
+ import os
4
+
5
+ from torch.utils.data import DataLoader, Subset
6
+ from data.dataset import FaceToComicDataset
7
+
8
+ class FaceToComicDataModule(L.LightningDataModule):
9
+ def __init__(
10
+ self,
11
+ face_path,
12
+ comic_path,
13
+ image_size=(128, 128),
14
+ batch_size=32,
15
+ max_samples=None
16
+ ):
17
+ super().__init__()
18
+
19
+ self.face_dir = face_path
20
+ self.comic_dir = comic_path
21
+ self.image_size = image_size
22
+ self.batch_size = batch_size
23
+ self.max_samples = max_samples
24
+
25
+ self.transform_face = T.Compose([
26
+ T.Resize(self.image_size),
27
+ T.ToTensor(),
28
+ T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
29
+ ])
30
+
31
+ self.transform_comic = T.Compose([
32
+ T.Resize(self.image_size),
33
+ T.ToTensor(),
34
+ T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
35
+ ])
36
+
37
+ self.face2comic = None
38
+
39
+ def prepare_data(self):
40
+ # No need to download or prepare data, as it's already present in the directories
41
+ pass
42
+
43
+ def setup(self, stage=None):
44
+ if stage == "fit" or stage is None:
45
+ dataset = FaceToComicDataset(
46
+ face_path=self.face_dir,
47
+ comic_path=self.comic_dir,
48
+ transform_face=self.transform_face,
49
+ transform_comic=self.transform_comic
50
+ )
51
+
52
+ # To Limit Dataset
53
+ if self.max_samples:
54
+ print(f"[INFO] Dataset is Limited to {self.max_samples} Samples")
55
+ self.face2comic = Subset(dataset, range(min(len(dataset), self.max_samples)))
56
+ else:
57
+ self.face2comic = dataset
58
+
59
+ def train_dataloader(self):
60
+ return DataLoader(self.face2comic, batch_size=self.batch_size, num_workers=os.cpu_count(), shuffle=True)
61
+
62
+ def val_dataloader(self):
63
+ # Implement if you need validation during training
64
+ pass
65
+
66
+ def test_dataloader(self):
67
+ # Implement if you need testing after training
68
+ pass
data/dataset.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image
3
+
4
+ from torch.utils.data import Dataset
5
+
6
+ class FaceToComicDataset(Dataset):
7
+ def __init__(self, face_path, comic_path, transform_face=None, transform_comic=None):
8
+ super().__init__()
9
+ self.face_dir = face_path
10
+ self.comic_dir = comic_path
11
+
12
+ self.face_list_files = os.listdir(self.face_dir)
13
+ self.comic_list_files = os.listdir(self.comic_dir)
14
+
15
+ # Create a dictionary for quick lookup of comic files
16
+ self.comic_dict = {comic_file: idx for idx, comic_file in enumerate(self.comic_list_files)}
17
+
18
+ # Filter out files that don't have a corresponding pair (find only have pair)
19
+ self.face_list_files = [f for f in self.face_list_files if f in self.comic_list_files]
20
+
21
+ self.transform_face = transform_face
22
+ self.transform_comic = transform_comic
23
+
24
+ def __getitem__(self, index):
25
+ face_file = self.face_list_files[index]
26
+ comic_file = self.comic_list_files[self.comic_dict[face_file]]
27
+
28
+ face_image = Image.open(os.path.join(self.face_dir, face_file))
29
+ comic_image = Image.open(os.path.join(self.comic_dir, comic_file))
30
+
31
+ if self.transform_face:
32
+ face_image = self.transform_face(face_image)
33
+ if self.transform_comic:
34
+ comic_image = self.transform_comic(comic_image)
35
+
36
+ return face_image, comic_image
37
+
38
+ def __len__(self):
39
+ return len(self.face_list_files)
data/examples/100.jpg ADDED
data/examples/1001.jpg ADDED
data/examples/1020.jpg ADDED
data/examples/1021.jpg ADDED
models/__init__.py ADDED
File without changes
models/base.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ class Block(nn.Module):
5
+ def __init__(self, in_channels, out_channels, down=True, act="relu", use_dropout=False):
6
+ super().__init__()
7
+
8
+ self.conv = nn.Sequential(
9
+ nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode="reflect")
10
+ if down
11
+ else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
12
+ nn.BatchNorm2d(out_channels),
13
+ nn.ReLU() if act == "relu" else nn.LeakyReLU(0.2),
14
+ )
15
+
16
+ self.use_dropout = use_dropout
17
+ self.dropout = nn.Dropout(0.5)
18
+ self.down = down
19
+
20
+ def forward(self, x):
21
+ x = self.conv(x)
22
+ return self.dropout(x)
23
+
24
+ class BlockCNN(nn.Module):
25
+ def __init__(self, in_channels, out_channels, stride=2):
26
+ super().__init__()
27
+
28
+ self.conv = nn.Sequential(
29
+ nn.Conv2d(in_channels, out_channels, 4, stride, bias=False, padding_mode="reflect"),
30
+ nn.BatchNorm2d(out_channels),
31
+ nn.LeakyReLU(0.2),
32
+ )
33
+
34
+ def forward(self, x):
35
+ return self.conv(x)
models/discriminator.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from models.base import BlockCNN
5
+
6
+
7
+ class Discriminator(nn.Module):
8
+ def __init__(self, in_channels=3, features=[64, 128, 256, 512], kernel_size=4, activation_slope=0.2, ):
9
+ super().__init__()
10
+
11
+ self.initial = nn.Sequential(
12
+ nn.Conv2d(
13
+ in_channels * 2,
14
+ features[0],
15
+ kernel_size,
16
+ stride=2,
17
+ padding=1,
18
+ padding_mode="reflect",
19
+ ),
20
+ nn.LeakyReLU(activation_slope),
21
+ )
22
+
23
+ layers = []
24
+ in_channels = features[0]
25
+ for feature in features[1:]:
26
+ layers.append(
27
+ BlockCNN(in_channels, feature, stride=1 if feature == features[-1] else 2)
28
+ )
29
+ in_channels = feature
30
+
31
+ layers.append(
32
+ nn.Conv2d(
33
+ in_channels, 1, kernel_size=kernel_size, stride=1, padding=1, padding_mode="reflect"
34
+ )
35
+ )
36
+
37
+ self.model = nn.Sequential(*layers)
38
+
39
+ def forward(self, x, y):
40
+ x = torch.cat([x, y], dim=1)
41
+ x = self.initial(x)
42
+ return self.model(x)
43
+
44
+ def test():
45
+ # Test Case for Discriminator Model
46
+ x = torch.randn((1, 3, 256, 256))
47
+ disc = Discriminator()
48
+ print(f"Discriminator Output Shape: {disc(x, x).shape}")
models/generator.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from models.base import Block
5
+
6
+
7
+ class Generator(nn.Module):
8
+ def __init__(self, in_channels=3, features=64):
9
+ super().__init__()
10
+
11
+ self.initial_down = nn.Sequential(
12
+ nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode="reflect"),
13
+ nn.LeakyReLU(0.2),
14
+ )
15
+
16
+ self.down1 = Block(features, features * 2, down=True, act="leaky", use_dropout=False)
17
+ self.down2 = Block(features * 2, features * 4, down=True, act="leaky", use_dropout=False)
18
+ self.down3 = Block(features * 4, features * 8, down=True, act="leaky", use_dropout=False)
19
+ self.down4 = Block(features * 8, features * 8, down=True, act="leaky", use_dropout=False)
20
+ self.down5 = Block(features * 8, features * 8, down=True, act="leaky", use_dropout=False)
21
+ self.down6 = Block(features * 8, features * 8, down=True, act="leaky", use_dropout=False)
22
+
23
+ self.bottleneck = nn.Sequential(
24
+ nn.Conv2d(features * 8, features * 8, 4, 2, 1),
25
+ nn.ReLU()
26
+ )
27
+
28
+ self.up1 = Block(features * 8, features * 8, down=False, act="relu", use_dropout=True)
29
+ self.up2 = Block(features * 8 * 2, features * 8, down=False, act="relu", use_dropout=True)
30
+ self.up3 = Block(features * 8 * 2, features * 8, down=False, act="relu", use_dropout=True)
31
+ self.up4 = Block(features * 8 * 2, features * 8, down=False, act="relu", use_dropout=False)
32
+ self.up5 = Block(features * 8 * 2, features * 4, down=False, act="relu", use_dropout=False)
33
+ self.up6 = Block(features * 4 * 2, features * 2, down=False, act="relu", use_dropout=False)
34
+ self.up7 = Block(features * 2 * 2, features, down=False, act="relu", use_dropout=False)
35
+
36
+ self.final_up = nn.Sequential(
37
+ nn.ConvTranspose2d(features * 2, in_channels, kernel_size=4, stride=2, padding=1),
38
+ nn.Tanh(),
39
+ )
40
+
41
+ def forward(self, x):
42
+ d1 = self.initial_down(x)
43
+ d2 = self.down1(d1)
44
+ d3 = self.down2(d2)
45
+ d4 = self.down3(d3)
46
+ d5 = self.down4(d4)
47
+ d6 = self.down5(d5)
48
+ d7 = self.down6(d6)
49
+
50
+ bottleneck = self.bottleneck(d7)
51
+
52
+ up1 = self.up1(bottleneck)
53
+ up2 = self.up2(torch.cat([up1, d7], 1))
54
+ up3 = self.up3(torch.cat([up2, d6], 1))
55
+ up4 = self.up4(torch.cat([up3, d5], 1))
56
+ up5 = self.up5(torch.cat([up4, d4], 1))
57
+ up6 = self.up6(torch.cat([up5, d3], 1))
58
+ up7 = self.up7(torch.cat([up6, d2], 1))
59
+
60
+ final_up = self.final_up(torch.cat([up7, d1], 1))
61
+
62
+ return final_up
63
+
64
+ def test():
65
+ # Test Case for Generator Model
66
+ x = torch.randn((1, 3, 256, 256))
67
+ gen = Generator()
68
+ print(f"Generator Output Shape: {gen(x).shape}")
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ pytorch-lightning
3
+ python-multipart
4
+ fastapi
5
+ pydantic
6
+ pydantic-settings
7
+ opencv-python==4.10.0
8
+ imageio==2.33.1
train.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import lightning as L
3
+
4
+ from config.core import config
5
+ from training.model import Pix2Pix
6
+ from training.callbacks import MyCustomSavingCallback
7
+ from data.dataloader import FaceToComicDataModule
8
+
9
+
10
+ # Add argparser for config params
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument("--load_checkpoint", action='store_true', help="Load checkpoint if this flag is set. If not set, start training from scratch.")
13
+ parser.add_argument("--no_load_checkpoint", action='store_false', dest='load_checkpoint', help="Do not load checkpoint. If set, start training from scratch.")
14
+
15
+ parser.add_argument("--ckpt_path", type=str, default=config.CKPT_PATH, help="Path to checkpoint file. If load_checkpoint is set, this path will be used to load the checkpoint.")
16
+ parser.add_argument("--learning_rate", type=float, default=config.LEARNING_RATE, help="Learning rate for Adam optimizer.")
17
+ parser.add_argument("--l1_lambda", type=int, default=config.L1_LAMBDA, help="Scale factor for L1 loss.")
18
+ parser.add_argument("--features_discriminator", type=int, nargs='+', default=config.FEATURE_DISCRIMINATOR, help="List of feature sizes for the discriminator network.")
19
+ parser.add_argument("--features_generator", type=int, default=config.FEATURE_GENERATOR, help="Feature size for the generator network.")
20
+ parser.add_argument("--display_step", type=int, default=config.DISPLAY_STEP, help="Interval of epochs to display loss and save examples.")
21
+ parser.add_argument("--num_epoch", type=int, default=config.NUM_EPOCH, help="Number of epochs to train for.")
22
+ parser.add_argument("--path_face", type=str, default=config.PATH_FACE, help="Path to folder containing face images.")
23
+ parser.add_argument("--path_comic", type=str, default=config.PATH_COMIC, help="Path to folder containing comic images.")
24
+ parser.add_argument("--image_size", type=int, default=config.IMAGE_SIZE, help="Size of input images.")
25
+ parser.add_argument("--batch_size", type=int, default=config.BATCH_SIZE, help="Batch size for training.")
26
+ parser.add_argument("--max_samples", type=int, default=config.MAX_SAMPLES, help="Maximum number of samples to use for training. If set to None, all samples will be used.")
27
+
28
+ args = parser.parse_args()
29
+
30
+ config.LOAD_CHECKPOINT = args.load_checkpoint if args.load_checkpoint is not None else config.LOAD_CHECKPOINT
31
+ config.CKPT_PATH = args.ckpt_path
32
+ config.LEARNING_RATE = args.learning_rate
33
+ config.L1_LAMBDA = args.l1_lambda
34
+ config.FEATURE_DISCRIMINATOR = args.features_discriminator
35
+ config.FEATURE_GENERATOR = args.features_generator
36
+ config.DISPLAY_STEP = args.display_step
37
+ config.NUM_EPOCH = args.num_epoch
38
+ config.PATH_FACE = args.path_face
39
+ config.PATH_COMIC = args.path_comic
40
+ config.IMAGE_SIZE = args.image_size
41
+ config.BATCH_SIZE = args.batch_size
42
+ config.MAX_SAMPLES = args.max_samples
43
+
44
+ # Initialize the Model Lightning
45
+ model = Pix2Pix(
46
+ in_channels=3,
47
+ learning_rate=config.LEARNING_RATE,
48
+ l1_lambda=config.L1_LAMBDA,
49
+ features_discriminator=config.FEATURE_DISCRIMINATOR,
50
+ features_generator=config.FEATURE_GENERATOR,
51
+ display_step=config.DISPLAY_STEP,
52
+ )
53
+
54
+ # Setup Trainer
55
+ n_log = None
56
+
57
+ trainer = L.Trainer(
58
+ accelerator="auto",
59
+ devices="auto",
60
+ strategy="auto",
61
+ log_every_n_steps=n_log,
62
+ max_epochs=config.NUM_EPOCH,
63
+ callbacks=[MyCustomSavingCallback()],
64
+ default_root_dir="/kaggle/working/",
65
+ precision="16-mixed",
66
+ # fast_dev_run=True
67
+ )
68
+
69
+ # Lightning DataModule
70
+ dm = FaceToComicDataModule(
71
+ face_path=config.PATH_FACE,
72
+ comic_path=config.PATH_COMIC,
73
+ image_size=(config.IMAGE_SIZE, config.IMAGE_SIZE),
74
+ batch_size=config.BATCH_SIZE,
75
+ max_samples=None
76
+ )
77
+
78
+ # Training set
79
+ if config.LOAD_CHECKPOINT:
80
+ trainer.fit(model, datamodule=dm, ckpt_path=config.CKPT_PATH)
81
+ else:
82
+ trainer.fit(model, datamodule=dm)
training/__init__.py ADDED
File without changes
training/callbacks.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from lightning.pytorch.callbacks import Callback
2
+ from utility.helper import update_version_kaggle_dataset
3
+
4
+ class MyCustomSavingCallback(Callback):
5
+ def __init__(self):
6
+ super().__init__()
7
+
8
+ def on_save_checkpoint(self, trainer, pl_module, checkpoint):
9
+ super().on_save_checkpoint(trainer, pl_module, checkpoint)
10
+ update_version_kaggle_dataset()
training/model.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import lightning as L
4
+ import torch.optim as optim
5
+
6
+ from models.generator import Generator
7
+ from models.discriminator import Discriminator
8
+ from utility.helper import save_some_examples
9
+
10
+
11
+ class Pix2Pix(L.LightningModule):
12
+ def __init__(self, in_channels, learning_rate, l1_lambda, features_generator, features_discriminator, display_step):
13
+ super().__init__()
14
+
15
+ self.automatic_optimization = False
16
+
17
+ self.gen = Generator(
18
+ in_channels=in_channels,
19
+ features=features_generator
20
+ )
21
+ self.disc = Discriminator(
22
+ in_channels=in_channels,
23
+ features=features_discriminator
24
+ )
25
+
26
+ self.loss_fn = nn.BCEWithLogitsLoss()
27
+
28
+ self.discriminator_losses = []
29
+ self.generator_losses = []
30
+ self.curr_step = 0
31
+
32
+ self.bce = nn.BCEWithLogitsLoss()
33
+ self.l1_loss = nn.L1Loss()
34
+
35
+ self.save_hyperparameters()
36
+
37
+
38
+ def configure_optimizers(self):
39
+ optimizer_G = optim.Adam(self.gen.parameters(), lr=self.hparams.learning_rate, betas=(0.5, 0.999))
40
+ optimizer_D = optim.Adam(self.disc.parameters(), lr=self.hparams.learning_rate, betas=(0.5, 0.999))
41
+
42
+ return optimizer_G, optimizer_D
43
+
44
+ def on_load_checkpoint(self, checkpoint):
45
+ # List of keys that you expect to load from the checkpoint
46
+ keys_to_load = ['discriminator_losses', 'generator_losses', 'curr_step']
47
+
48
+ # Iterate over the keys and load them if they exist in the checkpoint
49
+ for key in keys_to_load:
50
+ if key in checkpoint:
51
+ setattr(self, key, checkpoint[key])
52
+
53
+ def on_save_checkpoint(self, checkpoint):
54
+ # Save the current state of the model
55
+ checkpoint['discriminator_losses'] = self.discriminator_losses
56
+ checkpoint['generator_losses'] = self.generator_losses
57
+ checkpoint['curr_step'] = self.curr_step
58
+
59
+ def training_step(self, batch, batch_idx):
60
+ # Get the Optimizers
61
+ opt_generator, opt_discriminator = self.optimizers()
62
+
63
+ X, y = batch
64
+
65
+ # Train Discriminator
66
+ y_fake = self.gen(X)
67
+ D_real = self.disc(X, y)
68
+ D_fake = self.disc(X, y_fake.detach())
69
+
70
+ D_real_loss = self.loss_fn(D_real, torch.ones_like(D_real))
71
+ D_fake_loss = self.loss_fn(D_fake, torch.zeros_like(D_fake))
72
+ D_loss = (D_real_loss + D_fake_loss) / 2
73
+
74
+ opt_discriminator.zero_grad()
75
+ self.manual_backward(D_loss)
76
+ opt_discriminator.step()
77
+
78
+ self.log("D_loss", D_loss.item(), on_step=False, on_epoch=True, prog_bar=True)
79
+ self.discriminator_losses.append(D_loss.item())
80
+
81
+ # Train Generator
82
+ D_fake = self.disc(X, y_fake)
83
+ G_fake_loss = self.bce(D_fake, torch.ones_like(D_fake))
84
+
85
+ L1 = self.l1_loss(y_fake, y) * self.hparams.l1_lambda
86
+ G_loss = G_fake_loss + L1
87
+
88
+ opt_generator.zero_grad()
89
+ self.manual_backward(G_loss)
90
+ opt_generator.step()
91
+
92
+ self.log("G_loss", G_loss.item(), on_step=False, on_epoch=True, prog_bar=True)
93
+ self.generator_losses.append(G_loss.item())
94
+
95
+ self.log("Current_Step", self.curr_step, on_step=False, on_epoch=True, prog_bar=True)
96
+
97
+ # Visualize
98
+ if self.curr_step % self.hparams.display_step == 0 and self.curr_step > 0:
99
+ save_some_examples(self.gen, batch, self.current_epoch)
100
+
101
+ self.curr_step += 1
102
+
utility/__init__.py ADDED
File without changes
utility/helper.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import cv2
4
+ import imageio
5
+ import os
6
+ import subprocess
7
+
8
+ from config.core import config
9
+ from models.generator import Generator
10
+ from torchvision.utils import save_image
11
+
12
+
13
+ def save_some_examples(generator_model, batch, epoch, folder_path=config.PATH_OUTPUT, num_images=15):
14
+ """
15
+ Save some examples of the generator's output.
16
+
17
+ Parameters:
18
+ generator_model (nn.Module): The generator model.
19
+ batch (tuple): The batch of input and target images as a tuple of tensors.
20
+ epoch (int): The current epoch.
21
+ folder_path (str): The folder path to save the examples to. Defaults to config.PATH_OUTPUT.
22
+ num_images (int): The number of images to save. Defaults to 15.
23
+ """
24
+
25
+ # Ensure the folder exists
26
+ os.makedirs(folder_path, exist_ok=True)
27
+
28
+ x, y = batch # Unpack the batch
29
+
30
+ # Limit the number of images to the specified num_images
31
+ x = x[:num_images]
32
+ y = y[:num_images]
33
+
34
+ generator_model.eval()
35
+
36
+ with torch.inference_mode():
37
+ y_fake = generator_model(x)
38
+ y_fake = y_fake * 0.5 + 0.5 # Remove normalization by tanh
39
+
40
+ # Create 3x5 grid for generated images
41
+ save_image(y_fake, folder_path + f"/y_gen_{epoch}.png", nrow=5) # Save Generated Image
42
+
43
+ # Create 3x5 grid for input images
44
+ save_image(x * 0.5 + 0.5, folder_path + f"/input_{epoch}.png", nrow=5) # Save Real Image
45
+
46
+ generator_model.train()
47
+
48
+ def update_version_kaggle_dataset():
49
+ # Make Metadata json
50
+ subprocess.run(['kaggle', 'datasets', 'init'], check=True)
51
+
52
+ # Write new metadata
53
+ with open('/kaggle/working/dataset-metadata.json', 'w') as json_fid:
54
+ json_fid.write(f'{{\n "title": "Update Logs Pix2Pix",\n "id": "muhammadnaufal/pix2pix",\n "licenses": [{{"name": "CC0-1.0"}}]}}')
55
+
56
+ # Push new version
57
+ subprocess.run(['kaggle', 'datasets', 'version', '-m', 'Updated Dataset', '--quiet', '--dir-mode', 'tar'], check=True)
58
+
59
+
60
+ def init_generator_model():
61
+ """
62
+ Initializes and returns the Generator model.
63
+
64
+ Args:
65
+ None.
66
+
67
+ Returns:
68
+ Generator: The initialized Generator model.
69
+ """
70
+ model = Generator(
71
+ in_channels=config.IMAGE_CHANNELS,
72
+ features=config.FEATURE_GENERATOR,
73
+ )
74
+
75
+ return model
76
+
77
+
78
+ def load_model_weights(checkpoint_path, model, device, prefix):
79
+ """
80
+ Load specific weights from a PyTorch Lightning checkpoint into a model.
81
+
82
+ Parameters:
83
+ checkpoint_path (str): Path to the checkpoint file.
84
+ model (torch.nn.Module): The model instance to load weights into.
85
+ prefix (str): The prefix in the checkpoint's state_dict keys to filter by and remove.
86
+
87
+ Returns:
88
+ model (torch.nn.Module): The model with loaded weights.
89
+ """
90
+ # Load the checkpoint
91
+ checkpoint = torch.load(checkpoint_path, map_location=device)
92
+
93
+ # Extract and modify the state_dict keys to match the model's keys
94
+ model_weights = {k.replace(f"{prefix}.", ""): v for k, v in checkpoint["state_dict"].items() if k.startswith(f"{prefix}.")}
95
+
96
+ # Load the weights into the model
97
+ model.load_state_dict(model_weights)
98
+
99
+ return model
100
+
101
+
102
+ def initialize_weights(model):
103
+ """
104
+ Initializes the weights of a model using a normal distribution.
105
+
106
+ Args:
107
+ model: The model to be initialized.
108
+
109
+ Returns:
110
+ None
111
+ """
112
+
113
+ for m in model.modules():
114
+ if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d, nn.InstanceNorm2d)):
115
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
116
+
117
+
118
+ def create_video(image_folder, video_name, fps, appearance_duration=None):
119
+ """
120
+ Creates a video from a sequence of images with customizable appearance duration.
121
+
122
+ Args:
123
+ image_folder (str): The path to the folder containing the images.
124
+ video_name (str): The name of the output video file.
125
+ fps (int): The frames per second of the video.
126
+ appearance_duration (int, optional): The desired appearance duration for each image in milliseconds.
127
+ If None, the default duration based on frame rate is used.
128
+
129
+ Example:
130
+ image_folder = '/path/to/image/folder' \n
131
+ video_name = 'output_video.mp4' \n
132
+ fps = 12 \n
133
+ appearance_duration = 200 # Appearance duration of 200ms for each image \n
134
+
135
+ create_video(image_folder, video_name, fps, appearance_duration)
136
+ """
137
+
138
+ # Get a list of all image files in the folder
139
+ image_files = [f for f in os.listdir(image_folder) if f.endswith('.png')]
140
+
141
+ # Sort the image files based on the step number
142
+ image_files = sorted(image_files, key=lambda x: int(x.split('-')[1].split('.')[0]))
143
+
144
+ # Load the first image to get the video size
145
+ image = cv2.imread(os.path.join(image_folder, image_files[0]))
146
+ height, width, layers = image.shape
147
+
148
+ # Create a VideoWriter object
149
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Specify the video codec
150
+ video = cv2.VideoWriter(video_name, fourcc, fps, (width, height))
151
+
152
+ # Write each image to the video with customizable appearance duration
153
+ for image_file in image_files:
154
+ image = cv2.imread(os.path.join(image_folder, image_file))
155
+ video.write(image)
156
+
157
+ if appearance_duration is not None:
158
+ # Calculate the number of frames for the desired appearance duration
159
+ num_frames = appearance_duration * fps // 1000
160
+ for _ in range(num_frames):
161
+ video.write(image)
162
+
163
+ # Release the video writer
164
+ video.release()
165
+
166
+ def create_gif(image_folder, gif_name, fps, appearance_duration=None):
167
+ """
168
+ Creates a GIF from a sequence of images sorted by step number, with customizable appearance duration.
169
+
170
+ Args:
171
+ image_folder (str): The path to the folder containing the images.
172
+ gif_name (str): The name of the output GIF file.
173
+ fps (int): The frames per second of the GIF.
174
+ appearance_duration (int, optional): The desired appearance duration for each image in milliseconds.
175
+ If None, the default duration based on frame rate is used.
176
+
177
+ Example:
178
+ image_folder = '/path/to/image/folder'
179
+ gif_name = 'output_animation.gif'
180
+ fps = 12
181
+ appearance_duration = 300 # Appearance duration of 300ms for each image
182
+
183
+ create_gif(image_folder, gif_name, fps, appearance_duration)
184
+ """
185
+
186
+ # Get a list of all image files in the folder
187
+ image_files = [f for f in os.listdir(image_folder) if f.endswith('.png')]
188
+
189
+ # Sort the image files based on the step number
190
+ image_files = sorted(image_files, key=lambda x: int(x.split('-')[1].split('.')[0]))
191
+
192
+ # Load the images into a list
193
+ images = []
194
+ for file in image_files:
195
+ images.append(imageio.imread(os.path.join(image_folder, file)))
196
+
197
+ # Create a list to store the repeated images
198
+ repeated_images = []
199
+
200
+ # Repeat each image for the desired duration
201
+ if appearance_duration is not None:
202
+ for image in images:
203
+ repeated_images.extend([image] * (appearance_duration * fps // 1000))
204
+ else:
205
+ repeated_images = images # Default appearance duration (based on fps)
206
+
207
+ # Save the repeated images as a GIF
208
+ imageio.mimsave(gif_name, repeated_images, fps=fps)
weights/epoch=266-step=42186.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac94fc32bc10114294d5b0fe772847d1f6f3f83f28eaac13154ec2a99a13afec
3
+ size 686714944
weights/source.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Dataset(Private): https://www.kaggle.com/datasets/muhammadnaufal/pix2pix