Spaces:
Sleeping
Sleeping
Muhammad Naufal Rizqullah
commited on
Commit
·
ae0af75
1
Parent(s):
b2c027e
first commit
Browse files- .gitignore +30 -0
- CHANGELOG.md +30 -0
- LICENSE +21 -0
- app.py +77 -0
- config/__init__.py +0 -0
- config/core.py +27 -0
- data/dataloader.py +68 -0
- data/dataset.py +39 -0
- data/examples/100.jpg +0 -0
- data/examples/1001.jpg +0 -0
- data/examples/1020.jpg +0 -0
- data/examples/1021.jpg +0 -0
- models/__init__.py +0 -0
- models/base.py +35 -0
- models/discriminator.py +48 -0
- models/generator.py +68 -0
- requirements.txt +8 -0
- train.py +82 -0
- training/__init__.py +0 -0
- training/callbacks.py +10 -0
- training/model.py +102 -0
- utility/__init__.py +0 -0
- utility/helper.py +208 -0
- weights/epoch=266-step=42186.ckpt +3 -0
- weights/source.txt +1 -0
.gitignore
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.idea
|
2 |
+
.ipynb_checkpoints
|
3 |
+
.mypy_cache
|
4 |
+
.vscode
|
5 |
+
__pycache__
|
6 |
+
.pytest_cache
|
7 |
+
htmlcov
|
8 |
+
dist
|
9 |
+
site
|
10 |
+
.coverage
|
11 |
+
coverage.xml
|
12 |
+
.netlify
|
13 |
+
test.db
|
14 |
+
log.txt
|
15 |
+
Pipfile.lock
|
16 |
+
env3.*
|
17 |
+
env
|
18 |
+
docs_build
|
19 |
+
site_build
|
20 |
+
venv
|
21 |
+
docs.zip
|
22 |
+
archive.zip
|
23 |
+
|
24 |
+
# vim temporary files
|
25 |
+
*~
|
26 |
+
.*.sw?
|
27 |
+
.cache
|
28 |
+
|
29 |
+
# macOS
|
30 |
+
.DS_Store
|
CHANGELOG.md
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Changelog
|
2 |
+
|
3 |
+
All notable changes to this project will be documented in this file.
|
4 |
+
|
5 |
+
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
6 |
+
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
7 |
+
|
8 |
+
## [Unreleased] - 2024-10-03
|
9 |
+
### Plan
|
10 |
+
The Plan is making modular core. So when training in kaggle only pull and make some config for dataset (kaggle) and training / inference using only script with args.
|
11 |
+
|
12 |
+
## [Unreleased] - 2024-10-18
|
13 |
+
Build up the app so it can run demo. after that, find out the model is ugly when we tested. so maybe i back to training again.
|
14 |
+
|
15 |
+
idk what happen, but when i see the result of training by inspecting image (visualize generate), seems fine. but in interface, run badly.
|
16 |
+
|
17 |
+
## [0.0.1] - 2024-10-19
|
18 |
+
After looking on dataset, the problem before is because we training the image in shape 1024x1024 Close up Face Image, so when retrive image with face and a bit body can make model mess up. so require image like example provided to make some nice result.
|
19 |
+
### Feature:
|
20 |
+
- Turn Image of Face Close up into a Comic Style.
|
21 |
+
|
22 |
+
### Changed
|
23 |
+
- The Example is change, so user will get some insiration for the input
|
24 |
+
- 2 Output, first is the original image after transformation, and second is image after sending to model
|
25 |
+
|
26 |
+
### Removed
|
27 |
+
- Old Example
|
28 |
+
|
29 |
+
### Fixed
|
30 |
+
- When the resize doesnt match 256x256, because not provide in tuple, so resize only height when passed 1 paramters.
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 Muhammad Naufal Rizqullah
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
app.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import gradio as gr
|
4 |
+
import os
|
5 |
+
|
6 |
+
from PIL import Image
|
7 |
+
import torchvision.transforms as T
|
8 |
+
from config.core import config
|
9 |
+
from utility.helper import load_model_weights, init_generator_model
|
10 |
+
|
11 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
12 |
+
|
13 |
+
model = init_generator_model()
|
14 |
+
model = load_model_weights(
|
15 |
+
model=model,
|
16 |
+
checkpoint_path=config.CKPT_PATH,
|
17 |
+
device=device,
|
18 |
+
prefix="gen",
|
19 |
+
)
|
20 |
+
|
21 |
+
# Transformation
|
22 |
+
transform_face = T.Compose([
|
23 |
+
T.Resize((config.IMAGE_SIZE, config.IMAGE_SIZE)),
|
24 |
+
T.ToTensor(),
|
25 |
+
T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
26 |
+
])
|
27 |
+
|
28 |
+
def inference(image: Image):
|
29 |
+
# transforms the target image and add a batch dimension
|
30 |
+
img = transform_face(image)
|
31 |
+
img_un = img.unsqueeze(0)
|
32 |
+
|
33 |
+
image_transform = img_un * 0.5 + 0.5 # Normalize from Tanh
|
34 |
+
im_detach = image_transform.detach().cpu().squeeze(0)
|
35 |
+
im_permute = im_detach.permute(1, 2, 0)
|
36 |
+
im_array = im_permute.numpy()
|
37 |
+
|
38 |
+
# Scale values to 0-255 range
|
39 |
+
im_array = (im_array * 255).astype(np.uint8)
|
40 |
+
|
41 |
+
# Convert numpy array to PIL Image
|
42 |
+
im_pil = Image.fromarray(im_array)
|
43 |
+
|
44 |
+
# Inference the image
|
45 |
+
model.eval()
|
46 |
+
with torch.inference_mode():
|
47 |
+
c2f = model(img_un)
|
48 |
+
|
49 |
+
c2f = c2f * 0.5 + 0.5 # Normalize from Tanh
|
50 |
+
image_unflat = c2f.detach().cpu().squeeze(0) # Remove batch dimension
|
51 |
+
image = image_unflat.permute(1, 2, 0) # Permute to (H, W, C)
|
52 |
+
|
53 |
+
# Convert image to numpy array
|
54 |
+
image_array = image.numpy()
|
55 |
+
|
56 |
+
# Scale values to 0-255 range
|
57 |
+
image_array = (image_array * 255).astype(np.uint8)
|
58 |
+
|
59 |
+
# Convert numpy array to PIL Image
|
60 |
+
image = Image.fromarray(image_array)
|
61 |
+
|
62 |
+
return im_pil, image
|
63 |
+
|
64 |
+
demo = gr.Interface(
|
65 |
+
fn=inference,
|
66 |
+
inputs=gr.Image(type="pil"),
|
67 |
+
outputs=[
|
68 |
+
gr.Image(label="Original after Transform"),
|
69 |
+
gr.Image(label="Converted by Model")
|
70 |
+
],
|
71 |
+
title="Pix2Pix Face to Comic",
|
72 |
+
description="A implementation Pix2Pix from Scratch Pytorch",
|
73 |
+
examples=[f"data/examples/{i}" for i in os.listdir("data/examples") if i.endswith(('.png', '.jpg', '.jpeg', '.gif'))]
|
74 |
+
)
|
75 |
+
|
76 |
+
demo.launch()
|
77 |
+
|
config/__init__.py
ADDED
File without changes
|
config/core.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from pydantic_settings import BaseSettings
|
3 |
+
|
4 |
+
class Config(BaseSettings):
|
5 |
+
PATH_FACE: str = "/kaggle/input/comic-faces-paired-synthetic-v2/face2comics_v2.0.0_by_Sxela/face2comics_v2.0.0_by_Sxela/faces"
|
6 |
+
PATH_COMIC: str = "/kaggle/input/comic-faces-paired-synthetic-v2/face2comics_v2.0.0_by_Sxela/face2comics_v2.0.0_by_Sxela/comics"
|
7 |
+
PATH_OUTPUT: str ="/kaggle/working/generates"
|
8 |
+
|
9 |
+
IMAGE_CHANNELS: int = 3
|
10 |
+
|
11 |
+
FEATURE_DISCRIMINATOR: list = [64, 128, 256, 512]
|
12 |
+
FEATURE_GENERATOR: int = 64
|
13 |
+
|
14 |
+
IMAGE_SIZE: int = 256
|
15 |
+
BATCH_SIZE: int = 128
|
16 |
+
DISPLAY_STEP: int = 500
|
17 |
+
MAX_SAMPLES: int = 5000
|
18 |
+
|
19 |
+
LEARNING_RATE: float = 2e-4
|
20 |
+
L1_LAMBDA: int = 100
|
21 |
+
NUM_EPOCH: int = 500
|
22 |
+
|
23 |
+
LOAD_CHECKPOINT: bool = False
|
24 |
+
CKPT_PATH: str = "weights\epoch=266-step=42186.ckpt"
|
25 |
+
|
26 |
+
|
27 |
+
config = Config()
|
data/dataloader.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import lightning as L
|
2 |
+
import torchvision.transforms as T
|
3 |
+
import os
|
4 |
+
|
5 |
+
from torch.utils.data import DataLoader, Subset
|
6 |
+
from data.dataset import FaceToComicDataset
|
7 |
+
|
8 |
+
class FaceToComicDataModule(L.LightningDataModule):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
face_path,
|
12 |
+
comic_path,
|
13 |
+
image_size=(128, 128),
|
14 |
+
batch_size=32,
|
15 |
+
max_samples=None
|
16 |
+
):
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
self.face_dir = face_path
|
20 |
+
self.comic_dir = comic_path
|
21 |
+
self.image_size = image_size
|
22 |
+
self.batch_size = batch_size
|
23 |
+
self.max_samples = max_samples
|
24 |
+
|
25 |
+
self.transform_face = T.Compose([
|
26 |
+
T.Resize(self.image_size),
|
27 |
+
T.ToTensor(),
|
28 |
+
T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
29 |
+
])
|
30 |
+
|
31 |
+
self.transform_comic = T.Compose([
|
32 |
+
T.Resize(self.image_size),
|
33 |
+
T.ToTensor(),
|
34 |
+
T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
35 |
+
])
|
36 |
+
|
37 |
+
self.face2comic = None
|
38 |
+
|
39 |
+
def prepare_data(self):
|
40 |
+
# No need to download or prepare data, as it's already present in the directories
|
41 |
+
pass
|
42 |
+
|
43 |
+
def setup(self, stage=None):
|
44 |
+
if stage == "fit" or stage is None:
|
45 |
+
dataset = FaceToComicDataset(
|
46 |
+
face_path=self.face_dir,
|
47 |
+
comic_path=self.comic_dir,
|
48 |
+
transform_face=self.transform_face,
|
49 |
+
transform_comic=self.transform_comic
|
50 |
+
)
|
51 |
+
|
52 |
+
# To Limit Dataset
|
53 |
+
if self.max_samples:
|
54 |
+
print(f"[INFO] Dataset is Limited to {self.max_samples} Samples")
|
55 |
+
self.face2comic = Subset(dataset, range(min(len(dataset), self.max_samples)))
|
56 |
+
else:
|
57 |
+
self.face2comic = dataset
|
58 |
+
|
59 |
+
def train_dataloader(self):
|
60 |
+
return DataLoader(self.face2comic, batch_size=self.batch_size, num_workers=os.cpu_count(), shuffle=True)
|
61 |
+
|
62 |
+
def val_dataloader(self):
|
63 |
+
# Implement if you need validation during training
|
64 |
+
pass
|
65 |
+
|
66 |
+
def test_dataloader(self):
|
67 |
+
# Implement if you need testing after training
|
68 |
+
pass
|
data/dataset.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image
|
3 |
+
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
|
6 |
+
class FaceToComicDataset(Dataset):
|
7 |
+
def __init__(self, face_path, comic_path, transform_face=None, transform_comic=None):
|
8 |
+
super().__init__()
|
9 |
+
self.face_dir = face_path
|
10 |
+
self.comic_dir = comic_path
|
11 |
+
|
12 |
+
self.face_list_files = os.listdir(self.face_dir)
|
13 |
+
self.comic_list_files = os.listdir(self.comic_dir)
|
14 |
+
|
15 |
+
# Create a dictionary for quick lookup of comic files
|
16 |
+
self.comic_dict = {comic_file: idx for idx, comic_file in enumerate(self.comic_list_files)}
|
17 |
+
|
18 |
+
# Filter out files that don't have a corresponding pair (find only have pair)
|
19 |
+
self.face_list_files = [f for f in self.face_list_files if f in self.comic_list_files]
|
20 |
+
|
21 |
+
self.transform_face = transform_face
|
22 |
+
self.transform_comic = transform_comic
|
23 |
+
|
24 |
+
def __getitem__(self, index):
|
25 |
+
face_file = self.face_list_files[index]
|
26 |
+
comic_file = self.comic_list_files[self.comic_dict[face_file]]
|
27 |
+
|
28 |
+
face_image = Image.open(os.path.join(self.face_dir, face_file))
|
29 |
+
comic_image = Image.open(os.path.join(self.comic_dir, comic_file))
|
30 |
+
|
31 |
+
if self.transform_face:
|
32 |
+
face_image = self.transform_face(face_image)
|
33 |
+
if self.transform_comic:
|
34 |
+
comic_image = self.transform_comic(comic_image)
|
35 |
+
|
36 |
+
return face_image, comic_image
|
37 |
+
|
38 |
+
def __len__(self):
|
39 |
+
return len(self.face_list_files)
|
data/examples/100.jpg
ADDED
![]() |
data/examples/1001.jpg
ADDED
![]() |
data/examples/1020.jpg
ADDED
![]() |
data/examples/1021.jpg
ADDED
![]() |
models/__init__.py
ADDED
File without changes
|
models/base.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
class Block(nn.Module):
|
5 |
+
def __init__(self, in_channels, out_channels, down=True, act="relu", use_dropout=False):
|
6 |
+
super().__init__()
|
7 |
+
|
8 |
+
self.conv = nn.Sequential(
|
9 |
+
nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False, padding_mode="reflect")
|
10 |
+
if down
|
11 |
+
else nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False),
|
12 |
+
nn.BatchNorm2d(out_channels),
|
13 |
+
nn.ReLU() if act == "relu" else nn.LeakyReLU(0.2),
|
14 |
+
)
|
15 |
+
|
16 |
+
self.use_dropout = use_dropout
|
17 |
+
self.dropout = nn.Dropout(0.5)
|
18 |
+
self.down = down
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
x = self.conv(x)
|
22 |
+
return self.dropout(x)
|
23 |
+
|
24 |
+
class BlockCNN(nn.Module):
|
25 |
+
def __init__(self, in_channels, out_channels, stride=2):
|
26 |
+
super().__init__()
|
27 |
+
|
28 |
+
self.conv = nn.Sequential(
|
29 |
+
nn.Conv2d(in_channels, out_channels, 4, stride, bias=False, padding_mode="reflect"),
|
30 |
+
nn.BatchNorm2d(out_channels),
|
31 |
+
nn.LeakyReLU(0.2),
|
32 |
+
)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
return self.conv(x)
|
models/discriminator.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from models.base import BlockCNN
|
5 |
+
|
6 |
+
|
7 |
+
class Discriminator(nn.Module):
|
8 |
+
def __init__(self, in_channels=3, features=[64, 128, 256, 512], kernel_size=4, activation_slope=0.2, ):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
self.initial = nn.Sequential(
|
12 |
+
nn.Conv2d(
|
13 |
+
in_channels * 2,
|
14 |
+
features[0],
|
15 |
+
kernel_size,
|
16 |
+
stride=2,
|
17 |
+
padding=1,
|
18 |
+
padding_mode="reflect",
|
19 |
+
),
|
20 |
+
nn.LeakyReLU(activation_slope),
|
21 |
+
)
|
22 |
+
|
23 |
+
layers = []
|
24 |
+
in_channels = features[0]
|
25 |
+
for feature in features[1:]:
|
26 |
+
layers.append(
|
27 |
+
BlockCNN(in_channels, feature, stride=1 if feature == features[-1] else 2)
|
28 |
+
)
|
29 |
+
in_channels = feature
|
30 |
+
|
31 |
+
layers.append(
|
32 |
+
nn.Conv2d(
|
33 |
+
in_channels, 1, kernel_size=kernel_size, stride=1, padding=1, padding_mode="reflect"
|
34 |
+
)
|
35 |
+
)
|
36 |
+
|
37 |
+
self.model = nn.Sequential(*layers)
|
38 |
+
|
39 |
+
def forward(self, x, y):
|
40 |
+
x = torch.cat([x, y], dim=1)
|
41 |
+
x = self.initial(x)
|
42 |
+
return self.model(x)
|
43 |
+
|
44 |
+
def test():
|
45 |
+
# Test Case for Discriminator Model
|
46 |
+
x = torch.randn((1, 3, 256, 256))
|
47 |
+
disc = Discriminator()
|
48 |
+
print(f"Discriminator Output Shape: {disc(x, x).shape}")
|
models/generator.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from models.base import Block
|
5 |
+
|
6 |
+
|
7 |
+
class Generator(nn.Module):
|
8 |
+
def __init__(self, in_channels=3, features=64):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
self.initial_down = nn.Sequential(
|
12 |
+
nn.Conv2d(in_channels, features, 4, 2, 1, padding_mode="reflect"),
|
13 |
+
nn.LeakyReLU(0.2),
|
14 |
+
)
|
15 |
+
|
16 |
+
self.down1 = Block(features, features * 2, down=True, act="leaky", use_dropout=False)
|
17 |
+
self.down2 = Block(features * 2, features * 4, down=True, act="leaky", use_dropout=False)
|
18 |
+
self.down3 = Block(features * 4, features * 8, down=True, act="leaky", use_dropout=False)
|
19 |
+
self.down4 = Block(features * 8, features * 8, down=True, act="leaky", use_dropout=False)
|
20 |
+
self.down5 = Block(features * 8, features * 8, down=True, act="leaky", use_dropout=False)
|
21 |
+
self.down6 = Block(features * 8, features * 8, down=True, act="leaky", use_dropout=False)
|
22 |
+
|
23 |
+
self.bottleneck = nn.Sequential(
|
24 |
+
nn.Conv2d(features * 8, features * 8, 4, 2, 1),
|
25 |
+
nn.ReLU()
|
26 |
+
)
|
27 |
+
|
28 |
+
self.up1 = Block(features * 8, features * 8, down=False, act="relu", use_dropout=True)
|
29 |
+
self.up2 = Block(features * 8 * 2, features * 8, down=False, act="relu", use_dropout=True)
|
30 |
+
self.up3 = Block(features * 8 * 2, features * 8, down=False, act="relu", use_dropout=True)
|
31 |
+
self.up4 = Block(features * 8 * 2, features * 8, down=False, act="relu", use_dropout=False)
|
32 |
+
self.up5 = Block(features * 8 * 2, features * 4, down=False, act="relu", use_dropout=False)
|
33 |
+
self.up6 = Block(features * 4 * 2, features * 2, down=False, act="relu", use_dropout=False)
|
34 |
+
self.up7 = Block(features * 2 * 2, features, down=False, act="relu", use_dropout=False)
|
35 |
+
|
36 |
+
self.final_up = nn.Sequential(
|
37 |
+
nn.ConvTranspose2d(features * 2, in_channels, kernel_size=4, stride=2, padding=1),
|
38 |
+
nn.Tanh(),
|
39 |
+
)
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
d1 = self.initial_down(x)
|
43 |
+
d2 = self.down1(d1)
|
44 |
+
d3 = self.down2(d2)
|
45 |
+
d4 = self.down3(d3)
|
46 |
+
d5 = self.down4(d4)
|
47 |
+
d6 = self.down5(d5)
|
48 |
+
d7 = self.down6(d6)
|
49 |
+
|
50 |
+
bottleneck = self.bottleneck(d7)
|
51 |
+
|
52 |
+
up1 = self.up1(bottleneck)
|
53 |
+
up2 = self.up2(torch.cat([up1, d7], 1))
|
54 |
+
up3 = self.up3(torch.cat([up2, d6], 1))
|
55 |
+
up4 = self.up4(torch.cat([up3, d5], 1))
|
56 |
+
up5 = self.up5(torch.cat([up4, d4], 1))
|
57 |
+
up6 = self.up6(torch.cat([up5, d3], 1))
|
58 |
+
up7 = self.up7(torch.cat([up6, d2], 1))
|
59 |
+
|
60 |
+
final_up = self.final_up(torch.cat([up7, d1], 1))
|
61 |
+
|
62 |
+
return final_up
|
63 |
+
|
64 |
+
def test():
|
65 |
+
# Test Case for Generator Model
|
66 |
+
x = torch.randn((1, 3, 256, 256))
|
67 |
+
gen = Generator()
|
68 |
+
print(f"Generator Output Shape: {gen(x).shape}")
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
pytorch-lightning
|
3 |
+
python-multipart
|
4 |
+
fastapi
|
5 |
+
pydantic
|
6 |
+
pydantic-settings
|
7 |
+
opencv-python==4.10.0
|
8 |
+
imageio==2.33.1
|
train.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import lightning as L
|
3 |
+
|
4 |
+
from config.core import config
|
5 |
+
from training.model import Pix2Pix
|
6 |
+
from training.callbacks import MyCustomSavingCallback
|
7 |
+
from data.dataloader import FaceToComicDataModule
|
8 |
+
|
9 |
+
|
10 |
+
# Add argparser for config params
|
11 |
+
parser = argparse.ArgumentParser()
|
12 |
+
parser.add_argument("--load_checkpoint", action='store_true', help="Load checkpoint if this flag is set. If not set, start training from scratch.")
|
13 |
+
parser.add_argument("--no_load_checkpoint", action='store_false', dest='load_checkpoint', help="Do not load checkpoint. If set, start training from scratch.")
|
14 |
+
|
15 |
+
parser.add_argument("--ckpt_path", type=str, default=config.CKPT_PATH, help="Path to checkpoint file. If load_checkpoint is set, this path will be used to load the checkpoint.")
|
16 |
+
parser.add_argument("--learning_rate", type=float, default=config.LEARNING_RATE, help="Learning rate for Adam optimizer.")
|
17 |
+
parser.add_argument("--l1_lambda", type=int, default=config.L1_LAMBDA, help="Scale factor for L1 loss.")
|
18 |
+
parser.add_argument("--features_discriminator", type=int, nargs='+', default=config.FEATURE_DISCRIMINATOR, help="List of feature sizes for the discriminator network.")
|
19 |
+
parser.add_argument("--features_generator", type=int, default=config.FEATURE_GENERATOR, help="Feature size for the generator network.")
|
20 |
+
parser.add_argument("--display_step", type=int, default=config.DISPLAY_STEP, help="Interval of epochs to display loss and save examples.")
|
21 |
+
parser.add_argument("--num_epoch", type=int, default=config.NUM_EPOCH, help="Number of epochs to train for.")
|
22 |
+
parser.add_argument("--path_face", type=str, default=config.PATH_FACE, help="Path to folder containing face images.")
|
23 |
+
parser.add_argument("--path_comic", type=str, default=config.PATH_COMIC, help="Path to folder containing comic images.")
|
24 |
+
parser.add_argument("--image_size", type=int, default=config.IMAGE_SIZE, help="Size of input images.")
|
25 |
+
parser.add_argument("--batch_size", type=int, default=config.BATCH_SIZE, help="Batch size for training.")
|
26 |
+
parser.add_argument("--max_samples", type=int, default=config.MAX_SAMPLES, help="Maximum number of samples to use for training. If set to None, all samples will be used.")
|
27 |
+
|
28 |
+
args = parser.parse_args()
|
29 |
+
|
30 |
+
config.LOAD_CHECKPOINT = args.load_checkpoint if args.load_checkpoint is not None else config.LOAD_CHECKPOINT
|
31 |
+
config.CKPT_PATH = args.ckpt_path
|
32 |
+
config.LEARNING_RATE = args.learning_rate
|
33 |
+
config.L1_LAMBDA = args.l1_lambda
|
34 |
+
config.FEATURE_DISCRIMINATOR = args.features_discriminator
|
35 |
+
config.FEATURE_GENERATOR = args.features_generator
|
36 |
+
config.DISPLAY_STEP = args.display_step
|
37 |
+
config.NUM_EPOCH = args.num_epoch
|
38 |
+
config.PATH_FACE = args.path_face
|
39 |
+
config.PATH_COMIC = args.path_comic
|
40 |
+
config.IMAGE_SIZE = args.image_size
|
41 |
+
config.BATCH_SIZE = args.batch_size
|
42 |
+
config.MAX_SAMPLES = args.max_samples
|
43 |
+
|
44 |
+
# Initialize the Model Lightning
|
45 |
+
model = Pix2Pix(
|
46 |
+
in_channels=3,
|
47 |
+
learning_rate=config.LEARNING_RATE,
|
48 |
+
l1_lambda=config.L1_LAMBDA,
|
49 |
+
features_discriminator=config.FEATURE_DISCRIMINATOR,
|
50 |
+
features_generator=config.FEATURE_GENERATOR,
|
51 |
+
display_step=config.DISPLAY_STEP,
|
52 |
+
)
|
53 |
+
|
54 |
+
# Setup Trainer
|
55 |
+
n_log = None
|
56 |
+
|
57 |
+
trainer = L.Trainer(
|
58 |
+
accelerator="auto",
|
59 |
+
devices="auto",
|
60 |
+
strategy="auto",
|
61 |
+
log_every_n_steps=n_log,
|
62 |
+
max_epochs=config.NUM_EPOCH,
|
63 |
+
callbacks=[MyCustomSavingCallback()],
|
64 |
+
default_root_dir="/kaggle/working/",
|
65 |
+
precision="16-mixed",
|
66 |
+
# fast_dev_run=True
|
67 |
+
)
|
68 |
+
|
69 |
+
# Lightning DataModule
|
70 |
+
dm = FaceToComicDataModule(
|
71 |
+
face_path=config.PATH_FACE,
|
72 |
+
comic_path=config.PATH_COMIC,
|
73 |
+
image_size=(config.IMAGE_SIZE, config.IMAGE_SIZE),
|
74 |
+
batch_size=config.BATCH_SIZE,
|
75 |
+
max_samples=None
|
76 |
+
)
|
77 |
+
|
78 |
+
# Training set
|
79 |
+
if config.LOAD_CHECKPOINT:
|
80 |
+
trainer.fit(model, datamodule=dm, ckpt_path=config.CKPT_PATH)
|
81 |
+
else:
|
82 |
+
trainer.fit(model, datamodule=dm)
|
training/__init__.py
ADDED
File without changes
|
training/callbacks.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from lightning.pytorch.callbacks import Callback
|
2 |
+
from utility.helper import update_version_kaggle_dataset
|
3 |
+
|
4 |
+
class MyCustomSavingCallback(Callback):
|
5 |
+
def __init__(self):
|
6 |
+
super().__init__()
|
7 |
+
|
8 |
+
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
|
9 |
+
super().on_save_checkpoint(trainer, pl_module, checkpoint)
|
10 |
+
update_version_kaggle_dataset()
|
training/model.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import lightning as L
|
4 |
+
import torch.optim as optim
|
5 |
+
|
6 |
+
from models.generator import Generator
|
7 |
+
from models.discriminator import Discriminator
|
8 |
+
from utility.helper import save_some_examples
|
9 |
+
|
10 |
+
|
11 |
+
class Pix2Pix(L.LightningModule):
|
12 |
+
def __init__(self, in_channels, learning_rate, l1_lambda, features_generator, features_discriminator, display_step):
|
13 |
+
super().__init__()
|
14 |
+
|
15 |
+
self.automatic_optimization = False
|
16 |
+
|
17 |
+
self.gen = Generator(
|
18 |
+
in_channels=in_channels,
|
19 |
+
features=features_generator
|
20 |
+
)
|
21 |
+
self.disc = Discriminator(
|
22 |
+
in_channels=in_channels,
|
23 |
+
features=features_discriminator
|
24 |
+
)
|
25 |
+
|
26 |
+
self.loss_fn = nn.BCEWithLogitsLoss()
|
27 |
+
|
28 |
+
self.discriminator_losses = []
|
29 |
+
self.generator_losses = []
|
30 |
+
self.curr_step = 0
|
31 |
+
|
32 |
+
self.bce = nn.BCEWithLogitsLoss()
|
33 |
+
self.l1_loss = nn.L1Loss()
|
34 |
+
|
35 |
+
self.save_hyperparameters()
|
36 |
+
|
37 |
+
|
38 |
+
def configure_optimizers(self):
|
39 |
+
optimizer_G = optim.Adam(self.gen.parameters(), lr=self.hparams.learning_rate, betas=(0.5, 0.999))
|
40 |
+
optimizer_D = optim.Adam(self.disc.parameters(), lr=self.hparams.learning_rate, betas=(0.5, 0.999))
|
41 |
+
|
42 |
+
return optimizer_G, optimizer_D
|
43 |
+
|
44 |
+
def on_load_checkpoint(self, checkpoint):
|
45 |
+
# List of keys that you expect to load from the checkpoint
|
46 |
+
keys_to_load = ['discriminator_losses', 'generator_losses', 'curr_step']
|
47 |
+
|
48 |
+
# Iterate over the keys and load them if they exist in the checkpoint
|
49 |
+
for key in keys_to_load:
|
50 |
+
if key in checkpoint:
|
51 |
+
setattr(self, key, checkpoint[key])
|
52 |
+
|
53 |
+
def on_save_checkpoint(self, checkpoint):
|
54 |
+
# Save the current state of the model
|
55 |
+
checkpoint['discriminator_losses'] = self.discriminator_losses
|
56 |
+
checkpoint['generator_losses'] = self.generator_losses
|
57 |
+
checkpoint['curr_step'] = self.curr_step
|
58 |
+
|
59 |
+
def training_step(self, batch, batch_idx):
|
60 |
+
# Get the Optimizers
|
61 |
+
opt_generator, opt_discriminator = self.optimizers()
|
62 |
+
|
63 |
+
X, y = batch
|
64 |
+
|
65 |
+
# Train Discriminator
|
66 |
+
y_fake = self.gen(X)
|
67 |
+
D_real = self.disc(X, y)
|
68 |
+
D_fake = self.disc(X, y_fake.detach())
|
69 |
+
|
70 |
+
D_real_loss = self.loss_fn(D_real, torch.ones_like(D_real))
|
71 |
+
D_fake_loss = self.loss_fn(D_fake, torch.zeros_like(D_fake))
|
72 |
+
D_loss = (D_real_loss + D_fake_loss) / 2
|
73 |
+
|
74 |
+
opt_discriminator.zero_grad()
|
75 |
+
self.manual_backward(D_loss)
|
76 |
+
opt_discriminator.step()
|
77 |
+
|
78 |
+
self.log("D_loss", D_loss.item(), on_step=False, on_epoch=True, prog_bar=True)
|
79 |
+
self.discriminator_losses.append(D_loss.item())
|
80 |
+
|
81 |
+
# Train Generator
|
82 |
+
D_fake = self.disc(X, y_fake)
|
83 |
+
G_fake_loss = self.bce(D_fake, torch.ones_like(D_fake))
|
84 |
+
|
85 |
+
L1 = self.l1_loss(y_fake, y) * self.hparams.l1_lambda
|
86 |
+
G_loss = G_fake_loss + L1
|
87 |
+
|
88 |
+
opt_generator.zero_grad()
|
89 |
+
self.manual_backward(G_loss)
|
90 |
+
opt_generator.step()
|
91 |
+
|
92 |
+
self.log("G_loss", G_loss.item(), on_step=False, on_epoch=True, prog_bar=True)
|
93 |
+
self.generator_losses.append(G_loss.item())
|
94 |
+
|
95 |
+
self.log("Current_Step", self.curr_step, on_step=False, on_epoch=True, prog_bar=True)
|
96 |
+
|
97 |
+
# Visualize
|
98 |
+
if self.curr_step % self.hparams.display_step == 0 and self.curr_step > 0:
|
99 |
+
save_some_examples(self.gen, batch, self.current_epoch)
|
100 |
+
|
101 |
+
self.curr_step += 1
|
102 |
+
|
utility/__init__.py
ADDED
File without changes
|
utility/helper.py
ADDED
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import cv2
|
4 |
+
import imageio
|
5 |
+
import os
|
6 |
+
import subprocess
|
7 |
+
|
8 |
+
from config.core import config
|
9 |
+
from models.generator import Generator
|
10 |
+
from torchvision.utils import save_image
|
11 |
+
|
12 |
+
|
13 |
+
def save_some_examples(generator_model, batch, epoch, folder_path=config.PATH_OUTPUT, num_images=15):
|
14 |
+
"""
|
15 |
+
Save some examples of the generator's output.
|
16 |
+
|
17 |
+
Parameters:
|
18 |
+
generator_model (nn.Module): The generator model.
|
19 |
+
batch (tuple): The batch of input and target images as a tuple of tensors.
|
20 |
+
epoch (int): The current epoch.
|
21 |
+
folder_path (str): The folder path to save the examples to. Defaults to config.PATH_OUTPUT.
|
22 |
+
num_images (int): The number of images to save. Defaults to 15.
|
23 |
+
"""
|
24 |
+
|
25 |
+
# Ensure the folder exists
|
26 |
+
os.makedirs(folder_path, exist_ok=True)
|
27 |
+
|
28 |
+
x, y = batch # Unpack the batch
|
29 |
+
|
30 |
+
# Limit the number of images to the specified num_images
|
31 |
+
x = x[:num_images]
|
32 |
+
y = y[:num_images]
|
33 |
+
|
34 |
+
generator_model.eval()
|
35 |
+
|
36 |
+
with torch.inference_mode():
|
37 |
+
y_fake = generator_model(x)
|
38 |
+
y_fake = y_fake * 0.5 + 0.5 # Remove normalization by tanh
|
39 |
+
|
40 |
+
# Create 3x5 grid for generated images
|
41 |
+
save_image(y_fake, folder_path + f"/y_gen_{epoch}.png", nrow=5) # Save Generated Image
|
42 |
+
|
43 |
+
# Create 3x5 grid for input images
|
44 |
+
save_image(x * 0.5 + 0.5, folder_path + f"/input_{epoch}.png", nrow=5) # Save Real Image
|
45 |
+
|
46 |
+
generator_model.train()
|
47 |
+
|
48 |
+
def update_version_kaggle_dataset():
|
49 |
+
# Make Metadata json
|
50 |
+
subprocess.run(['kaggle', 'datasets', 'init'], check=True)
|
51 |
+
|
52 |
+
# Write new metadata
|
53 |
+
with open('/kaggle/working/dataset-metadata.json', 'w') as json_fid:
|
54 |
+
json_fid.write(f'{{\n "title": "Update Logs Pix2Pix",\n "id": "muhammadnaufal/pix2pix",\n "licenses": [{{"name": "CC0-1.0"}}]}}')
|
55 |
+
|
56 |
+
# Push new version
|
57 |
+
subprocess.run(['kaggle', 'datasets', 'version', '-m', 'Updated Dataset', '--quiet', '--dir-mode', 'tar'], check=True)
|
58 |
+
|
59 |
+
|
60 |
+
def init_generator_model():
|
61 |
+
"""
|
62 |
+
Initializes and returns the Generator model.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
None.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
Generator: The initialized Generator model.
|
69 |
+
"""
|
70 |
+
model = Generator(
|
71 |
+
in_channels=config.IMAGE_CHANNELS,
|
72 |
+
features=config.FEATURE_GENERATOR,
|
73 |
+
)
|
74 |
+
|
75 |
+
return model
|
76 |
+
|
77 |
+
|
78 |
+
def load_model_weights(checkpoint_path, model, device, prefix):
|
79 |
+
"""
|
80 |
+
Load specific weights from a PyTorch Lightning checkpoint into a model.
|
81 |
+
|
82 |
+
Parameters:
|
83 |
+
checkpoint_path (str): Path to the checkpoint file.
|
84 |
+
model (torch.nn.Module): The model instance to load weights into.
|
85 |
+
prefix (str): The prefix in the checkpoint's state_dict keys to filter by and remove.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
model (torch.nn.Module): The model with loaded weights.
|
89 |
+
"""
|
90 |
+
# Load the checkpoint
|
91 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
92 |
+
|
93 |
+
# Extract and modify the state_dict keys to match the model's keys
|
94 |
+
model_weights = {k.replace(f"{prefix}.", ""): v for k, v in checkpoint["state_dict"].items() if k.startswith(f"{prefix}.")}
|
95 |
+
|
96 |
+
# Load the weights into the model
|
97 |
+
model.load_state_dict(model_weights)
|
98 |
+
|
99 |
+
return model
|
100 |
+
|
101 |
+
|
102 |
+
def initialize_weights(model):
|
103 |
+
"""
|
104 |
+
Initializes the weights of a model using a normal distribution.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
model: The model to be initialized.
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
None
|
111 |
+
"""
|
112 |
+
|
113 |
+
for m in model.modules():
|
114 |
+
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d, nn.InstanceNorm2d)):
|
115 |
+
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
116 |
+
|
117 |
+
|
118 |
+
def create_video(image_folder, video_name, fps, appearance_duration=None):
|
119 |
+
"""
|
120 |
+
Creates a video from a sequence of images with customizable appearance duration.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
image_folder (str): The path to the folder containing the images.
|
124 |
+
video_name (str): The name of the output video file.
|
125 |
+
fps (int): The frames per second of the video.
|
126 |
+
appearance_duration (int, optional): The desired appearance duration for each image in milliseconds.
|
127 |
+
If None, the default duration based on frame rate is used.
|
128 |
+
|
129 |
+
Example:
|
130 |
+
image_folder = '/path/to/image/folder' \n
|
131 |
+
video_name = 'output_video.mp4' \n
|
132 |
+
fps = 12 \n
|
133 |
+
appearance_duration = 200 # Appearance duration of 200ms for each image \n
|
134 |
+
|
135 |
+
create_video(image_folder, video_name, fps, appearance_duration)
|
136 |
+
"""
|
137 |
+
|
138 |
+
# Get a list of all image files in the folder
|
139 |
+
image_files = [f for f in os.listdir(image_folder) if f.endswith('.png')]
|
140 |
+
|
141 |
+
# Sort the image files based on the step number
|
142 |
+
image_files = sorted(image_files, key=lambda x: int(x.split('-')[1].split('.')[0]))
|
143 |
+
|
144 |
+
# Load the first image to get the video size
|
145 |
+
image = cv2.imread(os.path.join(image_folder, image_files[0]))
|
146 |
+
height, width, layers = image.shape
|
147 |
+
|
148 |
+
# Create a VideoWriter object
|
149 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Specify the video codec
|
150 |
+
video = cv2.VideoWriter(video_name, fourcc, fps, (width, height))
|
151 |
+
|
152 |
+
# Write each image to the video with customizable appearance duration
|
153 |
+
for image_file in image_files:
|
154 |
+
image = cv2.imread(os.path.join(image_folder, image_file))
|
155 |
+
video.write(image)
|
156 |
+
|
157 |
+
if appearance_duration is not None:
|
158 |
+
# Calculate the number of frames for the desired appearance duration
|
159 |
+
num_frames = appearance_duration * fps // 1000
|
160 |
+
for _ in range(num_frames):
|
161 |
+
video.write(image)
|
162 |
+
|
163 |
+
# Release the video writer
|
164 |
+
video.release()
|
165 |
+
|
166 |
+
def create_gif(image_folder, gif_name, fps, appearance_duration=None):
|
167 |
+
"""
|
168 |
+
Creates a GIF from a sequence of images sorted by step number, with customizable appearance duration.
|
169 |
+
|
170 |
+
Args:
|
171 |
+
image_folder (str): The path to the folder containing the images.
|
172 |
+
gif_name (str): The name of the output GIF file.
|
173 |
+
fps (int): The frames per second of the GIF.
|
174 |
+
appearance_duration (int, optional): The desired appearance duration for each image in milliseconds.
|
175 |
+
If None, the default duration based on frame rate is used.
|
176 |
+
|
177 |
+
Example:
|
178 |
+
image_folder = '/path/to/image/folder'
|
179 |
+
gif_name = 'output_animation.gif'
|
180 |
+
fps = 12
|
181 |
+
appearance_duration = 300 # Appearance duration of 300ms for each image
|
182 |
+
|
183 |
+
create_gif(image_folder, gif_name, fps, appearance_duration)
|
184 |
+
"""
|
185 |
+
|
186 |
+
# Get a list of all image files in the folder
|
187 |
+
image_files = [f for f in os.listdir(image_folder) if f.endswith('.png')]
|
188 |
+
|
189 |
+
# Sort the image files based on the step number
|
190 |
+
image_files = sorted(image_files, key=lambda x: int(x.split('-')[1].split('.')[0]))
|
191 |
+
|
192 |
+
# Load the images into a list
|
193 |
+
images = []
|
194 |
+
for file in image_files:
|
195 |
+
images.append(imageio.imread(os.path.join(image_folder, file)))
|
196 |
+
|
197 |
+
# Create a list to store the repeated images
|
198 |
+
repeated_images = []
|
199 |
+
|
200 |
+
# Repeat each image for the desired duration
|
201 |
+
if appearance_duration is not None:
|
202 |
+
for image in images:
|
203 |
+
repeated_images.extend([image] * (appearance_duration * fps // 1000))
|
204 |
+
else:
|
205 |
+
repeated_images = images # Default appearance duration (based on fps)
|
206 |
+
|
207 |
+
# Save the repeated images as a GIF
|
208 |
+
imageio.mimsave(gif_name, repeated_images, fps=fps)
|
weights/epoch=266-step=42186.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ac94fc32bc10114294d5b0fe772847d1f6f3f83f28eaac13154ec2a99a13afec
|
3 |
+
size 686714944
|
weights/source.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
Dataset(Private): https://www.kaggle.com/datasets/muhammadnaufal/pix2pix
|