teja141290 commited on
Commit
412f263
·
1 Parent(s): 35758f6

Initial commit with Git LFS tracking

Browse files
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
38
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
39
+ *.gif filter=lfs diff=lfs merge=lfs -text
40
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
41
+ *.pth filter=lfs diff=lfs merge=lfs -text
assets/docs/ex1.gif ADDED

Git LFS Details

  • SHA256: cf0ea42c8b1d0f87d185184d4d638855392c19fa204879b32ad8e54fbc633443
  • Pointer size: 132 Bytes
  • Size of remote file: 1.47 MB
assets/docs/ex2.gif ADDED

Git LFS Details

  • SHA256: 89d87e92e7844a2b1d51e5ce69f3b15ecf657aaf8c433b967db33180411418e2
  • Pointer size: 132 Bytes
  • Size of remote file: 1.88 MB
assets/docs/ex3.gif ADDED

Git LFS Details

  • SHA256: a049aae887abc1d4d7faccd02b023ac909dccda4c2273b2cbc49a9fa71b263ba
  • Pointer size: 132 Bytes
  • Size of remote file: 1.92 MB
assets/docs/ex4.gif ADDED

Git LFS Details

  • SHA256: ffc437e988cda30318ed8dab6c348da8ec3cbf6f780718d4f863767818907467
  • Pointer size: 132 Bytes
  • Size of remote file: 1.68 MB
assets/docs/ex5.gif ADDED

Git LFS Details

  • SHA256: 07c6035a1425dfe2c52a2efeb0847bd41877b7dba396f1164d8ab62b65870716
  • Pointer size: 132 Bytes
  • Size of remote file: 2.5 MB
assets/docs/ex5_img.png ADDED

Git LFS Details

  • SHA256: 8299d8c625bc4b4fb8a787c5c5cad89c9a94eb693b1255f36310a21d66b6508e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.3 MB
assets/docs/sam_ex.gif ADDED

Git LFS Details

  • SHA256: 2b9bd47e3a83afa7652cde402c419c764a3c16ba1671c8361f130ae3dda5a9fa
  • Pointer size: 132 Bytes
  • Size of remote file: 6.01 MB
assets/docs/vid20.gif ADDED

Git LFS Details

  • SHA256: 80ffacdc59cbd233b2c3d8eaa9ddb27342cad7826a5969f0805e798931f96e36
  • Pointer size: 132 Bytes
  • Size of remote file: 6.57 MB
assets/docs/vid35orig.gif ADDED

Git LFS Details

  • SHA256: fe8380b1f68fbc936c3692a64cc4344dbd842692ff33b9b9ae1e703bb57a47ca
  • Pointer size: 132 Bytes
  • Size of remote file: 6.55 MB
assets/docs/vid60.gif ADDED

Git LFS Details

  • SHA256: 657d5f7f487aaab4a9e03bc61530921a321a1d95b7934c122f2cf821712a7b6b
  • Pointer size: 132 Bytes
  • Size of remote file: 6.72 MB
assets/gradio_example_images/1.png ADDED

Git LFS Details

  • SHA256: cff048edd766cea2eb970abf86d3a6b581680589b43c37623835efc8955ab48f
  • Pointer size: 131 Bytes
  • Size of remote file: 987 kB
assets/gradio_example_images/2.png ADDED

Git LFS Details

  • SHA256: 3a28998e8c61101a1071c77c64fa692900f5f4b3cc75b8f8ad5aaed87e19be2f
  • Pointer size: 132 Bytes
  • Size of remote file: 2.41 MB
assets/gradio_example_images/3.png ADDED

Git LFS Details

  • SHA256: 8ac6a2f780110fb2124275f4bb5518989afa9fa1aa390a59109f3fe72c58446a
  • Pointer size: 132 Bytes
  • Size of remote file: 2.25 MB
assets/gradio_example_images/4.png ADDED

Git LFS Details

  • SHA256: 7d22eae5966510727bcde2be670ffd86c176aea5b043aef599ae77e62a3b4e06
  • Pointer size: 132 Bytes
  • Size of remote file: 1.99 MB
assets/gradio_example_images/5.png ADDED

Git LFS Details

  • SHA256: 0df5f6af3546dd561a8de090d70dea7a025c81a70b3134e7289af880d7d1953c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.14 MB
assets/gradio_example_images/6.png ADDED

Git LFS Details

  • SHA256: f60c9b9d5f28cf10db7c6b5b7e859bab0157b1c372bbf5f711c47b546da04869
  • Pointer size: 132 Bytes
  • Size of remote file: 1.36 MB
assets/gradio_example_images/orig.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:107f1bb068638619d3712dfd72fa21b4fba8d072faa9768a85090d3369c70b8e
3
+ size 3675407
assets/mask1024.jpg ADDED

Git LFS Details

  • SHA256: d28fcdfd77b6d22c24153fbea0ea39b84cc967063fbeeee9389e1513e8ba8565
  • Pointer size: 131 Bytes
  • Size of remote file: 207 kB
assets/mask512.jpg ADDED

Git LFS Details

  • SHA256: 062526a2a4d73c91f06c90059a78d48e46a5dc7c73a5652175c42c3a4ad635c3
  • Pointer size: 130 Bytes
  • Size of remote file: 10.5 kB
model/__pycache__/models.cpython-310.pyc ADDED
Binary file (3.22 kB). View file
 
model/best_discriminator_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa0362ca43e381848ac34fd8b44a3d65d9eb6100b1bcb77aa706c0e1c58e06f1
3
+ size 2668758
model/best_unet_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:230b6a007c65af43a67dcbbb46f4504cff71031e6d410952ef195ba6db90e942
3
+ size 124275652
model/losses.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import lpips # LPIPS library for perceptual loss
4
+
5
+ class GeneratorLoss(nn.Module):
6
+ def __init__(self, discriminator_model, l1_weight=1.0, perceptual_weight=1.0, adversarial_weight=0.05,
7
+ device="cpu"):
8
+ super(GeneratorLoss, self).__init__()
9
+ self.discriminator_model = discriminator_model
10
+ self.l1_weight = l1_weight
11
+ self.perceptual_weight = perceptual_weight
12
+ self.adversarial_weight = adversarial_weight
13
+ self.criterion_l1 = nn.L1Loss()
14
+ self.criterion_adversarial = nn.BCEWithLogitsLoss()
15
+ self.criterion_perceptual = lpips.LPIPS(net='vgg').to(device)
16
+
17
+ def forward(self, output, target, source):
18
+ # L1 loss
19
+
20
+ l1_loss = self.criterion_l1(output, target)
21
+
22
+ # Perceptual loss
23
+ perceptual_loss = torch.mean(self.criterion_perceptual(output, target))
24
+
25
+ # Adversarial loss
26
+ fake_input = torch.cat([output, source[:, 4:5, :, :]], dim=1)
27
+ fake_prediction = self.discriminator_model(fake_input)
28
+
29
+ adversarial_loss = self.criterion_adversarial(fake_prediction, torch.ones_like(fake_prediction))
30
+
31
+ # Combine losses
32
+ generator_loss = self.l1_weight * l1_loss + self.perceptual_weight * perceptual_loss + \
33
+ self.adversarial_weight * adversarial_loss
34
+
35
+ return generator_loss, l1_loss, perceptual_loss, adversarial_loss
36
+
37
+ class DiscriminatorLoss(nn.Module):
38
+ def __init__(self, discriminator_model, fake_weight=1.0, real_weight=2.0, mock_weight=.5):
39
+ super(DiscriminatorLoss, self).__init__()
40
+ self.discriminator_model = discriminator_model
41
+ self.criterion_adversarial = nn.BCEWithLogitsLoss()
42
+ self.fake_weight = fake_weight
43
+ self.real_weight = real_weight
44
+ self.mock_weight = mock_weight
45
+
46
+ def forward(self, output, target, source):
47
+ # Adversarial loss
48
+ fake_input = torch.cat([output, source[:, 4:5, :, :]], dim=1) # prediction img with target age
49
+ real_input = torch.cat([target, source[:, 4:5, :, :]], dim=1) # target img with target age
50
+
51
+ mock_input1 = torch.cat([source[:, :3, :, :], source[:, 4:5, :, :]], dim=1) # source img with target age
52
+ mock_input2 = torch.cat([target, source[:, 3:4, :, :]], dim=1) # target img with source age
53
+ mock_input3 = torch.cat([output, source[:, 3:4, :, :]], dim=1) # prediction img with source age
54
+ mock_input4 = torch.cat([target, source[:, 3:4, :, :]], dim=1) # target img with target age
55
+
56
+ fake_pred, real_pred = self.discriminator_model(fake_input), self.discriminator_model(real_input)
57
+ mock_pred1, mock_pred2, mock_pred3, mock_pred4 = (self.discriminator_model(mock_input1),
58
+ self.discriminator_model(mock_input2),
59
+ self.discriminator_model(mock_input3),
60
+ self.discriminator_model(mock_input4))
61
+
62
+ discriminator_loss = (self.fake_weight * self.criterion_adversarial(fake_pred, torch.zeros_like(fake_pred)) +
63
+ self.real_weight * self.criterion_adversarial(real_pred, torch.ones_like(real_pred)) +
64
+ self.mock_weight * self.criterion_adversarial(mock_pred1, torch.zeros_like(mock_pred1)) +
65
+ self.mock_weight * self.criterion_adversarial(mock_pred2, torch.zeros_like(mock_pred2)) +
66
+ self.mock_weight * self.criterion_adversarial(mock_pred3, torch.zeros_like(mock_pred3)) +
67
+ self.mock_weight * self.criterion_adversarial(mock_pred4, torch.zeros_like(mock_pred4))
68
+ )
69
+
70
+ return discriminator_loss
model/models.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import antialiased_cnns
4
+
5
+
6
+ class DownLayer(nn.Module):
7
+ def __init__(self, in_channels, out_channels):
8
+ super(DownLayer, self).__init__()
9
+ self.layer = nn.Sequential(
10
+ nn.MaxPool2d(kernel_size=2, stride=1),
11
+ antialiased_cnns.BlurPool(in_channels, stride=2),
12
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
13
+ nn.LeakyReLU(inplace=True),
14
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
15
+ nn.LeakyReLU(inplace=True)
16
+ )
17
+
18
+ def forward(self, x):
19
+ return self.layer(x)
20
+
21
+
22
+ class UpLayer(nn.Module):
23
+ def __init__(self, in_channels, out_channels):
24
+ super(UpLayer, self).__init__()
25
+ # Conv transpose upsampling
26
+
27
+ self.blur_upsample = nn.Sequential(
28
+ nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0),
29
+ antialiased_cnns.BlurPool(out_channels, stride=1)
30
+ )
31
+
32
+ self.layer = nn.Sequential(
33
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
34
+ nn.LeakyReLU(inplace=True),
35
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
36
+ nn.LeakyReLU(inplace=True)
37
+ )
38
+
39
+ def forward(self, x, skip):
40
+ x = self.blur_upsample(x)
41
+ x = torch.cat([x, skip], dim=1) # Concatenate with skip connection
42
+ return self.layer(x)
43
+
44
+
45
+ class UNet(nn.Module):
46
+ def __init__(self):
47
+ super(UNet, self).__init__()
48
+ self.init_conv = nn.Sequential(
49
+ nn.Conv2d(5, 64, kernel_size=3, padding=1), # output: 512 x 512 x 64
50
+ nn.LeakyReLU(inplace=True),
51
+ nn.Conv2d(64, 64, kernel_size=3, padding=1), # output: 512 x 512 x 64
52
+ nn.LeakyReLU(inplace=True)
53
+ )
54
+
55
+ self.down1 = DownLayer(64, 128) # output: 256 x 256 x 128
56
+ self.down2 = DownLayer(128, 256) # output: 128 x 128 x 256
57
+ self.down3 = DownLayer(256, 512) # output: 64 x 64 x 512
58
+ self.down4 = DownLayer(512, 1024) # output: 32 x 32 x 1024
59
+ self.up1 = UpLayer(1024, 512) # output: 64 x 64 x 512
60
+ self.up2 = UpLayer(512, 256) # output: 128 x 128 x 256
61
+ self.up3 = UpLayer(256, 128) # output: 256 x 256 x 128
62
+ self.up4 = UpLayer(128, 64) # output: 512 x 512 x 64
63
+ self.final_conv = nn.Conv2d(64, 3, kernel_size=1) # output: 512 x 512 x 3
64
+
65
+ def forward(self, x):
66
+ x0 = self.init_conv(x)
67
+ x1 = self.down1(x0)
68
+ x2 = self.down2(x1)
69
+ x3 = self.down3(x2)
70
+ x4 = self.down4(x3)
71
+ x = self.up1(x4, x3)
72
+ x = self.up2(x, x2)
73
+ x = self.up3(x, x1)
74
+ x = self.up4(x, x0)
75
+ x = self.final_conv(x)
76
+ return x
77
+
78
+
79
+ class PatchGANDiscriminator(nn.Module):
80
+ def __init__(self, input_channels=3):
81
+ super(PatchGANDiscriminator, self).__init__()
82
+ self.model = nn.Sequential(
83
+ nn.Conv2d(input_channels, 64, kernel_size=4, stride=2, padding=1),
84
+ nn.LeakyReLU(0.2, inplace=True),
85
+
86
+ nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
87
+ nn.BatchNorm2d(128),
88
+ nn.LeakyReLU(0.2, inplace=True),
89
+
90
+ nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
91
+ nn.BatchNorm2d(256),
92
+ nn.LeakyReLU(0.2, inplace=True),
93
+
94
+ nn.Conv2d(256, 1, kernel_size=4, stride=1, padding=1)
95
+ # Output layer with 1 channel for binary classification
96
+ )
97
+
98
+ def forward(self, x):
99
+ return self.model(x)
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ antialiased_cnns
4
+ lpips
5
+ ffmpy
6
+ av
7
+ gradio
8
+ cmake
9
+ face_recognition
10
+ dlib
11
+ numpy
12
+ Pillow
scripts/__pycache__/test_functions.cpython-310.pyc ADDED
Binary file (6.11 kB). View file
 
scripts/app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from model.models import UNet
4
+ from scripts.test_functions import process_image, process_video
5
+
6
+ window_size = 512
7
+ stride = 256
8
+ steps = 18
9
+ frame_count = 0
10
+
11
+ def get_model():
12
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
13
+ unet_model = UNet().to(device)
14
+ unet_model.load_state_dict(torch.load("model/best_unet_model.pth", map_location=device))
15
+ unet_model.eval()
16
+ return unet_model
17
+
18
+ unet_model = get_model()
19
+
20
+ def block_img(image, source_age, target_age):
21
+ from PIL import Image as PILImage
22
+ import numpy as np
23
+ if isinstance(image, str):
24
+ image = PILImage.open(image).convert('RGB')
25
+ elif isinstance(image, np.ndarray) and image.dtype == object:
26
+ image = image.astype(np.uint8)
27
+ return process_image(unet_model, image, video=False, source_age=source_age,
28
+ target_age=target_age, window_size=window_size, stride=stride)
29
+
30
+ def block_img_vid(image, source_age):
31
+ from PIL import Image as PILImage
32
+ import numpy as np
33
+ if isinstance(image, str):
34
+ image = PILImage.open(image).convert('RGB')
35
+ elif isinstance(image, np.ndarray) and image.dtype == object:
36
+ image = image.astype(np.uint8)
37
+ return process_image(unet_model, image, video=True, source_age=source_age,
38
+ target_age=0, window_size=window_size, stride=stride, steps=steps)
39
+
40
+ def block_vid(video_path, source_age, target_age):
41
+ return process_video(unet_model, video_path, source_age, target_age,
42
+ window_size=window_size, stride=stride, frame_count=frame_count)
43
+
44
+ demo_img = gr.Interface(
45
+ fn=block_img,
46
+ inputs=[
47
+ gr.Image(type="pil"),
48
+ gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose your current age"),
49
+ gr.Slider(10, 90, value=80, step=1, label="Target age", info="Choose the age you want to become")
50
+ ],
51
+ outputs="image",
52
+ examples=[
53
+ ['assets/gradio_example_images/1.png', 20, 80],
54
+ ['assets/gradio_example_images/2.png', 75, 40],
55
+ ['assets/gradio_example_images/3.png', 30, 70],
56
+ ['assets/gradio_example_images/4.png', 22, 60],
57
+ ['assets/gradio_example_images/5.png', 28, 75],
58
+ ['assets/gradio_example_images/6.png', 35, 15]
59
+ ],
60
+ description="Input an image of a person and age them from the source age to the target age."
61
+ )
62
+
63
+ demo_img_vid = gr.Interface(
64
+ fn=block_img_vid,
65
+ inputs=[
66
+ gr.Image(type="pil"),
67
+ gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose your current age"),
68
+ ],
69
+ outputs=gr.Video(),
70
+ examples=[
71
+ ['assets/gradio_example_images/1.png', 20],
72
+ ['assets/gradio_example_images/2.png', 75],
73
+ ['assets/gradio_example_images/3.png', 30],
74
+ ['assets/gradio_example_images/4.png', 22],
75
+ ['assets/gradio_example_images/5.png', 28],
76
+ ['assets/gradio_example_images/6.png', 35]
77
+ ],
78
+ description="Input an image of a person and a video will be returned of the person at different ages."
79
+ )
80
+
81
+ demo_vid = gr.Interface(
82
+ fn=block_vid,
83
+ inputs=[
84
+ gr.Video(),
85
+ gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose your current age"),
86
+ gr.Slider(10, 90, value=80, step=1, label="Target age", info="Choose the age you want to become")
87
+ ],
88
+ outputs=gr.Video(),
89
+ examples=[
90
+ ['assets/gradio_example_images/orig.mp4', 35, 60],
91
+ ],
92
+ description="Input a video of a person, and it will be aged frame-by-frame."
93
+ )
94
+
95
+ demo = gr.TabbedInterface([demo_img, demo_img_vid, demo_vid],
96
+ tab_names=['Image inference demo', 'Image animation demo', 'Video inference demo'],
97
+ title="Face Re-Aging Demo",
98
+ )
99
+
100
+ if __name__ == "__main__":
101
+ demo.launch()
scripts/gradio_demo.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import argparse
4
+
5
+ import sys
6
+ sys.path.append(".")
7
+
8
+ from model.models import UNet
9
+ from scripts.test_functions import process_image, process_video
10
+
11
+ # default settings
12
+ window_size = 512
13
+ stride = 256
14
+ steps = 18
15
+ frame_count = 0
16
+
17
+ def run(model_path):
18
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19
+ unet_model = UNet().to(device)
20
+ unet_model.load_state_dict(torch.load(model_path, map_location=device))
21
+ unet_model.eval()
22
+
23
+ def block_img(image, source_age, target_age):
24
+ from PIL import Image as PILImage
25
+ import numpy as np
26
+ # If image is a file path (from examples), load it
27
+ if isinstance(image, str):
28
+ image = PILImage.open(image).convert('RGB')
29
+ # If image is a numpy array with dtype object (sometimes from Gradio), convert to uint8
30
+ elif isinstance(image, np.ndarray) and image.dtype == object:
31
+ image = image.astype(np.uint8)
32
+ return process_image(unet_model, image, video=False, source_age=source_age,
33
+ target_age=target_age, window_size=window_size, stride=stride)
34
+
35
+ def block_img_vid(image, source_age):
36
+ from PIL import Image as PILImage
37
+ import numpy as np
38
+ if isinstance(image, str):
39
+ image = PILImage.open(image).convert('RGB')
40
+ elif isinstance(image, np.ndarray) and image.dtype == object:
41
+ image = image.astype(np.uint8)
42
+ return process_image(unet_model, image, video=True, source_age=source_age,
43
+ target_age=0, window_size=window_size, stride=stride, steps=steps)
44
+
45
+ def block_vid(video_path, source_age, target_age):
46
+ return process_video(unet_model, video_path, source_age, target_age,
47
+ window_size=window_size, stride=stride, frame_count=frame_count)
48
+
49
+ demo_img = gr.Interface(
50
+ fn=block_img,
51
+ inputs=[
52
+ gr.Image(type="pil"),
53
+ gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose your current age"),
54
+ gr.Slider(10, 90, value=80, step=1, label="Target age", info="Choose the age you want to become")
55
+ ],
56
+ outputs="image",
57
+ examples=[
58
+ ['assets/gradio_example_images/1.png', 20, 80],
59
+ ['assets/gradio_example_images/2.png', 75, 40],
60
+ ['assets/gradio_example_images/3.png', 30, 70],
61
+ ['assets/gradio_example_images/4.png', 22, 60],
62
+ ['assets/gradio_example_images/5.png', 28, 75],
63
+ ['assets/gradio_example_images/6.png', 35, 15]
64
+ ],
65
+ description="Input an image of a person and age them from the source age to the target age."
66
+ )
67
+
68
+ demo_img_vid = gr.Interface(
69
+ fn=block_img_vid,
70
+ inputs=[
71
+ gr.Image(type="pil"),
72
+ gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose your current age"),
73
+ ],
74
+ outputs=gr.Video(),
75
+ examples=[
76
+ ['assets/gradio_example_images/1.png', 20],
77
+ ['assets/gradio_example_images/2.png', 75],
78
+ ['assets/gradio_example_images/3.png', 30],
79
+ ['assets/gradio_example_images/4.png', 22],
80
+ ['assets/gradio_example_images/5.png', 28],
81
+ ['assets/gradio_example_images/6.png', 35]
82
+ ],
83
+ description="Input an image of a person and a video will be returned of the person at different ages."
84
+ )
85
+
86
+ demo_vid = gr.Interface(
87
+ fn=block_vid,
88
+ inputs=[
89
+ gr.Video(),
90
+ gr.Slider(10, 90, value=20, step=1, label="Current age", info="Choose your current age"),
91
+ gr.Slider(10, 90, value=80, step=1, label="Target age", info="Choose the age you want to become")
92
+ ],
93
+ outputs=gr.Video(),
94
+ examples=[
95
+ ['assets/gradio_example_images/orig.mp4', 35, 60],
96
+ ],
97
+ description="Input a video of a person, and it will be aged frame-by-frame."
98
+ )
99
+
100
+ demo = gr.TabbedInterface([demo_img, demo_img_vid, demo_vid],
101
+ tab_names=['Image inference demo', 'Image animation demo', 'Video inference demo'],
102
+ title="Face Re-Aging Demo",
103
+ )
104
+
105
+ demo.launch()
106
+
107
+
108
+ if __name__ == "__main__":
109
+ # Define command-line arguments
110
+ parser = argparse.ArgumentParser(description="Testing script - Image demo")
111
+ parser.add_argument("--model_path", type=str, default="model/best_unet_model.pth", help="Path to the model")
112
+
113
+ # Parse command-line arguments
114
+ args = parser.parse_args()
115
+
116
+ run(args.model_path)
scripts/test_functions.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import face_recognition
2
+ import numpy as np
3
+ import os
4
+ import torch
5
+ from torch.autograd import Variable
6
+ from torchvision import transforms
7
+ from torchvision.io import write_video
8
+ import tempfile
9
+ import subprocess
10
+ import json
11
+ from ffmpy import FFmpeg, FFprobe
12
+ from PIL import Image
13
+
14
+ mask_file = torch.from_numpy(np.array(Image.open('assets/mask1024.jpg').convert('L'))) / 255
15
+ small_mask_file = torch.from_numpy(np.array(Image.open('assets/mask512.jpg').convert('L'))) / 255
16
+
17
+ def sliding_window_tensor(input_tensor, window_size, stride, your_model, mask=mask_file, small_mask=small_mask_file):
18
+ """
19
+ Apply aging operation on input tensor using a sliding-window method. This operation is done on the GPU, if available.
20
+ """
21
+
22
+ input_tensor = input_tensor.to(next(your_model.parameters()).device)
23
+ mask = mask.to(next(your_model.parameters()).device)
24
+ small_mask = small_mask.to(next(your_model.parameters()).device)
25
+
26
+ n, c, h, w = input_tensor.size()
27
+ output_tensor = torch.zeros((n, 3, h, w), dtype=input_tensor.dtype, device=input_tensor.device)
28
+
29
+ count_tensor = torch.zeros((n, 3, h, w), dtype=torch.float32, device=input_tensor.device)
30
+
31
+ add = 2 if window_size % stride != 0 else 1
32
+
33
+ for y in range(0, h - window_size + add, stride):
34
+ for x in range(0, w - window_size + add, stride):
35
+ window = input_tensor[:, :, y:y + window_size, x:x + window_size]
36
+
37
+ # Apply the same preprocessing as during training
38
+ input_variable = Variable(window, requires_grad=False) # Assuming GPU is available
39
+
40
+ # Forward pass
41
+ with torch.no_grad():
42
+ output = your_model(input_variable)
43
+
44
+ output_tensor[:, :, y:y + window_size, x:x + window_size] += output * small_mask
45
+ count_tensor[:, :, y:y + window_size, x:x + window_size] += small_mask
46
+
47
+ count_tensor = torch.clamp(count_tensor, min=1.0)
48
+
49
+ # Average the overlapping regions
50
+ output_tensor /= count_tensor
51
+
52
+ # Apply mask
53
+ output_tensor *= mask
54
+
55
+ return output_tensor.cpu()
56
+
57
+
58
+ def process_image(your_model, image, video, source_age, target_age=0,
59
+ window_size=512, stride=256, steps=18):
60
+ input_size = (1024, 1024)
61
+ # Robustly handle image input for face_recognition
62
+ from PIL import Image as PILImage
63
+ import numpy as np
64
+ if isinstance(image, PILImage.Image):
65
+ image = image.convert('RGB')
66
+ image = np.array(image)
67
+ elif isinstance(image, np.ndarray):
68
+ if image.ndim == 2: # grayscale
69
+ image = np.stack([image]*3, axis=-1)
70
+ elif image.shape[2] == 4: # RGBA
71
+ image = image[..., :3]
72
+ if image.dtype == np.float32 or image.dtype == np.float64:
73
+ if image.max() <= 1.0:
74
+ image = (image * 255).astype(np.uint8)
75
+ else:
76
+ image = image.astype(np.uint8)
77
+ elif image.dtype != np.uint8:
78
+ image = image.astype(np.uint8)
79
+ else:
80
+ image = np.array(PILImage.fromarray(image).convert('RGB'))
81
+ # Ensure shape is (H, W, 3) and contiguous
82
+ if image.ndim != 3 or image.shape[2] != 3:
83
+ raise ValueError(f"Image must have shape (H, W, 3), got {image.shape}")
84
+ image = np.ascontiguousarray(image, dtype=np.uint8)
85
+ print(f"[DEBUG] image type: {type(image)}, shape: {image.shape}, dtype: {image.dtype}, contiguous: {image.flags['C_CONTIGUOUS']}")
86
+ if video: # h264 codec requires frame size to be divisible by 2.
87
+ width, height, depth = image.shape
88
+ new_width = width if width % 2 == 0 else width - 1
89
+ new_height = height if height % 2 == 0 else height - 1
90
+ image.resize((new_width, new_height, depth))
91
+
92
+ # Diagnostic: try face_recognition on this image, and if it fails, save and reload
93
+ try:
94
+ fl = face_recognition.face_locations(image)[0]
95
+ except Exception as e:
96
+ print(f"[DEBUG] face_locations failed: {e}. Saving image for test...")
97
+ import tempfile
98
+ from PIL import Image as PILImage
99
+ temp_path = tempfile.mktemp(suffix='.png')
100
+ PILImage.fromarray(image).save(temp_path)
101
+ print(f"[DEBUG] Saved image to {temp_path}. Trying face_recognition.load_image_file...")
102
+ loaded_img = face_recognition.load_image_file(temp_path)
103
+ print(f"[DEBUG] loaded_img type: {type(loaded_img)}, shape: {loaded_img.shape}, dtype: {loaded_img.dtype}")
104
+ fl = face_recognition.face_locations(loaded_img)[0]
105
+
106
+ # calculate margins
107
+ margin_y_t = int((fl[2] - fl[0]) * .63 * .85) # larger as the forehead is often cut off
108
+ margin_y_b = int((fl[2] - fl[0]) * .37 * .85)
109
+ margin_x = int((fl[1] - fl[3]) // (2 / .85))
110
+ margin_y_t += 2 * margin_x - margin_y_t - margin_y_b # make sure square is preserved
111
+
112
+ l_y = max([fl[0] - margin_y_t, 0])
113
+ r_y = min([fl[2] + margin_y_b, image.shape[0]])
114
+ l_x = max([fl[3] - margin_x, 0])
115
+ r_x = min([fl[1] + margin_x, image.shape[1]])
116
+
117
+ # crop image
118
+ cropped_image = image[l_y:r_y, l_x:r_x, :]
119
+
120
+ # Resizing
121
+ orig_size = cropped_image.shape[:2]
122
+
123
+ cropped_image = transforms.ToTensor()(cropped_image)
124
+
125
+ cropped_image_resized = transforms.Resize(input_size, interpolation=Image.BILINEAR, antialias=True)(cropped_image)
126
+
127
+ source_age_channel = torch.full_like(cropped_image_resized[:1, :, :], source_age / 100)
128
+ target_age_channel = torch.full_like(cropped_image_resized[:1, :, :], target_age / 100)
129
+ input_tensor = torch.cat([cropped_image_resized, source_age_channel, target_age_channel], dim=0).unsqueeze(0)
130
+
131
+ image = transforms.ToTensor()(image)
132
+
133
+ if video:
134
+ # aging in steps
135
+ interval = .8 / steps
136
+ aged_cropped_images = torch.zeros((steps, 3, input_size[1], input_size[0]))
137
+ for i in range(0, steps):
138
+ input_tensor[:, -1, :, :] += interval
139
+
140
+ # performing actions on image
141
+ aged_cropped_images[i, ...] = sliding_window_tensor(input_tensor, window_size, stride, your_model)
142
+
143
+ # resize back to original size
144
+ aged_cropped_images_resized = transforms.Resize(orig_size, interpolation=Image.BILINEAR, antialias=True)(
145
+ aged_cropped_images)
146
+
147
+ # re-apply
148
+ image = image.repeat(steps, 1, 1, 1)
149
+
150
+ image[:, :, l_y:r_y, l_x:r_x] += aged_cropped_images_resized
151
+ image = torch.clamp(image, 0, 1)
152
+ image = (image * 255).to(torch.uint8)
153
+
154
+ output_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
155
+
156
+ write_video(output_file.name, image.permute(0, 2, 3, 1), 2)
157
+
158
+ return output_file.name
159
+
160
+ else:
161
+ # performing actions on image
162
+ aged_cropped_image = sliding_window_tensor(input_tensor, window_size, stride, your_model)
163
+
164
+ # resize back to original size
165
+ aged_cropped_image_resized = transforms.Resize(orig_size, interpolation=Image.BILINEAR, antialias=True)(
166
+ aged_cropped_image)
167
+
168
+ # re-apply
169
+ image[:, l_y:r_y, l_x:r_x] += aged_cropped_image_resized.squeeze(0)
170
+ image = torch.clamp(image, 0, 1)
171
+
172
+ return transforms.functional.to_pil_image(image)
173
+
174
+
175
+ def process_video(your_model, video_path, source_age, target_age, window_size=512, stride=256, frame_count=0):
176
+ """
177
+ Applying the aging to a video.
178
+ We age as from source_age to target_age, and return an image.
179
+ To limit the number of frames in a video, we can set frame_count.
180
+ """
181
+
182
+ # Extracting frames and placing them in a temporary directory
183
+ frames_dir = tempfile.TemporaryDirectory()
184
+ output_template = os.path.join(frames_dir.name, '%04d.jpg')
185
+
186
+ if frame_count:
187
+ ff = FFmpeg(
188
+ inputs={video_path: None},
189
+ outputs={output_template: ['-vf', f'select=lt(n\,{frame_count})', '-q:v', '1']}
190
+ )
191
+ else:
192
+ ff = FFmpeg(
193
+ inputs={video_path: None},
194
+ outputs={output_template: ['-q:v', '1']}
195
+ )
196
+
197
+ ff.run()
198
+
199
+ # Getting framerate (for reconstruction later)
200
+ ff = FFprobe(inputs={video_path: None},
201
+ global_options=['-v', 'error', '-select_streams', 'v', '-show_entries', 'stream=r_frame_rate', '-of',
202
+ 'default=noprint_wrappers=1:nokey=1'])
203
+ stdout, _ = ff.run(stdout=subprocess.PIPE, stderr=subprocess.PIPE)
204
+ frame_rate = eval(stdout.decode('utf-8').strip())
205
+
206
+
207
+ # Applying process_image to frames
208
+ processed_dir = tempfile.TemporaryDirectory()
209
+
210
+ for name in os.listdir(frames_dir.name):
211
+ image_path = os.path.join(frames_dir.name, name)
212
+ image = Image.open(image_path).convert('RGB')
213
+ image_aged = process_image(your_model, image, False, source_age, target_age, window_size, stride)
214
+ image_aged.save(os.path.join(processed_dir.name, name))
215
+
216
+ # Generating a new video
217
+ input_template = os.path.join(processed_dir.name, '%04d.jpg')
218
+ output_file = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
219
+ ff = FFmpeg(
220
+ inputs={input_template: f'-framerate {frame_rate}'}, global_options=['-y'],
221
+ outputs={output_file.name: ['-c:v', 'libx264', '-pix_fmt', 'yuv420p']}
222
+ )
223
+
224
+ ff.run()
225
+
226
+ frames_dir.cleanup()
227
+ processed_dir.cleanup()
228
+
229
+ return output_file.name
scripts/train.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.optim as optim
3
+ from torch.utils.data import DataLoader, random_split
4
+ import argparse
5
+
6
+ import sys
7
+ sys.path.append(".")
8
+
9
+ from model.models import UNet, PatchGANDiscriminator
10
+ from model.losses import GeneratorLoss, DiscriminatorLoss
11
+ from utils.dataloader import CustomDataset, transform
12
+
13
+
14
+ def train_model(root_dir, start_epoch, num_epochs, load_model_g, load_model_d, num_workers,
15
+ val_freq, batch_size, accum_iter, lr, lr_d, wandb_tracking, desc):
16
+ if wandb_tracking:
17
+ import wandb
18
+
19
+ wandb.init(project="FRAN",
20
+ # track hyperparameters and run metadata
21
+ config={
22
+ "lr": lr,
23
+ "lr_d": lr_d,
24
+ "dataset": root_dir,
25
+ "epochs": num_epochs,
26
+ "batch_size": batch_size,
27
+ "description": desc
28
+ }
29
+ )
30
+
31
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
+ print(f"device: {device}")
33
+ if torch.cuda.device_count() > 0:
34
+ print(f"{torch.cuda.device_count()} GPU(s)")
35
+ if torch.cuda.device_count() > 1:
36
+ print("multi-GPU training is currently not supported.")
37
+
38
+ # Create instances of the dataset and split into scripts and validation sets
39
+ dataset = CustomDataset(root_dir=root_dir, transform=transform)
40
+
41
+ # Assuming you want to use 80% of the data for scripts and 20% for validation
42
+ train_size = int(0.8 * len(dataset))
43
+ val_size = len(dataset) - train_size
44
+ train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
45
+
46
+ # Create data loaders for scripts and validation
47
+ train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
48
+ val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
49
+
50
+ # Create instances of the U-Net, discriminator, and loss models
51
+ unet_model = UNet()
52
+ discriminator_model = PatchGANDiscriminator(input_channels=4)
53
+
54
+ if load_model_g:
55
+ unet_model.load_state_dict(torch.load(load_model_g, map_location=device))
56
+ print(f'loaded {load_model_g} for unet_model')
57
+ if load_model_d:
58
+ discriminator_model.load_state_dict(torch.load(load_model_d, map_location=device))
59
+ print(f'loaded {load_model_d} for discriminator_model')
60
+
61
+ unet_model = unet_model.to(device)
62
+ discriminator_model = discriminator_model.to(device)
63
+
64
+ # if multiGPU:
65
+ # unet_model = nn.DataParallel(unet_model)
66
+ # discriminator_model = nn.DataParallel(discriminator_model)
67
+
68
+ # Create loss instances
69
+ generator_loss_func = GeneratorLoss(discriminator_model, l1_weight=1.0, perceptual_weight=1.0,
70
+ adversarial_weight=0.05, device=device)
71
+ discriminator_loss_func = DiscriminatorLoss(discriminator_model)
72
+
73
+ # Create instances of the Adam optimizer
74
+ optimizer_g = optim.Adam(unet_model.parameters(), lr=lr)
75
+ optimizer_d = optim.Adam(discriminator_model.parameters(), lr=lr_d)
76
+
77
+ # Training and validation loop
78
+ best_val_loss = float('inf')
79
+
80
+ for epoch in range(start_epoch - 1, num_epochs):
81
+ # Training
82
+ unet_model.train()
83
+ discriminator_model.train()
84
+ batch_idx = 0
85
+ for batch in train_dataloader:
86
+ batch_idx += 1
87
+ source_images, target_images = batch
88
+
89
+ # if not multiGPU:
90
+ # if multi GPU, nn.DataParallel will already put the batches on the right devices.
91
+ # Otherwise, we do it manually
92
+ source_images = source_images.to(device)
93
+ target_images = target_images.to(device)
94
+
95
+ # Zero gradients
96
+ # optimizer_g.zero_grad()
97
+ # optimizer_d.zero_grad()
98
+
99
+ # Forward pass
100
+ output_images = unet_model(source_images)
101
+ # if multiGPU:
102
+ # output_device = output_images.get_device()
103
+ # source_images, target_images = source_images.to(output_device), target_images.to(output_device)
104
+ output_images += source_images[:, :3, :, :]
105
+
106
+ # Discriminator pass
107
+ discriminator_loss = discriminator_loss_func(output_images.detach(), target_images, source_images)
108
+ # discriminator_loss /= accum_iter
109
+ discriminator_loss.backward()
110
+
111
+ if (batch_idx % accum_iter == 0) or (batch_idx == len(train_dataloader)):
112
+ optimizer_d.step()
113
+ optimizer_d.zero_grad()
114
+
115
+ # Generator pass
116
+ # Calculate the loss
117
+ generator_loss, l1_loss, per_loss, adv_loss = generator_loss_func(output_images, target_images,
118
+ source_images)
119
+ generator_loss, l1_loss, per_loss, adv_loss = [i / accum_iter for i in
120
+ [generator_loss, l1_loss, per_loss, adv_loss]]
121
+ generator_loss.backward()
122
+
123
+ if (batch_idx % accum_iter == 0) or (batch_idx == len(train_dataloader)):
124
+ optimizer_g.step()
125
+ optimizer_g.zero_grad()
126
+
127
+ # Print scripts information (if needed)
128
+ print(
129
+ f'Training Epoch [{epoch + 1}/{num_epochs}], Gen Loss: {generator_loss.item()}, L1: {l1_loss.item()}, P: {per_loss.item()}, A: {adv_loss.item()}, Dis Loss: {discriminator_loss.item()}')
130
+ if wandb_tracking:
131
+ wandb.log({
132
+ 'Training Epoch': epoch + 1,
133
+ 'Gen Loss': generator_loss.item(),
134
+ 'L1': l1_loss.item(),
135
+ 'P': per_loss.item(),
136
+ 'A': adv_loss.item(),
137
+ 'Dis Loss': discriminator_loss.item()
138
+ })
139
+
140
+ torch.save(unet_model.state_dict(), 'recent_unet_model.pth')
141
+ torch.save(discriminator_model.state_dict(), 'recent_discriminator_model.pth')
142
+
143
+ # Validation
144
+ if epoch % val_freq == 0:
145
+ unet_model.eval()
146
+ total_val_loss = 0.0
147
+ with torch.no_grad():
148
+ for val_batch in val_dataloader:
149
+ val_source_images, val_target_images = val_batch
150
+
151
+ # if not multiGPU:
152
+ # if multi GPU, nn.DataParallel will already put the batches on the right devices.
153
+ # Otherwise, we do it manually
154
+ val_source_images = val_source_images.to(device)
155
+ val_target_images = val_target_images.to(device)
156
+
157
+ # Forward pass
158
+ val_output_images = unet_model(val_source_images)
159
+
160
+ # if multiGPU:
161
+ # output_device = val_output_images.get_device()
162
+ # val_source_images, val_target_images = val_source_images.to(output_device), \
163
+ # val_target_images.to(output_device)
164
+
165
+ # Calculate the loss
166
+ generator_loss, _, _, _ = generator_loss_func(val_output_images, val_target_images,
167
+ val_source_images)
168
+ total_val_loss += generator_loss.item()
169
+
170
+ average_val_loss = total_val_loss / len(val_dataloader)
171
+
172
+ # Print validation information
173
+ print(f'Validation Epoch [{epoch + 1}/{num_epochs}], Average Loss: {average_val_loss}')
174
+ if wandb_tracking:
175
+ wandb.log({
176
+ 'Training Epoch': epoch + 1,
177
+ 'Val Loss': average_val_loss,
178
+ })
179
+
180
+ # Save the model with the best validation loss
181
+ if average_val_loss < best_val_loss:
182
+ best_val_loss = average_val_loss
183
+ torch.save(unet_model.state_dict(), 'best_unet_model.pth')
184
+ torch.save(discriminator_model.state_dict(), 'best_discriminator_model.pth')
185
+
186
+ if wandb_tracking:
187
+ wandb.finish()
188
+
189
+
190
+ if __name__ == "__main__":
191
+ # Define command-line arguments
192
+ parser = argparse.ArgumentParser(description="Training Script")
193
+ parser.add_argument("--root_dir", type=str, default='data/processed/train',
194
+ help="Path to the training data. Note the format: To use the dataloader, the directory should be filled with folders containing image files of various ages, where the file name is the age.")
195
+ parser.add_argument("--start_epoch", type=int, default=0, help="Start epoch, if scripts is resumed")
196
+ parser.add_argument("--num_epochs", type=int, default=2000, help="End epoch")
197
+ parser.add_argument("--load_model_g", type=str, default='',
198
+ help="Path to pretrained generator model. Leave blank to train from scratch")
199
+ parser.add_argument("--load_model_d", type=str, default='',
200
+ help="Path to pretrained discriminator model. Leave blank to train from scratch")
201
+ parser.add_argument("--num_workers", type=int, default=4, help="Number of workers")
202
+ parser.add_argument("--batch_size", type=int, default=3, help="Batch size")
203
+ parser.add_argument("--accum_iter", type=int, default=3, help="Number of batches after which weights are updated")
204
+ parser.add_argument("--val_freq", type=int, default=1, help="Validation frequency (epochs)")
205
+ parser.add_argument("--lr", type=float, default=0.00001, help="Learning rate for generator")
206
+ parser.add_argument("--lr_d", type=float, default=0.00001, help="Learning rate for discriminator")
207
+ parser.add_argument("--wandb_tracking", help="A binary (True/False) argument for using WandB tracking or not")
208
+ parser.add_argument("--desc", type=str, default='', help="Description for WandB")
209
+
210
+ # Parse command-line arguments
211
+ args = parser.parse_args()
212
+
213
+ # Call the scripts function with parsed arguments
214
+ train_model(args.root_dir, args.start_epoch, args.num_epochs, args.load_model_g, args.load_model_d,
215
+ args.num_workers, args.val_freq, args.batch_size, args.accum_iter, args.lr, args.lr_d,
216
+ args.wandb_tracking, args.desc)