Spaces:
Runtime error
Runtime error
echen01
commited on
Commit
•
33dd132
1
Parent(s):
f7bf9fb
update demo
Browse files- .gitignore +3 -0
- app.py +48 -13
- checkpoints/model_gradio_demo_input.pt +2 -2
- configs/global_config.py +1 -1
- configs/hyperparameters.py +2 -2
- configs/paths_config.py +1 -3
- dnnlib/__pycache__/__init__.cpython-39.pyc +0 -0
- dnnlib/__pycache__/util.cpython-39.pyc +0 -0
- flagged/1880s/tmpqihfju9s.png +3 -0
- embeddings/2010/PTI/input/0.pt → flagged/1890s/tmphowwncle.png +2 -2
- flagged/1900s/tmpiyqo8vid.png +3 -0
- flagged/1910s/tmplsvkbwmm.png +3 -0
- flagged/1920s/tmpm88oq3m5.png +3 -0
- flagged/1930s/tmp1xg1cf_k.png +3 -0
- flagged/1940s/tmpyxtlostd.png +3 -0
- flagged/1950s/tmp00ekzg14.png +3 -0
- flagged/1960s/tmp0ptji34n.png +3 -0
- flagged/1970s/tmppty6lfdc.png +3 -0
- flagged/1980s/tmpo1znpbho.png +3 -0
- flagged/1990s/tmpkoptsgzw.png +3 -0
- flagged/2000s/tmpn85t4qg9.png +3 -0
- flagged/2010s/tmpxwhqu3bi.png +3 -0
- flagged/Cropped Input/tmpewip1qud.png +3 -0
- flagged/Input Image/tmpqdqusbe1.png +3 -0
- flagged/log.csv +2 -0
- imgs/00061_1920.png +3 -0
- imgs/cropped/input.png +2 -2
- imgs/input.png +2 -2
- run_pti.py +5 -5
- torch_utils/__pycache__/__init__.cpython-39.pyc +0 -0
- torch_utils/__pycache__/custom_ops.cpython-39.pyc +0 -0
- torch_utils/__pycache__/misc.cpython-39.pyc +0 -0
- torch_utils/__pycache__/persistence.cpython-39.pyc +0 -0
- torch_utils/ops/__pycache__/__init__.cpython-39.pyc +0 -0
- torch_utils/ops/__pycache__/bias_act.cpython-39.pyc +0 -0
- torch_utils/ops/__pycache__/conv2d_gradfix.cpython-39.pyc +0 -0
- torch_utils/ops/__pycache__/conv2d_resample.cpython-39.pyc +0 -0
- torch_utils/ops/__pycache__/fma.cpython-39.pyc +0 -0
- torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-39.pyc +0 -0
- torch_utils/ops/__pycache__/upfirdn2d.cpython-39.pyc +0 -0
- training/coaches/base_coach.py +4 -3
- training/coaches/multi_id_coach.py +2 -2
- training/coaches/single_id_coach.py +2 -2
- utils/models_utils.py +2 -2
.gitignore
CHANGED
@@ -158,3 +158,6 @@ cython_debug/
|
|
158 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
#.idea/
|
|
|
|
|
|
|
|
158 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
#.idea/
|
161 |
+
checkpoints/*
|
162 |
+
embeddings/*
|
163 |
+
imgs/cropped/*
|
app.py
CHANGED
@@ -5,7 +5,8 @@ import torch
|
|
5 |
import math
|
6 |
from torchvision import transforms
|
7 |
from run_pti import run_PTI
|
8 |
-
|
|
|
9 |
years = [str(y) for y in range(1880, 2020, 10)]
|
10 |
decades = [y + "s" for y in years]
|
11 |
|
@@ -32,14 +33,15 @@ def run_alignment(image_path,idx=None):
|
|
32 |
return aligned_image
|
33 |
|
34 |
def predict(inp, in_decade):
|
|
|
35 |
#with torch.no_grad():
|
36 |
inp.save("imgs/input.png")
|
37 |
inversion = run_alignment("imgs/input.png", idx=0)
|
38 |
inversion.save("imgs/cropped/input.png")
|
39 |
-
run_PTI(run_name="gradio_demo", use_wandb=False, use_multi_id_training=False)
|
40 |
#inversion = Image.open("imgs/cropped/input.png")
|
41 |
|
42 |
-
|
43 |
pti_models = {}
|
44 |
|
45 |
for year in years:
|
@@ -57,14 +59,16 @@ def predict(inp, in_decade):
|
|
57 |
p += delta
|
58 |
|
59 |
space = 0
|
60 |
-
dst = Image.new("RGB", (256 * (len(years) + 1) + (space * len(years)), 256), color='white')
|
61 |
|
62 |
|
63 |
-
w_pti = torch.load(f"embeddings/
|
64 |
|
65 |
border_width = 10
|
66 |
#fill_color = 'red'
|
67 |
-
dst.paste(inversion, (0, 0))
|
|
|
|
|
68 |
|
69 |
|
70 |
|
@@ -76,14 +80,45 @@ def predict(inp, in_decade):
|
|
76 |
# if year == in_year:
|
77 |
# img = img.crop((border_width, border_width, 256 - border_width, 256-border_width))
|
78 |
# img = PIL.ImageOps.expand(img, border=border_width, fill=fill_color)
|
79 |
-
dst.paste(img, ((256 + space) * (i+1), 0))
|
|
|
80 |
dst
|
81 |
return dst
|
82 |
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
import math
|
6 |
from torchvision import transforms
|
7 |
from run_pti import run_PTI
|
8 |
+
from configs import global_config, paths_config
|
9 |
+
device = global_config.device
|
10 |
years = [str(y) for y in range(1880, 2020, 10)]
|
11 |
decades = [y + "s" for y in years]
|
12 |
|
|
|
33 |
return aligned_image
|
34 |
|
35 |
def predict(inp, in_decade):
|
36 |
+
in_year = in_decade[:-1]
|
37 |
#with torch.no_grad():
|
38 |
inp.save("imgs/input.png")
|
39 |
inversion = run_alignment("imgs/input.png", idx=0)
|
40 |
inversion.save("imgs/cropped/input.png")
|
41 |
+
run_PTI(run_name="gradio_demo", in_year=in_year, use_wandb=False, use_multi_id_training=False)
|
42 |
#inversion = Image.open("imgs/cropped/input.png")
|
43 |
|
44 |
+
|
45 |
pti_models = {}
|
46 |
|
47 |
for year in years:
|
|
|
59 |
p += delta
|
60 |
|
61 |
space = 0
|
62 |
+
#dst = Image.new("RGB", (256 * (len(years) + 1) + (space * len(years)), 256), color='white')
|
63 |
|
64 |
|
65 |
+
w_pti = torch.load(f"embeddings/gradio/PTI/input/0.pt", map_location=device)
|
66 |
|
67 |
border_width = 10
|
68 |
#fill_color = 'red'
|
69 |
+
#dst.paste(inversion, (0, 0))
|
70 |
+
dst = []
|
71 |
+
dst.append(inversion)
|
72 |
|
73 |
|
74 |
|
|
|
80 |
# if year == in_year:
|
81 |
# img = img.crop((border_width, border_width, 256 - border_width, 256-border_width))
|
82 |
# img = PIL.ImageOps.expand(img, border=border_width, fill=fill_color)
|
83 |
+
#dst.paste(img, ((256 + space) * (i+1), 0))
|
84 |
+
dst.append(img)
|
85 |
dst
|
86 |
return dst
|
87 |
|
88 |
|
89 |
+
|
90 |
+
with gr.Blocks() as demo:
|
91 |
+
with gr.Row():
|
92 |
+
with gr.Column():
|
93 |
+
|
94 |
+
in_img = gr.Image(label="Input Image", type="pil")
|
95 |
+
in_year = gr.Dropdown(label="Input Decade", choices=decades, value="2010s")
|
96 |
+
submit = gr.Button(value="Submit")
|
97 |
+
examples = gr.Examples(examples=[["imgs/Steven-Yeun.jpg", "2010s"], ["imgs/00061_1920.png", "1920s"]], inputs=[in_img, in_year])
|
98 |
+
with gr.Column() as outs:
|
99 |
+
with gr.Row():
|
100 |
+
cropped = gr.Image(label=f"Cropped Input", type="pil").style(height=256, width=256)
|
101 |
+
out_1880 = gr.Image(label=f"1880", type="pil").style(height=256, width=256)
|
102 |
+
out_1890 = gr.Image(label=f"1890", type="pil").style(height=256, width=256)
|
103 |
+
with gr.Row():
|
104 |
+
out_1900 = gr.Image(label=f"1900", type="pil").style(height=256, width=256)
|
105 |
+
out_1910 = gr.Image(label=f"1910", type="pil").style(height=256, width=256)
|
106 |
+
out_1920 = gr.Image(label=f"1920", type="pil").style(height=256, width=256)
|
107 |
+
with gr.Row():
|
108 |
+
out_1930 = gr.Image(label=f"1930", type="pil").style(height=256, width=256)
|
109 |
+
out_1940 = gr.Image(label=f"1940", type="pil").style(height=256, width=256)
|
110 |
+
out_1950 = gr.Image(label=f"1950", type="pil").style(height=256, width=256)
|
111 |
+
with gr.Row():
|
112 |
+
out_1960 = gr.Image(label=f"1960", type="pil").style(height=256, width=256)
|
113 |
+
out_1970 = gr.Image(label=f"1970", type="pil").style(height=256, width=256)
|
114 |
+
out_1980 = gr.Image(label=f"1980", type="pil").style(height=256, width=256)
|
115 |
+
with gr.Row():
|
116 |
+
out_1990 = gr.Image(label=f"1990", type="pil").style(height=256, width=256)
|
117 |
+
out_2000 = gr.Image(label=f"2000", type="pil").style(height=256, width=256)
|
118 |
+
out_2010 = gr.Image(label=f"2010", type="pil").style(height=256, width=256)
|
119 |
+
|
120 |
+
outs = [cropped, out_1880, out_1890, out_1900, out_1910, out_1920, out_1930, out_1940, out_1950, out_1960, out_1970, out_1980, out_1990, out_2000, out_2010]
|
121 |
+
submit.click(predict, inputs=[in_img, in_year], outputs=outs)
|
122 |
+
|
123 |
+
|
124 |
+
demo.launch() #server_name="0.0.0.0", server_port=8098)
|
checkpoints/model_gradio_demo_input.pt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0b43a87853f7d90d798348bdce0934a72d03c2127fec6d28d026e6f023063ce9
|
3 |
+
size 99869025
|
configs/global_config.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
## Device
|
2 |
-
cuda_visible_devices = "
|
3 |
device = "cpu"
|
4 |
|
5 |
## Logs
|
|
|
1 |
## Device
|
2 |
+
cuda_visible_devices = "1"
|
3 |
device = "cpu"
|
4 |
|
5 |
## Logs
|
configs/hyperparameters.py
CHANGED
@@ -12,14 +12,14 @@ regulizer_lpips_lambda = 0.1
|
|
12 |
regulizer_alpha = 30
|
13 |
|
14 |
## Loss
|
15 |
-
use_mask =
|
16 |
pt_l2_lambda = 0.7
|
17 |
pt_lpips_lambda = 1
|
18 |
color_transfer_lambda = 0 # 1e6
|
19 |
id_lambda = 1
|
20 |
|
21 |
## Steps
|
22 |
-
LPIPS_value_threshold = 0.
|
23 |
max_pti_steps = 350
|
24 |
first_inv_steps = 450
|
25 |
max_images_to_invert = 10
|
|
|
12 |
regulizer_alpha = 30
|
13 |
|
14 |
## Loss
|
15 |
+
use_mask = False
|
16 |
pt_l2_lambda = 0.7
|
17 |
pt_lpips_lambda = 1
|
18 |
color_transfer_lambda = 0 # 1e6
|
19 |
id_lambda = 1
|
20 |
|
21 |
## Steps
|
22 |
+
LPIPS_value_threshold = 0.03 # 0.06
|
23 |
max_pti_steps = 350
|
24 |
first_inv_steps = 450
|
25 |
max_images_to_invert = 10
|
configs/paths_config.py
CHANGED
@@ -4,8 +4,6 @@ year = "2010"
|
|
4 |
e4e = "./pretrained_models/e4e_ffhq_encode.pt"
|
5 |
|
6 |
|
7 |
-
stylegan2_ada_ffhq = f"pretrained_models/{year}.pkl"
|
8 |
-
|
9 |
style_clip_pretrained_mappers = ""
|
10 |
ir_se50 = "pretrained_models/model_ir_se50.pth"
|
11 |
dlib = "./pretrained_models/align.dat"
|
@@ -22,7 +20,7 @@ experiments_output_dir = "./output"
|
|
22 |
input_data_path = (
|
23 |
f"imgs/cropped"
|
24 |
)
|
25 |
-
input_data_id =
|
26 |
|
27 |
|
28 |
|
|
|
4 |
e4e = "./pretrained_models/e4e_ffhq_encode.pt"
|
5 |
|
6 |
|
|
|
|
|
7 |
style_clip_pretrained_mappers = ""
|
8 |
ir_se50 = "pretrained_models/model_ir_se50.pth"
|
9 |
dlib = "./pretrained_models/align.dat"
|
|
|
20 |
input_data_path = (
|
21 |
f"imgs/cropped"
|
22 |
)
|
23 |
+
input_data_id = "gradio"
|
24 |
|
25 |
|
26 |
|
dnnlib/__pycache__/__init__.cpython-39.pyc
CHANGED
Binary files a/dnnlib/__pycache__/__init__.cpython-39.pyc and b/dnnlib/__pycache__/__init__.cpython-39.pyc differ
|
|
dnnlib/__pycache__/util.cpython-39.pyc
CHANGED
Binary files a/dnnlib/__pycache__/util.cpython-39.pyc and b/dnnlib/__pycache__/util.cpython-39.pyc differ
|
|
flagged/1880s/tmpqihfju9s.png
ADDED
Git LFS Details
|
embeddings/2010/PTI/input/0.pt → flagged/1890s/tmphowwncle.png
RENAMED
File without changes
|
flagged/1900s/tmpiyqo8vid.png
ADDED
Git LFS Details
|
flagged/1910s/tmplsvkbwmm.png
ADDED
Git LFS Details
|
flagged/1920s/tmpm88oq3m5.png
ADDED
Git LFS Details
|
flagged/1930s/tmp1xg1cf_k.png
ADDED
Git LFS Details
|
flagged/1940s/tmpyxtlostd.png
ADDED
Git LFS Details
|
flagged/1950s/tmp00ekzg14.png
ADDED
Git LFS Details
|
flagged/1960s/tmp0ptji34n.png
ADDED
Git LFS Details
|
flagged/1970s/tmppty6lfdc.png
ADDED
Git LFS Details
|
flagged/1980s/tmpo1znpbho.png
ADDED
Git LFS Details
|
flagged/1990s/tmpkoptsgzw.png
ADDED
Git LFS Details
|
flagged/2000s/tmpn85t4qg9.png
ADDED
Git LFS Details
|
flagged/2010s/tmpxwhqu3bi.png
ADDED
Git LFS Details
|
flagged/Cropped Input/tmpewip1qud.png
ADDED
Git LFS Details
|
flagged/Input Image/tmpqdqusbe1.png
ADDED
Git LFS Details
|
flagged/log.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
Input Image,Input Decade,Cropped Input,1880s,1890s,1900s,1910s,1920s,1930s,1940s,1950s,1960s,1970s,1980s,1990s,2000s,2010s,flag,username,timestamp
|
2 |
+
/share/phoenix/nfs04/S7/emc348/faces-through-time/flagged/Input Image/tmpqdqusbe1.png,1920s,/share/phoenix/nfs04/S7/emc348/faces-through-time/flagged/Cropped Input/tmpewip1qud.png,/share/phoenix/nfs04/S7/emc348/faces-through-time/flagged/1880s/tmpqihfju9s.png,/share/phoenix/nfs04/S7/emc348/faces-through-time/flagged/1890s/tmphowwncle.png,/share/phoenix/nfs04/S7/emc348/faces-through-time/flagged/1900s/tmpiyqo8vid.png,/share/phoenix/nfs04/S7/emc348/faces-through-time/flagged/1910s/tmplsvkbwmm.png,/share/phoenix/nfs04/S7/emc348/faces-through-time/flagged/1920s/tmpm88oq3m5.png,/share/phoenix/nfs04/S7/emc348/faces-through-time/flagged/1930s/tmp1xg1cf_k.png,/share/phoenix/nfs04/S7/emc348/faces-through-time/flagged/1940s/tmpyxtlostd.png,/share/phoenix/nfs04/S7/emc348/faces-through-time/flagged/1950s/tmp00ekzg14.png,/share/phoenix/nfs04/S7/emc348/faces-through-time/flagged/1960s/tmp0ptji34n.png,/share/phoenix/nfs04/S7/emc348/faces-through-time/flagged/1970s/tmppty6lfdc.png,/share/phoenix/nfs04/S7/emc348/faces-through-time/flagged/1980s/tmpo1znpbho.png,/share/phoenix/nfs04/S7/emc348/faces-through-time/flagged/1990s/tmpkoptsgzw.png,/share/phoenix/nfs04/S7/emc348/faces-through-time/flagged/2000s/tmpn85t4qg9.png,/share/phoenix/nfs04/S7/emc348/faces-through-time/flagged/2010s/tmpxwhqu3bi.png,,,2023-06-07 16:44:52.259412
|
imgs/00061_1920.png
ADDED
Git LFS Details
|
imgs/cropped/input.png
CHANGED
Git LFS Details
|
Git LFS Details
|
imgs/input.png
CHANGED
Git LFS Details
|
Git LFS Details
|
run_pti.py
CHANGED
@@ -12,7 +12,7 @@ from training.coaches.single_id_coach import SingleIDCoach
|
|
12 |
from utils.ImagesDataset import ImagesDataset
|
13 |
|
14 |
|
15 |
-
def run_PTI(run_name="", use_wandb=False, use_multi_id_training=False):
|
16 |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
17 |
os.environ["CUDA_VISIBLE_DEVICES"] = global_config.cuda_visible_devices
|
18 |
|
@@ -40,9 +40,9 @@ def run_PTI(run_name="", use_wandb=False, use_multi_id_training=False):
|
|
40 |
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
|
41 |
|
42 |
if use_multi_id_training:
|
43 |
-
coach = MultiIDCoach(dataloader, use_wandb)
|
44 |
else:
|
45 |
-
coach = SingleIDCoach(dataloader, use_wandb)
|
46 |
|
47 |
coach.train()
|
48 |
|
@@ -50,6 +50,6 @@ def run_PTI(run_name="", use_wandb=False, use_multi_id_training=False):
|
|
50 |
|
51 |
|
52 |
if __name__ == "__main__":
|
53 |
-
run_name = f"
|
54 |
print(run_name)
|
55 |
-
run_PTI(run_name=run_name, use_wandb=False, use_multi_id_training=False)
|
|
|
12 |
from utils.ImagesDataset import ImagesDataset
|
13 |
|
14 |
|
15 |
+
def run_PTI(run_name="", in_year="2010", use_wandb=False, use_multi_id_training=False):
|
16 |
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
17 |
os.environ["CUDA_VISIBLE_DEVICES"] = global_config.cuda_visible_devices
|
18 |
|
|
|
40 |
dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
|
41 |
|
42 |
if use_multi_id_training:
|
43 |
+
coach = MultiIDCoach(dataloader, in_year, use_wandb)
|
44 |
else:
|
45 |
+
coach = SingleIDCoach(dataloader, in_year, use_wandb)
|
46 |
|
47 |
coach.train()
|
48 |
|
|
|
50 |
|
51 |
|
52 |
if __name__ == "__main__":
|
53 |
+
run_name = f"pti"
|
54 |
print(run_name)
|
55 |
+
run_PTI(run_name=run_name, in_year="2010", use_wandb=False, use_multi_id_training=False)
|
torch_utils/__pycache__/__init__.cpython-39.pyc
CHANGED
Binary files a/torch_utils/__pycache__/__init__.cpython-39.pyc and b/torch_utils/__pycache__/__init__.cpython-39.pyc differ
|
|
torch_utils/__pycache__/custom_ops.cpython-39.pyc
CHANGED
Binary files a/torch_utils/__pycache__/custom_ops.cpython-39.pyc and b/torch_utils/__pycache__/custom_ops.cpython-39.pyc differ
|
|
torch_utils/__pycache__/misc.cpython-39.pyc
CHANGED
Binary files a/torch_utils/__pycache__/misc.cpython-39.pyc and b/torch_utils/__pycache__/misc.cpython-39.pyc differ
|
|
torch_utils/__pycache__/persistence.cpython-39.pyc
CHANGED
Binary files a/torch_utils/__pycache__/persistence.cpython-39.pyc and b/torch_utils/__pycache__/persistence.cpython-39.pyc differ
|
|
torch_utils/ops/__pycache__/__init__.cpython-39.pyc
CHANGED
Binary files a/torch_utils/ops/__pycache__/__init__.cpython-39.pyc and b/torch_utils/ops/__pycache__/__init__.cpython-39.pyc differ
|
|
torch_utils/ops/__pycache__/bias_act.cpython-39.pyc
CHANGED
Binary files a/torch_utils/ops/__pycache__/bias_act.cpython-39.pyc and b/torch_utils/ops/__pycache__/bias_act.cpython-39.pyc differ
|
|
torch_utils/ops/__pycache__/conv2d_gradfix.cpython-39.pyc
CHANGED
Binary files a/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-39.pyc and b/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-39.pyc differ
|
|
torch_utils/ops/__pycache__/conv2d_resample.cpython-39.pyc
CHANGED
Binary files a/torch_utils/ops/__pycache__/conv2d_resample.cpython-39.pyc and b/torch_utils/ops/__pycache__/conv2d_resample.cpython-39.pyc differ
|
|
torch_utils/ops/__pycache__/fma.cpython-39.pyc
CHANGED
Binary files a/torch_utils/ops/__pycache__/fma.cpython-39.pyc and b/torch_utils/ops/__pycache__/fma.cpython-39.pyc differ
|
|
torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-39.pyc
CHANGED
Binary files a/torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-39.pyc and b/torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-39.pyc differ
|
|
torch_utils/ops/__pycache__/upfirdn2d.cpython-39.pyc
CHANGED
Binary files a/torch_utils/ops/__pycache__/upfirdn2d.cpython-39.pyc and b/torch_utils/ops/__pycache__/upfirdn2d.cpython-39.pyc differ
|
|
training/coaches/base_coach.py
CHANGED
@@ -23,7 +23,7 @@ import copy
|
|
23 |
|
24 |
|
25 |
class BaseCoach:
|
26 |
-
def __init__(self, data_loader, use_wandb):
|
27 |
|
28 |
self.use_wandb = use_wandb
|
29 |
self.data_loader = data_loader
|
@@ -56,6 +56,7 @@ class BaseCoach:
|
|
56 |
.to(global_config.device)
|
57 |
.eval()
|
58 |
)
|
|
|
59 |
|
60 |
if hyperparameters.use_mask:
|
61 |
self.mask = mask.Mask(device=global_config.device)
|
@@ -69,11 +70,11 @@ class BaseCoach:
|
|
69 |
def restart_training(self):
|
70 |
|
71 |
# Initialize networks
|
72 |
-
self.G = load_old_G()
|
73 |
|
74 |
toogle_grad(self.G, True)
|
75 |
|
76 |
-
self.original_G = load_old_G()
|
77 |
|
78 |
self.space_regulizer = Space_Regulizer(self.original_G, self.lpips_loss)
|
79 |
self.optimizer = self.configure_optimizers()
|
|
|
23 |
|
24 |
|
25 |
class BaseCoach:
|
26 |
+
def __init__(self, data_loader, in_year, use_wandb):
|
27 |
|
28 |
self.use_wandb = use_wandb
|
29 |
self.data_loader = data_loader
|
|
|
56 |
.to(global_config.device)
|
57 |
.eval()
|
58 |
)
|
59 |
+
self.in_year = in_year
|
60 |
|
61 |
if hyperparameters.use_mask:
|
62 |
self.mask = mask.Mask(device=global_config.device)
|
|
|
70 |
def restart_training(self):
|
71 |
|
72 |
# Initialize networks
|
73 |
+
self.G = load_old_G(self.in_year)
|
74 |
|
75 |
toogle_grad(self.G, True)
|
76 |
|
77 |
+
self.original_G = load_old_G(self.in_year)
|
78 |
|
79 |
self.space_regulizer = Space_Regulizer(self.original_G, self.lpips_loss)
|
80 |
self.optimizer = self.configure_optimizers()
|
training/coaches/multi_id_coach.py
CHANGED
@@ -10,8 +10,8 @@ from utils.log_utils import log_images_from_w
|
|
10 |
|
11 |
|
12 |
class MultiIDCoach(BaseCoach):
|
13 |
-
def __init__(self, data_loader, use_wandb):
|
14 |
-
super().__init__(data_loader, use_wandb)
|
15 |
|
16 |
def train(self):
|
17 |
self.G.synthesis.train()
|
|
|
10 |
|
11 |
|
12 |
class MultiIDCoach(BaseCoach):
|
13 |
+
def __init__(self, data_loader, in_year, use_wandb):
|
14 |
+
super().__init__(data_loader, in_year, use_wandb)
|
15 |
|
16 |
def train(self):
|
17 |
self.G.synthesis.train()
|
training/coaches/single_id_coach.py
CHANGED
@@ -9,8 +9,8 @@ import copy
|
|
9 |
|
10 |
|
11 |
class SingleIDCoach(BaseCoach):
|
12 |
-
def __init__(self, data_loader, use_wandb):
|
13 |
-
super().__init__(data_loader, use_wandb)
|
14 |
|
15 |
def train(self):
|
16 |
|
|
|
9 |
|
10 |
|
11 |
class SingleIDCoach(BaseCoach):
|
12 |
+
def __init__(self, data_loader, in_year, use_wandb):
|
13 |
+
super().__init__(data_loader, in_year, use_wandb)
|
14 |
|
15 |
def train(self):
|
16 |
|
utils/models_utils.py
CHANGED
@@ -18,8 +18,8 @@ def load_tuned_G(run_id, type):
|
|
18 |
return new_G
|
19 |
|
20 |
|
21 |
-
def load_old_G():
|
22 |
-
with open(
|
23 |
old_G = pickle.load(f)['G_ema'].to(global_config.device).eval()
|
24 |
old_G = old_G.float()
|
25 |
return old_G
|
|
|
18 |
return new_G
|
19 |
|
20 |
|
21 |
+
def load_old_G(in_year):
|
22 |
+
with open(f"pretrained_models/{in_year}.pkl", 'rb') as f:
|
23 |
old_G = pickle.load(f)['G_ema'].to(global_config.device).eval()
|
24 |
old_G = old_G.float()
|
25 |
return old_G
|