Upload 10 files
Browse files- src/cyclegan_turbo.py +254 -0
- src/image_prep.py +12 -0
- src/inference_paired.py +65 -0
- src/inference_unpaired.py +53 -0
- src/model.py +73 -0
- src/my_utils/dino_struct.py +185 -0
- src/my_utils/training_utils.py +409 -0
- src/pix2pix_turbo.py +227 -0
- src/train_cyclegan_turbo.py +389 -0
- src/train_pix2pix_turbo.py +307 -0
src/cyclegan_turbo.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import copy
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from transformers import AutoTokenizer, CLIPTextModel
|
7 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel
|
8 |
+
from peft import LoraConfig
|
9 |
+
from peft.utils import get_peft_model_state_dict
|
10 |
+
p = "src/"
|
11 |
+
sys.path.append(p)
|
12 |
+
from model import make_1step_sched, my_vae_encoder_fwd, my_vae_decoder_fwd, download_url
|
13 |
+
|
14 |
+
|
15 |
+
class VAE_encode(nn.Module):
|
16 |
+
def __init__(self, vae, vae_b2a=None):
|
17 |
+
super(VAE_encode, self).__init__()
|
18 |
+
self.vae = vae
|
19 |
+
self.vae_b2a = vae_b2a
|
20 |
+
|
21 |
+
def forward(self, x, direction):
|
22 |
+
assert direction in ["a2b", "b2a"]
|
23 |
+
if direction == "a2b":
|
24 |
+
_vae = self.vae
|
25 |
+
else:
|
26 |
+
_vae = self.vae_b2a
|
27 |
+
return _vae.encode(x).latent_dist.sample() * _vae.config.scaling_factor
|
28 |
+
|
29 |
+
|
30 |
+
class VAE_decode(nn.Module):
|
31 |
+
def __init__(self, vae, vae_b2a=None):
|
32 |
+
super(VAE_decode, self).__init__()
|
33 |
+
self.vae = vae
|
34 |
+
self.vae_b2a = vae_b2a
|
35 |
+
|
36 |
+
def forward(self, x, direction):
|
37 |
+
assert direction in ["a2b", "b2a"]
|
38 |
+
if direction == "a2b":
|
39 |
+
_vae = self.vae
|
40 |
+
else:
|
41 |
+
_vae = self.vae_b2a
|
42 |
+
assert _vae.encoder.current_down_blocks is not None
|
43 |
+
_vae.decoder.incoming_skip_acts = _vae.encoder.current_down_blocks
|
44 |
+
x_decoded = (_vae.decode(x / _vae.config.scaling_factor).sample).clamp(-1, 1)
|
45 |
+
return x_decoded
|
46 |
+
|
47 |
+
|
48 |
+
def initialize_unet(rank, return_lora_module_names=False):
|
49 |
+
unet = UNet2DConditionModel.from_pretrained("stabilityai/sd-turbo", subfolder="unet")
|
50 |
+
unet.requires_grad_(False)
|
51 |
+
unet.train()
|
52 |
+
l_target_modules_encoder, l_target_modules_decoder, l_modules_others = [], [], []
|
53 |
+
l_grep = ["to_k", "to_q", "to_v", "to_out.0", "conv", "conv1", "conv2", "conv_in", "conv_shortcut", "conv_out", "proj_out", "proj_in", "ff.net.2", "ff.net.0.proj"]
|
54 |
+
for n, p in unet.named_parameters():
|
55 |
+
if "bias" in n or "norm" in n: continue
|
56 |
+
for pattern in l_grep:
|
57 |
+
if pattern in n and ("down_blocks" in n or "conv_in" in n):
|
58 |
+
l_target_modules_encoder.append(n.replace(".weight",""))
|
59 |
+
break
|
60 |
+
elif pattern in n and "up_blocks" in n:
|
61 |
+
l_target_modules_decoder.append(n.replace(".weight",""))
|
62 |
+
break
|
63 |
+
elif pattern in n:
|
64 |
+
l_modules_others.append(n.replace(".weight",""))
|
65 |
+
break
|
66 |
+
lora_conf_encoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_encoder, lora_alpha=rank)
|
67 |
+
lora_conf_decoder = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_target_modules_decoder, lora_alpha=rank)
|
68 |
+
lora_conf_others = LoraConfig(r=rank, init_lora_weights="gaussian",target_modules=l_modules_others, lora_alpha=rank)
|
69 |
+
unet.add_adapter(lora_conf_encoder, adapter_name="default_encoder")
|
70 |
+
unet.add_adapter(lora_conf_decoder, adapter_name="default_decoder")
|
71 |
+
unet.add_adapter(lora_conf_others, adapter_name="default_others")
|
72 |
+
unet.set_adapters(["default_encoder", "default_decoder", "default_others"])
|
73 |
+
if return_lora_module_names:
|
74 |
+
return unet, l_target_modules_encoder, l_target_modules_decoder, l_modules_others
|
75 |
+
else:
|
76 |
+
return unet
|
77 |
+
|
78 |
+
|
79 |
+
def initialize_vae(rank=4, return_lora_module_names=False):
|
80 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/sd-turbo", subfolder="vae")
|
81 |
+
vae.requires_grad_(False)
|
82 |
+
vae.encoder.forward = my_vae_encoder_fwd.__get__(vae.encoder, vae.encoder.__class__)
|
83 |
+
vae.decoder.forward = my_vae_decoder_fwd.__get__(vae.decoder, vae.decoder.__class__)
|
84 |
+
vae.requires_grad_(True)
|
85 |
+
vae.train()
|
86 |
+
# add the skip connection convs
|
87 |
+
vae.decoder.skip_conv_1 = torch.nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda().requires_grad_(True)
|
88 |
+
vae.decoder.skip_conv_2 = torch.nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda().requires_grad_(True)
|
89 |
+
vae.decoder.skip_conv_3 = torch.nn.Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda().requires_grad_(True)
|
90 |
+
vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda().requires_grad_(True)
|
91 |
+
torch.nn.init.constant_(vae.decoder.skip_conv_1.weight, 1e-5)
|
92 |
+
torch.nn.init.constant_(vae.decoder.skip_conv_2.weight, 1e-5)
|
93 |
+
torch.nn.init.constant_(vae.decoder.skip_conv_3.weight, 1e-5)
|
94 |
+
torch.nn.init.constant_(vae.decoder.skip_conv_4.weight, 1e-5)
|
95 |
+
vae.decoder.ignore_skip = False
|
96 |
+
vae.decoder.gamma = 1
|
97 |
+
l_vae_target_modules = ["conv1","conv2","conv_in", "conv_shortcut",
|
98 |
+
"conv", "conv_out", "skip_conv_1", "skip_conv_2", "skip_conv_3",
|
99 |
+
"skip_conv_4", "to_k", "to_q", "to_v", "to_out.0",
|
100 |
+
]
|
101 |
+
vae_lora_config = LoraConfig(r=rank, init_lora_weights="gaussian", target_modules=l_vae_target_modules)
|
102 |
+
vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
|
103 |
+
if return_lora_module_names:
|
104 |
+
return vae, l_vae_target_modules
|
105 |
+
else:
|
106 |
+
return vae
|
107 |
+
|
108 |
+
|
109 |
+
class CycleGAN_Turbo(torch.nn.Module):
|
110 |
+
def __init__(self, pretrained_name=None, pretrained_path=None, ckpt_folder="checkpoints", lora_rank_unet=8, lora_rank_vae=4):
|
111 |
+
super().__init__()
|
112 |
+
self.tokenizer = AutoTokenizer.from_pretrained("stabilityai/sd-turbo", subfolder="tokenizer")
|
113 |
+
self.text_encoder = CLIPTextModel.from_pretrained("stabilityai/sd-turbo", subfolder="text_encoder").cuda()
|
114 |
+
self.sched = make_1step_sched()
|
115 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/sd-turbo", subfolder="vae")
|
116 |
+
unet = UNet2DConditionModel.from_pretrained("stabilityai/sd-turbo", subfolder="unet")
|
117 |
+
vae.encoder.forward = my_vae_encoder_fwd.__get__(vae.encoder, vae.encoder.__class__)
|
118 |
+
vae.decoder.forward = my_vae_decoder_fwd.__get__(vae.decoder, vae.decoder.__class__)
|
119 |
+
# add the skip connection convs
|
120 |
+
vae.decoder.skip_conv_1 = torch.nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
|
121 |
+
vae.decoder.skip_conv_2 = torch.nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
|
122 |
+
vae.decoder.skip_conv_3 = torch.nn.Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
|
123 |
+
vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
|
124 |
+
vae.decoder.ignore_skip = False
|
125 |
+
self.unet, self.vae = unet, vae
|
126 |
+
if pretrained_name == "day_to_night":
|
127 |
+
url = "https://www.cs.cmu.edu/~img2img-turbo/models/day2night.pkl"
|
128 |
+
self.load_ckpt_from_url(url, ckpt_folder)
|
129 |
+
self.timesteps = torch.tensor([999], device="cuda").long()
|
130 |
+
self.caption = "driving in the night"
|
131 |
+
self.direction = "a2b"
|
132 |
+
elif pretrained_name == "night_to_day":
|
133 |
+
url = "https://www.cs.cmu.edu/~img2img-turbo/models/night2day.pkl"
|
134 |
+
self.load_ckpt_from_url(url, ckpt_folder)
|
135 |
+
self.timesteps = torch.tensor([999], device="cuda").long()
|
136 |
+
self.caption = "driving in the day"
|
137 |
+
self.direction = "b2a"
|
138 |
+
elif pretrained_name == "clear_to_rainy":
|
139 |
+
url = "https://www.cs.cmu.edu/~img2img-turbo/models/clear2rainy.pkl"
|
140 |
+
self.load_ckpt_from_url(url, ckpt_folder)
|
141 |
+
self.timesteps = torch.tensor([999], device="cuda").long()
|
142 |
+
self.caption = "driving in heavy rain"
|
143 |
+
self.direction = "a2b"
|
144 |
+
elif pretrained_name == "rainy_to_clear":
|
145 |
+
url = "https://www.cs.cmu.edu/~img2img-turbo/models/rainy2clear.pkl"
|
146 |
+
self.load_ckpt_from_url(url, ckpt_folder)
|
147 |
+
self.timesteps = torch.tensor([999], device="cuda").long()
|
148 |
+
self.caption = "driving in the day"
|
149 |
+
self.direction = "b2a"
|
150 |
+
|
151 |
+
elif pretrained_path is not None:
|
152 |
+
sd = torch.load(pretrained_path)
|
153 |
+
self.load_ckpt_from_state_dict(sd)
|
154 |
+
self.timesteps = torch.tensor([999], device="cuda").long()
|
155 |
+
self.caption = None
|
156 |
+
self.direction = None
|
157 |
+
|
158 |
+
self.vae_enc.cuda()
|
159 |
+
self.vae_dec.cuda()
|
160 |
+
self.unet.cuda()
|
161 |
+
|
162 |
+
def load_ckpt_from_state_dict(self, sd):
|
163 |
+
lora_conf_encoder = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["l_target_modules_encoder"], lora_alpha=sd["rank_unet"])
|
164 |
+
lora_conf_decoder = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["l_target_modules_decoder"], lora_alpha=sd["rank_unet"])
|
165 |
+
lora_conf_others = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["l_modules_others"], lora_alpha=sd["rank_unet"])
|
166 |
+
self.unet.add_adapter(lora_conf_encoder, adapter_name="default_encoder")
|
167 |
+
self.unet.add_adapter(lora_conf_decoder, adapter_name="default_decoder")
|
168 |
+
self.unet.add_adapter(lora_conf_others, adapter_name="default_others")
|
169 |
+
for n, p in self.unet.named_parameters():
|
170 |
+
name_sd = n.replace(".default_encoder.weight", ".weight")
|
171 |
+
if "lora" in n and "default_encoder" in n:
|
172 |
+
p.data.copy_(sd["sd_encoder"][name_sd])
|
173 |
+
for n, p in self.unet.named_parameters():
|
174 |
+
name_sd = n.replace(".default_decoder.weight", ".weight")
|
175 |
+
if "lora" in n and "default_decoder" in n:
|
176 |
+
p.data.copy_(sd["sd_decoder"][name_sd])
|
177 |
+
for n, p in self.unet.named_parameters():
|
178 |
+
name_sd = n.replace(".default_others.weight", ".weight")
|
179 |
+
if "lora" in n and "default_others" in n:
|
180 |
+
p.data.copy_(sd["sd_other"][name_sd])
|
181 |
+
self.unet.set_adapter(["default_encoder", "default_decoder", "default_others"])
|
182 |
+
|
183 |
+
vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_target_modules"])
|
184 |
+
self.vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
|
185 |
+
self.vae.decoder.gamma = 1
|
186 |
+
self.vae_b2a = copy.deepcopy(self.vae)
|
187 |
+
self.vae_enc = VAE_encode(self.vae, vae_b2a=self.vae_b2a)
|
188 |
+
self.vae_enc.load_state_dict(sd["sd_vae_enc"])
|
189 |
+
self.vae_dec = VAE_decode(self.vae, vae_b2a=self.vae_b2a)
|
190 |
+
self.vae_dec.load_state_dict(sd["sd_vae_dec"])
|
191 |
+
|
192 |
+
def load_ckpt_from_url(self, url, ckpt_folder):
|
193 |
+
os.makedirs(ckpt_folder, exist_ok=True)
|
194 |
+
outf = os.path.join(ckpt_folder, os.path.basename(url))
|
195 |
+
download_url(url, outf)
|
196 |
+
sd = torch.load(outf)
|
197 |
+
self.load_ckpt_from_state_dict(sd)
|
198 |
+
|
199 |
+
@staticmethod
|
200 |
+
def forward_with_networks(x, direction, vae_enc, unet, vae_dec, sched, timesteps, text_emb):
|
201 |
+
B = x.shape[0]
|
202 |
+
assert direction in ["a2b", "b2a"]
|
203 |
+
x_enc = vae_enc(x, direction=direction).to(x.dtype)
|
204 |
+
model_pred = unet(x_enc, timesteps, encoder_hidden_states=text_emb,).sample
|
205 |
+
x_out = torch.stack([sched.step(model_pred[i], timesteps[i], x_enc[i], return_dict=True).prev_sample for i in range(B)])
|
206 |
+
x_out_decoded = vae_dec(x_out, direction=direction)
|
207 |
+
return x_out_decoded
|
208 |
+
|
209 |
+
@staticmethod
|
210 |
+
def get_traininable_params(unet, vae_a2b, vae_b2a):
|
211 |
+
# add all unet parameters
|
212 |
+
params_gen = list(unet.conv_in.parameters())
|
213 |
+
unet.conv_in.requires_grad_(True)
|
214 |
+
unet.set_adapters(["default_encoder", "default_decoder", "default_others"])
|
215 |
+
for n,p in unet.named_parameters():
|
216 |
+
if "lora" in n and "default" in n:
|
217 |
+
assert p.requires_grad
|
218 |
+
params_gen.append(p)
|
219 |
+
|
220 |
+
# add all vae_a2b parameters
|
221 |
+
for n,p in vae_a2b.named_parameters():
|
222 |
+
if "lora" in n and "vae_skip" in n:
|
223 |
+
assert p.requires_grad
|
224 |
+
params_gen.append(p)
|
225 |
+
params_gen = params_gen + list(vae_a2b.decoder.skip_conv_1.parameters())
|
226 |
+
params_gen = params_gen + list(vae_a2b.decoder.skip_conv_2.parameters())
|
227 |
+
params_gen = params_gen + list(vae_a2b.decoder.skip_conv_3.parameters())
|
228 |
+
params_gen = params_gen + list(vae_a2b.decoder.skip_conv_4.parameters())
|
229 |
+
|
230 |
+
# add all vae_b2a parameters
|
231 |
+
for n,p in vae_b2a.named_parameters():
|
232 |
+
if "lora" in n and "vae_skip" in n:
|
233 |
+
assert p.requires_grad
|
234 |
+
params_gen.append(p)
|
235 |
+
params_gen = params_gen + list(vae_b2a.decoder.skip_conv_1.parameters())
|
236 |
+
params_gen = params_gen + list(vae_b2a.decoder.skip_conv_2.parameters())
|
237 |
+
params_gen = params_gen + list(vae_b2a.decoder.skip_conv_3.parameters())
|
238 |
+
params_gen = params_gen + list(vae_b2a.decoder.skip_conv_4.parameters())
|
239 |
+
return params_gen
|
240 |
+
|
241 |
+
def forward(self, x_t, direction=None, caption=None, caption_emb=None):
|
242 |
+
if direction is None:
|
243 |
+
assert self.direction is not None
|
244 |
+
direction = self.direction
|
245 |
+
if caption is None and caption_emb is None:
|
246 |
+
assert self.caption is not None
|
247 |
+
caption = self.caption
|
248 |
+
if caption_emb is not None:
|
249 |
+
caption_enc = caption_emb
|
250 |
+
else:
|
251 |
+
caption_tokens = self.tokenizer(caption, max_length=self.tokenizer.model_max_length,
|
252 |
+
padding="max_length", truncation=True, return_tensors="pt").input_ids.to(x_t.device)
|
253 |
+
caption_enc = self.text_encoder(caption_tokens)[0].detach().clone()
|
254 |
+
return self.forward_with_networks(x_t, direction, self.vae_enc, self.unet, self.vae_dec, self.sched, self.timesteps, caption_enc)
|
src/image_prep.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from PIL import Image
|
3 |
+
import cv2
|
4 |
+
|
5 |
+
|
6 |
+
def canny_from_pil(image, low_threshold=100, high_threshold=200):
|
7 |
+
image = np.array(image)
|
8 |
+
image = cv2.Canny(image, low_threshold, high_threshold)
|
9 |
+
image = image[:, :, None]
|
10 |
+
image = np.concatenate([image, image, image], axis=2)
|
11 |
+
control_image = Image.fromarray(image)
|
12 |
+
return control_image
|
src/inference_paired.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
import numpy as np
|
4 |
+
from PIL import Image
|
5 |
+
import torch
|
6 |
+
from torchvision import transforms
|
7 |
+
import torchvision.transforms.functional as F
|
8 |
+
from pix2pix_turbo import Pix2Pix_Turbo
|
9 |
+
from image_prep import canny_from_pil
|
10 |
+
|
11 |
+
if __name__ == "__main__":
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
parser.add_argument('--input_image', type=str, required=True, help='path to the input image')
|
14 |
+
parser.add_argument('--prompt', type=str, required=True, help='the prompt to be used')
|
15 |
+
parser.add_argument('--model_name', type=str, default='', help='name of the pretrained model to be used')
|
16 |
+
parser.add_argument('--model_path', type=str, default='', help='path to a model state dict to be used')
|
17 |
+
parser.add_argument('--output_dir', type=str, default='output', help='the directory to save the output')
|
18 |
+
parser.add_argument('--low_threshold', type=int, default=100, help='Canny low threshold')
|
19 |
+
parser.add_argument('--high_threshold', type=int, default=200, help='Canny high threshold')
|
20 |
+
parser.add_argument('--gamma', type=float, default=0.4, help='The sketch interpolation guidance amount')
|
21 |
+
parser.add_argument('--seed', type=int, default=42, help='Random seed to be used')
|
22 |
+
args = parser.parse_args()
|
23 |
+
|
24 |
+
# only one of model_name and model_path should be provided
|
25 |
+
if args.model_name == '' != args.model_path == '':
|
26 |
+
raise ValueError('Either model_name or model_path should be provided')
|
27 |
+
|
28 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
29 |
+
|
30 |
+
# initialize the model
|
31 |
+
model = Pix2Pix_Turbo(pretrained_name=args.model_name, pretrained_path=args.model_path)
|
32 |
+
model.set_eval()
|
33 |
+
|
34 |
+
# make sure that the input image is a multiple of 8
|
35 |
+
input_image = Image.open(args.input_image).convert('RGB')
|
36 |
+
new_width = input_image.width - input_image.width % 8
|
37 |
+
new_height = input_image.height - input_image.height % 8
|
38 |
+
input_image = input_image.resize((new_width, new_height), Image.LANCZOS)
|
39 |
+
bname = os.path.basename(args.input_image)
|
40 |
+
|
41 |
+
# translate the image
|
42 |
+
with torch.no_grad():
|
43 |
+
if args.model_name == 'edge_to_image':
|
44 |
+
canny = canny_from_pil(input_image, args.low_threshold, args.high_threshold)
|
45 |
+
canny_viz_inv = Image.fromarray(255 - np.array(canny))
|
46 |
+
canny_viz_inv.save(os.path.join(args.output_dir, bname.replace('.png', '_canny.png')))
|
47 |
+
c_t = F.to_tensor(canny).unsqueeze(0).cuda()
|
48 |
+
output_image = model(c_t, args.prompt)
|
49 |
+
|
50 |
+
elif args.model_name == 'sketch_to_image_stochastic':
|
51 |
+
image_t = F.to_tensor(input_image) < 0.5
|
52 |
+
c_t = image_t.unsqueeze(0).cuda().float()
|
53 |
+
torch.manual_seed(args.seed)
|
54 |
+
B, C, H, W = c_t.shape
|
55 |
+
noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device)
|
56 |
+
output_image = model(c_t, args.prompt, deterministic=False, r=args.gamma, noise_map=noise)
|
57 |
+
|
58 |
+
else:
|
59 |
+
c_t = F.to_tensor(input_image).unsqueeze(0).cuda()
|
60 |
+
output_image = model(c_t, args.prompt)
|
61 |
+
|
62 |
+
output_pil = transforms.ToPILImage()(output_image[0].cpu() * 0.5 + 0.5)
|
63 |
+
|
64 |
+
# save the output image
|
65 |
+
output_pil.save(os.path.join(args.output_dir, bname))
|
src/inference_unpaired.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import argparse
|
3 |
+
from PIL import Image
|
4 |
+
import torch
|
5 |
+
from torchvision import transforms
|
6 |
+
from cyclegan_turbo import CycleGAN_Turbo
|
7 |
+
from my_utils.training_utils import build_transform
|
8 |
+
|
9 |
+
|
10 |
+
if __name__ == "__main__":
|
11 |
+
parser = argparse.ArgumentParser()
|
12 |
+
parser.add_argument('--input_image', type=str, required=True, help='path to the input image')
|
13 |
+
parser.add_argument('--prompt', type=str, required=False, help='the prompt to be used. It is required when loading a custom model_path.')
|
14 |
+
parser.add_argument('--model_name', type=str, default=None, help='name of the pretrained model to be used')
|
15 |
+
parser.add_argument('--model_path', type=str, default=None, help='path to a local model state dict to be used')
|
16 |
+
parser.add_argument('--output_dir', type=str, default='output', help='the directory to save the output')
|
17 |
+
parser.add_argument('--image_prep', type=str, default='resize_512x512', help='the image preparation method')
|
18 |
+
parser.add_argument('--direction', type=str, default=None, help='the direction of translation. None for pretrained models, a2b or b2a for custom paths.')
|
19 |
+
args = parser.parse_args()
|
20 |
+
|
21 |
+
# only one of model_name and model_path should be provided
|
22 |
+
if args.model_name is None != args.model_path is None:
|
23 |
+
raise ValueError('Either model_name or model_path should be provided')
|
24 |
+
|
25 |
+
if args.model_path is not None and args.prompt is None:
|
26 |
+
raise ValueError('prompt is required when loading a custom model_path.')
|
27 |
+
|
28 |
+
if args.model_name is not None:
|
29 |
+
assert args.prompt is None, 'prompt is not required when loading a pretrained model.'
|
30 |
+
assert args.direction is None, 'direction is not required when loading a pretrained model.'
|
31 |
+
|
32 |
+
# initialize the model
|
33 |
+
model = CycleGAN_Turbo(pretrained_name=args.model_name, pretrained_path=args.model_path)
|
34 |
+
model.eval()
|
35 |
+
model.unet.enable_xformers_memory_efficient_attention()
|
36 |
+
|
37 |
+
T_val = build_transform(args.image_prep)
|
38 |
+
|
39 |
+
input_image = Image.open(args.input_image).convert('RGB')
|
40 |
+
# translate the image
|
41 |
+
with torch.no_grad():
|
42 |
+
input_img = T_val(input_image)
|
43 |
+
x_t = transforms.ToTensor()(input_img)
|
44 |
+
x_t = transforms.Normalize([0.5], [0.5])(x_t).unsqueeze(0).cuda()
|
45 |
+
output = model(x_t, direction=args.direction, caption=args.prompt)
|
46 |
+
|
47 |
+
output_pil = transforms.ToPILImage()(output[0].cpu() * 0.5 + 0.5)
|
48 |
+
output_pil = output_pil.resize((input_image.width, input_image.height), Image.LANCZOS)
|
49 |
+
|
50 |
+
# save the output image
|
51 |
+
bname = os.path.basename(args.input_image)
|
52 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
53 |
+
output_pil.save(os.path.join(args.output_dir, bname))
|
src/model.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import requests
|
3 |
+
from tqdm import tqdm
|
4 |
+
from diffusers import DDPMScheduler
|
5 |
+
|
6 |
+
|
7 |
+
def make_1step_sched():
|
8 |
+
noise_scheduler_1step = DDPMScheduler.from_pretrained("stabilityai/sd-turbo", subfolder="scheduler")
|
9 |
+
noise_scheduler_1step.set_timesteps(1, device="cuda")
|
10 |
+
noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda()
|
11 |
+
return noise_scheduler_1step
|
12 |
+
|
13 |
+
|
14 |
+
def my_vae_encoder_fwd(self, sample):
|
15 |
+
sample = self.conv_in(sample)
|
16 |
+
l_blocks = []
|
17 |
+
# down
|
18 |
+
for down_block in self.down_blocks:
|
19 |
+
l_blocks.append(sample)
|
20 |
+
sample = down_block(sample)
|
21 |
+
# middle
|
22 |
+
sample = self.mid_block(sample)
|
23 |
+
sample = self.conv_norm_out(sample)
|
24 |
+
sample = self.conv_act(sample)
|
25 |
+
sample = self.conv_out(sample)
|
26 |
+
self.current_down_blocks = l_blocks
|
27 |
+
return sample
|
28 |
+
|
29 |
+
|
30 |
+
def my_vae_decoder_fwd(self, sample, latent_embeds=None):
|
31 |
+
sample = self.conv_in(sample)
|
32 |
+
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
33 |
+
# middle
|
34 |
+
sample = self.mid_block(sample, latent_embeds)
|
35 |
+
sample = sample.to(upscale_dtype)
|
36 |
+
if not self.ignore_skip:
|
37 |
+
skip_convs = [self.skip_conv_1, self.skip_conv_2, self.skip_conv_3, self.skip_conv_4]
|
38 |
+
# up
|
39 |
+
for idx, up_block in enumerate(self.up_blocks):
|
40 |
+
skip_in = skip_convs[idx](self.incoming_skip_acts[::-1][idx] * self.gamma)
|
41 |
+
# add skip
|
42 |
+
sample = sample + skip_in
|
43 |
+
sample = up_block(sample, latent_embeds)
|
44 |
+
else:
|
45 |
+
for idx, up_block in enumerate(self.up_blocks):
|
46 |
+
sample = up_block(sample, latent_embeds)
|
47 |
+
# post-process
|
48 |
+
if latent_embeds is None:
|
49 |
+
sample = self.conv_norm_out(sample)
|
50 |
+
else:
|
51 |
+
sample = self.conv_norm_out(sample, latent_embeds)
|
52 |
+
sample = self.conv_act(sample)
|
53 |
+
sample = self.conv_out(sample)
|
54 |
+
return sample
|
55 |
+
|
56 |
+
|
57 |
+
def download_url(url, outf):
|
58 |
+
if not os.path.exists(outf):
|
59 |
+
print(f"Downloading checkpoint to {outf}")
|
60 |
+
response = requests.get(url, stream=True)
|
61 |
+
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
62 |
+
block_size = 1024 # 1 Kibibyte
|
63 |
+
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
|
64 |
+
with open(outf, 'wb') as file:
|
65 |
+
for data in response.iter_content(block_size):
|
66 |
+
progress_bar.update(len(data))
|
67 |
+
file.write(data)
|
68 |
+
progress_bar.close()
|
69 |
+
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
|
70 |
+
print("ERROR, something went wrong")
|
71 |
+
print(f"Downloaded successfully to {outf}")
|
72 |
+
else:
|
73 |
+
print(f"Skipping download, {outf} already exists")
|
src/my_utils/dino_struct.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
def attn_cosine_sim(x, eps=1e-08):
|
7 |
+
x = x[0] # TEMP: getting rid of redundant dimension, TBF
|
8 |
+
norm1 = x.norm(dim=2, keepdim=True)
|
9 |
+
factor = torch.clamp(norm1 @ norm1.permute(0, 2, 1), min=eps)
|
10 |
+
sim_matrix = (x @ x.permute(0, 2, 1)) / factor
|
11 |
+
return sim_matrix
|
12 |
+
|
13 |
+
|
14 |
+
class VitExtractor:
|
15 |
+
BLOCK_KEY = 'block'
|
16 |
+
ATTN_KEY = 'attn'
|
17 |
+
PATCH_IMD_KEY = 'patch_imd'
|
18 |
+
QKV_KEY = 'qkv'
|
19 |
+
KEY_LIST = [BLOCK_KEY, ATTN_KEY, PATCH_IMD_KEY, QKV_KEY]
|
20 |
+
|
21 |
+
def __init__(self, model_name, device):
|
22 |
+
# pdb.set_trace()
|
23 |
+
self.model = torch.hub.load('facebookresearch/dino:main', model_name).to(device)
|
24 |
+
self.model.eval()
|
25 |
+
self.model_name = model_name
|
26 |
+
self.hook_handlers = []
|
27 |
+
self.layers_dict = {}
|
28 |
+
self.outputs_dict = {}
|
29 |
+
for key in VitExtractor.KEY_LIST:
|
30 |
+
self.layers_dict[key] = []
|
31 |
+
self.outputs_dict[key] = []
|
32 |
+
self._init_hooks_data()
|
33 |
+
|
34 |
+
def _init_hooks_data(self):
|
35 |
+
self.layers_dict[VitExtractor.BLOCK_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
|
36 |
+
self.layers_dict[VitExtractor.ATTN_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
|
37 |
+
self.layers_dict[VitExtractor.QKV_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
|
38 |
+
self.layers_dict[VitExtractor.PATCH_IMD_KEY] = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
|
39 |
+
for key in VitExtractor.KEY_LIST:
|
40 |
+
# self.layers_dict[key] = kwargs[key] if key in kwargs.keys() else []
|
41 |
+
self.outputs_dict[key] = []
|
42 |
+
|
43 |
+
def _register_hooks(self, **kwargs):
|
44 |
+
for block_idx, block in enumerate(self.model.blocks):
|
45 |
+
if block_idx in self.layers_dict[VitExtractor.BLOCK_KEY]:
|
46 |
+
self.hook_handlers.append(block.register_forward_hook(self._get_block_hook()))
|
47 |
+
if block_idx in self.layers_dict[VitExtractor.ATTN_KEY]:
|
48 |
+
self.hook_handlers.append(block.attn.attn_drop.register_forward_hook(self._get_attn_hook()))
|
49 |
+
if block_idx in self.layers_dict[VitExtractor.QKV_KEY]:
|
50 |
+
self.hook_handlers.append(block.attn.qkv.register_forward_hook(self._get_qkv_hook()))
|
51 |
+
if block_idx in self.layers_dict[VitExtractor.PATCH_IMD_KEY]:
|
52 |
+
self.hook_handlers.append(block.attn.register_forward_hook(self._get_patch_imd_hook()))
|
53 |
+
|
54 |
+
def _clear_hooks(self):
|
55 |
+
for handler in self.hook_handlers:
|
56 |
+
handler.remove()
|
57 |
+
self.hook_handlers = []
|
58 |
+
|
59 |
+
def _get_block_hook(self):
|
60 |
+
def _get_block_output(model, input, output):
|
61 |
+
self.outputs_dict[VitExtractor.BLOCK_KEY].append(output)
|
62 |
+
|
63 |
+
return _get_block_output
|
64 |
+
|
65 |
+
def _get_attn_hook(self):
|
66 |
+
def _get_attn_output(model, inp, output):
|
67 |
+
self.outputs_dict[VitExtractor.ATTN_KEY].append(output)
|
68 |
+
|
69 |
+
return _get_attn_output
|
70 |
+
|
71 |
+
def _get_qkv_hook(self):
|
72 |
+
def _get_qkv_output(model, inp, output):
|
73 |
+
self.outputs_dict[VitExtractor.QKV_KEY].append(output)
|
74 |
+
|
75 |
+
return _get_qkv_output
|
76 |
+
|
77 |
+
# TODO: CHECK ATTN OUTPUT TUPLE
|
78 |
+
def _get_patch_imd_hook(self):
|
79 |
+
def _get_attn_output(model, inp, output):
|
80 |
+
self.outputs_dict[VitExtractor.PATCH_IMD_KEY].append(output[0])
|
81 |
+
|
82 |
+
return _get_attn_output
|
83 |
+
|
84 |
+
def get_feature_from_input(self, input_img): # List([B, N, D])
|
85 |
+
self._register_hooks()
|
86 |
+
self.model(input_img)
|
87 |
+
feature = self.outputs_dict[VitExtractor.BLOCK_KEY]
|
88 |
+
self._clear_hooks()
|
89 |
+
self._init_hooks_data()
|
90 |
+
return feature
|
91 |
+
|
92 |
+
def get_qkv_feature_from_input(self, input_img):
|
93 |
+
self._register_hooks()
|
94 |
+
self.model(input_img)
|
95 |
+
feature = self.outputs_dict[VitExtractor.QKV_KEY]
|
96 |
+
self._clear_hooks()
|
97 |
+
self._init_hooks_data()
|
98 |
+
return feature
|
99 |
+
|
100 |
+
def get_attn_feature_from_input(self, input_img):
|
101 |
+
self._register_hooks()
|
102 |
+
self.model(input_img)
|
103 |
+
feature = self.outputs_dict[VitExtractor.ATTN_KEY]
|
104 |
+
self._clear_hooks()
|
105 |
+
self._init_hooks_data()
|
106 |
+
return feature
|
107 |
+
|
108 |
+
def get_patch_size(self):
|
109 |
+
return 8 if "8" in self.model_name else 16
|
110 |
+
|
111 |
+
def get_width_patch_num(self, input_img_shape):
|
112 |
+
b, c, h, w = input_img_shape
|
113 |
+
patch_size = self.get_patch_size()
|
114 |
+
return w // patch_size
|
115 |
+
|
116 |
+
def get_height_patch_num(self, input_img_shape):
|
117 |
+
b, c, h, w = input_img_shape
|
118 |
+
patch_size = self.get_patch_size()
|
119 |
+
return h // patch_size
|
120 |
+
|
121 |
+
def get_patch_num(self, input_img_shape):
|
122 |
+
patch_num = 1 + (self.get_height_patch_num(input_img_shape) * self.get_width_patch_num(input_img_shape))
|
123 |
+
return patch_num
|
124 |
+
|
125 |
+
def get_head_num(self):
|
126 |
+
if "dino" in self.model_name:
|
127 |
+
return 6 if "s" in self.model_name else 12
|
128 |
+
return 6 if "small" in self.model_name else 12
|
129 |
+
|
130 |
+
def get_embedding_dim(self):
|
131 |
+
if "dino" in self.model_name:
|
132 |
+
return 384 if "s" in self.model_name else 768
|
133 |
+
return 384 if "small" in self.model_name else 768
|
134 |
+
|
135 |
+
def get_queries_from_qkv(self, qkv, input_img_shape):
|
136 |
+
patch_num = self.get_patch_num(input_img_shape)
|
137 |
+
head_num = self.get_head_num()
|
138 |
+
embedding_dim = self.get_embedding_dim()
|
139 |
+
q = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[0]
|
140 |
+
return q
|
141 |
+
|
142 |
+
def get_keys_from_qkv(self, qkv, input_img_shape):
|
143 |
+
patch_num = self.get_patch_num(input_img_shape)
|
144 |
+
head_num = self.get_head_num()
|
145 |
+
embedding_dim = self.get_embedding_dim()
|
146 |
+
k = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[1]
|
147 |
+
return k
|
148 |
+
|
149 |
+
def get_values_from_qkv(self, qkv, input_img_shape):
|
150 |
+
patch_num = self.get_patch_num(input_img_shape)
|
151 |
+
head_num = self.get_head_num()
|
152 |
+
embedding_dim = self.get_embedding_dim()
|
153 |
+
v = qkv.reshape(patch_num, 3, head_num, embedding_dim // head_num).permute(1, 2, 0, 3)[2]
|
154 |
+
return v
|
155 |
+
|
156 |
+
def get_keys_from_input(self, input_img, layer_num):
|
157 |
+
qkv_features = self.get_qkv_feature_from_input(input_img)[layer_num]
|
158 |
+
keys = self.get_keys_from_qkv(qkv_features, input_img.shape)
|
159 |
+
return keys
|
160 |
+
|
161 |
+
def get_keys_self_sim_from_input(self, input_img, layer_num):
|
162 |
+
keys = self.get_keys_from_input(input_img, layer_num=layer_num)
|
163 |
+
h, t, d = keys.shape
|
164 |
+
concatenated_keys = keys.transpose(0, 1).reshape(t, h * d)
|
165 |
+
ssim_map = attn_cosine_sim(concatenated_keys[None, None, ...])
|
166 |
+
return ssim_map
|
167 |
+
|
168 |
+
|
169 |
+
class DinoStructureLoss:
|
170 |
+
def __init__(self, ):
|
171 |
+
self.extractor = VitExtractor(model_name="dino_vitb8", device="cuda")
|
172 |
+
self.preprocess = torchvision.transforms.Compose([
|
173 |
+
torchvision.transforms.Resize(224),
|
174 |
+
torchvision.transforms.ToTensor(),
|
175 |
+
torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
|
176 |
+
])
|
177 |
+
|
178 |
+
def calculate_global_ssim_loss(self, outputs, inputs):
|
179 |
+
loss = 0.0
|
180 |
+
for a, b in zip(inputs, outputs): # avoid memory limitations
|
181 |
+
with torch.no_grad():
|
182 |
+
target_keys_self_sim = self.extractor.get_keys_self_sim_from_input(a.unsqueeze(0), layer_num=11)
|
183 |
+
keys_ssim = self.extractor.get_keys_self_sim_from_input(b.unsqueeze(0), layer_num=11)
|
184 |
+
loss += F.mse_loss(keys_ssim, target_keys_self_sim)
|
185 |
+
return loss
|
src/my_utils/training_utils.py
ADDED
@@ -0,0 +1,409 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import argparse
|
4 |
+
import json
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from torchvision import transforms
|
8 |
+
import torchvision.transforms.functional as F
|
9 |
+
from glob import glob
|
10 |
+
|
11 |
+
|
12 |
+
def parse_args_paired_training(input_args=None):
|
13 |
+
"""
|
14 |
+
Parses command-line arguments used for configuring an paired session (pix2pix-Turbo).
|
15 |
+
This function sets up an argument parser to handle various training options.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
argparse.Namespace: The parsed command-line arguments.
|
19 |
+
"""
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
# args for the loss function
|
22 |
+
parser.add_argument("--gan_disc_type", default="vagan_clip")
|
23 |
+
parser.add_argument("--gan_loss_type", default="multilevel_sigmoid_s")
|
24 |
+
parser.add_argument("--lambda_gan", default=0.5, type=float)
|
25 |
+
parser.add_argument("--lambda_lpips", default=5, type=float)
|
26 |
+
parser.add_argument("--lambda_l2", default=1.0, type=float)
|
27 |
+
parser.add_argument("--lambda_clipsim", default=5.0, type=float)
|
28 |
+
|
29 |
+
# dataset options
|
30 |
+
parser.add_argument("--dataset_folder", required=True, type=str)
|
31 |
+
parser.add_argument("--train_image_prep", default="resized_crop_512", type=str)
|
32 |
+
parser.add_argument("--test_image_prep", default="resized_crop_512", type=str)
|
33 |
+
|
34 |
+
# validation eval args
|
35 |
+
parser.add_argument("--eval_freq", default=100, type=int)
|
36 |
+
parser.add_argument("--track_val_fid", default=False, action="store_true")
|
37 |
+
parser.add_argument("--num_samples_eval", type=int, default=100, help="Number of samples to use for all evaluation")
|
38 |
+
|
39 |
+
parser.add_argument("--viz_freq", type=int, default=100, help="Frequency of visualizing the outputs.")
|
40 |
+
parser.add_argument("--tracker_project_name", type=str, default="train_pix2pix_turbo", help="The name of the wandb project to log to.")
|
41 |
+
|
42 |
+
# details about the model architecture
|
43 |
+
parser.add_argument("--pretrained_model_name_or_path")
|
44 |
+
parser.add_argument("--revision", type=str, default=None,)
|
45 |
+
parser.add_argument("--variant", type=str, default=None,)
|
46 |
+
parser.add_argument("--tokenizer_name", type=str, default=None)
|
47 |
+
parser.add_argument("--lora_rank_unet", default=8, type=int)
|
48 |
+
parser.add_argument("--lora_rank_vae", default=4, type=int)
|
49 |
+
|
50 |
+
# training details
|
51 |
+
parser.add_argument("--output_dir", required=True)
|
52 |
+
parser.add_argument("--cache_dir", default=None,)
|
53 |
+
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
54 |
+
parser.add_argument("--resolution", type=int, default=512,)
|
55 |
+
parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader.")
|
56 |
+
parser.add_argument("--num_training_epochs", type=int, default=10)
|
57 |
+
parser.add_argument("--max_train_steps", type=int, default=10_000,)
|
58 |
+
parser.add_argument("--checkpointing_steps", type=int, default=500,)
|
59 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Number of updates steps to accumulate before performing a backward/update pass.",)
|
60 |
+
parser.add_argument("--gradient_checkpointing", action="store_true",)
|
61 |
+
parser.add_argument("--learning_rate", type=float, default=5e-6)
|
62 |
+
parser.add_argument("--lr_scheduler", type=str, default="constant",
|
63 |
+
help=(
|
64 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
65 |
+
' "constant", "constant_with_warmup"]'
|
66 |
+
),
|
67 |
+
)
|
68 |
+
parser.add_argument("--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler.")
|
69 |
+
parser.add_argument("--lr_num_cycles", type=int, default=1,
|
70 |
+
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
|
71 |
+
)
|
72 |
+
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
|
73 |
+
|
74 |
+
parser.add_argument("--dataloader_num_workers", type=int, default=0,)
|
75 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
76 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
77 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
78 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
79 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
80 |
+
parser.add_argument("--allow_tf32", action="store_true",
|
81 |
+
help=(
|
82 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
83 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
84 |
+
),
|
85 |
+
)
|
86 |
+
parser.add_argument("--report_to", type=str, default="wandb",
|
87 |
+
help=(
|
88 |
+
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
89 |
+
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
90 |
+
),
|
91 |
+
)
|
92 |
+
parser.add_argument("--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"],)
|
93 |
+
parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.")
|
94 |
+
parser.add_argument("--set_grads_to_none", action="store_true",)
|
95 |
+
|
96 |
+
if input_args is not None:
|
97 |
+
args = parser.parse_args(input_args)
|
98 |
+
else:
|
99 |
+
args = parser.parse_args()
|
100 |
+
|
101 |
+
return args
|
102 |
+
|
103 |
+
|
104 |
+
def parse_args_unpaired_training():
|
105 |
+
"""
|
106 |
+
Parses command-line arguments used for configuring an unpaired session (CycleGAN-Turbo).
|
107 |
+
This function sets up an argument parser to handle various training options.
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
argparse.Namespace: The parsed command-line arguments.
|
111 |
+
"""
|
112 |
+
|
113 |
+
parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
|
114 |
+
|
115 |
+
# fixed random seed
|
116 |
+
parser.add_argument("--seed", type=int, default=42, help="A seed for reproducible training.")
|
117 |
+
|
118 |
+
# args for the loss function
|
119 |
+
parser.add_argument("--gan_disc_type", default="vagan_clip")
|
120 |
+
parser.add_argument("--gan_loss_type", default="multilevel_sigmoid")
|
121 |
+
parser.add_argument("--lambda_gan", default=0.5, type=float)
|
122 |
+
parser.add_argument("--lambda_idt", default=1, type=float)
|
123 |
+
parser.add_argument("--lambda_cycle", default=1, type=float)
|
124 |
+
parser.add_argument("--lambda_cycle_lpips", default=10.0, type=float)
|
125 |
+
parser.add_argument("--lambda_idt_lpips", default=1.0, type=float)
|
126 |
+
|
127 |
+
# args for dataset and dataloader options
|
128 |
+
parser.add_argument("--dataset_folder", required=True, type=str)
|
129 |
+
parser.add_argument("--train_img_prep", required=True)
|
130 |
+
parser.add_argument("--val_img_prep", required=True)
|
131 |
+
parser.add_argument("--dataloader_num_workers", type=int, default=0)
|
132 |
+
parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader.")
|
133 |
+
parser.add_argument("--max_train_epochs", type=int, default=100)
|
134 |
+
parser.add_argument("--max_train_steps", type=int, default=None)
|
135 |
+
|
136 |
+
# args for the model
|
137 |
+
parser.add_argument("--pretrained_model_name_or_path", default="stabilityai/sd-turbo")
|
138 |
+
parser.add_argument("--revision", default=None, type=str)
|
139 |
+
parser.add_argument("--variant", default=None, type=str)
|
140 |
+
parser.add_argument("--lora_rank_unet", default=128, type=int)
|
141 |
+
parser.add_argument("--lora_rank_vae", default=4, type=int)
|
142 |
+
|
143 |
+
# args for validation and logging
|
144 |
+
parser.add_argument("--viz_freq", type=int, default=20)
|
145 |
+
parser.add_argument("--output_dir", type=str, required=True)
|
146 |
+
parser.add_argument("--report_to", type=str, default="wandb")
|
147 |
+
parser.add_argument("--tracker_project_name", type=str, required=True)
|
148 |
+
parser.add_argument("--validation_steps", type=int, default=500,)
|
149 |
+
parser.add_argument("--validation_num_images", type=int, default=-1, help="Number of images to use for validation. -1 to use all images.")
|
150 |
+
parser.add_argument("--checkpointing_steps", type=int, default=500)
|
151 |
+
|
152 |
+
# args for the optimization options
|
153 |
+
parser.add_argument("--learning_rate", type=float, default=5e-6,)
|
154 |
+
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
155 |
+
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
156 |
+
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
157 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
158 |
+
parser.add_argument("--max_grad_norm", default=10.0, type=float, help="Max gradient norm.")
|
159 |
+
parser.add_argument("--lr_scheduler", type=str, default="constant", help=(
|
160 |
+
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
161 |
+
' "constant", "constant_with_warmup"]'
|
162 |
+
),
|
163 |
+
)
|
164 |
+
parser.add_argument("--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler.")
|
165 |
+
parser.add_argument("--lr_num_cycles", type=int, default=1, help="Number of hard resets of the lr in cosine_with_restarts scheduler.",)
|
166 |
+
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
|
167 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
168 |
+
|
169 |
+
# memory saving options
|
170 |
+
parser.add_argument("--allow_tf32", action="store_true",
|
171 |
+
help=(
|
172 |
+
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
173 |
+
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
174 |
+
),
|
175 |
+
)
|
176 |
+
parser.add_argument("--gradient_checkpointing", action="store_true",
|
177 |
+
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.")
|
178 |
+
parser.add_argument("--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers.")
|
179 |
+
|
180 |
+
args = parser.parse_args()
|
181 |
+
return args
|
182 |
+
|
183 |
+
|
184 |
+
def build_transform(image_prep):
|
185 |
+
"""
|
186 |
+
Constructs a transformation pipeline based on the specified image preparation method.
|
187 |
+
|
188 |
+
Parameters:
|
189 |
+
- image_prep (str): A string describing the desired image preparation
|
190 |
+
|
191 |
+
Returns:
|
192 |
+
- torchvision.transforms.Compose: A composable sequence of transformations to be applied to images.
|
193 |
+
"""
|
194 |
+
if image_prep == "resized_crop_512":
|
195 |
+
T = transforms.Compose([
|
196 |
+
transforms.Resize(512, interpolation=transforms.InterpolationMode.LANCZOS),
|
197 |
+
transforms.CenterCrop(512),
|
198 |
+
])
|
199 |
+
elif image_prep == "resize_286_randomcrop_256x256_hflip":
|
200 |
+
T = transforms.Compose([
|
201 |
+
transforms.Resize((286, 286), interpolation=Image.LANCZOS),
|
202 |
+
transforms.RandomCrop((256, 256)),
|
203 |
+
transforms.RandomHorizontalFlip(),
|
204 |
+
])
|
205 |
+
elif image_prep in ["resize_256", "resize_256x256"]:
|
206 |
+
T = transforms.Compose([
|
207 |
+
transforms.Resize((256, 256), interpolation=Image.LANCZOS)
|
208 |
+
])
|
209 |
+
elif image_prep in ["resize_512", "resize_512x512"]:
|
210 |
+
T = transforms.Compose([
|
211 |
+
transforms.Resize((512, 512), interpolation=Image.LANCZOS)
|
212 |
+
])
|
213 |
+
elif image_prep == "no_resize":
|
214 |
+
T = transforms.Lambda(lambda x: x)
|
215 |
+
return T
|
216 |
+
|
217 |
+
|
218 |
+
class PairedDataset(torch.utils.data.Dataset):
|
219 |
+
def __init__(self, dataset_folder, split, image_prep, tokenizer):
|
220 |
+
"""
|
221 |
+
Itialize the paired dataset object for loading and transforming paired data samples
|
222 |
+
from specified dataset folders.
|
223 |
+
|
224 |
+
This constructor sets up the paths to input and output folders based on the specified 'split',
|
225 |
+
loads the captions (or prompts) for the input images, and prepares the transformations and
|
226 |
+
tokenizer to be applied on the data.
|
227 |
+
|
228 |
+
Parameters:
|
229 |
+
- dataset_folder (str): The root folder containing the dataset, expected to include
|
230 |
+
sub-folders for different splits (e.g., 'train_A', 'train_B').
|
231 |
+
- split (str): The dataset split to use ('train' or 'test'), used to select the appropriate
|
232 |
+
sub-folders and caption files within the dataset folder.
|
233 |
+
- image_prep (str): The image preprocessing transformation to apply to each image.
|
234 |
+
- tokenizer: The tokenizer used for tokenizing the captions (or prompts).
|
235 |
+
"""
|
236 |
+
super().__init__()
|
237 |
+
if split == "train":
|
238 |
+
self.input_folder = os.path.join(dataset_folder, "train_A")
|
239 |
+
self.output_folder = os.path.join(dataset_folder, "train_B")
|
240 |
+
captions = os.path.join(dataset_folder, "train_prompts.json")
|
241 |
+
elif split == "test":
|
242 |
+
self.input_folder = os.path.join(dataset_folder, "test_A")
|
243 |
+
self.output_folder = os.path.join(dataset_folder, "test_B")
|
244 |
+
captions = os.path.join(dataset_folder, "test_prompts.json")
|
245 |
+
with open(captions, "r") as f:
|
246 |
+
self.captions = json.load(f)
|
247 |
+
self.img_names = list(self.captions.keys())
|
248 |
+
self.T = build_transform(image_prep)
|
249 |
+
self.tokenizer = tokenizer
|
250 |
+
|
251 |
+
def __len__(self):
|
252 |
+
"""
|
253 |
+
Returns:
|
254 |
+
int: The total number of items in the dataset.
|
255 |
+
"""
|
256 |
+
return len(self.captions)
|
257 |
+
|
258 |
+
def __getitem__(self, idx):
|
259 |
+
"""
|
260 |
+
Retrieves a dataset item given its index. Each item consists of an input image,
|
261 |
+
its corresponding output image, the captions associated with the input image,
|
262 |
+
and the tokenized form of this caption.
|
263 |
+
|
264 |
+
This method performs the necessary preprocessing on both the input and output images,
|
265 |
+
including scaling and normalization, as well as tokenizing the caption using a provided tokenizer.
|
266 |
+
|
267 |
+
Parameters:
|
268 |
+
- idx (int): The index of the item to retrieve.
|
269 |
+
|
270 |
+
Returns:
|
271 |
+
dict: A dictionary containing the following key-value pairs:
|
272 |
+
- "output_pixel_values": a tensor of the preprocessed output image with pixel values
|
273 |
+
scaled to [-1, 1].
|
274 |
+
- "conditioning_pixel_values": a tensor of the preprocessed input image with pixel values
|
275 |
+
scaled to [0, 1].
|
276 |
+
- "caption": the text caption.
|
277 |
+
- "input_ids": a tensor of the tokenized caption.
|
278 |
+
|
279 |
+
Note:
|
280 |
+
The actual preprocessing steps (scaling and normalization) for images are defined externally
|
281 |
+
and passed to this class through the `image_prep` parameter during initialization. The
|
282 |
+
tokenization process relies on the `tokenizer` also provided at initialization, which
|
283 |
+
should be compatible with the models intended to be used with this dataset.
|
284 |
+
"""
|
285 |
+
img_name = self.img_names[idx]
|
286 |
+
input_img = Image.open(os.path.join(self.input_folder, img_name))
|
287 |
+
output_img = Image.open(os.path.join(self.output_folder, img_name))
|
288 |
+
caption = self.captions[img_name]
|
289 |
+
|
290 |
+
# input images scaled to 0,1
|
291 |
+
img_t = self.T(input_img)
|
292 |
+
img_t = F.to_tensor(img_t)
|
293 |
+
# output images scaled to -1,1
|
294 |
+
output_t = self.T(output_img)
|
295 |
+
output_t = F.to_tensor(output_t)
|
296 |
+
output_t = F.normalize(output_t, mean=[0.5], std=[0.5])
|
297 |
+
|
298 |
+
input_ids = self.tokenizer(
|
299 |
+
caption, max_length=self.tokenizer.model_max_length,
|
300 |
+
padding="max_length", truncation=True, return_tensors="pt"
|
301 |
+
).input_ids
|
302 |
+
|
303 |
+
return {
|
304 |
+
"output_pixel_values": output_t,
|
305 |
+
"conditioning_pixel_values": img_t,
|
306 |
+
"caption": caption,
|
307 |
+
"input_ids": input_ids,
|
308 |
+
}
|
309 |
+
|
310 |
+
|
311 |
+
class UnpairedDataset(torch.utils.data.Dataset):
|
312 |
+
def __init__(self, dataset_folder, split, image_prep, tokenizer):
|
313 |
+
"""
|
314 |
+
A dataset class for loading unpaired data samples from two distinct domains (source and target),
|
315 |
+
typically used in unsupervised learning tasks like image-to-image translation.
|
316 |
+
|
317 |
+
The class supports loading images from specified dataset folders, applying predefined image
|
318 |
+
preprocessing transformations, and utilizing fixed textual prompts (captions) for each domain,
|
319 |
+
tokenized using a provided tokenizer.
|
320 |
+
|
321 |
+
Parameters:
|
322 |
+
- dataset_folder (str): Base directory of the dataset containing subdirectories (train_A, train_B, test_A, test_B)
|
323 |
+
- split (str): Indicates the dataset split to use. Expected values are 'train' or 'test'.
|
324 |
+
- image_prep (str): he image preprocessing transformation to apply to each image.
|
325 |
+
- tokenizer: The tokenizer used for tokenizing the captions (or prompts).
|
326 |
+
"""
|
327 |
+
super().__init__()
|
328 |
+
if split == "train":
|
329 |
+
self.source_folder = os.path.join(dataset_folder, "train_A")
|
330 |
+
self.target_folder = os.path.join(dataset_folder, "train_B")
|
331 |
+
elif split == "test":
|
332 |
+
self.source_folder = os.path.join(dataset_folder, "test_A")
|
333 |
+
self.target_folder = os.path.join(dataset_folder, "test_B")
|
334 |
+
self.tokenizer = tokenizer
|
335 |
+
with open(os.path.join(dataset_folder, "fixed_prompt_a.txt"), "r") as f:
|
336 |
+
self.fixed_caption_src = f.read().strip()
|
337 |
+
self.input_ids_src = self.tokenizer(
|
338 |
+
self.fixed_caption_src, max_length=self.tokenizer.model_max_length,
|
339 |
+
padding="max_length", truncation=True, return_tensors="pt"
|
340 |
+
).input_ids
|
341 |
+
|
342 |
+
with open(os.path.join(dataset_folder, "fixed_prompt_b.txt"), "r") as f:
|
343 |
+
self.fixed_caption_tgt = f.read().strip()
|
344 |
+
self.input_ids_tgt = self.tokenizer(
|
345 |
+
self.fixed_caption_tgt, max_length=self.tokenizer.model_max_length,
|
346 |
+
padding="max_length", truncation=True, return_tensors="pt"
|
347 |
+
).input_ids
|
348 |
+
# find all images in the source and target folders with all IMG extensions
|
349 |
+
self.l_imgs_src = []
|
350 |
+
for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif"]:
|
351 |
+
self.l_imgs_src.extend(glob(os.path.join(self.source_folder, ext)))
|
352 |
+
self.l_imgs_tgt = []
|
353 |
+
for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.gif"]:
|
354 |
+
self.l_imgs_tgt.extend(glob(os.path.join(self.target_folder, ext)))
|
355 |
+
self.T = build_transform(image_prep)
|
356 |
+
|
357 |
+
def __len__(self):
|
358 |
+
"""
|
359 |
+
Returns:
|
360 |
+
int: The total number of items in the dataset.
|
361 |
+
"""
|
362 |
+
return len(self.l_imgs_src) + len(self.l_imgs_tgt)
|
363 |
+
|
364 |
+
def __getitem__(self, index):
|
365 |
+
"""
|
366 |
+
Fetches a pair of unaligned images from the source and target domains along with their
|
367 |
+
corresponding tokenized captions.
|
368 |
+
|
369 |
+
For the source domain, if the requested index is within the range of available images,
|
370 |
+
the specific image at that index is chosen. If the index exceeds the number of source
|
371 |
+
images, a random source image is selected. For the target domain,
|
372 |
+
an image is always randomly selected, irrespective of the index, to maintain the
|
373 |
+
unpaired nature of the dataset.
|
374 |
+
|
375 |
+
Both images are preprocessed according to the specified image transformation `T`, and normalized.
|
376 |
+
The fixed captions for both domains
|
377 |
+
are included along with their tokenized forms.
|
378 |
+
|
379 |
+
Parameters:
|
380 |
+
- index (int): The index of the source image to retrieve.
|
381 |
+
|
382 |
+
Returns:
|
383 |
+
dict: A dictionary containing processed data for a single training example, with the following keys:
|
384 |
+
- "pixel_values_src": The processed source image
|
385 |
+
- "pixel_values_tgt": The processed target image
|
386 |
+
- "caption_src": The fixed caption of the source domain.
|
387 |
+
- "caption_tgt": The fixed caption of the target domain.
|
388 |
+
- "input_ids_src": The source domain's fixed caption tokenized.
|
389 |
+
- "input_ids_tgt": The target domain's fixed caption tokenized.
|
390 |
+
"""
|
391 |
+
if index < len(self.l_imgs_src):
|
392 |
+
img_path_src = self.l_imgs_src[index]
|
393 |
+
else:
|
394 |
+
img_path_src = random.choice(self.l_imgs_src)
|
395 |
+
img_path_tgt = random.choice(self.l_imgs_tgt)
|
396 |
+
img_pil_src = Image.open(img_path_src).convert("RGB")
|
397 |
+
img_pil_tgt = Image.open(img_path_tgt).convert("RGB")
|
398 |
+
img_t_src = F.to_tensor(self.T(img_pil_src))
|
399 |
+
img_t_tgt = F.to_tensor(self.T(img_pil_tgt))
|
400 |
+
img_t_src = F.normalize(img_t_src, mean=[0.5], std=[0.5])
|
401 |
+
img_t_tgt = F.normalize(img_t_tgt, mean=[0.5], std=[0.5])
|
402 |
+
return {
|
403 |
+
"pixel_values_src": img_t_src,
|
404 |
+
"pixel_values_tgt": img_t_tgt,
|
405 |
+
"caption_src": self.fixed_caption_src,
|
406 |
+
"caption_tgt": self.fixed_caption_tgt,
|
407 |
+
"input_ids_src": self.input_ids_src,
|
408 |
+
"input_ids_tgt": self.input_ids_tgt,
|
409 |
+
}
|
src/pix2pix_turbo.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import requests
|
3 |
+
import sys
|
4 |
+
import copy
|
5 |
+
from tqdm import tqdm
|
6 |
+
import torch
|
7 |
+
from transformers import AutoTokenizer, CLIPTextModel
|
8 |
+
from diffusers import AutoencoderKL, UNet2DConditionModel
|
9 |
+
from diffusers.utils.peft_utils import set_weights_and_activate_adapters
|
10 |
+
from peft import LoraConfig
|
11 |
+
p = "src/"
|
12 |
+
sys.path.append(p)
|
13 |
+
from model import make_1step_sched, my_vae_encoder_fwd, my_vae_decoder_fwd
|
14 |
+
|
15 |
+
|
16 |
+
class TwinConv(torch.nn.Module):
|
17 |
+
def __init__(self, convin_pretrained, convin_curr):
|
18 |
+
super(TwinConv, self).__init__()
|
19 |
+
self.conv_in_pretrained = copy.deepcopy(convin_pretrained)
|
20 |
+
self.conv_in_curr = copy.deepcopy(convin_curr)
|
21 |
+
self.r = None
|
22 |
+
|
23 |
+
def forward(self, x):
|
24 |
+
x1 = self.conv_in_pretrained(x).detach()
|
25 |
+
x2 = self.conv_in_curr(x)
|
26 |
+
return x1 * (1 - self.r) + x2 * (self.r)
|
27 |
+
|
28 |
+
|
29 |
+
class Pix2Pix_Turbo(torch.nn.Module):
|
30 |
+
def __init__(self, pretrained_name=None, pretrained_path=None, ckpt_folder="checkpoints", lora_rank_unet=8, lora_rank_vae=4):
|
31 |
+
super().__init__()
|
32 |
+
self.tokenizer = AutoTokenizer.from_pretrained("stabilityai/sd-turbo", subfolder="tokenizer")
|
33 |
+
self.text_encoder = CLIPTextModel.from_pretrained("stabilityai/sd-turbo", subfolder="text_encoder").cuda()
|
34 |
+
self.sched = make_1step_sched()
|
35 |
+
|
36 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/sd-turbo", subfolder="vae")
|
37 |
+
vae.encoder.forward = my_vae_encoder_fwd.__get__(vae.encoder, vae.encoder.__class__)
|
38 |
+
vae.decoder.forward = my_vae_decoder_fwd.__get__(vae.decoder, vae.decoder.__class__)
|
39 |
+
# add the skip connection convs
|
40 |
+
vae.decoder.skip_conv_1 = torch.nn.Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
|
41 |
+
vae.decoder.skip_conv_2 = torch.nn.Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
|
42 |
+
vae.decoder.skip_conv_3 = torch.nn.Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
|
43 |
+
vae.decoder.skip_conv_4 = torch.nn.Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False).cuda()
|
44 |
+
vae.decoder.ignore_skip = False
|
45 |
+
unet = UNet2DConditionModel.from_pretrained("stabilityai/sd-turbo", subfolder="unet")
|
46 |
+
|
47 |
+
if pretrained_name == "edge_to_image":
|
48 |
+
url = "https://www.cs.cmu.edu/~img2img-turbo/models/edge_to_image_loras.pkl"
|
49 |
+
os.makedirs(ckpt_folder, exist_ok=True)
|
50 |
+
outf = os.path.join(ckpt_folder, "edge_to_image_loras.pkl")
|
51 |
+
if not os.path.exists(outf):
|
52 |
+
print(f"Downloading checkpoint to {outf}")
|
53 |
+
response = requests.get(url, stream=True)
|
54 |
+
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
55 |
+
block_size = 1024 # 1 Kibibyte
|
56 |
+
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
|
57 |
+
with open(outf, 'wb') as file:
|
58 |
+
for data in response.iter_content(block_size):
|
59 |
+
progress_bar.update(len(data))
|
60 |
+
file.write(data)
|
61 |
+
progress_bar.close()
|
62 |
+
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
|
63 |
+
print("ERROR, something went wrong")
|
64 |
+
print(f"Downloaded successfully to {outf}")
|
65 |
+
p_ckpt = outf
|
66 |
+
sd = torch.load(p_ckpt, map_location="cpu")
|
67 |
+
unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_target_modules"])
|
68 |
+
vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_target_modules"])
|
69 |
+
vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
|
70 |
+
_sd_vae = vae.state_dict()
|
71 |
+
for k in sd["state_dict_vae"]:
|
72 |
+
_sd_vae[k] = sd["state_dict_vae"][k]
|
73 |
+
vae.load_state_dict(_sd_vae)
|
74 |
+
unet.add_adapter(unet_lora_config)
|
75 |
+
_sd_unet = unet.state_dict()
|
76 |
+
for k in sd["state_dict_unet"]:
|
77 |
+
_sd_unet[k] = sd["state_dict_unet"][k]
|
78 |
+
unet.load_state_dict(_sd_unet)
|
79 |
+
|
80 |
+
elif pretrained_name == "sketch_to_image_stochastic":
|
81 |
+
# download from url
|
82 |
+
url = "https://www.cs.cmu.edu/~img2img-turbo/models/sketch_to_image_stochastic_lora.pkl"
|
83 |
+
os.makedirs(ckpt_folder, exist_ok=True)
|
84 |
+
outf = os.path.join(ckpt_folder, "sketch_to_image_stochastic_lora.pkl")
|
85 |
+
if not os.path.exists(outf):
|
86 |
+
print(f"Downloading checkpoint to {outf}")
|
87 |
+
response = requests.get(url, stream=True)
|
88 |
+
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
89 |
+
block_size = 1024 # 1 Kibibyte
|
90 |
+
progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
|
91 |
+
with open(outf, 'wb') as file:
|
92 |
+
for data in response.iter_content(block_size):
|
93 |
+
progress_bar.update(len(data))
|
94 |
+
file.write(data)
|
95 |
+
progress_bar.close()
|
96 |
+
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
|
97 |
+
print("ERROR, something went wrong")
|
98 |
+
print(f"Downloaded successfully to {outf}")
|
99 |
+
p_ckpt = outf
|
100 |
+
convin_pretrained = copy.deepcopy(unet.conv_in)
|
101 |
+
unet.conv_in = TwinConv(convin_pretrained, unet.conv_in)
|
102 |
+
sd = torch.load(p_ckpt, map_location="cpu")
|
103 |
+
unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_target_modules"])
|
104 |
+
vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_target_modules"])
|
105 |
+
vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
|
106 |
+
_sd_vae = vae.state_dict()
|
107 |
+
for k in sd["state_dict_vae"]:
|
108 |
+
_sd_vae[k] = sd["state_dict_vae"][k]
|
109 |
+
vae.load_state_dict(_sd_vae)
|
110 |
+
unet.add_adapter(unet_lora_config)
|
111 |
+
_sd_unet = unet.state_dict()
|
112 |
+
for k in sd["state_dict_unet"]:
|
113 |
+
_sd_unet[k] = sd["state_dict_unet"][k]
|
114 |
+
unet.load_state_dict(_sd_unet)
|
115 |
+
|
116 |
+
elif pretrained_path is not None:
|
117 |
+
sd = torch.load(pretrained_path, map_location="cpu")
|
118 |
+
unet_lora_config = LoraConfig(r=sd["rank_unet"], init_lora_weights="gaussian", target_modules=sd["unet_lora_target_modules"])
|
119 |
+
vae_lora_config = LoraConfig(r=sd["rank_vae"], init_lora_weights="gaussian", target_modules=sd["vae_lora_target_modules"])
|
120 |
+
vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
|
121 |
+
_sd_vae = vae.state_dict()
|
122 |
+
for k in sd["state_dict_vae"]:
|
123 |
+
_sd_vae[k] = sd["state_dict_vae"][k]
|
124 |
+
vae.load_state_dict(_sd_vae)
|
125 |
+
unet.add_adapter(unet_lora_config)
|
126 |
+
_sd_unet = unet.state_dict()
|
127 |
+
for k in sd["state_dict_unet"]:
|
128 |
+
_sd_unet[k] = sd["state_dict_unet"][k]
|
129 |
+
unet.load_state_dict(_sd_unet)
|
130 |
+
|
131 |
+
elif pretrained_name is None and pretrained_path is None:
|
132 |
+
print("Initializing model with random weights")
|
133 |
+
torch.nn.init.constant_(vae.decoder.skip_conv_1.weight, 1e-5)
|
134 |
+
torch.nn.init.constant_(vae.decoder.skip_conv_2.weight, 1e-5)
|
135 |
+
torch.nn.init.constant_(vae.decoder.skip_conv_3.weight, 1e-5)
|
136 |
+
torch.nn.init.constant_(vae.decoder.skip_conv_4.weight, 1e-5)
|
137 |
+
target_modules_vae = ["conv1", "conv2", "conv_in", "conv_shortcut", "conv", "conv_out",
|
138 |
+
"skip_conv_1", "skip_conv_2", "skip_conv_3", "skip_conv_4",
|
139 |
+
"to_k", "to_q", "to_v", "to_out.0",
|
140 |
+
]
|
141 |
+
vae_lora_config = LoraConfig(r=lora_rank_vae, init_lora_weights="gaussian",
|
142 |
+
target_modules=target_modules_vae)
|
143 |
+
vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
|
144 |
+
target_modules_unet = [
|
145 |
+
"to_k", "to_q", "to_v", "to_out.0", "conv", "conv1", "conv2", "conv_shortcut", "conv_out",
|
146 |
+
"proj_in", "proj_out", "ff.net.2", "ff.net.0.proj"
|
147 |
+
]
|
148 |
+
unet_lora_config = LoraConfig(r=lora_rank_unet, init_lora_weights="gaussian",
|
149 |
+
target_modules=target_modules_unet
|
150 |
+
)
|
151 |
+
unet.add_adapter(unet_lora_config)
|
152 |
+
self.lora_rank_unet = lora_rank_unet
|
153 |
+
self.lora_rank_vae = lora_rank_vae
|
154 |
+
self.target_modules_vae = target_modules_vae
|
155 |
+
self.target_modules_unet = target_modules_unet
|
156 |
+
|
157 |
+
# unet.enable_xformers_memory_efficient_attention()
|
158 |
+
unet.to("cuda")
|
159 |
+
vae.to("cuda")
|
160 |
+
self.unet, self.vae = unet, vae
|
161 |
+
self.vae.decoder.gamma = 1
|
162 |
+
self.timesteps = torch.tensor([999], device="cuda").long()
|
163 |
+
self.text_encoder.requires_grad_(False)
|
164 |
+
|
165 |
+
def set_eval(self):
|
166 |
+
self.unet.eval()
|
167 |
+
self.vae.eval()
|
168 |
+
self.unet.requires_grad_(False)
|
169 |
+
self.vae.requires_grad_(False)
|
170 |
+
|
171 |
+
def set_train(self):
|
172 |
+
self.unet.train()
|
173 |
+
self.vae.train()
|
174 |
+
for n, _p in self.unet.named_parameters():
|
175 |
+
if "lora" in n:
|
176 |
+
_p.requires_grad = True
|
177 |
+
self.unet.conv_in.requires_grad_(True)
|
178 |
+
for n, _p in self.vae.named_parameters():
|
179 |
+
if "lora" in n:
|
180 |
+
_p.requires_grad = True
|
181 |
+
self.vae.decoder.skip_conv_1.requires_grad_(True)
|
182 |
+
self.vae.decoder.skip_conv_2.requires_grad_(True)
|
183 |
+
self.vae.decoder.skip_conv_3.requires_grad_(True)
|
184 |
+
self.vae.decoder.skip_conv_4.requires_grad_(True)
|
185 |
+
|
186 |
+
def forward(self, c_t, prompt=None, prompt_tokens=None, deterministic=True, r=1.0, noise_map=None):
|
187 |
+
# either the prompt or the prompt_tokens should be provided
|
188 |
+
assert (prompt is None) != (prompt_tokens is None), "Either prompt or prompt_tokens should be provided"
|
189 |
+
|
190 |
+
if prompt is not None:
|
191 |
+
# encode the text prompt
|
192 |
+
caption_tokens = self.tokenizer(prompt, max_length=self.tokenizer.model_max_length,
|
193 |
+
padding="max_length", truncation=True, return_tensors="pt").input_ids.cuda()
|
194 |
+
caption_enc = self.text_encoder(caption_tokens)[0]
|
195 |
+
else:
|
196 |
+
caption_enc = self.text_encoder(prompt_tokens)[0]
|
197 |
+
if deterministic:
|
198 |
+
encoded_control = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor
|
199 |
+
model_pred = self.unet(encoded_control, self.timesteps, encoder_hidden_states=caption_enc,).sample
|
200 |
+
x_denoised = self.sched.step(model_pred, self.timesteps, encoded_control, return_dict=True).prev_sample
|
201 |
+
self.vae.decoder.incoming_skip_acts = self.vae.encoder.current_down_blocks
|
202 |
+
output_image = (self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample).clamp(-1, 1)
|
203 |
+
else:
|
204 |
+
# scale the lora weights based on the r value
|
205 |
+
self.unet.set_adapters(["default"], weights=[r])
|
206 |
+
set_weights_and_activate_adapters(self.vae, ["vae_skip"], [r])
|
207 |
+
encoded_control = self.vae.encode(c_t).latent_dist.sample() * self.vae.config.scaling_factor
|
208 |
+
# combine the input and noise
|
209 |
+
unet_input = encoded_control * r + noise_map * (1 - r)
|
210 |
+
self.unet.conv_in.r = r
|
211 |
+
unet_output = self.unet(unet_input, self.timesteps, encoder_hidden_states=caption_enc,).sample
|
212 |
+
self.unet.conv_in.r = None
|
213 |
+
x_denoised = self.sched.step(unet_output, self.timesteps, unet_input, return_dict=True).prev_sample
|
214 |
+
self.vae.decoder.incoming_skip_acts = self.vae.encoder.current_down_blocks
|
215 |
+
self.vae.decoder.gamma = r
|
216 |
+
output_image = (self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample).clamp(-1, 1)
|
217 |
+
return output_image
|
218 |
+
|
219 |
+
def save_model(self, outf):
|
220 |
+
sd = {}
|
221 |
+
sd["unet_lora_target_modules"] = self.target_modules_unet
|
222 |
+
sd["vae_lora_target_modules"] = self.target_modules_vae
|
223 |
+
sd["rank_unet"] = self.lora_rank_unet
|
224 |
+
sd["rank_vae"] = self.lora_rank_vae
|
225 |
+
sd["state_dict_unet"] = {k: v for k, v in self.unet.state_dict().items() if "lora" in k or "conv_in" in k}
|
226 |
+
sd["state_dict_vae"] = {k: v for k, v in self.vae.state_dict().items() if "lora" in k or "skip" in k}
|
227 |
+
torch.save(sd, outf)
|
src/train_cyclegan_turbo.py
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gc
|
3 |
+
import copy
|
4 |
+
import lpips
|
5 |
+
import torch
|
6 |
+
import wandb
|
7 |
+
from glob import glob
|
8 |
+
import numpy as np
|
9 |
+
from accelerate import Accelerator
|
10 |
+
from accelerate.utils import set_seed
|
11 |
+
from PIL import Image
|
12 |
+
from torchvision import transforms
|
13 |
+
from tqdm.auto import tqdm
|
14 |
+
from transformers import AutoTokenizer, CLIPTextModel
|
15 |
+
from diffusers.optimization import get_scheduler
|
16 |
+
from peft.utils import get_peft_model_state_dict
|
17 |
+
from cleanfid.fid import get_folder_features, build_feature_extractor, frechet_distance
|
18 |
+
import vision_aided_loss
|
19 |
+
from model import make_1step_sched
|
20 |
+
from cyclegan_turbo import CycleGAN_Turbo, VAE_encode, VAE_decode, initialize_unet, initialize_vae
|
21 |
+
from my_utils.training_utils import UnpairedDataset, build_transform, parse_args_unpaired_training
|
22 |
+
from my_utils.dino_struct import DinoStructureLoss
|
23 |
+
|
24 |
+
|
25 |
+
def main(args):
|
26 |
+
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, log_with=args.report_to)
|
27 |
+
set_seed(args.seed)
|
28 |
+
|
29 |
+
if accelerator.is_main_process:
|
30 |
+
os.makedirs(os.path.join(args.output_dir, "checkpoints"), exist_ok=True)
|
31 |
+
|
32 |
+
tokenizer = AutoTokenizer.from_pretrained("stabilityai/sd-turbo", subfolder="tokenizer", revision=args.revision, use_fast=False,)
|
33 |
+
noise_scheduler_1step = make_1step_sched()
|
34 |
+
text_encoder = CLIPTextModel.from_pretrained("stabilityai/sd-turbo", subfolder="text_encoder").cuda()
|
35 |
+
|
36 |
+
unet, l_modules_unet_encoder, l_modules_unet_decoder, l_modules_unet_others = initialize_unet(args.lora_rank_unet, return_lora_module_names=True)
|
37 |
+
vae_a2b, vae_lora_target_modules = initialize_vae(args.lora_rank_vae, return_lora_module_names=True)
|
38 |
+
|
39 |
+
weight_dtype = torch.float32
|
40 |
+
vae_a2b.to(accelerator.device, dtype=weight_dtype)
|
41 |
+
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
42 |
+
unet.to(accelerator.device, dtype=weight_dtype)
|
43 |
+
text_encoder.requires_grad_(False)
|
44 |
+
|
45 |
+
if args.gan_disc_type == "vagan_clip":
|
46 |
+
net_disc_a = vision_aided_loss.Discriminator(cv_type='clip', loss_type=args.gan_loss_type, device="cuda")
|
47 |
+
net_disc_a.cv_ensemble.requires_grad_(False) # Freeze feature extractor
|
48 |
+
net_disc_b = vision_aided_loss.Discriminator(cv_type='clip', loss_type=args.gan_loss_type, device="cuda")
|
49 |
+
net_disc_b.cv_ensemble.requires_grad_(False) # Freeze feature extractor
|
50 |
+
|
51 |
+
crit_cycle, crit_idt = torch.nn.L1Loss(), torch.nn.L1Loss()
|
52 |
+
|
53 |
+
if args.enable_xformers_memory_efficient_attention:
|
54 |
+
unet.enable_xformers_memory_efficient_attention()
|
55 |
+
|
56 |
+
if args.gradient_checkpointing:
|
57 |
+
unet.enable_gradient_checkpointing()
|
58 |
+
|
59 |
+
if args.allow_tf32:
|
60 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
61 |
+
|
62 |
+
unet.conv_in.requires_grad_(True)
|
63 |
+
vae_b2a = copy.deepcopy(vae_a2b)
|
64 |
+
params_gen = CycleGAN_Turbo.get_traininable_params(unet, vae_a2b, vae_b2a)
|
65 |
+
|
66 |
+
vae_enc = VAE_encode(vae_a2b, vae_b2a=vae_b2a)
|
67 |
+
vae_dec = VAE_decode(vae_a2b, vae_b2a=vae_b2a)
|
68 |
+
|
69 |
+
optimizer_gen = torch.optim.AdamW(params_gen, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2),
|
70 |
+
weight_decay=args.adam_weight_decay, eps=args.adam_epsilon,)
|
71 |
+
|
72 |
+
params_disc = list(net_disc_a.parameters()) + list(net_disc_b.parameters())
|
73 |
+
optimizer_disc = torch.optim.AdamW(params_disc, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2),
|
74 |
+
weight_decay=args.adam_weight_decay, eps=args.adam_epsilon,)
|
75 |
+
|
76 |
+
dataset_train = UnpairedDataset(dataset_folder=args.dataset_folder, image_prep=args.train_img_prep, split="train", tokenizer=tokenizer)
|
77 |
+
train_dataloader = torch.utils.data.DataLoader(dataset_train, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers)
|
78 |
+
T_val = build_transform(args.val_img_prep)
|
79 |
+
fixed_caption_src = dataset_train.fixed_caption_src
|
80 |
+
fixed_caption_tgt = dataset_train.fixed_caption_tgt
|
81 |
+
l_images_src_test = []
|
82 |
+
for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp"]:
|
83 |
+
l_images_src_test.extend(glob(os.path.join(args.dataset_folder, "test_A", ext)))
|
84 |
+
l_images_tgt_test = []
|
85 |
+
for ext in ["*.jpg", "*.jpeg", "*.png", "*.bmp"]:
|
86 |
+
l_images_tgt_test.extend(glob(os.path.join(args.dataset_folder, "test_B", ext)))
|
87 |
+
l_images_src_test, l_images_tgt_test = sorted(l_images_src_test), sorted(l_images_tgt_test)
|
88 |
+
|
89 |
+
# make the reference FID statistics
|
90 |
+
if accelerator.is_main_process:
|
91 |
+
feat_model = build_feature_extractor("clean", "cuda", use_dataparallel=False)
|
92 |
+
"""
|
93 |
+
FID reference statistics for A -> B translation
|
94 |
+
"""
|
95 |
+
output_dir_ref = os.path.join(args.output_dir, "fid_reference_a2b")
|
96 |
+
os.makedirs(output_dir_ref, exist_ok=True)
|
97 |
+
# transform all images according to the validation transform and save them
|
98 |
+
for _path in tqdm(l_images_tgt_test):
|
99 |
+
_img = T_val(Image.open(_path).convert("RGB"))
|
100 |
+
outf = os.path.join(output_dir_ref, os.path.basename(_path)).replace(".jpg", ".png")
|
101 |
+
if not os.path.exists(outf):
|
102 |
+
_img.save(outf)
|
103 |
+
# compute the features for the reference images
|
104 |
+
ref_features = get_folder_features(output_dir_ref, model=feat_model, num_workers=0, num=None,
|
105 |
+
shuffle=False, seed=0, batch_size=8, device=torch.device("cuda"),
|
106 |
+
mode="clean", custom_fn_resize=None, description="", verbose=True,
|
107 |
+
custom_image_tranform=None)
|
108 |
+
a2b_ref_mu, a2b_ref_sigma = np.mean(ref_features, axis=0), np.cov(ref_features, rowvar=False)
|
109 |
+
"""
|
110 |
+
FID reference statistics for B -> A translation
|
111 |
+
"""
|
112 |
+
# transform all images according to the validation transform and save them
|
113 |
+
output_dir_ref = os.path.join(args.output_dir, "fid_reference_b2a")
|
114 |
+
os.makedirs(output_dir_ref, exist_ok=True)
|
115 |
+
for _path in tqdm(l_images_src_test):
|
116 |
+
_img = T_val(Image.open(_path).convert("RGB"))
|
117 |
+
outf = os.path.join(output_dir_ref, os.path.basename(_path)).replace(".jpg", ".png")
|
118 |
+
if not os.path.exists(outf):
|
119 |
+
_img.save(outf)
|
120 |
+
# compute the features for the reference images
|
121 |
+
ref_features = get_folder_features(output_dir_ref, model=feat_model, num_workers=0, num=None,
|
122 |
+
shuffle=False, seed=0, batch_size=8, device=torch.device("cuda"),
|
123 |
+
mode="clean", custom_fn_resize=None, description="", verbose=True,
|
124 |
+
custom_image_tranform=None)
|
125 |
+
b2a_ref_mu, b2a_ref_sigma = np.mean(ref_features, axis=0), np.cov(ref_features, rowvar=False)
|
126 |
+
|
127 |
+
lr_scheduler_gen = get_scheduler(args.lr_scheduler, optimizer=optimizer_gen,
|
128 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
129 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
130 |
+
num_cycles=args.lr_num_cycles, power=args.lr_power)
|
131 |
+
lr_scheduler_disc = get_scheduler(args.lr_scheduler, optimizer=optimizer_disc,
|
132 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
133 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
134 |
+
num_cycles=args.lr_num_cycles, power=args.lr_power)
|
135 |
+
|
136 |
+
net_lpips = lpips.LPIPS(net='vgg')
|
137 |
+
net_lpips.cuda()
|
138 |
+
net_lpips.requires_grad_(False)
|
139 |
+
|
140 |
+
fixed_a2b_tokens = tokenizer(fixed_caption_tgt, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids[0]
|
141 |
+
fixed_a2b_emb_base = text_encoder(fixed_a2b_tokens.cuda().unsqueeze(0))[0].detach()
|
142 |
+
fixed_b2a_tokens = tokenizer(fixed_caption_src, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt").input_ids[0]
|
143 |
+
fixed_b2a_emb_base = text_encoder(fixed_b2a_tokens.cuda().unsqueeze(0))[0].detach()
|
144 |
+
del text_encoder, tokenizer # free up some memory
|
145 |
+
|
146 |
+
unet, vae_enc, vae_dec, net_disc_a, net_disc_b = accelerator.prepare(unet, vae_enc, vae_dec, net_disc_a, net_disc_b)
|
147 |
+
net_lpips, optimizer_gen, optimizer_disc, train_dataloader, lr_scheduler_gen, lr_scheduler_disc = accelerator.prepare(
|
148 |
+
net_lpips, optimizer_gen, optimizer_disc, train_dataloader, lr_scheduler_gen, lr_scheduler_disc
|
149 |
+
)
|
150 |
+
if accelerator.is_main_process:
|
151 |
+
accelerator.init_trackers(args.tracker_project_name, config=dict(vars(args)))
|
152 |
+
|
153 |
+
first_epoch = 0
|
154 |
+
global_step = 0
|
155 |
+
progress_bar = tqdm(range(0, args.max_train_steps), initial=global_step, desc="Steps",
|
156 |
+
disable=not accelerator.is_local_main_process,)
|
157 |
+
# turn off eff. attn for the disc
|
158 |
+
for name, module in net_disc_a.named_modules():
|
159 |
+
if "attn" in name:
|
160 |
+
module.fused_attn = False
|
161 |
+
for name, module in net_disc_b.named_modules():
|
162 |
+
if "attn" in name:
|
163 |
+
module.fused_attn = False
|
164 |
+
|
165 |
+
for epoch in range(first_epoch, args.max_train_epochs):
|
166 |
+
for step, batch in enumerate(train_dataloader):
|
167 |
+
l_acc = [unet, net_disc_a, net_disc_b, vae_enc, vae_dec]
|
168 |
+
with accelerator.accumulate(*l_acc):
|
169 |
+
img_a = batch["pixel_values_src"].to(dtype=weight_dtype)
|
170 |
+
img_b = batch["pixel_values_tgt"].to(dtype=weight_dtype)
|
171 |
+
|
172 |
+
bsz = img_a.shape[0]
|
173 |
+
fixed_a2b_emb = fixed_a2b_emb_base.repeat(bsz, 1, 1).to(dtype=weight_dtype)
|
174 |
+
fixed_b2a_emb = fixed_b2a_emb_base.repeat(bsz, 1, 1).to(dtype=weight_dtype)
|
175 |
+
timesteps = torch.tensor([noise_scheduler_1step.config.num_train_timesteps - 1] * bsz, device=img_a.device).long()
|
176 |
+
|
177 |
+
"""
|
178 |
+
Cycle Objective
|
179 |
+
"""
|
180 |
+
# A -> fake B -> rec A
|
181 |
+
cyc_fake_b = CycleGAN_Turbo.forward_with_networks(img_a, "a2b", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_a2b_emb)
|
182 |
+
cyc_rec_a = CycleGAN_Turbo.forward_with_networks(cyc_fake_b, "b2a", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_b2a_emb)
|
183 |
+
loss_cycle_a = crit_cycle(cyc_rec_a, img_a) * args.lambda_cycle
|
184 |
+
loss_cycle_a += net_lpips(cyc_rec_a, img_a).mean() * args.lambda_cycle_lpips
|
185 |
+
# B -> fake A -> rec B
|
186 |
+
cyc_fake_a = CycleGAN_Turbo.forward_with_networks(img_b, "b2a", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_b2a_emb)
|
187 |
+
cyc_rec_b = CycleGAN_Turbo.forward_with_networks(cyc_fake_a, "a2b", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_a2b_emb)
|
188 |
+
loss_cycle_b = crit_cycle(cyc_rec_b, img_b) * args.lambda_cycle
|
189 |
+
loss_cycle_b += net_lpips(cyc_rec_b, img_b).mean() * args.lambda_cycle_lpips
|
190 |
+
accelerator.backward(loss_cycle_a + loss_cycle_b, retain_graph=False)
|
191 |
+
if accelerator.sync_gradients:
|
192 |
+
accelerator.clip_grad_norm_(params_gen, args.max_grad_norm)
|
193 |
+
|
194 |
+
optimizer_gen.step()
|
195 |
+
lr_scheduler_gen.step()
|
196 |
+
optimizer_gen.zero_grad()
|
197 |
+
|
198 |
+
"""
|
199 |
+
Generator Objective (GAN) for task a->b and b->a (fake inputs)
|
200 |
+
"""
|
201 |
+
fake_a = CycleGAN_Turbo.forward_with_networks(img_b, "b2a", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_b2a_emb)
|
202 |
+
fake_b = CycleGAN_Turbo.forward_with_networks(img_a, "a2b", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_a2b_emb)
|
203 |
+
loss_gan_a = net_disc_a(fake_b, for_G=True).mean() * args.lambda_gan
|
204 |
+
loss_gan_b = net_disc_b(fake_a, for_G=True).mean() * args.lambda_gan
|
205 |
+
accelerator.backward(loss_gan_a + loss_gan_b, retain_graph=False)
|
206 |
+
if accelerator.sync_gradients:
|
207 |
+
accelerator.clip_grad_norm_(params_gen, args.max_grad_norm)
|
208 |
+
optimizer_gen.step()
|
209 |
+
lr_scheduler_gen.step()
|
210 |
+
optimizer_gen.zero_grad()
|
211 |
+
|
212 |
+
"""
|
213 |
+
Identity Objective
|
214 |
+
"""
|
215 |
+
idt_a = CycleGAN_Turbo.forward_with_networks(img_b, "a2b", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_a2b_emb)
|
216 |
+
loss_idt_a = crit_idt(idt_a, img_b) * args.lambda_idt
|
217 |
+
loss_idt_a += net_lpips(idt_a, img_b).mean() * args.lambda_idt_lpips
|
218 |
+
idt_b = CycleGAN_Turbo.forward_with_networks(img_a, "b2a", vae_enc, unet, vae_dec, noise_scheduler_1step, timesteps, fixed_b2a_emb)
|
219 |
+
loss_idt_b = crit_idt(idt_b, img_a) * args.lambda_idt
|
220 |
+
loss_idt_b += net_lpips(idt_b, img_a).mean() * args.lambda_idt_lpips
|
221 |
+
loss_g_idt = loss_idt_a + loss_idt_b
|
222 |
+
accelerator.backward(loss_g_idt, retain_graph=False)
|
223 |
+
if accelerator.sync_gradients:
|
224 |
+
accelerator.clip_grad_norm_(params_gen, args.max_grad_norm)
|
225 |
+
optimizer_gen.step()
|
226 |
+
lr_scheduler_gen.step()
|
227 |
+
optimizer_gen.zero_grad()
|
228 |
+
|
229 |
+
"""
|
230 |
+
Discriminator for task a->b and b->a (fake inputs)
|
231 |
+
"""
|
232 |
+
loss_D_A_fake = net_disc_a(fake_b.detach(), for_real=False).mean() * args.lambda_gan
|
233 |
+
loss_D_B_fake = net_disc_b(fake_a.detach(), for_real=False).mean() * args.lambda_gan
|
234 |
+
loss_D_fake = (loss_D_A_fake + loss_D_B_fake) * 0.5
|
235 |
+
accelerator.backward(loss_D_fake, retain_graph=False)
|
236 |
+
if accelerator.sync_gradients:
|
237 |
+
params_to_clip = list(net_disc_a.parameters()) + list(net_disc_b.parameters())
|
238 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
239 |
+
optimizer_disc.step()
|
240 |
+
lr_scheduler_disc.step()
|
241 |
+
optimizer_disc.zero_grad()
|
242 |
+
|
243 |
+
"""
|
244 |
+
Discriminator for task a->b and b->a (real inputs)
|
245 |
+
"""
|
246 |
+
loss_D_A_real = net_disc_a(img_b, for_real=True).mean() * args.lambda_gan
|
247 |
+
loss_D_B_real = net_disc_b(img_a, for_real=True).mean() * args.lambda_gan
|
248 |
+
loss_D_real = (loss_D_A_real + loss_D_B_real) * 0.5
|
249 |
+
accelerator.backward(loss_D_real, retain_graph=False)
|
250 |
+
if accelerator.sync_gradients:
|
251 |
+
params_to_clip = list(net_disc_a.parameters()) + list(net_disc_b.parameters())
|
252 |
+
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
253 |
+
optimizer_disc.step()
|
254 |
+
lr_scheduler_disc.step()
|
255 |
+
optimizer_disc.zero_grad()
|
256 |
+
|
257 |
+
logs = {}
|
258 |
+
logs["cycle_a"] = loss_cycle_a.detach().item()
|
259 |
+
logs["cycle_b"] = loss_cycle_b.detach().item()
|
260 |
+
logs["gan_a"] = loss_gan_a.detach().item()
|
261 |
+
logs["gan_b"] = loss_gan_b.detach().item()
|
262 |
+
logs["disc_a"] = loss_D_A_fake.detach().item() + loss_D_A_real.detach().item()
|
263 |
+
logs["disc_b"] = loss_D_B_fake.detach().item() + loss_D_B_real.detach().item()
|
264 |
+
logs["idt_a"] = loss_idt_a.detach().item()
|
265 |
+
logs["idt_b"] = loss_idt_b.detach().item()
|
266 |
+
|
267 |
+
if accelerator.sync_gradients:
|
268 |
+
progress_bar.update(1)
|
269 |
+
global_step += 1
|
270 |
+
|
271 |
+
if accelerator.is_main_process:
|
272 |
+
eval_unet = accelerator.unwrap_model(unet)
|
273 |
+
eval_vae_enc = accelerator.unwrap_model(vae_enc)
|
274 |
+
eval_vae_dec = accelerator.unwrap_model(vae_dec)
|
275 |
+
if global_step % args.viz_freq == 1:
|
276 |
+
for tracker in accelerator.trackers:
|
277 |
+
if tracker.name == "wandb":
|
278 |
+
viz_img_a = batch["pixel_values_src"].to(dtype=weight_dtype)
|
279 |
+
viz_img_b = batch["pixel_values_tgt"].to(dtype=weight_dtype)
|
280 |
+
log_dict = {
|
281 |
+
"train/real_a": [wandb.Image(viz_img_a[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(bsz)],
|
282 |
+
"train/real_b": [wandb.Image(viz_img_b[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(bsz)],
|
283 |
+
}
|
284 |
+
log_dict["train/rec_a"] = [wandb.Image(cyc_rec_a[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(bsz)]
|
285 |
+
log_dict["train/rec_b"] = [wandb.Image(cyc_rec_b[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(bsz)]
|
286 |
+
log_dict["train/fake_b"] = [wandb.Image(fake_b[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(bsz)]
|
287 |
+
log_dict["train/fake_a"] = [wandb.Image(fake_a[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(bsz)]
|
288 |
+
tracker.log(log_dict)
|
289 |
+
gc.collect()
|
290 |
+
torch.cuda.empty_cache()
|
291 |
+
|
292 |
+
if global_step % args.checkpointing_steps == 1:
|
293 |
+
outf = os.path.join(args.output_dir, "checkpoints", f"model_{global_step}.pkl")
|
294 |
+
sd = {}
|
295 |
+
sd["l_target_modules_encoder"] = l_modules_unet_encoder
|
296 |
+
sd["l_target_modules_decoder"] = l_modules_unet_decoder
|
297 |
+
sd["l_modules_others"] = l_modules_unet_others
|
298 |
+
sd["rank_unet"] = args.lora_rank_unet
|
299 |
+
sd["sd_encoder"] = get_peft_model_state_dict(eval_unet, adapter_name="default_encoder")
|
300 |
+
sd["sd_decoder"] = get_peft_model_state_dict(eval_unet, adapter_name="default_decoder")
|
301 |
+
sd["sd_other"] = get_peft_model_state_dict(eval_unet, adapter_name="default_others")
|
302 |
+
sd["rank_vae"] = args.lora_rank_vae
|
303 |
+
sd["vae_lora_target_modules"] = vae_lora_target_modules
|
304 |
+
sd["sd_vae_enc"] = eval_vae_enc.state_dict()
|
305 |
+
sd["sd_vae_dec"] = eval_vae_dec.state_dict()
|
306 |
+
torch.save(sd, outf)
|
307 |
+
gc.collect()
|
308 |
+
torch.cuda.empty_cache()
|
309 |
+
|
310 |
+
# compute val FID and DINO-Struct scores
|
311 |
+
if global_step % args.validation_steps == 1:
|
312 |
+
_timesteps = torch.tensor([noise_scheduler_1step.config.num_train_timesteps - 1] * 1, device="cuda").long()
|
313 |
+
net_dino = DinoStructureLoss()
|
314 |
+
"""
|
315 |
+
Evaluate "A->B"
|
316 |
+
"""
|
317 |
+
fid_output_dir = os.path.join(args.output_dir, f"fid-{global_step}/samples_a2b")
|
318 |
+
os.makedirs(fid_output_dir, exist_ok=True)
|
319 |
+
l_dino_scores_a2b = []
|
320 |
+
# get val input images from domain a
|
321 |
+
for idx, input_img_path in enumerate(tqdm(l_images_src_test)):
|
322 |
+
if idx > args.validation_num_images and args.validation_num_images > 0:
|
323 |
+
break
|
324 |
+
outf = os.path.join(fid_output_dir, f"{idx}.png")
|
325 |
+
with torch.no_grad():
|
326 |
+
input_img = T_val(Image.open(input_img_path).convert("RGB"))
|
327 |
+
img_a = transforms.ToTensor()(input_img)
|
328 |
+
img_a = transforms.Normalize([0.5], [0.5])(img_a).unsqueeze(0).cuda()
|
329 |
+
eval_fake_b = CycleGAN_Turbo.forward_with_networks(img_a, "a2b", eval_vae_enc, eval_unet,
|
330 |
+
eval_vae_dec, noise_scheduler_1step, _timesteps, fixed_a2b_emb[0:1])
|
331 |
+
eval_fake_b_pil = transforms.ToPILImage()(eval_fake_b[0] * 0.5 + 0.5)
|
332 |
+
eval_fake_b_pil.save(outf)
|
333 |
+
a = net_dino.preprocess(input_img).unsqueeze(0).cuda()
|
334 |
+
b = net_dino.preprocess(eval_fake_b_pil).unsqueeze(0).cuda()
|
335 |
+
dino_ssim = net_dino.calculate_global_ssim_loss(a, b).item()
|
336 |
+
l_dino_scores_a2b.append(dino_ssim)
|
337 |
+
dino_score_a2b = np.mean(l_dino_scores_a2b)
|
338 |
+
gen_features = get_folder_features(fid_output_dir, model=feat_model, num_workers=0, num=None,
|
339 |
+
shuffle=False, seed=0, batch_size=8, device=torch.device("cuda"),
|
340 |
+
mode="clean", custom_fn_resize=None, description="", verbose=True,
|
341 |
+
custom_image_tranform=None)
|
342 |
+
ed_mu, ed_sigma = np.mean(gen_features, axis=0), np.cov(gen_features, rowvar=False)
|
343 |
+
score_fid_a2b = frechet_distance(a2b_ref_mu, a2b_ref_sigma, ed_mu, ed_sigma)
|
344 |
+
print(f"step={global_step}, fid(a2b)={score_fid_a2b:.2f}, dino(a2b)={dino_score_a2b:.3f}")
|
345 |
+
|
346 |
+
"""
|
347 |
+
compute FID for "B->A"
|
348 |
+
"""
|
349 |
+
fid_output_dir = os.path.join(args.output_dir, f"fid-{global_step}/samples_b2a")
|
350 |
+
os.makedirs(fid_output_dir, exist_ok=True)
|
351 |
+
l_dino_scores_b2a = []
|
352 |
+
# get val input images from domain b
|
353 |
+
for idx, input_img_path in enumerate(tqdm(l_images_tgt_test)):
|
354 |
+
if idx > args.validation_num_images and args.validation_num_images > 0:
|
355 |
+
break
|
356 |
+
outf = os.path.join(fid_output_dir, f"{idx}.png")
|
357 |
+
with torch.no_grad():
|
358 |
+
input_img = T_val(Image.open(input_img_path).convert("RGB"))
|
359 |
+
img_b = transforms.ToTensor()(input_img)
|
360 |
+
img_b = transforms.Normalize([0.5], [0.5])(img_b).unsqueeze(0).cuda()
|
361 |
+
eval_fake_a = CycleGAN_Turbo.forward_with_networks(img_b, "b2a", eval_vae_enc, eval_unet,
|
362 |
+
eval_vae_dec, noise_scheduler_1step, _timesteps, fixed_b2a_emb[0:1])
|
363 |
+
eval_fake_a_pil = transforms.ToPILImage()(eval_fake_a[0] * 0.5 + 0.5)
|
364 |
+
eval_fake_a_pil.save(outf)
|
365 |
+
a = net_dino.preprocess(input_img).unsqueeze(0).cuda()
|
366 |
+
b = net_dino.preprocess(eval_fake_a_pil).unsqueeze(0).cuda()
|
367 |
+
dino_ssim = net_dino.calculate_global_ssim_loss(a, b).item()
|
368 |
+
l_dino_scores_b2a.append(dino_ssim)
|
369 |
+
dino_score_b2a = np.mean(l_dino_scores_b2a)
|
370 |
+
gen_features = get_folder_features(fid_output_dir, model=feat_model, num_workers=0, num=None,
|
371 |
+
shuffle=False, seed=0, batch_size=8, device=torch.device("cuda"),
|
372 |
+
mode="clean", custom_fn_resize=None, description="", verbose=True,
|
373 |
+
custom_image_tranform=None)
|
374 |
+
ed_mu, ed_sigma = np.mean(gen_features, axis=0), np.cov(gen_features, rowvar=False)
|
375 |
+
score_fid_b2a = frechet_distance(b2a_ref_mu, b2a_ref_sigma, ed_mu, ed_sigma)
|
376 |
+
print(f"step={global_step}, fid(b2a)={score_fid_b2a}, dino(b2a)={dino_score_b2a:.3f}")
|
377 |
+
logs["val/fid_a2b"], logs["val/fid_b2a"] = score_fid_a2b, score_fid_b2a
|
378 |
+
logs["val/dino_struct_a2b"], logs["val/dino_struct_b2a"] = dino_score_a2b, dino_score_b2a
|
379 |
+
del net_dino # free up memory
|
380 |
+
|
381 |
+
progress_bar.set_postfix(**logs)
|
382 |
+
accelerator.log(logs, step=global_step)
|
383 |
+
if global_step >= args.max_train_steps:
|
384 |
+
break
|
385 |
+
|
386 |
+
|
387 |
+
if __name__ == "__main__":
|
388 |
+
args = parse_args_unpaired_training()
|
389 |
+
main(args)
|
src/train_pix2pix_turbo.py
ADDED
@@ -0,0 +1,307 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gc
|
3 |
+
import lpips
|
4 |
+
import clip
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torch.utils.checkpoint
|
9 |
+
import transformers
|
10 |
+
from accelerate import Accelerator
|
11 |
+
from accelerate.utils import set_seed
|
12 |
+
from PIL import Image
|
13 |
+
from torchvision import transforms
|
14 |
+
from tqdm.auto import tqdm
|
15 |
+
|
16 |
+
import diffusers
|
17 |
+
from diffusers.utils.import_utils import is_xformers_available
|
18 |
+
from diffusers.optimization import get_scheduler
|
19 |
+
|
20 |
+
import wandb
|
21 |
+
from cleanfid.fid import get_folder_features, build_feature_extractor, fid_from_feats
|
22 |
+
|
23 |
+
from pix2pix_turbo import Pix2Pix_Turbo
|
24 |
+
from my_utils.training_utils import parse_args_paired_training, PairedDataset
|
25 |
+
|
26 |
+
|
27 |
+
def main(args):
|
28 |
+
accelerator = Accelerator(
|
29 |
+
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
30 |
+
mixed_precision=args.mixed_precision,
|
31 |
+
log_with=args.report_to,
|
32 |
+
)
|
33 |
+
|
34 |
+
if accelerator.is_local_main_process:
|
35 |
+
transformers.utils.logging.set_verbosity_warning()
|
36 |
+
diffusers.utils.logging.set_verbosity_info()
|
37 |
+
else:
|
38 |
+
transformers.utils.logging.set_verbosity_error()
|
39 |
+
diffusers.utils.logging.set_verbosity_error()
|
40 |
+
|
41 |
+
if args.seed is not None:
|
42 |
+
set_seed(args.seed)
|
43 |
+
|
44 |
+
if accelerator.is_main_process:
|
45 |
+
os.makedirs(os.path.join(args.output_dir, "checkpoints"), exist_ok=True)
|
46 |
+
os.makedirs(os.path.join(args.output_dir, "eval"), exist_ok=True)
|
47 |
+
|
48 |
+
if args.pretrained_model_name_or_path == "stabilityai/sd-turbo":
|
49 |
+
net_pix2pix = Pix2Pix_Turbo(lora_rank_unet=args.lora_rank_unet, lora_rank_vae=args.lora_rank_vae)
|
50 |
+
net_pix2pix.set_train()
|
51 |
+
|
52 |
+
if args.enable_xformers_memory_efficient_attention:
|
53 |
+
if is_xformers_available():
|
54 |
+
net_pix2pix.unet.enable_xformers_memory_efficient_attention()
|
55 |
+
else:
|
56 |
+
raise ValueError("xformers is not available, please install it by running `pip install xformers`")
|
57 |
+
|
58 |
+
if args.gradient_checkpointing:
|
59 |
+
net_pix2pix.unet.enable_gradient_checkpointing()
|
60 |
+
|
61 |
+
if args.allow_tf32:
|
62 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
63 |
+
|
64 |
+
if args.gan_disc_type == "vagan_clip":
|
65 |
+
import vision_aided_loss
|
66 |
+
net_disc = vision_aided_loss.Discriminator(cv_type='clip', loss_type=args.gan_loss_type, device="cuda")
|
67 |
+
else:
|
68 |
+
raise NotImplementedError(f"Discriminator type {args.gan_disc_type} not implemented")
|
69 |
+
|
70 |
+
net_disc = net_disc.cuda()
|
71 |
+
net_disc.requires_grad_(True)
|
72 |
+
net_disc.cv_ensemble.requires_grad_(False)
|
73 |
+
net_disc.train()
|
74 |
+
|
75 |
+
net_lpips = lpips.LPIPS(net='vgg').cuda()
|
76 |
+
net_clip, _ = clip.load("ViT-B/32", device="cuda")
|
77 |
+
net_clip.requires_grad_(False)
|
78 |
+
net_clip.eval()
|
79 |
+
|
80 |
+
net_lpips.requires_grad_(False)
|
81 |
+
|
82 |
+
# make the optimizer
|
83 |
+
layers_to_opt = []
|
84 |
+
for n, _p in net_pix2pix.unet.named_parameters():
|
85 |
+
if "lora" in n:
|
86 |
+
assert _p.requires_grad
|
87 |
+
layers_to_opt.append(_p)
|
88 |
+
layers_to_opt += list(net_pix2pix.unet.conv_in.parameters())
|
89 |
+
for n, _p in net_pix2pix.vae.named_parameters():
|
90 |
+
if "lora" in n and "vae_skip" in n:
|
91 |
+
assert _p.requires_grad
|
92 |
+
layers_to_opt.append(_p)
|
93 |
+
layers_to_opt = layers_to_opt + list(net_pix2pix.vae.decoder.skip_conv_1.parameters()) + \
|
94 |
+
list(net_pix2pix.vae.decoder.skip_conv_2.parameters()) + \
|
95 |
+
list(net_pix2pix.vae.decoder.skip_conv_3.parameters()) + \
|
96 |
+
list(net_pix2pix.vae.decoder.skip_conv_4.parameters())
|
97 |
+
|
98 |
+
optimizer = torch.optim.AdamW(layers_to_opt, lr=args.learning_rate,
|
99 |
+
betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
|
100 |
+
eps=args.adam_epsilon,)
|
101 |
+
lr_scheduler = get_scheduler(args.lr_scheduler, optimizer=optimizer,
|
102 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
103 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
104 |
+
num_cycles=args.lr_num_cycles, power=args.lr_power,)
|
105 |
+
|
106 |
+
optimizer_disc = torch.optim.AdamW(net_disc.parameters(), lr=args.learning_rate,
|
107 |
+
betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay,
|
108 |
+
eps=args.adam_epsilon,)
|
109 |
+
lr_scheduler_disc = get_scheduler(args.lr_scheduler, optimizer=optimizer_disc,
|
110 |
+
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
111 |
+
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
112 |
+
num_cycles=args.lr_num_cycles, power=args.lr_power)
|
113 |
+
|
114 |
+
dataset_train = PairedDataset(dataset_folder=args.dataset_folder, image_prep=args.train_image_prep, split="train", tokenizer=net_pix2pix.tokenizer)
|
115 |
+
dl_train = torch.utils.data.DataLoader(dataset_train, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers)
|
116 |
+
dataset_val = PairedDataset(dataset_folder=args.dataset_folder, image_prep=args.test_image_prep, split="test", tokenizer=net_pix2pix.tokenizer)
|
117 |
+
dl_val = torch.utils.data.DataLoader(dataset_val, batch_size=1, shuffle=False, num_workers=0)
|
118 |
+
|
119 |
+
# Prepare everything with our `accelerator`.
|
120 |
+
net_pix2pix, net_disc, optimizer, optimizer_disc, dl_train, lr_scheduler, lr_scheduler_disc = accelerator.prepare(
|
121 |
+
net_pix2pix, net_disc, optimizer, optimizer_disc, dl_train, lr_scheduler, lr_scheduler_disc
|
122 |
+
)
|
123 |
+
net_clip, net_lpips = accelerator.prepare(net_clip, net_lpips)
|
124 |
+
# renorm with image net statistics
|
125 |
+
t_clip_renorm = transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
|
126 |
+
weight_dtype = torch.float32
|
127 |
+
if accelerator.mixed_precision == "fp16":
|
128 |
+
weight_dtype = torch.float16
|
129 |
+
elif accelerator.mixed_precision == "bf16":
|
130 |
+
weight_dtype = torch.bfloat16
|
131 |
+
|
132 |
+
# Move al networksr to device and cast to weight_dtype
|
133 |
+
net_pix2pix.to(accelerator.device, dtype=weight_dtype)
|
134 |
+
net_disc.to(accelerator.device, dtype=weight_dtype)
|
135 |
+
net_lpips.to(accelerator.device, dtype=weight_dtype)
|
136 |
+
net_clip.to(accelerator.device, dtype=weight_dtype)
|
137 |
+
|
138 |
+
# We need to initialize the trackers we use, and also store our configuration.
|
139 |
+
# The trackers initializes automatically on the main process.
|
140 |
+
if accelerator.is_main_process:
|
141 |
+
tracker_config = dict(vars(args))
|
142 |
+
accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
|
143 |
+
|
144 |
+
progress_bar = tqdm(range(0, args.max_train_steps), initial=0, desc="Steps",
|
145 |
+
disable=not accelerator.is_local_main_process,)
|
146 |
+
|
147 |
+
# turn off eff. attn for the discriminator
|
148 |
+
for name, module in net_disc.named_modules():
|
149 |
+
if "attn" in name:
|
150 |
+
module.fused_attn = False
|
151 |
+
|
152 |
+
# compute the reference stats for FID tracking
|
153 |
+
if accelerator.is_main_process and args.track_val_fid:
|
154 |
+
feat_model = build_feature_extractor("clean", "cuda", use_dataparallel=False)
|
155 |
+
|
156 |
+
def fn_transform(x):
|
157 |
+
x_pil = Image.fromarray(x)
|
158 |
+
out_pil = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.LANCZOS)(x_pil)
|
159 |
+
return np.array(out_pil)
|
160 |
+
|
161 |
+
ref_stats = get_folder_features(os.path.join(args.dataset_folder, "test_B"), model=feat_model, num_workers=0, num=None,
|
162 |
+
shuffle=False, seed=0, batch_size=8, device=torch.device("cuda"),
|
163 |
+
mode="clean", custom_image_tranform=fn_transform, description="", verbose=True)
|
164 |
+
|
165 |
+
# start the training loop
|
166 |
+
global_step = 0
|
167 |
+
for epoch in range(0, args.num_training_epochs):
|
168 |
+
for step, batch in enumerate(dl_train):
|
169 |
+
l_acc = [net_pix2pix, net_disc]
|
170 |
+
with accelerator.accumulate(*l_acc):
|
171 |
+
x_src = batch["conditioning_pixel_values"]
|
172 |
+
x_tgt = batch["output_pixel_values"]
|
173 |
+
B, C, H, W = x_src.shape
|
174 |
+
# forward pass
|
175 |
+
x_tgt_pred = net_pix2pix(x_src, prompt_tokens=batch["input_ids"], deterministic=True)
|
176 |
+
# Reconstruction loss
|
177 |
+
loss_l2 = F.mse_loss(x_tgt_pred.float(), x_tgt.float(), reduction="mean") * args.lambda_l2
|
178 |
+
loss_lpips = net_lpips(x_tgt_pred.float(), x_tgt.float()).mean() * args.lambda_lpips
|
179 |
+
loss = loss_l2 + loss_lpips
|
180 |
+
# CLIP similarity loss
|
181 |
+
if args.lambda_clipsim > 0:
|
182 |
+
x_tgt_pred_renorm = t_clip_renorm(x_tgt_pred * 0.5 + 0.5)
|
183 |
+
x_tgt_pred_renorm = F.interpolate(x_tgt_pred_renorm, (224, 224), mode="bilinear", align_corners=False)
|
184 |
+
caption_tokens = clip.tokenize(batch["caption"], truncate=True).to(x_tgt_pred.device)
|
185 |
+
clipsim, _ = net_clip(x_tgt_pred_renorm, caption_tokens)
|
186 |
+
loss_clipsim = (1 - clipsim.mean() / 100)
|
187 |
+
loss += loss_clipsim * args.lambda_clipsim
|
188 |
+
accelerator.backward(loss, retain_graph=False)
|
189 |
+
if accelerator.sync_gradients:
|
190 |
+
accelerator.clip_grad_norm_(layers_to_opt, args.max_grad_norm)
|
191 |
+
optimizer.step()
|
192 |
+
lr_scheduler.step()
|
193 |
+
optimizer.zero_grad(set_to_none=args.set_grads_to_none)
|
194 |
+
|
195 |
+
"""
|
196 |
+
Generator loss: fool the discriminator
|
197 |
+
"""
|
198 |
+
x_tgt_pred = net_pix2pix(x_src, prompt_tokens=batch["input_ids"], deterministic=True)
|
199 |
+
lossG = net_disc(x_tgt_pred, for_G=True).mean() * args.lambda_gan
|
200 |
+
accelerator.backward(lossG)
|
201 |
+
if accelerator.sync_gradients:
|
202 |
+
accelerator.clip_grad_norm_(layers_to_opt, args.max_grad_norm)
|
203 |
+
optimizer.step()
|
204 |
+
lr_scheduler.step()
|
205 |
+
optimizer.zero_grad(set_to_none=args.set_grads_to_none)
|
206 |
+
|
207 |
+
"""
|
208 |
+
Discriminator loss: fake image vs real image
|
209 |
+
"""
|
210 |
+
# real image
|
211 |
+
lossD_real = net_disc(x_tgt.detach(), for_real=True).mean() * args.lambda_gan
|
212 |
+
accelerator.backward(lossD_real.mean())
|
213 |
+
if accelerator.sync_gradients:
|
214 |
+
accelerator.clip_grad_norm_(net_disc.parameters(), args.max_grad_norm)
|
215 |
+
optimizer_disc.step()
|
216 |
+
lr_scheduler_disc.step()
|
217 |
+
optimizer_disc.zero_grad(set_to_none=args.set_grads_to_none)
|
218 |
+
# fake image
|
219 |
+
lossD_fake = net_disc(x_tgt_pred.detach(), for_real=False).mean() * args.lambda_gan
|
220 |
+
accelerator.backward(lossD_fake.mean())
|
221 |
+
if accelerator.sync_gradients:
|
222 |
+
accelerator.clip_grad_norm_(net_disc.parameters(), args.max_grad_norm)
|
223 |
+
optimizer_disc.step()
|
224 |
+
optimizer_disc.zero_grad(set_to_none=args.set_grads_to_none)
|
225 |
+
lossD = lossD_real + lossD_fake
|
226 |
+
|
227 |
+
# Checks if the accelerator has performed an optimization step behind the scenes
|
228 |
+
if accelerator.sync_gradients:
|
229 |
+
progress_bar.update(1)
|
230 |
+
global_step += 1
|
231 |
+
|
232 |
+
if accelerator.is_main_process:
|
233 |
+
logs = {}
|
234 |
+
# log all the losses
|
235 |
+
logs["lossG"] = lossG.detach().item()
|
236 |
+
logs["lossD"] = lossD.detach().item()
|
237 |
+
logs["loss_l2"] = loss_l2.detach().item()
|
238 |
+
logs["loss_lpips"] = loss_lpips.detach().item()
|
239 |
+
if args.lambda_clipsim > 0:
|
240 |
+
logs["loss_clipsim"] = loss_clipsim.detach().item()
|
241 |
+
progress_bar.set_postfix(**logs)
|
242 |
+
|
243 |
+
# viz some images
|
244 |
+
if global_step % args.viz_freq == 1:
|
245 |
+
log_dict = {
|
246 |
+
"train/source": [wandb.Image(x_src[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(B)],
|
247 |
+
"train/target": [wandb.Image(x_tgt[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(B)],
|
248 |
+
"train/model_output": [wandb.Image(x_tgt_pred[idx].float().detach().cpu(), caption=f"idx={idx}") for idx in range(B)],
|
249 |
+
}
|
250 |
+
for k in log_dict:
|
251 |
+
logs[k] = log_dict[k]
|
252 |
+
|
253 |
+
# checkpoint the model
|
254 |
+
if global_step % args.checkpointing_steps == 1:
|
255 |
+
outf = os.path.join(args.output_dir, "checkpoints", f"model_{global_step}.pkl")
|
256 |
+
accelerator.unwrap_model(net_pix2pix).save_model(outf)
|
257 |
+
|
258 |
+
# compute validation set FID, L2, LPIPS, CLIP-SIM
|
259 |
+
if global_step % args.eval_freq == 1:
|
260 |
+
l_l2, l_lpips, l_clipsim = [], [], []
|
261 |
+
if args.track_val_fid:
|
262 |
+
os.makedirs(os.path.join(args.output_dir, "eval", f"fid_{global_step}"), exist_ok=True)
|
263 |
+
for step, batch_val in enumerate(dl_val):
|
264 |
+
if step >= args.num_samples_eval:
|
265 |
+
break
|
266 |
+
x_src = batch_val["conditioning_pixel_values"].cuda()
|
267 |
+
x_tgt = batch_val["output_pixel_values"].cuda()
|
268 |
+
B, C, H, W = x_src.shape
|
269 |
+
assert B == 1, "Use batch size 1 for eval."
|
270 |
+
with torch.no_grad():
|
271 |
+
# forward pass
|
272 |
+
x_tgt_pred = accelerator.unwrap_model(net_pix2pix)(x_src, prompt_tokens=batch_val["input_ids"].cuda(), deterministic=True)
|
273 |
+
# compute the reconstruction losses
|
274 |
+
loss_l2 = F.mse_loss(x_tgt_pred.float(), x_tgt.float(), reduction="mean")
|
275 |
+
loss_lpips = net_lpips(x_tgt_pred.float(), x_tgt.float()).mean()
|
276 |
+
# compute clip similarity loss
|
277 |
+
x_tgt_pred_renorm = t_clip_renorm(x_tgt_pred * 0.5 + 0.5)
|
278 |
+
x_tgt_pred_renorm = F.interpolate(x_tgt_pred_renorm, (224, 224), mode="bilinear", align_corners=False)
|
279 |
+
caption_tokens = clip.tokenize(batch_val["caption"], truncate=True).to(x_tgt_pred.device)
|
280 |
+
clipsim, _ = net_clip(x_tgt_pred_renorm, caption_tokens)
|
281 |
+
clipsim = clipsim.mean()
|
282 |
+
|
283 |
+
l_l2.append(loss_l2.item())
|
284 |
+
l_lpips.append(loss_lpips.item())
|
285 |
+
l_clipsim.append(clipsim.item())
|
286 |
+
# save output images to file for FID evaluation
|
287 |
+
if args.track_val_fid:
|
288 |
+
output_pil = transforms.ToPILImage()(x_tgt_pred[0].cpu() * 0.5 + 0.5)
|
289 |
+
outf = os.path.join(args.output_dir, "eval", f"fid_{global_step}", f"val_{step}.png")
|
290 |
+
output_pil.save(outf)
|
291 |
+
if args.track_val_fid:
|
292 |
+
curr_stats = get_folder_features(os.path.join(args.output_dir, "eval", f"fid_{global_step}"), model=feat_model, num_workers=0, num=None,
|
293 |
+
shuffle=False, seed=0, batch_size=8, device=torch.device("cuda"),
|
294 |
+
mode="clean", custom_image_tranform=fn_transform, description="", verbose=True)
|
295 |
+
fid_score = fid_from_feats(ref_stats, curr_stats)
|
296 |
+
logs["val/clean_fid"] = fid_score
|
297 |
+
logs["val/l2"] = np.mean(l_l2)
|
298 |
+
logs["val/lpips"] = np.mean(l_lpips)
|
299 |
+
logs["val/clipsim"] = np.mean(l_clipsim)
|
300 |
+
gc.collect()
|
301 |
+
torch.cuda.empty_cache()
|
302 |
+
accelerator.log(logs, step=global_step)
|
303 |
+
|
304 |
+
|
305 |
+
if __name__ == "__main__":
|
306 |
+
args = parse_args_paired_training()
|
307 |
+
main(args)
|