echen01 commited on
Commit
33dd132
1 Parent(s): f7bf9fb

update demo

Browse files
Files changed (44) hide show
  1. .gitignore +3 -0
  2. app.py +48 -13
  3. checkpoints/model_gradio_demo_input.pt +2 -2
  4. configs/global_config.py +1 -1
  5. configs/hyperparameters.py +2 -2
  6. configs/paths_config.py +1 -3
  7. dnnlib/__pycache__/__init__.cpython-39.pyc +0 -0
  8. dnnlib/__pycache__/util.cpython-39.pyc +0 -0
  9. flagged/1880s/tmpqihfju9s.png +3 -0
  10. embeddings/2010/PTI/input/0.pt → flagged/1890s/tmphowwncle.png +2 -2
  11. flagged/1900s/tmpiyqo8vid.png +3 -0
  12. flagged/1910s/tmplsvkbwmm.png +3 -0
  13. flagged/1920s/tmpm88oq3m5.png +3 -0
  14. flagged/1930s/tmp1xg1cf_k.png +3 -0
  15. flagged/1940s/tmpyxtlostd.png +3 -0
  16. flagged/1950s/tmp00ekzg14.png +3 -0
  17. flagged/1960s/tmp0ptji34n.png +3 -0
  18. flagged/1970s/tmppty6lfdc.png +3 -0
  19. flagged/1980s/tmpo1znpbho.png +3 -0
  20. flagged/1990s/tmpkoptsgzw.png +3 -0
  21. flagged/2000s/tmpn85t4qg9.png +3 -0
  22. flagged/2010s/tmpxwhqu3bi.png +3 -0
  23. flagged/Cropped Input/tmpewip1qud.png +3 -0
  24. flagged/Input Image/tmpqdqusbe1.png +3 -0
  25. flagged/log.csv +2 -0
  26. imgs/00061_1920.png +3 -0
  27. imgs/cropped/input.png +2 -2
  28. imgs/input.png +2 -2
  29. run_pti.py +5 -5
  30. torch_utils/__pycache__/__init__.cpython-39.pyc +0 -0
  31. torch_utils/__pycache__/custom_ops.cpython-39.pyc +0 -0
  32. torch_utils/__pycache__/misc.cpython-39.pyc +0 -0
  33. torch_utils/__pycache__/persistence.cpython-39.pyc +0 -0
  34. torch_utils/ops/__pycache__/__init__.cpython-39.pyc +0 -0
  35. torch_utils/ops/__pycache__/bias_act.cpython-39.pyc +0 -0
  36. torch_utils/ops/__pycache__/conv2d_gradfix.cpython-39.pyc +0 -0
  37. torch_utils/ops/__pycache__/conv2d_resample.cpython-39.pyc +0 -0
  38. torch_utils/ops/__pycache__/fma.cpython-39.pyc +0 -0
  39. torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-39.pyc +0 -0
  40. torch_utils/ops/__pycache__/upfirdn2d.cpython-39.pyc +0 -0
  41. training/coaches/base_coach.py +4 -3
  42. training/coaches/multi_id_coach.py +2 -2
  43. training/coaches/single_id_coach.py +2 -2
  44. utils/models_utils.py +2 -2
.gitignore CHANGED
@@ -158,3 +158,6 @@ cython_debug/
158
  # and can be added to the global gitignore or merged into this file. For a more nuclear
159
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
  #.idea/
 
 
 
 
158
  # and can be added to the global gitignore or merged into this file. For a more nuclear
159
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
  #.idea/
161
+ checkpoints/*
162
+ embeddings/*
163
+ imgs/cropped/*
app.py CHANGED
@@ -5,7 +5,8 @@ import torch
5
  import math
6
  from torchvision import transforms
7
  from run_pti import run_PTI
8
- device = "cpu"
 
9
  years = [str(y) for y in range(1880, 2020, 10)]
10
  decades = [y + "s" for y in years]
11
 
@@ -32,14 +33,15 @@ def run_alignment(image_path,idx=None):
32
  return aligned_image
33
 
34
  def predict(inp, in_decade):
 
35
  #with torch.no_grad():
36
  inp.save("imgs/input.png")
37
  inversion = run_alignment("imgs/input.png", idx=0)
38
  inversion.save("imgs/cropped/input.png")
39
- run_PTI(run_name="gradio_demo", use_wandb=False, use_multi_id_training=False)
40
  #inversion = Image.open("imgs/cropped/input.png")
41
 
42
- in_year = in_decade[:-1]
43
  pti_models = {}
44
 
45
  for year in years:
@@ -57,14 +59,16 @@ def predict(inp, in_decade):
57
  p += delta
58
 
59
  space = 0
60
- dst = Image.new("RGB", (256 * (len(years) + 1) + (space * len(years)), 256), color='white')
61
 
62
 
63
- w_pti = torch.load(f"embeddings/{in_year}/PTI/input/0.pt", map_location=device)
64
 
65
  border_width = 10
66
  #fill_color = 'red'
67
- dst.paste(inversion, (0, 0))
 
 
68
 
69
 
70
 
@@ -76,14 +80,45 @@ def predict(inp, in_decade):
76
  # if year == in_year:
77
  # img = img.crop((border_width, border_width, 256 - border_width, 256-border_width))
78
  # img = PIL.ImageOps.expand(img, border=border_width, fill=fill_color)
79
- dst.paste(img, ((256 + space) * (i+1), 0))
 
80
  dst
81
  return dst
82
 
83
 
84
- gr.Interface(fn=predict,
85
- inputs=[gr.Image(label="Input Image", type="pil"), gr.Dropdown(label="Input Decade", choices=decades, value="2010s")],
86
- outputs=gr.Image(label="Decade Transformations", type="pil"),
87
- examples=[["imgs/Steven-Yeun.jpg", "2010s"]]
88
-
89
- ).launch() #.launch(server_name="0.0.0.0", server_port=8098)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import math
6
  from torchvision import transforms
7
  from run_pti import run_PTI
8
+ from configs import global_config, paths_config
9
+ device = global_config.device
10
  years = [str(y) for y in range(1880, 2020, 10)]
11
  decades = [y + "s" for y in years]
12
 
 
33
  return aligned_image
34
 
35
  def predict(inp, in_decade):
36
+ in_year = in_decade[:-1]
37
  #with torch.no_grad():
38
  inp.save("imgs/input.png")
39
  inversion = run_alignment("imgs/input.png", idx=0)
40
  inversion.save("imgs/cropped/input.png")
41
+ run_PTI(run_name="gradio_demo", in_year=in_year, use_wandb=False, use_multi_id_training=False)
42
  #inversion = Image.open("imgs/cropped/input.png")
43
 
44
+
45
  pti_models = {}
46
 
47
  for year in years:
 
59
  p += delta
60
 
61
  space = 0
62
+ #dst = Image.new("RGB", (256 * (len(years) + 1) + (space * len(years)), 256), color='white')
63
 
64
 
65
+ w_pti = torch.load(f"embeddings/gradio/PTI/input/0.pt", map_location=device)
66
 
67
  border_width = 10
68
  #fill_color = 'red'
69
+ #dst.paste(inversion, (0, 0))
70
+ dst = []
71
+ dst.append(inversion)
72
 
73
 
74
 
 
80
  # if year == in_year:
81
  # img = img.crop((border_width, border_width, 256 - border_width, 256-border_width))
82
  # img = PIL.ImageOps.expand(img, border=border_width, fill=fill_color)
83
+ #dst.paste(img, ((256 + space) * (i+1), 0))
84
+ dst.append(img)
85
  dst
86
  return dst
87
 
88
 
89
+
90
+ with gr.Blocks() as demo:
91
+ with gr.Row():
92
+ with gr.Column():
93
+
94
+ in_img = gr.Image(label="Input Image", type="pil")
95
+ in_year = gr.Dropdown(label="Input Decade", choices=decades, value="2010s")
96
+ submit = gr.Button(value="Submit")
97
+ examples = gr.Examples(examples=[["imgs/Steven-Yeun.jpg", "2010s"], ["imgs/00061_1920.png", "1920s"]], inputs=[in_img, in_year])
98
+ with gr.Column() as outs:
99
+ with gr.Row():
100
+ cropped = gr.Image(label=f"Cropped Input", type="pil").style(height=256, width=256)
101
+ out_1880 = gr.Image(label=f"1880", type="pil").style(height=256, width=256)
102
+ out_1890 = gr.Image(label=f"1890", type="pil").style(height=256, width=256)
103
+ with gr.Row():
104
+ out_1900 = gr.Image(label=f"1900", type="pil").style(height=256, width=256)
105
+ out_1910 = gr.Image(label=f"1910", type="pil").style(height=256, width=256)
106
+ out_1920 = gr.Image(label=f"1920", type="pil").style(height=256, width=256)
107
+ with gr.Row():
108
+ out_1930 = gr.Image(label=f"1930", type="pil").style(height=256, width=256)
109
+ out_1940 = gr.Image(label=f"1940", type="pil").style(height=256, width=256)
110
+ out_1950 = gr.Image(label=f"1950", type="pil").style(height=256, width=256)
111
+ with gr.Row():
112
+ out_1960 = gr.Image(label=f"1960", type="pil").style(height=256, width=256)
113
+ out_1970 = gr.Image(label=f"1970", type="pil").style(height=256, width=256)
114
+ out_1980 = gr.Image(label=f"1980", type="pil").style(height=256, width=256)
115
+ with gr.Row():
116
+ out_1990 = gr.Image(label=f"1990", type="pil").style(height=256, width=256)
117
+ out_2000 = gr.Image(label=f"2000", type="pil").style(height=256, width=256)
118
+ out_2010 = gr.Image(label=f"2010", type="pil").style(height=256, width=256)
119
+
120
+ outs = [cropped, out_1880, out_1890, out_1900, out_1910, out_1920, out_1930, out_1940, out_1950, out_1960, out_1970, out_1980, out_1990, out_2000, out_2010]
121
+ submit.click(predict, inputs=[in_img, in_year], outputs=outs)
122
+
123
+
124
+ demo.launch() #server_name="0.0.0.0", server_port=8098)
checkpoints/model_gradio_demo_input.pt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:65ee5644ec8ab0966a4eb51971995c071f8178db765750d45e32e0ed18a09738
3
- size 99867041
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0b43a87853f7d90d798348bdce0934a72d03c2127fec6d28d026e6f023063ce9
3
+ size 99869025
configs/global_config.py CHANGED
@@ -1,5 +1,5 @@
1
  ## Device
2
- cuda_visible_devices = "0"
3
  device = "cpu"
4
 
5
  ## Logs
 
1
  ## Device
2
+ cuda_visible_devices = "1"
3
  device = "cpu"
4
 
5
  ## Logs
configs/hyperparameters.py CHANGED
@@ -12,14 +12,14 @@ regulizer_lpips_lambda = 0.1
12
  regulizer_alpha = 30
13
 
14
  ## Loss
15
- use_mask = True
16
  pt_l2_lambda = 0.7
17
  pt_lpips_lambda = 1
18
  color_transfer_lambda = 0 # 1e6
19
  id_lambda = 1
20
 
21
  ## Steps
22
- LPIPS_value_threshold = 0.01 # 0.06
23
  max_pti_steps = 350
24
  first_inv_steps = 450
25
  max_images_to_invert = 10
 
12
  regulizer_alpha = 30
13
 
14
  ## Loss
15
+ use_mask = False
16
  pt_l2_lambda = 0.7
17
  pt_lpips_lambda = 1
18
  color_transfer_lambda = 0 # 1e6
19
  id_lambda = 1
20
 
21
  ## Steps
22
+ LPIPS_value_threshold = 0.03 # 0.06
23
  max_pti_steps = 350
24
  first_inv_steps = 450
25
  max_images_to_invert = 10
configs/paths_config.py CHANGED
@@ -4,8 +4,6 @@ year = "2010"
4
  e4e = "./pretrained_models/e4e_ffhq_encode.pt"
5
 
6
 
7
- stylegan2_ada_ffhq = f"pretrained_models/{year}.pkl"
8
-
9
  style_clip_pretrained_mappers = ""
10
  ir_se50 = "pretrained_models/model_ir_se50.pth"
11
  dlib = "./pretrained_models/align.dat"
@@ -22,7 +20,7 @@ experiments_output_dir = "./output"
22
  input_data_path = (
23
  f"imgs/cropped"
24
  )
25
- input_data_id = f"{year}"
26
 
27
 
28
 
 
4
  e4e = "./pretrained_models/e4e_ffhq_encode.pt"
5
 
6
 
 
 
7
  style_clip_pretrained_mappers = ""
8
  ir_se50 = "pretrained_models/model_ir_se50.pth"
9
  dlib = "./pretrained_models/align.dat"
 
20
  input_data_path = (
21
  f"imgs/cropped"
22
  )
23
+ input_data_id = "gradio"
24
 
25
 
26
 
dnnlib/__pycache__/__init__.cpython-39.pyc CHANGED
Binary files a/dnnlib/__pycache__/__init__.cpython-39.pyc and b/dnnlib/__pycache__/__init__.cpython-39.pyc differ
 
dnnlib/__pycache__/util.cpython-39.pyc CHANGED
Binary files a/dnnlib/__pycache__/util.cpython-39.pyc and b/dnnlib/__pycache__/util.cpython-39.pyc differ
 
flagged/1880s/tmpqihfju9s.png ADDED

Git LFS Details

  • SHA256: 17bae676f7b0b6a5dd2e9fafb734068213b698740ac819c7141c8421115f052c
  • Pointer size: 131 Bytes
  • Size of remote file: 106 kB
embeddings/2010/PTI/input/0.pt → flagged/1890s/tmphowwncle.png RENAMED
File without changes
flagged/1900s/tmpiyqo8vid.png ADDED

Git LFS Details

  • SHA256: e5999133f20362fc1773c28069e60ae7ea0b9eac0ab6ff7e90422fc8b9e28f75
  • Pointer size: 131 Bytes
  • Size of remote file: 102 kB
flagged/1910s/tmplsvkbwmm.png ADDED

Git LFS Details

  • SHA256: f8714b7ed0cea3ef837ddd02dcd462dcddbef3b212dad29bf494d6ea64125d74
  • Pointer size: 130 Bytes
  • Size of remote file: 85.7 kB
flagged/1920s/tmpm88oq3m5.png ADDED

Git LFS Details

  • SHA256: a74da6dd9e66dd8467a90509001c1b2547f96faa1fae26eae80acabf57d1055c
  • Pointer size: 130 Bytes
  • Size of remote file: 95.8 kB
flagged/1930s/tmp1xg1cf_k.png ADDED

Git LFS Details

  • SHA256: 003e137999b2172edc03b3908c22881194842aa1c54d86059b63b142ce95161d
  • Pointer size: 130 Bytes
  • Size of remote file: 96.2 kB
flagged/1940s/tmpyxtlostd.png ADDED

Git LFS Details

  • SHA256: d73414199aaf1eb2e7b8cad5b175dde899954754150f5b67338723830852b7bc
  • Pointer size: 130 Bytes
  • Size of remote file: 90.2 kB
flagged/1950s/tmp00ekzg14.png ADDED

Git LFS Details

  • SHA256: bccbbe1683c1403dba761417691f4507f95f13a46d2ee36c4421605454a82d4f
  • Pointer size: 130 Bytes
  • Size of remote file: 89.1 kB
flagged/1960s/tmp0ptji34n.png ADDED

Git LFS Details

  • SHA256: 5334bdbc823e859a79fd81bf4c12c93af2752ef665f2beea8ea5d8adf362a3ee
  • Pointer size: 131 Bytes
  • Size of remote file: 110 kB
flagged/1970s/tmppty6lfdc.png ADDED

Git LFS Details

  • SHA256: 946fefa001dd05e67a738094b13a1134abaf214023c3e792dcdc09053bcd9af1
  • Pointer size: 130 Bytes
  • Size of remote file: 93.3 kB
flagged/1980s/tmpo1znpbho.png ADDED

Git LFS Details

  • SHA256: 509a761d7b32d4931e6d6eeeec2bd29ef5e2d9a63e8805e0e75e267505722cfe
  • Pointer size: 130 Bytes
  • Size of remote file: 98.2 kB
flagged/1990s/tmpkoptsgzw.png ADDED

Git LFS Details

  • SHA256: 828bd9b28a0c92b4880567aed2dd4934c148776826840b6d5a5f19a846248a5e
  • Pointer size: 130 Bytes
  • Size of remote file: 85.7 kB
flagged/2000s/tmpn85t4qg9.png ADDED

Git LFS Details

  • SHA256: 0c7b2d14d835c80a51281858ea2752459173b366ec856262feec6196e3148110
  • Pointer size: 131 Bytes
  • Size of remote file: 104 kB
flagged/2010s/tmpxwhqu3bi.png ADDED

Git LFS Details

  • SHA256: 78879d72514a82c14ca4f189ce61c30394c8054eebabb7525cc7c114f5a5b581
  • Pointer size: 130 Bytes
  • Size of remote file: 99.5 kB
flagged/Cropped Input/tmpewip1qud.png ADDED

Git LFS Details

  • SHA256: e45fca429a0605b5b8686d5a5395a113e727deceda0f8647d46fe53120e90431
  • Pointer size: 130 Bytes
  • Size of remote file: 67.6 kB
flagged/Input Image/tmpqdqusbe1.png ADDED

Git LFS Details

  • SHA256: e99509f3093eaa1f9b15a39ee91b8db5c5bedb713838820f416b0edb58f04645
  • Pointer size: 131 Bytes
  • Size of remote file: 707 kB
flagged/log.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Input Image,Input Decade,Cropped Input,1880s,1890s,1900s,1910s,1920s,1930s,1940s,1950s,1960s,1970s,1980s,1990s,2000s,2010s,flag,username,timestamp
2
+ /share/phoenix/nfs04/S7/emc348/faces-through-time/flagged/Input Image/tmpqdqusbe1.png,1920s,/share/phoenix/nfs04/S7/emc348/faces-through-time/flagged/Cropped Input/tmpewip1qud.png,/share/phoenix/nfs04/S7/emc348/faces-through-time/flagged/1880s/tmpqihfju9s.png,/share/phoenix/nfs04/S7/emc348/faces-through-time/flagged/1890s/tmphowwncle.png,/share/phoenix/nfs04/S7/emc348/faces-through-time/flagged/1900s/tmpiyqo8vid.png,/share/phoenix/nfs04/S7/emc348/faces-through-time/flagged/1910s/tmplsvkbwmm.png,/share/phoenix/nfs04/S7/emc348/faces-through-time/flagged/1920s/tmpm88oq3m5.png,/share/phoenix/nfs04/S7/emc348/faces-through-time/flagged/1930s/tmp1xg1cf_k.png,/share/phoenix/nfs04/S7/emc348/faces-through-time/flagged/1940s/tmpyxtlostd.png,/share/phoenix/nfs04/S7/emc348/faces-through-time/flagged/1950s/tmp00ekzg14.png,/share/phoenix/nfs04/S7/emc348/faces-through-time/flagged/1960s/tmp0ptji34n.png,/share/phoenix/nfs04/S7/emc348/faces-through-time/flagged/1970s/tmppty6lfdc.png,/share/phoenix/nfs04/S7/emc348/faces-through-time/flagged/1980s/tmpo1znpbho.png,/share/phoenix/nfs04/S7/emc348/faces-through-time/flagged/1990s/tmpkoptsgzw.png,/share/phoenix/nfs04/S7/emc348/faces-through-time/flagged/2000s/tmpn85t4qg9.png,/share/phoenix/nfs04/S7/emc348/faces-through-time/flagged/2010s/tmpxwhqu3bi.png,,,2023-06-07 16:44:52.259412
imgs/00061_1920.png ADDED

Git LFS Details

  • SHA256: 8ebc10d6fc4142e993c64b4f417ea01468336af9f938e0bc565311dd74b210ee
  • Pointer size: 130 Bytes
  • Size of remote file: 88.2 kB
imgs/cropped/input.png CHANGED

Git LFS Details

  • SHA256: ba7b8df0bffe226c723eb22c537e66ff9de844e6aae7845a6c88e696f03b6a40
  • Pointer size: 131 Bytes
  • Size of remote file: 105 kB

Git LFS Details

  • SHA256: b51b0109bc1afa021e4b603585a7c36a6f6d30c7c7c220cc3b9bb28c1bb728d2
  • Pointer size: 130 Bytes
  • Size of remote file: 77 kB
imgs/input.png CHANGED

Git LFS Details

  • SHA256: 3f8c1b42d80f44efcf0cb03a301072284a0b7ad6ae6f11871be44b7fae79613e
  • Pointer size: 133 Bytes
  • Size of remote file: 13.7 MB

Git LFS Details

  • SHA256: bd037d2906ebd271abe73ea02a8762aee9e341c88d5a99a45039eb154dafd40d
  • Pointer size: 130 Bytes
  • Size of remote file: 89 kB
run_pti.py CHANGED
@@ -12,7 +12,7 @@ from training.coaches.single_id_coach import SingleIDCoach
12
  from utils.ImagesDataset import ImagesDataset
13
 
14
 
15
- def run_PTI(run_name="", use_wandb=False, use_multi_id_training=False):
16
  os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
17
  os.environ["CUDA_VISIBLE_DEVICES"] = global_config.cuda_visible_devices
18
 
@@ -40,9 +40,9 @@ def run_PTI(run_name="", use_wandb=False, use_multi_id_training=False):
40
  dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
41
 
42
  if use_multi_id_training:
43
- coach = MultiIDCoach(dataloader, use_wandb)
44
  else:
45
- coach = SingleIDCoach(dataloader, use_wandb)
46
 
47
  coach.train()
48
 
@@ -50,6 +50,6 @@ def run_PTI(run_name="", use_wandb=False, use_multi_id_training=False):
50
 
51
 
52
  if __name__ == "__main__":
53
- run_name = f"pti_{paths_config.year}"
54
  print(run_name)
55
- run_PTI(run_name=run_name, use_wandb=False, use_multi_id_training=False)
 
12
  from utils.ImagesDataset import ImagesDataset
13
 
14
 
15
+ def run_PTI(run_name="", in_year="2010", use_wandb=False, use_multi_id_training=False):
16
  os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
17
  os.environ["CUDA_VISIBLE_DEVICES"] = global_config.cuda_visible_devices
18
 
 
40
  dataloader = DataLoader(dataset, batch_size=1, shuffle=False)
41
 
42
  if use_multi_id_training:
43
+ coach = MultiIDCoach(dataloader, in_year, use_wandb)
44
  else:
45
+ coach = SingleIDCoach(dataloader, in_year, use_wandb)
46
 
47
  coach.train()
48
 
 
50
 
51
 
52
  if __name__ == "__main__":
53
+ run_name = f"pti"
54
  print(run_name)
55
+ run_PTI(run_name=run_name, in_year="2010", use_wandb=False, use_multi_id_training=False)
torch_utils/__pycache__/__init__.cpython-39.pyc CHANGED
Binary files a/torch_utils/__pycache__/__init__.cpython-39.pyc and b/torch_utils/__pycache__/__init__.cpython-39.pyc differ
 
torch_utils/__pycache__/custom_ops.cpython-39.pyc CHANGED
Binary files a/torch_utils/__pycache__/custom_ops.cpython-39.pyc and b/torch_utils/__pycache__/custom_ops.cpython-39.pyc differ
 
torch_utils/__pycache__/misc.cpython-39.pyc CHANGED
Binary files a/torch_utils/__pycache__/misc.cpython-39.pyc and b/torch_utils/__pycache__/misc.cpython-39.pyc differ
 
torch_utils/__pycache__/persistence.cpython-39.pyc CHANGED
Binary files a/torch_utils/__pycache__/persistence.cpython-39.pyc and b/torch_utils/__pycache__/persistence.cpython-39.pyc differ
 
torch_utils/ops/__pycache__/__init__.cpython-39.pyc CHANGED
Binary files a/torch_utils/ops/__pycache__/__init__.cpython-39.pyc and b/torch_utils/ops/__pycache__/__init__.cpython-39.pyc differ
 
torch_utils/ops/__pycache__/bias_act.cpython-39.pyc CHANGED
Binary files a/torch_utils/ops/__pycache__/bias_act.cpython-39.pyc and b/torch_utils/ops/__pycache__/bias_act.cpython-39.pyc differ
 
torch_utils/ops/__pycache__/conv2d_gradfix.cpython-39.pyc CHANGED
Binary files a/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-39.pyc and b/torch_utils/ops/__pycache__/conv2d_gradfix.cpython-39.pyc differ
 
torch_utils/ops/__pycache__/conv2d_resample.cpython-39.pyc CHANGED
Binary files a/torch_utils/ops/__pycache__/conv2d_resample.cpython-39.pyc and b/torch_utils/ops/__pycache__/conv2d_resample.cpython-39.pyc differ
 
torch_utils/ops/__pycache__/fma.cpython-39.pyc CHANGED
Binary files a/torch_utils/ops/__pycache__/fma.cpython-39.pyc and b/torch_utils/ops/__pycache__/fma.cpython-39.pyc differ
 
torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-39.pyc CHANGED
Binary files a/torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-39.pyc and b/torch_utils/ops/__pycache__/grid_sample_gradfix.cpython-39.pyc differ
 
torch_utils/ops/__pycache__/upfirdn2d.cpython-39.pyc CHANGED
Binary files a/torch_utils/ops/__pycache__/upfirdn2d.cpython-39.pyc and b/torch_utils/ops/__pycache__/upfirdn2d.cpython-39.pyc differ
 
training/coaches/base_coach.py CHANGED
@@ -23,7 +23,7 @@ import copy
23
 
24
 
25
  class BaseCoach:
26
- def __init__(self, data_loader, use_wandb):
27
 
28
  self.use_wandb = use_wandb
29
  self.data_loader = data_loader
@@ -56,6 +56,7 @@ class BaseCoach:
56
  .to(global_config.device)
57
  .eval()
58
  )
 
59
 
60
  if hyperparameters.use_mask:
61
  self.mask = mask.Mask(device=global_config.device)
@@ -69,11 +70,11 @@ class BaseCoach:
69
  def restart_training(self):
70
 
71
  # Initialize networks
72
- self.G = load_old_G()
73
 
74
  toogle_grad(self.G, True)
75
 
76
- self.original_G = load_old_G()
77
 
78
  self.space_regulizer = Space_Regulizer(self.original_G, self.lpips_loss)
79
  self.optimizer = self.configure_optimizers()
 
23
 
24
 
25
  class BaseCoach:
26
+ def __init__(self, data_loader, in_year, use_wandb):
27
 
28
  self.use_wandb = use_wandb
29
  self.data_loader = data_loader
 
56
  .to(global_config.device)
57
  .eval()
58
  )
59
+ self.in_year = in_year
60
 
61
  if hyperparameters.use_mask:
62
  self.mask = mask.Mask(device=global_config.device)
 
70
  def restart_training(self):
71
 
72
  # Initialize networks
73
+ self.G = load_old_G(self.in_year)
74
 
75
  toogle_grad(self.G, True)
76
 
77
+ self.original_G = load_old_G(self.in_year)
78
 
79
  self.space_regulizer = Space_Regulizer(self.original_G, self.lpips_loss)
80
  self.optimizer = self.configure_optimizers()
training/coaches/multi_id_coach.py CHANGED
@@ -10,8 +10,8 @@ from utils.log_utils import log_images_from_w
10
 
11
 
12
  class MultiIDCoach(BaseCoach):
13
- def __init__(self, data_loader, use_wandb):
14
- super().__init__(data_loader, use_wandb)
15
 
16
  def train(self):
17
  self.G.synthesis.train()
 
10
 
11
 
12
  class MultiIDCoach(BaseCoach):
13
+ def __init__(self, data_loader, in_year, use_wandb):
14
+ super().__init__(data_loader, in_year, use_wandb)
15
 
16
  def train(self):
17
  self.G.synthesis.train()
training/coaches/single_id_coach.py CHANGED
@@ -9,8 +9,8 @@ import copy
9
 
10
 
11
  class SingleIDCoach(BaseCoach):
12
- def __init__(self, data_loader, use_wandb):
13
- super().__init__(data_loader, use_wandb)
14
 
15
  def train(self):
16
 
 
9
 
10
 
11
  class SingleIDCoach(BaseCoach):
12
+ def __init__(self, data_loader, in_year, use_wandb):
13
+ super().__init__(data_loader, in_year, use_wandb)
14
 
15
  def train(self):
16
 
utils/models_utils.py CHANGED
@@ -18,8 +18,8 @@ def load_tuned_G(run_id, type):
18
  return new_G
19
 
20
 
21
- def load_old_G():
22
- with open(paths_config.stylegan2_ada_ffhq, 'rb') as f:
23
  old_G = pickle.load(f)['G_ema'].to(global_config.device).eval()
24
  old_G = old_G.float()
25
  return old_G
 
18
  return new_G
19
 
20
 
21
+ def load_old_G(in_year):
22
+ with open(f"pretrained_models/{in_year}.pkl", 'rb') as f:
23
  old_G = pickle.load(f)['G_ema'].to(global_config.device).eval()
24
  old_G = old_G.float()
25
  return old_G