Spaces:
Build error
Build error
Commit
·
412f263
1
Parent(s):
35758f6
Initial commit with Git LFS tracking
Browse files- .gitattributes +6 -0
- assets/docs/ex1.gif +3 -0
- assets/docs/ex2.gif +3 -0
- assets/docs/ex3.gif +3 -0
- assets/docs/ex4.gif +3 -0
- assets/docs/ex5.gif +3 -0
- assets/docs/ex5_img.png +3 -0
- assets/docs/sam_ex.gif +3 -0
- assets/docs/vid20.gif +3 -0
- assets/docs/vid35orig.gif +3 -0
- assets/docs/vid60.gif +3 -0
- assets/gradio_example_images/1.png +3 -0
- assets/gradio_example_images/2.png +3 -0
- assets/gradio_example_images/3.png +3 -0
- assets/gradio_example_images/4.png +3 -0
- assets/gradio_example_images/5.png +3 -0
- assets/gradio_example_images/6.png +3 -0
- assets/gradio_example_images/orig.mp4 +3 -0
- assets/mask1024.jpg +3 -0
- assets/mask512.jpg +3 -0
- model/__pycache__/models.cpython-310.pyc +0 -0
- model/best_discriminator_model.pth +3 -0
- model/best_unet_model.pth +3 -0
- model/losses.py +70 -0
- model/models.py +99 -0
- requirements.txt +12 -0
- scripts/__pycache__/test_functions.cpython-310.pyc +0 -0
- scripts/app.py +101 -0
- scripts/gradio_demo.py +116 -0
- scripts/test_functions.py +229 -0
- scripts/train.py +216 -0
.gitattributes
CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
38 |
+
*.jpeg filter=lfs diff=lfs merge=lfs -text
|
39 |
+
*.gif filter=lfs diff=lfs merge=lfs -text
|
40 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
41 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
assets/docs/ex1.gif
ADDED
![]() |
Git LFS Details
|
assets/docs/ex2.gif
ADDED
![]() |
Git LFS Details
|
assets/docs/ex3.gif
ADDED
![]() |
Git LFS Details
|
assets/docs/ex4.gif
ADDED
![]() |
Git LFS Details
|
assets/docs/ex5.gif
ADDED
![]() |
Git LFS Details
|
assets/docs/ex5_img.png
ADDED
![]() |
Git LFS Details
|
assets/docs/sam_ex.gif
ADDED
![]() |
Git LFS Details
|
assets/docs/vid20.gif
ADDED
![]() |
Git LFS Details
|
assets/docs/vid35orig.gif
ADDED
![]() |
Git LFS Details
|
assets/docs/vid60.gif
ADDED
![]() |
Git LFS Details
|
assets/gradio_example_images/1.png
ADDED
![]() |
Git LFS Details
|
assets/gradio_example_images/2.png
ADDED
![]() |
Git LFS Details
|
assets/gradio_example_images/3.png
ADDED
![]() |
Git LFS Details
|
assets/gradio_example_images/4.png
ADDED
![]() |
Git LFS Details
|
assets/gradio_example_images/5.png
ADDED
![]() |
Git LFS Details
|
assets/gradio_example_images/6.png
ADDED
![]() |
Git LFS Details
|
assets/gradio_example_images/orig.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:107f1bb068638619d3712dfd72fa21b4fba8d072faa9768a85090d3369c70b8e
|
3 |
+
size 3675407
|
assets/mask1024.jpg
ADDED
![]() |
Git LFS Details
|
assets/mask512.jpg
ADDED
![]() |
Git LFS Details
|
model/__pycache__/models.cpython-310.pyc
ADDED
Binary file (3.22 kB). View file
|
|
model/best_discriminator_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aa0362ca43e381848ac34fd8b44a3d65d9eb6100b1bcb77aa706c0e1c58e06f1
|
3 |
+
size 2668758
|
model/best_unet_model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:230b6a007c65af43a67dcbbb46f4504cff71031e6d410952ef195ba6db90e942
|
3 |
+
size 124275652
|
model/losses.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import lpips # LPIPS library for perceptual loss
|
4 |
+
|
5 |
+
class GeneratorLoss(nn.Module):
|
6 |
+
def __init__(self, discriminator_model, l1_weight=1.0, perceptual_weight=1.0, adversarial_weight=0.05,
|
7 |
+
device="cpu"):
|
8 |
+
super(GeneratorLoss, self).__init__()
|
9 |
+
self.discriminator_model = discriminator_model
|
10 |
+
self.l1_weight = l1_weight
|
11 |
+
self.perceptual_weight = perceptual_weight
|
12 |
+
self.adversarial_weight = adversarial_weight
|
13 |
+
self.criterion_l1 = nn.L1Loss()
|
14 |
+
self.criterion_adversarial = nn.BCEWithLogitsLoss()
|
15 |
+
self.criterion_perceptual = lpips.LPIPS(net='vgg').to(device)
|
16 |
+
|
17 |
+
def forward(self, output, target, source):
|
18 |
+
# L1 loss
|
19 |
+
|
20 |
+
l1_loss = self.criterion_l1(output, target)
|
21 |
+
|
22 |
+
# Perceptual loss
|
23 |
+
perceptual_loss = torch.mean(self.criterion_perceptual(output, target))
|
24 |
+
|
25 |
+
# Adversarial loss
|
26 |
+
fake_input = torch.cat([output, source[:, 4:5, :, :]], dim=1)
|
27 |
+
fake_prediction = self.discriminator_model(fake_input)
|
28 |
+
|
29 |
+
adversarial_loss = self.criterion_adversarial(fake_prediction, torch.ones_like(fake_prediction))
|
30 |
+
|
31 |
+
# Combine losses
|
32 |
+
generator_loss = self.l1_weight * l1_loss + self.perceptual_weight * perceptual_loss + \
|
33 |
+
self.adversarial_weight * adversarial_loss
|
34 |
+
|
35 |
+
return generator_loss, l1_loss, perceptual_loss, adversarial_loss
|
36 |
+
|
37 |
+
class DiscriminatorLoss(nn.Module):
|
38 |
+
def __init__(self, discriminator_model, fake_weight=1.0, real_weight=2.0, mock_weight=.5):
|
39 |
+
super(DiscriminatorLoss, self).__init__()
|
40 |
+
self.discriminator_model = discriminator_model
|
41 |
+
self.criterion_adversarial = nn.BCEWithLogitsLoss()
|
42 |
+
self.fake_weight = fake_weight
|
43 |
+
self.real_weight = real_weight
|
44 |
+
self.mock_weight = mock_weight
|
45 |
+
|
46 |
+
def forward(self, output, target, source):
|
47 |
+
# Adversarial loss
|
48 |
+
fake_input = torch.cat([output, source[:, 4:5, :, :]], dim=1) # prediction img with target age
|
49 |
+
real_input = torch.cat([target, source[:, 4:5, :, :]], dim=1) # target img with target age
|
50 |
+
|
51 |
+
mock_input1 = torch.cat([source[:, :3, :, :], source[:, 4:5, :, :]], dim=1) # source img with target age
|
52 |
+
mock_input2 = torch.cat([target, source[:, 3:4, :, :]], dim=1) # target img with source age
|
53 |
+
mock_input3 = torch.cat([output, source[:, 3:4, :, :]], dim=1) # prediction img with source age
|
54 |
+
mock_input4 = torch.cat([target, source[:, 3:4, :, :]], dim=1) # target img with target age
|
55 |
+
|
56 |
+
fake_pred, real_pred = self.discriminator_model(fake_input), self.discriminator_model(real_input)
|
57 |
+
mock_pred1, mock_pred2, mock_pred3, mock_pred4 = (self.discriminator_model(mock_input1),
|
58 |
+
self.discriminator_model(mock_input2),
|
59 |
+
self.discriminator_model(mock_input3),
|
60 |
+
self.discriminator_model(mock_input4))
|
61 |
+
|
62 |
+
discriminator_loss = (self.fake_weight * self.criterion_adversarial(fake_pred, torch.zeros_like(fake_pred)) +
|
63 |
+
self.real_weight * self.criterion_adversarial(real_pred, torch.ones_like(real_pred)) +
|
64 |
+
self.mock_weight * self.criterion_adversarial(mock_pred1, torch.zeros_like(mock_pred1)) +
|
65 |
+
self.mock_weight * self.criterion_adversarial(mock_pred2, torch.zeros_like(mock_pred2)) +
|
66 |
+
self.mock_weight * self.criterion_adversarial(mock_pred3, torch.zeros_like(mock_pred3)) +
|
67 |
+
self.mock_weight * self.criterion_adversarial(mock_pred4, torch.zeros_like(mock_pred4))
|
68 |
+
)
|
69 |
+
|
70 |
+
return discriminator_loss
|
model/models.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import antialiased_cnns
|
4 |
+
|
5 |
+
|
6 |
+
class DownLayer(nn.Module):
|
7 |
+
def __init__(self, in_channels, out_channels):
|
8 |
+
super(DownLayer, self).__init__()
|
9 |
+
self.layer = nn.Sequential(
|
10 |
+
nn.MaxPool2d(kernel_size=2, stride=1),
|
11 |
+
antialiased_cnns.BlurPool(in_channels, stride=2),
|
12 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
13 |
+
nn.LeakyReLU(inplace=True),
|
14 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
15 |
+
nn.LeakyReLU(inplace=True)
|
16 |
+
)
|
17 |
+
|
18 |
+
def forward(self, x):
|
19 |
+
return self.layer(x)
|
20 |
+
|
21 |
+
|
22 |
+
class UpLayer(nn.Module):
|
23 |
+
def __init__(self, in_channels, out_channels):
|
24 |
+
super(UpLayer, self).__init__()
|
25 |
+
# Conv transpose upsampling
|
26 |
+
|
27 |
+
self.blur_upsample = nn.Sequential(
|
28 |
+
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0),
|
29 |
+
antialiased_cnns.BlurPool(out_channels, stride=1)
|
30 |
+
)
|
31 |
+
|
32 |
+
self.layer = nn.Sequential(
|
33 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
34 |
+
nn.LeakyReLU(inplace=True),
|
35 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
36 |
+
nn.LeakyReLU(inplace=True)
|
37 |
+
)
|
38 |
+
|
39 |
+
def forward(self, x, skip):
|
40 |
+
x = self.blur_upsample(x)
|
41 |
+
x = torch.cat([x, skip], dim=1) # Concatenate with skip connection
|
42 |
+
return self.layer(x)
|
43 |
+
|
44 |
+
|
45 |
+
class UNet(nn.Module):
|
46 |
+
def __init__(self):
|
47 |
+
super(UNet, self).__init__()
|
48 |
+
self.init_conv = nn.Sequential(
|
49 |
+
nn.Conv2d(5, 64, kernel_size=3, padding=1), # output: 512 x 512 x 64
|
50 |
+
nn.LeakyReLU(inplace=True),
|
51 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1), # output: 512 x 512 x 64
|
52 |
+
nn.LeakyReLU(inplace=True)
|
53 |
+
)
|
54 |
+
|
55 |
+
self.down1 = DownLayer(64, 128) # output: 256 x 256 x 128
|
56 |
+
self.down2 = DownLayer(128, 256) # output: 128 x 128 x 256
|
57 |
+
self.down3 = DownLayer(256, 512) # output: 64 x 64 x 512
|
58 |
+
self.down4 = DownLayer(512, 1024) # output: 32 x 32 x 1024
|
59 |
+
self.up1 = UpLayer(1024, 512) # output: 64 x 64 x 512
|
60 |
+
self.up2 = UpLayer(512, 256) # output: 128 x 128 x 256
|
61 |
+
self.up3 = UpLayer(256, 128) # output: 256 x 256 x 128
|
62 |
+
self.up4 = UpLayer(128, 64) # output: 512 x 512 x 64
|
63 |
+
self.final_conv = nn.Conv2d(64, 3, kernel_size=1) # output: 512 x 512 x 3
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
x0 = self.init_conv(x)
|
67 |
+
x1 = self.down1(x0)
|
68 |
+
x2 = self.down2(x1)
|
69 |
+
x3 = self.down3(x2)
|
70 |
+
x4 = self.down4(x3)
|
71 |
+
x = self.up1(x4, x3)
|
72 |
+
x = self.up2(x, x2)
|
73 |
+
x = self.up3(x, x1)
|
74 |
+
x = self.up4(x, x0)
|
75 |
+
x = self.final_conv(x)
|
76 |
+
return x
|
77 |
+
|
78 |
+
|
79 |
+
class PatchGANDiscriminator(nn.Module):
|
80 |
+
def __init__(self, input_channels=3):
|
81 |
+
super(PatchGANDiscriminator, self).__init__()
|
82 |
+
self.model = nn.Sequential(
|
83 |
+
nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1),
|
84 |
+
nn.LeakyReLU(0.2, inplace=True),
|
85 |
+
|
86 |
+
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
|
87 |
+
nn.BatchNorm2d(128),
|
88 |
+
nn.LeakyReLU(0.2, inplace=True),
|
89 |
+
|
90 |
+
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
|
91 |
+
nn.BatchNorm2d(256),
|
92 |
+
nn.LeakyReLU(0.2, inplace=True),
|
93 |
+
|
94 |
+
nn.Conv2d(256, 1, kernel_size=4, stride=1, padding=1)
|
95 |
+
# Output layer with 1 channel for binary classification
|
96 |
+
)
|
97 |
+
|
98 |
+
def forward(self, x):
|
99 |
+
return self.model(x)
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
antialiased_cnns
|
4 |
+
lpips
|
5 |
+
ffmpy
|
6 |
+
av
|
7 |
+
gradio
|
8 |
+
cmake
|
9 |
+
face_recognition
|
10 |
+
dlib
|
11 |
+
numpy
|
12 |
+
Pillow
|
scripts/__pycache__/test_functions.cpython-310.pyc
ADDED
Binary file (6.11 kB). View file
|
|
scripts/app.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from model.models import UNet
|
4 |
+
from scripts.test_functions import process_image, process_video
|
5 |
+
|
6 |
+
window_size = 512
|
7 |
+
stride = 256
|
8 |
+
steps = 18
|
9 |
+
frame_count = 0
|
10 |
+
|
11 |
+
def get_model():
|
12 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
13 |
+
unet_model = UNet().to(device)
|
14 |
+
unet_model.load_state_dict(torch.load("model/best_unet_model.pth", map_location=device))
|
15 |
+
unet_model.eval()
|
16 |
+
return unet_model
|
17 |
+
|
18 |
+
unet_model = get_model()
|
19 |
+
|
20 |
+
def block_img(image, source_age, target_age):
|
21 |
+
from PIL import Image as PILImage
|
22 |
+
import numpy as np
|
23 |
+
if isinstance(image, str):
|
24 |
+
image = PILImage.open(image).convert('RGB')
|
25 |
+
elif isinstance(image, np.ndarray) and image.dtype == object:
|
26 |
+
image = image.astype(np.uint8)
|
27 |
+
return process_image(unet_model, image, video=False, source_age=source_age,
|
28 |
+
target_age=target_age, window_size=window_size, stride=stride)
|
29 |
+
|
30 |
+
def block_img_vid(image, source_age):
|
31 |
+
from PIL import Image as PILImage
|
32 |
+
import numpy as np
|
33 |
+
if isinstance(image, str):
|
34 |
+
image = PILImage.open(image).convert('RGB')
|
35 |
+
elif isinstance(image, np.ndarray) and image.dtype == object:
|
36 |
+
image = image.astype(np.uint8)
|
37 |
+
return process_image(unet_model, image, video=True, source_age=source_age,
|
38 |
+
target_age=0, window_size=window_size, stride=stride, steps=steps)
|
39 |
+
|
40 |
+
def block_vid(video_path, source_age, target_age):
|
41 |
+
return process_video(unet_model, video_path, source_age, target_age,
|
42 |
+
window_size=window_size, stride=stride, frame_count=frame_count)
|
43 |
+
|
44 |
+
demo_img = gr.Interface(
|
45 |
+
fn=block_img,
|
46 |
+
inputs=[
|
47 |
+
gr.Image(type="pil"),
|
48 |
+
gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose your current age"),
|
49 |
+
gr.Slider(10, 90, value=80, step=1, label="Target age", info="Choose the age you want to become")
|
50 |
+
],
|
51 |
+
outputs="image",
|
52 |
+
examples=[
|
53 |
+
['assets/gradio_example_images/1.png', 20, 80],
|
54 |
+
['assets/gradio_example_images/2.png', 75, 40],
|
55 |
+
['assets/gradio_example_images/3.png', 30, 70],
|
56 |
+
['assets/gradio_example_images/4.png', 22, 60],
|
57 |
+
['assets/gradio_example_images/5.png', 28, 75],
|
58 |
+
['assets/gradio_example_images/6.png', 35, 15]
|
59 |
+
],
|
60 |
+
description="Input an image of a person and age them from the source age to the target age."
|
61 |
+
)
|
62 |
+
|
63 |
+
demo_img_vid = gr.Interface(
|
64 |
+
fn=block_img_vid,
|
65 |
+
inputs=[
|
66 |
+
gr.Image(type="pil"),
|
67 |
+
gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose your current age"),
|
68 |
+
],
|
69 |
+
outputs=gr.Video(),
|
70 |
+
examples=[
|
71 |
+
['assets/gradio_example_images/1.png', 20],
|
72 |
+
['assets/gradio_example_images/2.png', 75],
|
73 |
+
['assets/gradio_example_images/3.png', 30],
|
74 |
+
['assets/gradio_example_images/4.png', 22],
|
75 |
+
['assets/gradio_example_images/5.png', 28],
|
76 |
+
['assets/gradio_example_images/6.png', 35]
|
77 |
+
],
|
78 |
+
description="Input an image of a person and a video will be returned of the person at different ages."
|
79 |
+
)
|
80 |
+
|
81 |
+
demo_vid = gr.Interface(
|
82 |
+
fn=block_vid,
|
83 |
+
inputs=[
|
84 |
+
gr.Video(),
|
85 |
+
gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose your current age"),
|
86 |
+
gr.Slider(10, 90, value=80, step=1, label="Target age", info="Choose the age you want to become")
|
87 |
+
],
|
88 |
+
outputs=gr.Video(),
|
89 |
+
examples=[
|
90 |
+
['assets/gradio_example_images/orig.mp4', 35, 60],
|
91 |
+
],
|
92 |
+
description="Input a video of a person, and it will be aged frame-by-frame."
|
93 |
+
)
|
94 |
+
|
95 |
+
demo = gr.TabbedInterface([demo_img, demo_img_vid, demo_vid],
|
96 |
+
tab_names=['Image inference demo', 'Image animation demo', 'Video inference demo'],
|
97 |
+
title="Face Re-Aging Demo",
|
98 |
+
)
|
99 |
+
|
100 |
+
if __name__ == "__main__":
|
101 |
+
demo.launch()
|
scripts/gradio_demo.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
import sys
|
6 |
+
sys.path.append(".")
|
7 |
+
|
8 |
+
from model.models import UNet
|
9 |
+
from scripts.test_functions import process_image, process_video
|
10 |
+
|
11 |
+
# default settings
|
12 |
+
window_size = 512
|
13 |
+
stride = 256
|
14 |
+
steps = 18
|
15 |
+
frame_count = 0
|
16 |
+
|
17 |
+
def run(model_path):
|
18 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
19 |
+
unet_model = UNet().to(device)
|
20 |
+
unet_model.load_state_dict(torch.load(model_path, map_location=device))
|
21 |
+
unet_model.eval()
|
22 |
+
|
23 |
+
def block_img(image, source_age, target_age):
|
24 |
+
from PIL import Image as PILImage
|
25 |
+
import numpy as np
|
26 |
+
# If image is a file path (from examples), load it
|
27 |
+
if isinstance(image, str):
|
28 |
+
image = PILImage.open(image).convert('RGB')
|
29 |
+
# If image is a numpy array with dtype object (sometimes from Gradio), convert to uint8
|
30 |
+
elif isinstance(image, np.ndarray) and image.dtype == object:
|
31 |
+
image = image.astype(np.uint8)
|
32 |
+
return process_image(unet_model, image, video=False, source_age=source_age,
|
33 |
+
target_age=target_age, window_size=window_size, stride=stride)
|
34 |
+
|
35 |
+
def block_img_vid(image, source_age):
|
36 |
+
from PIL import Image as PILImage
|
37 |
+
import numpy as np
|
38 |
+
if isinstance(image, str):
|
39 |
+
image = PILImage.open(image).convert('RGB')
|
40 |
+
elif isinstance(image, np.ndarray) and image.dtype == object:
|
41 |
+
image = image.astype(np.uint8)
|
42 |
+
return process_image(unet_model, image, video=True, source_age=source_age,
|
43 |
+
target_age=0, window_size=window_size, stride=stride, steps=steps)
|
44 |
+
|
45 |
+
def block_vid(video_path, source_age, target_age):
|
46 |
+
return process_video(unet_model, video_path, source_age, target_age,
|
47 |
+
window_size=window_size, stride=stride, frame_count=frame_count)
|
48 |
+
|
49 |
+
demo_img = gr.Interface(
|
50 |
+
fn=block_img,
|
51 |
+
inputs=[
|
52 |
+
gr.Image(type="pil"),
|
53 |
+
gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose your current age"),
|
54 |
+
gr.Slider(10, 90, value=80, step=1, label="Target age", info="Choose the age you want to become")
|
55 |
+
],
|
56 |
+
outputs="image",
|
57 |
+
examples=[
|
58 |
+
['assets/gradio_example_images/1.png', 20, 80],
|
59 |
+
['assets/gradio_example_images/2.png', 75, 40],
|
60 |
+
['assets/gradio_example_images/3.png', 30, 70],
|
61 |
+
['assets/gradio_example_images/4.png', 22, 60],
|
62 |
+
['assets/gradio_example_images/5.png', 28, 75],
|
63 |
+
['assets/gradio_example_images/6.png', 35, 15]
|
64 |
+
],
|
65 |
+
description="Input an image of a person and age them from the source age to the target age."
|
66 |
+
)
|
67 |
+
|
68 |
+
demo_img_vid = gr.Interface(
|
69 |
+
fn=block_img_vid,
|
70 |
+
inputs=[
|
71 |
+
gr.Image(type="pil"),
|
72 |
+
gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose your current age"),
|
73 |
+
],
|
74 |
+
outputs=gr.Video(),
|
75 |
+
examples=[
|
76 |
+
['assets/gradio_example_images/1.png', 20],
|
77 |
+
['assets/gradio_example_images/2.png', 75],
|
78 |
+
['assets/gradio_example_images/3.png', 30],
|
79 |
+
['assets/gradio_example_images/4.png', 22],
|
80 |
+
['assets/gradio_example_images/5.png', 28],
|
81 |
+
['assets/gradio_example_images/6.png', 35]
|
82 |
+
],
|
83 |
+
description="Input an image of a person and a video will be returned of the person at different ages."
|
84 |
+
)
|
85 |
+
|
86 |
+
demo_vid = gr.Interface(
|
87 |
+
fn=block_vid,
|
88 |
+
inputs=[
|
89 |
+
gr.Video(),
|
90 |
+
gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose your current age"),
|
91 |
+
gr.Slider(10, 90, value=80, step=1, label="Target age", info="Choose the age you want to become")
|
92 |
+
],
|
93 |
+
outputs=gr.Video(),
|
94 |
+
examples=[
|
95 |
+
['assets/gradio_example_images/orig.mp4', 35, 60],
|
96 |
+
],
|
97 |
+
description="Input a video of a person, and it will be aged frame-by-frame."
|
98 |
+
)
|
99 |
+
|
100 |
+
demo = gr.TabbedInterface([demo_img, demo_img_vid, demo_vid],
|
101 |
+
tab_names=['Image inference demo', 'Image animation demo', 'Video inference demo'],
|
102 |
+
title="Face Re-Aging Demo",
|
103 |
+
)
|
104 |
+
|
105 |
+
demo.launch()
|
106 |
+
|
107 |
+
|
108 |
+
if __name__ == "__main__":
|
109 |
+
# Define command-line arguments
|
110 |
+
parser = argparse.ArgumentParser(description="Testing script - Image demo")
|
111 |
+
parser.add_argument("--model_path", type=str, default="model/best_unet_model.pth", help="Path to the model")
|
112 |
+
|
113 |
+
# Parse command-line arguments
|
114 |
+
args = parser.parse_args()
|
115 |
+
|
116 |
+
run(args.model_path)
|
scripts/test_functions.py
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import face_recognition
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
from torch.autograd import Variable
|
6 |
+
from torchvision import transforms
|
7 |
+
from torchvision.io import write_video
|
8 |
+
import tempfile
|
9 |
+
import subprocess
|
10 |
+
import json
|
11 |
+
from ffmpy import FFmpeg, FFprobe
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
mask_file = torch.from_numpy(np.array(Image.open('assets/mask1024.jpg').convert('L'))) / 255
|
15 |
+
small_mask_file = torch.from_numpy(np.array(Image.open('assets/mask512.jpg').convert('L'))) / 255
|
16 |
+
|
17 |
+
def sliding_window_tensor(input_tensor, window_size, stride, your_model, mask=mask_file, small_mask=small_mask_file):
|
18 |
+
"""
|
19 |
+
Apply aging operation on input tensor using a sliding-window method. This operation is done on the GPU, if available.
|
20 |
+
"""
|
21 |
+
|
22 |
+
input_tensor = input_tensor.to(next(your_model.parameters()).device)
|
23 |
+
mask = mask.to(next(your_model.parameters()).device)
|
24 |
+
small_mask = small_mask.to(next(your_model.parameters()).device)
|
25 |
+
|
26 |
+
n, c, h, w = input_tensor.size()
|
27 |
+
output_tensor = torch.zeros((n, 3, h, w), dtype=input_tensor.dtype, device=input_tensor.device)
|
28 |
+
|
29 |
+
count_tensor = torch.zeros((n, 3, h, w), dtype=torch.float32, device=input_tensor.device)
|
30 |
+
|
31 |
+
add = 2 if window_size % stride != 0 else 1
|
32 |
+
|
33 |
+
for y in range(0, h - window_size + add, stride):
|
34 |
+
for x in range(0, w - window_size + add, stride):
|
35 |
+
window = input_tensor[:, :, y:y + window_size, x:x + window_size]
|
36 |
+
|
37 |
+
# Apply the same preprocessing as during training
|
38 |
+
input_variable = Variable(window, requires_grad=False) # Assuming GPU is available
|
39 |
+
|
40 |
+
# Forward pass
|
41 |
+
with torch.no_grad():
|
42 |
+
output = your_model(input_variable)
|
43 |
+
|
44 |
+
output_tensor[:, :, y:y + window_size, x:x + window_size] += output * small_mask
|
45 |
+
count_tensor[:, :, y:y + window_size, x:x + window_size] += small_mask
|
46 |
+
|
47 |
+
count_tensor = torch.clamp(count_tensor, min=1.0)
|
48 |
+
|
49 |
+
# Average the overlapping regions
|
50 |
+
output_tensor /= count_tensor
|
51 |
+
|
52 |
+
# Apply mask
|
53 |
+
output_tensor *= mask
|
54 |
+
|
55 |
+
return output_tensor.cpu()
|
56 |
+
|
57 |
+
|
58 |
+
def process_image(your_model, image, video, source_age, target_age=0,
|
59 |
+
window_size=512, stride=256, steps=18):
|
60 |
+
input_size = (1024, 1024)
|
61 |
+
# Robustly handle image input for face_recognition
|
62 |
+
from PIL import Image as PILImage
|
63 |
+
import numpy as np
|
64 |
+
if isinstance(image, PILImage.Image):
|
65 |
+
image = image.convert('RGB')
|
66 |
+
image = np.array(image)
|
67 |
+
elif isinstance(image, np.ndarray):
|
68 |
+
if image.ndim == 2: # grayscale
|
69 |
+
image = np.stack([image]*3, axis=-1)
|
70 |
+
elif image.shape[2] == 4: # RGBA
|
71 |
+
image = image[..., :3]
|
72 |
+
if image.dtype == np.float32 or image.dtype == np.float64:
|
73 |
+
if image.max() <= 1.0:
|
74 |
+
image = (image * 255).astype(np.uint8)
|
75 |
+
else:
|
76 |
+
image = image.astype(np.uint8)
|
77 |
+
elif image.dtype != np.uint8:
|
78 |
+
image = image.astype(np.uint8)
|
79 |
+
else:
|
80 |
+
image = np.array(PILImage.fromarray(image).convert('RGB'))
|
81 |
+
# Ensure shape is (H, W, 3) and contiguous
|
82 |
+
if image.ndim != 3 or image.shape[2] != 3:
|
83 |
+
raise ValueError(f"Image must have shape (H, W, 3), got {image.shape}")
|
84 |
+
image = np.ascontiguousarray(image, dtype=np.uint8)
|
85 |
+
print(f"[DEBUG] image type: {type(image)}, shape: {image.shape}, dtype: {image.dtype}, contiguous: {image.flags['C_CONTIGUOUS']}")
|
86 |
+
if video: # h264 codec requires frame size to be divisible by 2.
|
87 |
+
width, height, depth = image.shape
|
88 |
+
new_width = width if width % 2 == 0 else width - 1
|
89 |
+
new_height = height if height % 2 == 0 else height - 1
|
90 |
+
image.resize((new_width, new_height, depth))
|
91 |
+
|
92 |
+
# Diagnostic: try face_recognition on this image, and if it fails, save and reload
|
93 |
+
try:
|
94 |
+
fl = face_recognition.face_locations(image)[0]
|
95 |
+
except Exception as e:
|
96 |
+
print(f"[DEBUG] face_locations failed: {e}. Saving image for test...")
|
97 |
+
import tempfile
|
98 |
+
from PIL import Image as PILImage
|
99 |
+
temp_path = tempfile.mktemp(suffix='.png')
|
100 |
+
PILImage.fromarray(image).save(temp_path)
|
101 |
+
print(f"[DEBUG] Saved image to {temp_path}. Trying face_recognition.load_image_file...")
|
102 |
+
loaded_img = face_recognition.load_image_file(temp_path)
|
103 |
+
print(f"[DEBUG] loaded_img type: {type(loaded_img)}, shape: {loaded_img.shape}, dtype: {loaded_img.dtype}")
|
104 |
+
fl = face_recognition.face_locations(loaded_img)[0]
|
105 |
+
|
106 |
+
# calculate margins
|
107 |
+
margin_y_t = int((fl[2] - fl[0]) * .63 * .85) # larger as the forehead is often cut off
|
108 |
+
margin_y_b = int((fl[2] - fl[0]) * .37 * .85)
|
109 |
+
margin_x = int((fl[1] - fl[3]) // (2 / .85))
|
110 |
+
margin_y_t += 2 * margin_x - margin_y_t - margin_y_b # make sure square is preserved
|
111 |
+
|
112 |
+
l_y = max([fl[0] - margin_y_t, 0])
|
113 |
+
r_y = min([fl[2] + margin_y_b, image.shape[0]])
|
114 |
+
l_x = max([fl[3] - margin_x, 0])
|
115 |
+
r_x = min([fl[1] + margin_x, image.shape[1]])
|
116 |
+
|
117 |
+
# crop image
|
118 |
+
cropped_image = image[l_y:r_y, l_x:r_x, :]
|
119 |
+
|
120 |
+
# Resizing
|
121 |
+
orig_size = cropped_image.shape[:2]
|
122 |
+
|
123 |
+
cropped_image = transforms.ToTensor()(cropped_image)
|
124 |
+
|
125 |
+
cropped_image_resized = transforms.Resize(input_size, interpolation=Image.BILINEAR, antialias=True)(cropped_image)
|
126 |
+
|
127 |
+
source_age_channel = torch.full_like(cropped_image_resized[:1, :, :], source_age / 100)
|
128 |
+
target_age_channel = torch.full_like(cropped_image_resized[:1, :, :], target_age / 100)
|
129 |
+
input_tensor = torch.cat([cropped_image_resized, source_age_channel, target_age_channel], dim=0).unsqueeze(0)
|
130 |
+
|
131 |
+
image = transforms.ToTensor()(image)
|
132 |
+
|
133 |
+
if video:
|
134 |
+
# aging in steps
|
135 |
+
interval = .8 / steps
|
136 |
+
aged_cropped_images = torch.zeros((steps, 3, input_size[1], input_size[0]))
|
137 |
+
for i in range(0, steps):
|
138 |
+
input_tensor[:, -1, :, :] += interval
|
139 |
+
|
140 |
+
# performing actions on image
|
141 |
+
aged_cropped_images[i, ...] = sliding_window_tensor(input_tensor, window_size, stride, your_model)
|
142 |
+
|
143 |
+
# resize back to original size
|
144 |
+
aged_cropped_images_resized = transforms.Resize(orig_size, interpolation=Image.BILINEAR, antialias=True)(
|
145 |
+
aged_cropped_images)
|
146 |
+
|
147 |
+
# re-apply
|
148 |
+
image = image.repeat(steps, 1, 1, 1)
|
149 |
+
|
150 |
+
image[:, :, l_y:r_y, l_x:r_x] += aged_cropped_images_resized
|
151 |
+
image = torch.clamp(image, 0, 1)
|
152 |
+
image = (image * 255).to(torch.uint8)
|
153 |
+
|
154 |
+
output_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
155 |
+
|
156 |
+
write_video(output_file.name, image.permute(0, 2, 3, 1), 2)
|
157 |
+
|
158 |
+
return output_file.name
|
159 |
+
|
160 |
+
else:
|
161 |
+
# performing actions on image
|
162 |
+
aged_cropped_image = sliding_window_tensor(input_tensor, window_size, stride, your_model)
|
163 |
+
|
164 |
+
# resize back to original size
|
165 |
+
aged_cropped_image_resized = transforms.Resize(orig_size, interpolation=Image.BILINEAR, antialias=True)(
|
166 |
+
aged_cropped_image)
|
167 |
+
|
168 |
+
# re-apply
|
169 |
+
image[:, l_y:r_y, l_x:r_x] += aged_cropped_image_resized.squeeze(0)
|
170 |
+
image = torch.clamp(image, 0, 1)
|
171 |
+
|
172 |
+
return transforms.functional.to_pil_image(image)
|
173 |
+
|
174 |
+
|
175 |
+
def process_video(your_model, video_path, source_age, target_age, window_size=512, stride=256, frame_count=0):
|
176 |
+
"""
|
177 |
+
Applying the aging to a video.
|
178 |
+
We age as from source_age to target_age, and return an image.
|
179 |
+
To limit the number of frames in a video, we can set frame_count.
|
180 |
+
"""
|
181 |
+
|
182 |
+
# Extracting frames and placing them in a temporary directory
|
183 |
+
frames_dir = tempfile.TemporaryDirectory()
|
184 |
+
output_template = os.path.join(frames_dir.name, '%04d.jpg')
|
185 |
+
|
186 |
+
if frame_count:
|
187 |
+
ff = FFmpeg(
|
188 |
+
inputs={video_path: None},
|
189 |
+
outputs={output_template: ['-vf', f'select=lt(n\,{frame_count})', '-q:v', '1']}
|
190 |
+
)
|
191 |
+
else:
|
192 |
+
ff = FFmpeg(
|
193 |
+
inputs={video_path: None},
|
194 |
+
outputs={output_template: ['-q:v', '1']}
|
195 |
+
)
|
196 |
+
|
197 |
+
ff.run()
|
198 |
+
|
199 |
+
# Getting framerate (for reconstruction later)
|
200 |
+
ff = FFprobe(inputs={video_path: None},
|
201 |
+
global_options=['-v', 'error', '-select_streams', 'v', '-show_entries', 'stream=r_frame_rate', '-of',
|
202 |
+
'default=noprint_wrappers=1:nokey=1'])
|
203 |
+
stdout, _ = ff.run(stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
204 |
+
frame_rate = eval(stdout.decode('utf-8').strip())
|
205 |
+
|
206 |
+
|
207 |
+
# Applying process_image to frames
|
208 |
+
processed_dir = tempfile.TemporaryDirectory()
|
209 |
+
|
210 |
+
for name in os.listdir(frames_dir.name):
|
211 |
+
image_path = os.path.join(frames_dir.name, name)
|
212 |
+
image = Image.open(image_path).convert('RGB')
|
213 |
+
image_aged = process_image(your_model, image, False, source_age, target_age, window_size, stride)
|
214 |
+
image_aged.save(os.path.join(processed_dir.name, name))
|
215 |
+
|
216 |
+
# Generating a new video
|
217 |
+
input_template = os.path.join(processed_dir.name, '%04d.jpg')
|
218 |
+
output_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
|
219 |
+
ff = FFmpeg(
|
220 |
+
inputs={input_template: f'-framerate {frame_rate}'}, global_options=['-y'],
|
221 |
+
outputs={output_file.name: ['-c:v', 'libx264', '-pix_fmt', 'yuv420p']}
|
222 |
+
)
|
223 |
+
|
224 |
+
ff.run()
|
225 |
+
|
226 |
+
frames_dir.cleanup()
|
227 |
+
processed_dir.cleanup()
|
228 |
+
|
229 |
+
return output_file.name
|
scripts/train.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.optim as optim
|
3 |
+
from torch.utils.data import DataLoader, random_split
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
import sys
|
7 |
+
sys.path.append(".")
|
8 |
+
|
9 |
+
from model.models import UNet, PatchGANDiscriminator
|
10 |
+
from model.losses import GeneratorLoss, DiscriminatorLoss
|
11 |
+
from utils.dataloader import CustomDataset, transform
|
12 |
+
|
13 |
+
|
14 |
+
def train_model(root_dir, start_epoch, num_epochs, load_model_g, load_model_d, num_workers,
|
15 |
+
val_freq, batch_size, accum_iter, lr, lr_d, wandb_tracking, desc):
|
16 |
+
if wandb_tracking:
|
17 |
+
import wandb
|
18 |
+
|
19 |
+
wandb.init(project="FRAN",
|
20 |
+
# track hyperparameters and run metadata
|
21 |
+
config={
|
22 |
+
"lr": lr,
|
23 |
+
"lr_d": lr_d,
|
24 |
+
"dataset": root_dir,
|
25 |
+
"epochs": num_epochs,
|
26 |
+
"batch_size": batch_size,
|
27 |
+
"description": desc
|
28 |
+
}
|
29 |
+
)
|
30 |
+
|
31 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
32 |
+
print(f"device: {device}")
|
33 |
+
if torch.cuda.device_count() > 0:
|
34 |
+
print(f"{torch.cuda.device_count()} GPU(s)")
|
35 |
+
if torch.cuda.device_count() > 1:
|
36 |
+
print("multi-GPU training is currently not supported.")
|
37 |
+
|
38 |
+
# Create instances of the dataset and split into scripts and validation sets
|
39 |
+
dataset = CustomDataset(root_dir=root_dir, transform=transform)
|
40 |
+
|
41 |
+
# Assuming you want to use 80% of the data for scripts and 20% for validation
|
42 |
+
train_size = int(0.8 * len(dataset))
|
43 |
+
val_size = len(dataset) - train_size
|
44 |
+
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
|
45 |
+
|
46 |
+
# Create data loaders for scripts and validation
|
47 |
+
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
|
48 |
+
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
|
49 |
+
|
50 |
+
# Create instances of the U-Net, discriminator, and loss models
|
51 |
+
unet_model = UNet()
|
52 |
+
discriminator_model = PatchGANDiscriminator(input_channels=4)
|
53 |
+
|
54 |
+
if load_model_g:
|
55 |
+
unet_model.load_state_dict(torch.load(load_model_g, map_location=device))
|
56 |
+
print(f'loaded {load_model_g} for unet_model')
|
57 |
+
if load_model_d:
|
58 |
+
discriminator_model.load_state_dict(torch.load(load_model_d, map_location=device))
|
59 |
+
print(f'loaded {load_model_d} for discriminator_model')
|
60 |
+
|
61 |
+
unet_model = unet_model.to(device)
|
62 |
+
discriminator_model = discriminator_model.to(device)
|
63 |
+
|
64 |
+
# if multiGPU:
|
65 |
+
# unet_model = nn.DataParallel(unet_model)
|
66 |
+
# discriminator_model = nn.DataParallel(discriminator_model)
|
67 |
+
|
68 |
+
# Create loss instances
|
69 |
+
generator_loss_func = GeneratorLoss(discriminator_model, l1_weight=1.0, perceptual_weight=1.0,
|
70 |
+
adversarial_weight=0.05, device=device)
|
71 |
+
discriminator_loss_func = DiscriminatorLoss(discriminator_model)
|
72 |
+
|
73 |
+
# Create instances of the Adam optimizer
|
74 |
+
optimizer_g = optim.Adam(unet_model.parameters(), lr=lr)
|
75 |
+
optimizer_d = optim.Adam(discriminator_model.parameters(), lr=lr_d)
|
76 |
+
|
77 |
+
# Training and validation loop
|
78 |
+
best_val_loss = float('inf')
|
79 |
+
|
80 |
+
for epoch in range(start_epoch - 1, num_epochs):
|
81 |
+
# Training
|
82 |
+
unet_model.train()
|
83 |
+
discriminator_model.train()
|
84 |
+
batch_idx = 0
|
85 |
+
for batch in train_dataloader:
|
86 |
+
batch_idx += 1
|
87 |
+
source_images, target_images = batch
|
88 |
+
|
89 |
+
# if not multiGPU:
|
90 |
+
# if multi GPU, nn.DataParallel will already put the batches on the right devices.
|
91 |
+
# Otherwise, we do it manually
|
92 |
+
source_images = source_images.to(device)
|
93 |
+
target_images = target_images.to(device)
|
94 |
+
|
95 |
+
# Zero gradients
|
96 |
+
# optimizer_g.zero_grad()
|
97 |
+
# optimizer_d.zero_grad()
|
98 |
+
|
99 |
+
# Forward pass
|
100 |
+
output_images = unet_model(source_images)
|
101 |
+
# if multiGPU:
|
102 |
+
# output_device = output_images.get_device()
|
103 |
+
# source_images, target_images = source_images.to(output_device), target_images.to(output_device)
|
104 |
+
output_images += source_images[:, :3, :, :]
|
105 |
+
|
106 |
+
# Discriminator pass
|
107 |
+
discriminator_loss = discriminator_loss_func(output_images.detach(), target_images, source_images)
|
108 |
+
# discriminator_loss /= accum_iter
|
109 |
+
discriminator_loss.backward()
|
110 |
+
|
111 |
+
if (batch_idx % accum_iter == 0) or (batch_idx == len(train_dataloader)):
|
112 |
+
optimizer_d.step()
|
113 |
+
optimizer_d.zero_grad()
|
114 |
+
|
115 |
+
# Generator pass
|
116 |
+
# Calculate the loss
|
117 |
+
generator_loss, l1_loss, per_loss, adv_loss = generator_loss_func(output_images, target_images,
|
118 |
+
source_images)
|
119 |
+
generator_loss, l1_loss, per_loss, adv_loss = [i / accum_iter for i in
|
120 |
+
[generator_loss, l1_loss, per_loss, adv_loss]]
|
121 |
+
generator_loss.backward()
|
122 |
+
|
123 |
+
if (batch_idx % accum_iter == 0) or (batch_idx == len(train_dataloader)):
|
124 |
+
optimizer_g.step()
|
125 |
+
optimizer_g.zero_grad()
|
126 |
+
|
127 |
+
# Print scripts information (if needed)
|
128 |
+
print(
|
129 |
+
f'Training Epoch [{epoch + 1}/{num_epochs}], Gen Loss: {generator_loss.item()}, L1: {l1_loss.item()}, P: {per_loss.item()}, A: {adv_loss.item()}, Dis Loss: {discriminator_loss.item()}')
|
130 |
+
if wandb_tracking:
|
131 |
+
wandb.log({
|
132 |
+
'Training Epoch': epoch + 1,
|
133 |
+
'Gen Loss': generator_loss.item(),
|
134 |
+
'L1': l1_loss.item(),
|
135 |
+
'P': per_loss.item(),
|
136 |
+
'A': adv_loss.item(),
|
137 |
+
'Dis Loss': discriminator_loss.item()
|
138 |
+
})
|
139 |
+
|
140 |
+
torch.save(unet_model.state_dict(), 'recent_unet_model.pth')
|
141 |
+
torch.save(discriminator_model.state_dict(), 'recent_discriminator_model.pth')
|
142 |
+
|
143 |
+
# Validation
|
144 |
+
if epoch % val_freq == 0:
|
145 |
+
unet_model.eval()
|
146 |
+
total_val_loss = 0.0
|
147 |
+
with torch.no_grad():
|
148 |
+
for val_batch in val_dataloader:
|
149 |
+
val_source_images, val_target_images = val_batch
|
150 |
+
|
151 |
+
# if not multiGPU:
|
152 |
+
# if multi GPU, nn.DataParallel will already put the batches on the right devices.
|
153 |
+
# Otherwise, we do it manually
|
154 |
+
val_source_images = val_source_images.to(device)
|
155 |
+
val_target_images = val_target_images.to(device)
|
156 |
+
|
157 |
+
# Forward pass
|
158 |
+
val_output_images = unet_model(val_source_images)
|
159 |
+
|
160 |
+
# if multiGPU:
|
161 |
+
# output_device = val_output_images.get_device()
|
162 |
+
# val_source_images, val_target_images = val_source_images.to(output_device), \
|
163 |
+
# val_target_images.to(output_device)
|
164 |
+
|
165 |
+
# Calculate the loss
|
166 |
+
generator_loss, _, _, _ = generator_loss_func(val_output_images, val_target_images,
|
167 |
+
val_source_images)
|
168 |
+
total_val_loss += generator_loss.item()
|
169 |
+
|
170 |
+
average_val_loss = total_val_loss / len(val_dataloader)
|
171 |
+
|
172 |
+
# Print validation information
|
173 |
+
print(f'Validation Epoch [{epoch + 1}/{num_epochs}], Average Loss: {average_val_loss}')
|
174 |
+
if wandb_tracking:
|
175 |
+
wandb.log({
|
176 |
+
'Training Epoch': epoch + 1,
|
177 |
+
'Val Loss': average_val_loss,
|
178 |
+
})
|
179 |
+
|
180 |
+
# Save the model with the best validation loss
|
181 |
+
if average_val_loss < best_val_loss:
|
182 |
+
best_val_loss = average_val_loss
|
183 |
+
torch.save(unet_model.state_dict(), 'best_unet_model.pth')
|
184 |
+
torch.save(discriminator_model.state_dict(), 'best_discriminator_model.pth')
|
185 |
+
|
186 |
+
if wandb_tracking:
|
187 |
+
wandb.finish()
|
188 |
+
|
189 |
+
|
190 |
+
if __name__ == "__main__":
|
191 |
+
# Define command-line arguments
|
192 |
+
parser = argparse.ArgumentParser(description="Training Script")
|
193 |
+
parser.add_argument("--root_dir", type=str, default='data/processed/train',
|
194 |
+
help="Path to the training data. Note the format: To use the dataloader, the directory should be filled with folders containing image files of various ages, where the file name is the age.")
|
195 |
+
parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch, if scripts is resumed")
|
196 |
+
parser.add_argument("--num_epochs", type=int, default=2000, help="End epoch")
|
197 |
+
parser.add_argument("--load_model_g", type=str, default='',
|
198 |
+
help="Path to pretrained generator model. Leave blank to train from scratch")
|
199 |
+
parser.add_argument("--load_model_d", type=str, default='',
|
200 |
+
help="Path to pretrained discriminator model. Leave blank to train from scratch")
|
201 |
+
parser.add_argument("--num_workers", type=int, default=4, help="Number of workers")
|
202 |
+
parser.add_argument("--batch_size", type=int, default=3, help="Batch size")
|
203 |
+
parser.add_argument("--accum_iter", type=int, default=3, help="Number of batches after which weights are updated")
|
204 |
+
parser.add_argument("--val_freq", type=int, default=1, help="Validation frequency (epochs)")
|
205 |
+
parser.add_argument("--lr", type=float, default=0.00001, help="Learning rate for generator")
|
206 |
+
parser.add_argument("--lr_d", type=float, default=0.00001, help="Learning rate for discriminator")
|
207 |
+
parser.add_argument("--wandb_tracking", help="A binary (True/False) argument for using WandB tracking or not")
|
208 |
+
parser.add_argument("--desc", type=str, default='', help="Description for WandB")
|
209 |
+
|
210 |
+
# Parse command-line arguments
|
211 |
+
args = parser.parse_args()
|
212 |
+
|
213 |
+
# Call the scripts function with parsed arguments
|
214 |
+
train_model(args.root_dir, args.start_epoch, args.num_epochs, args.load_model_g, args.load_model_d,
|
215 |
+
args.num_workers, args.val_freq, args.batch_size, args.accum_iter, args.lr, args.lr_d,
|
216 |
+
args.wandb_tracking, args.desc)
|